Skip to content

Commit

Permalink
feat: add TLS support to API server
Browse files Browse the repository at this point in the history
* TLS support

Add config options for TLS connections

* TLS support

Use TLS if config provides certificate and key file

* TLS support

Removed secure option as there is no need for it

* TLS support

Added TLS config options

* TLS support

Removed unnecessary secure config flag

* Update api_server.go

* chore: minor formatting

---------

Co-authored-by: devgianlu <[email protected]>
  • Loading branch information
tylkie and devgianlu authored Sep 10, 2024
1 parent b98dad6 commit ab6c852
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
19 changes: 15 additions & 4 deletions cmd/daemon/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ const timeout = 10 * time.Second

type ApiServer struct {
allowOrigin string
certFile string
keyFile string

close bool
listener net.Listener
Expand Down Expand Up @@ -244,8 +246,8 @@ type ApiEventDataShuffleContext struct {
Value bool `json:"value"`
}

func NewApiServer(address string, port int, allowOrigin string) (_ *ApiServer, err error) {
s := &ApiServer{allowOrigin: allowOrigin}
func NewApiServer(address string, port int, allowOrigin string, certFile string, keyFile string) (_ *ApiServer, err error) {
s := &ApiServer{allowOrigin: allowOrigin, certFile: certFile, keyFile: keyFile}
s.requests = make(chan ApiRequest)

s.listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", address, port))
Expand Down Expand Up @@ -502,7 +504,11 @@ func (s *ApiServer) serve() {
m.HandleFunc("/events", func(w http.ResponseWriter, r *http.Request) {
opts := &websocket.AcceptOptions{}
if len(s.allowOrigin) > 0 {
opts.OriginPatterns = []string{s.allowOrigin}
allow := s.allowOrigin
if strings.HasPrefix(allow, "https://") {
allow = s.allowOrigin[8:]
}
opts.OriginPatterns = []string{allow}
}

c, err := websocket.Accept(w, r, opts)
Expand Down Expand Up @@ -540,7 +546,12 @@ func (s *ApiServer) serve() {
}
})

err := http.Serve(s.listener, s.allowOriginMiddleware(m))
var err error
if len(s.certFile) > 0 && len(s.keyFile) > 0 {
err = http.ServeTLS(s.listener, s.allowOriginMiddleware(m), s.certFile, s.keyFile)
} else {
err = http.Serve(s.listener, s.allowOriginMiddleware(m))
}
if s.close {
return
} else if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion cmd/daemon/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ type Config struct {
Address string `yaml:"address"`
Port int `yaml:"port"`
AllowOrigin string `yaml:"allow_origin"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
} `yaml:"server"`
Credentials struct {
Type string `yaml:"type"`
Expand Down Expand Up @@ -432,7 +434,7 @@ func main() {

// create api server if needed
if cfg.Server.Enabled {
app.server, err = NewApiServer(cfg.Server.Address, cfg.Server.Port, cfg.Server.AllowOrigin)
app.server, err = NewApiServer(cfg.Server.Address, cfg.Server.Port, cfg.Server.AllowOrigin, cfg.Server.CertFile, cfg.Server.KeyFile)
if err != nil {
log.WithError(err).Fatal("failed creating api server")
}
Expand Down
10 changes: 10 additions & 0 deletions config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@
"type": "string",
"description": "The value for the Access-Control-Allow-Origin header",
"default": ""
},
"cert_file": {
"type": "string",
"description": "File path of the certificate file to use for TLS",
"default": ""
},
"key_file": {
"type": "string",
"description": "File path of the private key file to use for TLS",
"default": ""
}
}
},
Expand Down

0 comments on commit ab6c852

Please sign in to comment.