diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index 5b21fc1..5fac8c1 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -32,6 +32,7 @@ type Device interface { GetMigDevices() ([]MigDevice, error) GetMigProfiles() ([]MigProfile, error) GetPCIBusID() (string, error) + IsFabricAttached() (bool, error) IsMigCapable() (bool, error) IsMigEnabled() (bool, error) VisitMigDevices(func(j int, m MigDevice) error) error @@ -208,6 +209,47 @@ func (d *device) IsMigEnabled() (bool, error) { return (mode == nvml.DEVICE_MIG_ENABLE), nil } +// IsFabricAttached checks if a device is attached to a GPU fabric. +func (d *device) IsFabricAttached() (bool, error) { + if d.lib.hasSymbol("nvmlDeviceGetGpuFabricInfo") { + info, ret := d.GetGpuFabricInfo() + if ret == nvml.ERROR_NOT_SUPPORTED { + return false, nil + } + if ret != nvml.SUCCESS { + return false, fmt.Errorf("error getting GPU Fabric Info: %v", ret) + } + if info.State != nvml.GPU_FABRIC_STATE_COMPLETED { + return false, nil + } + if nvml.Return(info.Status) != nvml.SUCCESS { + return false, nil + } + + return true, nil + } + + if d.lib.hasSymbol("nvmlDeviceGetGpuFabricInfoV") { + info, ret := d.GetGpuFabricInfoV().V2() + if ret == nvml.ERROR_NOT_SUPPORTED { + return false, nil + } + if ret != nvml.SUCCESS { + return false, fmt.Errorf("error getting GPU Fabric Info: %v", ret) + } + if info.State != nvml.GPU_FABRIC_STATE_COMPLETED { + return false, nil + } + if nvml.Return(info.Status) != nvml.SUCCESS { + return false, nil + } + + return true, nil + } + + return false, nil +} + // VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it. func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error { capable, err := d.IsMigCapable()