diff --git a/.gitignore b/.gitignore index cf5efeb4..18195b3b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ node_modules # Go workspaces go.work go.work.sum +.idea \ No newline at end of file diff --git a/sync2/client.go b/sync2/client.go index e072e5b5..7d1dcbd6 100644 --- a/sync2/client.go +++ b/sync2/client.go @@ -4,9 +4,11 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" + "net" "net/http" "net/url" + "strings" "time" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -39,16 +41,30 @@ type HTTPClient struct { } func NewHTTPClient(shortTimeout, longTimeout time.Duration, destHomeServer string) *HTTPClient { + baseUrl := destHomeServer + if strings.HasPrefix(destHomeServer, "/") { + baseUrl = "http://unix" + } + return &HTTPClient{ - LongTimeoutClient: &http.Client{ - Timeout: longTimeout, - Transport: otelhttp.NewTransport(http.DefaultTransport), - }, - Client: &http.Client{ - Timeout: shortTimeout, - Transport: otelhttp.NewTransport(http.DefaultTransport), - }, - DestinationServer: destHomeServer, + LongTimeoutClient: newClient(longTimeout, destHomeServer), + Client: newClient(shortTimeout, destHomeServer), + DestinationServer: baseUrl, + } +} + +func newClient(timeout time.Duration, destHomeServer string) *http.Client { + transport := http.DefaultTransport + if strings.HasPrefix(destHomeServer, "/") { + transport = &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", destHomeServer) + }, + } + } + return &http.Client{ + Timeout: timeout, + Transport: otelhttp.NewTransport(transport), } } @@ -66,7 +82,7 @@ func (v *HTTPClient) Versions(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("/versions returned HTTP %d", res.StatusCode) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return nil, err } @@ -99,7 +115,7 @@ func (v *HTTPClient) WhoAmI(ctx context.Context, accessToken string) (string, st return "", "", fmt.Errorf("/whoami returned HTTP %d", res.StatusCode) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return "", "", err } diff --git a/v3.go b/v3.go index 8040fc6b..5ea5cbda 100644 --- a/v3.go +++ b/v3.go @@ -4,7 +4,10 @@ import ( "context" "embed" "encoding/json" + "errors" "fmt" + "io/fs" + "net" "net/http" "os" "strings" @@ -216,12 +219,18 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server, tlsCert, tlsKey str // Block forever var err error - if tlsCert != "" && tlsKey != "" { - logger.Info().Msgf("listening TLS on %s", bindAddr) - err = http.ListenAndServeTLS(bindAddr, tlsCert, tlsKey, srv) + if strings.HasPrefix(bindAddr, "/") { + logger.Info().Msgf("listening on unix socket %s", bindAddr) + listener := unixSocketListener(bindAddr) + err = http.Serve(listener, srv) } else { - logger.Info().Msgf("listening on %s", bindAddr) - err = http.ListenAndServe(bindAddr, srv) + if tlsCert != "" && tlsKey != "" { + logger.Info().Msgf("listening TLS on %s", bindAddr) + err = http.ListenAndServeTLS(bindAddr, tlsCert, tlsKey, srv) + } else { + logger.Info().Msgf("listening on %s", bindAddr) + err = http.ListenAndServe(bindAddr, srv) + } } if err != nil { sentry.CaptureException(err) @@ -230,6 +239,22 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server, tlsCert, tlsKey str } } +func unixSocketListener(bindAddr string) net.Listener { + err := os.Remove(bindAddr) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + logger.Fatal().Err(err).Msg("failed to remove existing unix socket") + } + listener, err := net.Listen("unix", bindAddr) + if err != nil { + logger.Fatal().Err(err).Msg("failed to serve unix socket") + } + err = os.Chmod(bindAddr, 0755) + if err != nil { + logger.Fatal().Err(err).Msg("failed to set unix socket permissions") + } + return listener +} + type HandlerError struct { StatusCode int Err error