diff --git a/src/functions.ts b/src/functions.ts index 862dd28..04af7fd 100644 --- a/src/functions.ts +++ b/src/functions.ts @@ -11,10 +11,6 @@ import SpanContext from './span_context'; * @return a REFERENCE_CHILD_OF reference pointing to `spanContext` */ export function childOf(spanContext: SpanContext | Span): Reference { - // Allow the user to pass a Span instead of a SpanContext - if (spanContext instanceof Span) { - spanContext = spanContext.context(); - } return new Reference(Constants.REFERENCE_CHILD_OF, spanContext); } @@ -26,9 +22,5 @@ export function childOf(spanContext: SpanContext | Span): Reference { * @return a REFERENCE_FOLLOWS_FROM reference pointing to `spanContext` */ export function followsFrom(spanContext: SpanContext | Span): Reference { - // Allow the user to pass a Span instead of a SpanContext - if (spanContext instanceof Span) { - spanContext = spanContext.context(); - } return new Reference(Constants.REFERENCE_FOLLOWS_FROM, spanContext); } diff --git a/src/reference.ts b/src/reference.ts index cf87e25..30d4a92 100644 --- a/src/reference.ts +++ b/src/reference.ts @@ -1,6 +1,23 @@ import Span from './span'; import SpanContext from './span_context'; +const toContext = (contextOrSpan: SpanContext | Span): SpanContext => { + if (contextOrSpan instanceof SpanContext) { + return contextOrSpan; + } + + // Second check is for cases when a Span implementation does not extend + // opentracing.Span class directly (like Jaeger), just implements the same interface. + // The only false-positive case here is a non-extending SpanContext class, + // which has a method called "context". + // But that's too much of a specification violation to take care of. + if (contextOrSpan instanceof Span || 'context' in contextOrSpan) { + return contextOrSpan.context(); + } + + return contextOrSpan; +}; + /** * Reference pairs a reference type constant (e.g., REFERENCE_CHILD_OF or * REFERENCE_FOLLOWS_FROM) with the SpanContext it points to. @@ -39,9 +56,6 @@ export default class Reference { */ constructor(type: string, referencedContext: SpanContext | Span) { this._type = type; - this._referencedContext = ( - referencedContext instanceof Span ? - referencedContext.context() : - referencedContext); + this._referencedContext = toContext(referencedContext); } } diff --git a/src/test/opentracing_api.ts b/src/test/opentracing_api.ts index ec4892b..3c99474 100644 --- a/src/test/opentracing_api.ts +++ b/src/test/opentracing_api.ts @@ -1,6 +1,11 @@ import { expect } from 'chai'; import * as opentracing from '../index'; +import MockContext from '../mock_tracer/mock_context'; +import MockSpan from '../mock_tracer/mock_span'; +import MockTracer from '../mock_tracer/mock_tracer'; +import Span from '../span'; +import SpanContext from '../span_context'; export function opentracingAPITests(): void { describe('Opentracing API', () => { @@ -85,6 +90,36 @@ export function opentracingAPITests(): void { const ref = new opentracing.Reference(opentracing.REFERENCE_CHILD_OF, span.context()); expect(ref).to.be.an('object'); }); + + it('should get context from custom extending span classes', () => { + const span = new MockSpan(new MockTracer()); + const ref = new opentracing.Reference(opentracing.REFERENCE_CHILD_OF, span); + expect(ref.referencedContext() instanceof SpanContext).to.equal(true); + }); + + it('should get context from custom non-extending span classes', () => { + const ctx = new SpanContext(); + const pseudoSpan = { + context: () => ctx + } as Span; + const ref = new opentracing.Reference(opentracing.REFERENCE_CHILD_OF, pseudoSpan); + expect(ref.referencedContext()).to.equal(ctx); + }); + + it('should use extending contexts', () => { + const ctx = new MockContext({} as MockSpan); + const ref = new opentracing.Reference(opentracing.REFERENCE_CHILD_OF, ctx); + expect(ref.referencedContext()).to.equal(ctx); + }); + + it('should use non-extending contexts', () => { + const ctx = { + toTraceId: () => '', + toSpanId: () => '' + }; + const ref = new opentracing.Reference(opentracing.REFERENCE_CHILD_OF, ctx); + expect(ref.referencedContext()).to.equal(ctx); + }); }); describe('BinaryCarrier', () => {