Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: adding context and merge artifact service creation into mlflow #80

Merged
merged 7 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions api/pkg/artifact/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,24 @@ import (
"google.golang.org/api/iterator"
)

type gcsClient struct {
API *storage.Client
Config Config
}
type Config struct {
Ctx context.Context
}

type Service interface {
DeleteArtifact(url string) error
DeleteArtifact(ctx context.Context, url string) error
}

func NewGcsClient(api *storage.Client, cfg Config) Service {
return &gcsClient{
API: api,
Config: cfg,
}
type GcsArtifactClient struct {
API *storage.Client
}

func (gc *gcsClient) DeleteArtifact(url string) error {
func (gac *GcsArtifactClient) DeleteArtifact(ctx context.Context, url string) error {
// Get bucket name and gcsPrefix
// the [5:] is to remove the "gs://" on the artifact uri
// ex : gs://bucketName/path → bucketName/path
gcsBucket, gcsLocation := gc.getGcsBucketAndLocation(url[5:])
gcsBucket, gcsLocation := gac.getGcsBucketAndLocation(url[5:])

// Sets the name for the bucket.
bucket := gc.API.Bucket(gcsBucket)
bucket := gac.API.Bucket(gcsBucket)

it := bucket.Objects(gc.Config.Ctx, &storage.Query{
it := bucket.Objects(ctx, &storage.Query{
Prefix: gcsLocation,
})
for {
Expand All @@ -47,16 +36,32 @@ func (gc *gcsClient) DeleteArtifact(url string) error {
if err != nil {
return err
}
if err := bucket.Object(attrs.Name).Delete(gc.Config.Ctx); err != nil {
if err := bucket.Object(attrs.Name).Delete(ctx); err != nil {
return err
}
}
return nil
}

func (gc *gcsClient) getGcsBucketAndLocation(str string) (string, string) {
func (gac *GcsArtifactClient) getGcsBucketAndLocation(str string) (string, string) {
// Split string using delimiter
// ex : bucketName/path/path1/item → (bucketName , path/path1/item)
splitStr := strings.SplitN(str, "/", 2)
return splitStr[0], splitStr[1]
}

func NewGcsArtifactClient(api *storage.Client) Service {
return &GcsArtifactClient{
API: api,
}
}

type NopArtifactClient struct{}

func (nac *NopArtifactClient) DeleteArtifact(ctx context.Context, url string) error {
return nil
}

func NewNopArtifactClient() Service {
return &NopArtifactClient{}
}
28 changes: 16 additions & 12 deletions api/pkg/artifact/mocks/artifact.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 24 additions & 9 deletions api/pkg/client/mlflow/mlflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ package mlflow

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"cloud.google.com/go/storage"

"github.com/gojek/mlp/api/pkg/artifact"
)

type Service interface {
searchRunsForExperiment(ExperimentID string) (SearchRunsResponse, error)
searchRunData(RunID string) (SearchRunResponse, error)
DeleteExperiment(ExperimentID string, deleteArtifact bool) error
DeleteRun(RunID, artifactURL string, deleteArtifact bool) error
DeleteExperiment(ctx context.Context, ExperimentID string, deleteArtifact bool) error
DeleteRun(ctx context.Context, RunID, artifactURL string, deleteArtifact bool) error
}

type mlflowService struct {
Expand All @@ -22,12 +25,24 @@ type mlflowService struct {
Config Config
}

func NewMlflowService(httpClient *http.Client, config Config, artifactService artifact.Service) Service {
func NewMlflowService(httpClient *http.Client, config Config) (Service, error) {
var artifactService artifact.Service
if config.ArtifactServiceType == "nop" {
artifactService = artifact.NewNopArtifactClient()
} else if config.ArtifactServiceType == "gcs" {
api, err := storage.NewClient(context.Background())
if err != nil {
return &mlflowService{}, fmt.Errorf("failed initializing gcs for mlflow delete package")
}
artifactService = artifact.NewGcsArtifactClient(api)
} else {
return &mlflowService{}, fmt.Errorf("invalid artifact service type")
}
return &mlflowService{
API: httpClient,
Config: config,
ArtifactService: artifactService,
}
}, nil
}

func (mfs *mlflowService) httpCall(method string, url string, body []byte, response interface{}) error {
Expand Down Expand Up @@ -101,19 +116,19 @@ func (mfs *mlflowService) searchRunData(RunID string) (SearchRunResponse, error)
return runResponse, nil
}

func (mfs *mlflowService) DeleteExperiment(ExperimentID string, deleteArtifact bool) error {
func (mfs *mlflowService) DeleteExperiment(ctx context.Context, ExperimentID string, deleteArtifact bool) error {

relatedRunData, err := mfs.searchRunsForExperiment(ExperimentID)
if err != nil {
return err
}
// Error handling for empty/no run for the experiment
if len(relatedRunData.RunsData) == 0 {
return fmt.Errorf("There are no related run for experiment id %s", ExperimentID)
return fmt.Errorf("there are no related run for experiment id %s", ExperimentID)
}
// Error Handling, when a RunID failed to delete return error
for _, run := range relatedRunData.RunsData {
err = mfs.DeleteRun(run.Info.RunID, run.Info.ArtifactURI, deleteArtifact)
err = mfs.DeleteRun(ctx, run.Info.RunID, run.Info.ArtifactURI, deleteArtifact)
if err != nil {
return fmt.Errorf("deletion failed for run_id %s for experiment id %s: %s", run.Info.RunID, ExperimentID, err)
}
Expand All @@ -122,7 +137,7 @@ func (mfs *mlflowService) DeleteExperiment(ExperimentID string, deleteArtifact b
return nil
}

func (mfs *mlflowService) DeleteRun(RunID, artifactURL string, deleteArtifact bool) error {
func (mfs *mlflowService) DeleteRun(ctx context.Context, RunID, artifactURL string, deleteArtifact bool) error {
if artifactURL == "" {
runDetail, err := mfs.searchRunData(RunID)
if err != nil {
Expand All @@ -131,7 +146,7 @@ func (mfs *mlflowService) DeleteRun(RunID, artifactURL string, deleteArtifact bo
artifactURL = runDetail.RunData.Info.ArtifactURI
}
if deleteArtifact {
err := mfs.ArtifactService.DeleteArtifact(artifactURL)
err := mfs.ArtifactService.DeleteArtifact(ctx, artifactURL)
if err != nil {
return err
}
Expand Down
Loading