-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.go
130 lines (120 loc) · 3.77 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package marqo
import (
"fmt"
"net/http"
)
// GetModelsResponse is the response from the server
type GetModelsResponse struct {
Models []Model `json:"models"`
}
// Model is the model from the server
type Model struct {
ModelName string `json:"model_name"`
ModelDevice string `json:"model_device"`
}
// GetModels returns the loaded models from the server.
//
// This method sends a GET request to the server to retrieve the list of loaded models.
//
// Returns:
//
// *GetModelsResponse: The response containing the list of models.
// error: An error if the operation fails, otherwise nil.
//
// The function performs the following steps:
// 1. Sends a GET request to the server to retrieve the models.
// 2. Checks the response status code and logs any errors.
// 3. Returns the list of models if the operation is successful, otherwise returns an error.
//
// Example usage:
//
// modelsResponse, err := client.GetModels()
// if err != nil {
// log.Fatalf("Failed to get models: %v", err)
// }
// fmt.Printf("Loaded models: %+v\n", modelsResponse.Models)
func (c *Client) GetModels() (*GetModelsResponse, error) {
logger := c.logger.With("method", "GetModels")
var result GetModelsResponse
resp, err := c.reqClient.
R().
SetSuccessResult(&result).
Get(c.reqClient.BaseURL + "/models")
if err != nil {
logger.Error("error getting models", "error", err)
return nil, err
}
if resp.Response.StatusCode != http.StatusOK {
logger.Error("error getting models", "status_code", resp.Response.StatusCode)
return nil, fmt.Errorf("error getting models: status code: %v", resp.Response.StatusCode)
}
logger.Info(fmt.Sprintf("response models: %+v", result))
return &result, nil
}
// EjectModelRequest is the request to eject a model
type EjectModelRequest struct {
ModelName string `validate:"required" json:"model_name"`
ModelDevice string `validate:"oneof=cpu cuda" json:"model_device"`
}
// EjectModel ejects the model from the server cache.
//
// This method sends a DELETE request to the server to remove the specified model
// from the server's cache. The model to be ejected is specified in the
// EjectModelRequest parameter.
//
// Parameters:
//
// ejectModelReq (*EjectModelRequest): The request containing the model name and device.
//
// Returns:
//
// error: An error if the operation fails, otherwise nil.
//
// The function performs the following steps:
// 1. Validates the ejectModelReq parameter.
// 2. Sends a DELETE request to the server with the model name and device as query parameters.
// 3. Checks the response status code and logs any errors.
// 4. Returns an error if the operation fails, otherwise returns nil.
//
// Example usage:
//
// ejectModelReq := &EjectModelRequest{
// ModelName: "example_model",
// ModelDevice: "cpu",
// }
// err := client.EjectModel(ejectModelReq)
// if err != nil {
// log.Fatalf("Failed to eject model: %v", err)
// }
func (c *Client) EjectModel(ejectModelReq *EjectModelRequest) error {
logger := c.logger.With("method", "EjectModel")
err := validate.Struct(ejectModelReq)
if err != nil {
logger.Error("error validating eject model request",
"error", err)
return err
}
resp, err := c.reqClient.
R().
SetQueryParams(
map[string]string{
"model_name": ejectModelReq.ModelName,
"model_device": ejectModelReq.ModelDevice,
},
).
Delete(c.reqClient.BaseURL + "/models")
if err != nil {
logger.Error("error ejecting model", "error", err)
return err
}
if resp.Response.StatusCode != http.StatusOK {
logger.Error("error ejecting model", "status_code", resp.
Response.StatusCode)
return fmt.Errorf("error ejecting model: status code: %v",
resp.Response.StatusCode)
}
logger.Info("ejected model successfully",
"model_name", ejectModelReq.ModelName,
"model_device", ejectModelReq.ModelDevice)
return nil
}