Skip to content

Commit

Permalink
fix negative waitgroup, fix cert expiry date, better auto renewal str…
Browse files Browse the repository at this point in the history
…ategy
  • Loading branch information
yusing committed Mar 23, 2024
1 parent fff790b commit 539ef91
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 62 deletions.
Binary file modified bin/go-proxy
Binary file not shown.
176 changes: 120 additions & 56 deletions src/go-proxy/autocert.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,21 @@ import (

"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/registration"
)

type ProviderOptions = map[string]string
type ProviderGenerator = func(ProviderOptions) (challenge.Provider, error)
type CertExpiries = map[string]time.Time

type AutoCertConfig struct {
Email string
Domains []string `yaml:",flow"`
Provider string
Options map[string]string `yaml:",flow"`
Options ProviderOptions `yaml:",flow"`
}

type AutoCertUser struct {
Expand All @@ -46,11 +51,11 @@ func (u *AutoCertUser) GetPrivateKey() crypto.PrivateKey {
type AutoCertProvider interface {
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
GetName() string
GetExpiry() time.Time
GetExpiries() CertExpiries
LoadCert() bool
ObtainCert() error

needRenew() bool
RenewalOn() time.Time
ScheduleRenewal()
}

func (cfg AutoCertConfig) GetProvider() (AutoCertProvider, error) {
Expand Down Expand Up @@ -78,58 +83,56 @@ func (cfg AutoCertConfig) GetProvider() (AutoCertProvider, error) {
if err != nil {
return nil, fmt.Errorf("unable to create lego client: %v", err)
}
base := &AutoCertProviderBase{
base := &autoCertProvider{
name: cfg.Provider,
cfg: cfg,
user: user,
legoCfg: legoCfg,
client: legoClient,
}
switch cfg.Provider {
case "cloudflare":
return NewAutoCertCFProvider(base, cfg.Options)
gen, ok := providersGenMap[cfg.Provider]
if !ok {
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
}
legoProvider, err := gen(cfg.Options)
if err != nil {
return nil, fmt.Errorf("unable to create provider: %v", err)
}
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
if err != nil {
return nil, fmt.Errorf("unable to set challenge provider: %v", err)
}
return base, nil
}

type AutoCertProviderBase struct {
type autoCertProvider struct {
name string
cfg AutoCertConfig
user *AutoCertUser
legoCfg *lego.Config
client *lego.Client

tlsCert *tls.Certificate
expiry time.Time
mutex sync.Mutex
tlsCert *tls.Certificate
certExpiries CertExpiries
mutex sync.Mutex
}

func (p *AutoCertProviderBase) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
func (p *autoCertProvider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
aclog.Fatal("no certificate available")
}
if p.needRenew() {
p.mutex.Lock()
defer p.mutex.Unlock()
if p.needRenew() {
err := p.ObtainCert()
if err != nil {
return nil, err
}
}
}
return p.tlsCert, nil
}

func (p *AutoCertProviderBase) GetName() string {
func (p *autoCertProvider) GetName() string {
return p.name
}

func (p *AutoCertProviderBase) GetExpiry() time.Time {
return p.expiry
func (p *autoCertProvider) GetExpiries() CertExpiries {
return p.certExpiries
}

func (p *AutoCertProviderBase) ObtainCert() error {
func (p *autoCertProvider) ObtainCert() error {
client := p.client
if p.user.Registration == nil {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
Expand All @@ -154,30 +157,55 @@ func (p *AutoCertProviderBase) ObtainCert() error {
if err != nil {
return err
}
p.tlsCert = &tlsCert
x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[len(tlsCert.Certificate)-1])
expiries, err := getCertExpiries(&tlsCert)
if err != nil {
return err
}
p.expiry = x509Cert.NotAfter
p.tlsCert = &tlsCert
p.certExpiries = expiries
return nil
}

func (p *AutoCertProviderBase) LoadCert() bool {
func (p *autoCertProvider) LoadCert() bool {
cert, err := tls.LoadX509KeyPair(certFileDefault, keyFileDefault)
if err != nil {
return false
}
x509Cert, err := x509.ParseCertificate(cert.Certificate[len(cert.Certificate)-1])
expiries, err := getCertExpiries(&cert)
if err != nil {
return false
}
p.tlsCert = &cert
p.expiry = x509Cert.NotAfter
p.certExpiries = expiries
p.renewIfNeeded()
return true
}

func (p *AutoCertProviderBase) saveCert(cert *certificate.Resource) error {
func (p *autoCertProvider) RenewalOn() time.Time {
t := time.Now().AddDate(0, 0, 3)
for _, expiry := range p.certExpiries {
if expiry.Before(t) {
return time.Now()
}
return t
}
// this line should never be reached
panic("no certificate available")
}

func (p *autoCertProvider) ScheduleRenewal() {
for {
t := time.Until(p.RenewalOn())
aclog.Infof("next renewal in %v", t)
time.Sleep(t)
err := p.renewIfNeeded()
if err != nil {
aclog.Fatal(err)
}
}
}

func (p *autoCertProvider) saveCert(cert *certificate.Resource) error {
err := os.MkdirAll(path.Dir(certFileDefault), 0644)
if err != nil {
return fmt.Errorf("unable to create cert directory: %v", err)
Expand All @@ -193,36 +221,68 @@ func (p *AutoCertProviderBase) saveCert(cert *certificate.Resource) error {
return nil
}

func (p *AutoCertProviderBase) needRenew() bool {
return p.expiry.Before(time.Now().Add(24 * time.Hour))
func (p *autoCertProvider) needRenewal() bool {
return time.Now().After(p.RenewalOn())
}

type AutoCertCFProvider struct {
*AutoCertProviderBase
*cloudflare.Config
}
func (p *autoCertProvider) renewIfNeeded() error {
if !p.needRenewal() {
return nil
}

func NewAutoCertCFProvider(base *AutoCertProviderBase, opt map[string]string) (*AutoCertCFProvider, error) {
p := &AutoCertCFProvider{
base,
cloudflare.NewDefaultConfig(),
p.mutex.Lock()
defer p.mutex.Unlock()

if !p.needRenewal() {
return nil
}
err := setOptions(p.Config, opt)
if err != nil {
return nil, err

trials := 0
for {
err := p.ObtainCert()
if err == nil {
return nil
}
trials++
if trials > 3 {
return fmt.Errorf("unable to renew certificate: %v after 3 trials", err)
}
aclog.Errorf("failed to renew certificate: %v, trying again in 5 seconds", err)
time.Sleep(5 * time.Second)
}
legoProvider, err := cloudflare.NewDNSProviderConfig(p.Config)
if err != nil {
return nil, fmt.Errorf("unable to create cloudflare provider: %v", err)
}

func providerGenerator[CT interface{}, PT challenge.Provider](defaultCfg func() *CT, newProvider func(*CT) (PT, error)) ProviderGenerator {
return func(opt ProviderOptions) (challenge.Provider, error) {
cfg := defaultCfg()
err := setOptions(cfg, opt)
if err != nil {
return nil, err
}
p, err := newProvider(cfg)
if err != nil {
return nil, err
}
return p, nil
}
err = p.client.Challenge.SetDNS01Provider(legoProvider)
if err != nil {
return nil, fmt.Errorf("unable to set challenge provider: %v", err)
}

func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate {
x509Cert, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
if x509Cert.IsCA {
continue
}
r[x509Cert.Subject.CommonName] = x509Cert.NotAfter
}
return p, nil
return r, nil
}

func setOptions[T interface{}](cfg *T, opt map[string]string) error {
func setOptions[T interface{}](cfg *T, opt ProviderOptions) error {
for k, v := range opt {
err := SetFieldFromSnake(cfg, k, v)
if err != nil {
Expand All @@ -231,3 +291,7 @@ func setOptions[T interface{}](cfg *T, opt map[string]string) error {
}
return nil
}

var providersGenMap = map[string]ProviderGenerator{
"cloudflare": providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig),
}
11 changes: 6 additions & 5 deletions src/go-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
)

func main() {
// flag.Parse()
runtime.GOMAXPROCS(runtime.NumCPU())

logrus.SetFormatter(&logrus.TextFormatter{
Expand Down Expand Up @@ -52,7 +51,10 @@ func main() {
aclog.Fatal("error obtaining certificate ", err)
}
}
aclog.Infof("certificate will be expired at %v and get renewed", autoCertProvider.GetExpiry())
for name, expiry := range autoCertProvider.GetExpiries() {
aclog.Infof("certificate %q: expire on %v", name, expiry)
}
go autoCertProvider.ScheduleRenewal()
}
proxyServer = NewServer(
"proxy",
Expand Down Expand Up @@ -86,9 +88,8 @@ func main() {
signal.Notify(sig, syscall.SIGHUP)

<-sig
cfg.StopWatching()
StopFSWatcher()
StopDockerWatcher()
// cfg.StopWatching()

cfg.StopProviders()
panelServer.Stop()
proxyServer.Stop()
Expand Down
2 changes: 2 additions & 0 deletions src/go-proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ func (s *Server) Stop() {
if s.httpStarted {
errHTTP := s.http.Shutdown(ctx)
s.handleErr("http", errHTTP)
s.httpStarted = false
}

if s.httpsStarted {
errHTTPS := s.https.Shutdown(ctx)
s.handleErr("https", errHTTPS)
s.httpsStarted = false
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/go-proxy/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ func InitFSWatcher() {
func InitDockerWatcher() {
// stop all docker client on watcher stop
go func() {
defer dockerWatcherWg.Done()
<-dockerWatcherStop
ParallelForEachValue(
dockerWatchMap.Iterator(),
(*dockerWatcher).Dispose,
)
dockerWatcherWg.Done()
}()
}

Expand Down

0 comments on commit 539ef91

Please sign in to comment.