Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/spec/v1/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func (vs *VGPUConfigSpec) MatchesDeviceFilter(deviceID types.DeviceID) bool {
}

for _, df := range deviceFilter {
newDeviceID, _ := types.NewDeviceIDFromString(df)
if newDeviceID == deviceID {
filterDeviceID, _ := types.NewDeviceIDFromString(df)
if filterDeviceID.Matches(deviceID) {
return true
}
}
Expand Down
77 changes: 77 additions & 0 deletions api/spec/v1/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package v1

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/NVIDIA/vgpu-device-manager/pkg/types"
)

func TestVGPUConfigSpecMatchesDeviceFilter(t *testing.T) {
a16 := types.NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE)

testCases := []struct {
name string
filter interface{}
want bool
}{
{
name: "empty filter matches all devices",
filter: "",
want: true,
},
{
name: "primary-only filter matches subsystem-aware hardware",
filter: "0x25B610DE",
want: true,
},
{
name: "matching subsystem filter matches hardware",
filter: "0x25B610DE:0x14A910DE",
want: true,
},
{
name: "sibling subsystem filter does not match hardware",
filter: "0x25B610DE:0x157E10DE",
want: false,
},
{
name: "multiple filters stop at a matching subsystem entry",
filter: []string{"0x25B610DE:0x157E10DE", "0x25B610DE:0x14A910DE"},
want: true,
},
{
name: "malformed filters with extra separators do not match",
filter: "0x25B610DE:0x14A910DE:extra",
want: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
spec := VGPUConfigSpec{
DeviceFilter: tc.filter,
Devices: "all",
}

require.Equal(t, tc.want, spec.MatchesDeviceFilter(a16))
})
}
}
2 changes: 1 addition & 1 deletion cmd/nvidia-vgpu-dm/assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func WalkSelectedVGPUConfigForEachGPU(vgpuConfig v1.VGPUConfigSpecSlice, f func(
}

for i, gpu := range gpus {
deviceID := types.NewDeviceID(gpu.Device, gpu.Vendor)
deviceID := types.NewDeviceIDWithSubsystem(gpu.Device, gpu.Vendor, gpu.SubsystemDevice, gpu.SubsystemVendor)

if !vc.MatchesDeviceFilter(deviceID) {
continue
Expand Down
106 changes: 96 additions & 10 deletions pkg/types/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,125 @@ package types

import (
"fmt"
"math"
"strconv"
"strings"
)

// DeviceID represents a GPU Device ID as read from a GPUs PCIe config space.
type DeviceID uint32
type DeviceID struct {
Device uint16
Vendor uint16
SubsystemDevice uint16
SubsystemVendor uint16
HasSubsystem bool
}

// NewDeviceID constructs a new 'DeviceID' from the device and vendor values pulled from a GPUs PCIe config space.
func NewDeviceID(device, vendor uint16) DeviceID {
return DeviceID((uint32(device) << 16) | uint32(vendor))
return DeviceID{
Device: device,
Vendor: vendor,
}
}

// NewDeviceIDWithSubsystem constructs a new 'DeviceID' with subsystem values.
func NewDeviceIDWithSubsystem(device, vendor, subDevice, subVendor uint16) DeviceID {
return DeviceID{
Device: device,
Vendor: vendor,
SubsystemDevice: subDevice,
SubsystemVendor: subVendor,
HasSubsystem: true,
}
}

// NewDeviceIDFromString constructs a 'DeviceID' from its string representation.
func NewDeviceIDFromString(str string) (DeviceID, error) {
deviceID, err := strconv.ParseUint(str, 0, 32)
parts := strings.Split(str, ":")
if len(parts) > 2 {
return DeviceID{}, fmt.Errorf(
"invalid DeviceID format '%v': expected '<devicevendor>' or '<devicevendor>:<subdevicevendor>'",
str,
)
}

deviceIDRaw, err := strconv.ParseUint(parts[0], 0, 32)
if err != nil {
return 0, fmt.Errorf("unable to create DeviceID from string '%v': %v", str, err)
return DeviceID{}, fmt.Errorf("unable to create DeviceID from string '%v': %v", str, err)
}

device, vendor, err := splitRawDeviceID(deviceIDRaw)
if err != nil {
return DeviceID{}, fmt.Errorf("unable to create DeviceID from string '%v': %v", str, err)
}

deviceID := DeviceID{
Device: device,
Vendor: vendor,
}

if len(parts) == 2 {
subIDRaw, err := strconv.ParseUint(parts[1], 0, 32)
if err != nil {
return DeviceID{}, fmt.Errorf("unable to create Subsystem from string '%v': %v", str, err)
}

subsystemDevice, subsystemVendor, err := splitRawDeviceID(subIDRaw)
if err != nil {
return DeviceID{}, fmt.Errorf("unable to create Subsystem from string '%v': %v", str, err)
}

deviceID.SubsystemDevice = subsystemDevice
deviceID.SubsystemVendor = subsystemVendor
deviceID.HasSubsystem = true
}
return DeviceID(deviceID), nil

return deviceID, nil
}

// String returns a 'DeviceID' as a string.
func (d DeviceID) String() string {
return fmt.Sprintf("0x%X", uint32(d))
primary := fmt.Sprintf("0x%04X%04X", d.Device, d.Vendor)
if d.HasSubsystem {
return fmt.Sprintf("%s:0x%04X%04X", primary, d.SubsystemDevice, d.SubsystemVendor)
}
return primary
}

// GetVendor returns the 'vendor' portion of a 'DeviceID'.
func (d DeviceID) GetVendor() uint16 {
// nolint:gosec // DeviceID is constructed from two uint16 numbers
return uint16(d)
return d.Vendor
}

// GetDevice returns the 'device' portion of a 'DeviceID'.
func (d DeviceID) GetDevice() uint16 {
// nolint:gosec // DeviceID is constructed from two uint16 numbers
return uint16(d >> 16)
return d.Device
}

// Matches checks if a hardware GPU matches the DeviceID filter.
// If the filter has a subsystem defined, it requires an exact match on all 4 components.
// Otherwise, it only matches on the primary device and vendor IDs.
func (filter DeviceID) Matches(hardware DeviceID) bool {
if filter.Device != hardware.Device || filter.Vendor != hardware.Vendor {
return false
}

if filter.HasSubsystem {
if filter.SubsystemDevice != hardware.SubsystemDevice || filter.SubsystemVendor != hardware.SubsystemVendor {
return false
}
}

return true
}

func splitRawDeviceID(raw uint64) (uint16, uint16, error) {
device := raw >> 16
vendor := raw & math.MaxUint16
if device > math.MaxUint16 || vendor > math.MaxUint16 {
return 0, 0, fmt.Errorf("value 0x%X is out of range for a PCI device ID", raw)
}

return uint16(device), uint16(vendor), nil
}
141 changes: 141 additions & 0 deletions pkg/types/device_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package types

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestNewDeviceID(t *testing.T) {
require.Equal(t, DeviceID{
Device: 0x25B6,
Vendor: 0x10DE,
}, NewDeviceID(0x25B6, 0x10DE))
}

func TestNewDeviceIDWithSubsystem(t *testing.T) {
require.Equal(t, DeviceID{
Device: 0x25B6,
Vendor: 0x10DE,
SubsystemDevice: 0x14A9,
SubsystemVendor: 0x10DE,
HasSubsystem: true,
}, NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE))
}

func TestNewDeviceIDFromString(t *testing.T) {
testCases := []struct {
name string
input string
want DeviceID
wantErr string
}{
{
name: "primary only",
input: "0x25B610DE",
want: NewDeviceID(0x25B6, 0x10DE),
},
{
name: "with subsystem",
input: "0x25B610DE:0x14A910DE",
want: NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE),
},
{
name: "invalid primary",
input: "not-a-device-id",
wantErr: "unable to create DeviceID",
},
{
name: "invalid subsystem",
input: "0x25B610DE:not-a-subsystem",
wantErr: "unable to create Subsystem",
},
{
name: "too many separators",
input: "0x25B610DE:0x14A910DE:extra",
wantErr: "invalid DeviceID format",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := NewDeviceIDFromString(tc.input)
if tc.wantErr != "" {
require.Error(t, err)
require.ErrorContains(t, err, tc.wantErr)
return
}

require.NoError(t, err)
require.Equal(t, tc.want, got)
})
}
}

func TestDeviceIDString(t *testing.T) {
require.Equal(t, "0x25B610DE", NewDeviceID(0x25B6, 0x10DE).String())
require.Equal(
t,
"0x25B610DE:0x14A910DE",
NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE).String(),
)
}

func TestDeviceIDMatches(t *testing.T) {
a16 := NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE)
a2 := NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x157E, 0x10DE)

testCases := []struct {
name string
filter DeviceID
hardware DeviceID
want bool
}{
{
name: "primary only filter matches same primary id",
filter: NewDeviceID(0x25B6, 0x10DE),
hardware: a16,
want: true,
},
{
name: "subsystem filter matches exact hardware",
filter: NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE),
hardware: a16,
want: true,
},
{
name: "subsystem filter rejects sibling subsystem",
filter: NewDeviceIDWithSubsystem(0x25B6, 0x10DE, 0x14A9, 0x10DE),
hardware: a2,
want: false,
},
{
name: "different primary id does not match",
filter: NewDeviceID(0x1E30, 0x10DE),
hardware: a16,
want: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.want, tc.filter.Matches(tc.hardware))
})
}
}
Loading