Skip to content

Commit

Permalink
Replace Library interface with Extensions interface
Browse files Browse the repository at this point in the history
The methods in this interface represent extensions to the core NVML API that
are only accessible through calling GetExtensions() against the Interface in
use (or at the package level for the default interface).

Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Apr 12, 2024
1 parent d72aa68 commit 1fa43fd
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 71 deletions.
2 changes: 1 addition & 1 deletion gen/nvml/generateapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var GeneratableInterfaces = []GeneratableInterfacePoperties{
{
Type: "library",
Interface: "Interface",
Exclude: []string{"Lookup"},
Exclude: []string{"LookupSymbol"},
PackageMethodsAliasedFrom: "libnvml",
},
{
Expand Down
16 changes: 11 additions & 5 deletions pkg/nvml/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@

package nvml

// ExtendedInterface defines a set of extensions to the core NVML API.
//
// TODO: For now the list of methods in this interface need to be kept in sync
// with the list of excluded methods for the Interface type in
// gen/nvml/generateapi.go. In the future we should automate this.
//
//go:generate moq -out mock/extendedinterface.go -pkg mock . ExtendedInterface:ExtendedInterface
type ExtendedInterface interface {
LookupSymbol(string) error
}

// libraryOptions hold the paramaters than can be set by a LibraryOption
type libraryOptions struct {
path string
Expand All @@ -25,11 +36,6 @@ type libraryOptions struct {
// LibraryOption represents a functional option to configure the underlying NVML library
type LibraryOption func(*libraryOptions)

// Library defines a set of functions defined on the underlying dynamic library.
type Library interface {
Lookup(string) error
}

// WithLibraryPath provides an option to set the library name to be used by the NVML library.
func WithLibraryPath(path string) LibraryOption {
return func(o *libraryOptions) {
Expand Down
50 changes: 25 additions & 25 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ func (l *library) init(opts ...LibraryOption) {
l.dl = dl.New(o.path, o.flags)
}

func (l *library) GetLibrary() Library {
func (l *library) Extensions() ExtendedInterface {
return l
}

// Lookup checks whether the specified library symbol exists in the library.
// LookupSymbol checks whether the specified library symbol exists in the library.
// Note that this requires that the library be loaded.
func (l *library) Lookup(name string) error {
func (l *library) LookupSymbol(name string) error {
if l == nil || l.refcount == 0 {
return fmt.Errorf("error looking up %s: %w", name, errLibraryNotLoaded)
}
Expand Down Expand Up @@ -198,93 +198,93 @@ func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo {
// When new versioned symbols are added, these would have to be initialized above and have
// corresponding checks and subsequent assignments added below.
func (l *library) updateVersionedSymbols() {
err := l.Lookup("nvmlInit_v2")
err := l.LookupSymbol("nvmlInit_v2")
if err == nil {
nvmlInit = nvmlInit_v2
}
err = l.Lookup("nvmlDeviceGetPciInfo_v2")
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v2")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2
}
err = l.Lookup("nvmlDeviceGetPciInfo_v3")
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v3")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3
}
err = l.Lookup("nvmlDeviceGetCount_v2")
err = l.LookupSymbol("nvmlDeviceGetCount_v2")
if err == nil {
nvmlDeviceGetCount = nvmlDeviceGetCount_v2
}
err = l.Lookup("nvmlDeviceGetHandleByIndex_v2")
err = l.LookupSymbol("nvmlDeviceGetHandleByIndex_v2")
if err == nil {
nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2
}
err = l.Lookup("nvmlDeviceGetHandleByPciBusId_v2")
err = l.LookupSymbol("nvmlDeviceGetHandleByPciBusId_v2")
if err == nil {
nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2
}
err = l.Lookup("nvmlDeviceGetNvLinkRemotePciInfo_v2")
err = l.LookupSymbol("nvmlDeviceGetNvLinkRemotePciInfo_v2")
if err == nil {
nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2
}
// Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes
// a different set of parameters than the v1 function.
//err = l.Lookup("nvmlDeviceRemoveGpu_v2")
//err = l.LookupSymbol("nvmlDeviceRemoveGpu_v2")
//if err == nil {
// nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2
//}
err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v2")
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v2")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2
}
err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v3")
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v3")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3
}
err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v4")
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v4")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4
}
err = l.Lookup("nvmlEventSetWait_v2")
err = l.LookupSymbol("nvmlEventSetWait_v2")
if err == nil {
nvmlEventSetWait = nvmlEventSetWait_v2
}
err = l.Lookup("nvmlDeviceGetAttributes_v2")
err = l.LookupSymbol("nvmlDeviceGetAttributes_v2")
if err == nil {
nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2
}
err = l.Lookup("nvmlComputeInstanceGetInfo_v2")
err = l.LookupSymbol("nvmlComputeInstanceGetInfo_v2")
if err == nil {
nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2
}
err = l.Lookup("nvmlDeviceGetComputeRunningProcesses_v2")
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v2")
if err == nil {
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2
}
err = l.Lookup("nvmlDeviceGetComputeRunningProcesses_v3")
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v3")
if err == nil {
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3
}
err = l.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v2")
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v2")
if err == nil {
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2
}
err = l.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v3")
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v3")
if err == nil {
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3
}
err = l.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v2")
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v2")
if err == nil {
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2
}
err = l.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v3")
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v3")
if err == nil {
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3
}
err = l.Lookup("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
err = l.LookupSymbol("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
if err == nil {
nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2
}
err = l.Lookup("nvmlVgpuInstanceGetLicenseInfo_v2")
err = l.LookupSymbol("nvmlVgpuInstanceGetLicenseInfo_v2")
if err == nil {
nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/nvml/lib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestLookupFromDefault(t *testing.T) {
if !tc.skipLoadLibrary {
require.ErrorIs(t, l.load(), tc.expectedLoadError)
}
require.ErrorIs(t, l.Lookup("symbol"), tc.expectedLookupErrror)
require.ErrorIs(t, l.LookupSymbol("symbol"), tc.expectedLookupErrror)
require.ErrorIs(t, l.close(), tc.expectedCloseError)
if tc.expectedCloseError == nil {
require.Equal(t, 0, int(l.refcount))
Expand Down
75 changes: 75 additions & 0 deletions pkg/nvml/mock/extendedinterface.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

74 changes: 37 additions & 37 deletions pkg/nvml/mock/interface.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 1fa43fd

Please sign in to comment.