diff --git a/CHANGELOG.md b/CHANGELOG.md index b4c0d047217..137bf9ff910 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Code instrumented with the `go.opentelemetry.io/otel/metric` will need to be mod ### Added - Add go 1.18 to our compatibility tests. (#2679) +- Allow configuring the Sampler with the `OTEL_TRACES_SAMPLER` and `OTEL_TRACES_SAMPLER_ARG` environment variables. (#2305, #2517) ### Changed diff --git a/sdk/trace/provider.go b/sdk/trace/provider.go index 0643d1a1611..77b86a84feb 100644 --- a/sdk/trace/provider.go +++ b/sdk/trace/provider.go @@ -99,6 +99,7 @@ func NewTracerProvider(opts ...TracerProviderOption) *TracerProvider { o := tracerProviderConfig{ spanLimits: NewSpanLimits(), } + o = applyTracerProviderEnvConfigs(o) for _, opt := range opts { o = opt.apply(o) @@ -335,7 +336,10 @@ func WithIDGenerator(g IDGenerator) TracerProviderOption { // Tracers the TracerProvider creates to make their sampling decisions for the // Spans they create. // -// If this option is not used, the TracerProvider will use a +// This option overrides the Sampler configured through the OTEL_TRACES_SAMPLER +// and OTEL_TRACES_SAMPLER_ARG environment variables. If this option is not used +// and the sampler is not configured through environment variables or the environment +// contains invalid/unsupported configuration, the TracerProvider will use a // ParentBased(AlwaysSample) Sampler by default. func WithSampler(s Sampler) TracerProviderOption { return traceProviderOptionFunc(func(cfg tracerProviderConfig) tracerProviderConfig { @@ -408,6 +412,29 @@ func WithRawSpanLimits(limits SpanLimits) TracerProviderOption { }) } +func applyTracerProviderEnvConfigs(cfg tracerProviderConfig) tracerProviderConfig { + for _, opt := range tracerProviderOptionsFromEnv() { + cfg = opt.apply(cfg) + } + + return cfg +} + +func tracerProviderOptionsFromEnv() []TracerProviderOption { + var opts []TracerProviderOption + + sampler, err := samplerFromEnv() + if err != nil { + otel.Handle(err) + } + + if sampler != nil { + opts = append(opts, WithSampler(sampler)) + } + + return opts +} + // ensureValidTracerProviderConfig ensures that given TracerProviderConfig is valid. func ensureValidTracerProviderConfig(cfg tracerProviderConfig) tracerProviderConfig { if cfg.sampler == nil { diff --git a/sdk/trace/provider_test.go b/sdk/trace/provider_test.go index e2fce31d7f7..0cfe584c926 100644 --- a/sdk/trace/provider_test.go +++ b/sdk/trace/provider_test.go @@ -17,10 +17,14 @@ package trace import ( "context" "errors" + "fmt" + "math/rand" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + ottest "go.opentelemetry.io/otel/internal/internaltest" "go.opentelemetry.io/otel/trace" ) @@ -94,3 +98,164 @@ func TestSchemaURL(t *testing.T) { tracerStruct := tracerIface.(*tracer) assert.EqualValues(t, schemaURL, tracerStruct.instrumentationLibrary.SchemaURL) } + +func TestTracerProviderSamplerConfigFromEnv(t *testing.T) { + type testCase struct { + sampler string + samplerArg string + argOptional bool + description string + errorType error + invalidArgErrorType interface{} + } + + randFloat := rand.Float64() + + tests := []testCase{ + { + sampler: "invalid-sampler", + argOptional: true, + description: ParentBased(AlwaysSample()).Description(), + errorType: errUnsupportedSampler("invalid-sampler"), + invalidArgErrorType: func() *errUnsupportedSampler { e := errUnsupportedSampler("invalid-sampler"); return &e }(), + }, + { + sampler: "always_on", + argOptional: true, + description: AlwaysSample().Description(), + }, + { + sampler: "always_off", + argOptional: true, + description: NeverSample().Description(), + }, + { + sampler: "traceidratio", + samplerArg: fmt.Sprintf("%g", randFloat), + description: TraceIDRatioBased(randFloat).Description(), + }, + { + sampler: "traceidratio", + samplerArg: fmt.Sprintf("%g", -randFloat), + description: TraceIDRatioBased(1.0).Description(), + errorType: errNegativeTraceIDRatio, + }, + { + sampler: "traceidratio", + samplerArg: fmt.Sprintf("%g", 1+randFloat), + description: TraceIDRatioBased(1.0).Description(), + errorType: errGreaterThanOneTraceIDRatio, + }, + { + sampler: "traceidratio", + argOptional: true, + description: TraceIDRatioBased(1.0).Description(), + invalidArgErrorType: new(samplerArgParseError), + }, + { + sampler: "parentbased_always_on", + argOptional: true, + description: ParentBased(AlwaysSample()).Description(), + }, + { + sampler: "parentbased_always_off", + argOptional: true, + description: ParentBased(NeverSample()).Description(), + }, + { + sampler: "parentbased_traceidratio", + samplerArg: fmt.Sprintf("%g", randFloat), + description: ParentBased(TraceIDRatioBased(randFloat)).Description(), + }, + { + sampler: "parentbased_traceidratio", + samplerArg: fmt.Sprintf("%g", -randFloat), + description: ParentBased(TraceIDRatioBased(1.0)).Description(), + errorType: errNegativeTraceIDRatio, + }, + { + sampler: "parentbased_traceidratio", + samplerArg: fmt.Sprintf("%g", 1+randFloat), + description: ParentBased(TraceIDRatioBased(1.0)).Description(), + errorType: errGreaterThanOneTraceIDRatio, + }, + { + sampler: "parentbased_traceidratio", + argOptional: true, + description: ParentBased(TraceIDRatioBased(1.0)).Description(), + invalidArgErrorType: new(samplerArgParseError), + }, + } + + handler.Reset() + + for _, test := range tests { + t.Run(test.sampler, func(t *testing.T) { + envVars := map[string]string{ + "OTEL_TRACES_SAMPLER": test.sampler, + } + + if test.samplerArg != "" { + envVars["OTEL_TRACES_SAMPLER_ARG"] = test.samplerArg + } + envStore, err := ottest.SetEnvVariables(envVars) + require.NoError(t, err) + t.Cleanup(func() { + handler.Reset() + require.NoError(t, envStore.Restore()) + }) + + stp := NewTracerProvider(WithSyncer(NewTestExporter())) + assert.Equal(t, test.description, stp.sampler.Description()) + if test.errorType != nil { + testStoredError(t, test.errorType) + } else { + assert.Empty(t, handler.errs) + } + + if test.argOptional { + t.Run("invalid sampler arg", func(t *testing.T) { + envStore, err := ottest.SetEnvVariables(map[string]string{ + "OTEL_TRACES_SAMPLER": test.sampler, + "OTEL_TRACES_SAMPLER_ARG": "invalid-ignored-string", + }) + require.NoError(t, err) + t.Cleanup(func() { + handler.Reset() + require.NoError(t, envStore.Restore()) + }) + + stp := NewTracerProvider(WithSyncer(NewTestExporter())) + t.Cleanup(func() { + require.NoError(t, stp.Shutdown(context.Background())) + }) + assert.Equal(t, test.description, stp.sampler.Description()) + + if test.invalidArgErrorType != nil { + testStoredError(t, test.invalidArgErrorType) + } else { + assert.Empty(t, handler.errs) + } + }) + } + }) + } +} + +func testStoredError(t *testing.T, target interface{}) { + t.Helper() + + if assert.Len(t, handler.errs, 1) && assert.Error(t, handler.errs[0]) { + err := handler.errs[0] + + require.Implements(t, (*error)(nil), target) + require.NotNil(t, target.(error)) + + defer handler.Reset() + if errors.Is(err, target.(error)) { + return + } + + assert.ErrorAs(t, err, target) + } +} diff --git a/sdk/trace/sampler_env.go b/sdk/trace/sampler_env.go new file mode 100644 index 00000000000..97f80576e68 --- /dev/null +++ b/sdk/trace/sampler_env.go @@ -0,0 +1,107 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trace // import "go.opentelemetry.io/otel/sdk/trace" + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" +) + +const ( + tracesSamplerKey = "OTEL_TRACES_SAMPLER" + tracesSamplerArgKey = "OTEL_TRACES_SAMPLER_ARG" + + samplerAlwaysOn = "always_on" + samplerAlwaysOff = "always_off" + samplerTraceIDRatio = "traceidratio" + samplerParentBasedAlwaysOn = "parentbased_always_on" + samplerParsedBasedAlwaysOff = "parentbased_always_off" + samplerParentBasedTraceIDRatio = "parentbased_traceidratio" +) + +type errUnsupportedSampler string + +func (e errUnsupportedSampler) Error() string { + return fmt.Sprintf("unsupported sampler: %s", string(e)) +} + +var ( + errNegativeTraceIDRatio = errors.New("invalid trace ID ratio: less than 0.0") + errGreaterThanOneTraceIDRatio = errors.New("invalid trace ID ratio: greater than 1.0") +) + +type samplerArgParseError struct { + parseErr error +} + +func (e samplerArgParseError) Error() string { + return fmt.Sprintf("parsing sampler argument: %s", e.parseErr.Error()) +} + +func (e samplerArgParseError) Unwrap() error { + return e.parseErr +} + +func samplerFromEnv() (Sampler, error) { + sampler, ok := os.LookupEnv(tracesSamplerKey) + if !ok { + return nil, nil + } + + sampler = strings.ToLower(strings.TrimSpace(sampler)) + samplerArg, hasSamplerArg := os.LookupEnv(tracesSamplerArgKey) + samplerArg = strings.TrimSpace(samplerArg) + + switch sampler { + case samplerAlwaysOn: + return AlwaysSample(), nil + case samplerAlwaysOff: + return NeverSample(), nil + case samplerTraceIDRatio: + ratio, err := parseTraceIDRatio(samplerArg, hasSamplerArg) + return ratio, err + case samplerParentBasedAlwaysOn: + return ParentBased(AlwaysSample()), nil + case samplerParsedBasedAlwaysOff: + return ParentBased(NeverSample()), nil + case samplerParentBasedTraceIDRatio: + ratio, err := parseTraceIDRatio(samplerArg, hasSamplerArg) + return ParentBased(ratio), err + default: + return nil, errUnsupportedSampler(sampler) + } + +} + +func parseTraceIDRatio(arg string, hasSamplerArg bool) (Sampler, error) { + if !hasSamplerArg { + return TraceIDRatioBased(1.0), nil + } + v, err := strconv.ParseFloat(arg, 64) + if err != nil { + return TraceIDRatioBased(1.0), samplerArgParseError{err} + } + if v < 0.0 { + return TraceIDRatioBased(1.0), errNegativeTraceIDRatio + } + if v > 1.0 { + return TraceIDRatioBased(1.0), errGreaterThanOneTraceIDRatio + } + + return TraceIDRatioBased(v), nil +}