From 93fa13d09800b38b6c49e5a9dcfebb93d6b8cb40 Mon Sep 17 00:00:00 2001 From: Kevin Klues Date: Tue, 23 Apr 2024 12:29:25 +0000 Subject: [PATCH] Assign dgxa100 mock methods as function pointers instead of overwriting Signed-off-by: Kevin Klues --- pkg/nvml/mock/dgxa100/dgxa100.go | 355 ++++++++++++++++--------------- 1 file changed, 186 insertions(+), 169 deletions(-) diff --git a/pkg/nvml/mock/dgxa100/dgxa100.go b/pkg/nvml/mock/dgxa100/dgxa100.go index a042435..e15121b 100644 --- a/pkg/nvml/mock/dgxa100/dgxa100.go +++ b/pkg/nvml/mock/dgxa100/dgxa100.go @@ -26,6 +26,7 @@ import ( type Server struct { mock.Interface + mock.ExtendedInterface Devices [8]nvml.Device DriverVersion string NvmlVersion string @@ -70,7 +71,7 @@ var _ nvml.GpuInstance = (*GpuInstance)(nil) var _ nvml.ComputeInstance = (*ComputeInstance)(nil) func New() nvml.Interface { - return &Server{ + server := &Server{ Devices: [8]nvml.Device{ NewDevice(0), NewDevice(1), @@ -85,10 +86,12 @@ func New() nvml.Interface { NvmlVersion: "12.550.54.15", CudaDriverVersion: 12040, } + server.setMockFuncs() + return server } func NewDevice(index int) nvml.Device { - return &Device{ + device := &Device{ UUID: "GPU-" + uuid.New().String(), Name: "Mock NVIDIA A100-SXM4-40GB", Brand: nvml.BRAND_NVIDIA, @@ -104,241 +107,255 @@ func NewDevice(index int) nvml.Device { GpuInstanceCounter: 0, MemoryInfo: nvml.Memory{42949672960, 0, 0}, } + device.setMockFuncs() + return device } func NewGpuInstance(info nvml.GpuInstanceInfo) nvml.GpuInstance { - return &GpuInstance{ + gi := &GpuInstance{ Info: info, ComputeInstances: make(map[*ComputeInstance]struct{}), ComputeInstanceCounter: 0, } + gi.setMockFuncs() + return gi } func NewComputeInstance(info nvml.ComputeInstanceInfo) nvml.ComputeInstance { - return &ComputeInstance{ + ci := &ComputeInstance{ Info: info, } + ci.setMockFuncs() + return ci } -func (n *Server) Extensions() nvml.ExtendedInterface { - return n -} - -func (n *Server) LookupSymbol(symbol string) error { - return nil -} +func (s *Server) setMockFuncs() { + s.ExtensionsFunc = func() nvml.ExtendedInterface { + return s + } -func (n *Server) Init() nvml.Return { - return nvml.SUCCESS -} + s.LookupSymbolFunc = func(symbol string) error { + return nil + } -func (n *Server) Shutdown() nvml.Return { - return nvml.SUCCESS -} + s.InitFunc = func() nvml.Return { + return nvml.SUCCESS + } -func (n *Server) SystemGetDriverVersion() (string, nvml.Return) { - return n.DriverVersion, nvml.SUCCESS -} + s.ShutdownFunc = func() nvml.Return { + return nvml.SUCCESS + } -func (n *Server) SystemGetNVMLVersion() (string, nvml.Return) { - return n.NvmlVersion, nvml.SUCCESS -} + s.SystemGetDriverVersionFunc = func() (string, nvml.Return) { + return s.DriverVersion, nvml.SUCCESS + } -func (n *Server) SystemGetCudaDriverVersion() (int, nvml.Return) { - return n.CudaDriverVersion, nvml.SUCCESS -} + s.SystemGetNVMLVersionFunc = func() (string, nvml.Return) { + return s.NvmlVersion, nvml.SUCCESS + } -func (n *Server) DeviceGetCount() (int, nvml.Return) { - return len(n.Devices), nvml.SUCCESS -} + s.SystemGetCudaDriverVersionFunc = func() (int, nvml.Return) { + return s.CudaDriverVersion, nvml.SUCCESS + } -func (n *Server) DeviceGetHandleByIndex(index int) (nvml.Device, nvml.Return) { - if index < 0 || index >= len(n.Devices) { - return nil, nvml.ERROR_INVALID_ARGUMENT + s.DeviceGetCountFunc = func() (int, nvml.Return) { + return len(s.Devices), nvml.SUCCESS } - return n.Devices[index], nvml.SUCCESS -} -func (n *Server) DeviceGetHandleByUUID(uuid string) (nvml.Device, nvml.Return) { - for _, d := range n.Devices { - if uuid == d.(*Device).UUID { - return d, nvml.SUCCESS + s.DeviceGetHandleByIndexFunc = func(index int) (nvml.Device, nvml.Return) { + if index < 0 || index >= len(s.Devices) { + return nil, nvml.ERROR_INVALID_ARGUMENT } + return s.Devices[index], nvml.SUCCESS } - return nil, nvml.ERROR_INVALID_ARGUMENT -} -func (n *Server) DeviceGetHandleByPciBusId(busID string) (nvml.Device, nvml.Return) { - for _, d := range n.Devices { - if busID == d.(*Device).PciBusID { - return d, nvml.SUCCESS + s.DeviceGetHandleByUUIDFunc = func(uuid string) (nvml.Device, nvml.Return) { + for _, d := range s.Devices { + if uuid == d.(*Device).UUID { + return d, nvml.SUCCESS + } } + return nil, nvml.ERROR_INVALID_ARGUMENT } - return nil, nvml.ERROR_INVALID_ARGUMENT -} - -func (d *Device) GetMinorNumber() (int, nvml.Return) { - return d.Minor, nvml.SUCCESS -} -func (d *Device) GetIndex() (int, nvml.Return) { - return d.Index, nvml.SUCCESS + s.DeviceGetHandleByPciBusIdFunc = func(busID string) (nvml.Device, nvml.Return) { + for _, d := range s.Devices { + if busID == d.(*Device).PciBusID { + return d, nvml.SUCCESS + } + } + return nil, nvml.ERROR_INVALID_ARGUMENT + } } -func (d *Device) GetCudaComputeCapability() (int, int, nvml.Return) { - return d.CudaComputeCapability.Major, d.CudaComputeCapability.Minor, nvml.SUCCESS -} +func (d *Device) setMockFuncs() { + d.GetMinorNumberFunc = func() (int, nvml.Return) { + return d.Minor, nvml.SUCCESS + } -func (d *Device) GetUUID() (string, nvml.Return) { - return d.UUID, nvml.SUCCESS -} + d.GetIndexFunc = func() (int, nvml.Return) { + return d.Index, nvml.SUCCESS + } -func (d *Device) GetName() (string, nvml.Return) { - return d.Name, nvml.SUCCESS -} + d.GetCudaComputeCapabilityFunc = func() (int, int, nvml.Return) { + return d.CudaComputeCapability.Major, d.CudaComputeCapability.Minor, nvml.SUCCESS + } -func (d *Device) GetBrand() (nvml.BrandType, nvml.Return) { - return d.Brand, nvml.SUCCESS -} + d.GetUUIDFunc = func() (string, nvml.Return) { + return d.UUID, nvml.SUCCESS + } -func (d *Device) GetArchitecture() (nvml.DeviceArchitecture, nvml.Return) { - return d.Architecture, nvml.SUCCESS -} + d.GetNameFunc = func() (string, nvml.Return) { + return d.Name, nvml.SUCCESS + } -func (d *Device) GetMemoryInfo() (nvml.Memory, nvml.Return) { - return d.MemoryInfo, nvml.SUCCESS -} + d.GetBrandFunc = func() (nvml.BrandType, nvml.Return) { + return d.Brand, nvml.SUCCESS + } -func (d *Device) GetPciInfo() (nvml.PciInfo, nvml.Return) { - p := nvml.PciInfo{ - PciDeviceId: 0x20B010DE, + d.GetArchitectureFunc = func() (nvml.DeviceArchitecture, nvml.Return) { + return d.Architecture, nvml.SUCCESS } - return p, nvml.SUCCESS -} -func (d *Device) SetMigMode(mode int) (nvml.Return, nvml.Return) { - d.MigMode = mode - return nvml.SUCCESS, nvml.SUCCESS -} + d.GetMemoryInfoFunc = func() (nvml.Memory, nvml.Return) { + return d.MemoryInfo, nvml.SUCCESS + } -func (d *Device) GetMigMode() (int, int, nvml.Return) { - return d.MigMode, d.MigMode, nvml.SUCCESS -} + d.GetPciInfoFunc = func() (nvml.PciInfo, nvml.Return) { + p := nvml.PciInfo{ + PciDeviceId: 0x20B010DE, + } + return p, nvml.SUCCESS + } -func (d *Device) GetGpuInstanceProfileInfo(giProfileId int) (nvml.GpuInstanceProfileInfo, nvml.Return) { - if giProfileId < 0 || giProfileId >= nvml.GPU_INSTANCE_PROFILE_COUNT { - return nvml.GpuInstanceProfileInfo{}, nvml.ERROR_INVALID_ARGUMENT + d.SetMigModeFunc = func(mode int) (nvml.Return, nvml.Return) { + d.MigMode = mode + return nvml.SUCCESS, nvml.SUCCESS } - if _, exists := MIGProfiles.GpuInstanceProfiles[giProfileId]; !exists { - return nvml.GpuInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED + d.GetMigModeFunc = func() (int, int, nvml.Return) { + return d.MigMode, d.MigMode, nvml.SUCCESS } - return MIGProfiles.GpuInstanceProfiles[giProfileId], nvml.SUCCESS -} + d.GetGpuInstanceProfileInfoFunc = func(giProfileId int) (nvml.GpuInstanceProfileInfo, nvml.Return) { + if giProfileId < 0 || giProfileId >= nvml.GPU_INSTANCE_PROFILE_COUNT { + return nvml.GpuInstanceProfileInfo{}, nvml.ERROR_INVALID_ARGUMENT + } -func (d *Device) GetGpuInstancePossiblePlacements(info *nvml.GpuInstanceProfileInfo) ([]nvml.GpuInstancePlacement, nvml.Return) { - return MIGPlacements.GpuInstancePossiblePlacements[int(info.Id)], nvml.SUCCESS -} + if _, exists := MIGProfiles.GpuInstanceProfiles[giProfileId]; !exists { + return nvml.GpuInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED + } -func (d *Device) CreateGpuInstance(info *nvml.GpuInstanceProfileInfo) (nvml.GpuInstance, nvml.Return) { - giInfo := nvml.GpuInstanceInfo{ - Device: d, - Id: d.GpuInstanceCounter, - ProfileId: info.Id, + return MIGProfiles.GpuInstanceProfiles[giProfileId], nvml.SUCCESS } - d.GpuInstanceCounter++ - gi := NewGpuInstance(giInfo) - d.GpuInstances[gi.(*GpuInstance)] = struct{}{} - return gi, nvml.SUCCESS -} -func (d *Device) CreateGpuInstanceWithPlacement(info *nvml.GpuInstanceProfileInfo, placement *nvml.GpuInstancePlacement) (nvml.GpuInstance, nvml.Return) { - giInfo := nvml.GpuInstanceInfo{ - Device: d, - Id: d.GpuInstanceCounter, - ProfileId: info.Id, - Placement: *placement, - } - d.GpuInstanceCounter++ - gi := NewGpuInstance(giInfo) - d.GpuInstances[gi.(*GpuInstance)] = struct{}{} - return gi, nvml.SUCCESS -} + d.GetGpuInstancePossiblePlacementsFunc = func(info *nvml.GpuInstanceProfileInfo) ([]nvml.GpuInstancePlacement, nvml.Return) { + return MIGPlacements.GpuInstancePossiblePlacements[int(info.Id)], nvml.SUCCESS + } -func (d *Device) GetGpuInstances(info *nvml.GpuInstanceProfileInfo) ([]nvml.GpuInstance, nvml.Return) { - var gis []nvml.GpuInstance - for gi := range d.GpuInstances { - if gi.Info.ProfileId == info.Id { - gis = append(gis, gi) + d.CreateGpuInstanceFunc = func(info *nvml.GpuInstanceProfileInfo) (nvml.GpuInstance, nvml.Return) { + giInfo := nvml.GpuInstanceInfo{ + Device: d, + Id: d.GpuInstanceCounter, + ProfileId: info.Id, } + d.GpuInstanceCounter++ + gi := NewGpuInstance(giInfo) + d.GpuInstances[gi.(*GpuInstance)] = struct{}{} + return gi, nvml.SUCCESS } - return gis, nvml.SUCCESS -} -func (gi *GpuInstance) GetInfo() (nvml.GpuInstanceInfo, nvml.Return) { - return gi.Info, nvml.SUCCESS -} + d.CreateGpuInstanceWithPlacementFunc = func(info *nvml.GpuInstanceProfileInfo, placement *nvml.GpuInstancePlacement) (nvml.GpuInstance, nvml.Return) { + giInfo := nvml.GpuInstanceInfo{ + Device: d, + Id: d.GpuInstanceCounter, + ProfileId: info.Id, + Placement: *placement, + } + d.GpuInstanceCounter++ + gi := NewGpuInstance(giInfo) + d.GpuInstances[gi.(*GpuInstance)] = struct{}{} + return gi, nvml.SUCCESS + } -func (gi *GpuInstance) GetComputeInstanceProfileInfo(ciProfileId int, ciEngProfileId int) (nvml.ComputeInstanceProfileInfo, nvml.Return) { - if ciProfileId < 0 || ciProfileId >= nvml.COMPUTE_INSTANCE_PROFILE_COUNT { - return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_INVALID_ARGUMENT + d.GetGpuInstancesFunc = func(info *nvml.GpuInstanceProfileInfo) ([]nvml.GpuInstance, nvml.Return) { + var gis []nvml.GpuInstance + for gi := range d.GpuInstances { + if gi.Info.ProfileId == info.Id { + gis = append(gis, gi) + } + } + return gis, nvml.SUCCESS } +} - if ciEngProfileId != nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED { - return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED +func (gi *GpuInstance) setMockFuncs() { + gi.GetInfoFunc = func() (nvml.GpuInstanceInfo, nvml.Return) { + return gi.Info, nvml.SUCCESS } - giProfileId := int(gi.Info.ProfileId) + gi.GetComputeInstanceProfileInfoFunc = func(ciProfileId int, ciEngProfileId int) (nvml.ComputeInstanceProfileInfo, nvml.Return) { + if ciProfileId < 0 || ciProfileId >= nvml.COMPUTE_INSTANCE_PROFILE_COUNT { + return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_INVALID_ARGUMENT + } + + if ciEngProfileId != nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED { + return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED + } - if _, exists := MIGProfiles.ComputeInstanceProfiles[giProfileId]; !exists { - return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED - } + giProfileId := int(gi.Info.ProfileId) - if _, exists := MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId]; !exists { - return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED - } + if _, exists := MIGProfiles.ComputeInstanceProfiles[giProfileId]; !exists { + return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED + } - return MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId], nvml.SUCCESS -} + if _, exists := MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId]; !exists { + return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED + } -func (gi *GpuInstance) GetComputeInstancePossiblePlacements(info *nvml.ComputeInstanceProfileInfo) ([]nvml.ComputeInstancePlacement, nvml.Return) { - return MIGPlacements.ComputeInstancePossiblePlacements[int(gi.Info.Id)][int(info.Id)], nvml.SUCCESS -} + return MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId], nvml.SUCCESS + } -func (gi *GpuInstance) CreateComputeInstance(info *nvml.ComputeInstanceProfileInfo) (nvml.ComputeInstance, nvml.Return) { - ciInfo := nvml.ComputeInstanceInfo{ - Device: gi.Info.Device, - GpuInstance: gi, - Id: gi.ComputeInstanceCounter, - ProfileId: info.Id, - } - gi.ComputeInstanceCounter++ - ci := NewComputeInstance(ciInfo) - gi.ComputeInstances[ci.(*ComputeInstance)] = struct{}{} - return ci, nvml.SUCCESS -} + gi.GetComputeInstancePossiblePlacementsFunc = func(info *nvml.ComputeInstanceProfileInfo) ([]nvml.ComputeInstancePlacement, nvml.Return) { + return MIGPlacements.ComputeInstancePossiblePlacements[int(gi.Info.Id)][int(info.Id)], nvml.SUCCESS + } -func (gi *GpuInstance) GetComputeInstances(info *nvml.ComputeInstanceProfileInfo) ([]nvml.ComputeInstance, nvml.Return) { - var cis []nvml.ComputeInstance - for ci := range gi.ComputeInstances { - if ci.Info.ProfileId == info.Id { - cis = append(cis, ci) + gi.CreateComputeInstanceFunc = func(info *nvml.ComputeInstanceProfileInfo) (nvml.ComputeInstance, nvml.Return) { + ciInfo := nvml.ComputeInstanceInfo{ + Device: gi.Info.Device, + GpuInstance: gi, + Id: gi.ComputeInstanceCounter, + ProfileId: info.Id, } + gi.ComputeInstanceCounter++ + ci := NewComputeInstance(ciInfo) + gi.ComputeInstances[ci.(*ComputeInstance)] = struct{}{} + return ci, nvml.SUCCESS } - return cis, nvml.SUCCESS -} -func (gi *GpuInstance) Destroy() nvml.Return { - delete(gi.Info.Device.(*Device).GpuInstances, gi) - return nvml.SUCCESS -} + gi.GetComputeInstancesFunc = func(info *nvml.ComputeInstanceProfileInfo) ([]nvml.ComputeInstance, nvml.Return) { + var cis []nvml.ComputeInstance + for ci := range gi.ComputeInstances { + if ci.Info.ProfileId == info.Id { + cis = append(cis, ci) + } + } + return cis, nvml.SUCCESS + } -func (ci *ComputeInstance) GetInfo() (nvml.ComputeInstanceInfo, nvml.Return) { - return ci.Info, nvml.SUCCESS + gi.DestroyFunc = func() nvml.Return { + delete(gi.Info.Device.(*Device).GpuInstances, gi) + return nvml.SUCCESS + } } -func (ci *ComputeInstance) Destroy() nvml.Return { - delete(ci.Info.GpuInstance.(*GpuInstance).ComputeInstances, ci) - return nvml.SUCCESS +func (ci *ComputeInstance) setMockFuncs() { + ci.GetInfoFunc = func() (nvml.ComputeInstanceInfo, nvml.Return) { + return ci.Info, nvml.SUCCESS + } + + ci.DestroyFunc = func() nvml.Return { + delete(ci.Info.GpuInstance.(*GpuInstance).ComputeInstances, ci) + return nvml.SUCCESS + } }