From 9a762219cce76b84a0984ee8e094311ad3c8d84f Mon Sep 17 00:00:00 2001 From: Rohith Date: Thu, 6 Jul 2017 16:44:39 +0100 Subject: [PATCH 1/7] Upstream URL Options - making the upstream url optional, assuming no url we act as a open proxy --- CHANGELOG.md | 1 + Dockerfile | 1 + Makefile | 32 ++- README.md | 20 +- main.go => cmd/keycloak-proxy/main.go | 0 cli.go => cmd/keycloak-proxy/proxy.go | 61 ++--- .../keycloak-proxy/proxy_test.go | 4 +- config_sample.yml | 45 +++- config.go => pkg/api/config.go | 98 ++++---- config_test.go => pkg/api/config_test.go | 10 +- doc.go => pkg/api/doc.go | 215 ++++++------------ resource.go => pkg/api/resource.go | 76 ++++--- resource_test.go => pkg/api/resource_test.go | 56 +++-- pkg/certs/doc.go | 27 +++ rotation.go => pkg/certs/rotate/rotation.go | 23 +- .../certs/rotate/rotation_test.go | 19 +- pkg/constants/const.go | 124 ++++++++++ pkg/errors/errors.go | 41 ++++ cookies.go => pkg/server/cookies.go | 2 +- cookies_test.go => pkg/server/cookies_test.go | 2 +- pkg/server/doc.go | 76 +++++++ forwarding.go => pkg/server/forwarding.go | 39 +++- handlers.go => pkg/server/handlers.go | 31 +-- .../server/handlers_test.go | 45 ++-- middleware.go => pkg/server/middleware.go | 36 +-- .../server/middleware_test.go | 104 +++++---- misc.go => pkg/server/misc.go | 8 +- misc_test.go => pkg/server/misc_test.go | 2 +- oauth.go => pkg/server/oauth.go | 14 +- oauth_test.go => pkg/server/oauth_test.go | 2 +- server.go => pkg/server/server.go | 86 ++++--- server_test.go => pkg/server/server_test.go | 50 ++-- session.go => pkg/server/session.go | 22 +- session_test.go => pkg/server/session_test.go | 12 +- stores.go => pkg/server/store.go | 35 +-- user_context.go => pkg/server/user_context.go | 19 +- .../server/user_context_test.go | 2 +- store_boltdb.go => pkg/store/boltdb.go | 4 +- .../store/boltdb_test.go | 2 +- pkg/store/doc.go | 30 +++ store_redis.go => pkg/store/redis.go | 4 +- pkg/store/store.go | 43 ++++ stores_test.go => pkg/store/store_test.go | 8 +- utils.go => pkg/utils/utils.go | 164 ++++++------- utils_test.go => pkg/utils/utils_test.go | 155 +++++-------- 45 files changed, 1067 insertions(+), 783 deletions(-) rename main.go => cmd/keycloak-proxy/main.go (100%) rename cli.go => cmd/keycloak-proxy/proxy.go (77%) rename cli_test.go => cmd/keycloak-proxy/proxy_test.go (92%) rename config.go => pkg/api/config.go (61%) rename config_test.go => pkg/api/config_test.go (94%) rename doc.go => pkg/api/doc.go (77%) rename resource.go => pkg/api/resource.go (61%) rename resource_test.go => pkg/api/resource_test.go (62%) create mode 100644 pkg/certs/doc.go rename rotation.go => pkg/certs/rotate/rotation.go (88%) rename rotation_test.go => pkg/certs/rotate/rotation_test.go (76%) create mode 100644 pkg/constants/const.go create mode 100644 pkg/errors/errors.go rename cookies.go => pkg/server/cookies.go (99%) rename cookies_test.go => pkg/server/cookies_test.go (99%) create mode 100644 pkg/server/doc.go rename forwarding.go => pkg/server/forwarding.go (85%) rename handlers.go => pkg/server/handlers.go (93%) rename handlers_test.go => pkg/server/handlers_test.go (83%) rename middleware.go => pkg/server/middleware.go (92%) rename middleware_test.go => pkg/server/middleware_test.go (93%) rename misc.go => pkg/server/misc.go (94%) rename misc_test.go => pkg/server/misc_test.go (98%) rename oauth.go => pkg/server/oauth.go (91%) rename oauth_test.go => pkg/server/oauth_test.go (99%) rename server.go => pkg/server/server.go (88%) rename server_test.go => pkg/server/server_test.go (93%) rename session.go => pkg/server/session.go (82%) rename session_test.go => pkg/server/session_test.go (90%) rename stores.go => pkg/server/store.go (70%) rename user_context.go => pkg/server/user_context.go (80%) rename user_context_test.go => pkg/server/user_context_test.go (99%) rename store_boltdb.go => pkg/store/boltdb.go (96%) rename store_boltdb_test.go => pkg/store/boltdb_test.go (99%) create mode 100644 pkg/store/doc.go rename store_redis.go => pkg/store/redis.go (95%) create mode 100644 pkg/store/store.go rename stores_test.go => pkg/store/store_test.go (85%) rename utils.go => pkg/utils/utils.go (56%) rename utils_test.go => pkg/utils/utils_test.go (71%) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2065b442..4b72ad55b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ FEATURES * updated the base image to apline 3.6 in commit [0fdebaf821](https://github.com/gambol99/keycloak-proxy/pull/236/commits/0fdebaf8215e9480896f01ec7ab2ef7caa242da1) * moved to use zap for the logging [#PR237](https://github.com/gambol99/keycloak-proxy/pull/237) * making the X-Auth-Token optional in the upstream headers via the --enable-token-header [#PR247](https://github.com/gambol99/keycloak-proxy/pull/247) +* the upstream url is optional, meaning when not configured via --upstream-url is will proxy all requests to the Host header [#PR248](https://github.com/gambol99/keycloak-proxy/pull/248) * adding the ability to load a CA authority to provide trust on upstream endpoint [#PR248](https://github.com/gambol99/keycloak-proxy/pull/248) BREAKING CHANGES: diff --git a/Dockerfile b/Dockerfile index b2472f690..67990e013 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ FROM alpine:3.6 MAINTAINER Rohith Jayawardene + LABEL Name=keycloak-proxy \ Release=https://github.com/gambol99/keycloak-proxy \ Url=https://github.com/gambol99/keycloak-proxy \ diff --git a/Makefile b/Makefile index 3e3a68ad1..6ec909b7b 100644 --- a/Makefile +++ b/Makefile @@ -9,9 +9,8 @@ GIT_SHA=$(shell git --no-pager describe --always --dirty) BUILD_TIME=$(shell date '+%s') VERSION ?= $(shell awk '/release.*=/ { print $$3 }' doc.go | sed 's/"//g') DEPS=$(shell go list -f '{{range .TestImports}}{{.}} {{end}}' ./...) -PACKAGES=$(shell go list ./...) -LFLAGS ?= -X main.gitsha=${GIT_SHA} -X main.compiled=${BUILD_TIME} -VETARGS ?= -asmdecl -atomic -bool -buildtags -copylocks -methods -nilfunc -printf -rangeloops -shift -structtags -unsafeptr +PACKAGES=$(shell go list ./... | grep -v vendor) +LFLAGS ?= -X constants.Gitsha=${GIT_SHA} -X constants.Compiled=${BUILD_TIME} .PHONY: test authors changelog build docker static release lint cover vet glide-install @@ -24,7 +23,7 @@ golang: build: golang @echo "--> Compiling the project" @mkdir -p bin - go build -ldflags "${LFLAGS}" -o bin/${NAME} + go build -ldflags "${LFLAGS}" -o bin/${NAME} cmd/keycloak-proxy/*.go static: golang deps @echo "--> Compiling the static binary" @@ -39,9 +38,9 @@ docker-build: -e GOOS=linux golang:${GOVERSION} \ make static -docker-test: +docker-test: static docker @echo "--> Running the docker test" - docker run --rm -ti -p 3000:3000 \ + docker run --rm -ti --net=host \ -v ${ROOT_DIR}/config.yml:/etc/keycloak/config.yml:ro \ -v ${ROOT_DIR}/tests:/opt/tests:ro \ ${REGISTRY}/${AUTHOR}/${NAME}:${VERSION} --config /etc/keycloak/config.yml @@ -94,7 +93,7 @@ vet: @go tool vet 2>/dev/null ; if [ $$? -eq 3 ]; then \ go get golang.org/x/tools/cmd/vet; \ fi - @go tool vet $(VETARGS) *.go + @go vet $(PACKAGES) lint: @echo "--> Running golint" @@ -105,12 +104,11 @@ lint: gofmt: @echo "--> Running gofmt check" - @gofmt -s -l *.go \ - | grep -q \.go ; if [ $$? -eq 0 ]; then \ - echo "You need to runn the make format, we have file unformatted"; \ - gofmt -s -l *.go; \ - exit 1; \ - fi + @gofmt -s -l *.go | grep -q \.go ; if [ $$? -eq 0 ]; then \ + echo "You need to runn the make format, we have file unformatted"; \ + gofmt -s -l *.go; \ + exit 1; \ + fi verify: @echo "--> Verifying the code" @@ -127,18 +125,18 @@ bench: coverage: @echo "--> Running go coverage" @go test -coverprofile cover.out - @go tool cover -html=cover.out -o cover.html + @go tool cover $(PACKAGES) -html=cover.out -o cover.html cover: @echo "--> Running go cover" - @go test --cover + @go test --cover $(PACKAGES) test: @echo "--> Running the tests" @if [ ! -d "vendor" ]; then \ make glide-install; \ fi - @go test -v + @go test -v $(PACKAGES) @$(MAKE) golang @$(MAKE) gofmt @$(MAKE) vet @@ -146,7 +144,7 @@ test: all: test echo "--> Performing all tests" - @${MAKE} verify + @$(MAKE) verify @$(MAKE) bench @$(MAKE) coverage diff --git a/README.md b/README.md index 63ecb0849..9721e13e3 100644 --- a/README.md +++ b/README.md @@ -212,16 +212,16 @@ Note, anything defined in the configuration file can also be configured as comma ```shell bin/keycloak-proxy \ - --discovery-url=https://keycloak.example.com/auth/realms/ \ - --client-id= \ - --client-secret= \ - --listen=127.0.0.1:3000 \ # unix sockets format unix://path - --redirection-url=http://127.0.0.1:3000 \ - --enable-refresh-token=true \ - --encryption-key=AgXa7xRcoClDEU0ZDSH4X0XhL5Qy2Z2j \ - --upstream-url=http://127.0.0.1:80 \ - --resources="uri=/admin*|methods=GET|roles=test1,test2" \ - --resources="uri=/backend*|roles=test1" + --discovery-url=https://keycloak.example.com/auth/realms/ \ + --client-id= \ + --client-secret= \ + --listen=127.0.0.1:3000 \ # unix sockets format unix://path + --redirection-url=http://127.0.0.1:3000 \ + --enable-refresh-token=true \ + --encryption-key=AgXa7xRcoClDEU0ZDSH4X0XhL5Qy2Z2j \ + --upstream-url=http://127.0.0.1:80 \ + --resources="uri=/admin*|methods=GET|roles=test1,test2" \ + --resources="uri=/backend*|roles=test1" ``` #### **HTTP Routing** diff --git a/main.go b/cmd/keycloak-proxy/main.go similarity index 100% rename from main.go rename to cmd/keycloak-proxy/main.go diff --git a/cli.go b/cmd/keycloak-proxy/proxy.go similarity index 77% rename from cli.go rename to cmd/keycloak-proxy/proxy.go index ceb951b67..2262a384d 100644 --- a/cli.go +++ b/cmd/keycloak-proxy/proxy.go @@ -23,18 +23,27 @@ import ( "syscall" "time" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/server" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/urfave/cli" ) +const ( + envPrefix = "PROXY_" +) + // newOauthProxyApp creates a new cli application and runs it func newOauthProxyApp() *cli.App { - config := newDefaultConfig() + config := api.NewDefaultConfig() app := cli.NewApp() - app.Name = prog - app.Usage = description - app.Version = getVersion() - app.Author = author - app.Email = email + app.Name = constants.Prog + app.Usage = constants.Description + app.Version = constants.GetVersion() + app.Author = constants.Author + app.Email = constants.Email app.Flags = getCommandLineOptions() app.UsageText = "keycloak-proxy [options]" @@ -49,30 +58,30 @@ func newOauthProxyApp() *cli.App { configFile := cx.String("config") // step: do we have a configuration file? if configFile != "" { - if err := readConfigFile(configFile, config); err != nil { - return printError("unable to read the configuration file: %s, error: %s", configFile, err.Error()) + if err := utils.ReadConfigFile(configFile, config); err != nil { + return utils.PrintError("unable to read the configuration file: %s, error: %s", configFile, err.Error()) } } // step: parse the command line options if err := parseCLIOptions(cx, config); err != nil { - return printError(err.Error()) + return utils.PrintError(err.Error()) } // step: validate the configuration - if err := config.isValid(); err != nil { - return printError(err.Error()) + if err := config.IsValid(); err != nil { + return utils.PrintError(err.Error()) } // step: create the proxy - proxy, err := newProxy(config) + proxy, err := server.New(config) if err != nil { - return printError(err.Error()) + return utils.PrintError(err.Error()) } // step: start the service if err := proxy.Run(); err != nil { - return printError(err.Error()) + return utils.PrintError(err.Error()) } // step: setup the termination signals @@ -89,11 +98,11 @@ func newOauthProxyApp() *cli.App { // getCommandLineOptions builds the command line options by reflecting the Config struct and extracting // the tagged information func getCommandLineOptions() []cli.Flag { - defaults := newDefaultConfig() + defaults := api.NewDefaultConfig() var flags []cli.Flag - count := reflect.TypeOf(Config{}).NumField() + count := reflect.TypeOf(api.Config{}).NumField() for i := 0; i < count; i++ { - field := reflect.TypeOf(Config{}).Field(i) + field := reflect.TypeOf(api.Config{}).Field(i) usage, found := field.Tag.Lookup("usage") if !found { continue @@ -150,7 +159,7 @@ func getCommandLineOptions() []cli.Flag { } // parseCLIOptions parses the command line options and constructs a config object -func parseCLIOptions(cx *cli.Context, config *Config) (err error) { +func parseCLIOptions(cx *cli.Context, config *api.Config) (err error) { // step: we can ignore these options in the Config struct ignoredOptions := []string{"tag-data", "match-claims", "resources", "headers"} // step: iterate the Config and grab command line options via reflection @@ -158,7 +167,7 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) { for i := 0; i < count; i++ { field := reflect.TypeOf(config).Elem().Field(i) name := field.Tag.Get("yaml") - if containedIn(name, ignoredOptions) { + if utils.ContainedIn(name, ignoredOptions) { continue } @@ -181,29 +190,29 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) { } } if cx.IsSet("tag") { - tags, err := decodeKeyPairs(cx.StringSlice("tag")) + tags, err := utils.DecodeKeyPairs(cx.StringSlice("tag")) if err != nil { return err } - mergeMaps(config.Tags, tags) + utils.MergeMaps(config.Tags, tags) } if cx.IsSet("match-claims") { - claims, err := decodeKeyPairs(cx.StringSlice("match-claims")) + claims, err := utils.DecodeKeyPairs(cx.StringSlice("match-claims")) if err != nil { return err } - mergeMaps(config.MatchClaims, claims) + utils.MergeMaps(config.MatchClaims, claims) } if cx.IsSet("headers") { - headers, err := decodeKeyPairs(cx.StringSlice("headers")) + headers, err := utils.DecodeKeyPairs(cx.StringSlice("headers")) if err != nil { return err } - mergeMaps(config.Headers, headers) + utils.MergeMaps(config.Headers, headers) } if cx.IsSet("resources") { for _, x := range cx.StringSlice("resources") { - resource, err := newResource().parse(x) + resource, err := api.NewResource().Parse(x) if err != nil { return fmt.Errorf("invalid resource %s, %s", x, err) } diff --git a/cli_test.go b/cmd/keycloak-proxy/proxy_test.go similarity index 92% rename from cli_test.go rename to cmd/keycloak-proxy/proxy_test.go index 6a8b0037a..044385d92 100644 --- a/cli_test.go +++ b/cmd/keycloak-proxy/proxy_test.go @@ -18,6 +18,8 @@ package main import ( "testing" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/stretchr/testify/assert" "github.com/urfave/cli" ) @@ -37,7 +39,7 @@ func TestReadOptions(t *testing.T) { c := cli.NewApp() c.Flags = getCommandLineOptions() c.Action = func(cx *cli.Context) error { - parseCLIOptions(cx, &Config{}) + parseCLIOptions(cx, &api.Config{}) return nil } c.Run([]string{""}) diff --git a/config_sample.yml b/config_sample.yml index 67ccfe6df..2662dfc51 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -1,4 +1,6 @@ - +#### +# Sample Configuration +#### # is the url for retrieve the openid configuration - normally the /auth/realm/ discovery-url: https://keycloak.example.com/auth/realms/commons # the client id for the 'client' application @@ -47,14 +49,33 @@ headers: myheader_name: my_header_value # a map of claims that MUST exist in the token presented and the value is it MUST match # So for example, you could match the audience or the issuer or some custom attribute -match-claims: - aud: openvpn - iss: https://keycloak.example.com/auth/realms/commons -# a list of claims to inject into the authentication headers i.e. given_name -> X-Auth-Given-Name -add-claims: -- given_name -- family_name -- name + +virtual: +- hostname: default + resources: + - uri: /admin/test + # the methods on this url that should be protected, if missing, we assuming all + methods: [ GET ] + # a list of roles the user must have in order to accces urls under the above + roles: [ openvpn:vpn-test ] + - uri: /admin/white_listed + # permits a url prefix through, bypassing the admission controls + white-listed: true + - uri: /admin/* + methods: + - GET + roles: + - openvpn:vpn-user + - openvpn:prod-vpn + match-claims: + aud: openvpn + iss: https://keycloak.example.com/auth/realms/commons + # a list of claims to inject into the authentication headers i.e. given_name -> X-Auth-Given-Name + add-claims: + - given_name + - family_name + - name + # a collection of resource i.e. urls that you wish to protect resources: - uri: /admin/test @@ -63,7 +84,7 @@ resources: - GET # a list of roles the user must have in order to accces urls under the above roles: - - openvpn:vpn-test + - openvpn:vpn-test - uri: /admin/white_listed # permits a url prefix through, bypassing the admission controls white-listed: true @@ -71,8 +92,8 @@ resources: methods: - GET roles: - - openvpn:vpn-user - - openvpn:prod-vpn + - openvpn:vpn-user + - openvpn:prod-vpn # an array of origins (Access-Control-Allow-Origin) cors-origins: [] diff --git a/config.go b/pkg/api/config.go similarity index 61% rename from config.go rename to pkg/api/config.go index c98c0eddf..bf02f52e2 100644 --- a/config.go +++ b/pkg/api/config.go @@ -24,8 +24,8 @@ import ( "time" ) -// newDefaultConfig returns a initialized config -func newDefaultConfig() *Config { +// NewDefaultConfig returns a initialized config +func NewDefaultConfig() *Config { return &Config{ AccessTokenDuration: time.Duration(720) * time.Hour, Tags: make(map[string]string), @@ -45,57 +45,45 @@ func newDefaultConfig() *Config { } } -// isValid validates if the config is valid -func (r *Config) isValid() error { - if r.Listen == "" { +// IsValid validates if the config is valid +func (c *Config) IsValid() error { + if c.Listen == "" { return errors.New("you have not specified the listening interface") } - if r.TLSCertificate != "" && r.TLSPrivateKey == "" { + if c.TLSCertificate != "" && c.TLSPrivateKey == "" { return errors.New("you have not provided a private key") } - if r.TLSPrivateKey != "" && r.TLSCertificate == "" { + if c.TLSPrivateKey != "" && c.TLSCertificate == "" { return errors.New("you have not provided a certificate file") } - if r.TLSCertificate != "" && !fileExists(r.TLSCertificate) { - return fmt.Errorf("the tls certificate %s does not exist", r.TLSCertificate) - } - if r.TLSPrivateKey != "" && !fileExists(r.TLSPrivateKey) { - return fmt.Errorf("the tls private key %s does not exist", r.TLSPrivateKey) - } - if r.TLSCaCertificate != "" && !fileExists(r.TLSCaCertificate) { - return fmt.Errorf("the tls ca certificate file %s does not exist", r.TLSCaCertificate) - } - if r.TLSClientCertificate != "" && !fileExists(r.TLSClientCertificate) { - return fmt.Errorf("the tls client certificate %s does not exist", r.TLSClientCertificate) - } - if r.UseLetsEncrypt && r.LetsEncryptCacheDir == "" { + if c.UseLetsEncrypt && c.LetsEncryptCacheDir == "" { return fmt.Errorf("the letsencrypt cache dir has not been set") } if r.EnableForwarding { - if r.ClientID == "" { + if c.ClientID == "" { return errors.New("you have not specified the client id") } - if r.DiscoveryURL == "" { + if c.DiscoveryURL == "" { return errors.New("you have not specified the discovery url") } - if r.ForwardingUsername == "" { + if c.ForwardingUsername == "" { return errors.New("no forwarding username") } - if r.ForwardingPassword == "" { + if c.ForwardingPassword == "" { return errors.New("no forwarding password") } - if r.TLSCertificate != "" { + if c.TLSCertificate != "" { return errors.New("you don't need to specify a tls-certificate, use tls-ca-certificate instead") } - if r.TLSPrivateKey != "" { + if c.TLSPrivateKey != "" { return errors.New("you don't need to specify the tls-private-key, use tls-ca-key instead") } } else { - if r.Upstream == "" { + if c.Upstream == "" { return errors.New("you have not specified an upstream endpoint to proxy to") } - if _, err := url.Parse(r.Upstream); err != nil { + if _, err := url.Parse(c.Upstream); err != nil { return fmt.Errorf("the upstream endpoint is invalid, %s", err) } if r.SkipUpstreamTLSVerify && r.UpstreamCA != "" { @@ -103,59 +91,59 @@ func (r *Config) isValid() error { } // step: if the skip verification is off, we need the below - if !r.SkipTokenVerification { - if r.ClientID == "" { + if !c.SkipTokenVerification { + if c.ClientID == "" { return errors.New("you have not specified the client id") } - if r.DiscoveryURL == "" { + if c.DiscoveryURL == "" { return errors.New("you have not specified the discovery url") } - if strings.HasSuffix(r.RedirectionURL, "/") { - r.RedirectionURL = strings.TrimSuffix(r.RedirectionURL, "/") + if strings.HasSuffix(c.RedirectionURL, "/") { + c.RedirectionURL = strings.TrimSuffix(c.RedirectionURL, "/") } - if !r.EnableSecurityFilter { - if r.EnableHTTPSRedirect { + if !c.EnableSecurityFilter { + if c.EnableHTTPSRedirect { return errors.New("the security filter must be switch on for this feature: http-redirect") } - if r.EnableBrowserXSSFilter { + if c.EnableBrowserXSSFilter { return errors.New("the security filter must be switch on for this feature: brower-xss-filter") } - if r.EnableFrameDeny { + if c.EnableFrameDeny { return errors.New("the security filter must be switch on for this feature: frame-deny-filter") } - if r.ContentSecurityPolicy != "" { + if c.ContentSecurityPolicy != "" { return errors.New("the security filter must be switch on for this feature: content-security-policy") } - if len(r.Hostnames) > 0 { + if len(c.Hostnames) > 0 { return errors.New("the security filter must be switch on for this feature: hostnames") } } - if r.EnableEncryptedToken && r.EncryptionKey == "" { + if c.EnableEncryptedToken && c.EncryptionKey == "" { return errors.New("you have not specified an encryption key for encoding the access token") } - if r.EnableRefreshTokens && r.EncryptionKey == "" { + if c.EnableRefreshTokens && c.EncryptionKey == "" { return errors.New("you have not specified an encryption key for encoding the session state") } - if r.EnableRefreshTokens && (len(r.EncryptionKey) != 16 && len(r.EncryptionKey) != 32) { - return fmt.Errorf("the encryption key (%d) must be either 16 or 32 characters for AES-128/AES-256 selection", len(r.EncryptionKey)) + if c.EnableRefreshTokens && (len(c.EncryptionKey) != 16 && len(c.EncryptionKey) != 32) { + return fmt.Errorf("the encryption key (%d) must be either 16 or 32 characters for AES-128/AES-256 selection", len(c.EncryptionKey)) } - if !r.NoRedirects && r.SecureCookie && r.RedirectionURL != "" && !strings.HasPrefix(r.RedirectionURL, "https") { + if !c.NoRedirects && c.SecureCookie && c.RedirectionURL != "" && !strings.HasPrefix(c.RedirectionURL, "https") { return errors.New("the cookie is set to secure but your redirection url is non-tls") } - if r.StoreURL != "" { - if _, err := url.Parse(r.StoreURL); err != nil { + if c.StoreURL != "" { + if _, err := url.Parse(c.StoreURL); err != nil { return fmt.Errorf("the store url is invalid, error: %s", err) } } } // check: ensure each of the resource are valid - for _, resource := range r.Resources { - if err := resource.valid(); err != nil { + for _, resource := range c.Resources { + if err := resource.IsValid(); err != nil { return err } } // step: validate the claims are validate regex's - for k, claim := range r.MatchClaims { + for k, claim := range c.MatchClaims { if _, err := regexp.Compile(claim); err != nil { return fmt.Errorf("the claim matcher: %s for claim: %s is not a valid regex", claim, k) } @@ -165,12 +153,12 @@ func (r *Config) isValid() error { return nil } -// hasCustomSignInPage checks if there is a custom sign in page -func (r *Config) hasCustomSignInPage() bool { - return r.SignInPage != "" +// HasCustomSignInPage check if a custom page is require +func (c *Config) HasCustomSignInPage() bool { + return c.SignInPage != "" } -// hasForbiddenPage checks if there is a custom forbidden page -func (r *Config) hasCustomForbiddenPage() bool { - return r.ForbiddenPage != "" +// HasCustomForbiddenPage checks if we have a custom forbidden page +func (c *Config) HasCustomForbiddenPage() bool { + return c.ForbiddenPage != "" } diff --git a/config_test.go b/pkg/api/config_test.go similarity index 94% rename from config_test.go rename to pkg/api/config_test.go index 19704c5e4..c0c6050d0 100644 --- a/config_test.go +++ b/pkg/api/config_test.go @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package api import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestNewDefaultConfig(t *testing.T) { - if config := newDefaultConfig(); config == nil { - t.Error("we should have received a config") - } + assert.NotNil(t, NewDefaultConfig(), "we should have received a config") } func TestIsConfig(t *testing.T) { @@ -140,7 +140,7 @@ func TestIsConfig(t *testing.T) { } for i, c := range tests { - if err := c.Config.isValid(); err != nil && c.Ok { + if err := c.Config.IsValid(); err != nil && c.Ok { t.Errorf("test case %d, the config should not have errored, error: %s", i, err) } } diff --git a/doc.go b/pkg/api/doc.go similarity index 77% rename from doc.go rename to pkg/api/doc.go index fe25325e3..26c263514 100644 --- a/doc.go +++ b/pkg/api/doc.go @@ -13,86 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package api -import ( - "errors" - "fmt" - "net/http" - "strconv" - "time" - - "github.com/gambol99/go-oidc/jose" -) - -var ( - release = "v2.1.0-rc3" - gitsha = "no gitsha provided" - compiled = "0" - version = "" -) - -const ( - prog = "keycloak-proxy" - author = "Rohith" - email = "gambol99@gmail.com" - description = "is a proxy using the keycloak service for auth and authorization" - - authorizationHeader = "Authorization" - contextScopeName = "context.scope.name" - envPrefix = "PROXY_" - headerUpgrade = "Upgrade" - httpSchema = "http" - versionHeader = "X-Auth-Proxy-Version" - - oauthURL = "/oauth" - authorizationURL = "/authorize" - callbackURL = "/callback" - expiredURL = "/expired" - healthURL = "/health" - loginURL = "/login" - logoutURL = "/logout" - metricsURL = "/metrics" - tokenURL = "/token" - debugURL = "/debug/pprof" - - claimPreferredName = "preferred_username" - claimAudience = "aud" - claimResourceAccess = "resource_access" - claimRealmAccess = "realm_access" - claimResourceRoles = "roles" -) - -const ( - headerXForwardedFor = "X-Forwarded-For" - headerXForwardedProto = "X-Forwarded-Proto" - headerXForwardedProtocol = "X-Forwarded-Protocol" - headerXForwardedSsl = "X-Forwarded-Ssl" - headerXRealIP = "X-Real-IP" - headerXRequestID = "X-Request-ID" -) - -var ( - // ErrSessionNotFound no session found in the request - ErrSessionNotFound = errors.New("authentication session not found") - // ErrNoSessionStateFound means there was not persist state - ErrNoSessionStateFound = errors.New("no session state found") - // ErrInvalidSession the session is invalid - ErrInvalidSession = errors.New("invalid session identifier") - // ErrAccessTokenExpired indicates the access token has expired - ErrAccessTokenExpired = errors.New("the access token has expired") - // ErrRefreshTokenExpired indicates the refresh token as expired - ErrRefreshTokenExpired = errors.New("the refresh token has expired") - // ErrNoTokenAudience indicates their is not audience in the token - ErrNoTokenAudience = errors.New("the token does not audience in claims") - // ErrDecryption indicates we can't decrypt the token - ErrDecryption = errors.New("failed to decrypt token") -) +import "time" // Resource represents a url resource to protect type Resource struct { - // URL the url for the resource - URL string `json:"uri" yaml:"uri"` + // URI the url for the resource + URI string `json:"uri" yaml:"uri"` + // Hostname is a hostname type + Hostname string `json:"hostname" yaml:"hostname"` // Methods the method type Methods []string `json:"methods" yaml:"methods"` // WhiteListed permits the prefix through @@ -101,6 +31,60 @@ type Resource struct { Roles []string `json:"roles" yaml:"roles"` } +// VirtualHost defines the structure for a virtual host +type VirtualHost struct { + // Hostname is the hostname of the site + Hostname string `json:"hostname" yaml:"hostname"` + // Upstream is the defined upstream + Upstream string `json:"upstream" yaml:"upstream"` + // Roles is a collection of roles to access this virtual host + Roles []string `json:"roles" yaml:"roles"` + // Resources is a collection of resources for this virtual host + Resources []*Resource `json:"resources" yaml:"resources"` + // Headers permits adding customs headers across the board + Headers map[string]string `json:"headers" yaml:"headers"` + + // TLSCertificate is the location for a tls certificate + TLSCertificate string `json:"tls-cert" yaml:"tls-cert"` + // TLSPrivateKey is the location of a tls private key + TLSPrivateKey string `json:"tls-private-key" yaml:"tls-private-key"` + // TLSCaCertificate is the CA certificate which the client cert must be signed + TLSCaCertificate string `json:"tls-ca-certificate" yaml:"tls-ca-certificate"` + + // EnableTokenHeader adds the JWT token to the upstream authentication headers + EnableTokenHeader bool `json:"enable-token-header" yaml:"enable-token-header"` + // EnableAuthorizationHeader indicates we should pass the authorization header + EnableAuthorizationHeader bool `json:"enable-authorization-header" yaml:"enable-authorization-header"` + + // MatchClaims is a series of checks, the claims in the token must match those here + MatchClaims map[string]string `json:"match-claims" yaml:"match-claims"` + // AddClaims is a series of claims that should be added to the auth headers + AddClaims []string `json:"add-claims" yaml:"add-claims"` + // CorsOrigins is a list of origins permitted + CorsOrigins []string `json:"cors-origins" yaml:"cors-origins"` + // CorsMethods is a set of access control methods + CorsMethods []string `json:"cors-methods" yaml:"cors-methods"` + // CorsHeaders is a set of cors headers + CorsHeaders []string `json:"cors-headers" yaml:"cors-headers"` + // CorsExposedHeaders are the exposed header fields + CorsExposedHeaders []string `json:"cors-exposed-headers" yaml:"cors-exposed-headers"` + // CorsCredentials set the credentials flag + CorsCredentials bool `json:"cors-credentials" yaml:"cors-credentials"` + // CorsMaxAge is the age for CORS + CorsMaxAge time.Duration `json:"cors-max-age" yaml:"cors-max-age"` + + // NoRedirects informs we should hand back a 401 not a redirect + NoRedirects bool `json:"no-redirects" yaml:"no-redirects"` + // SkipTokenVerification tells the service to skipp verifying the access token - for testing purposes + SkipTokenVerification bool `json:"skip-token-verification" yaml:"skip-token-verification"` + // UpstreamKeepalives specifies whether we use keepalives on the upstream + UpstreamKeepalives bool `json:"upstream-keepalives" yaml:"upstream-keepalives"` + // UpstreamTimeout is the maximum amount of time a dial will wait for a connect to complete + UpstreamTimeout time.Duration `json:"upstream-timeout" yaml:"upstream-timeout"` + // UpstreamKeepaliveTimeout + UpstreamKeepaliveTimeout time.Duration `json:"upstream-keepalive-timeout" yaml:"upstream-keepalive-timeout"` +} + // Config is the configuration for the proxy type Config struct { // ConfigFile is the binding interface @@ -131,6 +115,8 @@ type Config struct { Resources []*Resource `json:"resources" yaml:"resources" usage:"list of resources 'uri=/admin|methods=GET,PUT|roles=role1,role2'"` // Headers permits adding customs headers across the board Headers map[string]string `json:"headers" yaml:"headers" usage:"custom headers to the upstream request, key=value"` + // VirtualHosts is a collection of vitualhosts + VirtualHosts []*VirtualHost `json:"virtual-hosts" yaml:"virtual-hosts"` // EnableTokenHeader adds the JWT token to the upstream authentication headers EnableTokenHeader bool `json:"enable-token-header" yaml:"enable-token-header" usage:"enables the token authentication header X-Auth-Token to upstream"` @@ -162,6 +148,8 @@ type Config struct { EnableContentNoSniff bool `json:"filter-content-nosniff" yaml:"filter-content-nosniff" usage:"adds the X-Content-Type-Options header with the value nosniff"` // EnableFrameDeny indicates the filter is on EnableFrameDeny bool `json:"filter-frame-deny" yaml:"filter-frame-deny" usage:"enable to the frame deny header"` + // EnableProxyProtocol controls the proxy protocol + EnableProxyProtocol bool `json:"enabled-proxy-protocol" yaml:"enabled-proxy-protocol" usage:"enable proxy protocol"` // ContentSecurityPolicy allows the Content-Security-Policy header value to be set with a custom value ContentSecurityPolicy string `json:"content-security-policy" yaml:"content-security-policy" usage:"specify the content security policy"` // LocalhostMetrics indicated the metrics can only be consume via localhost @@ -233,8 +221,6 @@ type Config struct { UpstreamKeepaliveTimeout time.Duration `json:"upstream-keepalive-timeout" yaml:"upstream-keepalive-timeout" usage:"specifies the keep-alive period for an active network connection"` // Verbose switches on debug logging Verbose bool `json:"verbose" yaml:"verbose" usage:"switch on debug / verbose logging"` - // EnableProxyProtocol controls the proxy protocol - EnableProxyProtocol bool `json:"enabled-proxy-protocol" yaml:"enabled-proxy-protocol" usage:"enable proxy protocol"` // UseLetsEncrypt controls if we should use letsencrypt to retrieve certificates UseLetsEncrypt bool `json:"use-letsencrypt" yaml:"use-letsencrypt" usage:"use letsencrypt for certificates"` @@ -259,76 +245,3 @@ type Config struct { // DisableAllLogging indicates no logging at all DisableAllLogging bool `json:"disable-all-logging" yaml:"disable-all-logging" usage:"disables all logging to stdout and stderr"` } - -// getVersion returns the proxy version -func getVersion() string { - if version == "" { - tm, err := strconv.ParseInt(compiled, 10, 64) - if err != nil { - return "unable to parse compiled time" - } - version = fmt.Sprintf("%s (git+sha: %s, built: %s)", release, gitsha, time.Unix(tm, 0).Format("02-01-2006")) - } - - return version -} - -// RequestScope is a request level context scope passed between middleware -type RequestScope struct { - // AccessDenied indicates the request should not be proxied on - AccessDenied bool - // Identity is the user Identity of the request - Identity *userContext -} - -// storage is used to hold the offline refresh token, assuming you don't want to use -// the default practice of a encrypted cookie -type storage interface { - // Set the token to the store - Set(string, string) error - // Get retrieves a token from the store - Get(string) (string, error) - // Delete removes a key from the store - Delete(string) error - // Close is used to close off any resources - Close() error -} - -// reverseProxy is a wrapper -type reverseProxy interface { - ServeHTTP(rw http.ResponseWriter, req *http.Request) -} - -// userContext represents a user -type userContext struct { - // the id of the user - id string - // the email associated to the user - email string - // a name of the user - name string - // the preferred name - preferredName string - // the expiration of the access token - expiresAt time.Time - // a set of roles associated - roles []string - // the audience for the token - audience string - // the access token itself - token jose.JWT - // the claims associated to the token - claims jose.Claims - // whether the context is from a session cookie or authorization header - bearerToken bool -} - -// tokenResponse -type tokenResponse struct { - TokenType string `json:"token_type"` - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` - RefreshToken string `json:"refresh_token,omitempty"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope,omitempty"` -} diff --git a/resource.go b/pkg/api/resource.go similarity index 61% rename from resource.go rename to pkg/api/resource.go index 2a5eb8003..e3c72a446 100644 --- a/resource.go +++ b/pkg/api/resource.go @@ -13,48 +13,52 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package api import ( "errors" "fmt" "strconv" "strings" + + "github.com/gambol99/keycloak-proxy/pkg/constants" ) -func newResource() *Resource { +// NewResource returns a new resource +func NewResource() *Resource { return &Resource{ - Methods: allHTTPMethods, + Methods: constants.AllHTTPMethods, } } -// parse decodes a resource definition -func (r *Resource) parse(resource string) (*Resource, error) { +// Parse decodes a resource definition +func (r *Resource) Parse(resource string) (*Resource, error) { if resource == "" { return nil, errors.New("the resource has no options") } for _, x := range strings.Split(resource, "|") { - kp := strings.Split(x, "=") - if len(kp) != 2 { + items := strings.Split(x, "=") + if len(items) != 2 { return nil, errors.New("invalid resource keypair, should be (uri|roles|methods|white-listed)=comma_values") } - switch kp[0] { + + switch items[0] { case "uri": - r.URL = kp[1] - if !strings.HasPrefix(r.URL, "/") { + r.URI = items[1] + if !strings.HasPrefix(r.URI, "/") { return nil, errors.New("the resource uri should start with a '/'") } case "methods": - r.Methods = strings.Split(kp[1], ",") + r.Methods = strings.Split(items[1], ",") if len(r.Methods) == 1 { if r.Methods[0] == "any" || r.Methods[0] == "ANY" { - r.Methods = allHTTPMethods + r.Methods = constants.AllHTTPMethods } } case "roles": - r.Roles = strings.Split(kp[1], ",") + r.Roles = strings.Split(items[1], ",") case "white-listed": - value, err := strconv.ParseBool(kp[1]) + value, err := strconv.ParseBool(items[1]) if err != nil { return nil, errors.New("the value of whitelisted must be true|TRUE|T or it's false equivalent") } @@ -67,55 +71,59 @@ func (r *Resource) parse(resource string) (*Resource, error) { return r, nil } -// valid ensure the resource is valid -func (r *Resource) valid() error { +// IsValid ensure the resource is valid +func (r *Resource) IsValid() error { + if strings.HasPrefix(r.URI, constants.OauthURL) { + return errors.New("this is used by the oauth handlers") + } if r.Methods == nil { r.Methods = make([]string, 0) } if r.Roles == nil { r.Roles = make([]string, 0) } - if strings.HasPrefix(r.URL, oauthURL) { - return errors.New("this is used by the oauth handlers") - } - if r.URL == "" { - return errors.New("resource does not have url") + if r.URI == "" { + return errors.New("neither uri or hostname specified") } // step: add any of no methods if len(r.Methods) <= 0 { - r.Methods = allHTTPMethods - } - // step: check the method is valid - for _, m := range r.Methods { - if !isValidHTTPMethod(m) { - return fmt.Errorf("invalid method %s", m) - } + r.Methods = constants.AllHTTPMethods } return nil } -// getRoles returns a list of roles for this resource -func (r Resource) getRoles() string { +// GetRoles returns a list of roles for this resource +func (r Resource) GetRoles() string { return strings.Join(r.Roles, ",") } // String returns a string representation of the resource func (r Resource) String() string { if r.WhiteListed { - return fmt.Sprintf("uri: %s, white-listed", r.URL) + return fmt.Sprintf("uri: %s, white-listed", r.URI) } - roles := "authentication only" + roles := "auth only" methods := "ANY" if len(r.Roles) > 0 { roles = strings.Join(r.Roles, ",") } - if len(r.Methods) > 0 { methods = strings.Join(r.Methods, ",") } - return fmt.Sprintf("uri: %s, methods: %s, required: %s", r.URL, methods, roles) + return fmt.Sprintf("uri: %s, methods: %s, required: %s", r.URI, methods, roles) +} + +// isValidHTTPMethod ensure this is a valid http method type +func isValidHTTPMethod(method string) bool { + for _, x := range constants.AllHTTPMethods { + if method == x { + return true + } + } + + return false } diff --git a/resource_test.go b/pkg/api/resource_test.go similarity index 62% rename from resource_test.go rename to pkg/api/resource_test.go index 688252551..9331ec13a 100644 --- a/resource_test.go +++ b/pkg/api/resource_test.go @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package api import ( "testing" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/stretchr/testify/assert" ) @@ -33,7 +35,7 @@ func TestDecodeResourceBad(t *testing.T) { {Option: "uri=/|white-listed=ERROR"}, } for i, c := range cs { - if _, err := newResource().parse(c.Option); err == nil { + if _, err := NewResource().Parse(c.Option); err == nil { t.Errorf("case %d should have errored", i) } } @@ -46,35 +48,35 @@ func TestResourceParseOk(t *testing.T) { }{ { Option: "uri=/admin", - Resource: &Resource{URL: "/admin", Methods: allHTTPMethods}, + Resource: &Resource{URI: "/admin", Methods: constants.AllHTTPMethods}, }, { Option: "uri=/", - Resource: &Resource{URL: "/", Methods: allHTTPMethods}, + Resource: &Resource{URI: "/", Methods: constants.AllHTTPMethods}, }, { Option: "uri=/admin/sso|roles=test,test1", - Resource: &Resource{URL: "/admin/sso", Roles: []string{"test", "test1"}, Methods: allHTTPMethods}, + Resource: &Resource{URI: "/admin/sso", Roles: []string{"test", "test1"}, Methods: constants.AllHTTPMethods}, }, { Option: "uri=/admin/sso|roles=test,test1|methods=GET,POST", - Resource: &Resource{URL: "/admin/sso", Roles: []string{"test", "test1"}, Methods: []string{"GET", "POST"}}, + Resource: &Resource{URI: "/admin/sso", Roles: []string{"test", "test1"}, Methods: []string{"GET", "POST"}}, }, { Option: "uri=/allow_me|white-listed=true", - Resource: &Resource{URL: "/allow_me", WhiteListed: true, Methods: allHTTPMethods}, + Resource: &Resource{URI: "/allow_me", WhiteListed: true, Methods: constants.AllHTTPMethods}, }, { Option: "uri=/*|methods=any", - Resource: &Resource{URL: "/*", Methods: allHTTPMethods}, + Resource: &Resource{URI: "/*", Methods: constants.AllHTTPMethods}, }, { Option: "uri=/*|methods=any", - Resource: &Resource{URL: "/*", Methods: allHTTPMethods}, + Resource: &Resource{URI: "/*", Methods: constants.AllHTTPMethods}, }, } for i, x := range cs { - r, err := newResource().parse(x.Option) + r, err := NewResource().Parse(x.Option) assert.NoError(t, err, "case %d should not have errored with: %s", i, err) assert.Equal(t, r, x.Resource, "case %d, expected: %#v, got: %#v", i, x.Resource, r) } @@ -86,35 +88,45 @@ func TestIsValid(t *testing.T) { Ok bool }{ { - Resource: &Resource{URL: "/test"}, - Ok: true, + Resource: &Resource{URI: "/test"}, Ok: true, }, { - Resource: &Resource{URL: "/test", Methods: []string{"GET"}}, - Ok: true, + Resource: &Resource{URI: "/test", Methods: []string{"GET"}}, Ok: true, }, { Resource: &Resource{}, }, { - Resource: &Resource{URL: "/oauth"}, + Resource: &Resource{URI: "/oauth"}, }, { - Resource: &Resource{ - URL: "/test", - Methods: []string{"NO_SUCH_METHOD"}, - }, + Resource: &Resource{URI: "/test", Methods: []string{"NO_SUCH_METHOD"}}, }, } for i, c := range testCases { - err := c.Resource.valid() - if err != nil && c.Ok { + if err := c.Resource.IsValid(); err != nil && c.Ok { t.Errorf("case %d should not have failed, error: %s", i, err) } } } +func TestIsValidHTTPMethod(t *testing.T) { + cs := []struct { + Method string + Ok bool + }{ + {Method: "GET", Ok: true}, + {Method: "GETT"}, + {Method: "CONNECT", Ok: false}, + {Method: "PUT", Ok: true}, + {Method: "PATCH", Ok: true}, + } + for _, x := range cs { + assert.Equal(t, x.Ok, isValidHTTPMethod(x.Method)) + } +} + func TestResourceString(t *testing.T) { resource := &Resource{ Roles: []string{"1", "2", "3"}, @@ -129,7 +141,7 @@ func TestGetRoles(t *testing.T) { Roles: []string{"1", "2", "3"}, } - if resource.getRoles() != "1,2,3" { + if resource.GetRoles() != "1,2,3" { t.Error("the resource roles not as expected") } } diff --git a/pkg/certs/doc.go b/pkg/certs/doc.go new file mode 100644 index 000000000..89d493980 --- /dev/null +++ b/pkg/certs/doc.go @@ -0,0 +1,27 @@ +/* +Copyright 2015 All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certs + +import ( + "crypto/tls" +) + +// Provider is a TLS certificate provider +type Provider interface { + // GetCertificate returns a TLS certificate + GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) +} diff --git a/rotation.go b/pkg/certs/rotate/rotation.go similarity index 88% rename from rotation.go rename to pkg/certs/rotate/rotation.go index 3d8479ee6..888785d85 100644 --- a/rotation.go +++ b/pkg/certs/rotate/rotation.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package rotate import ( "crypto/tls" @@ -21,6 +21,9 @@ import ( "path" "sync" + "github.com/gambol99/keycloak-proxy/pkg/certs" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/fsnotify/fsnotify" "go.uber.org/zap" ) @@ -37,20 +40,26 @@ type certificationRotation struct { log *zap.Logger } -// newCertificateRotator creates a new certificate -func newCertificateRotator(cert, key string, log *zap.Logger) (*certificationRotation, error) { +// New creates a new certificate +func New(cert, key string, log *zap.Logger) (certs.Provider, error) { // step: attempt to load the certificate certificate, err := tls.LoadX509KeyPair(cert, key) if err != nil { return nil, err } - // step: are we watching the files for changes? - return &certificationRotation{ + svc := &certificationRotation{ certificate: certificate, certificateFile: cert, log: log, privateKeyFile: key, - }, nil + } + + // start watching the certificates + if err := svc.watch(); err != nil { + return nil, err + } + + return svc, nil } // watch is responsible for adding a file notification and watch on the files for changes @@ -79,7 +88,7 @@ func (c *certificationRotation) watch() error { case event := <-watcher.Events: if event.Op&fsnotify.Write == fsnotify.Write { // step: does the change effect our files? - if !containedIn(event.Name, filewatchPaths) { + if !utils.ContainedIn(event.Name, filewatchPaths) { continue } // step: reload the certificate diff --git a/rotation_test.go b/pkg/certs/rotate/rotation_test.go similarity index 76% rename from rotation_test.go rename to pkg/certs/rotate/rotation_test.go index 01f96af75..9710ed9bd 100644 --- a/rotation_test.go +++ b/pkg/certs/rotate/rotation_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package rotate import ( "crypto/tls" @@ -24,12 +24,13 @@ import ( ) const ( - testCertificateFile = "./tests/proxy.pem" - testPrivateKeyFile = "./tests/proxy-key.pem" + testCertificateFile = "../../../tests/proxy.pem" + testPrivateKeyFile = "../../../tests/proxy-key.pem" ) func newTestCertificateRotator(t *testing.T) *certificationRotation { - c, err := newCertificateRotator(testCertificateFile, testPrivateKeyFile, zap.NewNop()) + p, err := New(testCertificateFile, testPrivateKeyFile, zap.NewNop()) + c := p.(*certificationRotation) assert.NotNil(t, c) assert.Equal(t, testCertificateFile, c.certificateFile) assert.Equal(t, testPrivateKeyFile, c.privateKeyFile) @@ -41,13 +42,13 @@ func newTestCertificateRotator(t *testing.T) *certificationRotation { } func TestNewCeritifacteRotator(t *testing.T) { - c, err := newCertificateRotator(testCertificateFile, testPrivateKeyFile, zap.NewNop()) + c, err := New(testCertificateFile, testPrivateKeyFile, zap.NewNop()) assert.NotNil(t, c) assert.NoError(t, err) } func TestNewCeritifacteRotatorFailure(t *testing.T) { - c, err := newCertificateRotator("./tests/does_not_exist", testPrivateKeyFile, zap.NewNop()) + c, err := New("./tests/does_not_exist", testPrivateKeyFile, zap.NewNop()) assert.Nil(t, c) assert.Error(t, err) } @@ -68,9 +69,3 @@ func TestLoadCertificate(t *testing.T) { assert.NoError(t, err) assert.Equal(t, &tls.Certificate{}, crt) } - -func TestWatchCertificate(t *testing.T) { - c := newTestCertificateRotator(t) - err := c.watch() - assert.NoError(t, err) -} diff --git a/pkg/constants/const.go b/pkg/constants/const.go new file mode 100644 index 000000000..789ffdafc --- /dev/null +++ b/pkg/constants/const.go @@ -0,0 +1,124 @@ +/* +Copyright 2017 All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package constants + +import ( + "fmt" + "net/http" + "strconv" + "time" +) + +var ( + // Release is the release version + Release = "v2.1.0-rc2" + // Gitsha is the gitsha + Gitsha = "no gitsha provided" + // Compiled is the build time + Compiled = "0" + // Version is a version string + Version = "" +) + +const ( + // Prog is the name of the service + Prog = "keycloak-proxy" + // Author is the writer + Author = "Rohith Jayawardene" + // Email is the author email + Email = "gambol99@gmail.com" + // Description is a short hand description + Description = "is a proxy using the keycloak service for auth and authorization" + // ClaimPreferredName is the keycloak username claim + ClaimPreferredName = "preferred_username" + // ClaimAudience is tha audience claim + ClaimAudience = "aud" + // ClaimResourceAccess is the keycloak client roles + ClaimResourceAccess = "resource_access" + // ClaimRealmAccess is the keycloak realm roles + ClaimRealmAccess = "realm_access" + // ClaimResourceRoles is the roles claims + ClaimResourceRoles = "roles" + // HeaderUpgrade indicates a connecttion upgrade1 + HeaderUpgrade = "Upgrade" + // HTTPSchema is the http schema + HTTPSchema = "http" + // HTTPSSchema is the https schema + HTTPSSchema = "https" + // HeaderXForwardedFor is a HTTP header + HeaderXForwardedFor = "X-Forwarded-For" + // HeaderXForwardedProto is a HTTP header + HeaderXForwardedProto = "X-Forwarded-Proto" + // HeaderXForwardedProtocol is a HTTP header + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + // HeaderXForwardedSSL is a HTTP header + HeaderXForwardedSSL = "X-Forwarded-SSL" + // HeaderXRealIP is a HTTP header + HeaderXRealIP = "X-Real-IP" + // AuthorizationHeader is a http authorization header + AuthorizationHeader = "Authorization" + // VersionHeader is a verion http header + VersionHeader = "X-Auth-Proxy-Version" + + // OauthURL is the base oauth uri + OauthURL = "/oauth" + // AuthorizationURL is the uri for oauth authorization + AuthorizationURL = "/authorize" + // CallbackURL is the uri for oauth callbacks + CallbackURL = "/callback" + // ExpiredURL is the expiration handler + ExpiredURL = "/expired" + // HealthURL is the health handler + HealthURL = "/health" + // LoginURL is the login handler + LoginURL = "/login" + // LogoutURL is the logout handler + LogoutURL = "/logout" + // MetricsURL is the uri for the metrics handler + MetricsURL = "/metrics" + // TokenURL is the uri for the tokens handler + TokenURL = "/token" + // DebugURL is the uri for the debug endpoint + DebugURL = "/debug/pprof" +) + +var ( + // AllHTTPMethods contains all the http methods + AllHTTPMethods = []string{ + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + http.MethodPut, + http.MethodTrace, + } +) + +// GetVersion returns the proxy version +func GetVersion() string { + if Version == "" { + tm, err := strconv.ParseInt(Compiled, 10, 64) + if err != nil { + return "unable to parse compiled time" + } + Version = fmt.Sprintf("%s (git+sha: %s, built: %s)", Release, Gitsha, time.Unix(tm, 0).Format("02-01-2006")) + } + + return Version +} diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go new file mode 100644 index 000000000..5d0a5cab7 --- /dev/null +++ b/pkg/errors/errors.go @@ -0,0 +1,41 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package errors + +import "errors" + +var ( + // ErrSessionNotFound no session found in the request + ErrSessionNotFound = errors.New("authentication session not found") + // ErrNoSessionStateFound means there was not persist state + ErrNoSessionStateFound = errors.New("no session state found") + // ErrInvalidSession the session is invalid + ErrInvalidSession = errors.New("invalid session identifier") + // ErrAccessTokenExpired indicates the access token has expired + ErrAccessTokenExpired = errors.New("the access token has expired") + // ErrRefreshTokenExpired indicates the refresh token as expired + ErrRefreshTokenExpired = errors.New("the refresh token has expired") + // ErrNoTokenAudience indicates their is not audience in the token + ErrNoTokenAudience = errors.New("the token does not audience in claims") + // ErrDecryption indicates we can't decrypt the token + ErrDecryption = errors.New("failed to decrypt token") + // ErrUnsupportedStore indicates the storage type is not supported + ErrUnsupportedStore = errors.New("unsupport store type") + // ErrDecryptionTextSmall indicates the encryption key is too small + ErrDecryptionTextSmall = errors.New("failed to decrypt the ciphertext, the text is too short") + // ErrUserInfoValidation indicates the token was not validated by userinfo endpoint + ErrUserInfoValidation = errors.New("token not validate by userinfo endpoint") +) diff --git a/cookies.go b/pkg/server/cookies.go similarity index 99% rename from cookies.go rename to pkg/server/cookies.go index a852d1b39..cf7c66789 100644 --- a/cookies.go +++ b/pkg/server/cookies.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "net/http" diff --git a/cookies_test.go b/pkg/server/cookies_test.go similarity index 99% rename from cookies_test.go rename to pkg/server/cookies_test.go index 343f811b9..2b68654ea 100644 --- a/cookies_test.go +++ b/pkg/server/cookies_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "net/http" diff --git a/pkg/server/doc.go b/pkg/server/doc.go new file mode 100644 index 000000000..075502e18 --- /dev/null +++ b/pkg/server/doc.go @@ -0,0 +1,76 @@ +/* +Copyright 2015 All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "net/http" + "time" + + "github.com/gambol99/go-oidc/jose" +) + +const ( + // contextScopeName is the context value name in for a request + contextScopeName = "context.scope.name" +) + +// RequestScope is a request level context scope passed between middleware +type RequestScope struct { + // AccessDenied indicates the request should not be proxied on + AccessDenied bool + // Identity is the user Identity of the request + Identity *userContext +} + +// reverseProxy is a wrapper +type reverseProxy interface { + ServeHTTP(rw http.ResponseWriter, req *http.Request) +} + +// userContext represents a user +type userContext struct { + // the id of the user + id string + // the email associated to the user + email string + // a name of the user + name string + // the preferred name + preferredName string + // the expiration of the access token + expiresAt time.Time + // a set of roles associated + roles []string + // the audience for the token + audience string + // the access token itself + token jose.JWT + // the claims associated to the token + claims jose.Claims + // whether the context is from a session cookie or authorization header + bearerToken bool +} + +// tokenResponse +type tokenResponse struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope,omitempty"` +} diff --git a/forwarding.go b/pkg/server/forwarding.go similarity index 85% rename from forwarding.go rename to pkg/server/forwarding.go index e7669fff8..59cfcf713 100644 --- a/forwarding.go +++ b/pkg/server/forwarding.go @@ -13,13 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "fmt" "net/http" "time" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/gambol99/go-oidc/jose" "github.com/gambol99/go-oidc/oidc" "go.uber.org/zap" @@ -39,9 +43,9 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { } } - if isUpgradedConnection(req) { + if utils.IsUpgradedConnection(req) { r.log.Debug("upgrading the connnection", zap.String("client_ip", req.RemoteAddr)) - if err := tryUpdateConnection(req, w, r.endpoint); err != nil { + if err := utils.TryUpdateConnection(req, w, r.endpoint); err != nil { r.log.Error("failed to upgrade connection", zap.Error(err)) w.WriteHeader(http.StatusInternalServerError) return @@ -56,11 +60,24 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { // By default goproxy only provides a forwarding proxy, thus all requests have to be absolute // and we must update the host headers - req.URL.Host = r.endpoint.Host - req.URL.Scheme = r.endpoint.Scheme - req.Host = r.endpoint.Host + // https://github.com/gambol99/keycloak-proxy/pull/248 - updated to permit optional upstreams + switch r.endpoint { + case nil: + req.URL.Host = req.Host + switch req.TLS { + case nil: + req.URL.Scheme = constants.HTTPSchema + default: + req.URL.Scheme = constants.HTTPSSchema + } + default: + // override with the specificed upstream endpoint + req.URL.Host = r.endpoint.Host + req.URL.Scheme = r.endpoint.Scheme + req.Host = r.endpoint.Host + } - req.Header.Add("X-Forwarded-For", realIP(req)) + req.Header.Add("X-Forwarded-For", utils.RealIP(req)) req.Header.Set("X-Forwarded-Host", req.URL.Host) req.Header.Set("X-Forwarded-Proto", req.Header.Get("X-Forwarded-Proto")) @@ -149,7 +166,7 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { if err != nil { state.login = true switch err { - case ErrRefreshTokenExpired: + case errors.ErrRefreshTokenExpired: r.log.Warn("the refresh token has expired, need to login again", zap.String("subject", state.identity.ID), zap.String("email", state.identity.Email)) @@ -185,7 +202,7 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { // wait for an expiration to come close if state.wait { // set the expiration of the access token within a random 85% of actual expiration - duration := getWithin(state.expiration, 0.85) + duration := utils.GetWithin(state.expiration, 0.85) r.log.Info("waiting for expiration of access token", zap.String("token_expiration", state.expiration.Format(time.RFC3339)), zap.String("renewel_duration", duration.String())) @@ -199,8 +216,8 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { hostname := req.Host req.URL.Host = hostname // is the host being signed? - if len(r.config.ForwardingDomains) == 0 || containsSubString(hostname, r.config.ForwardingDomains) { - req.Header.Set("X-Forwarded-Agent", prog) + if len(r.config.ForwardingDomains) == 0 || utils.ContainsSubString(hostname, r.config.ForwardingDomains) { + req.Header.Set("X-Forwarded-Agent", constants.Prog) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", state.token.Encode())) } } diff --git a/handlers.go b/pkg/server/handlers.go similarity index 93% rename from handlers.go rename to pkg/server/handlers.go index 944ce1801..212b59072 100644 --- a/handlers.go +++ b/pkg/server/handlers.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "bytes" @@ -30,6 +30,9 @@ import ( "strings" "time" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/gambol99/go-oidc/oauth2" "github.com/pressly/chi" "go.uber.org/zap" @@ -42,14 +45,14 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request) case "": // need to determine the scheme, cx.Request.URL.Scheme doesn't have it, best way is to default // and then check for TLS - scheme := "http" + scheme := constants.HTTPSchema if req.TLS != nil { - scheme = "https" + scheme = constants.HTTPSSchema } // @QUESTION: should I use the X-Forwarded-
?? .. redirect = fmt.Sprintf("%s://%s", - defaultTo(req.Header.Get("X-Forwarded-Proto"), scheme), - defaultTo(req.Header.Get("X-Forwarded-Host"), req.Host)) + utils.DefaultTo(req.Header.Get("X-Forwarded-Proto"), scheme), + utils.DefaultTo(req.Header.Get("X-Forwarded-Host"), req.Host)) default: redirect = r.config.RedirectionURL } @@ -72,7 +75,7 @@ func (r *oauthProxy) oauthAuthorizationHandler(w http.ResponseWriter, req *http. // step: set the access type of the session var accessType string - if containedIn("offline", r.config.Scopes) { + if utils.ContainedIn("offline", r.config.Scopes) { accessType = "offline" } @@ -83,11 +86,11 @@ func (r *oauthProxy) oauthAuthorizationHandler(w http.ResponseWriter, req *http. zap.String("client_ip", req.RemoteAddr)) // step: if we have a custom sign in page, lets display that - if r.config.hasCustomSignInPage() { + if r.config.HasCustomSignInPage() { model := make(map[string]string) model["redirect"] = authURL w.WriteHeader(http.StatusOK) - r.Render(w, path.Base(r.config.SignInPage), mergeMaps(model, r.config.Tags)) + r.Render(w, path.Base(r.config.SignInPage), utils.MergeMaps(model, r.config.Tags)) return } @@ -149,7 +152,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque // step: are we encrypting the access token? if r.config.EnableEncryptedToken { - if accessToken, err = encodeText(accessToken, r.config.EncryptionKey); err != nil { + if accessToken, err = utils.EncodeText(accessToken, r.config.EncryptionKey); err != nil { r.log.Error("unable to encode the access token", zap.Error(err)) w.WriteHeader(http.StatusInternalServerError) return @@ -164,7 +167,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque // step: does the response has a refresh token and we are NOT ignore refresh tokens? if r.config.EnableRefreshTokens && resp.RefreshToken != "" { var encrypted string - encrypted, err = encodeText(resp.RefreshToken, r.config.EncryptionKey) + encrypted, err = utils.EncodeText(resp.RefreshToken, r.config.EncryptionKey) if err != nil { r.log.Error("failed to encrypt the refresh token", zap.Error(err)) w.WriteHeader(http.StatusInternalServerError) @@ -300,7 +303,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { if r.idp.EndSessionEndpoint != nil { revokeDefault = r.idp.EndSessionEndpoint.String() } - revocationURL := defaultTo(r.config.RevocationEndpoint, revokeDefault) + revocationURL := utils.DefaultTo(r.config.RevocationEndpoint, revokeDefault) // step: do we have a revocation endpoint? if revocationURL != "" { @@ -377,7 +380,7 @@ func (r *oauthProxy) tokenHandler(w http.ResponseWriter, req *http.Request) { // healthHandler is a health check handler for the service func (r *oauthProxy) healthHandler(w http.ResponseWriter, req *http.Request) { - w.Header().Set(versionHeader, getVersion()) + w.Header().Set(constants.VersionHeader, constants.GetVersion()) w.WriteHeader(http.StatusOK) w.Write([]byte("OK\n")) } @@ -420,7 +423,7 @@ func (r *oauthProxy) debugHandler(w http.ResponseWriter, req *http.Request) { // proxyMetricsHandler forwards the request into the prometheus handler func (r *oauthProxy) proxyMetricsHandler(w http.ResponseWriter, req *http.Request) { if r.config.LocalhostMetrics { - if !net.ParseIP(realIP(req)).IsLoopback() { + if !net.ParseIP(utils.RealIP(req)).IsLoopback() { r.accessForbidden(w, req) return } @@ -441,7 +444,7 @@ func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) } ecrypted = token // returns encryped, avoid encoding twice - token, err = decodeText(token, r.config.EncryptionKey) + token, err = utils.DecodeText(token, r.config.EncryptionKey) return } diff --git a/handlers_test.go b/pkg/server/handlers_test.go similarity index 83% rename from handlers_test.go rename to pkg/server/handlers_test.go index 80f6a47c2..4e95f2649 100644 --- a/handlers_test.go +++ b/pkg/server/handlers_test.go @@ -13,17 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "net/http" "testing" "time" + + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/constants" ) func TestDebugHandler(t *testing.T) { c := newFakeKeycloakConfig() - c.Resources = make([]*Resource, 0) + c.Resources = make([]*api.Resource, 0) c.EnableProfiling = true requests := []fakeRequest{ {URI: "/debug/pprof/no_there", ExpectedCode: http.StatusNotFound}, @@ -41,7 +44,7 @@ func TestDebugHandler(t *testing.T) { } func TestExpirationHandler(t *testing.T) { - uri := oauthURL + expiredURL + uri := constants.OauthURL + constants.ExpiredURL requests := []fakeRequest{ { URI: uri, @@ -78,8 +81,8 @@ func TestLoginHandlerDisabled(t *testing.T) { c := newFakeKeycloakConfig() c.EnableLoginHandler = false requests := []fakeRequest{ - {URI: oauthURL + loginURL, Method: http.MethodPost, ExpectedCode: http.StatusNotImplemented}, - {URI: oauthURL + loginURL, ExpectedCode: http.StatusMethodNotAllowed}, + {URI: constants.OauthURL + constants.LoginURL, Method: http.MethodPost, ExpectedCode: http.StatusNotImplemented}, + {URI: constants.OauthURL + constants.LoginURL, ExpectedCode: http.StatusMethodNotAllowed}, } newFakeProxy(c).RunTests(t, requests) } @@ -94,7 +97,7 @@ func TestLoginHandlerNotDisabled(t *testing.T) { } func TestLoginHandler(t *testing.T) { - uri := oauthURL + loginURL + uri := constants.OauthURL + constants.LoginURL requests := []fakeRequest{ { URI: uri, @@ -137,7 +140,7 @@ func TestLoginHandler(t *testing.T) { func TestLogoutHandlerBadRequest(t *testing.T) { requests := []fakeRequest{ - {URI: oauthURL + logoutURL, ExpectedCode: http.StatusBadRequest}, + {URI: constants.OauthURL + constants.LogoutURL, ExpectedCode: http.StatusBadRequest}, } newFakeProxy(nil).RunTests(t, requests) } @@ -145,17 +148,17 @@ func TestLogoutHandlerBadRequest(t *testing.T) { func TestLogoutHandlerBadToken(t *testing.T) { requests := []fakeRequest{ { - URI: oauthURL + logoutURL, + URI: constants.OauthURL + constants.LogoutURL, ExpectedCode: http.StatusBadRequest, }, { - URI: oauthURL + logoutURL, + URI: constants.OauthURL + constants.LogoutURL, HasCookieToken: true, RawToken: "this.is.a.bad.token", ExpectedCode: http.StatusBadRequest, }, { - URI: oauthURL + logoutURL, + URI: constants.OauthURL + constants.LogoutURL, RawToken: "this.is.a.bad.token", ExpectedCode: http.StatusBadRequest, }, @@ -166,12 +169,12 @@ func TestLogoutHandlerBadToken(t *testing.T) { func TestLogoutHandlerGood(t *testing.T) { requests := []fakeRequest{ { - URI: oauthURL + logoutURL, + URI: constants.OauthURL + constants.LogoutURL, HasToken: true, ExpectedCode: http.StatusOK, }, { - URI: oauthURL + logoutURL + "?redirect=http://example.com", + URI: constants.OauthURL + constants.LogoutURL + "?redirect=http://example.com", HasToken: true, ExpectedCode: http.StatusTemporaryRedirect, ExpectedLocation: "http://example.com", @@ -181,7 +184,7 @@ func TestLogoutHandlerGood(t *testing.T) { } func TestTokenHandler(t *testing.T) { - uri := oauthURL + tokenURL + uri := constants.OauthURL + constants.TokenURL requests := []fakeRequest{ { URI: uri, @@ -228,7 +231,7 @@ func TestAuthorizationURLWithSkipToken(t *testing.T) { c.SkipTokenVerification = true newFakeProxy(c).RunTests(t, []fakeRequest{ { - URI: oauthURL + authorizationURL, + URI: constants.OauthURL + constants.AuthorizationURL, ExpectedCode: http.StatusNotAcceptable, }, }) @@ -278,28 +281,28 @@ func TestCallbackURL(t *testing.T) { cfg := newFakeKeycloakConfig() requests := []fakeRequest{ { - URI: oauthURL + callbackURL, + URI: constants.OauthURL + constants.CallbackURL, Method: http.MethodPost, ExpectedCode: http.StatusMethodNotAllowed, }, { - URI: oauthURL + callbackURL, + URI: constants.OauthURL + constants.CallbackURL, ExpectedCode: http.StatusBadRequest, }, { - URI: oauthURL + callbackURL + "?code=fake", + URI: constants.OauthURL + constants.CallbackURL + "?code=fake", ExpectedCookies: []string{cfg.CookieAccessName}, ExpectedLocation: "/", ExpectedCode: http.StatusTemporaryRedirect, }, { - URI: oauthURL + callbackURL + "?code=fake&state=/admin", + URI: constants.OauthURL + constants.CallbackURL + "?code=fake&state=/admin", ExpectedCookies: []string{cfg.CookieAccessName}, ExpectedLocation: "/", ExpectedCode: http.StatusTemporaryRedirect, }, { - URI: oauthURL + callbackURL + "?code=fake&state=L2FkbWlu", + URI: constants.OauthURL + constants.CallbackURL + "?code=fake&state=L2FkbWlu", ExpectedCookies: []string{cfg.CookieAccessName}, ExpectedLocation: "/admin", ExpectedCode: http.StatusTemporaryRedirect, @@ -311,12 +314,12 @@ func TestCallbackURL(t *testing.T) { func TestHealthHandler(t *testing.T) { requests := []fakeRequest{ { - URI: oauthURL + healthURL, + URI: constants.OauthURL + constants.HealthURL, ExpectedCode: http.StatusOK, ExpectedContent: "OK\n", }, { - URI: oauthURL + healthURL, + URI: constants.OauthURL + constants.HealthURL, Method: http.MethodHead, ExpectedCode: http.StatusMethodNotAllowed, }, diff --git a/middleware.go b/pkg/server/middleware.go similarity index 92% rename from middleware.go rename to pkg/server/middleware.go index d7bc79108..0b02fd905 100644 --- a/middleware.go +++ b/pkg/server/middleware.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "context" @@ -23,6 +23,11 @@ import ( "strings" "time" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/PuerkitoBio/purell" "github.com/gambol99/go-oidc/jose" "github.com/go-chi/chi/middleware" @@ -80,7 +85,8 @@ func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler { // metricsMiddleware is responsible for collecting metrics func (r *oauthProxy) metricsMiddleware(next http.Handler) http.Handler { - r.log.Info("enabled the service metrics middleware, available on", zap.String("path", fmt.Sprintf("%s%s", oauthURL, metricsURL))) + r.log.Info("enabled the service metrics middleware, available on", zap.String("path", fmt.Sprintf("%s%s", + constants.OauthURL, constants.MetricsURL))) statusMetrics := prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -100,7 +106,7 @@ func (r *oauthProxy) metricsMiddleware(next http.Handler) http.Handler { } // authenticationMiddleware is responsible for verifying the access token -func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Handler) http.Handler { +func (r *oauthProxy) authenticationMiddleware(resource *api.Resource) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { clientIP := req.RemoteAddr @@ -133,7 +139,7 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Hand // step: if the error post verification is anything other than a token // expired error we immediately throw an access forbidden - as there is // something messed up in the token - if err != ErrAccessTokenExpired { + if err != errors.ErrAccessTokenExpired { r.log.Error("access token failed verification", zap.String("client_ip", clientIP), zap.Error(err)) @@ -173,7 +179,7 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Hand token, exp, err := getRefreshedToken(r.client, refresh) if err != nil { switch err { - case ErrRefreshTokenExpired: + case errors.ErrRefreshTokenExpired: r.log.Warn("refresh token has expired, cannot retrieve access token", zap.String("client_ip", clientIP), zap.String("email", user.email)) @@ -197,7 +203,7 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Hand accessToken := token.Encode() if r.config.EnableEncryptedToken { - if accessToken, err = encodeText(accessToken, r.config.EncryptionKey); err != nil { + if accessToken, err = utils.EncodeText(accessToken, r.config.EncryptionKey); err != nil { r.log.Error("unable to encode the access token", zap.Error(err)) w.WriteHeader(http.StatusInternalServerError) return @@ -229,7 +235,7 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Hand } // admissionMiddleware is responsible checking the access token against the protected resource -func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) http.Handler { +func (r *oauthProxy) admissionMiddleware(resource *api.Resource) func(http.Handler) http.Handler { claimMatches := make(map[string]*regexp.Regexp) for k, v := range r.config.MatchClaims { claimMatches[k] = regexp.MustCompile(v) @@ -247,12 +253,12 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) // step: we need to check the roles if roles := len(resource.Roles); roles > 0 { - if !hasRoles(resource.Roles, user.roles) { + if !utils.HasRoles(resource.Roles, user.roles) { r.log.Warn("access denied, invalid roles", zap.String("access", "denied"), zap.String("email", user.email), - zap.String("resource", resource.URL), - zap.String("required", resource.getRoles())) + zap.String("resource", resource.URI), + zap.String("required", resource.GetRoles())) next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) return @@ -266,7 +272,7 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) r.log.Error("unable to extract the claim from token", zap.String("access", "denied"), zap.String("email", user.email), - zap.String("resource", resource.URL), + zap.String("resource", resource.URI), zap.Error(err)) next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) @@ -278,7 +284,7 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) zap.String("access", "denied"), zap.String("claim", claimName), zap.String("email", user.email), - zap.String("resource", resource.URL)) + zap.String("resource", resource.URI)) next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) return @@ -292,7 +298,7 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) zap.String("email", user.email), zap.String("issued", value), zap.String("required", match.String()), - zap.String("resource", resource.URL)) + zap.String("resource", resource.URI)) next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) return @@ -303,7 +309,7 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) zap.String("access", "permitted"), zap.String("email", user.email), zap.Duration("expires", time.Until(user.expiresAt)), - zap.String("resource", resource.URL)) + zap.String("resource", resource.URI)) next.ServeHTTP(w, req) }) @@ -314,7 +320,7 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) func (r *oauthProxy) headersMiddleware(custom []string) func(http.Handler) http.Handler { customClaims := make(map[string]string) for _, x := range custom { - customClaims[x] = fmt.Sprintf("X-Auth-%s", toHeader(x)) + customClaims[x] = fmt.Sprintf("X-Auth-%s", utils.ToHeader(x)) } return func(next http.Handler) http.Handler { diff --git a/middleware_test.go b/pkg/server/middleware_test.go similarity index 93% rename from middleware_test.go rename to pkg/server/middleware_test.go index 13f9f3cb8..303497b00 100644 --- a/middleware_test.go +++ b/pkg/server/middleware_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "fmt" @@ -25,6 +25,10 @@ import ( "testing" "time" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/gambol99/go-oidc/jose" "github.com/go-resty/resty" "github.com/rs/cors" @@ -65,13 +69,13 @@ type fakeRequest struct { } type fakeProxy struct { - config *Config + config *api.Config idp *fakeAuthServer proxy *oauthProxy cookies map[string]*http.Cookie } -func newFakeProxy(c *Config) *fakeProxy { +func newFakeProxy(c *api.Config) *fakeProxy { log.SetOutput(ioutil.Discard) if c == nil { c = newFakeKeycloakConfig() @@ -80,7 +84,7 @@ func newFakeProxy(c *Config) *fakeProxy { c.DiscoveryURL = auth.getLocation() c.RevocationEndpoint = auth.getRevocationURL() c.Verbose = true - proxy, err := newProxy(c) + proxy, err := New(c) if err != nil { panic("failed to create fake proxy service, error: " + err.Error()) } @@ -242,7 +246,7 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) { g := len(resp.Cookies()) assert.Equal(t, l, g, "case %d, expected %d cookies, got: %d", i, l, g) for _, x := range c.ExpectedCookies { - assert.NotNil(t, findCookie(x, resp.Cookies()), "case %d, expected cookie %s not found", i, x) + assert.NotNil(t, utils.FindCookie(x, resp.Cookies()), "case %d, expected cookie %s not found", i, x) } } if c.OnResponse != nil { @@ -270,7 +274,7 @@ func (f *fakeProxy) performUserLogin(uri string) error { return nil } -func setRequestAuthentication(cfg *Config, client *resty.Client, request *resty.Request, c *fakeRequest, token string) { +func setRequestAuthentication(cfg *api.Config, client *resty.Client, request *resty.Request, c *fakeRequest, token string) { switch c.HasCookieToken { case true: client.SetCookie(&http.Cookie{ @@ -289,12 +293,12 @@ func TestMetricsMiddleware(t *testing.T) { cfg.LocalhostMetrics = true requests := []fakeRequest{ { - URI: oauthURL + metricsURL, + URI: constants.OauthURL + constants.MetricsURL, ExpectedCode: http.StatusOK, ExpectedContentContains: "http_request_total", }, { - URI: oauthURL + metricsURL, + URI: constants.OauthURL + constants.MetricsURL, Headers: map[string]string{ "X-Forwarded-For": "10.0.0.1", }, @@ -328,25 +332,25 @@ func TestOauthRequests(t *testing.T) { func TestStrangeRoutingError(t *testing.T) { cfg := newFakeKeycloakConfig() - cfg.Resources = []*Resource{ + cfg.Resources = []*api.Resource{ { - URL: "/api/v1/events/123456789", - Methods: allHTTPMethods, + URI: "/api/v1/events/123456789", + Methods: constants.AllHTTPMethods, Roles: []string{"user"}, }, { - URL: "/api/v1/events/404", - Methods: allHTTPMethods, + URI: "/api/v1/events/404", + Methods: constants.AllHTTPMethods, Roles: []string{"monitoring"}, }, { - URL: "/api/v1/audit/*", - Methods: allHTTPMethods, + URI: "/api/v1/audit/*", + Methods: constants.AllHTTPMethods, Roles: []string{"auditor", "dev"}, }, { - URL: "/*", - Methods: allHTTPMethods, + URI: "/*", + Methods: constants.AllHTTPMethods, Roles: []string{"dev"}, }, } @@ -401,10 +405,10 @@ func TestStrangeRoutingError(t *testing.T) { func TestNoProxyingRequests(t *testing.T) { c := newFakeKeycloakConfig() - c.Resources = []*Resource{ + c.Resources = []*api.Resource{ { - URL: "/*", - Methods: allHTTPMethods, + URI: "/*", + Methods: constants.AllHTTPMethods, }, } requests := []fakeRequest{ @@ -434,10 +438,10 @@ func TestNoProxyingRequests(t *testing.T) { func TestStrangeAdminRequests(t *testing.T) { cfg := newFakeKeycloakConfig() - cfg.Resources = []*Resource{ + cfg.Resources = []*api.Resource{ { - URL: "/admin*", - Methods: allHTTPMethods, + URI: "/admin*", + Methods: constants.AllHTTPMethods, Roles: []string{fakeAdminRole}, }, } @@ -502,16 +506,16 @@ func TestStrangeAdminRequests(t *testing.T) { func TestWhiteListedRequests(t *testing.T) { cfg := newFakeKeycloakConfig() - cfg.Resources = []*Resource{ + cfg.Resources = []*api.Resource{ { - URL: "/*", - Methods: allHTTPMethods, + URI: "/*", + Methods: constants.AllHTTPMethods, Roles: []string{fakeTestRole}, }, { - URL: "/whitelist*", + URI: "/whitelist*", WhiteListed: true, - Methods: allHTTPMethods, + Methods: constants.AllHTTPMethods, }, } requests := []fakeRequest{ @@ -548,40 +552,40 @@ func TestWhiteListedRequests(t *testing.T) { func TestRolePermissionsMiddleware(t *testing.T) { cfg := newFakeKeycloakConfig() - cfg.Resources = []*Resource{ + cfg.Resources = []*api.Resource{ { - URL: "/admin*", - Methods: allHTTPMethods, + URI: "/admin*", + Methods: constants.AllHTTPMethods, Roles: []string{fakeAdminRole}, }, { - URL: "/test*", + URI: "/test*", Methods: []string{"GET"}, Roles: []string{fakeTestRole}, }, { - URL: "/test_admin_role*", + URI: "/test_admin_role*", Methods: []string{"GET"}, Roles: []string{fakeAdminRole, fakeTestRole}, }, { - URL: "/section/*", - Methods: allHTTPMethods, + URI: "/section/*", + Methods: constants.AllHTTPMethods, Roles: []string{fakeAdminRole}, }, { - URL: "/section/one", - Methods: allHTTPMethods, + URI: "/section/one", + Methods: constants.AllHTTPMethods, Roles: []string{"one"}, }, { - URL: "/whitelist", + URI: "/whitelist", Methods: []string{"GET"}, Roles: []string{}, }, { - URL: "/*", - Methods: allHTTPMethods, + URI: "/*", + Methods: constants.AllHTTPMethods, Roles: []string{fakeTestRole}, }, } @@ -922,25 +926,25 @@ func TestCustomHeadersHandler(t *testing.T) { func TestAdmissionHandlerRoles(t *testing.T) { cfg := newFakeKeycloakConfig() cfg.NoRedirects = true - cfg.Resources = []*Resource{ + cfg.Resources = []*api.Resource{ { - URL: "/admin", - Methods: allHTTPMethods, + URI: "/admin", + Methods: constants.AllHTTPMethods, Roles: []string{"admin"}, }, { - URL: "/test", + URI: "/test", Methods: []string{"GET"}, Roles: []string{"test"}, }, { - URL: "/either", - Methods: allHTTPMethods, + URI: "/either", + Methods: constants.AllHTTPMethods, Roles: []string{"admin", "test"}, }, { - URL: "/", - Methods: allHTTPMethods, + URI: "/", + Methods: constants.AllHTTPMethods, }, } requests := []fakeRequest{ @@ -1037,7 +1041,7 @@ func TestCustomHeaders(t *testing.T) { } for _, c := range requests { cfg := newFakeKeycloakConfig() - cfg.Resources = []*Resource{{URL: "/admin*", Methods: allHTTPMethods}} + cfg.Resources = []*api.Resource{{URI: "/admin*", Methods: constants.AllHTTPMethods}} cfg.Headers = c.Headers newFakeProxy(cfg).RunTests(t, []fakeRequest{c.Request}) } @@ -1129,7 +1133,7 @@ func TestRolesAdmissionHandlerClaims(t *testing.T) { } for _, c := range requests { cfg := newFakeKeycloakConfig() - cfg.Resources = []*Resource{{URL: "/admin*", Methods: allHTTPMethods}} + cfg.Resources = []*api.Resource{{URI: "/admin*", Methods: constants.AllHTTPMethods}} cfg.MatchClaims = c.Matches newFakeProxy(cfg).RunTests(t, []fakeRequest{c.Request}) } diff --git a/misc.go b/pkg/server/misc.go similarity index 94% rename from misc.go rename to pkg/server/misc.go index 41cfc1f87..b5f48cb1c 100644 --- a/misc.go +++ b/pkg/server/misc.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "context" @@ -23,6 +23,8 @@ import ( "path" "time" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/go-oidc/jose" "go.uber.org/zap" ) @@ -46,7 +48,7 @@ func (r *oauthProxy) revokeProxy(w http.ResponseWriter, req *http.Request) conte func (r *oauthProxy) accessForbidden(w http.ResponseWriter, req *http.Request) context.Context { w.WriteHeader(http.StatusForbidden) // are we using a custom http template for 403? - if r.config.hasCustomForbiddenPage() { + if r.config.HasCustomForbiddenPage() { name := path.Base(r.config.ForbiddenPage) if err := r.Render(w, name, r.config.Tags); err != nil { r.log.Error("failed to render the template", zap.Error(err), zap.String("template", name)) @@ -78,7 +80,7 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re w.WriteHeader(http.StatusForbidden) return r.revokeProxy(w, req) } - r.redirectToURL(oauthURL+authorizationURL+authQuery, w, req) + r.redirectToURL(constants.OauthURL+constants.AuthorizationURL+authQuery, w, req) return r.revokeProxy(w, req) } diff --git a/misc_test.go b/pkg/server/misc_test.go similarity index 98% rename from misc_test.go rename to pkg/server/misc_test.go index 9adb121f6..11baf98c8 100644 --- a/misc_test.go +++ b/pkg/server/misc_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "net/http" diff --git a/oauth.go b/pkg/server/oauth.go similarity index 91% rename from oauth.go rename to pkg/server/oauth.go index b66131a4c..0f90df072 100644 --- a/oauth.go +++ b/pkg/server/oauth.go @@ -13,17 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "encoding/json" - "errors" "fmt" "io/ioutil" "net/http" "strings" "time" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/gambol99/go-oidc/jose" "github.com/gambol99/go-oidc/oauth2" "github.com/gambol99/go-oidc/oidc" @@ -48,7 +50,7 @@ func (r *oauthProxy) getOAuthClient(redirectionURL string) (*oauth2.Client, erro func verifyToken(client *oidc.Client, token jose.JWT) error { if err := client.VerifyJWT(token); err != nil { if strings.Contains(err.Error(), "token is expired") { - return ErrAccessTokenExpired + return errors.ErrAccessTokenExpired } return err } @@ -65,7 +67,7 @@ func getRefreshedToken(client *oidc.Client, t string) (jose.JWT, time.Time, erro response, err := getToken(cl, oauth2.GrantTypeRefreshToken, t) if err != nil { if strings.Contains(err.Error(), "token expired") { - return jose.JWT{}, time.Time{}, ErrRefreshTokenExpired + return jose.JWT{}, time.Time{}, errors.ErrRefreshTokenExpired } return jose.JWT{}, time.Time{}, err } @@ -89,14 +91,14 @@ func getUserinfo(client *oauth2.Client, endpoint string, token string) (jose.Cla if err != nil { return nil, err } - req.Header.Set(authorizationHeader, fmt.Sprintf("Bearer %s", token)) + req.Header.Set(constants.AuthorizationHeader, fmt.Sprintf("Bearer %s", token)) resp, err := client.HttpClient().Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { - return nil, errors.New("token not validate by userinfo endpoint") + return nil, errors.ErrUserInfoValidation } content, err := ioutil.ReadAll(resp.Body) if err != nil { diff --git a/oauth_test.go b/pkg/server/oauth_test.go similarity index 99% rename from oauth_test.go rename to pkg/server/oauth_test.go index 3d7c382a3..90062a43a 100644 --- a/oauth_test.go +++ b/pkg/server/oauth_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "crypto/x509" diff --git a/server.go b/pkg/server/server.go similarity index 88% rename from server.go rename to pkg/server/server.go index fddd051bb..8a7472dbe 100644 --- a/server.go +++ b/pkg/server/server.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "context" @@ -36,6 +36,12 @@ import ( httplog "log" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/certs/rotate" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/store" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/armon/go-proxyproto" "github.com/gambol99/go-oidc/oidc" "github.com/gambol99/goproxy" @@ -48,7 +54,7 @@ import ( type oauthProxy struct { client *oidc.Client - config *Config + config *api.Config endpoint *url.URL idp oidc.ProviderConfig idpClient *http.Client @@ -57,7 +63,7 @@ type oauthProxy struct { metricsHandler http.Handler router http.Handler server *http.Server - store storage + store store.Storage templates *template.Template upstream reverseProxy } @@ -67,15 +73,18 @@ func init() { runtime.GOMAXPROCS(runtime.NumCPU()) // set the core } -// newProxy create's a new proxy from configuration -func newProxy(config *Config) (*oauthProxy, error) { +// New create's a new proxy from configuration +func New(config *api.Config) (*oauthProxy, error) { // create the service logger log, err := createLogger(config) if err != nil { return nil, err } - log.Info("starting the service", zap.String("prog", prog), zap.String("author", author), zap.String("version", version)) + log.Info("starting the service", + zap.String("prog", constants.Prog), + zap.String("author", constants.Author), + zap.String("version", constants.Version)) svc := &oauthProxy{ config: config, log: log, @@ -83,13 +92,15 @@ func newProxy(config *Config) (*oauthProxy, error) { } // parse the upstream endpoint - if svc.endpoint, err = url.Parse(config.Upstream); err != nil { - return nil, err + if config.Upstream != "" { + if svc.endpoint, err = url.Parse(config.Upstream); err != nil { + return nil, err + } } // initialize the store if any if config.StoreURL != "" { - if svc.store, err = createStorage(config.StoreURL); err != nil { + if svc.store, err = store.New(config.StoreURL); err != nil { return nil, err } } @@ -122,7 +133,7 @@ func newProxy(config *Config) (*oauthProxy, error) { } // createLogger is responsible for creating the service logger -func createLogger(config *Config) (*zap.Logger, error) { +func createLogger(config *api.Config) (*zap.Logger, error) { httplog.SetOutput(ioutil.Discard) // disable the http logger if config.DisableAllLogging { return zap.NewNop(), nil @@ -184,22 +195,22 @@ func (r *oauthProxy) createReverseProxy() error { r.router = engine // step: add the routing for oauth - engine.With(proxyDenyMiddleware).Route(oauthURL, func(e chi.Router) { + engine.With(proxyDenyMiddleware).Route(constants.OauthURL, func(e chi.Router) { e.MethodNotAllowed(methodNotAllowHandlder) - e.Get(authorizationURL, r.oauthAuthorizationHandler) - e.Get(callbackURL, r.oauthCallbackHandler) - e.Get(expiredURL, r.expirationHandler) - e.Get(healthURL, r.healthHandler) - e.Get(logoutURL, r.logoutHandler) - e.Get(tokenURL, r.tokenHandler) - e.Post(loginURL, r.loginHandler) + e.Get(constants.AuthorizationURL, r.oauthAuthorizationHandler) + e.Get(constants.CallbackURL, r.oauthCallbackHandler) + e.Get(constants.ExpiredURL, r.expirationHandler) + e.Get(constants.HealthURL, r.healthHandler) + e.Get(constants.LogoutURL, r.logoutHandler) + e.Get(constants.TokenURL, r.tokenHandler) + e.Post(constants.LoginURL, r.loginHandler) if r.config.EnableMetrics { - e.Get(metricsURL, r.proxyMetricsHandler) + e.Get(constants.MetricsURL, r.proxyMetricsHandler) } }) if r.config.EnableProfiling { - engine.With(proxyDenyMiddleware).Route(debugURL, func(e chi.Router) { + engine.With(proxyDenyMiddleware).Route(constants.DebugURL, func(e chi.Router) { r.log.Warn("enabling the debug profiling on /debug/pprof") e.Get("/{name}", r.debugHandler) e.Post("/{name}", r.debugHandler) @@ -212,11 +223,11 @@ func (r *oauthProxy) createReverseProxy() error { } // step: provision in the protected resources for _, x := range r.config.Resources { - if x.URL[len(x.URL)-1:] == "/" { + if x.URI[len(x.URI)-1:] == "/" { r.log.Warn("the resource url is not a prefix", - zap.String("resource", x.URL), - zap.String("change", x.URL), - zap.String("ammended", strings.TrimRight(x.URL, "/"))) + zap.String("resource", x.URI), + zap.String("change", x.URI), + zap.String("ammended", strings.TrimRight(x.URI, "/"))) } } @@ -230,11 +241,11 @@ func (r *oauthProxy) createReverseProxy() error { switch x.WhiteListed { case false: for _, m := range x.Methods { - e.MethodFunc(m, x.URL, emptyHandler) + e.MethodFunc(m, x.URI, emptyHandler) } default: for _, m := range x.Methods { - engine.MethodFunc(m, x.URL, emptyHandler) + engine.MethodFunc(m, x.URI, emptyHandler) } } } @@ -247,6 +258,9 @@ func (r *oauthProxy) createReverseProxy() error { if r.config.EnableEncryptedToken { r.log.Info("session access tokens will be encrypted") } + if r.config.Upstream == "" && (r.config.EnableAuthorizationHeader || r.config.EnableTokenHeader) { + r.log.Warn("no upstream has been configured, hence acting as an open reverse proxy, you sure you want tokens proxied") + } return nil } @@ -269,7 +283,7 @@ func (r *oauthProxy) createForwardingProxy() error { // setup the tls configuration if r.config.TLSCaCertificate != "" && r.config.TLSCaPrivateKey != "" { - ca, err := loadCA(r.config.TLSCaCertificate, r.config.TLSCaPrivateKey) + ca, err := utils.LoadCA(r.config.TLSCaCertificate, r.config.TLSCaPrivateKey) if err != nil { return fmt.Errorf("unable to load certificate authority, error: %s", err) } @@ -328,7 +342,6 @@ func (r *oauthProxy) Run() error { hostnames: r.config.Hostnames, redirectionURL: r.config.RedirectionURL, }) - if err != nil { return err } @@ -398,7 +411,7 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er // are we create a unix socket or tcp listener? if strings.HasPrefix(config.listen, "unix://") { socket := strings.Trim(config.listen, "unix://") - if exists := fileExists(socket); exists { + if exists := utils.FileExists(socket); exists { if err = os.Remove(socket); err != nil { return nil, err } @@ -493,24 +506,25 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er } // createUpstreamProxy create a reverse http proxy from the upstream -func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { +func (r *oauthProxy) createUpstreamProxy(endpoint *url.URL) error { dialer := (&net.Dialer{ KeepAlive: r.config.UpstreamKeepaliveTimeout, Timeout: r.config.UpstreamTimeout, }).Dial // are we using a unix socket? - if upstream != nil && upstream.Scheme == "unix" { - r.log.Info("using unix socket for upstream", zap.String("socket", fmt.Sprintf("%s%s", upstream.Host, upstream.Path))) + if endpoint != nil && endpoint.Scheme == "unix" { + r.log.Info("using unix socket for endpoint", zap.String("socket", fmt.Sprintf("%s%s", endpoint.Host, endpoint.Path))) - socketPath := fmt.Sprintf("%s%s", upstream.Host, upstream.Path) + socketPath := fmt.Sprintf("%s%s", endpoint.Host, endpoint.Path) dialer = func(network, address string) (net.Conn, error) { return net.Dial("unix", socketPath) } - upstream.Path = "" - upstream.Host = "domain-sock" - upstream.Scheme = "http" + endpoint.Path = "" + endpoint.Host = "domain-sock" + endpoint.Scheme = "http" } + // create the upstream tls configure tlsConfig := &tls.Config{InsecureSkipVerify: r.config.SkipUpstreamTLSVerify} diff --git a/server_test.go b/pkg/server/server_test.go similarity index 93% rename from server_test.go rename to pkg/server/server_test.go index 5b8e0ef15..0821210b5 100644 --- a/server_test.go +++ b/pkg/server/server_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "encoding/json" @@ -26,6 +26,9 @@ import ( "testing" "time" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/go-oidc/jose" "github.com/stretchr/testify/assert" ) @@ -71,12 +74,11 @@ func TestNewKeycloakProxy(t *testing.T) { cfg.Listen = "127.0.0.1:0" cfg.ListenHTTP = "" - proxy, err := newProxy(cfg) + proxy, err := New(cfg) assert.NoError(t, err) assert.NotNil(t, proxy) assert.NotNil(t, proxy.config) assert.NotNil(t, proxy.router) - assert.NotNil(t, proxy.endpoint) assert.NoError(t, proxy.Run()) } @@ -127,11 +129,11 @@ func TestForwardingProxy(t *testing.T) { func TestForbiddenTemplate(t *testing.T) { cfg := newFakeKeycloakConfig() - cfg.ForbiddenPage = "templates/forbidden.html.tmpl" - cfg.Resources = []*Resource{ + cfg.ForbiddenPage = "../../templates/forbidden.html.tmpl" + cfg.Resources = []*api.Resource{ { - URL: "/*", - Methods: allHTTPMethods, + URI: "/*", + Methods: constants.AllHTTPMethods, Roles: []string{fakeAdminRole}, }, } @@ -149,17 +151,17 @@ func TestForbiddenTemplate(t *testing.T) { func TestAuthorizationTemplate(t *testing.T) { cfg := newFakeKeycloakConfig() - cfg.SignInPage = "templates/sign_in.html.tmpl" - cfg.Resources = []*Resource{ + cfg.SignInPage = "../../templates/sign_in.html.tmpl" + cfg.Resources = []*api.Resource{ { - URL: "/*", - Methods: allHTTPMethods, + URI: "/*", + Methods: constants.AllHTTPMethods, Roles: []string{fakeAdminRole}, }, } requests := []fakeRequest{ { - URI: oauthURL + authorizationURL, + URI: constants.OauthURL + constants.AuthorizationURL, Redirects: true, ExpectedCode: http.StatusOK, ExpectedContentContains: "Sign In", @@ -332,7 +334,7 @@ func newTestService() string { return u } -func newTestProxyService(config *Config) (*oauthProxy, *fakeAuthServer, string) { +func newTestProxyService(config *api.Config) (*oauthProxy, *fakeAuthServer, string) { auth := newFakeAuthServer() if config == nil { config = newFakeKeycloakConfig() @@ -342,7 +344,7 @@ func newTestProxyService(config *Config) (*oauthProxy, *fakeAuthServer, string) config.Verbose = false config.EnableLogging = false - proxy, err := newProxy(config) + proxy, err := New(config) if err != nil { panic("failed to create proxy service, error: " + err.Error()) } @@ -373,8 +375,8 @@ func newFakeHTTPRequest(method, path string) *http.Request { } } -func newFakeKeycloakConfig() *Config { - return &Config{ +func newFakeKeycloakConfig() *api.Config { + return &api.Config{ ClientID: fakeClientID, ClientSecret: fakeSecret, CookieAccessName: "kc-access", @@ -388,31 +390,31 @@ func newFakeKeycloakConfig() *Config { Listen: "127.0.0.1:0", Scopes: []string{}, Verbose: true, - Resources: []*Resource{ + Resources: []*api.Resource{ { - URL: fakeAdminRoleURL, + URI: fakeAdminRoleURL, Methods: []string{"GET"}, Roles: []string{fakeAdminRole}, }, { - URL: fakeTestRoleURL, + URI: fakeTestRoleURL, Methods: []string{"GET"}, Roles: []string{fakeTestRole}, }, { - URL: fakeTestAdminRolesURL, + URI: fakeTestAdminRolesURL, Methods: []string{"GET"}, Roles: []string{fakeAdminRole, fakeTestRole}, }, { - URL: fakeAuthAllURL, - Methods: allHTTPMethods, + URI: fakeAuthAllURL, + Methods: constants.AllHTTPMethods, Roles: []string{}, }, { - URL: fakeTestWhitelistedURL, + URI: fakeTestWhitelistedURL, WhiteListed: true, - Methods: allHTTPMethods, + Methods: constants.AllHTTPMethods, Roles: []string{}, }, }, diff --git a/session.go b/pkg/server/session.go similarity index 82% rename from session.go rename to pkg/server/session.go index 6bb11f90a..4bdb4261b 100644 --- a/session.go +++ b/pkg/server/session.go @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "net/http" "strings" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/gambol99/keycloak-proxy/pkg/utils" + "github.com/gambol99/go-oidc/jose" "go.uber.org/zap" ) @@ -32,8 +36,8 @@ func (r *oauthProxy) getIdentity(req *http.Request) (*userContext, error) { return nil, err } if r.config.EnableEncryptedToken { - if access, err = decodeText(access, r.config.EncryptionKey); err != nil { - return nil, ErrDecryption + if access, err = utils.DecodeText(access, r.config.EncryptionKey); err != nil { + return nil, errors.ErrDecryption } } token, err := jose.ParseJWT(access) @@ -71,7 +75,7 @@ func getTokenInRequest(req *http.Request, name string) (string, bool, error) { // step: check for a token in the authorization header token, err := getTokenInBearer(req) if err != nil { - if err != ErrSessionNotFound { + if err != errors.ErrSessionNotFound { return "", false, err } if token, err = getTokenInCookie(req, name); err != nil { @@ -85,14 +89,14 @@ func getTokenInRequest(req *http.Request, name string) (string, bool, error) { // getTokenInBearer retrieves a access token from the authorization header func getTokenInBearer(req *http.Request) (string, error) { - token := req.Header.Get(authorizationHeader) + token := req.Header.Get(constants.AuthorizationHeader) if token == "" { - return "", ErrSessionNotFound + return "", errors.ErrSessionNotFound } items := strings.Split(token, " ") if len(items) != 2 { - return "", ErrInvalidSession + return "", errors.ErrInvalidSession } return items[1], nil @@ -100,8 +104,8 @@ func getTokenInBearer(req *http.Request) (string, error) { // getTokenInCookie retrieves the access token from the request cookies func getTokenInCookie(req *http.Request, name string) (string, error) { - if cookie := findCookie(name, req.Cookies()); cookie != nil { + if cookie := utils.FindCookie(name, req.Cookies()); cookie != nil { return cookie.Value, nil } - return "", ErrSessionNotFound + return "", errors.ErrSessionNotFound } diff --git a/session_test.go b/pkg/server/session_test.go similarity index 90% rename from session_test.go rename to pkg/server/session_test.go index 841b94f7a..7bbed19b2 100644 --- a/session_test.go +++ b/pkg/server/session_test.go @@ -13,13 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "fmt" "net/http" "testing" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/stretchr/testify/assert" ) @@ -63,7 +67,7 @@ func TestGetIndentity(t *testing.T) { } func TestGetTokenInRequest(t *testing.T) { - defaultName := newDefaultConfig().CookieAccessName + defaultName := api.NewDefaultConfig().CookieAccessName token := newTestToken("test").getToken() cs := []struct { Token string @@ -72,7 +76,7 @@ func TestGetTokenInRequest(t *testing.T) { }{ { Token: "", - Error: ErrSessionNotFound, + Error: errors.ErrSessionNotFound, }, { Token: token.Encode(), @@ -89,7 +93,7 @@ func TestGetTokenInRequest(t *testing.T) { if x.Token != "" { switch x.IsBearer { case true: - req.Header.Set(authorizationHeader, "Bearer "+x.Token) + req.Header.Set(constants.AuthorizationHeader, "Bearer "+x.Token) default: req.AddCookie(&http.Cookie{ Name: defaultName, diff --git a/stores.go b/pkg/server/store.go similarity index 70% rename from stores.go rename to pkg/server/store.go index 6277db8af..03a6edd56 100644 --- a/stores.go +++ b/pkg/server/store.go @@ -13,37 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( - "fmt" - "net/url" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/gambol99/keycloak-proxy/pkg/utils" "github.com/gambol99/go-oidc/jose" "go.uber.org/zap" ) -// createStorage creates the store client for use -func createStorage(location string) (storage, error) { - var store storage - var err error - - u, err := url.Parse(location) - if err != nil { - return nil, err - } - switch u.Scheme { - case "redis": - store, err = newRedisStore(u) - case "boltdb": - store, err = newBoltDBStore(u) - default: - return nil, fmt.Errorf("unsupport store: %s", u.Scheme) - } - - return store, err -} - // useStore checks if we are using a store to hold the refresh tokens func (r *oauthProxy) useStore() bool { return r.store != nil @@ -51,18 +30,18 @@ func (r *oauthProxy) useStore() bool { // StoreRefreshToken the token to the store func (r *oauthProxy) StoreRefreshToken(token jose.JWT, value string) error { - return r.store.Set(getHashKey(&token), value) + return r.store.Set(utils.GetHashKey(&token), value) } // Get retrieves a token from the store, the key we are using here is the access token func (r *oauthProxy) GetRefreshToken(token jose.JWT) (string, error) { // step: the key is the access token - v, err := r.store.Get(getHashKey(&token)) + v, err := r.store.Get(utils.GetHashKey(&token)) if err != nil { return v, err } if v == "" { - return v, ErrNoSessionStateFound + return v, errors.ErrNoSessionStateFound } return v, nil @@ -70,7 +49,7 @@ func (r *oauthProxy) GetRefreshToken(token jose.JWT) (string, error) { // DeleteRefreshToken removes a key from the store func (r *oauthProxy) DeleteRefreshToken(token jose.JWT) error { - if err := r.store.Delete(getHashKey(&token)); err != nil { + if err := r.store.Delete(utils.GetHashKey(&token)); err != nil { r.log.Error("unable to delete token", zap.Error(err)) return err diff --git a/user_context.go b/pkg/server/user_context.go similarity index 80% rename from user_context.go rename to pkg/server/user_context.go index 875085f2d..32e293ae0 100644 --- a/user_context.go +++ b/pkg/server/user_context.go @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "fmt" "strings" "time" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/gambol99/go-oidc/jose" "github.com/gambol99/go-oidc/oidc" ) @@ -35,18 +38,18 @@ func extractIdentity(token jose.JWT) (*userContext, error) { return nil, err } // step: ensure we have and can extract the preferred name of the user, if not, we set to the ID - preferredName, found, err := claims.StringClaim(claimPreferredName) + preferredName, found, err := claims.StringClaim(constants.ClaimPreferredName) if err != nil || !found { preferredName = identity.Email } - audience, found, err := claims.StringClaim(claimAudience) + audience, found, err := claims.StringClaim(constants.ClaimAudience) if err != nil || !found { - return nil, ErrNoTokenAudience + return nil, errors.ErrNoTokenAudience } // step: extract the realm roles var list []string - if realmRoles, found := claims[claimRealmAccess].(map[string]interface{}); found { - if roles, found := realmRoles[claimResourceRoles]; found { + if realmRoles, found := claims[constants.ClaimRealmAccess].(map[string]interface{}); found { + if roles, found := realmRoles[constants.ClaimResourceRoles]; found { for _, r := range roles.([]interface{}) { list = append(list, fmt.Sprintf("%s", r)) } @@ -54,10 +57,10 @@ func extractIdentity(token jose.JWT) (*userContext, error) { } // step: extract the client roles from the access token - if accesses, found := claims[claimResourceAccess].(map[string]interface{}); found { + if accesses, found := claims[constants.ClaimResourceAccess].(map[string]interface{}); found { for roleName, roleList := range accesses { scopes := roleList.(map[string]interface{}) - if roles, found := scopes[claimResourceRoles]; found { + if roles, found := scopes[constants.ClaimResourceRoles]; found { for _, r := range roles.([]interface{}) { list = append(list, fmt.Sprintf("%s:%s", roleName, r)) } diff --git a/user_context_test.go b/pkg/server/user_context_test.go similarity index 99% rename from user_context_test.go rename to pkg/server/user_context_test.go index b28c45f50..6e7ea7078 100644 --- a/user_context_test.go +++ b/pkg/server/user_context_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package server import ( "testing" diff --git a/store_boltdb.go b/pkg/store/boltdb.go similarity index 96% rename from store_boltdb.go rename to pkg/store/boltdb.go index 135b97b16..dbddb08b1 100644 --- a/store_boltdb.go +++ b/pkg/store/boltdb.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package store import ( "errors" @@ -38,7 +38,7 @@ type boltdbStore struct { client *bolt.DB } -func newBoltDBStore(location *url.URL) (storage, error) { +func newBoltDBStore(location *url.URL) (Storage, error) { // step: drop the initial slash path := strings.TrimPrefix(location.Path, "/") db, err := bolt.Open(path, 0600, &bolt.Options{ diff --git a/store_boltdb_test.go b/pkg/store/boltdb_test.go similarity index 99% rename from store_boltdb_test.go rename to pkg/store/boltdb_test.go index a5d6788b3..bb7db6d09 100644 --- a/store_boltdb_test.go +++ b/pkg/store/boltdb_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package store import ( "fmt" diff --git a/pkg/store/doc.go b/pkg/store/doc.go new file mode 100644 index 000000000..fb88436c8 --- /dev/null +++ b/pkg/store/doc.go @@ -0,0 +1,30 @@ +/* +Copyright 2017 All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package store + +// Storage is used to hold the offline refresh token, assuming you don't want to use +// the default practice of a encrypted cookie +type Storage interface { + // Set the token to the store + Set(string, string) error + // Get retrieves a token from the store + Get(string) (string, error) + // Delete removes a key from the store + Delete(string) error + // Close is used to close off any resources + Close() error +} diff --git a/store_redis.go b/pkg/store/redis.go similarity index 95% rename from store_redis.go rename to pkg/store/redis.go index 8979ce20a..740357aee 100644 --- a/store_redis.go +++ b/pkg/store/redis.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package store import ( "net/url" @@ -27,7 +27,7 @@ type redisStore struct { } // newRedisStore creates a new redis store -func newRedisStore(location *url.URL) (storage, error) { +func newRedisStore(location *url.URL) (Storage, error) { // step: get any password password := "" if location.User != nil { diff --git a/pkg/store/store.go b/pkg/store/store.go new file mode 100644 index 000000000..71395f113 --- /dev/null +++ b/pkg/store/store.go @@ -0,0 +1,43 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package store + +import ( + "net/url" + + "github.com/gambol99/keycloak-proxy/pkg/errors" +) + +// New creates the store client for use +func New(location string) (Storage, error) { + var store Storage + var err error + + u, err := url.Parse(location) + if err != nil { + return nil, err + } + switch u.Scheme { + case "redis": + store, err = newRedisStore(u) + case "boltdb": + store, err = newBoltDBStore(u) + default: + return nil, errors.ErrUnsupportedStore + } + + return store, err +} diff --git a/stores_test.go b/pkg/store/store_test.go similarity index 85% rename from stores_test.go rename to pkg/store/store_test.go index 1fa342183..f62e0d9da 100644 --- a/stores_test.go +++ b/pkg/store/store_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package store import ( "os" @@ -23,13 +23,13 @@ import ( ) func TestCreateStorageRedis(t *testing.T) { - store, err := createStorage("redis://127.0.0.1") + store, err := New("redis://127.0.0.1") assert.NotNil(t, store) assert.NoError(t, err) } func TestCreateStorageBoltDB(t *testing.T) { - store, err := createStorage("boltdb:////tmp/bolt") + store, err := New("boltdb:////tmp/bolt") assert.NotNil(t, store) assert.NoError(t, err) if store != nil { @@ -38,7 +38,7 @@ func TestCreateStorageBoltDB(t *testing.T) { } func TestCreateStorageFail(t *testing.T) { - store, err := createStorage("not_there:///tmp/bolt") + store, err := New("not_there:///tmp/bolt") assert.Nil(t, store) assert.Error(t, err) } diff --git a/utils.go b/pkg/utils/utils.go similarity index 56% rename from utils.go rename to pkg/utils/utils.go index 9d30d06ca..f11950f18 100644 --- a/utils.go +++ b/pkg/utils/utils.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package utils import ( "crypto/aes" @@ -24,7 +24,6 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" - "errors" "fmt" "io" "io/ioutil" @@ -40,30 +39,20 @@ import ( "unicode" "unicode/utf8" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/gambol99/keycloak-proxy/pkg/errors" + "github.com/gambol99/go-oidc/jose" "github.com/urfave/cli" "gopkg.in/yaml.v2" ) -var ( - allHTTPMethods = []string{ - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - http.MethodPut, - http.MethodTrace, - } -) - var ( symbolsFilter = regexp.MustCompilePOSIX("[_$><\\[\\].,\\+-/'%^&*()!\\\\]+") ) -// readConfigFile reads and parses the configuration file -func readConfigFile(filename string, config *Config) error { +// ReadConfigFile reads and parses the configuration file +func ReadConfigFile(filename string, data interface{}) error { content, err := ioutil.ReadFile(filename) if err != nil { return err @@ -71,16 +60,16 @@ func readConfigFile(filename string, config *Config) error { // step: attempt to un-marshal the data switch ext := filepath.Ext(filename); ext { case "json": - err = json.Unmarshal(content, config) + err = json.Unmarshal(content, data) default: - err = yaml.Unmarshal(content, config) + err = yaml.Unmarshal(content, data) } return err } -// encryptDataBlock encrypts the plaintext string with the key -func encryptDataBlock(plaintext, key []byte) ([]byte, error) { +// EncryptDataBlock encrypts the plaintext string with the key +func EncryptDataBlock(plaintext, key []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { return []byte{}, err @@ -97,8 +86,8 @@ func encryptDataBlock(plaintext, key []byte) ([]byte, error) { return gcm.Seal(nonce, nonce, plaintext, nil), nil } -// decryptDataBlock decrypts some cipher text -func decryptDataBlock(cipherText, key []byte) ([]byte, error) { +// DecryptDataBlock decrypts some cipher text +func DecryptDataBlock(cipherText, key []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { return []byte{}, err @@ -109,16 +98,16 @@ func decryptDataBlock(cipherText, key []byte) ([]byte, error) { } nonceSize := gcm.NonceSize() if len(cipherText) < nonceSize { - return nil, errors.New("failed to decrypt the ciphertext, the text is too short") + return nil, errors.ErrDecryptionTextSmall } nonce, input := cipherText[:nonceSize], cipherText[nonceSize:] return gcm.Open(nil, nonce, input, nil) } -// encodeText encodes the session state information into a value for a cookie to consume -func encodeText(plaintext string, key string) (string, error) { - cipherText, err := encryptDataBlock([]byte(plaintext), []byte(key)) +// EncodeText encodes the session state information into a value for a cookie to consume +func EncodeText(plaintext string, key string) (string, error) { + cipherText, err := EncryptDataBlock([]byte(plaintext), []byte(key)) if err != nil { return "", err } @@ -126,23 +115,23 @@ func encodeText(plaintext string, key string) (string, error) { return base64.RawStdEncoding.EncodeToString(cipherText), nil } -// decodeText decodes the session state cookie value -func decodeText(state, key string) (string, error) { +// DecodeText decodes the session state cookie value +func DecodeText(state, key string) (string, error) { cipherText, err := base64.RawStdEncoding.DecodeString(state) if err != nil { return "", err } // step: decrypt the cookie back in the expiration|token - encoded, err := decryptDataBlock(cipherText, []byte(key)) + encoded, err := DecryptDataBlock(cipherText, []byte(key)) if err != nil { - return "", ErrInvalidSession + return "", errors.ErrInvalidSession } return string(encoded), nil } -// decodeKeyPairs converts a list of strings (key=pair) to a map -func decodeKeyPairs(list []string) (map[string]string, error) { +// DecodeKeyPairs converts a list of strings (key=pair) to a map +func DecodeKeyPairs(list []string) (map[string]string, error) { kp := make(map[string]string) for _, x := range list { @@ -156,19 +145,8 @@ func decodeKeyPairs(list []string) (map[string]string, error) { return kp, nil } -// isValidHTTPMethod ensure this is a valid http method type -func isValidHTTPMethod(method string) bool { - for _, x := range allHTTPMethods { - if method == x { - return true - } - } - - return false -} - -// defaultTo returns the value of the default -func defaultTo(v, d string) string { +// DefaultTo returns the value of the default +func DefaultTo(v, d string) string { if v != "" { return v } @@ -176,8 +154,8 @@ func defaultTo(v, d string) string { return d } -// fileExists check if a file exists -func fileExists(filename string) bool { +// FileExists check if a file exists +func FileExists(filename string) bool { if _, err := os.Stat(filename); err != nil { if os.IsNotExist(err) { return false @@ -187,10 +165,10 @@ func fileExists(filename string) bool { return true } -// hasRoles checks the scopes are the same -func hasRoles(required, issued []string) bool { +// HasRoles checks the scopes are the same +func HasRoles(required, issued []string) bool { for _, role := range required { - if !containedIn(role, issued) { + if !ContainedIn(role, issued) { return false } } @@ -198,8 +176,8 @@ func hasRoles(required, issued []string) bool { return true } -// containedIn checks if a value in a list of a strings -func containedIn(value string, list []string) bool { +// ContainedIn checks if a value in a list of a strings +func ContainedIn(value string, list []string) bool { for _, x := range list { if x == value { return true @@ -209,8 +187,8 @@ func containedIn(value string, list []string) bool { return false } -// containsSubString checks if substring exists -func containsSubString(value string, list []string) bool { +// ContainsSubString checks if substring exists +func ContainsSubString(value string, list []string) bool { for _, x := range list { if strings.Contains(value, x) { return true @@ -220,10 +198,10 @@ func containsSubString(value string, list []string) bool { return false } -// tryDialEndpoint dials the upstream endpoint via plain -func tryDialEndpoint(location *url.URL) (net.Conn, error) { - switch dialAddress := dialAddress(location); location.Scheme { - case httpSchema: +// TryDialEndpoint dials the upstream endpoint via plain +func TryDialEndpoint(location *url.URL) (net.Conn, error) { + switch dialAddress := DialAddress(location); location.Scheme { + case constants.HTTPSchema: return net.Dial("tcp", dialAddress) default: return tls.Dial("tcp", dialAddress, &tls.Config{ @@ -233,21 +211,21 @@ func tryDialEndpoint(location *url.URL) (net.Conn, error) { } } -// isUpgradedConnection checks to see if the request is requesting -func isUpgradedConnection(req *http.Request) bool { - return req.Header.Get(headerUpgrade) != "" +// IsUpgradedConnection checks to see if the request is requesting +func IsUpgradedConnection(req *http.Request) bool { + return req.Header.Get(constants.HeaderUpgrade) != "" } -// transferBytes transfers bytes between the sink and source -func transferBytes(src io.Reader, dest io.Writer, wg *sync.WaitGroup) (int64, error) { +// TransferBytes transfers bytes between the sink and source +func TransferBytes(src io.Reader, dest io.Writer, wg *sync.WaitGroup) (int64, error) { defer wg.Done() return io.Copy(dest, src) } -// tryUpdateConnection attempt to upgrade the connection to a http pdy stream -func tryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint *url.URL) error { +// TryUpdateConnection attempt to upgrade the connection to a http pdy stream +func TryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint *url.URL) error { // step: dial the endpoint - tlsConn, err := tryDialEndpoint(endpoint) + tlsConn, err := TryDialEndpoint(endpoint) if err != nil { return err } @@ -268,19 +246,19 @@ func tryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint // step: copy the date between client and upstream endpoint var wg sync.WaitGroup wg.Add(2) - go transferBytes(tlsConn, clientConn, &wg) - go transferBytes(clientConn, tlsConn, &wg) + go TransferBytes(tlsConn, clientConn, &wg) + go TransferBytes(clientConn, tlsConn, &wg) wg.Wait() return nil } -// dialAddress extracts the dial address from the url -func dialAddress(location *url.URL) string { +// DialAddress extracts the dial address from the url +func DialAddress(location *url.URL) string { items := strings.Split(location.Host, ":") if len(items) != 2 { switch location.Scheme { - case httpSchema: + case constants.HTTPSchema: return location.Host + ":80" default: return location.Host + ":443" @@ -290,8 +268,8 @@ func dialAddress(location *url.URL) string { return location.Host } -// findCookie looks for a cookie in a list of cookies -func findCookie(name string, cookies []*http.Cookie) *http.Cookie { +// FindCookie looks for a cookie in a list of cookies +func FindCookie(name string, cookies []*http.Cookie) *http.Cookie { for _, cookie := range cookies { if cookie.Name == name { return cookie @@ -301,20 +279,20 @@ func findCookie(name string, cookies []*http.Cookie) *http.Cookie { return nil } -// toHeader is a helper method to play nice in the headers -func toHeader(v string) string { +// ToHeader is a helper method to play nice in the headers +func ToHeader(v string) string { var list []string // step: filter out any symbols and convert to dashes for _, x := range symbolsFilter.Split(v, -1) { - list = append(list, capitalize(x)) + list = append(list, Capitalize(x)) } return strings.Join(list, "-") } -// capitalize capitalizes the first letter of a word -func capitalize(s string) string { +// Capitalize capitalizes the first letter of a word +func Capitalize(s string) string { if s == "" { return "" } @@ -323,8 +301,8 @@ func capitalize(s string) string { return string(unicode.ToUpper(r)) + s[n:] } -// mergeMaps simples copies the keys from source to destination -func mergeMaps(dest, source map[string]string) map[string]string { +// MergeMaps simples copies the keys from source to destination +func MergeMaps(dest, source map[string]string) map[string]string { for k, v := range source { dest[k] = v } @@ -332,8 +310,8 @@ func mergeMaps(dest, source map[string]string) map[string]string { return dest } -// loadCA loads the certificate authority -func loadCA(cert, key string) (*tls.Certificate, error) { +// LoadCA loads the certificate authority +func LoadCA(cert, key string) (*tls.Certificate, error) { caCert, err := ioutil.ReadFile(cert) if err != nil { return nil, err @@ -354,9 +332,9 @@ func loadCA(cert, key string) (*tls.Certificate, error) { return &ca, err } -// getWithin calculates a duration of x percent of the time period, i.e. something +// GetWithin calculates a duration of x percent of the time period, i.e. something // expires in 1 hours, get me a duration within 80% -func getWithin(expires time.Time, within float64) time.Duration { +func GetWithin(expires time.Time, within float64) time.Duration { left := expires.UTC().Sub(time.Now().UTC()).Seconds() if left <= 0 { return time.Duration(0) @@ -366,23 +344,23 @@ func getWithin(expires time.Time, within float64) time.Duration { return time.Duration(seconds) * time.Second } -// getHashKey returns a hash of the encodes jwt token -func getHashKey(token *jose.JWT) string { +// GetHashKey returns a hash of the encodes jwt token +func GetHashKey(token *jose.JWT) string { hash := md5.Sum([]byte(token.Encode())) return base64.RawStdEncoding.EncodeToString(hash[:]) } -// printError display the command line usage and error -func printError(message string, args ...interface{}) *cli.ExitError { +// PrintError display the command line usage and error +func PrintError(message string, args ...interface{}) *cli.ExitError { return cli.NewExitError(fmt.Sprintf("[error] "+message, args...), 1) } -// realIP retrieves the client ip address from a http request -func realIP(req *http.Request) string { +// RealIP retrieves the client ip address from a http request +func RealIP(req *http.Request) string { ra := req.RemoteAddr - if ip := req.Header.Get(headerXForwardedFor); ip != "" { + if ip := req.Header.Get(constants.HeaderXForwardedFor); ip != "" { ra = strings.Split(ip, ", ")[0] - } else if ip := req.Header.Get(headerXRealIP); ip != "" { + } else if ip := req.Header.Get(constants.HeaderXRealIP); ip != "" { ra = ip } else { ra, _, _ = net.SplitHostPort(ra) diff --git a/utils_test.go b/pkg/utils/utils_test.go similarity index 71% rename from utils_test.go rename to pkg/utils/utils_test.go index eabafb43f..a2800cfa9 100644 --- a/utils_test.go +++ b/pkg/utils/utils_test.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package utils import ( "bytes" @@ -26,6 +26,8 @@ import ( "testing" "time" + "github.com/gambol99/keycloak-proxy/pkg/constants" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -50,7 +52,7 @@ func TestDecodeKeyPairs(t *testing.T) { } for i, c := range testCases { - kp, err := decodeKeyPairs(c.List) + kp, err := DecodeKeyPairs(c.List) if err != nil && c.Ok { t.Errorf("test case %d should not have failed", i) continue @@ -82,7 +84,7 @@ func TestDefaultTo(t *testing.T) { }, } for _, c := range cs { - assert.Equal(t, c.Expected, defaultTo(c.Value, c.Default)) + assert.Equal(t, c.Expected, DefaultTo(c.Value, c.Default)) } } @@ -108,7 +110,7 @@ func TestEncryptDataBlock(t *testing.T) { } for i, test := range testCase { - _, err := encryptDataBlock(bytes.NewBufferString(test.Text).Bytes(), bytes.NewBufferString(test.Key).Bytes()) + _, err := EncryptDataBlock(bytes.NewBufferString(test.Text).Bytes(), bytes.NewBufferString(test.Key).Bytes()) if err != nil && test.Ok { t.Errorf("test case: %d should not have failed, %s", i, err) } @@ -116,7 +118,7 @@ func TestEncryptDataBlock(t *testing.T) { } func TestEncodeText(t *testing.T) { - session, err := encodeText("12245325632323263762", "1gjrlcjQ8RyKANngp9607txr5fF5fhf1") + session, err := EncodeText("12245325632323263762", "1gjrlcjQ8RyKANngp9607txr5fF5fhf1") assert.NotEmpty(t, session) assert.NoError(t, err) } @@ -146,7 +148,7 @@ func TestEncryptedText(t *testing.T) { func BenchmarkEncryptDataBlock(b *testing.B) { for n := 0; n < b.N; n++ { - encryptDataBlock(fakePlainText, fakeKey) + EncryptDataBlock(fakePlainText, fakeKey) } } @@ -154,7 +156,7 @@ func BenchmarkEncodeText(b *testing.B) { text := string(fakePlainText) key := string(fakeKey) for n := 0; n < b.N; n++ { - encodeText(text, key) + EncodeText(text, key) } } @@ -162,7 +164,7 @@ func BenchmarkDecodeText(b *testing.B) { t := string(fakeCipherText) k := string(fakeKey) for n := 0; n < b.N; n++ { - if _, err := decodeText(t, k); err != nil { + if _, err := DecodeText(t, k); err != nil { b.FailNow() } } @@ -172,11 +174,11 @@ func TestDecodeText(t *testing.T) { fakeKey := "HYLNt2JSzD7Lpz0djTRudmlOpbwx1oHB" fakeText := "12245325632323263762" - encrypted, err := encodeText(fakeText, fakeKey) + encrypted, err := EncodeText(fakeText, fakeKey) require.NoError(t, err) assert.NotEmpty(t, encrypted) - decoded, _ := decodeText(encrypted, fakeKey) + decoded, _ := DecodeText(encrypted, fakeKey) assert.NotNil(t, decoded, "the session should not have been nil") assert.Equal(t, decoded, fakeText, "the decoded text is not the same") } @@ -185,8 +187,8 @@ func TestFindCookie(t *testing.T) { cookies := []*http.Cookie{ {Name: "cookie_there"}, } - assert.NotNil(t, findCookie("cookie_there", cookies)) - assert.Nil(t, findCookie("not_there", cookies)) + assert.NotNil(t, FindCookie("cookie_there", cookies)) + assert.Nil(t, FindCookie("not_there", cookies)) } func TestDecryptDataBlock(t *testing.T) { @@ -208,12 +210,12 @@ func TestDecryptDataBlock(t *testing.T) { } for i, test := range testCase { - cipher, err := encryptDataBlock(bytes.NewBufferString(test.Text).Bytes(), bytes.NewBufferString(test.Key).Bytes()) + cipher, err := EncryptDataBlock(bytes.NewBufferString(test.Text).Bytes(), bytes.NewBufferString(test.Key).Bytes()) if err != nil && test.Ok { t.Errorf("test case: %d should not have failed, %s", i, err) } - plain, err := decryptDataBlock(cipher, bytes.NewBufferString(test.Key).Bytes()) + plain, err := DecryptDataBlock(cipher, bytes.NewBufferString(test.Key).Bytes()) if err != nil { t.Errorf("test case: %d should not have failed, %s", i, err) } @@ -248,69 +250,53 @@ func TestHasRoles(t *testing.T) { } for i, test := range testCases { - if !hasRoles(test.Required, test.Roles) && test.Ok { + if !HasRoles(test.Required, test.Roles) && test.Ok { assert.Fail(t, "test case: %i should have ok, %s, %s", i, test.Roles, test.Required) } } } func TestContainedIn(t *testing.T) { - assert.False(t, containedIn("1", []string{"2", "3", "4"})) - assert.True(t, containedIn("1", []string{"1", "2", "3", "4"})) + assert.False(t, ContainedIn("1", []string{"2", "3", "4"})) + assert.True(t, ContainedIn("1", []string{"1", "2", "3", "4"})) } func TestContainsSubString(t *testing.T) { - assert.False(t, containsSubString("bar.com", []string{"foo.bar.com"})) - assert.True(t, containsSubString("www.foo.bar.com", []string{"foo.bar.com"})) - assert.True(t, containsSubString("foo.bar.com", []string{"bar.com"})) - assert.True(t, containsSubString("star.domain.com", []string{"domain.com", "domain1.com"})) - assert.True(t, containsSubString("star.domain1.com", []string{"domain.com", "domain1.com"})) - assert.True(t, containsSubString("test.test.svc.cluster.local", []string{"svc.cluster.local"})) - - assert.False(t, containsSubString("star.domain1.com", []string{"domain.com", "sub.domain1.com"})) - assert.False(t, containsSubString("svc.cluster.local", []string{"nginx.pr1.svc.cluster.local"})) - assert.False(t, containsSubString("cluster.local", []string{"nginx.pr1.svc.cluster.local"})) - assert.False(t, containsSubString("pr1", []string{"nginx.pr1.svc.cluster.local"})) + assert.False(t, ContainsSubString("bar.com", []string{"foo.bar.com"})) + assert.True(t, ContainsSubString("www.foo.bar.com", []string{"foo.bar.com"})) + assert.True(t, ContainsSubString("foo.bar.com", []string{"bar.com"})) + assert.True(t, ContainsSubString("star.domain.com", []string{"domain.com", "domain1.com"})) + assert.True(t, ContainsSubString("star.domain1.com", []string{"domain.com", "domain1.com"})) + assert.True(t, ContainsSubString("test.test.svc.cluster.local", []string{"svc.cluster.local"})) + + assert.False(t, ContainsSubString("star.domain1.com", []string{"domain.com", "sub.domain1.com"})) + assert.False(t, ContainsSubString("svc.cluster.local", []string{"nginx.pr1.svc.cluster.local"})) + assert.False(t, ContainsSubString("cluster.local", []string{"nginx.pr1.svc.cluster.local"})) + assert.False(t, ContainsSubString("pr1", []string{"nginx.pr1.svc.cluster.local"})) } func BenchmarkContainsSubString(t *testing.B) { for n := 0; n < t.N; n++ { - containsSubString("svc.cluster.local", []string{"nginx.pr1.svc.cluster.local"}) + ContainsSubString("svc.cluster.local", []string{"nginx.pr1.svc.cluster.local"}) } } func TestDialAddress(t *testing.T) { - assert.Equal(t, dialAddress(getFakeURL("http://127.0.0.1")), "127.0.0.1:80") - assert.Equal(t, dialAddress(getFakeURL("https://127.0.0.1")), "127.0.0.1:443") - assert.Equal(t, dialAddress(getFakeURL("http://127.0.0.1:8080")), "127.0.0.1:8080") + assert.Equal(t, DialAddress(getFakeURL("http://127.0.0.1")), "127.0.0.1:80") + assert.Equal(t, DialAddress(getFakeURL("https://127.0.0.1")), "127.0.0.1:443") + assert.Equal(t, DialAddress(getFakeURL("http://127.0.0.1:8080")), "127.0.0.1:8080") } func TestIsUpgradedConnection(t *testing.T) { header := http.Header{} - header.Add(headerUpgrade, "") - assert.False(t, isUpgradedConnection(&http.Request{Header: header})) - header.Set(headerUpgrade, "set") - assert.True(t, isUpgradedConnection(&http.Request{Header: header})) -} - -func TestIdValidHTTPMethod(t *testing.T) { - cs := []struct { - Method string - Ok bool - }{ - {Method: "GET", Ok: true}, - {Method: "GETT"}, - {Method: "CONNECT", Ok: false}, - {Method: "PUT", Ok: true}, - {Method: "PATCH", Ok: true}, - } - for _, x := range cs { - assert.Equal(t, x.Ok, isValidHTTPMethod(x.Method)) - } + header.Add(constants.HeaderUpgrade, "") + assert.False(t, IsUpgradedConnection(&http.Request{Header: header})) + header.Set(constants.HeaderUpgrade, "set") + assert.True(t, IsUpgradedConnection(&http.Request{Header: header})) } func TestFileExists(t *testing.T) { - if fileExists("no_such_file_exsit_32323232") { + if FileExists("no_such_file_exsit_32323232") { t.Error("we should have received false") } tmpfile, err := ioutil.TempFile("/tmp", fmt.Sprintf("test_file_%d", os.Getpid())) @@ -319,7 +305,7 @@ func TestFileExists(t *testing.T) { } defer os.Remove(tmpfile.Name()) - if !fileExists(tmpfile.Name()) { + if !FileExists(tmpfile.Name()) { t.Error("we should have received a true") } } @@ -342,7 +328,7 @@ func TestGetWithin(t *testing.T) { }, } for _, x := range cs { - assert.Equal(t, x.Expected, getWithin(x.Expires, x.Percent)) + assert.Equal(t, x.Expected, GetWithin(x.Expires, x.Percent)) } } @@ -365,8 +351,8 @@ func TestToHeader(t *testing.T) { }, } for i, x := range cases { - assert.Equal(t, x.Expected, toHeader(x.Word), "case %d, expected: %s but got: %s", - i, x.Expected, toHeader(x.Word)) + assert.Equal(t, x.Expected, ToHeader(x.Word), "case %d, expected: %s but got: %s", + i, x.Expected, ToHeader(x.Word)) } } @@ -389,8 +375,8 @@ func TestCapitalize(t *testing.T) { }, } for i, x := range cases { - assert.Equal(t, x.Expected, capitalize(x.Word), "case %d, expected: %s but got: %s", i, x.Expected, - capitalize(x.Word)) + assert.Equal(t, x.Expected, Capitalize(x.Word), "case %d, expected: %s but got: %s", i, x.Expected, + Capitalize(x.Word)) } } @@ -416,7 +402,7 @@ func TestMergeMaps(t *testing.T) { }, } for i, x := range cases { - merged := mergeMaps(x.Dest, x.Source) + merged := MergeMaps(x.Dest, x.Source) if !reflect.DeepEqual(x.Expected, merged) { t.Errorf("case %d, expected: %v but got: %v", i, x.Expected, merged) } @@ -424,41 +410,20 @@ func TestMergeMaps(t *testing.T) { } func TestReadConfiguration(t *testing.T) { - testCases := []struct { - Content string - Ok bool - }{ - { - Content: ` -discovery_url: https://keyclock.domain.com/ -client-id: -secret: -`, - }, - { - Content: ` -discovery_url: https://keyclock.domain.com -client-id: -secret: -upstream-url: http://127.0.0.1:8080 -redirection_url: http://127.0.0.1:3000 -`, - Ok: true, - }, - } - - for i, test := range testCases { - // step: write the fake config file - file := writeFakeConfigFile(t, test.Content) - - config := new(Config) - err := readConfigFile(file.Name(), config) - if test.Ok && err != nil { - os.Remove(file.Name()) - t.Errorf("test case %d should not have failed, config: %v, error: %s", i, config, err) - } - os.Remove(file.Name()) - } + var test struct { + ID int `yaml:"id"` + Name string `yaml:"name"` + } + assert.Error(t, ReadConfigFile("not_found", nil)) + content := ` +id: 12 +name: test +` + file := writeFakeConfigFile(t, content) + assert.NoError(t, ReadConfigFile(file.Name(), &test)) + assert.Equal(t, 12, test.ID) + assert.Equal(t, "test", test.Name) + os.Remove(file.Name()) } func getFakeURL(location string) *url.URL { From 5c8b4f91ed884e70bacb69ef2f68a5cc929e115d Mon Sep 17 00:00:00 2001 From: Rohith Date: Sun, 23 Jul 2017 15:07:09 +0100 Subject: [PATCH 2/7] - rebased in master branch --- pkg/api/config.go | 27 +++--- pkg/certs/letsencrypt/main.go | 78 +++++++++++++++++ pkg/certs/letsencrypt/main_test.go | 16 ++++ pkg/certs/rotate/{rotation.go => main.go} | 9 +- .../rotate/{rotation_test.go => main_test.go} | 12 +-- pkg/errors/errors.go | 2 + pkg/server/server.go | 87 +++++-------------- 7 files changed, 140 insertions(+), 91 deletions(-) create mode 100644 pkg/certs/letsencrypt/main.go create mode 100644 pkg/certs/letsencrypt/main_test.go rename pkg/certs/rotate/{rotation.go => main.go} (93%) rename pkg/certs/rotate/{rotation_test.go => main_test.go} (84%) diff --git a/pkg/api/config.go b/pkg/api/config.go index bf02f52e2..446088087 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package api import ( "errors" @@ -28,18 +28,19 @@ import ( func NewDefaultConfig() *Config { return &Config{ AccessTokenDuration: time.Duration(720) * time.Hour, - Tags: make(map[string]string), - MatchClaims: make(map[string]string), - Headers: make(map[string]string), - UpstreamTimeout: time.Duration(10) * time.Second, - UpstreamKeepaliveTimeout: time.Duration(10) * time.Second, - EnableAuthorizationHeader: true, - EnableTokenHeader: true, CookieAccessName: "kc-access", CookieRefreshName: "kc-state", + EnableAuthorizationHeader: true, + EnableTokenHeader: true, + Headers: make(map[string]string), + LetsEncryptCacheDir: "./cache/", + MatchClaims: make(map[string]string), SecureCookie: true, - SkipUpstreamTLSVerify: true, SkipOpenIDProviderTLSVerify: false, + SkipUpstreamTLSVerify: true, + Tags: make(map[string]string), + UpstreamKeepaliveTimeout: time.Duration(10) * time.Second, + UpstreamTimeout: time.Duration(10) * time.Second, UseLetsEncrypt: false, LetsEncryptCacheDir: "./cache/", } @@ -56,11 +57,11 @@ func (c *Config) IsValid() error { if c.TLSPrivateKey != "" && c.TLSCertificate == "" { return errors.New("you have not provided a certificate file") } - if c.UseLetsEncrypt && c.LetsEncryptCacheDir == "" { + if c.UseLetsEncrypt && c.LetsEncryptCacheDir == "" { return fmt.Errorf("the letsencrypt cache dir has not been set") } - if r.EnableForwarding { + if c.EnableForwarding { if c.ClientID == "" { return errors.New("you have not specified the client id") } @@ -86,10 +87,6 @@ func (c *Config) IsValid() error { if _, err := url.Parse(c.Upstream); err != nil { return fmt.Errorf("the upstream endpoint is invalid, %s", err) } - if r.SkipUpstreamTLSVerify && r.UpstreamCA != "" { - return fmt.Errorf("you cannot skip upstream tls and load a root ca: %s to verify it", r.UpstreamCA) - } - // step: if the skip verification is off, we need the below if !c.SkipTokenVerification { if c.ClientID == "" { diff --git a/pkg/certs/letsencrypt/main.go b/pkg/certs/letsencrypt/main.go new file mode 100644 index 000000000..530159b32 --- /dev/null +++ b/pkg/certs/letsencrypt/main.go @@ -0,0 +1,78 @@ +/* +Copyright 2017 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package letsencrypt + +import ( + "context" + "crypto/tls" + "net/url" + + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/certs" + "github.com/gambol99/keycloak-proxy/pkg/errors" + + "go.uber.org/zap" + "golang.org/x/crypto/acme/autocert" +) + +type provider struct { + manager *autocert.Manager + hostnames []string + redirectionURL string +} + +// New returns a letsencrypt provider +func New(c *api.Config, log *zap.Logger) (certs.Provider, error) { + p := &provider{ + hostnames: c.Hostnames, + redirectionURL: c.RedirectionURL, + } + p.manager = &autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(c.LetsEncryptCacheDir), + HostPolicy: p.enforceHostPolicy, + } + + return p, nil +} + +// enforceHostPolicy is responsible for the hostname policy +func (p *provider) enforceHostPolicy(_ context.Context, hostname string) error { + if len(p.hostnames) > 0 { + found := false + for _, h := range p.hostnames { + found = found || (h == hostname) + } + if !found { + return errors.ErrHostNotConfigured + } + } else if p.redirectionURL != "" { + u, err := url.Parse(p.redirectionURL) + if err != nil { + return err + } + if u.Host != hostname { + return errors.ErrHostNotConfigured + } + } + + return nil +} + +// GetCertificate just wraps the letsencrypt method +func (p *provider) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + return p.manager.GetCertificate(h) +} diff --git a/pkg/certs/letsencrypt/main_test.go b/pkg/certs/letsencrypt/main_test.go new file mode 100644 index 000000000..b94be5e66 --- /dev/null +++ b/pkg/certs/letsencrypt/main_test.go @@ -0,0 +1,16 @@ +/* +Copyright 2017 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package letsencrypt diff --git a/pkg/certs/rotate/rotation.go b/pkg/certs/rotate/main.go similarity index 93% rename from pkg/certs/rotate/rotation.go rename to pkg/certs/rotate/main.go index 888785d85..67b1c8d63 100644 --- a/pkg/certs/rotate/rotation.go +++ b/pkg/certs/rotate/main.go @@ -21,6 +21,7 @@ import ( "path" "sync" + "github.com/gambol99/keycloak-proxy/pkg/api" "github.com/gambol99/keycloak-proxy/pkg/certs" "github.com/gambol99/keycloak-proxy/pkg/utils" @@ -41,17 +42,17 @@ type certificationRotation struct { } // New creates a new certificate -func New(cert, key string, log *zap.Logger) (certs.Provider, error) { +func New(c *api.Config, log *zap.Logger) (certs.Provider, error) { // step: attempt to load the certificate - certificate, err := tls.LoadX509KeyPair(cert, key) + certificate, err := tls.LoadX509KeyPair(c.TLSCertificate, c.TLSPrivateKey) if err != nil { return nil, err } svc := &certificationRotation{ certificate: certificate, - certificateFile: cert, + certificateFile: c.TLSCertificate, log: log, - privateKeyFile: key, + privateKeyFile: c.TLSPrivateKey, } // start watching the certificates diff --git a/pkg/certs/rotate/rotation_test.go b/pkg/certs/rotate/main_test.go similarity index 84% rename from pkg/certs/rotate/rotation_test.go rename to pkg/certs/rotate/main_test.go index 9710ed9bd..ce137dda5 100644 --- a/pkg/certs/rotate/rotation_test.go +++ b/pkg/certs/rotate/main_test.go @@ -19,6 +19,8 @@ import ( "crypto/tls" "testing" + "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -29,7 +31,7 @@ const ( ) func newTestCertificateRotator(t *testing.T) *certificationRotation { - p, err := New(testCertificateFile, testPrivateKeyFile, zap.NewNop()) + p, err := New(&api.Config{TLSCertificate: testCertificateFile, TLSPrivateKey: testPrivateKeyFile}, zap.NewNop()) c := p.(*certificationRotation) assert.NotNil(t, c) assert.Equal(t, testCertificateFile, c.certificateFile) @@ -41,14 +43,8 @@ func newTestCertificateRotator(t *testing.T) *certificationRotation { return c } -func TestNewCeritifacteRotator(t *testing.T) { - c, err := New(testCertificateFile, testPrivateKeyFile, zap.NewNop()) - assert.NotNil(t, c) - assert.NoError(t, err) -} - func TestNewCeritifacteRotatorFailure(t *testing.T) { - c, err := New("./tests/does_not_exist", testPrivateKeyFile, zap.NewNop()) + c, err := New(&api.Config{}, zap.NewNop()) assert.Nil(t, c) assert.Error(t, err) } diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 5d0a5cab7..e46ca5669 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -38,4 +38,6 @@ var ( ErrDecryptionTextSmall = errors.New("failed to decrypt the ciphertext, the text is too short") // ErrUserInfoValidation indicates the token was not validated by userinfo endpoint ErrUserInfoValidation = errors.New("token not validate by userinfo endpoint") + // ErrHostNotConfigured indicates the hostname was not found + ErrHostNotConfigured = errors.New("acme/autocert: host not configured") ) diff --git a/pkg/server/server.go b/pkg/server/server.go index 8a7472dbe..f1968fdae 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -16,7 +16,6 @@ limitations under the License. package server import ( - "context" "crypto/tls" "crypto/x509" "errors" @@ -24,6 +23,7 @@ import ( "html/template" "io" "io/ioutil" + httplog "log" "net" "net/http" "net/url" @@ -32,11 +32,9 @@ import ( "strings" "time" - "golang.org/x/crypto/acme/autocert" - - httplog "log" - "github.com/gambol99/keycloak-proxy/pkg/api" + "github.com/gambol99/keycloak-proxy/pkg/certs" + "github.com/gambol99/keycloak-proxy/pkg/certs/letsencrypt" "github.com/gambol99/keycloak-proxy/pkg/certs/rotate" "github.com/gambol99/keycloak-proxy/pkg/constants" "github.com/gambol99/keycloak-proxy/pkg/store" @@ -81,10 +79,7 @@ func New(config *api.Config) (*oauthProxy, error) { return nil, err } - log.Info("starting the service", - zap.String("prog", constants.Prog), - zap.String("author", constants.Author), - zap.String("version", constants.Version)) + log.Info("starting the service", zap.String("prog", constants.Prog), zap.String("author", constants.Author), zap.String("version", constants.Version)) svc := &oauthProxy{ config: config, log: log, @@ -146,7 +141,6 @@ func createLogger(config *api.Config) (*zap.Logger, error) { if !config.EnableJSONLogging { c.Encoding = "console" } - // are we running verbose mode? if config.Verbose { httplog.SetOutput(os.Stderr) c.DisableCaller = false @@ -389,20 +383,18 @@ func (r *oauthProxy) Run() error { // listenerConfig encapsulate listener options type listenerConfig struct { - listen string // the interface to bind the listener to - certificate string // the path to the certificate if any - privateKey string // the path to the private key if any ca string // the path to a certificate authority + certificate string // the path to the certificate if any clientCert string // the path to a client certificate to use for mutual tls - proxyProtocol bool // whether to enable proxy protocol on the listen hostnames []string // list of hostnames the service will respond to + letsEncryptCacheDir string // the path to cache letsencrypt certificates + listen string // the interface to bind the listener to + privateKey string // the path to the private key if any + proxyProtocol bool // whether to enable proxy protocol on the listen redirectionURL string // url to redirect to useLetsEncrypt bool // whether to use lets encrypt for retrieving ssl certificates - letsEncryptCacheDir string // the path to cache letsencrypt certificates } -var ErrHostNotConfigured = errors.New("acme/autocert: host not configured") - // createHTTPListener is responsible for creating a listening socket func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, error) { var listener net.Listener @@ -432,62 +424,27 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er listener = &proxyproto.Listener{Listener: listener} } - // does the socket require TLS? - if (config.certificate != "" && config.privateKey != "") || config.useLetsEncrypt { - getCertificate := func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return nil, errors.New("Not configured") + // the default get certificate methid when tls is enabled + if config.useLetsEncrypt || config.certificate != "" { + var getCert certs.Provider + tlsConfig := &tls.Config{ + PreferServerCipherSuites: true, } if config.useLetsEncrypt { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - Cache: autocert.DirCache(config.letsEncryptCacheDir), - HostPolicy: func(_ context.Context, host string) error { - if len(config.hostnames) > 0 { - found := false - - for _, h := range config.hostnames { - found = found || (h == host) - } - - if !found { - return ErrHostNotConfigured - } - } else if config.redirectionURL != "" { - if u, err := url.Parse(config.redirectionURL); err != nil { - return err - } else if u.Host != host { - return ErrHostNotConfigured - } - } - - return nil - }, - } - - getCertificate = m.GetCertificate - } else { - r.log.Info("tls support enabled", - zap.String("certificate", config.certificate), zap.String("private_key", config.privateKey)) - // creating a certificate rotation - rotate, err := newCertificateRotator(config.certificate, config.privateKey, r.log) + r.log.Info("enabling letsencrypt tls support") + getCert, err = letsencrypt.New(r.config, r.log) if err != nil { return nil, err } - // start watching the files for changes - if err := rotate.watch(); err != nil { + } else if config.certificate != "" { + r.log.Info("enabling tls support") + getCert, err = rotate.New(r.config, r.log) + if err != nil { return nil, err } - - getCertificate = rotate.GetCertificate - } - - tlsConfig := &tls.Config{ - PreferServerCipherSuites: true, - GetCertificate: getCertificate, } - - listener = tls.NewListener(listener, tlsConfig) + tlsConfig.GetCertificate = getCert.GetCertificate // are we doing mutual tls? if config.clientCert != "" { @@ -500,6 +457,8 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er tlsConfig.ClientCAs = caCertPool tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } + + listener = tls.NewListener(listener, tlsConfig) } return listener, nil From f536a6137f68d6ed3a9b0fb1249ef98823ca2780 Mon Sep 17 00:00:00 2001 From: Rohith Date: Sun, 23 Jul 2017 15:55:17 +0100 Subject: [PATCH 3/7] a --- pkg/api/config.go | 21 +++++----- pkg/api/doc.go | 1 - pkg/api/resource.go | 9 ++--- pkg/certs/letsencrypt/main.go | 36 ++++++++++------- pkg/certs/rotate/main.go | 75 +++++++++++++++++------------------ pkg/certs/rotate/main_test.go | 6 +-- pkg/errors/errors.go | 2 + pkg/server/handlers.go | 13 ++---- pkg/utils/utils.go | 17 ++++---- 9 files changed, 86 insertions(+), 94 deletions(-) diff --git a/pkg/api/config.go b/pkg/api/config.go index 446088087..728cad270 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -27,17 +27,16 @@ import ( // NewDefaultConfig returns a initialized config func NewDefaultConfig() *Config { return &Config{ - AccessTokenDuration: time.Duration(720) * time.Hour, - CookieAccessName: "kc-access", - CookieRefreshName: "kc-state", - EnableAuthorizationHeader: true, - EnableTokenHeader: true, - Headers: make(map[string]string), - LetsEncryptCacheDir: "./cache/", - MatchClaims: make(map[string]string), - SecureCookie: true, - SkipOpenIDProviderTLSVerify: false, - SkipUpstreamTLSVerify: true, + AccessTokenDuration: time.Duration(720) * time.Hour, + CookieAccessName: "kc-access", + CookieRefreshName: "kc-state", + EnableAuthorizationHeader: true, + EnableTokenHeader: true, + Headers: make(map[string]string), + LetsEncryptCacheDir: "./cache/", + MatchClaims: make(map[string]string), + SecureCookie: true, + SkipUpstreamTLSVerify: true, Tags: make(map[string]string), UpstreamKeepaliveTimeout: time.Duration(10) * time.Second, UpstreamTimeout: time.Duration(10) * time.Second, diff --git a/pkg/api/doc.go b/pkg/api/doc.go index 26c263514..949dad896 100644 --- a/pkg/api/doc.go +++ b/pkg/api/doc.go @@ -224,7 +224,6 @@ type Config struct { // UseLetsEncrypt controls if we should use letsencrypt to retrieve certificates UseLetsEncrypt bool `json:"use-letsencrypt" yaml:"use-letsencrypt" usage:"use letsencrypt for certificates"` - // LetsEncryptCacheDir is the path to store letsencrypt certificates LetsEncryptCacheDir string `json:"letsencrypt-cache-dir" yaml:"letsencrypt-cache-dir" usage:"path where cached letsencrypt certificates are stored"` diff --git a/pkg/api/resource.go b/pkg/api/resource.go index e3c72a446..5222b5454 100644 --- a/pkg/api/resource.go +++ b/pkg/api/resource.go @@ -26,15 +26,13 @@ import ( // NewResource returns a new resource func NewResource() *Resource { - return &Resource{ - Methods: constants.AllHTTPMethods, - } + return &Resource{Methods: constants.AllHTTPMethods} } // Parse decodes a resource definition func (r *Resource) Parse(resource string) (*Resource, error) { if resource == "" { - return nil, errors.New("the resource has no options") + return nil, errors.New("resource has no options") } for _, x := range strings.Split(resource, "|") { items := strings.Split(x, "=") @@ -46,7 +44,7 @@ func (r *Resource) Parse(resource string) (*Resource, error) { case "uri": r.URI = items[1] if !strings.HasPrefix(r.URI, "/") { - return nil, errors.New("the resource uri should start with a '/'") + return nil, errors.New("resource uri should start with a '/'") } case "methods": r.Methods = strings.Split(items[1], ",") @@ -85,7 +83,6 @@ func (r *Resource) IsValid() error { if r.URI == "" { return errors.New("neither uri or hostname specified") } - // step: add any of no methods if len(r.Methods) <= 0 { r.Methods = constants.AllHTTPMethods } diff --git a/pkg/certs/letsencrypt/main.go b/pkg/certs/letsencrypt/main.go index 530159b32..c530aeae3 100644 --- a/pkg/certs/letsencrypt/main.go +++ b/pkg/certs/letsencrypt/main.go @@ -29,38 +29,36 @@ import ( ) type provider struct { - manager *autocert.Manager - hostnames []string - redirectionURL string + manager *autocert.Manager + config *api.Config } // New returns a letsencrypt provider func New(c *api.Config, log *zap.Logger) (certs.Provider, error) { p := &provider{ - hostnames: c.Hostnames, - redirectionURL: c.RedirectionURL, - } - p.manager = &autocert.Manager{ - Prompt: autocert.AcceptTOS, - Cache: autocert.DirCache(c.LetsEncryptCacheDir), - HostPolicy: p.enforceHostPolicy, + config: c, + manager: &autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(c.LetsEncryptCacheDir), + }, } + p.manager.HostPolicy = p.enforceHostPolicy return p, nil } // enforceHostPolicy is responsible for the hostname policy func (p *provider) enforceHostPolicy(_ context.Context, hostname string) error { - if len(p.hostnames) > 0 { + if len(p.hostnames()) > 0 { found := false - for _, h := range p.hostnames { + for _, h := range p.hostnames() { found = found || (h == hostname) } if !found { return errors.ErrHostNotConfigured } - } else if p.redirectionURL != "" { - u, err := url.Parse(p.redirectionURL) + } else if p.redirectionURL() != "" { + u, err := url.Parse(p.redirectionURL()) if err != nil { return err } @@ -76,3 +74,13 @@ func (p *provider) enforceHostPolicy(_ context.Context, hostname string) error { func (p *provider) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { return p.manager.GetCertificate(h) } + +// hostnames returns a list of hostnames from the config +func (p *provider) hostnames() []string { + return p.config.Hostnames +} + +// redirectionURL returns the redirectionURL from config +func (p *provider) redirectionURL() string { + return p.config.RedirectionURL +} diff --git a/pkg/certs/rotate/main.go b/pkg/certs/rotate/main.go index 67b1c8d63..38a1ee77a 100644 --- a/pkg/certs/rotate/main.go +++ b/pkg/certs/rotate/main.go @@ -29,33 +29,24 @@ import ( "go.uber.org/zap" ) -type certificationRotation struct { +type provider struct { sync.RWMutex - // certificate holds the current issuing certificate + config *api.Config certificate tls.Certificate - // certificateFile is the path the certificate - certificateFile string - // the privateKeyFile is the path of the private key - privateKeyFile string - // the logger for this service - log *zap.Logger + log *zap.Logger } -// New creates a new certificate +// New creates a new rotate provider func New(c *api.Config, log *zap.Logger) (certs.Provider, error) { - // step: attempt to load the certificate certificate, err := tls.LoadX509KeyPair(c.TLSCertificate, c.TLSPrivateKey) if err != nil { return nil, err } - svc := &certificationRotation{ - certificate: certificate, - certificateFile: c.TLSCertificate, - log: log, - privateKeyFile: c.TLSPrivateKey, + svc := &provider{ + certificate: certificate, + config: c, + log: log, } - - // start watching the certificates if err := svc.watch(); err != nil { return nil, err } @@ -64,26 +55,26 @@ func New(c *api.Config, log *zap.Logger) (certs.Provider, error) { } // watch is responsible for adding a file notification and watch on the files for changes -func (c *certificationRotation) watch() error { - c.log.Info("adding a file watch on the certificates, certificate", - zap.String("certificate", c.certificateFile), - zap.String("private_key", c.privateKeyFile)) +func (p *provider) watch() error { + p.log.Info("adding a file watch on the tls certificates", + zap.String("certificate", p.tlsCertificate()), + zap.String("private_key", p.tlsPrivateKey())) watcher, err := fsnotify.NewWatcher() if err != nil { return err } // add the files to the watch list - for _, x := range []string{c.certificateFile, c.privateKeyFile} { + for _, x := range []string{p.tlsCertificate(), p.tlsPrivateKey()} { if err := watcher.Add(path.Dir(x)); err != nil { return fmt.Errorf("unable to add watch on directory: %s, error: %s", path.Dir(x), err) } } // step: watching for events - filewatchPaths := []string{c.certificateFile, c.privateKeyFile} + filewatchPaths := []string{p.tlsCertificate(), p.tlsPrivateKey()} go func() { - c.log.Info("starting to watch changes to the tls certificate files") + p.log.Info("starting to watch changes to the tls certificate files") for { select { case event := <-watcher.Events: @@ -93,19 +84,17 @@ func (c *certificationRotation) watch() error { continue } // step: reload the certificate - certificate, err := tls.LoadX509KeyPair(c.certificateFile, c.privateKeyFile) + certificate, err := tls.LoadX509KeyPair(p.tlsCertificate(), p.tlsPrivateKey()) if err != nil { - c.log.Error("unable to load the updated certificate", + p.log.Error("unable to load the updated certificate", zap.String("filename", event.Name), zap.Error(err)) } - // step: load the new certificate - c.storeCertificate(certificate) - // step: print a debug message for us - c.log.Info("replacing the server certifacte with updated version") + p.storeCertificate(certificate) + p.log.Info("replacing the server certifacte with updated version") } case err := <-watcher.Errors: - c.log.Error("recieved an error from the file watcher", zap.Error(err)) + p.log.Error("recieved an error from the file watcher", zap.Error(err)) } } }() @@ -114,18 +103,26 @@ func (c *certificationRotation) watch() error { } // storeCertificate provides entrypoint to update the certificate -func (c *certificationRotation) storeCertificate(certifacte tls.Certificate) error { - c.Lock() - defer c.Unlock() - c.certificate = certifacte +func (p *provider) storeCertificate(certifacte tls.Certificate) error { + p.Lock() + defer p.Unlock() + p.certificate = certifacte return nil } // GetCertificate is responsible for retrieving -func (c *certificationRotation) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - c.RLock() - defer c.RUnlock() +func (p *provider) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + p.RLock() + defer p.RUnlock() + + return &p.certificate, nil +} + +func (p *provider) tlsCertificate() string { + return p.config.TLSCertificate +} - return &c.certificate, nil +func (p *provider) tlsPrivateKey() string { + return p.config.TLSPrivateKey } diff --git a/pkg/certs/rotate/main_test.go b/pkg/certs/rotate/main_test.go index ce137dda5..58615c7ba 100644 --- a/pkg/certs/rotate/main_test.go +++ b/pkg/certs/rotate/main_test.go @@ -30,12 +30,10 @@ const ( testPrivateKeyFile = "../../../tests/proxy-key.pem" ) -func newTestCertificateRotator(t *testing.T) *certificationRotation { +func newTestCertificateRotator(t *testing.T) *provider { p, err := New(&api.Config{TLSCertificate: testCertificateFile, TLSPrivateKey: testPrivateKeyFile}, zap.NewNop()) - c := p.(*certificationRotation) + c := p.(*provider) assert.NotNil(t, c) - assert.Equal(t, testCertificateFile, c.certificateFile) - assert.Equal(t, testPrivateKeyFile, c.privateKeyFile) if !assert.NoError(t, err) { t.Fatalf("unable to create the certificate rotator, error: %s", err) } diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index e46ca5669..b5945f5f2 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -40,4 +40,6 @@ var ( ErrUserInfoValidation = errors.New("token not validate by userinfo endpoint") // ErrHostNotConfigured indicates the hostname was not found ErrHostNotConfigured = errors.New("acme/autocert: host not configured") + // ErrInvalidFormat indicates the config format is unreadable + ErrInvalidFormat = errors.New("invalid config format") ) diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index 212b59072..555e70cc3 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -43,8 +43,6 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request) var redirect string switch r.config.RedirectionURL { case "": - // need to determine the scheme, cx.Request.URL.Scheme doesn't have it, best way is to default - // and then check for TLS scheme := constants.HTTPSchema if req.TLS != nil { scheme = constants.HTTPSSchema @@ -91,7 +89,6 @@ func (r *oauthProxy) oauthAuthorizationHandler(w http.ResponseWriter, req *http. model["redirect"] = authURL w.WriteHeader(http.StatusOK) r.Render(w, path.Base(r.config.SignInPage), utils.MergeMaps(model, r.config.Tags)) - return } @@ -275,7 +272,6 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { // the user can specify a url to redirect the back redirectURL := req.URL.Query().Get("redirect") - // step: drop the access token user, err := r.getIdentity(req) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -305,7 +301,6 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { } revocationURL := utils.DefaultTo(r.config.RevocationEndpoint, revokeDefault) - // step: do we have a revocation endpoint? if revocationURL != "" { client, err := r.client.OAuthClient() if err != nil { @@ -335,7 +330,6 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { return } - // step: check the response switch response.StatusCode { case http.StatusNoContent: r.log.Info("successfully logged out of the endpoint", zap.String("email", user.email)) @@ -346,7 +340,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { zap.String("response", fmt.Sprintf("%s", content))) } } - // step: should we redirect the user + if redirectURL != "" { r.redirectToURL(redirectURL, w, req) } @@ -359,7 +353,6 @@ func (r *oauthProxy) expirationHandler(w http.ResponseWriter, req *http.Request) w.WriteHeader(http.StatusUnauthorized) return } - if user.isExpired() { w.WriteHeader(http.StatusUnauthorized) return @@ -432,7 +425,7 @@ func (r *oauthProxy) proxyMetricsHandler(w http.ResponseWriter, req *http.Reques } // retrieveRefreshToken retrieves the refresh token from store or cookie -func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) (token, ecrypted string, err error) { +func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) (token, encrypted string, err error) { switch r.useStore() { case true: token, err = r.GetRefreshToken(user.token) @@ -443,7 +436,7 @@ func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) return } - ecrypted = token // returns encryped, avoid encoding twice + encrypted = token token, err = utils.DecodeText(token, r.config.EncryptionKey) return } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index f11950f18..0166a875a 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -39,10 +39,10 @@ import ( "unicode" "unicode/utf8" + "github.com/gambol99/go-oidc/jose" "github.com/gambol99/keycloak-proxy/pkg/constants" "github.com/gambol99/keycloak-proxy/pkg/errors" - "github.com/gambol99/go-oidc/jose" "github.com/urfave/cli" "gopkg.in/yaml.v2" ) @@ -57,12 +57,15 @@ func ReadConfigFile(filename string, data interface{}) error { if err != nil { return err } - // step: attempt to un-marshal the data switch ext := filepath.Ext(filename); ext { - case "json": + case ".json": err = json.Unmarshal(content, data) - default: + case ".yaml": + fallthrough + case ".yml": err = yaml.Unmarshal(content, data) + default: + return errors.ErrInvalidFormat } return err @@ -150,7 +153,6 @@ func DefaultTo(v, d string) string { if v != "" { return v } - return d } @@ -224,7 +226,6 @@ func TransferBytes(src io.Reader, dest io.Writer, wg *sync.WaitGroup) (int64, er // TryUpdateConnection attempt to upgrade the connection to a http pdy stream func TryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint *url.URL) error { - // step: dial the endpoint tlsConn, err := TryDialEndpoint(endpoint) if err != nil { return err @@ -243,7 +244,7 @@ func TryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint return err } - // step: copy the date between client and upstream endpoint + // step: copy the data between client and upstream endpoint var wg sync.WaitGroup wg.Add(2) go TransferBytes(tlsConn, clientConn, &wg) @@ -282,8 +283,6 @@ func FindCookie(name string, cookies []*http.Cookie) *http.Cookie { // ToHeader is a helper method to play nice in the headers func ToHeader(v string) string { var list []string - - // step: filter out any symbols and convert to dashes for _, x := range symbolsFilter.Split(v, -1) { list = append(list, Capitalize(x)) } From 2b9287570b25aa173b0522cf84011fec666c8774 Mon Sep 17 00:00:00 2001 From: Rohith Date: Sun, 23 Jul 2017 19:45:40 +0100 Subject: [PATCH 4/7] - updated the Makefile to get the correct version file --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 6ec909b7b..27f1a6783 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ ROOT_DIR=${PWD} HARDWARE=$(shell uname -m) GIT_SHA=$(shell git --no-pager describe --always --dirty) BUILD_TIME=$(shell date '+%s') -VERSION ?= $(shell awk '/release.*=/ { print $$3 }' doc.go | sed 's/"//g') +VERSION ?= $(shell awk '/Release.*=/ { print $$3 }' pkg/constants/const.go | sed 's/"//g') DEPS=$(shell go list -f '{{range .TestImports}}{{.}} {{end}}' ./...) PACKAGES=$(shell go list ./... | grep -v vendor) LFLAGS ?= -X constants.Gitsha=${GIT_SHA} -X constants.Compiled=${BUILD_TIME} From 0356a4260f33b2093368517f40ed26c62fd8c27c Mon Sep 17 00:00:00 2001 From: Rohith Date: Sun, 23 Jul 2017 19:45:58 +0100 Subject: [PATCH 5/7] - updated the Dockerfile to use stages --- Dockerfile | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 67990e013..405b55dfb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,16 +1,23 @@ -FROM alpine:3.6 -MAINTAINER Rohith Jayawardene +FROM golang:1.8 as build +RUN go get -d github.com/gambol99/keycloak-proxy \ + && cd /go/src/github.com/gambol99/keycloak-proxy \ + && make static + +FROM ubuntu as certs +RUN apt-get update && apt-get install -y ca-certificates + +FROM scratch +COPY --from=build /go/src/github.com/gambol99/keycloak-proxy/bin/keycloak-proxy /opt/keycloak-proxy +COPY --from=certs /etc/ssl/certs /etc/ssl/certs LABEL Name=keycloak-proxy \ + Maintainer="Rohith Jayawardene " \ Release=https://github.com/gambol99/keycloak-proxy \ Url=https://github.com/gambol99/keycloak-proxy \ Help=https://github.com/gambol99/keycloak-proxy/issues -RUN apk add ca-certificates --update - ADD templates/ /opt/templates -ADD bin/keycloak-proxy /opt/keycloak-proxy -WORKDIR "/opt" +WORKDIR /opt -ENTRYPOINT [ "/opt/keycloak-proxy" ] +CMD [ "/opt/keycloak-proxy" ] From dc9cb740b2b3842e4f4dec37bb90f3f37b85ec8b Mon Sep 17 00:00:00 2001 From: Rohith Date: Wed, 2 Aug 2017 23:50:17 +0100 Subject: [PATCH 6/7] a --- CHANGELOG.md | 1 + config_sample.yml | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b72ad55b..26fc7c092 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ FEATURES * moved to use zap for the logging [#PR237](https://github.com/gambol99/keycloak-proxy/pull/237) * making the X-Auth-Token optional in the upstream headers via the --enable-token-header [#PR247](https://github.com/gambol99/keycloak-proxy/pull/247) * the upstream url is optional, meaning when not configured via --upstream-url is will proxy all requests to the Host header [#PR248](https://github.com/gambol99/keycloak-proxy/pull/248) +* updated the Dockerfile to use stages and build in one go [#PR?](https://github.com/gambol99/keycloak-proxy/pull/?] * adding the ability to load a CA authority to provide trust on upstream endpoint [#PR248](https://github.com/gambol99/keycloak-proxy/pull/248) BREAKING CHANGES: diff --git a/config_sample.yml b/config_sample.yml index 2662dfc51..40bc7849d 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -50,8 +50,9 @@ headers: # a map of claims that MUST exist in the token presented and the value is it MUST match # So for example, you could match the audience or the issuer or some custom attribute -virtual: -- hostname: default +hosts: +- hostname: _default_ + upstream-url: https://ingress.svc.cluster.local resources: - uri: /admin/test # the methods on this url that should be protected, if missing, we assuming all From 530ff934e66bc5897656e9e568b1673dcc9b481f Mon Sep 17 00:00:00 2001 From: Rohith Date: Sat, 5 Aug 2017 17:57:16 +0100 Subject: [PATCH 7/7] Utils unit testing fix - fixing up the ReadConfigFile unit test --- pkg/api/config.go | 3 +-- pkg/utils/utils_test.go | 49 ++++++++++++++++++++++------------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/pkg/api/config.go b/pkg/api/config.go index 728cad270..e86a9a611 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -40,8 +40,7 @@ func NewDefaultConfig() *Config { Tags: make(map[string]string), UpstreamKeepaliveTimeout: time.Duration(10) * time.Second, UpstreamTimeout: time.Duration(10) * time.Second, - UseLetsEncrypt: false, - LetsEncryptCacheDir: "./cache/", + UseLetsEncrypt: false, } } diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index a2800cfa9..258f41180 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -26,6 +26,7 @@ import ( "testing" "time" + "github.com/gambol99/keycloak-proxy/pkg/api" "github.com/gambol99/keycloak-proxy/pkg/constants" "github.com/stretchr/testify/assert" @@ -409,21 +410,28 @@ func TestMergeMaps(t *testing.T) { } } -func TestReadConfiguration(t *testing.T) { - var test struct { - ID int `yaml:"id"` - Name string `yaml:"name"` - } +func TestReadConfigurationNotFound(t *testing.T) { assert.Error(t, ReadConfigFile("not_found", nil)) - content := ` -id: 12 -name: test -` - file := writeFakeConfigFile(t, content) - assert.NoError(t, ReadConfigFile(file.Name(), &test)) - assert.Equal(t, 12, test.ID) - assert.Equal(t, "test", test.Name) - os.Remove(file.Name()) +} + +func TestReadConfigurationOK(t *testing.T) { + cs := []struct { + Content string + Expected api.Config + }{ + { + Content: "upstream-url: http://127.0.0.1:8080", + Expected: api.Config{Upstream: "http://127.0.0.1:8080"}, + }, + } + for _, c := range cs { + config := api.Config{} + filename := writeFakeConfigFile(t, c.Content) + defer os.Remove(filename) + assert.NoError(t, ReadConfigFile(filename, &config)) + assert.Equal(t, c.Expected, config) + + } } func getFakeURL(location string) *url.URL { @@ -431,16 +439,11 @@ func getFakeURL(location string) *url.URL { return u } -func writeFakeConfigFile(t *testing.T, content string) *os.File { - f, err := ioutil.TempFile("", "node_label_file") - if err != nil { - t.Fatalf("unexpected error creating node_label_file: %v", err) - } - f.Close() - - if err := ioutil.WriteFile(f.Name(), []byte(content), 0700); err != nil { +func writeFakeConfigFile(t *testing.T, content string) string { + filename := "/tmp/keycloak_proxy_test.yml" + if err := ioutil.WriteFile(filename, []byte(content), 0700); err != nil { t.Fatalf("unexpected error writing node label file: %v", err) } - return f + return filename }