Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor grpc-server package #598

Merged
merged 1 commit into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions cmd/tink-server/internal/postgres_setup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package internal

import (
"database/sql"

"github.com/packethost/pkg/log"
"github.com/tinkerbell/tink/db"
)

// SetupPostgres initializes a connection to a postgres database.
func SetupPostgres(connInfo string, onlyMigrate bool, logger log.Logger) (db.Database, error) {
dbCon, err := sql.Open("postgres", connInfo)
if err != nil {
return nil, err
}
tinkDB := db.Connect(dbCon, logger)

if onlyMigrate {
logger.Info("Applying migrations. This process will end when migrations will take place.")
numAppliedMigrations, err := tinkDB.Migrate()
if err != nil {
return nil, err
}
logger.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully")
return nil, nil
}

numAvailableMigrations, err := tinkDB.CheckRequiredMigrations()
if err != nil {
return nil, err
}
if numAvailableMigrations != 0 {
logger.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.")
}
return *tinkDB, nil
}
74 changes: 45 additions & 29 deletions cmd/tink-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@ package main

import (
"context"
"database/sql"
"crypto/tls"
"fmt"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"

"github.com/equinix-labs/otel-init-go/otelinit"
"github.com/packethost/pkg/env"
"github.com/packethost/pkg/log"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"github.com/tinkerbell/tink/db"
grpcServer "github.com/tinkerbell/tink/grpc-server"
httpServer "github.com/tinkerbell/tink/http-server"
"github.com/tinkerbell/tink/cmd/tink-server/internal"
grpcserver "github.com/tinkerbell/tink/grpc-server"
httpserver "github.com/tinkerbell/tink/http-server"
"github.com/tinkerbell/tink/metrics"
"github.com/tinkerbell/tink/server"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// version is set at build time.
Expand Down Expand Up @@ -130,48 +135,59 @@ func NewRootCommand(config *DaemonConfig, logger log.Logger) *cobra.Command {
config.PGPassword,
config.PGSSLMode,
)

dbCon, err := sql.Open("postgres", connInfo)
database, err := internal.SetupPostgres(connInfo, config.OnlyMigration, logger)
if err != nil {
return err
}
tinkDB := db.Connect(dbCon, logger)

if config.OnlyMigration {
logger.Info("Applying migrations. This process will end when migrations will take place.")
numAppliedMigrations, err := tinkDB.Migrate()
return nil
}

var (
grpcOpts []grpc.ServerOption
certPEM []byte
certModTime *time.Time
)
if config.TLS {
certsDir := os.Getenv("TINKERBELL_CERTS_DIR")
if certsDir == "" {
certsDir = filepath.Join("/certs", config.Facility)
}
var cert *tls.Certificate
cert, certPEM, certModTime, err = grpcserver.GetCerts(certsDir)
if err != nil {
return err
}
logger.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully")
return nil
grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewServerTLSFromCert(cert)))
}

numAvailableMigrations, err := tinkDB.CheckRequiredMigrations()
tinkAPI, err := server.NewDBServer(
logger,
database,
server.WithCerts(*certModTime, certPEM),
)
if err != nil {
return err
}
if numAvailableMigrations != 0 {
logger.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.")
}

grpcConfig := &grpcServer.ConfigGRPCServer{
Facility: config.Facility,
TLSCert: "insecure",
GRPCAuthority: config.GRPCAuthority,
DB: tinkDB,
}
if config.TLS {
grpcConfig.TLSCert = config.TLSCert
// Start the gRPC server in the background
addr, err := grpcserver.SetupGRPC(
ctx,
tinkAPI,
config.GRPCAuthority,
grpcOpts,
errCh)
if err != nil {
return err
}
cert, modT := grpcServer.SetupGRPC(ctx, logger, grpcConfig, errCh)
logger.With("address", addr).Info("started listener")

httpConfig := &httpServer.Config{
httpConfig := &httpserver.Config{
HTTPAuthority: config.HTTPAuthority,
CertPEM: cert,
ModTime: modT,
CertPEM: certPEM,
ModTime: *certModTime,
}
httpServer.SetupHTTP(ctx, logger, httpConfig, errCh)
httpserver.SetupHTTP(ctx, logger, httpConfig, errCh)

Comment on lines +174 to 191
Copy link
Member

@chrisdoherty4 chrisdoherty4 Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could really benefit from using https://github.com/oklog/run. Unnecessary for this PR, but if you get time separately it would help clean up the subsequent error channel code that does static draining of the channel.

select {
case err = <-errCh:
Expand Down
161 changes: 50 additions & 111 deletions grpc-server/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,148 +7,87 @@ import (
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"

grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/packethost/pkg/log"
"github.com/pkg/errors"
"github.com/tinkerbell/tink/db"
"github.com/tinkerbell/tink/protos/hardware"
"github.com/tinkerbell/tink/protos/template"
"github.com/tinkerbell/tink/protos/workflow"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
)

// Server is the gRPC server for tinkerbell.
type server struct {
cert []byte
modT time.Time
// GetCerts returns a TLS certificate, PEM bytes, and file modification time for a
// given path. An error is returned for any failure.
//
// The public key is expected to be named "bundle.pem" and the private key
// "server.pem".
func GetCerts(certsDir string) (*tls.Certificate, []byte, *time.Time, error) {
certFile, err := os.Open(filepath.Join(certsDir, "bundle.pem"))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

G304: Potential file inclusion via variable
(at-me in a reply with help or ignore)

if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to open TLS cert")
}

db db.Database
quit <-chan struct{}
stat, err := certFile.Stat()
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to stat TLS cert")
}
modT := stat.ModTime()
certPEM, err := ioutil.ReadAll(certFile)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to read TLS cert")
}
err = certFile.Close()
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to close TLS cert")
}

dbLock sync.RWMutex
dbReady bool
keyPEM, err := ioutil.ReadFile(filepath.Join(certsDir, "server-key.pem"))
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to read TLS key")
}

watchLock sync.RWMutex
watch map[string]chan string
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to parse TLS file content")
}

logger log.Logger
return &cert, certPEM, &modT, nil
}

type ConfigGRPCServer struct {
micahhausler marked this conversation as resolved.
Show resolved Hide resolved
Facility string
TLSCert string
GRPCAuthority string
DB db.Database
// Registrar is an interface for registering APIs on a gRPC server.
type Registrar interface {
Register(*grpc.Server)
}

// SetupGRPC setup and return a gRPC server.
func SetupGRPC(ctx context.Context, logger log.Logger, config *ConfigGRPCServer, errCh chan<- error) ([]byte, time.Time) {
// SetupGRPC opens a listener and serves a given Registrar's APIs on a gRPC server
// and returns the listener's address or an error.
func SetupGRPC(ctx context.Context, r Registrar, listenAddr string, opts []grpc.ServerOption, errCh chan<- error) (serverAddr string, err error) {
params := []grpc.ServerOption{
grpc_middleware.WithUnaryServerChain(grpc_prometheus.UnaryServerInterceptor, otelgrpc.UnaryServerInterceptor()),
grpc_middleware.WithStreamServerChain(grpc_prometheus.StreamServerInterceptor, otelgrpc.StreamServerInterceptor()),
}
server := &server{
db: config.DB,
dbReady: true,
logger: logger,
}
cert := config.TLSCert
switch cert {
case "insecure":
// server.cert *must* be nil, which it is because that is the default value
// server.modT doesn't matter
case "":
tlsCert, certPEM, modT := getCerts(config.Facility, logger)
params = append(params, grpc.Creds(credentials.NewServerTLSFromCert(&tlsCert)))
server.cert = certPEM
server.modT = modT
default:
server.cert = []byte(cert)
server.modT = time.Now()
}
params = append(params, opts...)

// register servers
s := grpc.NewServer(params...)
template.RegisterTemplateServiceServer(s, server)
workflow.RegisterWorkflowServiceServer(s, server)
hardware.RegisterHardwareServiceServer(s, server)
r.Register(s)
reflection.Register(s)

grpc_prometheus.Register(s)

go func() {
lis, err := net.Listen("tcp", config.GRPCAuthority)
if err != nil {
err = errors.Wrap(err, "failed to listen")
logger.Error(err)
panic(err)
}

errCh <- s.Serve(lis)
}()

go func() {
<-ctx.Done()
s.GracefulStop()
}()
return server.cert, server.modT
}

func getCerts(facility string, logger log.Logger) (tls.Certificate, []byte, time.Time) {
var (
certPEM []byte
modT time.Time
)

certsDir := os.Getenv("TINKERBELL_CERTS_DIR")
if certsDir == "" {
certsDir = "/certs/" + facility
}
if !strings.HasSuffix(certsDir, "/") {
certsDir += "/"
}

certFile, err := os.Open(filepath.Clean(certsDir + "bundle.pem"))
lis, err := net.Listen("tcp", listenAddr)
if err != nil {
err = errors.Wrap(err, "failed to open TLS cert")
logger.Error(err)
panic(err)
return "", errors.Wrap(err, "failed to listen")
}

if stat, err := certFile.Stat(); err != nil {
err = errors.Wrap(err, "failed to stat TLS cert")
logger.Error(err)
panic(err)
} else {
modT = stat.ModTime()
}
go func(errChan chan<- error) {
errChan <- s.Serve(lis)
}(errCh)

certPEM, err = ioutil.ReadAll(certFile)
if err != nil {
err = errors.Wrap(err, "failed to read TLS cert")
logger.Error(err)
panic(err)
}
keyPEM, err := ioutil.ReadFile(filepath.Clean(certsDir + "server-key.pem"))
if err != nil {
err = errors.Wrap(err, "failed to read TLS key")
logger.Error(err)
panic(err)
}
go func(ctx context.Context, s *grpc.Server) {
<-ctx.Done()
s.GracefulStop()
}(ctx, s)

cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
err = errors.Wrap(err, "failed to ingest TLS files")
logger.Error(err)
panic(err)
}
return cert, certPEM, modT
return lis.Addr().String(), nil
}
Loading