Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for JSON string in defaultImage field for overriding pipeline specific images in the mappings #293

Merged
merged 12 commits into from
Jan 28, 2025
Merged
33 changes: 23 additions & 10 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,27 @@ var containerHostPorts = map[string]string{
"live-video-to-video": "8900",
}

// Mapping for per pipeline container images.
// Default pipeline container image mapping to use if no overrides are provided.
var defaultBaseImage = "livepeer/ai-runner:latest"
var pipelineToImage = map[string]string{
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
"text-to-speech": "livepeer/ai-runner:text-to-speech",
"audio-to-text": "livepeer/ai-runner:audio-to-text",
"llm": "livepeer/ai-runner:llm",
}

var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
"noop": "livepeer/ai-runner:live-app-noop",
}

type ImageOverrides struct {
Default string `json:"default"`
Batch map[string]string `json:"batch"`
Live map[string]string `json:"live"`
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
// NOTE: ensure any docker.Client methods used in this package are added.
type DockerClient interface {
Expand All @@ -91,9 +97,9 @@ var _ DockerClient = (*docker.Client)(nil)
var dockerWaitUntilRunningFunc = dockerWaitUntilRunning

type DockerManager struct {
defaultImage string
gpus []string
modelDir string
gpus []string
modelDir string
overrides ImageOverrides

dockerClient DockerClient
// gpu ID => container name
Expand All @@ -103,7 +109,7 @@ type DockerManager struct {
mu *sync.Mutex
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
func NewDockerManager(overrides ImageOverrides, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
Expand All @@ -112,9 +118,9 @@ func NewDockerManager(defaultImage string, gpus []string, modelDir string, clien
cancel()

manager := &DockerManager{
defaultImage: defaultImage,
gpus: gpus,
modelDir: modelDir,
overrides: overrides,
dockerClient: client,
gpuContainers: make(map[string]string),
containers: make(map[string]*RunnerContainer),
Expand Down Expand Up @@ -215,17 +221,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) {
func (m *DockerManager) getContainerImageName(pipeline, modelID string) (string, error) {
if pipeline == "live-video-to-video" {
// We currently use the model ID as the live pipeline name for legacy reasons.
if image, ok := livePipelineToImage[modelID]; ok {
if image, ok := m.overrides.Live[modelID]; ok {
return image, nil
} else if image, ok := livePipelineToImage[modelID]; ok {
return image, nil
}
return "", fmt.Errorf("no container image found for live pipeline %s", modelID)
}

if image, ok := pipelineToImage[pipeline]; ok {
if image, ok := m.overrides.Batch[pipeline]; ok {
return image, nil
} else if image, ok := pipelineToImage[pipeline]; ok {
return image, nil
}

return m.defaultImage, nil
if m.overrides.Default != "" {
return m.overrides.Default, nil
}
return defaultBaseImage, nil
}

// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.
Expand Down
95 changes: 89 additions & 6 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ func NewMockServer() *MockServer {
// createDockerManager creates a DockerManager with a mock DockerClient.
func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
return &DockerManager{
defaultImage: "default-image",
gpus: []string{"gpu0"},
modelDir: "/models",
overrides: ImageOverrides{Default: "default-image"},
dockerClient: mockDockerClient,
gpuContainers: make(map[string]string),
containers: make(map[string]*RunnerContainer),
Expand All @@ -110,10 +110,10 @@ func TestNewDockerManager(t *testing.T) {
mockDockerClient := new(MockDockerClient)

createAndVerifyManager := func() *DockerManager {
manager, err := NewDockerManager("default-image", []string{"gpu0"}, "/models", mockDockerClient)
manager, err := NewDockerManager(ImageOverrides{Default: "default-image"}, []string{"gpu0"}, "/models", mockDockerClient)
require.NoError(t, err)
require.NotNil(t, manager)
require.Equal(t, "default-image", manager.defaultImage)
require.Equal(t, "default-image", manager.overrides.Default)
require.Equal(t, []string{"gpu0"}, manager.gpus)
require.Equal(t, "/models", manager.modelDir)
require.Equal(t, mockDockerClient, manager.dockerClient)
Expand Down Expand Up @@ -301,47 +301,130 @@ func TestDockerManager_returnContainer(t *testing.T) {

func TestDockerManager_getContainerImageName(t *testing.T) {
mockDockerClient := new(MockDockerClient)
manager := createDockerManager(mockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

tests := []struct {
name string
setup func(*DockerManager, *MockDockerClient)
pipeline string
modelID string
expectedImage string
expectError bool
}{
{
name: "live-video-to-video with valid modelID",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "live-video-to-video",
modelID: "streamdiffusion",
expectedImage: "livepeer/ai-runner:live-app-streamdiffusion",
expectError: false,
},
{
name: "live-video-to-video with invalid modelID",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "live-video-to-video",
modelID: "invalid-model",
expectError: true,
},
{
name: "valid pipeline",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "text-to-speech",
modelID: "",
expectedImage: "livepeer/ai-runner:text-to-speech",
expectError: false,
},
{
name: "invalid pipeline",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "invalid-pipeline",
modelID: "",
expectedImage: "default-image",
expectError: false,
},
{
name: "override default image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "custom-image",
}
},
pipeline: "",
modelID: "",
expectedImage: "custom-image",
expectError: false,
},
{
name: "override batch image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Batch: map[string]string{
"text-to-speech": "custom-image",
},
}
},
pipeline: "text-to-speech",
modelID: "",
expectedImage: "custom-image",
expectError: false,
},
{
name: "override live image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Live: map[string]string{
"streamdiffusion": "custom-image",
},
}
},
pipeline: "live-video-to-video",
modelID: "streamdiffusion",
expectedImage: "custom-image",
expectError: false,
},
{
name: "non-overridden batch image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "default-image",
Batch: map[string]string{
"text-to-speech": "custom-batch-image",
},
Live: map[string]string{
"streamdiffusion": "custom-live-image",
},
}
},
pipeline: "audio-to-text",
modelID: "",
expectedImage: "livepeer/ai-runner:audio-to-text",
expectError: false,
},
{
name: "non-overridden live image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "default-image",
Batch: map[string]string{
"text-to-speech": "custom-batch-image",
},
Live: map[string]string{
"streamdiffusion": "custom-live-image",
},
}
},
pipeline: "live-video-to-video",
modelID: "comfyui",
expectedImage: "livepeer/ai-runner:live-app-comfyui",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
image, err := manager.getContainerImageName(tt.pipeline, tt.modelID)
tt.setup(dockerManager, mockDockerClient)

image, err := dockerManager.getContainerImageName(tt.pipeline, tt.modelID)
if tt.expectError {
require.Error(t, err)
require.Equal(t, fmt.Sprintf("no container image found for live pipeline %s", tt.modelID), err.Error())
Expand Down Expand Up @@ -500,7 +583,7 @@ func TestDockerManager_createContainer(t *testing.T) {
dockerManager.gpus = []string{gpu}
dockerManager.gpuContainers = make(map[string]string)
dockerManager.containers = make(map[string]*RunnerContainer)
dockerManager.defaultImage = containerImage
dockerManager.overrides.Default = containerImage

mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil)
mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil)
Expand Down
4 changes: 2 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type Worker struct {
mu *sync.Mutex
}

func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) {
func NewWorker(imageOverrides ImageOverrides, gpus []string, modelDir string) (*Worker, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}

manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient)
manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient)
if err != nil {
return nil, err
}
Expand Down
Loading