Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SigV4 support #78

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
maxRetriesConfig = &configuration{flag: "max-retries", envFlag: "max_retries", defaultValue: strconv.Itoa(awsClient.DefaultRetryerMaxNumRetries)}
defaultDatabaseConfig = &configuration{flag: "default-database", envFlag: "default_database", defaultValue: ""}
defaultTableConfig = &configuration{flag: "default-table", envFlag: "default_table", defaultValue: ""}
enableSigV4AuthConfig = &configuration{flag: "enable-sigv4-auth", envFlag: "enable_sigv4_auth", defaultValue: "true"}
listenAddrConfig = &configuration{flag: "web.listen-address", envFlag: "", defaultValue: ":9201"}
telemetryPathConfig = &configuration{flag: "web.telemetry-path", envFlag: "", defaultValue: "/metrics"}
failOnLabelConfig = &configuration{flag: "fail-on-long-label", envFlag: "fail_on_long_label", defaultValue: "false"}
Expand Down
34 changes: 28 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/go-kit/log"
"github.com/gogo/protobuf/proto"
"github.com/golang/snappy"
Expand Down Expand Up @@ -83,6 +84,7 @@ type connectionConfig struct {
defaultDatabase string
defaultTable string
enableLogging bool
enableSigV4Auth bool
failOnLongMetricLabelName bool
failOnInvalidSample bool
listenAddr string
Expand Down Expand Up @@ -145,9 +147,20 @@ func lambdaHandler(req events.APIGatewayProxyRequest) (events.APIGatewayProxyRes

logger := cfg.createLogger()

awsCredentials, ok := parseBasicAuth(req.Headers[basicAuthHeader])
if !ok {
return createErrorResponse(errors.NewParseBasicAuthHeaderError().(*errors.ParseBasicAuthHeaderError).Message())
var awsCredentials *credentials.Credentials
var ok bool

// If SigV4 authentication has been enabled, such as when write requests originate
// from the OpenTelemetry collector, credentials will be taken from the local environment.
// Otherwise, basic auth is used for AWS credentials
if cfg.enableSigV4Auth {
sess := session.Must(session.NewSession())
awsCredentials = sess.Config.Credentials
} else {
awsCredentials, ok = parseBasicAuth(req.Headers[basicAuthHeader])
if !ok {
return createErrorResponse(errors.NewParseBasicAuthHeaderError().(*errors.ParseBasicAuthHeaderError).Message())
}
}

awsConfigs := cfg.buildAWSConfig()
Expand Down Expand Up @@ -280,7 +293,7 @@ func (cfg *connectionConfig) createLogger() (logger log.Logger) {
}

// parseBoolFromStrings parses the boolean configuration options from the strings in connectionConfig.
func (cfg *connectionConfig) parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample string) error {
func (cfg *connectionConfig) parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample, enableSigV4Auth string) error {
var err error

cfg.enableLogging, err = strconv.ParseBool(enableLogging)
Expand All @@ -304,6 +317,13 @@ func (cfg *connectionConfig) parseBoolFromStrings(enableLogging, failOnLongMetri
return timestreamError
}

cfg.enableSigV4Auth, err = strconv.ParseBool(enableSigV4Auth)
if err != nil {
timestreamError := errors.NewParseSampleOptionError(failOnInvalidSample)
fmt.Println(timestreamError.Error())
return timestreamError
}

return nil
}

Expand All @@ -328,7 +348,7 @@ func parseEnvironmentVariables() (*connectionConfig, error) {
cfg.defaultTable = getOrDefault(defaultTableConfig)

var err error
err = cfg.parseBoolFromStrings(getOrDefault(enableLogConfig), getOrDefault(failOnLabelConfig), getOrDefault(failOnInvalidSampleConfig))
err = cfg.parseBoolFromStrings(getOrDefault(enableLogConfig), getOrDefault(failOnLabelConfig), getOrDefault(failOnInvalidSampleConfig), getOrDefault(enableSigV4AuthConfig))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -357,6 +377,7 @@ func parseFlags() *connectionConfig {
}

var enableLogging string
var enableSigV4Auth string
var failOnLongMetricLabelName string
var failOnInvalidSample string

Expand All @@ -373,6 +394,7 @@ func parseFlags() *connectionConfig {
Default(failOnInvalidSampleConfig.defaultValue).StringVar(&failOnInvalidSample)
a.Flag(certificateConfig.flag, "TLS server certificate file.").Default(certificateConfig.defaultValue).StringVar(&cfg.certificate)
a.Flag(keyConfig.flag, "TLS server private key file.").Default(keyConfig.defaultValue).StringVar(&cfg.key)
a.Flag(enableSigV4AuthConfig.flag, "Whether to enable SigV4 authentication with the API Gateway. Default to 'false'.").Default(enableSigV4AuthConfig.defaultValue).StringVar(&enableSigV4Auth)

flag.AddFlags(a, &cfg.promlogConfig)

Expand All @@ -381,7 +403,7 @@ func parseFlags() *connectionConfig {
os.Exit(1)
}

if err := cfg.parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample); err != nil {
if err := cfg.parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample, enableSigV4Auth); err != nil {
os.Exit(1)
}

Expand Down
4 changes: 4 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ func setUp() ([]string, *connectionConfig) {
defaultDatabase: "foo",
defaultTable: "bar",
enableLogging: true,
enableSigV4Auth: true,
listenAddr: ":9201",
maxRetries: 3,
telemetryPath: "/metrics",
Expand Down Expand Up @@ -365,6 +366,7 @@ func TestLambdaHandlerPrepareRequest(t *testing.T) {
lambdaOptions: []lambdaEnvOptions{
{key: defaultTableConfig.envFlag, value: tableValue},
{key: defaultDatabaseConfig.envFlag, value: databaseValue},
{key: enableSigV4AuthConfig.envFlag, value: "false"},
},
inputRequest: events.APIGatewayProxyRequest{
IsBase64Encoded: true,
Expand All @@ -379,6 +381,7 @@ func TestLambdaHandlerPrepareRequest(t *testing.T) {
lambdaOptions: []lambdaEnvOptions{
{key: defaultTableConfig.envFlag, value: tableValue},
{key: defaultDatabaseConfig.envFlag, value: databaseValue},
{key: enableSigV4AuthConfig.envFlag, value: "false"},
},
inputRequest: events.APIGatewayProxyRequest{
IsBase64Encoded: true,
Expand Down Expand Up @@ -658,6 +661,7 @@ func TestParseEnvironmentVariables(t *testing.T) {
clientConfig: &clientConfig{region: "us-east-1"},
promlogConfig: defaultLogConfig,
enableLogging: true,
enableSigV4Auth: true,
failOnInvalidSample: false,
failOnLongMetricLabelName: false,
maxRetries: 3,
Expand Down
Loading