From e57c6d92b5d876cfeb9e1345ba25f0b45cabea51 Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Mon, 4 Nov 2024 22:33:16 +0100 Subject: [PATCH 1/4] refactor: simplify codes without logic changes --- getters.go | 18 +++++++++--------- ore.go | 7 ++++--- serviceResolver.go | 15 +++++++-------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/getters.go b/getters.go index 1ab5c3a..c58e381 100644 --- a/getters.go +++ b/getters.go @@ -5,32 +5,32 @@ import ( "sort" ) -func getLastRegisteredResolver(typeId typeID) (serviceResolver, int) { +func getLastRegisteredResolver(typeID typeID) serviceResolver { // try to get service resolver from container lock.RLock() - resolvers, resolverExists := container[typeId] + resolvers, resolverExists := container[typeID] lock.RUnlock() if !resolverExists { - return nil, -1 + return nil } count := len(resolvers) if count == 0 { - return nil, -1 + return nil } // index of the last implementation lastIndex := count - 1 - return resolvers[lastIndex], lastIndex + return resolvers[lastIndex] } // Get Retrieves an instance based on type and key (panics if no valid implementations) func Get[T any](ctx context.Context, key ...KeyStringer) (T, context.Context) { pointerTypeName := getPointerTypeName[T]() typeID := getTypeID(pointerTypeName, key) - lastRegisteredResolver, lastIndex := getLastRegisteredResolver(typeID) + lastRegisteredResolver := getLastRegisteredResolver(typeID) if lastRegisteredResolver == nil { //not found, T is an alias lock.RLock() @@ -47,7 +47,7 @@ func Get[T any](ctx context.Context, key ...KeyStringer) (T, context.Context) { for i := count - 1; i >= 0; i-- { impl := implementations[i] typeID = getTypeID(impl, key) - lastRegisteredResolver, lastIndex = getLastRegisteredResolver(typeID) + lastRegisteredResolver = getLastRegisteredResolver(typeID) if lastRegisteredResolver != nil { break } @@ -56,7 +56,7 @@ func Get[T any](ctx context.Context, key ...KeyStringer) (T, context.Context) { if lastRegisteredResolver == nil { panic(noValidImplementation[T]()) } - con, ctx := lastRegisteredResolver.resolveService(ctx, typeID, lastIndex) + con, ctx := lastRegisteredResolver.resolveService(ctx) return con.value.(T), ctx } @@ -92,7 +92,7 @@ func GetList[T any](ctx context.Context, key ...KeyStringer) ([]T, context.Conte for index := 0; index < len(resolvers); index++ { resolver := resolvers[index] - con, newCtx := resolver.resolveService(ctx, typeID, index) + con, newCtx := resolver.resolveService(ctx) servicesArray = append(servicesArray, con.value.(T)) ctx = newCtx } diff --git a/ore.go b/ore.go index 1d46404..0060da4 100644 --- a/ore.go +++ b/ore.go @@ -43,7 +43,7 @@ func typeIdentifier[T any](key []KeyStringer) typeID { } // Appends a service resolver to the container with type and key -func appendToContainer[T any](resolver serviceResolver, key []KeyStringer) { +func appendToContainer[T any](resolver serviceResolverImpl[T], key []KeyStringer) { if isBuilt { panic(alreadyBuiltCannotAdd) } @@ -51,13 +51,14 @@ func appendToContainer[T any](resolver serviceResolver, key []KeyStringer) { typeID := typeIdentifier[T](key) lock.Lock() + resolver.ID = contextKey{typeID, len(container[typeID])} container[typeID] = append(container[typeID], resolver) lock.Unlock() } -func replaceServiceResolver(typeId typeID, index int, resolver serviceResolver) { +func replaceServiceResolver[T any](resolver serviceResolverImpl[T]) { lock.Lock() - container[typeId][index] = resolver + container[resolver.ID.typeID][resolver.ID.index] = resolver lock.Unlock() } diff --git a/serviceResolver.go b/serviceResolver.go index de6970f..131627c 100644 --- a/serviceResolver.go +++ b/serviceResolver.go @@ -10,7 +10,7 @@ type ( ) type serviceResolver interface { - resolveService(ctx context.Context, typeId typeID, index int) (*concrete, context.Context) + resolveService(ctx context.Context) (*concrete, context.Context) //return the invoked singleton value, or false if the resolver is not a singleton or has not been invoked getInvokedSingleton() (con *concrete, isInvokedSingleton bool) } @@ -20,22 +20,21 @@ type serviceResolverImpl[T any] struct { creatorInstance Creator[T] singletonConcrete *concrete lifetime Lifetime + ID contextKey } // make sure that the `serviceResolverImpl` struct implements the `serviceResolver` interface var _ serviceResolver = serviceResolverImpl[any]{} -func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID typeID, index int) (*concrete, context.Context) { +func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concrete, context.Context) { // try get concrete implementation if this.lifetime == Singleton && this.singletonConcrete != nil { return this.singletonConcrete, ctx } - ctxKey := contextKey{typeID, index} - // try get concrete from context scope if this.lifetime == Scoped { - scopedConcrete, ok := ctx.Value(ctxKey).(*concrete) + scopedConcrete, ok := ctx.Value(this.ID).(*concrete) if ok { return scopedConcrete, ctx } @@ -59,15 +58,15 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID ty // if scoped, attach to the current context if this.lifetime == Scoped { - ctx = context.WithValue(ctx, ctxKey, con) - ctx = addToContextKeysRepository(ctx, ctxKey) + ctx = context.WithValue(ctx, this.ID, con) + ctx = addToContextKeysRepository(ctx, this.ID) } // if was lazily-created, then attach the newly-created concrete implementation // to the service resolver if this.lifetime == Singleton { this.singletonConcrete = con - replaceServiceResolver(typeID, index, this) + replaceServiceResolver(this) return con, ctx } From fa6cb01fffdae9b5639d40dc25f1535055098ccb Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Wed, 6 Nov 2024 16:58:08 +0100 Subject: [PATCH 2/4] add validation feature - handle circular DI: panic with a nice error instead of call stack overflow - detect lifetime misalignment to panic - add ore.validate() func to panic early if the container is bad configured --- README.md | 22 +- concrete.go | 4 + errors.go | 8 + get_test.go | 350 ++++++++++++++--------- getters.go | 8 +- internal/models/disposable.go | 20 ++ internal/testtools/assert2/assertions.go | 72 +++++ lifetimes.go | 22 +- ore.go | 28 +- registrars.go | 12 +- serviceResolver.go | 93 +++++- utils.go | 16 +- validate_test.go | 212 ++++++++++++++ 13 files changed, 703 insertions(+), 164 deletions(-) create mode 100644 internal/testtools/assert2/assertions.go create mode 100644 validate_test.go diff --git a/README.md b/README.md index 428cdd8..4004c2d 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,22 @@ func TestGetInterfaceAlias(t *testing.T) { Alias is also scoped by key. When you "Get" an alias with keys for eg: `ore.Get[IPerson](ctx, "module1")` then Ore would return only Services registered under this key ("module1") and panic if no service found. +### Registration validation + +It is recommended to build your container (which means register ALL the resolvers) only ONCE on application start. +Next, it is recommended to call `ore.Validate()` + +- either in a test which is automatically run on your CI/CD (option 1) +- or on application start, just after resolvers registration (option 2) + +option 1 (run `ore.Validate` on test) is often a better choice. + +`ore.Validate()` invokes ALL your registered resolvers, it panics when something gone wrong. The purpose of this function is to panic early when the Container is bad configured: + +- Missing depedency: you forgot to register certain resolvers. +- Circular dependency: A depends on B whic depends on A. +- Lifetime misalignment: a longer lifetime service (eg. Singleton) depends on a shorter one (eg Transient). + ### Graceful application termination On application termination, you want to call `Shutdown()` on all the "Singletons" objects which have been created during the application life time. @@ -353,7 +369,8 @@ In resume, the `ore.GetResolvedSingletons[TInterface]()` function returns a list - It returns only the instances which had been invoked (a.k.a resolved). - All the implementations including "keyed" one will be returned. -- The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. +- The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the "most recently created" one. + - if "A" depends on "B", "C", Ore will make sure to return "B" and "C" first in the list so that they would be shutdowned before "A". However Ore won't guarantee the order of "B" and "C" ### Graceful context termination @@ -392,7 +409,8 @@ The `ore.GetResolvedScopedInstances[TInterface](context)` function returns a lis - It returns only the instances which had been invoked (a.k.a resolved) during the context life time. - All the implementations including "keyed" one will be returned. -- The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. +- The returned instances are sorted by invocation order, the first one being the most "recently created" one. + - if "A" depends on "B", "C", Ore will make sure to return "B" and "C" first in the list so that they would be Disposed before "A". However Ore won't guarantee the order of "B" and "C" ## More Complex Example diff --git a/concrete.go b/concrete.go index 29200d8..c27cbd4 100644 --- a/concrete.go +++ b/concrete.go @@ -10,4 +10,8 @@ type concrete struct { createdAt time.Time // the lifetime of this concrete lifetime Lifetime + // the invocation deep level, the bigger the value, the deeper it was resolved in the dependency chain + // for example: A depends on B, B depends on C, C depends on D + // A will have invocationLevel = 1, B = 2, C = 3, D = 4 + invocationLevel int } diff --git a/errors.go b/errors.go index 5597b48..7f95b15 100644 --- a/errors.go +++ b/errors.go @@ -14,6 +14,14 @@ func nilVal[T any]() error { return fmt.Errorf("nil implementation for type: %s", reflect.TypeFor[T]()) } +func lifetimeMisalignment(resolver resolverMetadata, depResolver resolverMetadata) error { + return fmt.Errorf("detect lifetime misalignment: %s depends on %s", resolver, depResolver) +} + +func cyclicDependency(resolver resolverMetadata) error { + return fmt.Errorf("detect cyclic dependency where: %s depends on itself", resolver) +} + var alreadyBuilt = errors.New("services container is already built") var alreadyBuiltCannotAdd = errors.New("cannot appendToContainer, services container is already built") var nilKey = errors.New("cannot have nil keys") diff --git a/get_test.go b/get_test.go index 9ef6e52..98df359 100644 --- a/get_test.go +++ b/get_test.go @@ -75,152 +75,228 @@ func TestGetKeyed(t *testing.T) { } func TestGetResolvedSingletons(t *testing.T) { - //Arrange - clearAll() - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { - return &m.DisposableService1{Name: "A1"}, ctx - }) - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { - return &m.DisposableService1{Name: "A2"}, ctx - }) - RegisterEagerSingleton(&m.DisposableService2{Name: "E1"}) - RegisterEagerSingleton(&m.DisposableService2{Name: "E2"}) - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService3, context.Context) { - return &m.DisposableService3{Name: "S1"}, ctx - }) - RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService3, context.Context) { - return &m.DisposableService3{Name: "S2"}, ctx - }) - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService4, context.Context) { - return &m.DisposableService4{Name: "X1"}, ctx + t.Run("When multiple lifetimes and keys are registered", func(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A1"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A2"}, ctx + }) + RegisterEagerSingleton(&m.DisposableService2{Name: "E1"}) + RegisterEagerSingleton(&m.DisposableService2{Name: "E2"}) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "S1"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "S2"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService4, context.Context) { + return &m.DisposableService4{Name: "X1"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService4, context.Context) { + return &m.DisposableService4{Name: "X2"}, ctx + }, "somekey") + + ctx := context.Background() + //Act + disposables := GetResolvedSingletons[m.Disposer]() //E1, E2 + assert.Equal(t, 2, len(disposables)) + + //invoke A1, A2 + _, ctx = GetList[*m.DisposableService1](ctx) //A1, A2 + + //Act + disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2 + assert.Equal(t, 4, len(disposables)) + + //invoke S1, S2, X1 + RegisterAlias[fmt.Stringer, *m.DisposableService3]() + RegisterAlias[fmt.Stringer, *m.DisposableService4]() + _, ctx = GetList[fmt.Stringer](ctx) //S1, S2, X1 + + //Act + //because S1, S2 are not singleton, so they won't be returned, only X1 will be returned in addition + disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2, X1 + assert.Equal(t, 5, len(disposables)) + + //invoke X2 in "somekey" scope + _, _ = GetList[fmt.Stringer](ctx, "somekey") + + //Act + //all invoked singleton would be returned whatever keys they are registered with + disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2, X1, X2 + assert.Equal(t, 6, len(disposables)) }) - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService4, context.Context) { - return &m.DisposableService4{Name: "X2"}, ctx - }, "somekey") - - ctx := context.Background() - //Act - disposables := GetResolvedSingletons[m.Disposer]() //E1, E2 - assert.Equal(t, 2, len(disposables)) - - //invoke A1, A2 - _, ctx = GetList[*m.DisposableService1](ctx) //A1, A2 - - //Act - disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2 - assert.Equal(t, 4, len(disposables)) - - //invoke S1, S2, X1 - RegisterAlias[fmt.Stringer, *m.DisposableService3]() - RegisterAlias[fmt.Stringer, *m.DisposableService4]() - _, ctx = GetList[fmt.Stringer](ctx) //S1, S2, X1 - - //Act - //because S1, S2 are not singleton, so they won't be returned, only X1 will be returned in addition - disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2, X1 - assert.Equal(t, 5, len(disposables)) - - //invoke X2 in "somekey" scope - _, _ = GetList[fmt.Stringer](ctx, "somekey") - - //Act - //all invoked singleton would be returned whatever keys they are registered with - disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2, X1, X2 - assert.Equal(t, 6, len(disposables)) -} -func TestGetResolvedSingletonsOrder(t *testing.T) { - //Arrange - clearAll() - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { - return &m.DisposableService1{Name: "A"}, ctx - }) - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService2, context.Context) { - return &m.DisposableService2{Name: "B"}, ctx + t.Run("respect invocation chronological time order", func(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "B"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "C"}, ctx + }) + + ctx := context.Background() + + //invocation order: [A,C,B] + _, ctx = Get[*m.DisposableService1](ctx) + time.Sleep(1 * time.Microsecond) + _, ctx = Get[*m.DisposableService3](ctx) + time.Sleep(1 * time.Microsecond) + _, _ = Get[*m.DisposableService2](ctx) + + //Act + disposables := GetResolvedSingletons[m.Disposer]() //B, A + + //Assert that the order is [B,C,A], the most recent invocation would be returned first + assert.Equal(t, 3, len(disposables)) + assert.Equal(t, "B", disposables[0].String()) + assert.Equal(t, "C", disposables[1].String()) + assert.Equal(t, "A", disposables[2].String()) }) - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService3, context.Context) { - return &m.DisposableService3{Name: "C"}, ctx + t.Run("deeper invocation level is returned first", func(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 calls 3 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "3"}, ctx + }) + + //invocation order: [1,2,3] + _, _ = Get[*m.DisposableService1](context.Background()) + + //Act + disposables := GetResolvedSingletons[m.Disposer]() + + //Assert that the order is [B,C,A], the deepest invocation level would be returned first + assert.Equal(t, 3, len(disposables)) + assert.Equal(t, "3", disposables[0].String()) + assert.Equal(t, "2", disposables[1].String()) + assert.Equal(t, "1", disposables[2].String()) }) - - ctx := context.Background() - - //invocation order: [A,C,B] - _, ctx = Get[*m.DisposableService1](ctx) - time.Sleep(1 * time.Millisecond) - _, ctx = Get[*m.DisposableService3](ctx) - time.Sleep(1 * time.Millisecond) - _, _ = Get[*m.DisposableService2](ctx) - - //Act - disposables := GetResolvedSingletons[m.Disposer]() //B, A - - //Assert that the order is [B,C,A], the most recent invocation would be returned first - assert.Equal(t, 3, len(disposables)) - assert.Equal(t, "B", disposables[0].String()) - assert.Equal(t, "C", disposables[1].String()) - assert.Equal(t, "A", disposables[2].String()) } func TestGetResolvedScopedInstances(t *testing.T) { - clearAll() - RegisterEagerSingleton(&m.DisposableService1{Name: "S1"}) - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { - return &m.DisposableService1{Name: "S2"}, ctx + t.Run("When multiple lifetimes and keys are registered", func(t *testing.T) { + clearAll() + RegisterEagerSingleton(&m.DisposableService1{Name: "S1"}) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "S2"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "T1"}, ctx + }, "module1") + + ctx := context.Background() + + //Act + disposables := GetResolvedScopedInstances[m.Disposer](ctx) //empty + assert.Empty(t, disposables) + + //invoke S2 + _, ctx = GetList[*m.DisposableService1](ctx) + + //Act + disposables = GetResolvedScopedInstances[m.Disposer](ctx) //S2 + assert.Equal(t, 1, len(disposables)) + assert.Equal(t, "S2", disposables[0].String()) + + //invoke the keyed service T1 + _, ctx = GetList[*m.DisposableService2](ctx, "module1") + + //Act + disposables = GetResolvedScopedInstances[m.Disposer](ctx) //S2, T1 + assert.Equal(t, 2, len(disposables)) }) - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { - return &m.DisposableService2{Name: "T1"}, ctx - }, "module1") - ctx := context.Background() - - //Act - disposables := GetResolvedScopedInstances[m.Disposer](ctx) //empty - assert.Empty(t, disposables) - - //invoke S2 - _, ctx = GetList[*m.DisposableService1](ctx) - - //Act - disposables = GetResolvedScopedInstances[m.Disposer](ctx) //S2 - assert.Equal(t, 1, len(disposables)) - assert.Equal(t, "S2", disposables[0].String()) - - //invoke the keyed service T1 - _, ctx = GetList[*m.DisposableService2](ctx, "module1") - - //Act - disposables = GetResolvedScopedInstances[m.Disposer](ctx) //S2, T1 - assert.Equal(t, 2, len(disposables)) -} - -func TestGetResolvedScopedInstancesOrder(t *testing.T) { - //Arrange - clearAll() - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { - return &m.DisposableService1{Name: "A"}, ctx - }) - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { - return &m.DisposableService2{Name: "B"}, ctx - }) - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService3, context.Context) { - return &m.DisposableService3{Name: "C"}, ctx + t.Run("respect invocation chronological time order", func(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "B"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "C"}, ctx + }) + + ctx := context.Background() + + //invocation order: [A,C,B] + _, ctx = Get[*m.DisposableService1](ctx) + time.Sleep(1 * time.Microsecond) + _, ctx = Get[*m.DisposableService3](ctx) + time.Sleep(1 * time.Microsecond) + _, ctx = Get[*m.DisposableService2](ctx) + + //Act + disposables := GetResolvedScopedInstances[m.Disposer](ctx) //B, A + + //Assert that the order is [B,C,A], the most recent invocation would be returned first + assert.Equal(t, 3, len(disposables)) + assert.Equal(t, "B", disposables[0].String()) + assert.Equal(t, "C", disposables[1].String()) + assert.Equal(t, "A", disposables[2].String()) }) - ctx := context.Background() - - //invocation order: [A,C,B] - _, ctx = Get[*m.DisposableService1](ctx) - time.Sleep(1 * time.Millisecond) - _, ctx = Get[*m.DisposableService3](ctx) - time.Sleep(1 * time.Millisecond) - _, ctx = Get[*m.DisposableService2](ctx) - - //Act - disposables := GetResolvedScopedInstances[m.Disposer](ctx) //B, A - - //Assert that the order is [B,C,A], the most recent invocation would be returned first - assert.Equal(t, 3, len(disposables)) - assert.Equal(t, "B", disposables[0].String()) - assert.Equal(t, "C", disposables[1].String()) - assert.Equal(t, "A", disposables[2].String()) + t.Run("respect invocation deep level", func(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { + //1 calls 3 + _, ctx = Get[*m.DisposableService3](ctx) + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "3"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService4, context.Context) { + //4 calls 1, 2 + _, ctx = Get[*m.DisposableService1](ctx) + _, ctx = Get[*m.DisposableService2](ctx) + return &m.DisposableService4{Name: "4"}, ctx + }) + + ctx := context.Background() + + //invocation order: [4,1,3,2]. + _, ctx = Get[*m.DisposableService4](ctx) + + //Act + disposables := GetResolvedScopedInstances[m.Disposer](ctx) + + assert.Equal(t, 4, len(disposables)) + + //find the position of the disposables + index1 := m.FindIndexOf(disposables, "1") + index2 := m.FindIndexOf(disposables, "2") + index3 := m.FindIndexOf(disposables, "3") + index4 := m.FindIndexOf(disposables, "4") + + //Assert that 4 should be disposed after 1 and 2 (because 4 calls 1 and 2) + assert.Greater(t, index4, index1) + assert.Greater(t, index4, index2) + + //Assert that 1 should be disposed after 3 (because 1 calls 3) + assert.Greater(t, index1, index3) + }) } diff --git a/getters.go b/getters.go index c58e381..a27d202 100644 --- a/getters.go +++ b/getters.go @@ -102,7 +102,8 @@ func GetList[T any](ctx context.Context, key ...KeyStringer) ([]T, context.Conte } // GetResolvedSingletons retrieves a list of Singleton instances that implement the [TInterface]. -// The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. +// The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the "most recently" created one. +// If an instance "A" depends on certain instances "B" and "C" then this function guarantee to return "B" and "C" before "A" in the list. // It would return only the instances which had been resolved. Other lazy implementations which have never been invoked will not be returned. // This function is useful for cleaning operations. // @@ -135,6 +136,7 @@ func GetResolvedSingletons[TInterface any]() []TInterface { // GetResolvedScopedInstances retrieves a list of Scoped instances that implement the [TInterface]. // The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. +// If an instance "A" depends on certain instances "B" and "C" then this function guarantee to return "B" and "C" before "A" in the list. // It would return only the instances which had been resolved. Other lazy implementations which have never been invoked will not be returned. // This function is useful for cleaning operations. // @@ -167,7 +169,9 @@ func GetResolvedScopedInstances[TInterface any](ctx context.Context) []TInterfac func sortAndSelect[TInterface any](list []*concrete) []TInterface { //sorting sort.Slice(list, func(i, j int) bool { - return list[i].createdAt.After(list[j].createdAt) + return list[i].createdAt.After(list[j].createdAt) || + (list[i].createdAt == list[j].createdAt && + list[i].invocationLevel > list[j].invocationLevel) }) //selecting diff --git a/internal/models/disposable.go b/internal/models/disposable.go index 31eb215..f70c6c8 100644 --- a/internal/models/disposable.go +++ b/internal/models/disposable.go @@ -50,3 +50,23 @@ func (*DisposableService4) Dispose() {} func (this *DisposableService4) String() string { return this.Name } + +var _ Disposer = (*DisposableService5)(nil) + +type DisposableService5 struct { + Name string +} + +func (*DisposableService5) Dispose() {} +func (this *DisposableService5) String() string { + return this.Name +} + +func FindIndexOf(disposables []Disposer, name string) int { + for i, disposable := range disposables { + if disposable.String() == name { + return i + } + } + return -1 +} diff --git a/internal/testtools/assert2/assertions.go b/internal/testtools/assert2/assertions.go new file mode 100644 index 0000000..7c8942d --- /dev/null +++ b/internal/testtools/assert2/assertions.go @@ -0,0 +1,72 @@ +// assert2 package add missing assertions from testify/assert package +package assert2 + +import ( + "fmt" + "runtime/debug" + + "github.com/stretchr/testify/assert" +) + +type tHelper interface { + Helper() +} + +type StringMatcher = func(s string) bool + +// PanicsWithError asserts that the code inside the specified PanicTestFunc +// panics, and that the recovered panic value is an error that satisfies the +// StringMatcher. +// +// assert.PanicsWithError(t, ErrorStartsWith("crazy error"), func(){ GoCrazy() }) +func PanicsWithError(t assert.TestingT, errStringMatcher StringMatcher, f assert.PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + funcDidPanic, panicValue, panickedStack := didPanic(f) + if !funcDidPanic { + return assert.Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...) + } + panicErr, ok := panicValue.(error) + if !ok || !errStringMatcher(panicErr.Error()) { + return assert.Fail(t, fmt.Sprintf("func %#v panic with unexpected Panic value:\t%#v\n\tPanic stack:\t%s", f, panicValue, panickedStack), msgAndArgs...) + } + + return true +} + +func ErrorStartsWith(prefix string) StringMatcher { + return func(s string) bool { + return s != "" && s[:len(prefix)] == prefix + } +} + +func ErrorEndsWith(suffix string) StringMatcher { + return func(s string) bool { + return s != "" && s[len(s)-len(suffix):] == suffix + } +} + +func ErrorContains(substr string) StringMatcher { + return func(s string) bool { + return s != "" && s[:len(substr)] == substr + } +} + +// didPanic returns true if the function passed to it panics. Otherwise, it returns false. +func didPanic(f assert.PanicTestFunc) (didPanic bool, message interface{}, stack string) { + didPanic = true + + defer func() { + message = recover() + if didPanic { + stack = string(debug.Stack()) + } + }() + + // call the target function + f() + didPanic = false + + return +} diff --git a/lifetimes.go b/lifetimes.go index f2d1fbd..d6d854a 100644 --- a/lifetimes.go +++ b/lifetimes.go @@ -1,9 +1,23 @@ package ore -type Lifetime string +type Lifetime int +// The bigger the value, the longer the lifetime const ( - Singleton Lifetime = "singleton" - Transient Lifetime = "transient" - Scoped Lifetime = "scoped" + Transient Lifetime = 0 + Scoped Lifetime = 1 + Singleton Lifetime = 2 ) + +func (this Lifetime) String() string { + switch this { + case 0: + return "Transient" + case 1: + return "Scoped" + case 2: + return "Singleton" + default: + return "Unknow" + } +} diff --git a/ore.go b/ore.go index 0060da4..b0f7428 100644 --- a/ore.go +++ b/ore.go @@ -14,11 +14,9 @@ var ( aliases = map[pointerTypeName][]pointerTypeName{} //contextKeysRepositoryID is a special context key. The value of this key is the collection of other context keys stored in the context. - contextKeysRepositoryID = contextKey{ - typeID{ - pointerTypeName: "", - oreKey: "The context keys repository", - }, -1} + contextKeysRepositoryID specialContextKey = "The context keys repository" + //contextKeyResolversChain is a special context key. The value of this key is the [ResolversChain]. + contextKeyResolversChain specialContextKey = "Dependencies chain" ) type contextKeysRepository = []contextKey @@ -51,14 +49,14 @@ func appendToContainer[T any](resolver serviceResolverImpl[T], key []KeyStringer typeID := typeIdentifier[T](key) lock.Lock() - resolver.ID = contextKey{typeID, len(container[typeID])} + resolver.id = contextKey{typeID, len(container[typeID])} container[typeID] = append(container[typeID], resolver) lock.Unlock() } func replaceServiceResolver[T any](resolver serviceResolverImpl[T]) { lock.Lock() - container[resolver.ID.typeID][resolver.ID.index] = resolver + container[resolver.id.typeID][resolver.id.index] = resolver lock.Unlock() } @@ -85,3 +83,19 @@ func Build() { isBuilt = true } + +// Validate invokes all registered resolvers. It panics if any of them fails. +// It is recommended to call this function on application start, or in the CI/CD test pipeline +// The objectif is to panic early when the container is bad configured. For eg: +// +// - (1) Missing depedency (forget to register certain resolvers) +// - (2) cyclic dependency +// - (3) lifetime misalignment (a longer lifetime service depends on a shorter one). +func Validate() { + ctx := context.Background() + for _, resolvers := range container { + for _, resolver := range resolvers { + _, ctx = resolver.resolveService(ctx) + } + } +} diff --git a/registrars.go b/registrars.go index 32bcd06..871575a 100644 --- a/registrars.go +++ b/registrars.go @@ -13,7 +13,9 @@ func RegisterLazyCreator[T any](lifetime Lifetime, creator Creator[T], key ...Ke } e := serviceResolverImpl[T]{ - lifetime: lifetime, + resolverMetadata: resolverMetadata{ + lifetime: lifetime, + }, creatorInstance: creator, } appendToContainer[T](e, key) @@ -26,7 +28,9 @@ func RegisterEagerSingleton[T comparable](impl T, key ...KeyStringer) { } e := serviceResolverImpl[T]{ - lifetime: Singleton, + resolverMetadata: resolverMetadata{ + lifetime: Singleton, + }, singletonConcrete: &concrete{ value: impl, lifetime: Singleton, @@ -43,7 +47,9 @@ func RegisterLazyFunc[T any](lifetime Lifetime, initializer Initializer[T], key } e := serviceResolverImpl[T]{ - lifetime: lifetime, + resolverMetadata: resolverMetadata{ + lifetime: lifetime, + }, anonymousInitializer: &initializer, } appendToContainer[T](e, key) diff --git a/serviceResolver.go b/serviceResolver.go index 131627c..5085612 100644 --- a/serviceResolver.go +++ b/serviceResolver.go @@ -1,7 +1,9 @@ package ore import ( + "container/list" "context" + "fmt" "time" ) @@ -10,19 +12,47 @@ type ( ) type serviceResolver interface { + fmt.Stringer resolveService(ctx context.Context) (*concrete, context.Context) //return the invoked singleton value, or false if the resolver is not a singleton or has not been invoked getInvokedSingleton() (con *concrete, isInvokedSingleton bool) } +type resolverMetadata struct { + id contextKey + lifetime Lifetime +} + type serviceResolverImpl[T any] struct { + resolverMetadata anonymousInitializer *Initializer[T] creatorInstance Creator[T] singletonConcrete *concrete - lifetime Lifetime - ID contextKey } +// resolversChain is a linkedList[resolverMetadata], describing a dependencies chain which a resolver has to invoke other resolvers to resolve its dependencies. +// Before a resolver creates a new concrete value it would be registered to the resolversChain. +// Once the concrete is resolved (with help of other resolvers), then it would be removed from the chain. +// +// While a Resolver forms a tree with other dependent resolvers. +// +// Example: +// +// A calls B and C; B calls D; C calls E. +// +// then resolversChain is a "path" in the tree from the root to one of the bottom. +// +// Example: +// +// A -> B -> D or A -> C -> E +// +// The resolversChain is stored in the context. Analyze the chain will help to +// +// - (1) detect the invocation level +// - (2) detect cyclic dependencies +// - (3) detect lifetime misalignment (when a service of longer lifetime depends on a service of shorter lifetime) +type resolversChain = *list.List + // make sure that the `serviceResolverImpl` struct implements the `serviceResolver` interface var _ serviceResolver = serviceResolverImpl[any]{} @@ -34,12 +64,27 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret // try get concrete from context scope if this.lifetime == Scoped { - scopedConcrete, ok := ctx.Value(this.ID).(*concrete) + scopedConcrete, ok := ctx.Value(this.id).(*concrete) if ok { return scopedConcrete, ctx } } + // this resolver is about to create a new concrete value, we have to put it to the resolversChain until the creation done + + // get the current currentChain from the context + var currentChain resolversChain + untypedCurrentChain := ctx.Value(contextKeyResolversChain) + if untypedCurrentChain == nil { + currentChain = list.New() + ctx = context.WithValue(ctx, contextKeyResolversChain, currentChain) + } else { + currentChain = untypedCurrentChain.(resolversChain) + } + + // push this newest resolver to the resolversChain + marker := appendResolver(currentChain, this.resolverMetadata) + var concreteValue T // first, try make concrete implementation from `anonymousInitializer` @@ -50,16 +95,22 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret concreteValue, ctx = this.creatorInstance.New(ctx) } + invocationLevel := currentChain.Len() + + // the concreteValue is created, we must to remove it from the resolversChain so that downstream resolvers (meaning the future resolvers) won't link to it + currentChain.Remove(marker) + con := &concrete{ - value: concreteValue, - lifetime: this.lifetime, - createdAt: time.Now(), + value: concreteValue, + lifetime: this.lifetime, + createdAt: time.Now(), + invocationLevel: invocationLevel, } // if scoped, attach to the current context if this.lifetime == Scoped { - ctx = context.WithValue(ctx, this.ID, con) - ctx = addToContextKeysRepository(ctx, this.ID) + ctx = context.WithValue(ctx, this.id, con) + ctx = addToContextKeysRepository(ctx, this.id) } // if was lazily-created, then attach the newly-created concrete implementation @@ -73,6 +124,28 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret return con, ctx } +// appendToResolversChain push the given resolver to the Back of the ResolversChain. +// `marker.previous` refers to the calling (parent) resolver +func appendResolver(chain resolversChain, currentResolver resolverMetadata) (marker *list.Element) { + if chain.Len() != 0 { + //detect lifetime misalignment + lastElem := chain.Back() + lastResolver := lastElem.Value.(resolverMetadata) + if lastResolver.lifetime > currentResolver.lifetime { + panic(lifetimeMisalignment(lastResolver, currentResolver)) + } + + //detect cyclic dependencies + for e := chain.Back(); e != nil; e = e.Prev() { + if e.Value.(resolverMetadata).id == currentResolver.id { + panic(cyclicDependency(currentResolver)) + } + } + } + marker = chain.PushBack(currentResolver) // `marker.previous` refers to the calling (parent) resolver + return marker +} + func (this serviceResolverImpl[T]) getInvokedSingleton() (con *concrete, isInvokedSingleton bool) { if this.lifetime == Singleton && this.singletonConcrete != nil { return this.singletonConcrete, true @@ -89,3 +162,7 @@ func addToContextKeysRepository(ctx context.Context, newContextKey contextKey) c } return context.WithValue(ctx, contextKeysRepositoryID, repository) } + +func (this resolverMetadata) String() string { + return fmt.Sprintf("Resolver(%s, type={%s}, key='%s')", this.lifetime, getUnderlyingTypeName(this.id.pointerTypeName), this.id.oreKey) +} diff --git a/utils.go b/utils.go index 187367a..6ead85e 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,11 @@ package ore -import "fmt" +import ( + "fmt" + "strings" +) + +type specialContextKey string type contextKey struct { typeID @@ -29,3 +34,12 @@ func getPointerTypeName[T any]() pointerTypeName { var mockValue *T return pointerTypeName(fmt.Sprintf("%T", mockValue)) } + +func getUnderlyingTypeName(ptn pointerTypeName) string { + s := string(ptn) + index := strings.Index(s, "*") + if index == -1 { + return s // no '*' found, return the original string + } + return s[:index] + s[index+1:] +} diff --git a/validate_test.go b/validate_test.go new file mode 100644 index 0000000..e2b20a2 --- /dev/null +++ b/validate_test.go @@ -0,0 +1,212 @@ +package ore + +import ( + "context" + "testing" + + m "github.com/firasdarwish/ore/internal/models" + "github.com/firasdarwish/ore/internal/testtools/assert2" + "github.com/stretchr/testify/assert" +) + +func TestValidate_CircularDepsUniformLifetype(t *testing.T) { + for _, lt := range types { + t.Run("Direct circular "+lt.String()+" (1 calls 1)", func(t *testing.T) { + clearAll() + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService1](ctx) //1 calls 1 + return &m.DisposableService1{Name: "1"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect cyclic dependency"), Validate) + }) + t.Run("Indirect circular "+lt.String()+" (1 calls 2 calls 3 calls 1)", func(t *testing.T) { + clearAll() + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 calls 3 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService3, context.Context) { + _, ctx = Get[*m.DisposableService1](ctx) //3 calls 1 + return &m.DisposableService3{Name: "3"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect cyclic dependency"), Validate) + }) + t.Run("Middle circular "+lt.String()+" (1 calls 2 calls 3 calls 4 calls 2)", func(t *testing.T) { + clearAll() + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 calls 3 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService3, context.Context) { + _, ctx = Get[*m.DisposableService4](ctx) //3 calls 4 + return &m.DisposableService3{Name: "3"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService4, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //4 calls 2 + return &m.DisposableService4{Name: "4"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect cyclic dependency"), Validate) + }) + t.Run("circular on complex tree "+lt.String()+"", func(t *testing.T) { + clearAll() + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 + _, ctx = Get[*m.DisposableService3](ctx) //1 calls 3 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService4](ctx) //2 calls 4 + _, ctx = Get[*m.DisposableService5](ctx) //2 calls 5 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService3, context.Context) { + _, ctx = Get[*m.DisposableService4](ctx) //3 calls 4 + return &m.DisposableService3{Name: "3"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService4, context.Context) { + _, ctx = Get[*m.DisposableService5](ctx) //4 calls 5 + return &m.DisposableService4{Name: "4"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService5, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //5 calls 3 => circular here: 5->3->4->5 + return &m.DisposableService5{Name: "5"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect cyclic dependency"), Validate) + }) + t.Run("fake circular top down "+lt.String()+": (1 calls 2 (x2) calls 3 calls 4, 2 calls 4)", func(t *testing.T) { + clearAll() + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 again + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 calls 3 + _, ctx = Get[*m.DisposableService4](ctx) //2 calls 4 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService3, context.Context) { + _, ctx = Get[*m.DisposableService4](ctx) //3 calls 4 + _, ctx = Get[*m.DisposableService4](ctx) //3 calls 4 + return &m.DisposableService3{Name: "3"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService4, context.Context) { + return &m.DisposableService4{Name: "4"}, ctx + }) + assert.NotPanics(t, Validate) + }) + t.Run("fake circular sibling "+lt.String()+": 1 calls 2 & 3; 2 calls 3)", func(t *testing.T) { + clearAll() + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 + _, ctx = Get[*m.DisposableService3](ctx) //1 calls 3 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 calls 3 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(lt, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "3"}, ctx + }) + assert.NotPanics(t, Validate) + }) + } +} + +func TestValidate_CircularMixedLifetype(t *testing.T) { + clearAll() + + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService4](ctx) //2 calls 4 + _, ctx = Get[*m.DisposableService5](ctx) //2 calls 5 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService3, context.Context) { + _, ctx = Get[*m.DisposableService4](ctx) //3 calls 4 + return &m.DisposableService3{Name: "3"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService4, context.Context) { + _, ctx = Get[*m.DisposableService5](ctx) //4 calls 5 + return &m.DisposableService4{Name: "4"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService5, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //5 calls 3 => circular here: 5->3->4->5 + return &m.DisposableService5{Name: "5"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 calls 2 + _, ctx = Get[*m.DisposableService3](ctx) //1 calls 3 + return &m.DisposableService1{Name: "1"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect cyclic dependency"), Validate) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect cyclic dependency"), func() { + _, _ = Get[*m.DisposableService1](context.Background()) + }) +} + +func TestValidate_LifetimeAlignment(t *testing.T) { + t.Run("Singleton depends on Scoped", func(t *testing.T) { + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "2"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) + }) + t.Run("Scoped depends on Transient", func(t *testing.T) { + clearAll() + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "2"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) + }) + t.Run("Singleton depends on Transient", func(t *testing.T) { + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 depends on 3 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "3"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) + }) +} + +func TestValidate_MissingDependency(t *testing.T) { + clearAll() + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 depends on 3 + return &m.DisposableService2{Name: "2"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService3, context.Context) { + _, ctx = Get[*m.DisposableService4](ctx) //3 depends on 4 + return &m.DisposableService3{Name: "3"}, ctx + }) + //forget to register 4 + assert2.PanicsWithError(t, assert2.ErrorStartsWith("implementation not found for type"), Validate) +} From 91edbcf4cf894945d401842e0eeb3f9bd15279a0 Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Thu, 7 Nov 2024 03:33:18 +0100 Subject: [PATCH 3/4] add benchmark and feature flags --- README.md | 2 +- examples/benchperf/README.md | 74 ++++++++ examples/benchperf/bench_test.go | 36 ++++ examples/benchperf/internal/DiOre.go | 52 +++++ examples/benchperf/internal/DiSamber.go | 48 +++++ examples/benchperf/internal/model.go | 240 ++++++++++++++++++++++++ examples/benchperf/main.go | 26 +++ examples/go.mod | 12 ++ examples/go.sum | 6 + examples/shutdownerdemo/main.go | 2 + ore.go | 25 ++- serviceResolver.go | 42 +++-- utils.go | 1 + validate_test.go | 76 ++++---- 14 files changed, 589 insertions(+), 53 deletions(-) create mode 100644 examples/benchperf/README.md create mode 100644 examples/benchperf/bench_test.go create mode 100644 examples/benchperf/internal/DiOre.go create mode 100644 examples/benchperf/internal/DiSamber.go create mode 100644 examples/benchperf/internal/model.go create mode 100644 examples/benchperf/main.go create mode 100644 examples/go.mod create mode 100644 examples/go.sum diff --git a/README.md b/README.md index 4004c2d..50a753a 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,7 @@ option 1 (run `ore.Validate` on test) is often a better choice. `ore.Validate()` invokes ALL your registered resolvers, it panics when something gone wrong. The purpose of this function is to panic early when the Container is bad configured: - Missing depedency: you forgot to register certain resolvers. -- Circular dependency: A depends on B whic depends on A. +- Circular dependency: A depends on B which depends on A. - Lifetime misalignment: a longer lifetime service (eg. Singleton) depends on a shorter one (eg Transient). ### Graceful application termination diff --git a/examples/benchperf/README.md b/examples/benchperf/README.md new file mode 100644 index 0000000..9f4ded0 --- /dev/null +++ b/examples/benchperf/README.md @@ -0,0 +1,74 @@ +# Benchmark comparison + +This sample will compare Ore (current commit of Nov 2024) to [samber/do/v2 v2.0.0-beta.7](https://github.com/samber/do). +We registered the below dependency graphs to both Ore and SamberDo, then ask them to create the concrete A. + +We will only benchmark the creation, not the registration. Because registration usually happens only once on application startup => + not very interesting to benchmark. + +## Data Model + +- This data model has only 2 singletons `F` and `Gb` => they will be created only once +- Every other concrete are `Transient` => they will be created each time the container create a new `A` +- We don't test the "Scoped" lifetime in this excercise because SamberDo doesn't has equivalent support for it. The "Scoped" functionality of SamberDo means "Sub Module" rather than a lifetime. + +```mermaid +flowchart TD +A["A
"] +B["B
"] +C["C
"] +D["D

"] +E["E

"] +F["F
Singleton"] +G(["G
(interface)"]) +Gb("Gb
Singleton") +Ga("Ga
") +DGa("DGa
(decorator)") +H(["H
(interface)
"]) +Hr["Hr
(real)"] +Hm["Hm
(mock)"] + +A-->B +A-->C +B-->D +B-->E +D-->H +D-->F +Hr -.implement..-> H +Hm -.implement..-> H +E-->DGa +E-->Gb +E-->Gc +DGa-->|decorate| Ga +Ga -.implement..-> G +Gb -.implement..-> G +Gc -.implement..-> G +DGa -.implement..-> G +``` + +## Run the benchmark by yourself + +```sh + go test -benchmem -bench . + ``` + +## Sample results + +On my machine, Ore always perform faster and use less memory than Samber/Do: + +```text +Benchmark_Ore-12 415822 2565 ns/op 2089 B/op 57 allocs/op +Benchmark_SamberDo-12 221941 4954 ns/op 2184 B/op 70 allocs/op +``` + +And with `ore.DisableValidation = true` + +```text +Benchmark_Ore-12 785088 1668 ns/op 1080 B/op 30 allocs/op +Benchmark_SamberDo-12 227851 4940 ns/op 2184 B/op 70 allocs/op +``` + +As any benchmarks, please take these number "relatively" as a general idea: + +- These numbers are probably outdated at the moment you are reading them +- You might got a very different numbers when running them on your machine or on production machine. diff --git a/examples/benchperf/bench_test.go b/examples/benchperf/bench_test.go new file mode 100644 index 0000000..0f48255 --- /dev/null +++ b/examples/benchperf/bench_test.go @@ -0,0 +1,36 @@ +package main + +import ( + "context" + i "examples/benchperf/internal" + "testing" + + "github.com/firasdarwish/ore" + "github.com/samber/do/v2" +) + +// func Benchmark_Ore_NoValidation(b *testing.B) { +// i.BuildContainerOre() +// ore.DisableValidation = true +// ctx := context.Background() +// b.ResetTimer() +// for n := 0; n < b.N; n++ { +// _, ctx = ore.Get[*i.A](ctx) +// } +// } + +var _ = i.BuildContainerOre() +var injector = i.BuildContainerDo() +var ctx = context.Background() + +func Benchmark_Ore(b *testing.B) { + for n := 0; n < b.N; n++ { + _, ctx = ore.Get[*i.A](ctx) + } +} + +func Benchmark_SamberDo(b *testing.B) { + for n := 0; n < b.N; n++ { + _ = do.MustInvoke[*i.A](injector) + } +} diff --git a/examples/benchperf/internal/DiOre.go b/examples/benchperf/internal/DiOre.go new file mode 100644 index 0000000..7c22471 --- /dev/null +++ b/examples/benchperf/internal/DiOre.go @@ -0,0 +1,52 @@ +package internal + +import ( + "context" + + "github.com/firasdarwish/ore" +) + +func BuildContainerOre() bool { + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*A, context.Context) { + b, ctx := ore.Get[*B](ctx) + c, ctx := ore.Get[*C](ctx) + return NewA(b, c), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*B, context.Context) { + d, ctx := ore.Get[*D](ctx) + e, ctx := ore.Get[*E](ctx) + return NewB(d, e), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*C, context.Context) { + return NewC(), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*D, context.Context) { + f, ctx := ore.Get[*F](ctx) + h, ctx := ore.Get[H](ctx) + return NewD(f, h), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*E, context.Context) { + gs, ctx := ore.GetList[G](ctx) + return NewE(gs), ctx + }) + ore.RegisterLazyFunc(ore.Singleton, func(ctx context.Context) (*F, context.Context) { + return NewF(), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*Ga, context.Context) { + return NewGa(), ctx + }) + ore.RegisterLazyFunc(ore.Singleton, func(ctx context.Context) (G, context.Context) { + return NewGb(), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (G, context.Context) { + return NewGc(), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (G, context.Context) { + ga, ctx := ore.Get[*Ga](ctx) + return NewDGa(ga), ctx + }) + ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (H, context.Context) { + return NewHr(), ctx + }) + return true +} diff --git a/examples/benchperf/internal/DiSamber.go b/examples/benchperf/internal/DiSamber.go new file mode 100644 index 0000000..1605875 --- /dev/null +++ b/examples/benchperf/internal/DiSamber.go @@ -0,0 +1,48 @@ +package internal + +import ( + "github.com/samber/do/v2" +) + +func BuildContainerDo() do.Injector { + injector := do.New() + do.ProvideTransient(injector, func(inj do.Injector) (*A, error) { + return NewA(do.MustInvoke[*B](inj), do.MustInvoke[*C](inj)), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (*B, error) { + return NewB(do.MustInvoke[*D](inj), do.MustInvoke[*E](inj)), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (*C, error) { + return NewC(), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (*D, error) { + return NewD(do.MustInvoke[*F](inj), do.MustInvoke[H](inj)), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (*E, error) { + gs := []G{ + do.MustInvoke[*DGa](inj), + do.MustInvoke[*Gb](inj), + do.MustInvoke[*Gc](inj), + } + return NewE(gs), nil + }) + do.Provide(injector, func(inj do.Injector) (*F, error) { + return NewF(), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (*Ga, error) { + return NewGa(), nil + }) + do.Provide(injector, func(inj do.Injector) (*Gb, error) { + return NewGb(), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (*Gc, error) { + return NewGc(), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (*DGa, error) { + return NewDGa(do.MustInvoke[*Ga](inj)), nil + }) + do.ProvideTransient(injector, func(inj do.Injector) (H, error) { + return NewHr(), nil + }) + return injector +} diff --git a/examples/benchperf/internal/model.go b/examples/benchperf/internal/model.go new file mode 100644 index 0000000..16855a1 --- /dev/null +++ b/examples/benchperf/internal/model.go @@ -0,0 +1,240 @@ +package internal + +import ( + "fmt" + "log" + "sync/atomic" + + "github.com/samber/do/v2" +) + +var counter uint64 +var countIdEnabled = false + +func ResetCounter() { + atomic.StoreUint64(&counter, 0) +} + +func generateId(prefix string) string { + if countIdEnabled { + return fmt.Sprintf("%s-%02d", prefix, atomic.AddUint64(&counter, 1)) + } + return prefix +} + +type A struct { + id string + b *B `do:""` + c *C `do:""` +} + +func NewA(b *B, c *C) *A { + return &A{id: generateId("A"), b: b, c: c} +} +func (this *A) ToString() string { + return fmt.Sprintf("%s { %s, %s }", this.id, this.b.ToString(), this.c.ToString()) +} + +type B struct { + id string + d *D `do:""` + e *E `do:""` +} + +func NewB(d *D, e *E) *B { + return &B{id: generateId("B"), d: d, e: e} +} +func (this *B) ToString() string { + return fmt.Sprintf("%s { %s, %s }", this.id, this.d.ToString(), this.e.ToString()) +} + +type C struct { + id string +} + +func NewC() *C { + return &C{id: generateId("C")} +} + +func (this *C) ToString() string { + return this.id +} + +type D struct { + id string + f *F `do:""` + h H `do:""` +} + +func NewD(f *F, h H) *D { + return &D{id: generateId("D"), f: f, h: h} +} +func (this *D) ToString() string { + return fmt.Sprintf("%s { %s, %s }", this.id, this.f.ToString(), this.h.ToString()) +} + +var _ do.Shutdowner = (*D)(nil) + +func (this *D) Shutdown() { + if countIdEnabled { + log.Println("Shutdown " + this.id) + } +} + +type E struct { + id string + g []G `do:""` +} + +func NewE(g []G) *E { + return &E{id: generateId("E"), g: g} +} +func (this *E) ToString() string { + resu := this.id + "{ " + for _, gItem := range this.g { + resu += gItem.ToString() + ", " + } + resu += " }" + return resu +} + +var _ do.Shutdowner = (*E)(nil) + +func (this *E) Shutdown() { + if countIdEnabled { + log.Println("Shutdown " + this.id) + } +} + +type F struct { + id string +} + +func NewF() *F { + return &F{id: generateId("F")} +} +func (this *F) ToString() string { + return this.id +} + +type G interface { + GetId() string + ToString() string +} + +var _ G = (*Ga)(nil) + +type Ga struct { + id string +} + +func NewGa() *Ga { + return &Ga{id: generateId("Ga")} +} +func (this *Ga) GetId() string { + return this.id +} +func (this *Ga) ToString() string { + return this.id +} + +var _ do.Shutdowner = (*Ga)(nil) + +func (this *Ga) Shutdown() { + if countIdEnabled { + log.Println("Shutdown " + this.id) + } +} + +var _ G = (*Gb)(nil) + +type Gb struct { + id string +} + +func NewGb() *Gb { + return &Gb{id: generateId("Gb")} +} +func (this *Gb) ToString() string { + return this.id +} + +func (this *Gb) GetId() string { + return this.id +} + +var _ G = (*Gc)(nil) + +type Gc struct { + id string +} + +func NewGc() *Gc { + return &Gc{id: generateId("Gc")} +} +func (this *Gc) ToString() string { + return this.id +} + +func (this *Gc) GetId() string { + return this.id +} + +var _ G = (*DGa)(nil) + +type DGa struct { + core *Ga `do:""` + id string +} + +func NewDGa(core *Ga) *DGa { + return &DGa{core: core, id: generateId("DGa")} +} +func (this *DGa) ToString() string { + return fmt.Sprintf("%s { %s }", this.id, this.core.ToString()) +} + +func (this *DGa) GetId() string { + return this.id +} + +type H interface { + do.Shutdowner + ToString() string +} + +type Hr struct { + id string +} + +var _ H = (*Hr)(nil) + +func NewHr() *Hr { + return &Hr{id: generateId("Hr")} +} +func (this *Hr) ToString() string { + return this.id +} +func (this *Hr) Shutdown() { + if countIdEnabled { + log.Println("Shutdown " + this.id) + } +} + +type Hm struct { + id string +} + +var _ H = (*Hm)(nil) + +func NewHm() *Hm { + return &Hm{id: generateId("Hm")} +} +func (this *Hm) ToString() string { + return this.id +} +func (this *Hm) Shutdown() { + if countIdEnabled { + log.Println("Shutdown " + this.id) + } +} diff --git a/examples/benchperf/main.go b/examples/benchperf/main.go new file mode 100644 index 0000000..8b09b1a --- /dev/null +++ b/examples/benchperf/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "context" + i "examples/benchperf/internal" + "log" + + "github.com/firasdarwish/ore" + "github.com/samber/do/v2" +) + +func main() { + i.BuildContainerOre() + a1, _ := ore.Get[*i.A](context.Background()) + log.Println(a1.ToString()) + a2, _ := ore.Get[*i.A](context.Background()) + log.Println(a2.ToString()) + + i.ResetCounter() + + injector := i.BuildContainerDo() + a3 := do.MustInvoke[*i.A](injector) + log.Println(a3.ToString()) + a4 := do.MustInvoke[*i.A](injector) + log.Println(a4.ToString()) +} diff --git a/examples/go.mod b/examples/go.mod new file mode 100644 index 0000000..7eb67ac --- /dev/null +++ b/examples/go.mod @@ -0,0 +1,12 @@ +module examples + +go 1.23.2 + +require github.com/firasdarwish/ore v0.3.0 + +require ( + github.com/samber/do/v2 v2.0.0-beta.7 // indirect + github.com/samber/go-type-to-string v1.4.0 // indirect +) + +replace github.com/firasdarwish/ore => ../ diff --git a/examples/go.sum b/examples/go.sum new file mode 100644 index 0000000..b30d0a2 --- /dev/null +++ b/examples/go.sum @@ -0,0 +1,6 @@ +github.com/firasdarwish/ore v0.3.0 h1:jk5g5xB7+5hhhkKoeafRVxfFBFnlI0KWx9Qeb276B00= +github.com/firasdarwish/ore v0.3.0/go.mod h1:Hii37a86OsbOZ6xNRG1TjXOguhekZgWY8VjtCbWIrTI= +github.com/samber/do/v2 v2.0.0-beta.7 h1:tmdLOVSCbTA6uGWLU5poi/nZvMRh5QxXFJ9vHytU+Jk= +github.com/samber/do/v2 v2.0.0-beta.7/go.mod h1:+LpV3vu4L81Q1JMZNSkMvSkW9lt4e5eJoXoZHkeBS4c= +github.com/samber/go-type-to-string v1.4.0 h1:KXphToZgiFdnJQxryU25brhlh/CqY/cwJVeX2rfmow0= +github.com/samber/go-type-to-string v1.4.0/go.mod h1:jpU77vIDoIxkahknKDoEx9C8bQ1ADnh2sotZ8I4QqBU= diff --git a/examples/shutdownerdemo/main.go b/examples/shutdownerdemo/main.go index 7b6fb1f..d300197 100644 --- a/examples/shutdownerdemo/main.go +++ b/examples/shutdownerdemo/main.go @@ -44,6 +44,8 @@ func main() { ore.RegisterEagerSingleton[*myGlobalRepo](&myGlobalRepo{}) ore.RegisterLazyCreator(ore.Scoped, &myScopedRepo{}) + ore.Validate() + wg := sync.WaitGroup{} wg.Add(1) diff --git a/ore.go b/ore.go index b0f7428..013c788 100644 --- a/ore.go +++ b/ore.go @@ -6,9 +6,25 @@ import ( ) var ( - lock = &sync.RWMutex{} - isBuilt = false - container = map[typeID][]serviceResolver{} + //DisableValidation set to true to skip validation. + // Use case: you called the [Validate] function (either in the test pipeline or on application startup). + // So you are confident that your registrations are good: + // + // - no missing dependencies + // - no circular dependencies + // - no lifetime misalignment (a longer lifetime service depends on a shorter one). + // + // You don't need Ore to validate over and over again each time it creates a new concrete. It's just a waste of resource + // especially when you will need Ore to create milion of transient concretes and any "pico" seconds or memory allocation matter for you + // + // In this case, you can put DisableValidation to false. + // + // This config would impact also the the [GetResolvedSingletons] and the [GetResolvedScopedInstances] functions, + // the returning order would be no longer guaranteed. + DisableValidation = false + lock = &sync.RWMutex{} + isBuilt = false + container = map[typeID][]serviceResolver{} //map the alias type (usually an interface) to the original types (usually implementations of the interface) aliases = map[pointerTypeName][]pointerTypeName{} @@ -92,6 +108,9 @@ func Build() { // - (2) cyclic dependency // - (3) lifetime misalignment (a longer lifetime service depends on a shorter one). func Validate() { + if DisableValidation { + panic("Validation is disabled") + } ctx := context.Background() for _, resolvers := range container { for _, resolver := range resolvers { diff --git a/serviceResolver.go b/serviceResolver.go index 5085612..6be13b8 100644 --- a/serviceResolver.go +++ b/serviceResolver.go @@ -74,18 +74,21 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret // get the current currentChain from the context var currentChain resolversChain - untypedCurrentChain := ctx.Value(contextKeyResolversChain) - if untypedCurrentChain == nil { - currentChain = list.New() - ctx = context.WithValue(ctx, contextKeyResolversChain, currentChain) - } else { - currentChain = untypedCurrentChain.(resolversChain) - } - - // push this newest resolver to the resolversChain - marker := appendResolver(currentChain, this.resolverMetadata) + var marker *list.Element + if !DisableValidation { + untypedCurrentChain := ctx.Value(contextKeyResolversChain) + if untypedCurrentChain == nil { + currentChain = list.New() + ctx = context.WithValue(ctx, contextKeyResolversChain, currentChain) + } else { + currentChain = untypedCurrentChain.(resolversChain) + } + // push this newest resolver to the resolversChain + marker = appendResolver(currentChain, this.resolverMetadata) + } var concreteValue T + createdAt := time.Now() // first, try make concrete implementation from `anonymousInitializer` // if nil, try the concrete implementation `Creator` @@ -95,15 +98,18 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret concreteValue, ctx = this.creatorInstance.New(ctx) } - invocationLevel := currentChain.Len() + invocationLevel := 0 + if !DisableValidation { + invocationLevel = currentChain.Len() - // the concreteValue is created, we must to remove it from the resolversChain so that downstream resolvers (meaning the future resolvers) won't link to it - currentChain.Remove(marker) + // the concreteValue is created, we must to remove it from the resolversChain so that downstream resolvers (meaning the future resolvers) won't link to it + currentChain.Remove(marker) + } con := &concrete{ value: concreteValue, lifetime: this.lifetime, - createdAt: time.Now(), + createdAt: createdAt, invocationLevel: invocationLevel, } @@ -166,3 +172,11 @@ func addToContextKeysRepository(ctx context.Context, newContextKey contextKey) c func (this resolverMetadata) String() string { return fmt.Sprintf("Resolver(%s, type={%s}, key='%s')", this.lifetime, getUnderlyingTypeName(this.id.pointerTypeName), this.id.oreKey) } + +// func toString(resolversChain resolversChain) string { +// var sb string +// for e := resolversChain.Front(); e != nil; e = e.Next() { +// sb = fmt.Sprintf("%s%s\n", sb, e.Value.(resolverMetadata).String()) +// } +// return sb +// } diff --git a/utils.go b/utils.go index 6ead85e..c34a01b 100644 --- a/utils.go +++ b/utils.go @@ -26,6 +26,7 @@ func clearAll() { container = make(map[typeID][]serviceResolver) aliases = make(map[pointerTypeName][]pointerTypeName) isBuilt = false + DisableValidation = false } // Get type name of *T. diff --git a/validate_test.go b/validate_test.go index e2b20a2..a0d6fb9 100644 --- a/validate_test.go +++ b/validate_test.go @@ -153,44 +153,42 @@ func TestValidate_CircularMixedLifetype(t *testing.T) { }) } -func TestValidate_LifetimeAlignment(t *testing.T) { - t.Run("Singleton depends on Scoped", func(t *testing.T) { - clearAll() - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { - _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 - return &m.DisposableService1{Name: "1"}, ctx - }) - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { - return &m.DisposableService2{Name: "2"}, ctx - }) - assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) +func TestValidate_LifetimeAlignment_SingletonCallsScoped(t *testing.T) { + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 + return &m.DisposableService1{Name: "1"}, ctx }) - t.Run("Scoped depends on Transient", func(t *testing.T) { - clearAll() - RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { - _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 - return &m.DisposableService1{Name: "1"}, ctx - }) - RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService2, context.Context) { - return &m.DisposableService2{Name: "2"}, ctx - }) - assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "2"}, ctx }) - t.Run("Singleton depends on Transient", func(t *testing.T) { - clearAll() - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { - _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 - return &m.DisposableService1{Name: "1"}, ctx - }) - RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService2, context.Context) { - _, ctx = Get[*m.DisposableService3](ctx) //2 depends on 3 - return &m.DisposableService2{Name: "2"}, ctx - }) - RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService3, context.Context) { - return &m.DisposableService3{Name: "3"}, ctx - }) - assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) +} +func TestValidate_LifetimeAlignment_ScopedCallsTransient(t *testing.T) { + clearAll() + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "2"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) +} +func TestValidate_LifetimeAlignment_SingletonCallsTransient(t *testing.T) { + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + _, ctx = Get[*m.DisposableService2](ctx) //1 depends on 2 + return &m.DisposableService1{Name: "1"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService2, context.Context) { + _, ctx = Get[*m.DisposableService3](ctx) //2 depends on 3 + return &m.DisposableService2{Name: "2"}, ctx }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "3"}, ctx + }) + assert2.PanicsWithError(t, assert2.ErrorStartsWith("detect lifetime misalignment"), Validate) } func TestValidate_MissingDependency(t *testing.T) { @@ -210,3 +208,11 @@ func TestValidate_MissingDependency(t *testing.T) { //forget to register 4 assert2.PanicsWithError(t, assert2.ErrorStartsWith("implementation not found for type"), Validate) } + +// func TestValidate_DisableValidation(t *testing.T) { +// clearAll() +// DisableValidation = true +// assert.Panics(t, Validate) +// DisableValidation = false +// assert.NotPanics(t, Validate) +// } From 74aa0513f9f4d03c4d1e98be100009ddf258139c Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Fri, 8 Nov 2024 10:10:12 +0100 Subject: [PATCH 4/4] enhance some english wordings --- README.md | 48 +++++++++++++++++++---------- concrete.go | 19 +++++++----- examples/benchperf/README.md | 6 ++-- get_test.go | 8 +++++ getters.go | 4 +-- ore.go | 13 ++++---- registrars.go | 6 ++-- serviceResolver.go | 60 +++++++++++++++++++----------------- 8 files changed, 98 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 50a753a..6039e43 100644 --- a/README.md +++ b/README.md @@ -323,20 +323,31 @@ Alias is also scoped by key. When you "Get" an alias with keys for eg: `ore.Get[ ### Registration validation -It is recommended to build your container (which means register ALL the resolvers) only ONCE on application start. -Next, it is recommended to call `ore.Validate()` +Once finishing all your registrations, it is recommended to call `ore.Validate()`. -- either in a test which is automatically run on your CI/CD (option 1) -- or on application start, just after resolvers registration (option 2) - -option 1 (run `ore.Validate` on test) is often a better choice. - -`ore.Validate()` invokes ALL your registered resolvers, it panics when something gone wrong. The purpose of this function is to panic early when the Container is bad configured: +`ore.Validate()` invokes ALL your registered resolvers. The purpose is to panic early if your registrations were in bad shape: - Missing depedency: you forgot to register certain resolvers. - Circular dependency: A depends on B which depends on A. - Lifetime misalignment: a longer lifetime service (eg. Singleton) depends on a shorter one (eg Transient). +### Registration recommendation + +(1) You should call `ore.Validate()` + +- either in a test which is automatically run on your CI/CD pipeline (option 1) +- or on application start, just after all the registrations (option 2) + +option 1 (run `ore.Validate` on test) is usually a better choice. + +(2) It is recommended to build your container (which means register ALL the resolvers) only ONCE on application start => Please don't call `ore.RegisterXX` all over the place. + +(3) Keep the object creation function (a.k.a resolvers) simple. Their only responsibility should be **object creation**. + +- they should not spawn new goroutine +- they should not open database connection +- they should not contain any "if" statement or other business logic + ### Graceful application termination On application termination, you want to call `Shutdown()` on all the "Singletons" objects which have been created during the application life time. @@ -353,13 +364,15 @@ ore.RegisterEagerSingleton(&Logger{}) //*Logger implements Shutdowner ore.RegisterEagerSingleton(&SomeRepository{}) //*SomeRepository implements Shutdowner ore.RegisterEagerSingleton(&SomeService{}, "some_module") //*SomeService implements Shutdowner -//On application termination, Ore can help to retreive all the singletons implementation of the `Shutdowner` interface. -//There might be other `Shutdowner`'s implementation which were lazily registered but have never been created (a.k.a invoked). +//On application termination, Ore can help to retreive all the singletons implementation +//of the `Shutdowner` interface. +//There might be other `Shutdowner`'s implementation which were lazily registered but +//have never been created. //Ore will ignore them, and return only the concrete instances which can be Shutdown() shutdowables := ore.GetResolvedSingletons[Shutdowner]() //Now we can Shutdown() them all and gracefully terminate our application. -//The most recently created instance will be Shutdown() first +//The most recently invoked instance will be Shutdown() first for _, instance := range disposables { instance.Shutdown() } @@ -369,8 +382,8 @@ In resume, the `ore.GetResolvedSingletons[TInterface]()` function returns a list - It returns only the instances which had been invoked (a.k.a resolved). - All the implementations including "keyed" one will be returned. -- The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the "most recently created" one. - - if "A" depends on "B", "C", Ore will make sure to return "B" and "C" first in the list so that they would be shutdowned before "A". However Ore won't guarantee the order of "B" and "C" +- The returned instances are sorted by the invocation order, the first one being lastest invoked one. + - if "A" depends on "B", "C", Ore will make sure to return "B" and "C" first in the list so that they would be shutdowned before "A". ### Graceful context termination @@ -394,7 +407,7 @@ go func() { <-ctx.Done() // Wait for the context to be canceled // Perform your cleanup tasks here disposables := ore.GetResolvedScopedInstances[Disposer](ctx) - //The most recently created instance will be Dispose() first + //The most recently invoked instance will be Dispose() first for _, d := range disposables { _ = d.Dispose(ctx) } @@ -409,8 +422,8 @@ The `ore.GetResolvedScopedInstances[TInterface](context)` function returns a lis - It returns only the instances which had been invoked (a.k.a resolved) during the context life time. - All the implementations including "keyed" one will be returned. -- The returned instances are sorted by invocation order, the first one being the most "recently created" one. - - if "A" depends on "B", "C", Ore will make sure to return "B" and "C" first in the list so that they would be Disposed before "A". However Ore won't guarantee the order of "B" and "C" +- The returned instances are sorted by invocation order, the first one being the lastest invoked one. + - if "A" depends on "B", "C", Ore will make sure to return "B" and "C" first in the list so that they would be Disposed before "A". ## More Complex Example @@ -480,6 +493,9 @@ BenchmarkGetList BenchmarkGetList-20 1852132 637.0 ns/op ``` +Checkout also [examples/benchperf/README.md](examples/benchperf/README.md) + + # 👤 Contributors ![Contributors](https://contrib.rocks/image?repo=firasdarwish/ore) diff --git a/concrete.go b/concrete.go index c27cbd4..e55e53c 100644 --- a/concrete.go +++ b/concrete.go @@ -4,14 +4,19 @@ import "time" // concrete holds the resolved instance value and other metadata type concrete struct { - // the value implementation + //the value implementation value any - // the creation time - createdAt time.Time - // the lifetime of this concrete + + //invocationTime is the time when the resolver had been invoked, it is different from the "creationTime" + //of the concrete. Eg: A calls B, then the invocationTime of A is before B, but the creationTime of A is after B + //(because B was created before A) + invocationTime time.Time + + //the lifetime of this concrete lifetime Lifetime - // the invocation deep level, the bigger the value, the deeper it was resolved in the dependency chain - // for example: A depends on B, B depends on C, C depends on D - // A will have invocationLevel = 1, B = 2, C = 3, D = 4 + + //the invocation deep level, the bigger the value, the deeper it was resolved in the dependency chain + //for example: A depends on B, B depends on C, C depends on D + //A will have invocationLevel = 1, B = 2, C = 3, D = 4 invocationLevel int } diff --git a/examples/benchperf/README.md b/examples/benchperf/README.md index 9f4ded0..58ba7cd 100644 --- a/examples/benchperf/README.md +++ b/examples/benchperf/README.md @@ -1,7 +1,7 @@ # Benchmark comparison This sample will compare Ore (current commit of Nov 2024) to [samber/do/v2 v2.0.0-beta.7](https://github.com/samber/do). -We registered the below dependency graphs to both Ore and SamberDo, then ask them to create the concrete A. +We registered the below dependency graphs to both Ore and SamberDo, then ask them to create the concrete `A`. We will only benchmark the creation, not the registration. Because registration usually happens only once on application startup => not very interesting to benchmark. @@ -9,8 +9,8 @@ We will only benchmark the creation, not the registration. Because registration ## Data Model - This data model has only 2 singletons `F` and `Gb` => they will be created only once -- Every other concrete are `Transient` => they will be created each time the container create a new `A` -- We don't test the "Scoped" lifetime in this excercise because SamberDo doesn't has equivalent support for it. The "Scoped" functionality of SamberDo means "Sub Module" rather than a lifetime. +- Other concretes are `Transient` => they will be created each time the container create a new `A` concrete. +- We don't test the "Scoped" lifetime in this excercise because SamberDo doesn't have equivalent support for it. [The "Scoped" functionality of SamberDo](https://do.samber.dev/docs/container/scope) means "Sub Module" rather than a lifetime. ```mermaid flowchart TD diff --git a/get_test.go b/get_test.go index 98df359..1a9b797 100644 --- a/get_test.go +++ b/get_test.go @@ -74,6 +74,14 @@ func TestGetKeyed(t *testing.T) { } } +func TestGetKeyedUnhashable(t *testing.T) { + RegisterLazyCreator(Singleton, &simpleCounter{}, "a") + _, _ = Get[someCounter](context.Background(), "a") + + RegisterLazyCreator(Singleton, &simpleCounter{}, []string{"a", "b"}) + _, _ = Get[someCounter](context.Background(), []string{"a", "b"}) +} + func TestGetResolvedSingletons(t *testing.T) { t.Run("When multiple lifetimes and keys are registered", func(t *testing.T) { //Arrange diff --git a/getters.go b/getters.go index a27d202..fb592be 100644 --- a/getters.go +++ b/getters.go @@ -169,8 +169,8 @@ func GetResolvedScopedInstances[TInterface any](ctx context.Context) []TInterfac func sortAndSelect[TInterface any](list []*concrete) []TInterface { //sorting sort.Slice(list, func(i, j int) bool { - return list[i].createdAt.After(list[j].createdAt) || - (list[i].createdAt == list[j].createdAt && + return list[i].invocationTime.After(list[j].invocationTime) || + (list[i].invocationTime == list[j].invocationTime && list[i].invocationLevel > list[j].invocationLevel) }) diff --git a/ore.go b/ore.go index 013c788..776893d 100644 --- a/ore.go +++ b/ore.go @@ -6,7 +6,7 @@ import ( ) var ( - //DisableValidation set to true to skip validation. + //DisableValidation is false by default, Set to true to skip validation. // Use case: you called the [Validate] function (either in the test pipeline or on application startup). // So you are confident that your registrations are good: // @@ -14,10 +14,11 @@ var ( // - no circular dependencies // - no lifetime misalignment (a longer lifetime service depends on a shorter one). // - // You don't need Ore to validate over and over again each time it creates a new concrete. It's just a waste of resource - // especially when you will need Ore to create milion of transient concretes and any "pico" seconds or memory allocation matter for you + // You don't need Ore to validate over and over again each time it creates a new concrete. + // It's a waste of resource especially when you will need Ore to create milion of transient concretes + // and any "pico" seconds or memory allocation matter for you. // - // In this case, you can put DisableValidation to false. + // In this case, you can set DisableValidation = true. // // This config would impact also the the [GetResolvedSingletons] and the [GetResolvedScopedInstances] functions, // the returning order would be no longer guaranteed. @@ -31,8 +32,8 @@ var ( //contextKeysRepositoryID is a special context key. The value of this key is the collection of other context keys stored in the context. contextKeysRepositoryID specialContextKey = "The context keys repository" - //contextKeyResolversChain is a special context key. The value of this key is the [ResolversChain]. - contextKeyResolversChain specialContextKey = "Dependencies chain" + //contextKeyResolversStack is a special context key. The value of this key is the [ResolversStack]. + contextKeyResolversStack specialContextKey = "Dependencies stack" ) type contextKeysRepository = []contextKey diff --git a/registrars.go b/registrars.go index 871575a..4a0bc56 100644 --- a/registrars.go +++ b/registrars.go @@ -32,9 +32,9 @@ func RegisterEagerSingleton[T comparable](impl T, key ...KeyStringer) { lifetime: Singleton, }, singletonConcrete: &concrete{ - value: impl, - lifetime: Singleton, - createdAt: time.Now(), + value: impl, + lifetime: Singleton, + invocationTime: time.Now(), }, } appendToContainer[T](e, key) diff --git a/serviceResolver.go b/serviceResolver.go index 6be13b8..9c57ec3 100644 --- a/serviceResolver.go +++ b/serviceResolver.go @@ -30,9 +30,10 @@ type serviceResolverImpl[T any] struct { singletonConcrete *concrete } -// resolversChain is a linkedList[resolverMetadata], describing a dependencies chain which a resolver has to invoke other resolvers to resolve its dependencies. -// Before a resolver creates a new concrete value it would be registered to the resolversChain. -// Once the concrete is resolved (with help of other resolvers), then it would be removed from the chain. +// resolversStack is a stack of [resolverMetadata], similar to a call stack describing How a resolver has +// to call other resolvers to resolve its dependencies. +// Before a resolver creates a new concrete value it would be registered (pushed) to the stack. +// Once the concrete is resolved (with help of other resolvers), then it would be removed (poped) from the stack. // // While a Resolver forms a tree with other dependent resolvers. // @@ -40,18 +41,18 @@ type serviceResolverImpl[T any] struct { // // A calls B and C; B calls D; C calls E. // -// then resolversChain is a "path" in the tree from the root to one of the bottom. +// then resolversStack holds a "path" in the tree from the root to one of the bottom. // // Example: // // A -> B -> D or A -> C -> E // -// The resolversChain is stored in the context. Analyze the chain will help to +// The resolversStack is stored in the context. Analyze the stack will help to // // - (1) detect the invocation level // - (2) detect cyclic dependencies // - (3) detect lifetime misalignment (when a service of longer lifetime depends on a service of shorter lifetime) -type resolversChain = *list.List +type resolversStack = *list.List // make sure that the `serviceResolverImpl` struct implements the `serviceResolver` interface var _ serviceResolver = serviceResolverImpl[any]{} @@ -70,25 +71,25 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret } } - // this resolver is about to create a new concrete value, we have to put it to the resolversChain until the creation done + // this resolver is about to create a new concrete value, we have to put it to the resolversStack until the creation done - // get the current currentChain from the context - var currentChain resolversChain + // get the current currentStack from the context + var currentStack resolversStack var marker *list.Element if !DisableValidation { - untypedCurrentChain := ctx.Value(contextKeyResolversChain) - if untypedCurrentChain == nil { - currentChain = list.New() - ctx = context.WithValue(ctx, contextKeyResolversChain, currentChain) + untypedCurrentStack := ctx.Value(contextKeyResolversStack) + if untypedCurrentStack == nil { + currentStack = list.New() + ctx = context.WithValue(ctx, contextKeyResolversStack, currentStack) } else { - currentChain = untypedCurrentChain.(resolversChain) + currentStack = untypedCurrentStack.(resolversStack) } - // push this newest resolver to the resolversChain - marker = appendResolver(currentChain, this.resolverMetadata) + // push the current resolver to the resolversStack + marker = pushToStack(currentStack, this.resolverMetadata) } var concreteValue T - createdAt := time.Now() + invocationTime := time.Now() // first, try make concrete implementation from `anonymousInitializer` // if nil, try the concrete implementation `Creator` @@ -100,16 +101,17 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret invocationLevel := 0 if !DisableValidation { - invocationLevel = currentChain.Len() + invocationLevel = currentStack.Len() - // the concreteValue is created, we must to remove it from the resolversChain so that downstream resolvers (meaning the future resolvers) won't link to it - currentChain.Remove(marker) + //the concreteValue is created, we must to pop the current resolvers from the stack + //so that future resolvers won't link to it + currentStack.Remove(marker) } con := &concrete{ value: concreteValue, lifetime: this.lifetime, - createdAt: createdAt, + invocationTime: invocationTime, invocationLevel: invocationLevel, } @@ -130,25 +132,25 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context) (*concret return con, ctx } -// appendToResolversChain push the given resolver to the Back of the ResolversChain. +// pushToStack appends the given resolver to the Back of the given resolversStack. // `marker.previous` refers to the calling (parent) resolver -func appendResolver(chain resolversChain, currentResolver resolverMetadata) (marker *list.Element) { - if chain.Len() != 0 { +func pushToStack(stack resolversStack, currentResolver resolverMetadata) (marker *list.Element) { + if stack.Len() != 0 { //detect lifetime misalignment - lastElem := chain.Back() + lastElem := stack.Back() lastResolver := lastElem.Value.(resolverMetadata) if lastResolver.lifetime > currentResolver.lifetime { panic(lifetimeMisalignment(lastResolver, currentResolver)) } //detect cyclic dependencies - for e := chain.Back(); e != nil; e = e.Prev() { + for e := stack.Back(); e != nil; e = e.Prev() { if e.Value.(resolverMetadata).id == currentResolver.id { panic(cyclicDependency(currentResolver)) } } } - marker = chain.PushBack(currentResolver) // `marker.previous` refers to the calling (parent) resolver + marker = stack.PushBack(currentResolver) // `marker.previous` refers to the calling (parent) resolver return marker } @@ -173,9 +175,9 @@ func (this resolverMetadata) String() string { return fmt.Sprintf("Resolver(%s, type={%s}, key='%s')", this.lifetime, getUnderlyingTypeName(this.id.pointerTypeName), this.id.oreKey) } -// func toString(resolversChain resolversChain) string { +// func toString(resolversStack resolversStack) string { // var sb string -// for e := resolversChain.Front(); e != nil; e = e.Next() { +// for e := resolversStack.Front(); e != nil; e = e.Next() { // sb = fmt.Sprintf("%s%s\n", sb, e.Value.(resolverMetadata).String()) // } // return sb