Skip to content

Commit

Permalink
refactor: adding context and merge artifact service creation into mlf…
Browse files Browse the repository at this point in the history
…low (#80)

* refactor: adding context and merge artifact service creation into mlflow

* refactor: mlflow unit test

* refactor: change context argument sequence

* refactor: change unit test context sequence

* refactor: naming service and client

* refactor: remove gcs config for context

* refactor: change definition order, struct naming

---------

Co-authored-by: Alexander <[email protected]>
  • Loading branch information
zenovore and Alexander authored Apr 10, 2023
1 parent 543add5 commit 3117e9e
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 83 deletions.
45 changes: 25 additions & 20 deletions api/pkg/artifact/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,24 @@ import (
"google.golang.org/api/iterator"
)

type gcsClient struct {
API *storage.Client
Config Config
}
type Config struct {
Ctx context.Context
}

type Service interface {
DeleteArtifact(url string) error
DeleteArtifact(ctx context.Context, url string) error
}

func NewGcsClient(api *storage.Client, cfg Config) Service {
return &gcsClient{
API: api,
Config: cfg,
}
type GcsArtifactClient struct {
API *storage.Client
}

func (gc *gcsClient) DeleteArtifact(url string) error {
func (gac *GcsArtifactClient) DeleteArtifact(ctx context.Context, url string) error {
// Get bucket name and gcsPrefix
// the [5:] is to remove the "gs://" on the artifact uri
// ex : gs://bucketName/path → bucketName/path
gcsBucket, gcsLocation := gc.getGcsBucketAndLocation(url[5:])
gcsBucket, gcsLocation := gac.getGcsBucketAndLocation(url[5:])

// Sets the name for the bucket.
bucket := gc.API.Bucket(gcsBucket)
bucket := gac.API.Bucket(gcsBucket)

it := bucket.Objects(gc.Config.Ctx, &storage.Query{
it := bucket.Objects(ctx, &storage.Query{
Prefix: gcsLocation,
})
for {
Expand All @@ -47,16 +36,32 @@ func (gc *gcsClient) DeleteArtifact(url string) error {
if err != nil {
return err
}
if err := bucket.Object(attrs.Name).Delete(gc.Config.Ctx); err != nil {
if err := bucket.Object(attrs.Name).Delete(ctx); err != nil {
return err
}
}
return nil
}

func (gc *gcsClient) getGcsBucketAndLocation(str string) (string, string) {
func (gac *GcsArtifactClient) getGcsBucketAndLocation(str string) (string, string) {
// Split string using delimiter
// ex : bucketName/path/path1/item → (bucketName , path/path1/item)
splitStr := strings.SplitN(str, "/", 2)
return splitStr[0], splitStr[1]
}

func NewGcsArtifactClient(api *storage.Client) Service {
return &GcsArtifactClient{
API: api,
}
}

type NopArtifactClient struct{}

func (nac *NopArtifactClient) DeleteArtifact(ctx context.Context, url string) error {
return nil
}

func NewNopArtifactClient() Service {
return &NopArtifactClient{}
}
28 changes: 16 additions & 12 deletions api/pkg/artifact/mocks/artifact.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 24 additions & 9 deletions api/pkg/client/mlflow/mlflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ package mlflow

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"cloud.google.com/go/storage"

"github.com/gojek/mlp/api/pkg/artifact"
)

type Service interface {
searchRunsForExperiment(ExperimentID string) (SearchRunsResponse, error)
searchRunData(RunID string) (SearchRunResponse, error)
DeleteExperiment(ExperimentID string, deleteArtifact bool) error
DeleteRun(RunID, artifactURL string, deleteArtifact bool) error
DeleteExperiment(ctx context.Context, ExperimentID string, deleteArtifact bool) error
DeleteRun(ctx context.Context, RunID, artifactURL string, deleteArtifact bool) error
}

type mlflowService struct {
Expand All @@ -22,12 +25,24 @@ type mlflowService struct {
Config Config
}

func NewMlflowService(httpClient *http.Client, config Config, artifactService artifact.Service) Service {
func NewMlflowService(httpClient *http.Client, config Config) (Service, error) {
var artifactService artifact.Service
if config.ArtifactServiceType == "nop" {
artifactService = artifact.NewNopArtifactClient()
} else if config.ArtifactServiceType == "gcs" {
api, err := storage.NewClient(context.Background())
if err != nil {
return &mlflowService{}, fmt.Errorf("failed initializing gcs for mlflow delete package")
}
artifactService = artifact.NewGcsArtifactClient(api)
} else {
return &mlflowService{}, fmt.Errorf("invalid artifact service type")
}
return &mlflowService{
API: httpClient,
Config: config,
ArtifactService: artifactService,
}
}, nil
}

func (mfs *mlflowService) httpCall(method string, url string, body []byte, response interface{}) error {
Expand Down Expand Up @@ -101,19 +116,19 @@ func (mfs *mlflowService) searchRunData(RunID string) (SearchRunResponse, error)
return runResponse, nil
}

func (mfs *mlflowService) DeleteExperiment(ExperimentID string, deleteArtifact bool) error {
func (mfs *mlflowService) DeleteExperiment(ctx context.Context, ExperimentID string, deleteArtifact bool) error {

relatedRunData, err := mfs.searchRunsForExperiment(ExperimentID)
if err != nil {
return err
}
// Error handling for empty/no run for the experiment
if len(relatedRunData.RunsData) == 0 {
return fmt.Errorf("There are no related run for experiment id %s", ExperimentID)
return fmt.Errorf("there are no related run for experiment id %s", ExperimentID)
}
// Error Handling, when a RunID failed to delete return error
for _, run := range relatedRunData.RunsData {
err = mfs.DeleteRun(run.Info.RunID, run.Info.ArtifactURI, deleteArtifact)
err = mfs.DeleteRun(ctx, run.Info.RunID, run.Info.ArtifactURI, deleteArtifact)
if err != nil {
return fmt.Errorf("deletion failed for run_id %s for experiment id %s: %s", run.Info.RunID, ExperimentID, err)
}
Expand All @@ -122,7 +137,7 @@ func (mfs *mlflowService) DeleteExperiment(ExperimentID string, deleteArtifact b
return nil
}

func (mfs *mlflowService) DeleteRun(RunID, artifactURL string, deleteArtifact bool) error {
func (mfs *mlflowService) DeleteRun(ctx context.Context, RunID, artifactURL string, deleteArtifact bool) error {
if artifactURL == "" {
runDetail, err := mfs.searchRunData(RunID)
if err != nil {
Expand All @@ -131,7 +146,7 @@ func (mfs *mlflowService) DeleteRun(RunID, artifactURL string, deleteArtifact bo
artifactURL = runDetail.RunData.Info.ArtifactURI
}
if deleteArtifact {
err := mfs.ArtifactService.DeleteArtifact(artifactURL)
err := mfs.ArtifactService.DeleteArtifact(ctx, artifactURL)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 3117e9e

Please sign in to comment.