Skip to content

Commit

Permalink
Enhancement: can now listen on a unix socket
Browse files Browse the repository at this point in the history
  • Loading branch information
ae-govau committed Feb 4, 2024
1 parent 18472d9 commit 08ef523
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 4 deletions.
10 changes: 10 additions & 0 deletions changelog/unreleased/pull-272
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Enhancement: can now listen on a unix socket

If `--listen unix:/tmp/foo` is passed, the server will listen on a unix socket. This is triggered by the prefix `unix:`.

This is useful in combination with remote port portforwarding to enable remote server to backup locally, e.g.

```bash
rest-server --listen unix:/tmp/foo &
ssh -R /tmp/foo:/tmp/foo user@host restic -r rest:http+unix:/tmp/foo:/repo backup
```
20 changes: 16 additions & 4 deletions cmd/rest-server/listener_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"log"
"net"
"strings"

"github.com/coreos/go-systemd/v22/activation"
)
Expand All @@ -23,12 +24,23 @@ func findListener(addr string) (listener net.Listener, err error) {
switch len(listeners) {
case 0:
// no listeners found, listen manually
listener, err = net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("listen on %v failed: %w", addr, err)
if strings.HasPrefix(addr, "unix:") { // if we want to listen on a unix socket
unixAddr, err := net.ResolveUnixAddr("unix", strings.TrimPrefix(addr, "unix:"))
if err != nil {
return nil, fmt.Errorf("unable to understand unix address %s: %w", addr, err)
}
listener, err = net.ListenUnix("unix", unixAddr)
if err != nil {
return nil, fmt.Errorf("listen on %v failed: %w", addr, err)
}
} else { // assume tcp
listener, err = net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("listen on %v failed: %w", addr, err)
}
}

log.Printf("start server on %v", addr)
log.Printf("start server on %v", listener.Addr())
return listener, nil

case 1:
Expand Down
75 changes: 75 additions & 0 deletions cmd/rest-server/listener_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//go:build !windows
// +build !windows

package main

import (
"context"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
)

func TestUnixSocket(t *testing.T) {
td := t.TempDir()

// this is the socket we'll listen on and connect to
tempSocket := filepath.Join(td, "sock")

// create some content and parent dirs
if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil {
t.Fatal(err)
}

// run the following twice, to test that the server will
// cleanup its socket file when quitting, which won't happen
// if it doesn't exit gracefully
for i := 0; i < 2; i++ {
err := testServerWithArgs([]string{
"--no-auth",
"--path", filepath.Join(td, "data"),
"--listen", fmt.Sprintf("unix:%s", tempSocket),
}, time.Second, func(ctx context.Context, _ *restServerApp) error {
// custom client that will talk HTTP to unix socket
client := http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", tempSocket)
},
},
}
for _, test := range []struct {
Path string
StatusCode int
}{
{"/repo1/", http.StatusMethodNotAllowed},
{"/repo1/config", http.StatusOK},
{"/repo2/config", http.StatusNotFound},
} {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://ignored"+test.Path, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode != test.StatusCode {
return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path)
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
}

0 comments on commit 08ef523

Please sign in to comment.