diff --git a/api/spec/v1/helpers.go b/api/spec/v1/helpers.go index 4af5dd07..dafdb7ca 100644 --- a/api/spec/v1/helpers.go +++ b/api/spec/v1/helpers.go @@ -37,8 +37,8 @@ func (vs *VGPUConfigSpec) MatchesDeviceFilter(deviceID types.DeviceID) bool { } for _, df := range deviceFilter { - newDeviceID, _ := types.NewDeviceIDFromString(df) - if newDeviceID == deviceID { + filterDeviceID, _ := types.NewDeviceIDFromString(df) + if filterDeviceID.Matches(deviceID) { return true } } diff --git a/api/spec/v1/helpers_test.go b/api/spec/v1/helpers_test.go new file mode 100644 index 00000000..f7a8e400 --- /dev/null +++ b/api/spec/v1/helpers_test.go @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package v1 + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/vgpu-device-manager/pkg/types" +) + +func TestVGPUConfigSpecMatchesDeviceFilter(t *testing.T) { + a16 := types.NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE) + + testCases := []struct { + name string + filter interface{} + want bool + }{ + { + name: "empty filter matches all devices", + filter: "", + want: true, + }, + { + name: "primary-only filter matches subsystem-aware hardware", + filter: "0x25B610DE", + want: true, + }, + { + name: "matching subsystem filter matches hardware", + filter: "0x25B610DE:0x14A910DE", + want: true, + }, + { + name: "sibling subsystem filter does not match hardware", + filter: "0x25B610DE:0x157E10DE", + want: false, + }, + { + name: "multiple filters stop at a matching subsystem entry", + filter: []string{"0x25B610DE:0x157E10DE", "0x25B610DE:0x14A910DE"}, + want: true, + }, + { + name: "malformed filters with extra separators do not match", + filter: "0x25B610DE:0x14A910DE:extra", + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + spec := VGPUConfigSpec{ + DeviceFilter: tc.filter, + Devices: "all", + } + + require.Equal(t, tc.want, spec.MatchesDeviceFilter(a16)) + }) + } +} diff --git a/cmd/nvidia-vgpu-dm/assert/assert.go b/cmd/nvidia-vgpu-dm/assert/assert.go index 72537645..cd52e262 100644 --- a/cmd/nvidia-vgpu-dm/assert/assert.go +++ b/cmd/nvidia-vgpu-dm/assert/assert.go @@ -205,7 +205,7 @@ func WalkSelectedVGPUConfigForEachGPU(vgpuConfig v1.VGPUConfigSpecSlice, f func( } for i, gpu := range gpus { - deviceID := types.NewDeviceID(gpu.Device, gpu.Vendor) + deviceID := types.NewDeviceIDWithSubsystem(gpu.Device, gpu.Vendor, gpu.SubsystemDevice, gpu.SubsystemVendor) if !vc.MatchesDeviceFilter(deviceID) { continue diff --git a/pkg/types/device.go b/pkg/types/device.go index 5227572c..e7a2ec6a 100644 --- a/pkg/types/device.go +++ b/pkg/types/device.go @@ -18,39 +18,125 @@ package types import ( "fmt" + "math" "strconv" + "strings" ) // DeviceID represents a GPU Device ID as read from a GPUs PCIe config space. -type DeviceID uint32 +type DeviceID struct { + Device uint16 + Vendor uint16 + SubsystemDevice uint16 + SubsystemVendor uint16 + HasSubsystem bool +} // NewDeviceID constructs a new 'DeviceID' from the device and vendor values pulled from a GPUs PCIe config space. func NewDeviceID(device, vendor uint16) DeviceID { - return DeviceID((uint32(device) << 16) | uint32(vendor)) + return DeviceID{ + Device: device, + Vendor: vendor, + } +} + +// NewDeviceIDWithSubsystem constructs a new 'DeviceID' with subsystem values. +func NewDeviceIDWithSubsystem(device, vendor, subDevice, subVendor uint16) DeviceID { + return DeviceID{ + Device: device, + Vendor: vendor, + SubsystemDevice: subDevice, + SubsystemVendor: subVendor, + HasSubsystem: true, + } } // NewDeviceIDFromString constructs a 'DeviceID' from its string representation. func NewDeviceIDFromString(str string) (DeviceID, error) { - deviceID, err := strconv.ParseUint(str, 0, 32) + parts := strings.Split(str, ":") + if len(parts) > 2 { + return DeviceID{}, fmt.Errorf( + "invalid DeviceID format '%v': expected '' or ':'", + str, + ) + } + + deviceIDRaw, err := strconv.ParseUint(parts[0], 0, 32) if err != nil { - return 0, fmt.Errorf("unable to create DeviceID from string '%v': %v", str, err) + return DeviceID{}, fmt.Errorf("unable to create DeviceID from string '%v': %v", str, err) + } + + device, vendor, err := splitRawDeviceID(deviceIDRaw) + if err != nil { + return DeviceID{}, fmt.Errorf("unable to create DeviceID from string '%v': %v", str, err) + } + + deviceID := DeviceID{ + Device: device, + Vendor: vendor, + } + + if len(parts) == 2 { + subIDRaw, err := strconv.ParseUint(parts[1], 0, 32) + if err != nil { + return DeviceID{}, fmt.Errorf("unable to create Subsystem from string '%v': %v", str, err) + } + + subsystemDevice, subsystemVendor, err := splitRawDeviceID(subIDRaw) + if err != nil { + return DeviceID{}, fmt.Errorf("unable to create Subsystem from string '%v': %v", str, err) + } + + deviceID.SubsystemDevice = subsystemDevice + deviceID.SubsystemVendor = subsystemVendor + deviceID.HasSubsystem = true } - return DeviceID(deviceID), nil + + return deviceID, nil } // String returns a 'DeviceID' as a string. func (d DeviceID) String() string { - return fmt.Sprintf("0x%X", uint32(d)) + primary := fmt.Sprintf("0x%04X%04X", d.Device, d.Vendor) + if d.HasSubsystem { + return fmt.Sprintf("%s:0x%04X%04X", primary, d.SubsystemDevice, d.SubsystemVendor) + } + return primary } // GetVendor returns the 'vendor' portion of a 'DeviceID'. func (d DeviceID) GetVendor() uint16 { - // nolint:gosec // DeviceID is constructed from two uint16 numbers - return uint16(d) + return d.Vendor } // GetDevice returns the 'device' portion of a 'DeviceID'. func (d DeviceID) GetDevice() uint16 { - // nolint:gosec // DeviceID is constructed from two uint16 numbers - return uint16(d >> 16) + return d.Device +} + +// Matches checks if a hardware GPU matches the DeviceID filter. +// If the filter has a subsystem defined, it requires an exact match on all 4 components. +// Otherwise, it only matches on the primary device and vendor IDs. +func (filter DeviceID) Matches(hardware DeviceID) bool { + if filter.Device != hardware.Device || filter.Vendor != hardware.Vendor { + return false + } + + if filter.HasSubsystem { + if filter.SubsystemDevice != hardware.SubsystemDevice || filter.SubsystemVendor != hardware.SubsystemVendor { + return false + } + } + + return true +} + +func splitRawDeviceID(raw uint64) (uint16, uint16, error) { + device := raw >> 16 + vendor := raw & math.MaxUint16 + if device > math.MaxUint16 || vendor > math.MaxUint16 { + return 0, 0, fmt.Errorf("value 0x%X is out of range for a PCI device ID", raw) + } + + return uint16(device), uint16(vendor), nil } diff --git a/pkg/types/device_test.go b/pkg/types/device_test.go new file mode 100644 index 00000000..ec01db52 --- /dev/null +++ b/pkg/types/device_test.go @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewDeviceID(t *testing.T) { + require.Equal(t, DeviceID{ + Device: 0x25B6, + Vendor: 0x10DE, + }, NewDeviceID(0x25B6, 0x10DE)) +} + +func TestNewDeviceIDWithSubsystem(t *testing.T) { + require.Equal(t, DeviceID{ + Device: 0x25B6, + Vendor: 0x10DE, + SubsystemDevice: 0x14A9, + SubsystemVendor: 0x10DE, + HasSubsystem: true, + }, NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE)) +} + +func TestNewDeviceIDFromString(t *testing.T) { + testCases := []struct { + name string + input string + want DeviceID + wantErr string + }{ + { + name: "primary only", + input: "0x25B610DE", + want: NewDeviceID(0x25B6, 0x10DE), + }, + { + name: "with subsystem", + input: "0x25B610DE:0x14A910DE", + want: NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE), + }, + { + name: "invalid primary", + input: "not-a-device-id", + wantErr: "unable to create DeviceID", + }, + { + name: "invalid subsystem", + input: "0x25B610DE:not-a-subsystem", + wantErr: "unable to create Subsystem", + }, + { + name: "too many separators", + input: "0x25B610DE:0x14A910DE:extra", + wantErr: "invalid DeviceID format", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewDeviceIDFromString(tc.input) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + return + } + + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestDeviceIDString(t *testing.T) { + require.Equal(t, "0x25B610DE", NewDeviceID(0x25B6, 0x10DE).String()) + require.Equal( + t, + "0x25B610DE:0x14A910DE", + NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE).String(), + ) +} + +func TestDeviceIDMatches(t *testing.T) { + a16 := NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE) + a2 := NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x157E, 0x10DE) + + testCases := []struct { + name string + filter DeviceID + hardware DeviceID + want bool + }{ + { + name: "primary only filter matches same primary id", + filter: NewDeviceID(0x25B6, 0x10DE), + hardware: a16, + want: true, + }, + { + name: "subsystem filter matches exact hardware", + filter: NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE), + hardware: a16, + want: true, + }, + { + name: "subsystem filter rejects sibling subsystem", + filter: NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE), + hardware: a2, + want: false, + }, + { + name: "different primary id does not match", + filter: NewDeviceID(0x1E30, 0x10DE), + hardware: a16, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.want, tc.filter.Matches(tc.hardware)) + }) + } +} diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go index f41d0f24..c6e9f475 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go @@ -112,6 +112,8 @@ type NvidiaPCIDevice struct { Class uint32 ClassName string Device uint16 + SubsystemVendor uint16 + SubsystemDevice uint16 DeviceName string Driver string IommuGroup int @@ -283,6 +285,20 @@ func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevi return nil, fmt.Errorf("unable to convert device string to uint16: %v", deviceStr) } + subVendor, err := os.ReadFile(path.Join(devicePath, "subsystem_vendor")) + var subVendorID uint64 + if err == nil { + subVendorStr := strings.TrimSpace(string(subVendor)) + subVendorID, _ = strconv.ParseUint(subVendorStr, 0, 16) + } + + subDevice, err := os.ReadFile(path.Join(devicePath, "subsystem_device")) + var subDeviceID uint64 + if err == nil { + subDeviceStr := strings.TrimSpace(string(subDevice)) + subDeviceID, _ = strconv.ParseUint(subDeviceStr, 0, 16) + } + driver, err := getDriver(devicePath) if err != nil { return nil, fmt.Errorf("unable to detect driver for %s: %w", address, err) @@ -381,6 +397,8 @@ func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevi Vendor: uint16(vendorID), Class: uint32(classID), Device: uint16(deviceID), + SubsystemVendor: uint16(subVendorID), + SubsystemDevice: uint16(subDeviceID), Driver: driver, IommuGroup: int(iommuGroup), IommuFD: iommuFD,