Skip to content

Commit

Permalink
Use interface for logging to allow structured logs
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Nikolov authored and crewjam committed May 22, 2017
1 parent f102ca0 commit 6b5dd2d
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 62 deletions.
21 changes: 10 additions & 11 deletions example/idp/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@ package main

import (
"crypto"
"crypto/x509"
"encoding/pem"
"flag"
"log"
"net/url"

"golang.org/x/crypto/bcrypt"

"github.com/zenazn/goji"

"crypto/x509"

"github.com/crewjam/saml/logger"
"github.com/crewjam/saml/samlidp"
"github.com/zenazn/goji"
"golang.org/x/crypto/bcrypt"
)

var key = func() crypto.PrivateKey {
Expand Down Expand Up @@ -73,22 +70,24 @@ UzreO96WzlBBMtY=
}()

func main() {
logr := logger.DefaultLogger
baseURLstr := flag.String("idp", "", "The URL to the IDP")
flag.Parse()

baseURL, err := url.Parse(*baseURLstr)
if err != nil {
log.Fatalf("cannot parse base URL: %v", err)
logr.Fatalf("cannot parse base URL: %v", err)
}

idpServer, err := samlidp.New(samlidp.Options{
URL: *baseURL,
Key: key,
Logger: logr,
Certificate: cert,
Store: &samlidp.MemoryStore{},
})
if err != nil {
log.Fatalf("%s", err)
logr.Fatalf("%s", err)
}

hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("hunter2"), bcrypt.DefaultCost)
Expand All @@ -101,7 +100,7 @@ func main() {
GivenName: "Alice",
})
if err != nil {
log.Fatalf("%s", err)
logr.Fatalf("%s", err)
}

err = idpServer.Store.Put("/users/bob", samlidp.User{
Expand All @@ -114,7 +113,7 @@ func main() {
GivenName: "Bob",
})
if err != nil {
log.Fatalf("%s", err)
logr.Fatalf("%s", err)
}

goji.Handle("/*", idpServer)
Expand Down
6 changes: 4 additions & 2 deletions example/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"encoding/xml"
"flag"
"fmt"
"log"
"net/http"
"net/url"
"strings"
Expand All @@ -19,6 +18,7 @@ import (
"github.com/zenazn/goji"
"github.com/zenazn/goji/web"

"github.com/crewjam/saml/logger"
"github.com/crewjam/saml/samlsp"
)

Expand Down Expand Up @@ -100,6 +100,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==
)

func main() {
logr := logger.DefaultLogger
rootURLstr := flag.String("url", "https://962766ce.ngrok.io", "The base URL of this service")
idpMetadataURLstr := flag.String("idp", "https://516becc2.ngrok.io/metadata", "The metadata URL for the IDP")
flag.Parse()
Expand All @@ -126,12 +127,13 @@ func main() {
samlSP, err := samlsp.New(samlsp.Options{
URL: *rootURL,
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
Logger: logr,
Certificate: keyPair.Leaf,
AllowIDPInitiated: true,
IDPMetadataURL: idpMetadataURL,
})
if err != nil {
log.Fatalf("%s", err)
logr.Fatalf("%s", err)
}

// register with the service provider
Expand Down
19 changes: 10 additions & 9 deletions identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
Expand All @@ -20,6 +19,7 @@ import (
"time"

"github.com/beevik/etree"
"github.com/crewjam/saml/logger"
"github.com/crewjam/saml/xmlenc"
dsig "github.com/russellhaering/goxmldsig"
)
Expand Down Expand Up @@ -79,6 +79,7 @@ type ServiceProviderProvider interface {
// and password).
type IdentityProvider struct {
Key crypto.PrivateKey
Logger logger.Interface
Certificate *x509.Certificate
MetadataURL url.URL
SSOURL url.URL
Expand Down Expand Up @@ -169,13 +170,13 @@ func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, r *http.Reques
func (idp *IdentityProvider) ServeSSO(w http.ResponseWriter, r *http.Request) {
req, err := NewIdpAuthnRequest(idp, r)
if err != nil {
log.Printf("failed to parse request: %s", err)
idp.Logger.Printf("failed to parse request: %s", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}

if err := req.Validate(); err != nil {
log.Printf("failed to validate request: %s", err)
idp.Logger.Printf("failed to validate request: %s", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
Expand All @@ -190,12 +191,12 @@ func (idp *IdentityProvider) ServeSSO(w http.ResponseWriter, r *http.Request) {

// we have a valid session and must make a SAML assertion
if err := req.MakeAssertion(session); err != nil {
log.Printf("failed to make assertion: %s", err)
idp.Logger.Printf("failed to make assertion: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := req.WriteResponse(w); err != nil {
log.Printf("failed to write response: %s", err)
idp.Logger.Printf("failed to write response: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand All @@ -219,11 +220,11 @@ func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Re
var err error
req.ServiceProviderMetadata, err = idp.ServiceProviderProvider.GetServiceProvider(r, serviceProviderID)
if err == os.ErrNotExist {
log.Printf("cannot find service provider: %s", serviceProviderID)
idp.Logger.Printf("cannot find service provider: %s", serviceProviderID)
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
} else if err != nil {
log.Printf("cannot find service provider %s: %v", serviceProviderID, err)
idp.Logger.Printf("cannot find service provider %s: %v", serviceProviderID, err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand All @@ -234,12 +235,12 @@ func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Re
}

if err := req.MakeAssertion(session); err != nil {
log.Printf("failed to make assertion: %s", err)
idp.Logger.Printf("failed to make assertion: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := req.WriteResponse(w); err != nil {
log.Printf("failed to write response: %s", err)
idp.Logger.Printf("failed to write response: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand Down
2 changes: 2 additions & 0 deletions identity_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"

"github.com/beevik/etree"
"github.com/crewjam/saml/logger"
"github.com/crewjam/saml/testsaml"
"github.com/crewjam/saml/xmlenc"
"github.com/dgrijalva/jwt-go"
Expand Down Expand Up @@ -125,6 +126,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==
test.IDP = IdentityProvider{
Key: test.Key,
Certificate: test.Certificate,
Logger: logger.DefaultLogger,
MetadataURL: mustParseURL("https://idp.example.com/saml/metadata"),
SSOURL: mustParseURL("https://idp.example.com/saml/sso"),
ServiceProviderProvider: &mockServiceProviderProvider{
Expand Down
31 changes: 31 additions & 0 deletions logger/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package logger

import (
"log"
"os"
)

// Interface provides the minimal logging interface
type Interface interface {
// Printf prints to the logger using the format.
Printf(format string, v ...interface{})
// Print prints to the logger.
Print(v ...interface{})
// Println prints new line.
Println(v ...interface{})
// Fatal is equivalent to Print() followed by a call to os.Exit(1).
Fatal(v ...interface{})
// Fatalf is equivalent to Printf() followed by a call to os.Exit(1).
Fatalf(format string, v ...interface{})
// Fatalln is equivalent to Println() followed by a call to os.Exit(1).
Fatalln(v ...interface{})
// Panic is equivalent to Print() followed by a call to panic().
Panic(v ...interface{})
// Panicf is equivalent to Printf() followed by a call to panic().
Panicf(format string, v ...interface{})
// Panicln is equivalent to Println() followed by a call to panic().
Panicln(v ...interface{})
}

// DefaultLogger logs messages to os.Stdout
var DefaultLogger = log.New(os.Stdout, "", log.LstdFlags)
16 changes: 13 additions & 3 deletions samlidp/samlidp.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ package samlidp

import (
"crypto"
"crypto/x509"
"net/http"
"net/url"
"sync"

"crypto/x509"

"github.com/crewjam/saml"
"github.com/crewjam/saml/logger"
"github.com/zenazn/goji/web"
)

// Options represent the parameters to New() for creating a new IDP server
type Options struct {
URL url.URL
Key crypto.PrivateKey
Logger logger.Interface
Certificate *x509.Certificate
Store Store
}
Expand All @@ -35,6 +36,7 @@ type Options struct {
type Server struct {
http.Handler
idpConfigMu sync.RWMutex // protects calls into the IDP
logger logger.Interface
serviceProviders map[string]*saml.Metadata
IDP saml.IdentityProvider // the underlying IDP
Store Store // the data store
Expand All @@ -46,16 +48,24 @@ func New(opts Options) (*Server, error) {
metadataURL.Path = metadataURL.Path + "/metadata"
ssoURL := opts.URL
ssoURL.Path = ssoURL.Path + "/sso"
logr := opts.Logger
if logr == nil {
logr = logger.DefaultLogger
}

s := &Server{
serviceProviders: map[string]*saml.Metadata{},
IDP: saml.IdentityProvider{
Key: opts.Key,
Logger: logr,
Certificate: opts.Certificate,
MetadataURL: metadataURL,
SSOURL: ssoURL,
},
Store: opts.Store,
logger: logr,
Store: opts.Store,
}

s.IDP.SessionProvider = s
s.IDP.ServiceProviderProvider = s

Expand Down
7 changes: 5 additions & 2 deletions samlidp/samlidp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"crypto/rsa"

"github.com/crewjam/saml"
"github.com/crewjam/saml/logger"
"github.com/dgrijalva/jwt-go"
)

Expand Down Expand Up @@ -123,6 +124,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==
MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"),
AcsURL: mustParseURL("https://sp.example.com/saml2/acs"),
IDPMetadata: &saml.Metadata{},
Logger: logger.DefaultLogger,
}
test.Key = mustParsePrivateKey("-----BEGIN RSA PRIVATE KEY-----\nMIICXgIBAAKBgQDU8wdiaFmPfTyRYuFlVPi866WrH/2JubkHzp89bBQopDaLXYxi\n3PTu3O6Q/KaKxMOFBqrInwqpv/omOGZ4ycQ51O9I+Yc7ybVlW94lTo2gpGf+Y/8E\nPsVbnZaFutRctJ4dVIp9aQ2TpLiGT0xX1OzBO/JEgq9GzDRf+B+eqSuglwIDAQAB\nAoGBAMuy1eN6cgFiCOgBsB3gVDdTKpww87Qk5ivjqEt28SmXO13A1KNVPS6oQ8SJ\nCT5Azc6X/BIAoJCURVL+LHdqebogKljhH/3yIel1kH19vr4E2kTM/tYH+qj8afUS\nJEmArUzsmmK8ccuNqBcllqdwCZjxL4CHDUmyRudFcHVX9oyhAkEA/OV1OkjM3CLU\nN3sqELdMmHq5QZCUihBmk3/N5OvGdqAFGBlEeewlepEVxkh7JnaNXAXrKHRVu/f/\nfbCQxH+qrwJBANeQERF97b9Sibp9xgolb749UWNlAdqmEpmlvmS202TdcaaT1msU\n4rRLiQN3X9O9mq4LZMSVethrQAdX1whawpkCQQDk1yGf7xZpMJ8F4U5sN+F4rLyM\nRq8Sy8p2OBTwzCUXXK+fYeXjybsUUMr6VMYTRP2fQr/LKJIX+E5ZxvcIyFmDAkEA\nyfjNVUNVaIbQTzEbRlRvT6MqR+PTCefC072NF9aJWR93JimspGZMR7viY6IM4lrr\nvBkm0F5yXKaYtoiiDMzlOQJADqmEwXl0D72ZG/2KDg8b4QZEmC9i5gidpQwJXUc6\nhU+IVQoLxRq0fBib/36K9tcrrO5Ba4iEvDcNY+D8yGbUtA==\n-----END RSA PRIVATE KEY-----\n")
test.Certificate = mustParseCertificate("-----BEGIN CERTIFICATE-----\nMIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJV\nUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0\nMB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMx\nCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCB\nnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9\nibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmH\nO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKv\nRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgk\nakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeT\nQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvn\nOwJlNCASPZRH/JmF8tX0hoHuAQ==\n-----END CERTIFICATE-----\n")
Expand All @@ -131,10 +133,11 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==

var err error
test.Server, err = New(Options{
URL: url.URL{Scheme: "https", Host: "idp.example.com"},
Key: test.Key,
Certificate: test.Certificate,
Key: test.Key,
Logger: logger.DefaultLogger,
Store: &test.Store,
URL: url.URL{Scheme: "https", Host: "idp.example.com"},
})
c.Assert(err, IsNil)

Expand Down
13 changes: 6 additions & 7 deletions samlidp/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"encoding/xml"
"fmt"
"log"
"net/http"
"os"

Expand Down Expand Up @@ -40,7 +39,7 @@ func (s *Server) GetServiceProvider(r *http.Request, serviceProviderID string) (
func (s *Server) HandleListServices(c web.C, w http.ResponseWriter, r *http.Request) {
services, err := s.Store.List("/services/")
if err != nil {
log.Printf("ERROR: %s", err)
s.logger.Printf("ERROR: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand All @@ -56,7 +55,7 @@ func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, r *http.Reques
service := Service{}
err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service)
if err != nil {
log.Printf("ERROR: %s", err)
s.logger.Printf("ERROR: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand All @@ -68,14 +67,14 @@ func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, r *http.Reques
func (s *Server) HandlePutService(c web.C, w http.ResponseWriter, r *http.Request) {
service := Service{}
if err := xml.NewDecoder(r.Body).Decode(&service.Metadata); err != nil {
log.Printf("ERROR: %s", err)
s.logger.Printf("ERROR: %s", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}

err := s.Store.Put(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service)
if err != nil {
log.Printf("ERROR: %s", err)
s.logger.Printf("ERROR: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand All @@ -92,13 +91,13 @@ func (s *Server) HandleDeleteService(c web.C, w http.ResponseWriter, r *http.Req
service := Service{}
err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service)
if err != nil {
log.Printf("ERROR: %s", err)
s.logger.Printf("ERROR: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

if err := s.Store.Delete(fmt.Sprintf("/services/%s", c.URLParams["id"])); err != nil {
log.Printf("ERROR: %s", err)
s.logger.Printf("ERROR: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand Down
Loading

0 comments on commit 6b5dd2d

Please sign in to comment.