Skip to content

Commit

Permalink
[cherry-pick]Fix endpoint (#115)
Browse files Browse the repository at this point in the history
* improve tgi startup time

* fix endpoint

---------

Co-authored-by: James <[email protected]>
  • Loading branch information
ganisback and James authored Sep 12, 2024
1 parent 3d876e1 commit 4463b45
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 29 deletions.
1 change: 1 addition & 0 deletions builder/deploy/scheduler/deploy_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ func (t *DeployRunner) makeDeployRequest() (*types.RunRequest, error) {
// runtime framework port for model
envMap["port"] = strconv.Itoa(deploy.ContainerPort)
envMap["HF_ENDPOINT"] = t.modelDownloadEndpoint // "https://hub-stg.opencsg.com/"
envMap["HF_HUB_OFFLINE"] = "1"
}

if deploy.Type == types.FinetuneType {
Expand Down
36 changes: 12 additions & 24 deletions component/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strconv"

"opencsg.com/csghub-server/builder/deploy"
deployStatus "opencsg.com/csghub-server/builder/deploy/common"
"opencsg.com/csghub-server/builder/git/gitserver"
"opencsg.com/csghub-server/builder/inference"
"opencsg.com/csghub-server/builder/store/database"
Expand Down Expand Up @@ -449,23 +448,7 @@ func (c *ModelComponent) GetServerless(ctx context.Context, namespace, name, cur
if deploy == nil {
return nil, nil
}
var endpoint string
if len(deploy.SvcName) > 0 && deploy.Status == deployStatus.Running {
cls, err := c.cluster.ByClusterID(ctx, deploy.ClusterID)
zone := ""
provider := ""
if err != nil {
return nil, fmt.Errorf("get cluster with error: %w", err)
} else {
zone = cls.Zone
provider = cls.Provider
}
regionDomain := ""
if len(zone) > 0 && len(provider) > 0 {
regionDomain = fmt.Sprintf(".%s.%s", zone, provider)
}
endpoint = fmt.Sprintf("%s%s.%s", deploy.SvcName, regionDomain, c.publicRootDomain)
}
endpoint, _ := c.generateEndpoint(ctx, deploy)

resDeploy := types.DeployRepo{
DeployID: deploy.ID,
Expand Down Expand Up @@ -890,15 +873,20 @@ func (c *ModelComponent) ListModelsOfRuntimeFrameworks(ctx context.Context, curr
newError := fmt.Errorf("failed to get public model repos, error:%w", err)
return nil, 0, newError
}
// define EnableInference
enableInference := deployType == types.InferenceType
enableFinetune := deployType == types.FinetuneType

for _, repo := range repos {
resModels = append(resModels, types.Model{
Name: repo.Name,
Nickname: repo.Nickname,
Description: repo.Description,
Path: repo.Path,
RepositoryID: repo.ID,
Private: repo.Private,
Name: repo.Name,
Nickname: repo.Nickname,
Description: repo.Description,
Path: repo.Path,
RepositoryID: repo.ID,
Private: repo.Private,
EnableInference: enableInference,
EnableFinetune: enableFinetune,
})
}
return resModels, total, nil
Expand Down
8 changes: 4 additions & 4 deletions component/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ func (c *RepoComponent) DeployDetail(ctx context.Context, detailReq types.Deploy
return nil, err
}

endpoint := c.generateEndpoint(ctx, deploy)
endpoint, _ := c.generateEndpoint(ctx, deploy)

req := types.DeployRepo{
DeployID: deploy.ID,
Expand Down Expand Up @@ -1907,13 +1907,13 @@ func (c *RepoComponent) DeployDetail(ctx context.Context, detailReq types.Deploy
}

// generate endpoint
func (c *RepoComponent) generateEndpoint(ctx context.Context, deploy *database.Deploy) string {
func (c *RepoComponent) generateEndpoint(ctx context.Context, deploy *database.Deploy) (string, string) {
var endpoint string
provider := ""
if len(deploy.SvcName) > 0 && deploy.Status == deployStatus.Running {
// todo: zone.provider.endpoint to support multi-zone, multi-provider
cls, err := c.cluster.ByClusterID(ctx, deploy.ClusterID)
zone := ""
provider := ""
if err != nil {
slog.Warn("Get cluster with error", slog.Any("error", err))
} else {
Expand All @@ -1934,7 +1934,7 @@ func (c *RepoComponent) generateEndpoint(ctx context.Context, deploy *database.D

}

return endpoint
return endpoint, provider
}

func deployStatusCodeToString(code int) string {
Expand Down
9 changes: 8 additions & 1 deletion component/space.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"log/slog"
"net/url"
"strconv"
"strings"

Expand Down Expand Up @@ -203,7 +204,13 @@ func (c *SpaceComponent) Show(ctx context.Context, namespace, name, currentUser
var endpoint string
svcName, status, _ := c.status(ctx, space)
if len(svcName) > 0 {
endpoint = fmt.Sprintf("%s.%s", svcName, c.publicRootDomain)
if c.publicRootDomain == "" {
endpoint, _ = url.JoinPath(c.serverBaseUrl, "endpoint", svcName)
endpoint = strings.Replace(endpoint, "http://", "", 1)
endpoint = strings.Replace(endpoint, "https://", "", 1)
} else {
endpoint = fmt.Sprintf("%s.%s", svcName, c.publicRootDomain)
}
}

likeExists, err := c.uls.IsExist(ctx, currentUser, space.Repository.ID)
Expand Down
7 changes: 7 additions & 0 deletions component/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,12 @@ func (c *UserComponent) ListDeploys(ctx context.Context, repoType types.Reposito

var resDeploys []types.DeployRepo
for _, deploy := range deploys {
d := &database.Deploy{
SvcName: deploy.SvcName,
ClusterID: deploy.ClusterID,
Status: deploy.Status,
}
endpoint, _ := c.repoComponent.generateEndpoint(ctx, d)
repoPath := strings.TrimPrefix(deploy.GitPath, string(repoType)+"s_")
var hardware types.HardWare
json.Unmarshal([]byte(deploy.Hardware), &hardware)
Expand Down Expand Up @@ -545,6 +551,7 @@ func (c *UserComponent) ListDeploys(ctx context.Context, repoType types.Reposito
Type: deploy.Type,
ResourceType: resourceType,
RepoTag: tag,
Endpoint: endpoint,
})
}
return resDeploys, total, nil
Expand Down
1 change: 1 addition & 0 deletions servicerunner/component/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func (s *ServiceComponent) GenerateService(request types.SVCRequest, srvName str
templateAnnotations["autoscaling.knative.dev/target-utilization-percentage"] = "90"
templateAnnotations["autoscaling.knative.dev/min-scale"] = strconv.Itoa(request.MinReplica)
templateAnnotations["autoscaling.knative.dev/max-scale"] = strconv.Itoa(request.MaxReplica)
templateAnnotations["serving.knative.dev/progress-deadline"] = fmt.Sprintf("%dm", s.env.Model.DeployTimeoutInMin)
}
initialDelaySeconds := 10
periodSeconds := 10
Expand Down

0 comments on commit 4463b45

Please sign in to comment.