Skip to content

Commit

Permalink
Refactor grpc-server package
Browse files Browse the repository at this point in the history
* Untangled certificate acquisition
* Made gRPC server accept different registrable APIs
* Moved postgres setup logic out of tink-server's main.go to
  an internal package

Signed-off-by: Micah Hausler <[email protected]>
  • Loading branch information
micahhausler committed Mar 17, 2022
1 parent e1e2af6 commit e3588d5
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 163 deletions.
37 changes: 37 additions & 0 deletions cmd/tink-server/internal/postgres_setup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package internal

import (
"database/sql"

"github.com/packethost/pkg/log"
"github.com/tinkerbell/tink/db"
)

// SetupPostgres initializes a connection to a postgres database.
func SetupPostgres(connInfo string, onlyMigrate bool, logger log.Logger) (db.Database, error) {

dbCon, err := sql.Open("postgres", connInfo)
if err != nil {
return nil, err
}
tinkDB := db.Connect(dbCon, logger)

if onlyMigrate {
logger.Info("Applying migrations. This process will end when migrations will take place.")
numAppliedMigrations, err := tinkDB.Migrate()
if err != nil {
return nil, err
}
logger.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully")
return nil, nil
}

numAvailableMigrations, err := tinkDB.CheckRequiredMigrations()
if err != nil {
return nil, err
}
if numAvailableMigrations != 0 {
logger.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.")
}
return *tinkDB, nil
}
69 changes: 41 additions & 28 deletions cmd/tink-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@ package main

import (
"context"
"database/sql"
"crypto/tls"
"fmt"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"

"github.com/equinix-labs/otel-init-go/otelinit"
"github.com/packethost/pkg/env"
"github.com/packethost/pkg/log"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"github.com/tinkerbell/tink/db"
grpcServer "github.com/tinkerbell/tink/grpc-server"
"github.com/tinkerbell/tink/cmd/tink-server/internal"
grpcserver "github.com/tinkerbell/tink/grpc-server"
httpServer "github.com/tinkerbell/tink/http-server"
"github.com/tinkerbell/tink/metrics"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// version is set at build time.
Expand Down Expand Up @@ -130,46 +134,55 @@ func NewRootCommand(config *DaemonConfig, logger log.Logger) *cobra.Command {
config.PGPassword,
config.PGSSLMode,
)

dbCon, err := sql.Open("postgres", connInfo)
database, err := internal.SetupPostgres(connInfo, config.OnlyMigration, logger)
if err != nil {
return err
}
tinkDB := db.Connect(dbCon, logger)

if config.OnlyMigration {
logger.Info("Applying migrations. This process will end when migrations will take place.")
numAppliedMigrations, err := tinkDB.Migrate()
return nil
}

var (
grpcOpts []grpc.ServerOption
certPEM []byte
certModTime *time.Time
)
if config.TLS {
certsDir := os.Getenv("TINKERBELL_CERTS_DIR")
if certsDir == "" {
certsDir = filepath.Join("/certs", config.Facility)
}
var cert *tls.Certificate
cert, certPEM, certModTime, err = grpcserver.GetCerts(certsDir)
if err != nil {
return err
}
logger.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully")
return nil
grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewServerTLSFromCert(cert)))
}

numAvailableMigrations, err := tinkDB.CheckRequiredMigrations()
tinkAPI := grpcserver.NewTinkServer(
logger,
database,
grpcserver.WithCerts(*certModTime, certPEM),
)

// Start the gRPC server in the background
addr, err := grpcserver.SetupGRPC(
ctx,
logger,
tinkAPI,
config.GRPCAuthority,
grpcOpts,
errCh)
if err != nil {
return err
}
if numAvailableMigrations != 0 {
logger.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.")
}

grpcConfig := &grpcServer.ConfigGRPCServer{
Facility: config.Facility,
TLSCert: "insecure",
GRPCAuthority: config.GRPCAuthority,
DB: tinkDB,
}
if config.TLS {
grpcConfig.TLSCert = config.TLSCert
}
cert, modT := grpcServer.SetupGRPC(ctx, logger, grpcConfig, errCh)
logger.With("address", addr).Info("started listener")

httpConfig := &httpServer.Config{
HTTPAuthority: config.HTTPAuthority,
CertPEM: cert,
ModTime: modT,
CertPEM: certPEM,
ModTime: *certModTime,
}
httpServer.SetupHTTP(ctx, logger, httpConfig, errCh)

Expand Down
187 changes: 91 additions & 96 deletions grpc-server/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"

Expand All @@ -21,12 +20,61 @@ import (
"github.com/tinkerbell/tink/protos/workflow"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
)

// Server is the gRPC server for tinkerbell.
type server struct {
// GetCerts returns a TLS certificate, PEM bytes, and file modification time for a
// given path. An error is returned for any failure.
//
//
// The public key is expected to be named "bundle.pem" and the private key
// "server.pem".
func GetCerts(certsDir string) (*tls.Certificate, []byte, *time.Time, error) {
certFile, err := os.Open(filepath.Join(certsDir, "bundle.pem"))
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to open TLS cert")
}

stat, err := certFile.Stat()
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to stat TLS cert")
}
modT := stat.ModTime()
certPEM, err := ioutil.ReadAll(certFile)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to read TLS cert")
}
err = certFile.Close()
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to close TLS cert")
}

keyPEM, err := ioutil.ReadFile(filepath.Join(certsDir, "server-key.pem"))
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to read TLS key")
}

cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to parse TLS file content")
}

return &cert, certPEM, &modT, nil
}

// Option is a type for modifying a TinkServer.
type Option func(*TinkServer)

// WithCerts sets a certificate mod time and material on a server
func WithCerts(modTime time.Time, publicCertPEM []byte) Option {
return func(s *TinkServer) {
s.modT = modTime
s.cert = publicCertPEM
}
}

// Server is the gRPC TinkServer for tinkerbell.
type TinkServer struct {
cert []byte
modT time.Time

Expand All @@ -42,113 +90,60 @@ type server struct {
logger log.Logger
}

type ConfigGRPCServer struct {
Facility string
TLSCert string
GRPCAuthority string
DB db.Database
// NewTinkServer returns a new Tinkerbell server.
func NewTinkServer(l log.Logger, database db.Database, opts ...Option) *TinkServer {
ts := &TinkServer{
db: database,
logger: l,
dbReady: true,
}
for _, opt := range opts {
opt(ts)
}

return ts
}

// SetupGRPC setup and return a gRPC server.
func SetupGRPC(ctx context.Context, logger log.Logger, config *ConfigGRPCServer, errCh chan<- error) ([]byte, time.Time) {
// Register registers Template, Workflow, and Hardware APIs on a gRPC server.
func (s *TinkServer) Register(server *grpc.Server) {
template.RegisterTemplateServiceServer(server, s)
workflow.RegisterWorkflowServiceServer(server, s)
hardware.RegisterHardwareServiceServer(server, s)
}

// Registrar is an interface for registering APIs on a gRPC server.
type Registrar interface {
Register(*grpc.Server)
}

// SetupGRPC opens a listener and serves a given Registrar's APIs on a gRPC server
// and returns the listener's address or an error.
func SetupGRPC(ctx context.Context, logger log.Logger, r Registrar, listenAddr string, opts []grpc.ServerOption, errCh chan<- error) (serverAddr string, err error) {
params := []grpc.ServerOption{
grpc_middleware.WithUnaryServerChain(grpc_prometheus.UnaryServerInterceptor, otelgrpc.UnaryServerInterceptor()),
grpc_middleware.WithStreamServerChain(grpc_prometheus.StreamServerInterceptor, otelgrpc.StreamServerInterceptor()),
}
server := &server{
db: config.DB,
dbReady: true,
logger: logger,
}
cert := config.TLSCert
switch cert {
case "insecure":
// server.cert *must* be nil, which it is because that is the default value
// server.modT doesn't matter
case "":
tlsCert, certPEM, modT := getCerts(config.Facility, logger)
params = append(params, grpc.Creds(credentials.NewServerTLSFromCert(&tlsCert)))
server.cert = certPEM
server.modT = modT
default:
server.cert = []byte(cert)
server.modT = time.Now()
}
params = append(params, opts...)

// register servers
s := grpc.NewServer(params...)
template.RegisterTemplateServiceServer(s, server)
workflow.RegisterWorkflowServiceServer(s, server)
hardware.RegisterHardwareServiceServer(s, server)
r.Register(s)
reflection.Register(s)

grpc_prometheus.Register(s)

go func() {
lis, err := net.Listen("tcp", config.GRPCAuthority)
if err != nil {
err = errors.Wrap(err, "failed to listen")
logger.Error(err)
panic(err)
}

errCh <- s.Serve(lis)
}()

go func() {
<-ctx.Done()
s.GracefulStop()
}()
return server.cert, server.modT
}

func getCerts(facility string, logger log.Logger) (tls.Certificate, []byte, time.Time) {
var (
certPEM []byte
modT time.Time
)

certsDir := os.Getenv("TINKERBELL_CERTS_DIR")
if certsDir == "" {
certsDir = "/certs/" + facility
}
if !strings.HasSuffix(certsDir, "/") {
certsDir += "/"
}

certFile, err := os.Open(filepath.Clean(certsDir + "bundle.pem"))
lis, err := net.Listen("tcp", listenAddr)
if err != nil {
err = errors.Wrap(err, "failed to open TLS cert")
logger.Error(err)
panic(err)
return "", errors.Wrap(err, "failed to listen")
}

if stat, err := certFile.Stat(); err != nil {
err = errors.Wrap(err, "failed to stat TLS cert")
logger.Error(err)
panic(err)
} else {
modT = stat.ModTime()
}
go func(errChan chan<- error) {
errChan <- s.Serve(lis)
}(errCh)

certPEM, err = ioutil.ReadAll(certFile)
if err != nil {
err = errors.Wrap(err, "failed to read TLS cert")
logger.Error(err)
panic(err)
}
keyPEM, err := ioutil.ReadFile(filepath.Clean(certsDir + "server-key.pem"))
if err != nil {
err = errors.Wrap(err, "failed to read TLS key")
logger.Error(err)
panic(err)
}
go func(ctx context.Context, s *grpc.Server) {
<-ctx.Done()
s.GracefulStop()
}(ctx, s)

cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
err = errors.Wrap(err, "failed to ingest TLS files")
logger.Error(err)
panic(err)
}
return cert, certPEM, modT
return lis.Addr().String(), nil
}
Loading

0 comments on commit e3588d5

Please sign in to comment.