diff --git a/example/idp/idp.go b/example/idp/idp.go index 81ecfeed..2ff8fc87 100644 --- a/example/idp/idp.go +++ b/example/idp/idp.go @@ -2,18 +2,15 @@ package main import ( "crypto" + "crypto/x509" "encoding/pem" "flag" - "log" "net/url" - "golang.org/x/crypto/bcrypt" - - "github.com/zenazn/goji" - - "crypto/x509" - + "github.com/crewjam/saml/logger" "github.com/crewjam/saml/samlidp" + "github.com/zenazn/goji" + "golang.org/x/crypto/bcrypt" ) var key = func() crypto.PrivateKey { @@ -73,22 +70,24 @@ UzreO96WzlBBMtY= }() func main() { + logr := logger.DefaultLogger baseURLstr := flag.String("idp", "", "The URL to the IDP") flag.Parse() baseURL, err := url.Parse(*baseURLstr) if err != nil { - log.Fatalf("cannot parse base URL: %v", err) + logr.Fatalf("cannot parse base URL: %v", err) } idpServer, err := samlidp.New(samlidp.Options{ URL: *baseURL, Key: key, + Logger: logr, Certificate: cert, Store: &samlidp.MemoryStore{}, }) if err != nil { - log.Fatalf("%s", err) + logr.Fatalf("%s", err) } hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("hunter2"), bcrypt.DefaultCost) @@ -101,7 +100,7 @@ func main() { GivenName: "Alice", }) if err != nil { - log.Fatalf("%s", err) + logr.Fatalf("%s", err) } err = idpServer.Store.Put("/users/bob", samlidp.User{ @@ -114,7 +113,7 @@ func main() { GivenName: "Bob", }) if err != nil { - log.Fatalf("%s", err) + logr.Fatalf("%s", err) } goji.Handle("/*", idpServer) diff --git a/example/service.go b/example/service.go index 57d2d9f0..6a4f03f5 100644 --- a/example/service.go +++ b/example/service.go @@ -9,7 +9,6 @@ import ( "encoding/xml" "flag" "fmt" - "log" "net/http" "net/url" "strings" @@ -19,6 +18,7 @@ import ( "github.com/zenazn/goji" "github.com/zenazn/goji/web" + "github.com/crewjam/saml/logger" "github.com/crewjam/saml/samlsp" ) @@ -100,6 +100,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ== ) func main() { + logr := logger.DefaultLogger rootURLstr := flag.String("url", "https://962766ce.ngrok.io", "The base URL of this service") idpMetadataURLstr := flag.String("idp", "https://516becc2.ngrok.io/metadata", "The metadata URL for the IDP") flag.Parse() @@ -126,12 +127,13 @@ func main() { samlSP, err := samlsp.New(samlsp.Options{ URL: *rootURL, Key: keyPair.PrivateKey.(*rsa.PrivateKey), + Logger: logr, Certificate: keyPair.Leaf, AllowIDPInitiated: true, IDPMetadataURL: idpMetadataURL, }) if err != nil { - log.Fatalf("%s", err) + logr.Fatalf("%s", err) } // register with the service provider diff --git a/identity_provider.go b/identity_provider.go index 1b48ec69..921c7746 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -11,7 +11,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "net/http" "net/url" "os" @@ -20,6 +19,7 @@ import ( "time" "github.com/beevik/etree" + "github.com/crewjam/saml/logger" "github.com/crewjam/saml/xmlenc" dsig "github.com/russellhaering/goxmldsig" ) @@ -79,6 +79,7 @@ type ServiceProviderProvider interface { // and password). type IdentityProvider struct { Key crypto.PrivateKey + Logger logger.Interface Certificate *x509.Certificate MetadataURL url.URL SSOURL url.URL @@ -169,13 +170,13 @@ func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, r *http.Reques func (idp *IdentityProvider) ServeSSO(w http.ResponseWriter, r *http.Request) { req, err := NewIdpAuthnRequest(idp, r) if err != nil { - log.Printf("failed to parse request: %s", err) + idp.Logger.Printf("failed to parse request: %s", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } if err := req.Validate(); err != nil { - log.Printf("failed to validate request: %s", err) + idp.Logger.Printf("failed to validate request: %s", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } @@ -190,12 +191,12 @@ func (idp *IdentityProvider) ServeSSO(w http.ResponseWriter, r *http.Request) { // we have a valid session and must make a SAML assertion if err := req.MakeAssertion(session); err != nil { - log.Printf("failed to make assertion: %s", err) + idp.Logger.Printf("failed to make assertion: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if err := req.WriteResponse(w); err != nil { - log.Printf("failed to write response: %s", err) + idp.Logger.Printf("failed to write response: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -219,11 +220,11 @@ func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Re var err error req.ServiceProviderMetadata, err = idp.ServiceProviderProvider.GetServiceProvider(r, serviceProviderID) if err == os.ErrNotExist { - log.Printf("cannot find service provider: %s", serviceProviderID) + idp.Logger.Printf("cannot find service provider: %s", serviceProviderID) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } else if err != nil { - log.Printf("cannot find service provider %s: %v", serviceProviderID, err) + idp.Logger.Printf("cannot find service provider %s: %v", serviceProviderID, err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -234,12 +235,12 @@ func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Re } if err := req.MakeAssertion(session); err != nil { - log.Printf("failed to make assertion: %s", err) + idp.Logger.Printf("failed to make assertion: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if err := req.WriteResponse(w); err != nil { - log.Printf("failed to write response: %s", err) + idp.Logger.Printf("failed to write response: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } diff --git a/identity_provider_test.go b/identity_provider_test.go index ffde5200..cc95e79c 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -18,6 +18,7 @@ import ( "os" "github.com/beevik/etree" + "github.com/crewjam/saml/logger" "github.com/crewjam/saml/testsaml" "github.com/crewjam/saml/xmlenc" "github.com/dgrijalva/jwt-go" @@ -125,6 +126,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ== test.IDP = IdentityProvider{ Key: test.Key, Certificate: test.Certificate, + Logger: logger.DefaultLogger, MetadataURL: mustParseURL("https://idp.example.com/saml/metadata"), SSOURL: mustParseURL("https://idp.example.com/saml/sso"), ServiceProviderProvider: &mockServiceProviderProvider{ diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..c211aba6 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,31 @@ +package logger + +import ( + "log" + "os" +) + +// Interface provides the minimal logging interface +type Interface interface { + // Printf prints to the logger using the format. + Printf(format string, v ...interface{}) + // Print prints to the logger. + Print(v ...interface{}) + // Println prints new line. + Println(v ...interface{}) + // Fatal is equivalent to Print() followed by a call to os.Exit(1). + Fatal(v ...interface{}) + // Fatalf is equivalent to Printf() followed by a call to os.Exit(1). + Fatalf(format string, v ...interface{}) + // Fatalln is equivalent to Println() followed by a call to os.Exit(1). + Fatalln(v ...interface{}) + // Panic is equivalent to Print() followed by a call to panic(). + Panic(v ...interface{}) + // Panicf is equivalent to Printf() followed by a call to panic(). + Panicf(format string, v ...interface{}) + // Panicln is equivalent to Println() followed by a call to panic(). + Panicln(v ...interface{}) +} + +// DefaultLogger logs messages to os.Stdout +var DefaultLogger = log.New(os.Stdout, "", log.LstdFlags) diff --git a/samlidp/samlidp.go b/samlidp/samlidp.go index cd6228a3..237eb69b 100644 --- a/samlidp/samlidp.go +++ b/samlidp/samlidp.go @@ -4,13 +4,13 @@ package samlidp import ( "crypto" + "crypto/x509" "net/http" "net/url" "sync" - "crypto/x509" - "github.com/crewjam/saml" + "github.com/crewjam/saml/logger" "github.com/zenazn/goji/web" ) @@ -18,6 +18,7 @@ import ( type Options struct { URL url.URL Key crypto.PrivateKey + Logger logger.Interface Certificate *x509.Certificate Store Store } @@ -35,6 +36,7 @@ type Options struct { type Server struct { http.Handler idpConfigMu sync.RWMutex // protects calls into the IDP + logger logger.Interface serviceProviders map[string]*saml.Metadata IDP saml.IdentityProvider // the underlying IDP Store Store // the data store @@ -46,16 +48,24 @@ func New(opts Options) (*Server, error) { metadataURL.Path = metadataURL.Path + "/metadata" ssoURL := opts.URL ssoURL.Path = ssoURL.Path + "/sso" + logr := opts.Logger + if logr == nil { + logr = logger.DefaultLogger + } + s := &Server{ serviceProviders: map[string]*saml.Metadata{}, IDP: saml.IdentityProvider{ Key: opts.Key, + Logger: logr, Certificate: opts.Certificate, MetadataURL: metadataURL, SSOURL: ssoURL, }, - Store: opts.Store, + logger: logr, + Store: opts.Store, } + s.IDP.SessionProvider = s s.IDP.ServiceProviderProvider = s diff --git a/samlidp/samlidp_test.go b/samlidp/samlidp_test.go index 46ae61ba..f98259c0 100644 --- a/samlidp/samlidp_test.go +++ b/samlidp/samlidp_test.go @@ -16,6 +16,7 @@ import ( "crypto/rsa" "github.com/crewjam/saml" + "github.com/crewjam/saml/logger" "github.com/dgrijalva/jwt-go" ) @@ -123,6 +124,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ== MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"), AcsURL: mustParseURL("https://sp.example.com/saml2/acs"), IDPMetadata: &saml.Metadata{}, + Logger: logger.DefaultLogger, } test.Key = mustParsePrivateKey("-----BEGIN RSA PRIVATE KEY-----\nMIICXgIBAAKBgQDU8wdiaFmPfTyRYuFlVPi866WrH/2JubkHzp89bBQopDaLXYxi\n3PTu3O6Q/KaKxMOFBqrInwqpv/omOGZ4ycQ51O9I+Yc7ybVlW94lTo2gpGf+Y/8E\nPsVbnZaFutRctJ4dVIp9aQ2TpLiGT0xX1OzBO/JEgq9GzDRf+B+eqSuglwIDAQAB\nAoGBAMuy1eN6cgFiCOgBsB3gVDdTKpww87Qk5ivjqEt28SmXO13A1KNVPS6oQ8SJ\nCT5Azc6X/BIAoJCURVL+LHdqebogKljhH/3yIel1kH19vr4E2kTM/tYH+qj8afUS\nJEmArUzsmmK8ccuNqBcllqdwCZjxL4CHDUmyRudFcHVX9oyhAkEA/OV1OkjM3CLU\nN3sqELdMmHq5QZCUihBmk3/N5OvGdqAFGBlEeewlepEVxkh7JnaNXAXrKHRVu/f/\nfbCQxH+qrwJBANeQERF97b9Sibp9xgolb749UWNlAdqmEpmlvmS202TdcaaT1msU\n4rRLiQN3X9O9mq4LZMSVethrQAdX1whawpkCQQDk1yGf7xZpMJ8F4U5sN+F4rLyM\nRq8Sy8p2OBTwzCUXXK+fYeXjybsUUMr6VMYTRP2fQr/LKJIX+E5ZxvcIyFmDAkEA\nyfjNVUNVaIbQTzEbRlRvT6MqR+PTCefC072NF9aJWR93JimspGZMR7viY6IM4lrr\nvBkm0F5yXKaYtoiiDMzlOQJADqmEwXl0D72ZG/2KDg8b4QZEmC9i5gidpQwJXUc6\nhU+IVQoLxRq0fBib/36K9tcrrO5Ba4iEvDcNY+D8yGbUtA==\n-----END RSA PRIVATE KEY-----\n") test.Certificate = mustParseCertificate("-----BEGIN CERTIFICATE-----\nMIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJV\nUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0\nMB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMx\nCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCB\nnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9\nibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmH\nO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKv\nRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgk\nakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeT\nQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvn\nOwJlNCASPZRH/JmF8tX0hoHuAQ==\n-----END CERTIFICATE-----\n") @@ -131,10 +133,11 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ== var err error test.Server, err = New(Options{ - URL: url.URL{Scheme: "https", Host: "idp.example.com"}, - Key: test.Key, Certificate: test.Certificate, + Key: test.Key, + Logger: logger.DefaultLogger, Store: &test.Store, + URL: url.URL{Scheme: "https", Host: "idp.example.com"}, }) c.Assert(err, IsNil) diff --git a/samlidp/service.go b/samlidp/service.go index c982f97b..befdab95 100644 --- a/samlidp/service.go +++ b/samlidp/service.go @@ -4,7 +4,6 @@ import ( "encoding/json" "encoding/xml" "fmt" - "log" "net/http" "os" @@ -40,7 +39,7 @@ func (s *Server) GetServiceProvider(r *http.Request, serviceProviderID string) ( func (s *Server) HandleListServices(c web.C, w http.ResponseWriter, r *http.Request) { services, err := s.Store.List("/services/") if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -56,7 +55,7 @@ func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, r *http.Reques service := Service{} err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -68,14 +67,14 @@ func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, r *http.Reques func (s *Server) HandlePutService(c web.C, w http.ResponseWriter, r *http.Request) { service := Service{} if err := xml.NewDecoder(r.Body).Decode(&service.Metadata); err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } err := s.Store.Put(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -92,13 +91,13 @@ func (s *Server) HandleDeleteService(c web.C, w http.ResponseWriter, r *http.Req service := Service{} err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if err := s.Store.Delete(fmt.Sprintf("/services/%s", c.URLParams["id"])); err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } diff --git a/samlidp/shortcut.go b/samlidp/shortcut.go index 7c4e2431..151a84ea 100644 --- a/samlidp/shortcut.go +++ b/samlidp/shortcut.go @@ -3,7 +3,6 @@ package samlidp import ( "encoding/json" "fmt" - "log" "net/http" "github.com/zenazn/goji/web" @@ -91,7 +90,7 @@ func (s *Server) HandleIDPInitiated(c web.C, w http.ResponseWriter, r *http.Requ shortcutName := c.URLParams["shortcut"] shortcut := Shortcut{} if err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", shortcutName), &shortcut); err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } diff --git a/samlidp/user.go b/samlidp/user.go index a60c17ab..23751020 100644 --- a/samlidp/user.go +++ b/samlidp/user.go @@ -3,7 +3,6 @@ package samlidp import ( "encoding/json" "fmt" - "log" "net/http" "github.com/zenazn/goji/web" @@ -28,7 +27,7 @@ type User struct { func (s *Server) HandleListUsers(c web.C, w http.ResponseWriter, r *http.Request) { users, err := s.Store.List("/users/") if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -44,7 +43,7 @@ func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, r *http.Request) user := User{} err := s.Store.Get(fmt.Sprintf("/users/%s", c.URLParams["id"]), &user) if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -59,7 +58,7 @@ func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, r *http.Request) func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) { user := User{} if err := json.NewDecoder(r.Body).Decode(&user); err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } @@ -69,7 +68,7 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) var err error user.HashedPassword, err = bcrypt.GenerateFromPassword([]byte(*user.PlaintextPassword), bcrypt.DefaultCost) if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -82,7 +81,7 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) case err == ErrNotFound: // nop default: - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -91,7 +90,7 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) err := s.Store.Put(fmt.Sprintf("/users/%s", c.URLParams["id"]), &user) if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -102,7 +101,7 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) func (s *Server) HandleDeleteUser(c web.C, w http.ResponseWriter, r *http.Request) { err := s.Store.Delete(fmt.Sprintf("/users/%s", c.URLParams["id"])) if err != nil { - log.Printf("ERROR: %s", err) + s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 54061998..69e6ef4f 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -1,19 +1,16 @@ package samlsp import ( + "crypto/x509" "encoding/base64" "encoding/xml" "fmt" - "log" "net/http" "strings" "time" - "github.com/dgrijalva/jwt-go" - - "crypto/x509" - "github.com/crewjam/saml" + "github.com/dgrijalva/jwt-go" ) // Middleware implements middleware than allows a web application @@ -86,7 +83,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { assertion, err := m.ServiceProvider.ParseResponse(r, m.getPossibleRequestIDs(r)) if err != nil { if parseErr, ok := err.(*saml.InvalidResponseError); ok { - log.Printf("RESPONSE: ===\n%s\n===\nNOW: %s\nERROR: %s", + m.ServiceProvider.Logger.Printf("RESPONSE: ===\n%s\n===\nNOW: %s\nERROR: %s", parseErr.Response, parseErr.Now, parseErr.PrivateErr) } http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) @@ -185,7 +182,7 @@ func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string { if !strings.HasPrefix(cookie.Name, "saml_") { continue } - log.Printf("getPossibleRequestIDs: cookie: %s", cookie.String()) + m.ServiceProvider.Logger.Printf("getPossibleRequestIDs: cookie: %s", cookie.String()) jwtParser := jwt.Parser{ ValidMethods: []string{jwtSigningMethod.Name}, @@ -195,7 +192,7 @@ func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string { return secretBlock, nil }) if err != nil || !token.Valid { - log.Printf("... invalid token %s", err) + m.ServiceProvider.Logger.Printf("... invalid token %s", err) continue } claims := token.Claims.(jwt.MapClaims) @@ -225,7 +222,7 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion if r.Form.Get("RelayState") != "" { stateCookie, err := r.Cookie(fmt.Sprintf("saml_%s", r.Form.Get("RelayState"))) if err != nil { - log.Printf("cannot find corresponding cookie: %s", fmt.Sprintf("saml_%s", r.Form.Get("RelayState"))) + m.ServiceProvider.Logger.Printf("cannot find corresponding cookie: %s", fmt.Sprintf("saml_%s", r.Form.Get("RelayState"))) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } @@ -237,7 +234,7 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion return secretBlock, nil }) if err != nil || !state.Valid { - log.Printf("Cannot decode state JWT: %s (%s)", err, stateCookie.Value) + m.ServiceProvider.Logger.Printf("Cannot decode state JWT: %s (%s)", err, stateCookie.Value) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } @@ -313,15 +310,15 @@ func (m *Middleware) IsAuthorized(r *http.Request) bool { return secretBlock, nil }) if err != nil || !token.Valid { - log.Printf("ERROR: invalid token: %s", err) + m.ServiceProvider.Logger.Printf("ERROR: invalid token: %s", err) return false } if err := tokenClaims.StandardClaims.Valid(); err != nil { - log.Printf("ERROR: invalid token claims: %s", err) + m.ServiceProvider.Logger.Printf("ERROR: invalid token claims: %s", err) return false } if tokenClaims.Audience != m.ServiceProvider.Metadata().EntityID { - log.Printf("ERROR: invalid audience: %s", err) + m.ServiceProvider.Logger.Printf("ERROR: invalid audience: %s", err) return false } diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index 42649858..bc6a2581 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "github.com/crewjam/saml" + "github.com/crewjam/saml/logger" "github.com/crewjam/saml/testsaml" ) @@ -73,6 +74,7 @@ func (test *MiddlewareTest) SetUpTest(c *C) { MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), IDPMetadata: &saml.Metadata{}, + Logger: logger.DefaultLogger, }, CookieName: "ttt", CookieMaxAge: time.Hour * 2, diff --git a/samlsp/samlsp.go b/samlsp/samlsp.go index e74746ac..a05850b9 100644 --- a/samlsp/samlsp.go +++ b/samlsp/samlsp.go @@ -3,24 +3,24 @@ package samlsp import ( + "crypto/rsa" + "crypto/x509" "encoding/xml" "fmt" "io/ioutil" - "log" "net/http" "net/url" "time" - "crypto/rsa" - "crypto/x509" - "github.com/crewjam/saml" + "github.com/crewjam/saml/logger" ) // Options represents the parameters for creating a new middleware type Options struct { URL url.URL Key *rsa.PrivateKey + Logger logger.Interface Certificate *x509.Certificate AllowIDPInitiated bool IDPMetadata *saml.Metadata @@ -34,10 +34,15 @@ func New(opts Options) (*Middleware, error) { metadataURL.Path = metadataURL.Path + "/saml/metadata" acsURL := opts.URL acsURL.Path = acsURL.Path + "/saml/acs" + logr := opts.Logger + if logr == nil { + logr = logger.DefaultLogger + } m := &Middleware{ ServiceProvider: saml.ServiceProvider{ Key: opts.Key, + Logger: logr, Certificate: opts.Certificate, MetadataURL: metadataURL, AcsURL: acsURL, @@ -76,7 +81,7 @@ func New(opts Options) (*Middleware, error) { if i > 10 { return nil, err } - log.Printf("ERROR: %s: %s (will retry)", opts.IDPMetadataURL, err) + logr.Printf("ERROR: %s: %s (will retry)", opts.IDPMetadataURL, err) time.Sleep(5 * time.Second) continue } diff --git a/service_provider.go b/service_provider.go index b240ac4b..89d12a77 100644 --- a/service_provider.go +++ b/service_provider.go @@ -16,13 +16,16 @@ import ( "time" "github.com/beevik/etree" + "github.com/crewjam/saml/logger" "github.com/crewjam/saml/xmlenc" dsig "github.com/russellhaering/goxmldsig" "github.com/russellhaering/goxmldsig/etreeutils" ) +// NameIDFormat is the format of the id type NameIDFormat string +// Name ID formats const ( UnspecifiedNameIDFormat NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:unspecified" TransientNameIDFormat NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" @@ -63,6 +66,9 @@ type ServiceProvider struct { // MetadataValidDuration is a duration used to calculate validUntil // attribute in the metadata endpoint MetadataValidDuration time.Duration + + // Logger is used to log messages for example in the event of errors + Logger logger.Interface } // MaxIssueDelay is the longest allowed time between when a SAML assertion is