From 7f702d0f2d3b2edd5f831f9d7cb09d820605945e Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 17 Mar 2022 22:31:39 +0000 Subject: [PATCH] Refactor grpc-server package * Untangle certificate acquisition * Make gRPC server accept different registrable APIs * Move postgres setup logic out of tink-server's main.go to an internal package * Move gRPC server implementation to the new /server package * Remove unused global `workflowData` Signed-off-by: Micah Hausler --- cmd/tink-server/internal/postgres_setup.go | 36 ++++ cmd/tink-server/main.go | 74 ++++---- grpc-server/grpc_server.go | 161 ++++++------------ grpc-server/grpc_server_test.go | 128 ++++++++++++-- grpc-server/testdata/bundle.pem | 32 ++++ grpc-server/testdata/server-key.pem | 53 ++++++ server/dbserver.go | 78 +++++++++ .../dbserver_hardware.go | 24 +-- .../dbserver_hardware_test.go | 2 +- .../dbserver_template.go | 12 +- .../dbserver_template_test.go | 2 +- .../dbserver_worker_workflow.go | 90 ++++------ .../dbserver_worker_workflow_test.go | 6 +- .../dbserver_workflow.go | 14 +- .../dbserver_workflow_test.go | 2 +- 15 files changed, 478 insertions(+), 236 deletions(-) create mode 100644 cmd/tink-server/internal/postgres_setup.go create mode 100644 grpc-server/testdata/bundle.pem create mode 100644 grpc-server/testdata/server-key.pem create mode 100644 server/dbserver.go rename grpc-server/hardware.go => server/dbserver_hardware.go (88%) rename grpc-server/hardware_test.go => server/dbserver_hardware_test.go (99%) rename grpc-server/template.go => server/dbserver_template.go (88%) rename grpc-server/template_test.go => server/dbserver_template_test.go (99%) rename grpc-server/tinkerbell.go => server/dbserver_worker_workflow.go (63%) rename grpc-server/tinkerbell_test.go => server/dbserver_worker_workflow_test.go (99%) rename grpc-server/workflow.go => server/dbserver_workflow.go (92%) rename grpc-server/workflow_test.go => server/dbserver_workflow_test.go (99%) diff --git a/cmd/tink-server/internal/postgres_setup.go b/cmd/tink-server/internal/postgres_setup.go new file mode 100644 index 000000000..b354fe4b8 --- /dev/null +++ b/cmd/tink-server/internal/postgres_setup.go @@ -0,0 +1,36 @@ +package internal + +import ( + "database/sql" + + "github.com/packethost/pkg/log" + "github.com/tinkerbell/tink/db" +) + +// SetupPostgres initializes a connection to a postgres database. +func SetupPostgres(connInfo string, onlyMigrate bool, logger log.Logger) (db.Database, error) { + dbCon, err := sql.Open("postgres", connInfo) + if err != nil { + return nil, err + } + tinkDB := db.Connect(dbCon, logger) + + if onlyMigrate { + logger.Info("Applying migrations. This process will end when migrations will take place.") + numAppliedMigrations, err := tinkDB.Migrate() + if err != nil { + return nil, err + } + logger.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully") + return nil, nil + } + + numAvailableMigrations, err := tinkDB.CheckRequiredMigrations() + if err != nil { + return nil, err + } + if numAvailableMigrations != 0 { + logger.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.") + } + return *tinkDB, nil +} diff --git a/cmd/tink-server/main.go b/cmd/tink-server/main.go index 4908c2e50..d4b2801b7 100644 --- a/cmd/tink-server/main.go +++ b/cmd/tink-server/main.go @@ -2,12 +2,14 @@ package main import ( "context" - "database/sql" + "crypto/tls" "fmt" "os" "os/signal" + "path/filepath" "strings" "syscall" + "time" "github.com/equinix-labs/otel-init-go/otelinit" "github.com/packethost/pkg/env" @@ -15,10 +17,13 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" - "github.com/tinkerbell/tink/db" - grpcServer "github.com/tinkerbell/tink/grpc-server" - httpServer "github.com/tinkerbell/tink/http-server" + "github.com/tinkerbell/tink/cmd/tink-server/internal" + grpcserver "github.com/tinkerbell/tink/grpc-server" + httpserver "github.com/tinkerbell/tink/http-server" "github.com/tinkerbell/tink/metrics" + "github.com/tinkerbell/tink/server" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) // version is set at build time. @@ -130,48 +135,59 @@ func NewRootCommand(config *DaemonConfig, logger log.Logger) *cobra.Command { config.PGPassword, config.PGSSLMode, ) - - dbCon, err := sql.Open("postgres", connInfo) + database, err := internal.SetupPostgres(connInfo, config.OnlyMigration, logger) if err != nil { return err } - tinkDB := db.Connect(dbCon, logger) - if config.OnlyMigration { - logger.Info("Applying migrations. This process will end when migrations will take place.") - numAppliedMigrations, err := tinkDB.Migrate() + return nil + } + + var ( + grpcOpts []grpc.ServerOption + certPEM []byte + certModTime *time.Time + ) + if config.TLS { + certsDir := os.Getenv("TINKERBELL_CERTS_DIR") + if certsDir == "" { + certsDir = filepath.Join("/certs", config.Facility) + } + var cert *tls.Certificate + cert, certPEM, certModTime, err = grpcserver.GetCerts(certsDir) if err != nil { return err } - logger.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully") - return nil + grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewServerTLSFromCert(cert))) } - numAvailableMigrations, err := tinkDB.CheckRequiredMigrations() + tinkAPI, err := server.NewDBServer( + logger, + database, + server.WithCerts(*certModTime, certPEM), + ) if err != nil { return err } - if numAvailableMigrations != 0 { - logger.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.") - } - grpcConfig := &grpcServer.ConfigGRPCServer{ - Facility: config.Facility, - TLSCert: "insecure", - GRPCAuthority: config.GRPCAuthority, - DB: tinkDB, - } - if config.TLS { - grpcConfig.TLSCert = config.TLSCert + // Start the gRPC server in the background + addr, err := grpcserver.SetupGRPC( + ctx, + tinkAPI, + config.GRPCAuthority, + grpcOpts, + errCh) + if err != nil { + return err } - cert, modT := grpcServer.SetupGRPC(ctx, logger, grpcConfig, errCh) + logger.With("address", addr).Info("started listener") - httpConfig := &httpServer.Config{ + httpConfig := &httpserver.Config{ HTTPAuthority: config.HTTPAuthority, - CertPEM: cert, - ModTime: modT, + CertPEM: certPEM, + ModTime: *certModTime, } - httpServer.SetupHTTP(ctx, logger, httpConfig, errCh) + httpserver.SetupHTTP(ctx, logger, httpConfig, errCh) select { case err = <-errCh: diff --git a/grpc-server/grpc_server.go b/grpc-server/grpc_server.go index 383aca2c4..40cb61f07 100644 --- a/grpc-server/grpc_server.go +++ b/grpc-server/grpc_server.go @@ -7,148 +7,87 @@ import ( "net" "os" "path/filepath" - "strings" - "sync" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "github.com/packethost/pkg/log" "github.com/pkg/errors" - "github.com/tinkerbell/tink/db" - "github.com/tinkerbell/tink/protos/hardware" - "github.com/tinkerbell/tink/protos/template" - "github.com/tinkerbell/tink/protos/workflow" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/reflection" ) -// Server is the gRPC server for tinkerbell. -type server struct { - cert []byte - modT time.Time +// GetCerts returns a TLS certificate, PEM bytes, and file modification time for a +// given path. An error is returned for any failure. +// +// The public key is expected to be named "bundle.pem" and the private key +// "server.pem". +func GetCerts(certsDir string) (*tls.Certificate, []byte, *time.Time, error) { + certFile, err := os.Open(filepath.Join(certsDir, "bundle.pem")) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to open TLS cert") + } - db db.Database - quit <-chan struct{} + stat, err := certFile.Stat() + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to stat TLS cert") + } + modT := stat.ModTime() + certPEM, err := ioutil.ReadAll(certFile) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to read TLS cert") + } + err = certFile.Close() + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to close TLS cert") + } - dbLock sync.RWMutex - dbReady bool + keyPEM, err := ioutil.ReadFile(filepath.Join(certsDir, "server-key.pem")) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to read TLS key") + } - watchLock sync.RWMutex - watch map[string]chan string + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to parse TLS file content") + } - logger log.Logger + return &cert, certPEM, &modT, nil } -type ConfigGRPCServer struct { - Facility string - TLSCert string - GRPCAuthority string - DB db.Database +// Registrar is an interface for registering APIs on a gRPC server. +type Registrar interface { + Register(*grpc.Server) } -// SetupGRPC setup and return a gRPC server. -func SetupGRPC(ctx context.Context, logger log.Logger, config *ConfigGRPCServer, errCh chan<- error) ([]byte, time.Time) { +// SetupGRPC opens a listener and serves a given Registrar's APIs on a gRPC server +// and returns the listener's address or an error. +func SetupGRPC(ctx context.Context, r Registrar, listenAddr string, opts []grpc.ServerOption, errCh chan<- error) (serverAddr string, err error) { params := []grpc.ServerOption{ grpc_middleware.WithUnaryServerChain(grpc_prometheus.UnaryServerInterceptor, otelgrpc.UnaryServerInterceptor()), grpc_middleware.WithStreamServerChain(grpc_prometheus.StreamServerInterceptor, otelgrpc.StreamServerInterceptor()), } - server := &server{ - db: config.DB, - dbReady: true, - logger: logger, - } - cert := config.TLSCert - switch cert { - case "insecure": - // server.cert *must* be nil, which it is because that is the default value - // server.modT doesn't matter - case "": - tlsCert, certPEM, modT := getCerts(config.Facility, logger) - params = append(params, grpc.Creds(credentials.NewServerTLSFromCert(&tlsCert))) - server.cert = certPEM - server.modT = modT - default: - server.cert = []byte(cert) - server.modT = time.Now() - } + params = append(params, opts...) // register servers s := grpc.NewServer(params...) - template.RegisterTemplateServiceServer(s, server) - workflow.RegisterWorkflowServiceServer(s, server) - hardware.RegisterHardwareServiceServer(s, server) + r.Register(s) reflection.Register(s) - grpc_prometheus.Register(s) - go func() { - lis, err := net.Listen("tcp", config.GRPCAuthority) - if err != nil { - err = errors.Wrap(err, "failed to listen") - logger.Error(err) - panic(err) - } - - errCh <- s.Serve(lis) - }() - - go func() { - <-ctx.Done() - s.GracefulStop() - }() - return server.cert, server.modT -} - -func getCerts(facility string, logger log.Logger) (tls.Certificate, []byte, time.Time) { - var ( - certPEM []byte - modT time.Time - ) - - certsDir := os.Getenv("TINKERBELL_CERTS_DIR") - if certsDir == "" { - certsDir = "/certs/" + facility - } - if !strings.HasSuffix(certsDir, "/") { - certsDir += "/" - } - - certFile, err := os.Open(filepath.Clean(certsDir + "bundle.pem")) + lis, err := net.Listen("tcp", listenAddr) if err != nil { - err = errors.Wrap(err, "failed to open TLS cert") - logger.Error(err) - panic(err) + return "", errors.Wrap(err, "failed to listen") } - if stat, err := certFile.Stat(); err != nil { - err = errors.Wrap(err, "failed to stat TLS cert") - logger.Error(err) - panic(err) - } else { - modT = stat.ModTime() - } + go func(errChan chan<- error) { + errChan <- s.Serve(lis) + }(errCh) - certPEM, err = ioutil.ReadAll(certFile) - if err != nil { - err = errors.Wrap(err, "failed to read TLS cert") - logger.Error(err) - panic(err) - } - keyPEM, err := ioutil.ReadFile(filepath.Clean(certsDir + "server-key.pem")) - if err != nil { - err = errors.Wrap(err, "failed to read TLS key") - logger.Error(err) - panic(err) - } + go func(ctx context.Context, s *grpc.Server) { + <-ctx.Done() + s.GracefulStop() + }(ctx, s) - cert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - err = errors.Wrap(err, "failed to ingest TLS files") - logger.Error(err) - panic(err) - } - return cert, certPEM, modT + return lis.Addr().String(), nil } diff --git a/grpc-server/grpc_server_test.go b/grpc-server/grpc_server_test.go index 4781b7d69..488d0629d 100644 --- a/grpc-server/grpc_server_test.go +++ b/grpc-server/grpc_server_test.go @@ -3,12 +3,15 @@ package grpcserver import ( "context" "fmt" + "io/ioutil" + "path/filepath" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/ktr0731/evans/grpc" "github.com/packethost/pkg/log" + "github.com/tinkerbell/tink/server" ) func TestSetupGRPC(t *testing.T) { @@ -16,26 +19,51 @@ func TestSetupGRPC(t *testing.T) { server string client string } - tests := map[string]struct { + tests := []struct { + name string input input want []string err error }{ - "successful grpc client call": {input: input{server: "127.0.0.1:55005", client: "127.0.0.1:55005"}, want: []string{"HardwareService", "TemplateService", "WorkflowService", "ServerReflection"}}, - "grpc client fail to communicate": {input: input{server: "127.0.0.1:0", client: "127.0.0.1:55007"}, err: fmt.Errorf("failed to list services from reflection enabled gRPC server: rpc error: code = Unavailable desc = connection error: desc = \"transport: Error while dialing dial tcp 127.0.0.1:55007: connect: connection refused\"")}, + { + name: "successful grpc client call", + input: input{ + server: "127.0.0.1:55005", + client: "127.0.0.1:55005", + }, + want: []string{"HardwareService", "TemplateService", "WorkflowService", "ServerReflection"}, + }, + { + name: "grpc client fail to communicate", + input: input{ + server: "127.0.0.1:0", + client: "127.0.0.1:55007", + }, + err: fmt.Errorf("failed to list services from reflection enabled gRPC server: rpc error: code = Unavailable desc = connection error: desc = \"transport: Error while dialing dial tcp 127.0.0.1:55007: connect: connection refused\""), + }, } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() errCh := make(chan error) - logger, _ := log.Init("test_package") - SetupGRPC(ctx, logger, &ConfigGRPCServer{ - Facility: "onprem", - TLSCert: "just can't be an empty string", - GRPCAuthority: tc.input.server, - }, errCh) + logger := log.Test(t, "test_package") + tinkServer, _ := server.NewDBServer( + logger, + nil, + ) + _, err := SetupGRPC( + ctx, + tinkServer, + tc.input.server, + nil, + errCh) + if err != nil { + t.Errorf("failed to set up gRPC server: %v", err) + return + } + client, err := grpc.NewClient(tc.input.client, "name", true, false, "", "", "", nil) if err != nil { t.Fatal(err) @@ -71,3 +99,81 @@ func TestSetupGRPC(t *testing.T) { }) } } + +func TestGetCerts(t *testing.T) { + cases := []struct { + name string + setupFunc func(t *testing.T) (string, error) + wanterr error + }{ + { + "Real key file", + func(t *testing.T) (string, error) { + t.Helper() + return "./testdata", nil + }, + nil, + }, + { + "No cert", + func(t *testing.T) (string, error) { + t.Helper() + return "./not-a-directory", nil + }, + fmt.Errorf("failed to open TLS cert: open not-a-directory/bundle.pem: no such file or directory"), + }, + { + "empty content", + func(t *testing.T) (string, error) { + t.Helper() + tdir := t.TempDir() + err := ioutil.WriteFile(filepath.Join(tdir, "bundle.pem"), []byte{}, 0o644) + if err != nil { + return "", err + } + err = ioutil.WriteFile(filepath.Join(tdir, "server-key.pem"), []byte{}, 0o644) + if err != nil { + return "", err + } + return tdir, nil + }, + fmt.Errorf("failed to parse TLS file content: tls: failed to find any PEM data in certificate input"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + input, err := tc.setupFunc(t) + if err != nil { + t.Errorf("Failed to setup test: %v", err) + return + } + gotCert, gotBytes, modTime, err := GetCerts(input) + + if tc.wanterr == nil { + if gotCert == nil { + t.Error("Missing expected cert, got nil") + } + if gotBytes == nil { + t.Error("Missing expected cert bytes, got nil") + } + if modTime == nil { + t.Error("Missing expected cert mod time, got nil") + } + } + if tc.wanterr == nil && err == nil { + return + } + if tc.wanterr != nil { + if err == nil { + t.Errorf("Missing expected error %s", tc.wanterr.Error()) + return + } + if tc.wanterr.Error() != err.Error() { + t.Errorf("Got different error. Wanted %s, got %s", tc.wanterr.Error(), err.Error()) + } + return + } + }) + } +} diff --git a/grpc-server/testdata/bundle.pem b/grpc-server/testdata/bundle.pem new file mode 100644 index 000000000..e7130206c --- /dev/null +++ b/grpc-server/testdata/bundle.pem @@ -0,0 +1,32 @@ +TEST CERTIFICATE ONLY - NOT FOR REAL USAGE +-----BEGIN CERTIFICATE----- +MIIFTzCCAzegAwIBAgIUBviL+dMgIkMf7R8nPspjyb+LcGgwDQYJKoZIhvcNAQEL +BQAwNzENMAsGA1UECgwEQ05DRjETMBEGA1UECwwKVGlua2VyYmVsbDERMA8GA1UE +AwwIdGVzdC1rZXkwHhcNMjIwMzE4MTUyMjIyWhcNMzIwMzE1MTUyMjIyWjA3MQ0w +CwYDVQQKDARDTkNGMRMwEQYDVQQLDApUaW5rZXJiZWxsMREwDwYDVQQDDAh0ZXN0 +LWtleTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBALonX7dDury4h569 +ZQlBSUTL5+gdFxKMw0rrMwVxcHDm8Q4t9OlomHZu37Epc91XRfui4YEYglTo3Zzl +36gwkAPcfvGQ0KUtC1nRXs0FNOfvIm+Uv++l8S2B3BVaEABvcdpMQPkyZnKLhssd +0jF7/SFH96fdY8Xmwf0LPHcobjrQGURLU8S7mtLMJUiC0ByDTfTye6cTLFdELqzu +3LYbnT8ZceLOvLDmERJoQEgJxg2Lhr1H2TMmEUDy+svAEEyEUqYIXrPhrfN2wWUY +rIZgUupDOxvpSIzTeo/kHNrI6tbJ3laaz93Xb/0c6v5Wq/B7pViEQr5gPw+Q/h2+ +3padZIwuKAOCZLdILl0gQbUSYJ3CrOTzXBp8k++FLA8Oa8ENG3JzFXZpfsPV4Ovb +T/dulrG3zHXLjEZ/H7oNyvOLV92fJDBguGJEUeRqD8p6bmhcctygGQBBiqaKyxnl ++mGadsDrMWyegm7ua25hHnFgLofzBUEAiXGh1z0XdZAizT/oyPRn2TGAslgedoqK +wcmvOeZp+GVZi9qE8Fbd/cbdcbMWHoVBafSUh3nKiMAWaD8VeB8NKA3p2PK9j0Qn +L73kCyovDI4n9pUkkNv3pWId5DXKDAUyVgC3/ZeIwMj8lKufoG97lUyf4gOJj3xC +6fyZxaZqNk6vrGPJdY8vkF+r4e9NAgMBAAGjUzBRMB0GA1UdDgQWBBRnYKNBhu9S +IEqFqyKeJ8R7hGqagzAfBgNVHSMEGDAWgBRnYKNBhu9SIEqFqyKeJ8R7hGqagzAP +BgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4ICAQBybdkX6LoOsHKFBSRE +85yCc/tq5DNA24BAyvLBnHa4oenVt4g+cDT+eZKe/rZlHZwH6CSuWBI3XG0eyxjt +1TNkGtPcPZFHt/xe0x1nFE1XuWtWF9pwzJvJ3fDVd3IwyXoTbQ040jaNC3csVsMV +yLJNPKDT4ULVM76L89Gh/GsuWlJM9ocwfAWmfWTIUnDtJWVDw+TKvli4+IFPVStE +dFqB95AybJ1pI+0OSUioLkvrd+7udFTxn6QaqKGPDOocA+cGkR+oHmV1g72ucPaF +iTeTIT+/rzpGmmdoi1SGNdm+8+0afRHTRCjc14wRXSBDgOLtIk8f6R4lo3ZZY9xO +2PM6h1khWThjt6aEEVWL40dS9we6iIbTlc/auYR97EWMIZnlaHjOg1A4tn/XIlVR +bJ6M1x1eqCQ0bv/VLte671FXuUIgbvu6XIIDv3kW+/YDWKNkiIm2uGvrR/wFMz7r +xqyH0PzVVO9C21nNd09ZFRZM5+SF8N9NavSL2Q7m1RV8E8Boj1kJOAjYmeWYbDFB +hq5CZKXw59+WKga0ETZb3CHA7SXK0S8+8lhHdusXsV4vdgB6jL/hDH6UbFi7cxB9 +mU5Z9/wAcNUR8+UCdiW+ZpK76VFqEgWIZJzhswZ4mXET6nqypYw5XE0s3WUz2+Sh +HFFHJvKS/BrOklUc9AUGFKujZw== +-----END CERTIFICATE----- diff --git a/grpc-server/testdata/server-key.pem b/grpc-server/testdata/server-key.pem new file mode 100644 index 000000000..a3efc6182 --- /dev/null +++ b/grpc-server/testdata/server-key.pem @@ -0,0 +1,53 @@ +TEST PRIVATE KEY ONLY - NOT FOR REAL USAGE +-----BEGIN PRIVATE KEY----- +MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQC6J1+3Q7q8uIee +vWUJQUlEy+foHRcSjMNK6zMFcXBw5vEOLfTpaJh2bt+xKXPdV0X7ouGBGIJU6N2c +5d+oMJAD3H7xkNClLQtZ0V7NBTTn7yJvlL/vpfEtgdwVWhAAb3HaTED5MmZyi4bL +HdIxe/0hR/en3WPF5sH9Czx3KG460BlES1PEu5rSzCVIgtAcg0308nunEyxXRC6s +7ty2G50/GXHizryw5hESaEBICcYNi4a9R9kzJhFA8vrLwBBMhFKmCF6z4a3zdsFl +GKyGYFLqQzsb6UiM03qP5BzayOrWyd5Wms/d12/9HOr+Vqvwe6VYhEK+YD8PkP4d +vt6WnWSMLigDgmS3SC5dIEG1EmCdwqzk81wafJPvhSwPDmvBDRtycxV2aX7D1eDr +20/3bpaxt8x1y4xGfx+6Dcrzi1fdnyQwYLhiRFHkag/Kem5oXHLcoBkAQYqmissZ +5fphmnbA6zFsnoJu7mtuYR5xYC6H8wVBAIlxodc9F3WQIs0/6Mj0Z9kxgLJYHnaK +isHJrznmafhlWYvahPBW3f3G3XGzFh6FQWn0lId5yojAFmg/FXgfDSgN6djyvY9E +Jy+95AsqLwyOJ/aVJJDb96ViHeQ1ygwFMlYAt/2XiMDI/JSrn6Bve5VMn+IDiY98 +Qun8mcWmajZOr6xjyXWPL5Bfq+HvTQIDAQABAoICAQCHPeXXIji/tRyqohSOdcUC +W1W/l6rUijmz605lDPZQwCevUooVLS1fFcwkTOZlj2tDlyFYBfNiNtASlhs4eReY +BpCfdcNvzVrqxSansrmuK1kMUbhkJl4i3q6DQKxRdKX1n+KwaQJvA5lJZf/4fYj4 +re1qInjDJZQYABrMwy3aQqeoq0VPr7Cap0AK/yatIGP5qlVVm6NiPyHd96eYElXa +quTHj8Uci/kpM15IN/mQi6a3S0SsWWK9mgnFD2OIA7Z786bB7xrOv59fkF5/Penf +UjrWW1SfI2Fuup6QANpEc+K5br3IASXWcTT01QkTrPECbwyCfTAzYfaMq8fRCsYp +ppzla+F/YOKT2xtNyBj3dYnLyqEoapiL1knzoPlXAfnwu5Nb/YdQrt/AnX25Cga7 +8l3Rk/nMaGtOwv9Wj7HrRMC9a754+ILFWyGYQw8GTww4a4nJ7aXrXQJTh8WYGFRp +hQUbOX3lWT4BVhm2ENQxaqGsB/sDSSMke9dhf7OvBnxtArxDQvFHulWkPQd5idYD +pDFwlgCLCp9USeQp3cwieS1ukMPYVHjx+pgGknUqzEPscjCmqgttOhMRS2f1m2KK +bs4kJBzipwM981qx4WTp2rhXBKbFmSQYR5Rv8vgjzPaaCxg96YH+V+/4Jd5fQtw3 +R9mYa1Xg43FLuiR/brkbIQKCAQEA86qFknjZKIe/5CCwaa5lGXj80fjdJqX1is2v +YJmrw/Kupt7dTS8+6ZQPHj74eGab0gyww5OVWKCsvKWgTdq+qIL9eQBl1v9B7Ckn +aRX6SfMDMbftlzQQYy2redz/X6PJaUggOsHWtNWP2gbKFeeBaZbDOjDCV0xlsThC +9gIp2Yhya1EDcMXod6/FTuEE4oyMwktU6bLxwKv9fu0Oqo7g1Wr9COM4xal1sLhq +FUPpguYhWkKBR8sF/LoOuf18stdcMv/X/194Ry5Bc1TVb9oIDV6E7M8Wamfc5ngh +w5PwRyWqkziBhTOgUgJ9n+frBBV9ieNxPOtnsjxfbyOkXiIs6QKCAQEAw5OYZaTu +g9yXlI70sYVKcF1OMZZ+yY+MgRaDy2if/0Mr2ZGGh4UhBfSXM1oCKYJt+Oxx3v/F +YbSzGDeJ0ZefcOU5RL8jePx3/rn8pSjfCTYGZG0DE3bFdCzbWRdmcgz3RiFHc12V +eKmrIES5rk3Jv78SzrvbMOStU7DNMX82PlVKlZEnVFmGgoOc0xPGN6CNF7RDyCc7 +LJgaKBCWHx8PaSwOpcc+GBoAUXDh//tiUt5sD1vlc7vZKc49EHPZoUtuRbqHfNI9 +zCoKg+yO34A/0L4Syqehind6QVlfMcavtx4mgGA6olR9f6O/kxL9CKZYsjVje+SB +3GFGL6rGhg5gxQKCAQEAgmPXwn/ExTmPaAZOIN9f3net4ransUzRzosutCTHk73D +1CwihHEp21iNloPf9p4B+C8uUBojqx+gD/sZg0/xAr/F4ABkft5tanDDVCqcmwHd +zbc8/tKvikMgJcArMAS3fQ1Joeeke3Rk3CkR7xLJX7V7lyIMfSa2rFUNEBQsTOoF +QIRDsQ9WzOVUUld7g7fugvJI6B4H3DCtIES+umpmyg3MhfsBoFSEVCL7MZH73T5E +zsYT5FUySQFPbBvHpPQ1tFzQOynddUm9YHgfFxG3iV/xBb/zoNEflnzmpH//7jKr +yshMFvl/ayNGElHKo47UdPsu14ipHunLr++Ev5LOMQKCAQBJ05ZMkETlC8lAb/JY +bKtb3SzeNSQpLAHq3LfledojvpR37aIt2AhOOjU1Uj3Ms4qV05NsjXpR1qdgdd6V +ernaIP1MQSa/zfXx9v0yz0naLUWedTQbDdOddi1a5SVr2g8hrwBMwT/iK5IIfUjm +TkDhG9yao3krbLctB2l9zLqKLyIXcZK6GY1YCRyS5T0G1JlOIGMR1BVXURdWlmRE +3TGxDst8sshyyqXiGE2HlrpX89QwvAzSck+Yo1yTsFevtkyrD62DZc2kGx6bDBom +rj/oqUdornyhS1agAn+Xx5ue8UexYCHiEyjInOR9PUa9FCYZJ2QlaW3H5gRbjAii +pBzNAoIBAAP0k83hf8gbIvRbdJayhdRRjF7tK5ad7imp8/LhZqdYpkZG8CvxrGft +R6xc+Hebl3XNJfTrLLyKY7dZt3MmSVKxiEg5+QAh49ih3nCG30A4Uu3Ime4GiKB6 +Twa7UdqSplWsLoXbUV4mPZQCzefUVdlDqYrO57eeSSmRrkJsdJONpjuhnehVVznp +DhNNbvhGhsZarLLwBJ15SmQC0iOR6QVOFjJbgK5Ldpu3QBHDYpRqYMUEHBUkaohM +HwKd3dtQHWJ4QS3VbF9CDs/tnlQH+xn9l/Qwl6WJcypxoKvATgOeJHrMCN4fQIll +Lnmk2S0UPNTJwgZQsKgRGGsrF1v+Hlw= +-----END PRIVATE KEY----- diff --git a/server/dbserver.go b/server/dbserver.go new file mode 100644 index 000000000..d61387081 --- /dev/null +++ b/server/dbserver.go @@ -0,0 +1,78 @@ +package server + +import ( + "sync" + "time" + + "github.com/packethost/pkg/log" + "github.com/tinkerbell/tink/db" + "github.com/tinkerbell/tink/protos/hardware" + "github.com/tinkerbell/tink/protos/template" + "github.com/tinkerbell/tink/protos/workflow" + "google.golang.org/grpc" +) + +const ( + errInvalidWorkerID = "invalid worker id" + errInvalidWorkflowID = "invalid workflow id" + errInvalidTaskName = "invalid task name" + errInvalidActionName = "invalid action name" + errInvalidTaskReported = "reported task name does not match the current action details" + errInvalidActionReported = "reported action name does not match the current action details" + + msgReceivedStatus = "received action status: %s" + msgCurrentWfContext = "current workflow context" + msgSendWfContext = "send workflow context: %s" +) + +// DBServerOption is a type for modifying a DBServer. +type DBServerOption func(*DBServer) error + +// WithCerts sets a certificate mod time and material on a server. +func WithCerts(modTime time.Time, publicCertPEM []byte) DBServerOption { + return func(s *DBServer) error { + s.modT = modTime + s.cert = publicCertPEM + return nil + } +} + +// DBServer is a gRPC Server for database-backed Tinkerbell. +type DBServer struct { + cert []byte + modT time.Time + + db db.Database + quit <-chan struct{} + + dbLock sync.RWMutex + dbReady bool + + watchLock sync.RWMutex + watch map[string]chan string + + logger log.Logger +} + +// NewServer returns a new Tinkerbell server. +func NewDBServer(l log.Logger, database db.Database, opts ...DBServerOption) (*DBServer, error) { + ts := &DBServer{ + db: database, + logger: l, + dbReady: true, + } + for _, opt := range opts { + if err := opt(ts); err != nil { + return nil, err + } + } + + return ts, nil +} + +// Register registers Template, Workflow, and Hardware APIs on a gRPC server. +func (s *DBServer) Register(gserver *grpc.Server) { + template.RegisterTemplateServiceServer(gserver, s) + workflow.RegisterWorkflowServiceServer(gserver, s) + hardware.RegisterHardwareServiceServer(gserver, s) +} diff --git a/grpc-server/hardware.go b/server/dbserver_hardware.go similarity index 88% rename from grpc-server/hardware.go rename to server/dbserver_hardware.go index 1d8100c12..427374d34 100644 --- a/grpc-server/hardware.go +++ b/server/dbserver_hardware.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "context" @@ -21,7 +21,7 @@ const ( duplicateMAC = "Duplicate MAC address found" ) -func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware.Empty, error) { +func (s *DBServer) Push(ctx context.Context, in *hardware.PushRequest) (*hardware.Empty, error) { s.logger.Info("push") labels := prometheus.Labels{"method": "Push", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -95,7 +95,7 @@ func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware. return &hardware.Empty{}, err } -func (s *server) by(method string, fn func() (string, error)) (*hardware.Hardware, error) { +func (s *DBServer) by(method string, fn func() (string, error)) (*hardware.Hardware, error) { labels := prometheus.Labels{"method": method, "op": "get"} metrics.CacheTotals.With(labels).Inc() @@ -128,27 +128,27 @@ func (s *server) by(method string, fn func() (string, error)) (*hardware.Hardwar return hw, nil } -func (s *server) ByMAC(ctx context.Context, in *hardware.GetRequest) (*hardware.Hardware, error) { +func (s *DBServer) ByMAC(ctx context.Context, in *hardware.GetRequest) (*hardware.Hardware, error) { return s.by("ByMAC", func() (string, error) { return s.db.GetByMAC(ctx, in.Mac) }) } -func (s *server) ByIP(ctx context.Context, in *hardware.GetRequest) (*hardware.Hardware, error) { +func (s *DBServer) ByIP(ctx context.Context, in *hardware.GetRequest) (*hardware.Hardware, error) { return s.by("ByIP", func() (string, error) { return s.db.GetByIP(ctx, in.Ip) }) } // ByID implements hardware.ByID. -func (s *server) ByID(ctx context.Context, in *hardware.GetRequest) (*hardware.Hardware, error) { +func (s *DBServer) ByID(ctx context.Context, in *hardware.GetRequest) (*hardware.Hardware, error) { return s.by("ByID", func() (string, error) { return s.db.GetByID(ctx, in.Id) }) } // ALL implements hardware.All. -func (s *server) All(_ *hardware.Empty, stream hardware.HardwareService_AllServer) error { +func (s *DBServer) All(_ *hardware.Empty, stream hardware.HardwareService_AllServer) error { labels := prometheus.Labels{"method": "All", "op": "get"} metrics.CacheTotals.With(labels).Inc() @@ -181,7 +181,7 @@ func (s *server) All(_ *hardware.Empty, stream hardware.HardwareService_AllServe return nil } -func (s *server) DeprecatedWatch(in *hardware.GetRequest, stream hardware.HardwareService_DeprecatedWatchServer) error { +func (s *DBServer) DeprecatedWatch(in *hardware.GetRequest, stream hardware.HardwareService_DeprecatedWatchServer) error { l := s.logger.With("id", in.Id) ch := make(chan string, 1) @@ -242,16 +242,16 @@ func (s *server) DeprecatedWatch(in *hardware.GetRequest, stream hardware.Hardwa } // Cert returns the public cert that can be served to clients. -func (s *server) Cert() []byte { +func (s *DBServer) Cert() []byte { return s.cert } // ModTime returns the modified-time of the grpc cert. -func (s *server) ModTime() time.Time { +func (s *DBServer) ModTime() time.Time { return s.modT } -func (s *server) Delete(ctx context.Context, in *hardware.DeleteRequest) (*hardware.Empty, error) { +func (s *DBServer) Delete(ctx context.Context, in *hardware.DeleteRequest) (*hardware.Empty, error) { s.logger.Info("delete") labels := prometheus.Labels{"method": "Delete", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -303,7 +303,7 @@ func (s *server) Delete(ctx context.Context, in *hardware.DeleteRequest) (*hardw return &hardware.Empty{}, err } -func (s *server) validateHardwareData(ctx context.Context, hw *hardware.Hardware) error { +func (s *DBServer) validateHardwareData(ctx context.Context, hw *hardware.Hardware) error { for _, iface := range hw.GetNetwork().GetInterfaces() { mac := iface.GetDhcp().GetMac() diff --git a/grpc-server/hardware_test.go b/server/dbserver_hardware_test.go similarity index 99% rename from grpc-server/hardware_test.go rename to server/dbserver_hardware_test.go index 78d1e0217..bd0228c5a 100644 --- a/grpc-server/hardware_test.go +++ b/server/dbserver_hardware_test.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "testing" diff --git a/grpc-server/template.go b/server/dbserver_template.go similarity index 88% rename from grpc-server/template.go rename to server/dbserver_template.go index 0dc2cd53c..3858f7d53 100644 --- a/grpc-server/template.go +++ b/server/dbserver_template.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "context" @@ -14,7 +14,7 @@ import ( ) // CreateTemplate implements template.CreateTemplate. -func (s *server) CreateTemplate(ctx context.Context, in *template.WorkflowTemplate) (*template.CreateResponse, error) { +func (s *DBServer) CreateTemplate(ctx context.Context, in *template.WorkflowTemplate) (*template.CreateResponse, error) { s.logger.Info("createtemplate") labels := prometheus.Labels{"method": "CreateTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -44,7 +44,7 @@ func (s *server) CreateTemplate(ctx context.Context, in *template.WorkflowTempla } // GetTemplate implements template.GetTemplate. -func (s *server) GetTemplate(ctx context.Context, in *template.GetRequest) (*template.WorkflowTemplate, error) { +func (s *DBServer) GetTemplate(ctx context.Context, in *template.GetRequest) (*template.WorkflowTemplate, error) { s.logger.Info("gettemplate") labels := prometheus.Labels{"method": "GetTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -76,7 +76,7 @@ func (s *server) GetTemplate(ctx context.Context, in *template.GetRequest) (*tem } // DeleteTemplate implements template.DeleteTemplate. -func (s *server) DeleteTemplate(ctx context.Context, in *template.GetRequest) (*template.Empty, error) { +func (s *DBServer) DeleteTemplate(ctx context.Context, in *template.GetRequest) (*template.Empty, error) { s.logger.Info("deletetemplate") labels := prometheus.Labels{"method": "DeleteTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -104,7 +104,7 @@ func (s *server) DeleteTemplate(ctx context.Context, in *template.GetRequest) (* } // ListTemplates implements template.ListTemplates. -func (s *server) ListTemplates(in *template.ListRequest, stream template.TemplateService_ListTemplatesServer) error { +func (s *DBServer) ListTemplates(in *template.ListRequest, stream template.TemplateService_ListTemplatesServer) error { s.logger.Info("listtemplates") labels := prometheus.Labels{"method": "ListTemplates", "op": "list"} metrics.CacheTotals.With(labels).Inc() @@ -139,7 +139,7 @@ func (s *server) ListTemplates(in *template.ListRequest, stream template.Templat } // UpdateTemplate implements template.UpdateTemplate. -func (s *server) UpdateTemplate(ctx context.Context, in *template.WorkflowTemplate) (*template.Empty, error) { +func (s *DBServer) UpdateTemplate(ctx context.Context, in *template.WorkflowTemplate) (*template.Empty, error) { s.logger.Info("updatetemplate") labels := prometheus.Labels{"method": "UpdateTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() diff --git a/grpc-server/template_test.go b/server/dbserver_template_test.go similarity index 99% rename from grpc-server/template_test.go rename to server/dbserver_template_test.go index ca9a6569d..93af6fd0d 100644 --- a/grpc-server/template_test.go +++ b/server/dbserver_template_test.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "context" diff --git a/grpc-server/tinkerbell.go b/server/dbserver_worker_workflow.go similarity index 63% rename from grpc-server/tinkerbell.go rename to server/dbserver_worker_workflow.go index 8d3476608..0e4bc73a3 100644 --- a/grpc-server/tinkerbell.go +++ b/server/dbserver_worker_workflow.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "context" @@ -8,28 +8,13 @@ import ( "github.com/packethost/pkg/log" "github.com/tinkerbell/tink/db" - pb "github.com/tinkerbell/tink/protos/workflow" + "github.com/tinkerbell/tink/protos/workflow" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) -var workflowData = make(map[string]int) - -const ( - errInvalidWorkerID = "invalid worker id" - errInvalidWorkflowID = "invalid workflow id" - errInvalidTaskName = "invalid task name" - errInvalidActionName = "invalid action name" - errInvalidTaskReported = "reported task name does not match the current action details" - errInvalidActionReported = "reported action name does not match the current action details" - - msgReceivedStatus = "received action status: %s" - msgCurrentWfContext = "current workflow context" - msgSendWfContext = "send workflow context: %s" -) - // GetWorkflowContexts implements tinkerbell.GetWorkflowContexts. -func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.WorkflowService_GetWorkflowContextsServer) error { +func (s *DBServer) GetWorkflowContexts(req *workflow.WorkflowContextRequest, stream workflow.WorkflowService_GetWorkflowContextsServer) error { wfs, err := getWorkflowsForWorker(stream.Context(), s.db, req.WorkerId) if err != nil { return err @@ -49,14 +34,14 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W } // GetWorkflowContextList implements tinkerbell.GetWorkflowContextList. -func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) { +func (s *DBServer) GetWorkflowContextList(ctx context.Context, req *workflow.WorkflowContextRequest) (*workflow.WorkflowContextList, error) { wfs, err := getWorkflowsForWorker(ctx, s.db, req.WorkerId) if err != nil { return nil, err } if wfs != nil { - wfContexts := []*pb.WorkflowContext{} + wfContexts := []*workflow.WorkflowContext{} for _, wf := range wfs { wfContext, err := s.db.GetWorkflowContexts(ctx, wf) if err != nil { @@ -64,7 +49,7 @@ func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowCon } wfContexts = append(wfContexts, wfContext) } - return &pb.WorkflowContextList{ + return &workflow.WorkflowContextList{ WorkflowContexts: wfContexts, }, nil } @@ -72,7 +57,7 @@ func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowCon } // GetWorkflowActions implements tinkerbell.GetWorkflowActions. -func (s *server) GetWorkflowActions(ctx context.Context, req *pb.WorkflowActionsRequest) (*pb.WorkflowActionList, error) { +func (s *DBServer) GetWorkflowActions(ctx context.Context, req *workflow.WorkflowActionsRequest) (*workflow.WorkflowActionList, error) { wfID := req.GetWorkflowId() if wfID == "" { return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowID) @@ -81,7 +66,7 @@ func (s *server) GetWorkflowActions(ctx context.Context, req *pb.WorkflowActions } // ReportActionStatus implements tinkerbell.ReportActionStatus. -func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionStatus) (*pb.Empty, error) { +func (s *DBServer) ReportActionStatus(ctx context.Context, req *workflow.WorkflowActionStatus) (*workflow.Empty, error) { wfID := req.GetWorkflowId() if wfID == "" { return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowID) @@ -93,7 +78,7 @@ func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionS return nil, status.Errorf(codes.InvalidArgument, errInvalidActionName) } - l := s.logger.With("actionName", req.GetActionName(), "workflowID", req.GetWorkflowId()) + l := s.logger.With("actionName", req.GetActionName(), "workflowID", req.GetWorkflowId(), "taskName", req.GetTaskName()) l.Info(fmt.Sprintf(msgReceivedStatus, req.GetActionStatus())) wfContext, err := s.db.GetWorkflowContexts(ctx, wfID) @@ -106,7 +91,7 @@ func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionS } actionIndex := wfContext.GetCurrentActionIndex() - if req.GetActionStatus() == pb.State_STATE_RUNNING { + if req.GetActionStatus() == workflow.State_STATE_RUNNING { if wfContext.GetCurrentAction() != "" { actionIndex++ } @@ -125,14 +110,14 @@ func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionS wfContext.CurrentActionIndex = actionIndex err = s.db.UpdateWorkflowState(ctx, wfContext) if err != nil { - return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error()) + return &workflow.Empty{}, status.Errorf(codes.Aborted, err.Error()) } // TODO the below "time" would be a part of the request which is coming form worker. t := time.Now() err = s.db.InsertIntoWorkflowEventTable(ctx, req, t) if err != nil { - return &pb.Empty{}, status.Error(codes.Aborted, err.Error()) + return &workflow.Empty{}, status.Error(codes.Aborted, err.Error()) } l = s.logger.With( @@ -145,55 +130,52 @@ func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionS "totalNumberOfActions", wfContext.GetTotalNumberOfActions(), ) l.Info(msgCurrentWfContext) - return &pb.Empty{}, nil + return &workflow.Empty{}, nil } // UpdateWorkflowData updates workflow ephemeral data. -func (s *server) UpdateWorkflowData(ctx context.Context, req *pb.UpdateWorkflowDataRequest) (*pb.Empty, error) { - wfID := req.GetWorkflowId() - if wfID == "" { - return &pb.Empty{}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowID) - } - _, ok := workflowData[wfID] - if !ok { - workflowData[wfID] = 1 +func (s *DBServer) UpdateWorkflowData(ctx context.Context, req *workflow.UpdateWorkflowDataRequest) (*workflow.Empty, error) { + if wfID := req.GetWorkflowId(); wfID == "" { + return &workflow.Empty{}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowID) } + err := s.db.InsertIntoWfDataTable(ctx, req) if err != nil { - return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error()) + return &workflow.Empty{}, status.Errorf(codes.Aborted, err.Error()) } - return &pb.Empty{}, nil + return &workflow.Empty{}, nil } // GetWorkflowData gets the ephemeral data for a workflow. -func (s *server) GetWorkflowData(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { +func (s *DBServer) GetWorkflowData(ctx context.Context, req *workflow.GetWorkflowDataRequest) (*workflow.GetWorkflowDataResponse, error) { if id := req.GetWorkflowId(); id == "" { - return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowID) + return &workflow.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowID) } data, err := s.db.GetfromWfDataTable(ctx, req) if err != nil { - return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error()) + s.logger.Error(err, "Error getting from data table") + return &workflow.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error()) } - return &pb.GetWorkflowDataResponse{Data: data}, nil + return &workflow.GetWorkflowDataResponse{Data: data}, nil } // GetWorkflowMetadata returns metadata wrt to the ephemeral data of a workflow. -func (s *server) GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { +func (s *DBServer) GetWorkflowMetadata(ctx context.Context, req *workflow.GetWorkflowDataRequest) (*workflow.GetWorkflowDataResponse, error) { data, err := s.db.GetWorkflowMetadata(ctx, req) if err != nil { - return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error()) + return &workflow.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error()) } - return &pb.GetWorkflowDataResponse{Data: data}, nil + return &workflow.GetWorkflowDataResponse{Data: data}, nil } // GetWorkflowDataVersion returns the latest version of data for a workflow. -func (s *server) GetWorkflowDataVersion(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { +func (s *DBServer) GetWorkflowDataVersion(ctx context.Context, req *workflow.GetWorkflowDataRequest) (*workflow.GetWorkflowDataResponse, error) { version, err := s.db.GetWorkflowDataVersion(ctx, req.WorkflowId) if err != nil { - return &pb.GetWorkflowDataResponse{Version: version}, status.Errorf(codes.Aborted, err.Error()) + return &workflow.GetWorkflowDataResponse{Version: version}, status.Errorf(codes.Aborted, err.Error()) } - return &pb.GetWorkflowDataResponse{Version: version}, nil + return &workflow.GetWorkflowDataResponse{Version: version}, nil } func getWorkflowsForWorker(ctx context.Context, d db.Database, id string) ([]string, error) { @@ -207,7 +189,7 @@ func getWorkflowsForWorker(ctx context.Context, d db.Database, id string) ([]str return wfs, nil } -func getWorkflowActions(ctx context.Context, d db.Database, wfID string) (*pb.WorkflowActionList, error) { +func getWorkflowActions(ctx context.Context, d db.Database, wfID string) (*workflow.WorkflowActionList, error) { actions, err := d.GetWorkflowActions(ctx, wfID) if err != nil { return nil, status.Errorf(codes.Aborted, errInvalidWorkflowID) @@ -217,16 +199,16 @@ func getWorkflowActions(ctx context.Context, d db.Database, wfID string) (*pb.Wo // isApplicableToSend checks if a particular workflow context is applicable or if it is needed to // be sent to a worker based on the state of the current action and the targeted workerID. -func isApplicableToSend(ctx context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, d db.Database) bool { - if wfContext.GetCurrentActionState() == pb.State_STATE_FAILED || - wfContext.GetCurrentActionState() == pb.State_STATE_TIMEOUT { +func isApplicableToSend(ctx context.Context, logger log.Logger, wfContext *workflow.WorkflowContext, workerID string, d db.Database) bool { + if wfContext.GetCurrentActionState() == workflow.State_STATE_FAILED || + wfContext.GetCurrentActionState() == workflow.State_STATE_TIMEOUT { return false } actions, err := getWorkflowActions(ctx, d, wfContext.GetWorkflowId()) if err != nil { return false } - if wfContext.GetCurrentActionState() == pb.State_STATE_SUCCESS { + if wfContext.GetCurrentActionState() == workflow.State_STATE_SUCCESS { if isLastAction(wfContext, actions) { return false } @@ -243,6 +225,6 @@ func isApplicableToSend(ctx context.Context, logger log.Logger, wfContext *pb.Wo return false } -func isLastAction(wfContext *pb.WorkflowContext, actions *pb.WorkflowActionList) bool { +func isLastAction(wfContext *workflow.WorkflowContext, actions *workflow.WorkflowActionList) bool { return int(wfContext.GetCurrentActionIndex()) == len(actions.GetActionList())-1 } diff --git a/grpc-server/tinkerbell_test.go b/server/dbserver_worker_workflow_test.go similarity index 99% rename from grpc-server/tinkerbell_test.go rename to server/dbserver_worker_workflow_test.go index 1c994042c..370f222f0 100644 --- a/grpc-server/tinkerbell_test.go +++ b/server/dbserver_worker_workflow_test.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "context" @@ -28,14 +28,14 @@ const ( var wfData = []byte("{'os': 'ubuntu', 'base_url': 'http://192.168.1.1/'}") -func testServer(t *testing.T, d db.Database) *server { +func testServer(t *testing.T, d db.Database) *DBServer { t.Helper() l, err := log.Init("github.com/tinkerbell/tink") if err != nil { t.Errorf("log init failed: %v", err) } - return &server{ + return &DBServer{ logger: l, db: d, } diff --git a/grpc-server/workflow.go b/server/dbserver_workflow.go similarity index 92% rename from grpc-server/workflow.go rename to server/dbserver_workflow.go index f37ec4339..c02761b6a 100644 --- a/grpc-server/workflow.go +++ b/server/dbserver_workflow.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "context" @@ -17,7 +17,7 @@ import ( const errFailedToGetTemplate = "failed to get template with ID %s" // CreateWorkflow implements workflow.CreateWorkflow. -func (s *server) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) (*workflow.CreateResponse, error) { +func (s *DBServer) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) (*workflow.CreateResponse, error) { s.logger.Info("createworkflow") labels := prometheus.Labels{"method": "CreateWorkflow", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -72,7 +72,7 @@ func (s *server) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) } // GetWorkflow implements workflow.GetWorkflow. -func (s *server) GetWorkflow(ctx context.Context, in *workflow.GetRequest) (*workflow.Workflow, error) { +func (s *DBServer) GetWorkflow(ctx context.Context, in *workflow.GetRequest) (*workflow.Workflow, error) { s.logger.Info("getworkflow") labels := prometheus.Labels{"method": "GetWorkflow", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -123,7 +123,7 @@ func (s *server) GetWorkflow(ctx context.Context, in *workflow.GetRequest) (*wor } // DeleteWorkflow implements workflow.DeleteWorkflow. -func (s *server) DeleteWorkflow(ctx context.Context, in *workflow.GetRequest) (*workflow.Empty, error) { +func (s *DBServer) DeleteWorkflow(ctx context.Context, in *workflow.GetRequest) (*workflow.Empty, error) { s.logger.Info("deleteworkflow") labels := prometheus.Labels{"method": "DeleteWorkflow", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -152,7 +152,7 @@ func (s *server) DeleteWorkflow(ctx context.Context, in *workflow.GetRequest) (* } // ListWorkflows implements workflow.ListWorkflows. -func (s *server) ListWorkflows(_ *workflow.Empty, stream workflow.WorkflowService_ListWorkflowsServer) error { +func (s *DBServer) ListWorkflows(_ *workflow.Empty, stream workflow.WorkflowService_ListWorkflowsServer) error { s.logger.Info("listworkflows") labels := prometheus.Labels{"method": "ListWorkflows", "op": "list"} metrics.CacheTotals.With(labels).Inc() @@ -189,7 +189,7 @@ func (s *server) ListWorkflows(_ *workflow.Empty, stream workflow.WorkflowServic return nil } -func (s *server) GetWorkflowContext(ctx context.Context, in *workflow.GetRequest) (*workflow.WorkflowContext, error) { +func (s *DBServer) GetWorkflowContext(ctx context.Context, in *workflow.GetRequest) (*workflow.WorkflowContext, error) { s.logger.Info("GetworkflowContext") labels := prometheus.Labels{"method": "GetWorkflowContext", "op": ""} metrics.CacheInFlight.With(labels).Inc() @@ -236,7 +236,7 @@ func (s *server) GetWorkflowContext(ctx context.Context, in *workflow.GetRequest } // ShowWorflowevents implements workflow.ShowWorflowEvents. -func (s *server) ShowWorkflowEvents(req *workflow.GetRequest, stream workflow.WorkflowService_ShowWorkflowEventsServer) error { +func (s *DBServer) ShowWorkflowEvents(req *workflow.GetRequest, stream workflow.WorkflowService_ShowWorkflowEventsServer) error { s.logger.Info("List workflows Events") labels := prometheus.Labels{"method": "ShowWorkflowEvents", "op": "list"} metrics.CacheTotals.With(labels).Inc() diff --git a/grpc-server/workflow_test.go b/server/dbserver_workflow_test.go similarity index 99% rename from grpc-server/workflow_test.go rename to server/dbserver_workflow_test.go index 6ae191939..b5475fa60 100644 --- a/grpc-server/workflow_test.go +++ b/server/dbserver_workflow_test.go @@ -1,4 +1,4 @@ -package grpcserver +package server import ( "context"