Skip to content

Commit

Permalink
refactor: enhance tests and improve overridePipelineImages structure
Browse files Browse the repository at this point in the history
This commit expands the test coverage to ensure more robust behavior and
refactors the `overridePipelineImages` function to improve error
handling and readability.
  • Loading branch information
rickstaa committed Jan 22, 2025
1 parent d7fe6f6 commit aca9cc8
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 117 deletions.
88 changes: 45 additions & 43 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,57 @@ var containerHostPorts = map[string]string{
}

// Mapping for per pipeline container images.
var defaultBaseImage = "livepeer/ai-runner:latest"
var pipelineToImage = map[string]string{
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
"text-to-speech": "livepeer/ai-runner:text-to-speech",
"audio-to-text": "livepeer/ai-runner:audio-to-text",
"llm": "livepeer/ai-runner:llm",
}

var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
"noop": "livepeer/ai-runner:live-app-noop",
}

// overridePipelineImages updates base and pipeline images with the provided overrides.
func overridePipelineImages(imageOverrides string) error {
if imageOverrides == "" {
return fmt.Errorf("empty string is not a valid image override")
}

// Handle JSON format for multiple pipeline images.
var imageMap map[string]string
if err := json.Unmarshal([]byte(imageOverrides), &imageMap); err == nil {
for pipeline, image := range imageMap {
if pipeline == "base" {
defaultBaseImage = image
continue
}

// Check and update the pipeline images.
if _, exists := pipelineToImage[pipeline]; exists {
pipelineToImage[pipeline] = image
} else if _, exists := livePipelineToImage[pipeline]; exists {
livePipelineToImage[pipeline] = image
} else {
return fmt.Errorf("can't override docker image for unknown pipeline: %s", pipeline)
}
}
return nil
}

// Check for invalid docker image string.
if strings.ContainsAny(imageOverrides, "{}[]\",") {
return fmt.Errorf("invalid JSON format for image overrides")
}

// Update the base image.
defaultBaseImage = imageOverrides
return nil
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
// NOTE: ensure any docker.Client methods used in this package are added.
type DockerClient interface {
Expand Down Expand Up @@ -103,58 +140,23 @@ type DockerManager struct {
mu *sync.Mutex
}

// updatePipelineMappings updates the specified mapping with pipeline to image overriding.
// It logs a warning if a pipeline is not found in the given mapping.
//
// Parameters:
// - overrides: A map of pipeline names to custom image names.
// - mapping: The map to be updated with the provided overrides.
// - mapName: The name of the map (used for logging purposes).
func updatePipelineMappings(overrides map[string]string, mapping map[string]string, mapName string) {
for pipeline, image := range overrides {
if _, exists := mapping[pipeline]; exists {
mapping[pipeline] = image
} else {
slog.Warn("Pipeline not found in map", "map", mapName, "pipeline", pipeline)
}
}
}

// overridePipelineImages function parses a JSON string containing pipeline-to-image mappings and overrides the default mappings if valid.
// It updates the `pipelineToImage` and `livePipelineToImage` maps with custom images.
// Parameters:
// - defaultImage: A string that can either be containerImage name or a JSON string with overrides for pipeline-to-image mappings.
//
// Returns:
// - error: An error if the JSON parsing fails or if the mapping is not found in existing maps else `nil`.
func overridePipelineImages(defaultImage string) error {
if strings.HasPrefix(defaultImage, "{") || strings.HasSuffix(defaultImage, "}") {
var pipelineOverrides map[string]string
if err := json.Unmarshal([]byte(defaultImage), &pipelineOverrides); err != nil {
slog.Error("Error parsing JSON", "error", err)
return err
}
updatePipelineMappings(pipelineOverrides, pipelineToImage, "pipelineToImage")
updatePipelineMappings(pipelineOverrides, livePipelineToImage, "livePipelineToImage")
}
return nil
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
func NewDockerManager(imageOverrides string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
return nil, err
}
cancel()

// call to handle image overriding logic
if err := overridePipelineImages(defaultImage); err != nil {
return nil, err
// Override pipeline images if provided.
if imageOverrides != "" {
if err := overridePipelineImages(imageOverrides); err != nil {
return nil, err
}
}

manager := &DockerManager{
defaultImage: defaultImage,
defaultImage: defaultBaseImage,
gpus: gpus,
modelDir: modelDir,
dockerClient: client,
Expand Down
189 changes: 117 additions & 72 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,123 @@ func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
}
}

// copyMap returns a deep copy of the given map.
func copyMap(m map[string]string) map[string]string {
copy := make(map[string]string)
for k, v := range m {
copy[k] = v
}
return copy
}

func TestOverridePipelineImages(t *testing.T) {
// Store the original values of the maps.
originalDefaultBaseImage := defaultBaseImage
originalPipelineToImage := copyMap(pipelineToImage)
originalLivePipelineToImage := copyMap(livePipelineToImage)

tests := []struct {
name string
inputJSON string
expectedBase string
expectedPipelineImages map[string]string
expectedLiveImages map[string]string
expectError bool
}{
{
name: "ValidPipelineOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: map[string]string{
"segment-anything-2": "custom-image:1.0",
"text-to-speech": "speech-image:2.0",
"audio-to-text": originalPipelineToImage["audio-to-text"],
},
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "OverrideBaseImage",
inputJSON: "new-base-image:latest",
expectedBase: "new-base-image:latest",
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "OverrideBaseImageJSON",
inputJSON: `{"base": "new-base-image:latest"}`,
expectedBase: "new-base-image:latest",
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "EmptyJSON",
inputJSON: `{}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
{
name: "EmptyString",
inputJSON: "",
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
{
name: "UnknownPipeline",
inputJSON: `{"unknown-pipeline": "unknown-image:latest"}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Register a cleanup function to reset state after the subtest.
t.Cleanup(func() {
defaultBaseImage = originalDefaultBaseImage
pipelineToImage = copyMap(originalPipelineToImage)
livePipelineToImage = copyMap(originalLivePipelineToImage)
})

// Call overridePipelineImages function with the mock data.
err := overridePipelineImages(tt.inputJSON)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedBase, defaultBaseImage)

// Verify the expected pipeline images.
for pipeline, expectedImage := range tt.expectedPipelineImages {
require.Equal(t, expectedImage, pipelineToImage[pipeline])
}

// Verify the expected live pipeline images.
for livePipeline, expectedImage := range tt.expectedLiveImages {
require.Equal(t, expectedImage, livePipelineToImage[livePipeline])
}
}
})
}
}

func TestNewDockerManager(t *testing.T) {
mockDockerClient := new(MockDockerClient)

Expand Down Expand Up @@ -986,75 +1103,3 @@ func TestDockerWaitUntilRunning(t *testing.T) {
mockDockerClient.AssertExpectations(t)
})
}

func TestDockerManager_overridePipelineImages(t *testing.T) {
mockDockerClient := new(MockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

tests := []struct {
name string
inputJSON string
pipeline string
expectedImage string
expectError bool
}{
{
name: "ValidOverride",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
pipeline: "segment-anything-2",
expectedImage: "custom-image:1.0",
expectError: false,
},
{
name: "MultipleOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
pipeline: "text-to-speech",
expectedImage: "speech-image:2.0",
expectError: false,
},
{
name: "NoOverrideFallback",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
pipeline: "streamdiffusion",
expectedImage: "default-image",
expectError: false,
},
{
name: "EmptyJSON",
inputJSON: `{}`,
pipeline: "segment-anything-2",
expectedImage: "custom-image:1.0",
expectError: false,
},
{
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
pipeline: "segment-anything-2",
expectError: true,
},
{
name: "RegularStringInput",
inputJSON: "",
pipeline: "image-to-video",
expectedImage: "default-image",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Call overridePipelineImages function with the mock data.
err := overridePipelineImages(tt.inputJSON)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)

// Verify the expected image.
image, _ := dockerManager.getContainerImageName(tt.pipeline, "")
require.Equal(t, tt.expectedImage, image)
}
})
}
}
4 changes: 2 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type Worker struct {
mu *sync.Mutex
}

func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) {
func NewWorker(imageOverrides string, gpus []string, modelDir string) (*Worker, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}

manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient)
manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit aca9cc8

Please sign in to comment.