Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: close connection in test, setup timeout in metric server and use… #49

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions internal/db/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package db

import (
"context"
"crypto/rand"
"fmt"
"math/rand"
"math/big"

"github.com/babylonlabs-io/staking-api-service/internal/db/model"
"github.com/babylonlabs-io/staking-api-service/internal/types"
Expand Down Expand Up @@ -86,7 +87,11 @@ func (db *Database) IncrementOverallStats(
upsertUpdate["$inc"].(bson.M)["total_stakers"] = 1
}

upsertFilter := bson.M{"_id": db.generateOverallStatsId()}
shardNumber, err := db.generateOverallStatsId()
if err != nil {
return nil, err
}
upsertFilter := bson.M{"_id": shardNumber}

_, err = overallStatsClient.UpdateOne(sessCtx, upsertFilter, upsertUpdate, options.Update().SetUpsert(true))
if err != nil {
Expand Down Expand Up @@ -132,7 +137,11 @@ func (db *Database) SubtractOverallStats(
return nil, err
}

upsertFilter := bson.M{"_id": db.generateOverallStatsId()}
shardNumber, err := db.generateOverallStatsId()
if err != nil {
return nil, err
}
upsertFilter := bson.M{"_id": shardNumber}

_, err = overallStatsClient.UpdateOne(sessCtx, upsertFilter, upsertUpdate, options.Update().SetUpsert(true))
if err != nil {
Expand Down Expand Up @@ -188,8 +197,15 @@ func (db *Database) GetOverallStats(ctx context.Context) (*model.OverallStatsDoc
// Generate the id for the overall stats document. Id is a random number ranged from 0-LogicalShardCount-1
// It's a logical shard to avoid locking the same field during concurrent writes
// The sharding number should never be reduced after roll out
func (db *Database) generateOverallStatsId() string {
return fmt.Sprint(rand.Intn(int(db.cfg.LogicalShardCount)))
func (db *Database) generateOverallStatsId() (string, error) {
max := big.NewInt(int64(db.cfg.LogicalShardCount))
// Generate a secure random number within the range [0, max)
n, err := rand.Int(rand.Reader, max)
if err != nil {
return "", err
}

return fmt.Sprint(n), nil
}

func (db *Database) updateStatsLockByFieldName(ctx context.Context, stakingTxHashHex, state string, fieldName string) error {
Expand Down
24 changes: 18 additions & 6 deletions internal/observability/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ import (
type Outcome string

const (
Success Outcome = "success"
Error Outcome = "error"
Success Outcome = "success"
Error Outcome = "error"
MetricRequestTimeout time.Duration = 5 * time.Second
MetricRequestIdleTimeout time.Duration = 10 * time.Second
)

func (O Outcome) String() string {
Expand Down Expand Up @@ -49,11 +51,21 @@ func initMetricsRouter(metricsPort int) {
promhttp.Handler().ServeHTTP(w, r)
})

// Create a custom server with timeout settings
metricsAddr := fmt.Sprintf(":%d", metricsPort)
server := &http.Server{
Addr: metricsAddr,
Handler: metricsRouter,
ReadTimeout: MetricRequestTimeout,
WriteTimeout: MetricRequestTimeout,
IdleTimeout: MetricRequestIdleTimeout,
}

// Start the server in a separate goroutine
go func() {
metricsAddr := fmt.Sprintf(":%d", metricsPort)
err := http.ListenAndServe(metricsAddr, metricsRouter)
if err != nil {
log.Fatal().Err(err).Msgf("error starting metrics server on %s", metricsAddr)
log.Printf("Starting metrics server on %s", metricsAddr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatal().Err(err).Msgf("Error starting metrics server on %s", metricsAddr)
}
}()
}
Expand Down
8 changes: 6 additions & 2 deletions tests/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ type TestServer struct {
func (ts *TestServer) Close() {
ts.Server.Close()
ts.Queues.StopReceivingMessages()
ts.Conn.Close()
ts.channel.Close()
if err := ts.Conn.Close(); err != nil {
log.Fatal("failed to close connection in test: ", err)
}
if err := ts.channel.Close(); err != nil {
log.Fatal("failed to close channel in test: ", err)
}
}

func loadTestConfig(t *testing.T) *config.Config {
Expand Down
Loading