Skip to content

Commit

Permalink
feat(v0.4.4): v0.4.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Ccheers committed Jan 21, 2025
1 parent 2b7df7f commit 646a508
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 11 deletions.
49 changes: 48 additions & 1 deletion adapter/kratos/transport/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"context"
"encoding/json"
http "net/http"
"reflect"
"strings"

"github.com/ccheers/xpkg/sync/errgroup"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

type WebSocket[T any, R any] struct {
Expand All @@ -21,6 +24,7 @@ type Encoding struct {
errorEncodeFunc func(w http.ResponseWriter, err error)
requestDecodeFunc func(r *http.Request, req interface{}) error

wsReqDecodeFunc func(bs []byte, req interface{}) error
replyEncodeFunc func(ws *websocket.Conn, resp interface{})
replyErrorEncodeFunc func(ws *websocket.Conn, err error)
}
Expand Down Expand Up @@ -54,6 +58,9 @@ func defaultWSOptions() WSOptions {
requestDecodeFunc: func(r *http.Request, req interface{}) error {
return json.NewDecoder(r.Body).Decode(req)
},
wsReqDecodeFunc: func(bs []byte, req interface{}) error {
return unmarshalJSON(bs, req)
},
replyEncodeFunc: func(ws *websocket.Conn, resp interface{}) {
_ = ws.WriteJSON(map[string]interface{}{
"code": 200,
Expand Down Expand Up @@ -99,6 +106,12 @@ func WithRequestDecodeFunc(fn func(r *http.Request, req interface{}) error) WSOp
}
}

func WithWsReqDecodeFunc(fn func(bs []byte, req interface{}) error) WSOptionFunc {
return func(options *WSOptions) {
options.encoding.wsReqDecodeFunc = fn
}
}

func WithReplyEncodeFunc(fn func(ws *websocket.Conn, resp interface{})) WSOptionFunc {
return func(options *WSOptions) {
options.encoding.replyEncodeFunc = fn
Expand Down Expand Up @@ -193,7 +206,7 @@ func (x *WebSocket[T, R]) readLoop(ctx context.Context, ws *websocket.Conn) {
default:
}
var dst R
err := ws.ReadJSON(&dst)
_, bs, err := ws.ReadMessage()
if err != nil {
if strings.Contains(err.Error(), "connection reset by peer") {
return
Expand All @@ -204,6 +217,12 @@ func (x *WebSocket[T, R]) readLoop(ctx context.Context, ws *websocket.Conn) {
continue
}

err = x.options.encoding.wsReqDecodeFunc(bs, &dst)
if err != nil {
x.options.encoding.replyErrorEncodeFunc(ws, err)
continue
}

if validate, ok := (interface{})(&dst).(interface{ Validate() error }); ok {
err := validate.Validate()
if err != nil {
Expand Down Expand Up @@ -237,3 +256,31 @@ func (x *WebSocket[T, R]) writeLoop(ctx context.Context, ws *websocket.Conn) {
x.options.encoding.replyEncodeFunc(ws, resp)
}
}

var (
// unmarshalOptions is a configurable JSON format parser.
unmarshalOptions = protojson.UnmarshalOptions{
DiscardUnknown: true,
}
)

func unmarshalJSON(data []byte, v interface{}) error {
switch m := v.(type) {
case json.Unmarshaler:
return m.UnmarshalJSON(data)
case proto.Message:
return unmarshalOptions.Unmarshal(data, m)
default:
rv := reflect.ValueOf(v)
for rv := rv; rv.Kind() == reflect.Ptr; {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
rv = rv.Elem()
}
if m, ok := reflect.Indirect(rv).Interface().(proto.Message); ok {
return unmarshalOptions.Unmarshal(data, m)
}
return json.Unmarshal(data, m)
}
}
4 changes: 2 additions & 2 deletions client/xvm/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func TestVM(t *testing.T) {

datas := []string{
//"server_name",
"entity_count{server_name=\"xxxxx\"}",
//"online_number",
//"entity_count",
"online_number",
//"lock_entity_status",
//"lock_lb_status",
//"count(entity_count)",
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
go.etcd.io/etcd/client/v3 v3.5.7
go.opentelemetry.io/otel v1.28.0
go.opentelemetry.io/otel/metric v1.28.0
go.opentelemetry.io/otel/sdk v1.28.0
go.opentelemetry.io/otel/sdk/metric v1.28.0
go.opentelemetry.io/otel/trace v1.28.0
go.uber.org/zap v1.27.0
Expand Down Expand Up @@ -63,11 +64,11 @@ require (
github.com/satori/go.uuid v1.2.0 // indirect
go.etcd.io/etcd/api/v3 v3.5.7 // indirect
go.etcd.io/etcd/client/pkg/v3 v3.5.7 // indirect
go.opentelemetry.io/otel/sdk v1.28.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.23.0 // indirect
golang.org/x/text v0.17.0 // indirect
google.golang.org/genproto v0.0.0-20231212172506-995d672761c0 // indirect
Expand Down
5 changes: 4 additions & 1 deletion mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewMysql(opts ...DBConfigOption) (*sql.DB, func(), error) {
opt(&cfg)
}
netAddr := fmt.Sprintf("tcp(%s:%d)", cfg.Host, cfg.Port)
dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s&charset=utf8mb4", cfg.User, cfg.Pass, netAddr, cfg.DBName)
dsn := fmt.Sprintf("%s:%s@%s/%s?loc=Local&charset=utf8mb4&parseTime=True", cfg.User, cfg.Pass, netAddr, cfg.DBName)

driverName, err := otelsql.Register(
"mysql",
Expand All @@ -89,6 +89,9 @@ func NewMysql(opts ...DBConfigOption) (*sql.DB, func(), error) {
semconv.DBSystemMySQL,
),
otelsql.WithMeterProvider(cfg.MeterProvider),
otelsql.WithSpanOptions(otelsql.SpanOptions{
DisableErrSkip: true,
}),
)
if err != nil {
return nil, nil, err
Expand Down
2 changes: 2 additions & 0 deletions redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func NewRedis(opts ...RedisConfigOption) (*redis.Client, error) {
redisClient := redis.NewClient(&redis.Options{
Network: "tcp",
Addr: addr,
Password: cfg.Pass,
DialTimeout: time.Duration(cfg.ReadTimeout) * time.Second,
ReadTimeout: time.Duration(cfg.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(cfg.WriteTimeout) * time.Second,
Expand Down Expand Up @@ -104,6 +105,7 @@ func NewRedisV8(opts ...RedisConfigOption) (*redisv8.Client, error) {
redisClient := redisv8.NewClient(&redisv8.Options{
Network: "tcp",
Addr: addr,
Password: cfg.Pass,
DialTimeout: time.Duration(cfg.ReadTimeout) * time.Second,
ReadTimeout: time.Duration(cfg.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(cfg.WriteTimeout) * time.Second,
Expand Down
49 changes: 43 additions & 6 deletions transport/chttp/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,39 @@ import (
"io"
nethttp "net/http"
"reflect"
"strconv"
"strings"
"time"

"github.com/go-kratos/kratos/v2/transport/http"
"github.com/opendevops-cn/codo-golang-sdk/cerr"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

type options struct {
propagator propagation.TextMapPropagator
}

var optionsDefault = options{
propagator: propagation.NewCompositeTextMapPropagator(propagation.Baggage{}, propagation.TraceContext{}),
}

type Resp struct {
// 业务 code
Code cerr.ErrCode `json:"code"`
// 开发看
Msg string `json:"msg"`
// 用户看
Msg string `json:"msg"`
// 开发看
Reason string `json:"reason"`
// 服务器时间戳
Timestamp uint32 `json:"timestamp"`
// 服务器毫秒时间戳
Timestamp string `json:"timestamp"`
// 结构化数据
Result json.RawMessage `json:"result"`
// TraceID
TraceID string `json:"trace_id"`
}

var (
Expand All @@ -47,27 +60,46 @@ func ResponseEncoder(writer nethttp.ResponseWriter, request *nethttp.Request, i
if err != nil {
return err
}

ctx := optionsDefault.propagator.Extract(request.Context(), propagation.HeaderCarrier(request.Header))
sp := trace.SpanContextFromContext(ctx)
milliSecondsStr := strconv.Itoa(int(time.Now().UnixMilli()))

// 写入
writer.WriteHeader(nethttp.StatusOK)
return json.NewEncoder(writer).Encode(&Resp{
Code: cerr.SCode,
Msg: "success",
Timestamp: uint32(time.Now().Unix()),
Reason: "success",
Timestamp: milliSecondsStr,
Result: bs,
TraceID: sp.TraceID().String(),
})
}

func RequestBodyDecoder(r *nethttp.Request, i interface{}) error {
const megaBytes4 = 4 << 20
if r.ContentLength == 0 || r.ContentLength > megaBytes4 {
return nil
}

data, err := io.ReadAll(r.Body)
if err != nil {
return cerr.New(cerr.EParamUnparsedCode, err)
}

if len(data) == 0 {
return nil
}

// reset body.
r.Body = io.NopCloser(bytes.NewBuffer(data))

err = unmarshalJSON(data, i)
if err != nil {
return cerr.New(cerr.EParamUnparsedCode, err)
}

return nil
}

Expand Down Expand Up @@ -116,13 +148,18 @@ func ErrorEncoder(writer http.ResponseWriter, request *http.Request, err error)
statusCode := codeError.Code.AsHTTPCode()
msg := codeError.Code.String()

ctx := optionsDefault.propagator.Extract(request.Context(), propagation.HeaderCarrier(request.Header))
sp := trace.SpanContextFromContext(ctx)
milliSecondsStr := strconv.Itoa(int(time.Now().UnixMilli()))

// 写入
writer.WriteHeader(statusCode)
_ = json.NewEncoder(writer).Encode(&Resp{
Code: errCode,
Msg: msg,
Reason: err.Error(),
Timestamp: uint32(time.Now().Unix()),
Timestamp: milliSecondsStr,
TraceID: sp.TraceID().String(),
})
}

Expand Down

0 comments on commit 646a508

Please sign in to comment.