Skip to content

Commit

Permalink
[no-relnote] Address integer overflow linting errors
Browse files Browse the repository at this point in the history
Signed-off-by: Evan Lezar <[email protected]>
elezar committed Aug 22, 2024
1 parent 9a890ac commit 3f0378d
Showing 6 changed files with 39 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -22,6 +22,9 @@ linters-settings:
local-prefixes: github.com/NVIDIA/k8s-device-plugin

issues:
exclude:
# A conversion of a uint8 to an int cannot overflow.
- "G115: integer overflow conversion uint8 -> int"
exclude-rules:
# We use math/rand instead of crypto/rand for unique names in e2e tests.
- path: tests/e2e/
4 changes: 2 additions & 2 deletions api/config/v1/replicas.go
Original file line number Diff line number Diff line change
@@ -299,9 +299,9 @@ func (s *ReplicatedDevices) UnmarshalJSON(b []byte) error {
result := make([]ReplicatedDeviceRef, len(slice))
for i, s := range slice {
// Match a uint as a GPU index and convert it to a string
var index uint
var index uint64
if err = json.Unmarshal(s, &index); err == nil {
result[i] = ReplicatedDeviceRef(strconv.Itoa(int(index)))
result[i] = ReplicatedDeviceRef(strconv.FormatUint(index, 10))
continue
}
// Match strings as valid entries if they are GPU indices, MIG indices, or UUIDs
1 change: 1 addition & 0 deletions internal/cuda/api.go
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@ func DriverGetVersion() (int, Result) {
// DeviceGet returns the device with the specified index.
func DeviceGet(index int) (Device, Result) {
var device Device
//nolint:gosec // Since index is internal-only, we ignore possible overflow errors here.
r := cuDeviceGet(&device, int32(index))

return device, r
38 changes: 24 additions & 14 deletions internal/rm/health.go
Original file line number Diff line number Diff line change
@@ -88,8 +88,8 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
}()

parentToDeviceMap := make(map[string]*Device)
deviceIDToGiMap := make(map[string]int)
deviceIDToCiMap := make(map[string]int)
deviceIDToGiMap := make(map[string]uint32)
deviceIDToCiMap := make(map[string]uint32)

eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError)
for _, d := range devices {
@@ -112,7 +112,7 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic

supportedEvents, ret := gpu.GetSupportedEventTypes()
if ret != nvml.SUCCESS {
klog.Infof("Unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret)
klog.Infof("unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret)
unhealthy <- d
continue
}
@@ -176,7 +176,7 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
if d.IsMigDevice() && e.GpuInstanceId != 0xFFFFFFFF && e.ComputeInstanceId != 0xFFFFFFFF {
gi := deviceIDToGiMap[d.ID]
ci := deviceIDToCiMap[d.ID]
if !(uint32(gi) == e.GpuInstanceId && uint32(ci) == e.ComputeInstanceId) {
if !(gi == e.GpuInstanceId && ci == e.ComputeInstanceId) {
continue
}
klog.Infof("Event for mig device %v (gi=%v, ci=%v)", d.ID, gi, ci)
@@ -215,15 +215,15 @@ func getAdditionalXids(input string) []uint64 {
// getDevicePlacement returns the placement of the specified device.
// For a MIG device the placement is defined by the 3-tuple <parent UUID, GI, CI>
// For a full device the returned 3-tuple is the device's uuid and 0xFFFFFFFF for the other two elements.
func (r *nvmlResourceManager) getDevicePlacement(d *Device) (string, int, int, error) {
func (r *nvmlResourceManager) getDevicePlacement(d *Device) (string, uint32, uint32, error) {
if !d.IsMigDevice() {
return d.GetUUID(), 0xFFFFFFFF, 0xFFFFFFFF, nil
}
return r.getMigDeviceParts(d)
}

// getMigDeviceParts returns the parent GI and CI ids of the MIG device.
func (r *nvmlResourceManager) getMigDeviceParts(d *Device) (string, int, int, error) {
func (r *nvmlResourceManager) getMigDeviceParts(d *Device) (string, uint32, uint32, error) {
if !d.IsMigDevice() {
return "", 0, 0, fmt.Errorf("cannot get GI and CI of full device")
}
@@ -250,32 +250,42 @@ func (r *nvmlResourceManager) getMigDeviceParts(d *Device) (string, int, int, er
if ret != nvml.SUCCESS {
return "", 0, 0, fmt.Errorf("failed to get Compute Instance ID: %v", ret)
}
return parentUUID, gi, ci, nil
//nolint:gosec // We know that the values returned from Get*InstanceId are within the valid uint32 range.
return parentUUID, uint32(gi), uint32(ci), nil
}
return parseMigDeviceUUID(uuid)
}

// parseMigDeviceUUID splits the MIG device UUID into the parent device UUID and ci and gi
func parseMigDeviceUUID(mig string) (string, int, int, error) {
func parseMigDeviceUUID(mig string) (string, uint32, uint32, error) {
tokens := strings.SplitN(mig, "-", 2)
if len(tokens) != 2 || tokens[0] != "MIG" {
return "", 0, 0, fmt.Errorf("Unable to parse UUID as MIG device")
return "", 0, 0, fmt.Errorf("unable to parse UUID as MIG device")
}

tokens = strings.SplitN(tokens[1], "/", 3)
if len(tokens) != 3 || !strings.HasPrefix(tokens[0], "GPU-") {
return "", 0, 0, fmt.Errorf("Unable to parse UUID as MIG device")
return "", 0, 0, fmt.Errorf("unable to parse UUID as MIG device")
}

gi, err := strconv.Atoi(tokens[1])
gi, err := toUint32(tokens[1])
if err != nil {
return "", 0, 0, fmt.Errorf("Unable to parse UUID as MIG device")
return "", 0, 0, fmt.Errorf("unable to parse UUID as MIG device")
}

ci, err := strconv.Atoi(tokens[2])
ci, err := toUint32(tokens[2])
if err != nil {
return "", 0, 0, fmt.Errorf("Unable to parse UUID as MIG device")
return "", 0, 0, fmt.Errorf("unable to parse UUID as MIG device")
}

return tokens[0], gi, ci, nil
}

func toUint32(s string) (uint32, error) {
u, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return 0, err
}
//nolint:gosec // Since we parse s with a 32-bit size this will not overflow.
return uint32(u), nil
}
10 changes: 5 additions & 5 deletions internal/vgpu/pciutil.go
Original file line number Diff line number Diff line change
@@ -122,11 +122,11 @@ func (d *PCIDevice) GetVendorSpecificCapability() ([]byte, error) {
}

var visited [256]byte
pos := int(GetByte(d.Config, PciCapabilityList))
pos := GetByte(d.Config, PciCapabilityList)
for pos != 0 {
id := int(GetByte(d.Config, pos+PciCapabilityListID))
next := int(GetByte(d.Config, pos+PciCapabilityListNext))
length := int(GetByte(d.Config, pos+PciCapabilityLength))
id := GetByte(d.Config, pos+PciCapabilityListID)
next := GetByte(d.Config, pos+PciCapabilityListNext)
length := GetByte(d.Config, pos+PciCapabilityLength)

if visited[pos] != 0 {
// chain looped
@@ -149,7 +149,7 @@ func (d *PCIDevice) GetVendorSpecificCapability() ([]byte, error) {
}

// GetByte returns a single byte of data at specified position
func GetByte(buffer []byte, pos int) uint8 {
func GetByte(buffer []byte, pos uint8) uint8 {
return buffer[pos]
}

8 changes: 4 additions & 4 deletions internal/vgpu/vgpu.go
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ type Info struct {

const (
// VGPUCapabilityRecordStart indicates offset of beginning vGPU capability record
VGPUCapabilityRecordStart = 5
VGPUCapabilityRecordStart uint8 = 5
// HostDriverVersionLength indicates max length of driver version
HostDriverVersionLength = 10
// HostDriverBranchLength indicates max length of driver branch
@@ -116,14 +116,14 @@ func (d *Device) GetInfo() (*Info, error) {
foundDriverVersionRecord := false
pos := VGPUCapabilityRecordStart
record := GetByte(d.vGPUCapability, VGPUCapabilityRecordStart)
for record != 0 && pos < len(d.vGPUCapability) {
for record != 0 && int(pos) < len(d.vGPUCapability) {
// find next record
recordLength := GetByte(d.vGPUCapability, pos+1)
pos += int(recordLength)
pos += recordLength
record = GetByte(d.vGPUCapability, pos)
}

if record == 0 && pos+2+HostDriverVersionLength+HostDriverBranchLength <= len(d.vGPUCapability) {
if record == 0 && int(pos+2+HostDriverVersionLength+HostDriverBranchLength) <= len(d.vGPUCapability) {
foundDriverVersionRecord = true
// found vGPU host driver version record type
// initialized at record data byte, i.e pos + 1(record id byte) + 1(record lengh byte)

0 comments on commit 3f0378d

Please sign in to comment.