From 4755e95933e18cad2f53c371446f7f7fe78f55be Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Wed, 30 Oct 2024 00:33:44 +0100 Subject: [PATCH 1/6] refactor: fix codes smell reported by go linter * fix S1028: should use fmt.Errorf(...) instead of errors.New(fmt.Sprintf(...)) (gosimple) * fix SA5009: Printf format %c has arg #1 of wrong type *T (staticcheck) * fix SA1029: should not use built-in string type as context key. => Add more string alias for logic and type safety * fix SA9004: only the first constant in this group has an explicit type (staticcheck) * rename "entry" to "serviceResolver" * make serviceResolver return "any" instead of T so that resolving alias would be possible later * enhance codes duplications ("replaceEntry") * make many test structs (for eg. "Counter") private, so that invisible from consumers * fix others warnings in tests codes --- README.md | 4 +++ benchmarks_test.go | 22 ++++++------- creator_test.go | 70 ++++++++++++++++++++--------------------- eager_singleton_test.go | 48 ++++++++++++++-------------- entry.go | 54 ------------------------------- errors.go | 4 +-- get_list_test.go | 16 +++++----- get_test.go | 18 +++++------ getters.go | 52 ++++++++++++------------------ initializer_test.go | 70 ++++++++++++++++++++--------------------- lifetimes.go | 4 +-- ore.go | 27 +++++++++------- ore_test.go | 25 +++++++++++++-- registrars.go | 10 +++--- serviceResolver.go | 64 +++++++++++++++++++++++++++++++++++++ test.go | 41 +++--------------------- utils.go | 17 ++++++++-- 17 files changed, 275 insertions(+), 271 deletions(-) delete mode 100644 entry.go create mode 100644 serviceResolver.go diff --git a/README.md b/README.md index 49772d9..35dc573 100644 --- a/README.md +++ b/README.md @@ -245,6 +245,10 @@ func main() { ``` +#### Injecting Mocks in Tests + +The last registered implementation takes precedence, so you can register a mock implementation in the test, which will override the real implementation. +
### Keyed Services Retrieval Example diff --git a/benchmarks_test.go b/benchmarks_test.go index 8086628..db190bf 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -10,7 +10,7 @@ func BenchmarkRegisterLazyFunc(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - RegisterLazyFunc[Counter](Scoped, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](Scoped, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) } @@ -21,7 +21,7 @@ func BenchmarkRegisterLazyCreator(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - RegisterLazyCreator[Counter](Scoped, &simpleCounter{}) + RegisterLazyCreator[someCounter](Scoped, &simpleCounter{}) } } @@ -30,44 +30,44 @@ func BenchmarkRegisterEagerSingleton(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - RegisterEagerSingleton[Counter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) } } func BenchmarkGet(b *testing.B) { clearAll() - RegisterLazyFunc[Counter](Scoped, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](Scoped, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - RegisterEagerSingleton[Counter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) - RegisterLazyCreator[Counter](Scoped, &simpleCounter{}) + RegisterLazyCreator[someCounter](Scoped, &simpleCounter{}) ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { - Get[Counter](ctx) + Get[someCounter](ctx) } } func BenchmarkGetList(b *testing.B) { clearAll() - RegisterLazyFunc[Counter](Scoped, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](Scoped, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - RegisterEagerSingleton[Counter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) - RegisterLazyCreator[Counter](Scoped, &simpleCounter{}) + RegisterLazyCreator[someCounter](Scoped, &simpleCounter{}) ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { - GetList[Counter](ctx) + GetList[someCounter](ctx) } } diff --git a/creator_test.go b/creator_test.go index 8c0251f..3fe7f41 100644 --- a/creator_test.go +++ b/creator_test.go @@ -9,9 +9,9 @@ func TestRegisterLazyCreator(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - c, _ := Get[Counter](context.Background()) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() @@ -25,32 +25,32 @@ func TestRegisterLazyCreator(t *testing.T) { func TestRegisterLazyCreatorNilFuncTransient(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyCreator[Counter](Transient, nil) + RegisterLazyCreator[someCounter](Transient, nil) } func TestRegisterLazyCreatorNilFuncScoped(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyCreator[Counter](Scoped, nil) + RegisterLazyCreator[someCounter](Scoped, nil) } func TestRegisterLazyCreatorNilFuncSingleton(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyCreator[Counter](Singleton, nil) + RegisterLazyCreator[someCounter](Singleton, nil) } func TestRegisterLazyCreatorMultipleImplementations(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - counters, _ := GetList[Counter](context.Background()) + counters, _ := GetList[someCounter](context.Background()) if got := len(counters); got != 3 { t.Errorf("got %v, expected %v", got, 3) @@ -62,13 +62,13 @@ func TestRegisterLazyCreatorMultipleImplementationsKeyed(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}, "firas") + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}, "firas") - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}, "firas") + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}, "firas") - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - counters, _ := GetList[Counter](context.Background(), "firas") + counters, _ := GetList[someCounter](context.Background(), "firas") if got := len(counters); got != 2 { t.Errorf("got %v, expected %v", got, 2) @@ -81,16 +81,16 @@ func TestRegisterLazyCreatorSingletonState(t *testing.T) { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - c, _ := Get[Counter](context.Background()) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() - c, _ = Get[Counter](context.Background()) + c, _ = Get[someCounter](context.Background()) c.AddOne() - c, _ = Get[Counter](context.Background()) + c, _ = Get[someCounter](context.Background()) c.AddOne() c.AddOne() c.AddOne() @@ -105,18 +105,18 @@ func TestRegisterLazyCreatorScopedState(t *testing.T) { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) ctx := context.Background() - c, ctx := Get[Counter](ctx) + c, ctx := Get[someCounter](ctx) c.AddOne() c.AddOne() - c, ctx = Get[Counter](ctx) + c, ctx = Get[someCounter](ctx) c.AddOne() - c, ctx = Get[Counter](ctx) + c, _ = Get[someCounter](ctx) c.AddOne() c.AddOne() c.AddOne() @@ -131,18 +131,18 @@ func TestRegisterLazyCreatorTransientState(t *testing.T) { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) ctx := context.Background() - c, ctx := Get[Counter](ctx) + c, ctx := Get[someCounter](ctx) c.AddOne() c.AddOne() - c, ctx = Get[Counter](ctx) + c, ctx = Get[someCounter](ctx) c.AddOne() - c, ctx = Get[Counter](ctx) + c, _ = Get[someCounter](ctx) c.AddOne() c.AddOne() c.AddOne() @@ -155,24 +155,24 @@ func TestRegisterLazyCreatorTransientState(t *testing.T) { func TestRegisterLazyCreatorNilKeyOnRegistering(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyCreator[Counter](Scoped, &simpleCounter{}, nil) + RegisterLazyCreator[someCounter](Scoped, &simpleCounter{}, nil) } func TestRegisterLazyCreatorNilKeyOnGetting(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyCreator[Counter](Scoped, &simpleCounter{}, "firas") + RegisterLazyCreator[someCounter](Scoped, &simpleCounter{}, "firas") - Get[Counter](context.Background(), nil) + Get[someCounter](context.Background(), nil) } func TestRegisterLazyCreatorGeneric(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[CounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) + RegisterLazyCreator[someCounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) - c, _ := Get[CounterGeneric[uint]](context.Background()) + c, _ := Get[someCounterGeneric[uint]](context.Background()) c.Add(5) c.Add(5) @@ -187,13 +187,13 @@ func TestRegisterLazyCreatorMultipleGenericImplementations(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[CounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) + RegisterLazyCreator[someCounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) - RegisterLazyCreator[CounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) + RegisterLazyCreator[someCounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) - RegisterLazyCreator[CounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) + RegisterLazyCreator[someCounterGeneric[uint]](registrationType, &counterGeneric[uint]{}) - counters, _ := GetList[CounterGeneric[uint]](context.Background()) + counters, _ := GetList[someCounterGeneric[uint]](context.Background()) if got := len(counters); got != 3 { t.Errorf("got %v, expected %v", got, 3) @@ -216,7 +216,7 @@ func TestRegisterLazyCreatorScopedNested(t *testing.T) { a2, ctx := Get[*a](ctx) a2.C.Counter += 1 - a3, ctx := Get[*a](ctx) + a3, _ := Get[*a](ctx) a3.C.Counter += 1 if got := a2.C.Counter; got != 3 { diff --git a/eager_singleton_test.go b/eager_singleton_test.go index 71386f8..913c573 100644 --- a/eager_singleton_test.go +++ b/eager_singleton_test.go @@ -8,9 +8,9 @@ import ( func TestRegisterEagerSingleton(t *testing.T) { clearAll() - RegisterEagerSingleton[Counter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) - c, _ := Get[Counter](context.Background()) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() @@ -23,17 +23,17 @@ func TestRegisterEagerSingleton(t *testing.T) { func TestRegisterEagerSingletonNilImplementation(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterEagerSingleton[Counter](nil) + RegisterEagerSingleton[someCounter](nil) } func TestRegisterEagerSingletonMultipleImplementations(t *testing.T) { clearAll() - RegisterEagerSingleton[Counter](&simpleCounter{}) - RegisterEagerSingleton[Counter](&simpleCounter{}) - RegisterEagerSingleton[Counter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) - counters, _ := GetList[Counter](context.Background()) + counters, _ := GetList[someCounter](context.Background()) if got := len(counters); got != 3 { t.Errorf("got %v, expected %v", got, 3) @@ -43,12 +43,12 @@ func TestRegisterEagerSingletonMultipleImplementations(t *testing.T) { func TestRegisterEagerSingletonMultipleImplementationsKeyed(t *testing.T) { clearAll() - RegisterEagerSingleton[Counter](&simpleCounter{}, "firas") - RegisterEagerSingleton[Counter](&simpleCounter{}, "firas") + RegisterEagerSingleton[someCounter](&simpleCounter{}, "firas") + RegisterEagerSingleton[someCounter](&simpleCounter{}, "firas") - RegisterEagerSingleton[Counter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) - counters, _ := GetList[Counter](context.Background(), "firas") + counters, _ := GetList[someCounter](context.Background(), "firas") if got := len(counters); got != 2 { t.Errorf("got %v, expected %v", got, 2) @@ -58,16 +58,16 @@ func TestRegisterEagerSingletonMultipleImplementationsKeyed(t *testing.T) { func TestRegisterEagerSingletonSingletonState(t *testing.T) { clearAll() - RegisterEagerSingleton[Counter](&simpleCounter{}) + RegisterEagerSingleton[someCounter](&simpleCounter{}) - c, _ := Get[Counter](context.Background()) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() - c, _ = Get[Counter](context.Background()) + c, _ = Get[someCounter](context.Background()) c.AddOne() - c, _ = Get[Counter](context.Background()) + c, _ = Get[someCounter](context.Background()) c.AddOne() c.AddOne() c.AddOne() @@ -80,23 +80,23 @@ func TestRegisterEagerSingletonSingletonState(t *testing.T) { func TestRegisterEagerSingletonNilKeyOnRegistering(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterEagerSingleton[Counter](&simpleCounter{}, nil) + RegisterEagerSingleton[someCounter](&simpleCounter{}, nil) } func TestRegisterEagerSingletonNilKeyOnGetting(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterEagerSingleton[Counter](&simpleCounter{}, "firas") + RegisterEagerSingleton[someCounter](&simpleCounter{}, "firas") - Get[Counter](context.Background(), nil) + Get[someCounter](context.Background(), nil) } func TestRegisterEagerSingletonGeneric(t *testing.T) { clearAll() - RegisterEagerSingleton[CounterGeneric[uint]](&counterGeneric[uint]{}) + RegisterEagerSingleton[someCounterGeneric[uint]](&counterGeneric[uint]{}) - c, _ := Get[CounterGeneric[uint]](context.Background()) + c, _ := Get[someCounterGeneric[uint]](context.Background()) c.Add(5) c.Add(5) @@ -109,11 +109,11 @@ func TestRegisterEagerSingletonGeneric(t *testing.T) { func TestRegisterEagerSingletonMultipleGenericImplementations(t *testing.T) { clearAll() - RegisterEagerSingleton[CounterGeneric[uint]](&counterGeneric[uint]{}) - RegisterEagerSingleton[CounterGeneric[uint]](&counterGeneric[uint]{}) - RegisterEagerSingleton[CounterGeneric[uint]](&counterGeneric[uint]{}) + RegisterEagerSingleton[someCounterGeneric[uint]](&counterGeneric[uint]{}) + RegisterEagerSingleton[someCounterGeneric[uint]](&counterGeneric[uint]{}) + RegisterEagerSingleton[someCounterGeneric[uint]](&counterGeneric[uint]{}) - counters, _ := GetList[CounterGeneric[uint]](context.Background()) + counters, _ := GetList[someCounterGeneric[uint]](context.Background()) if got := len(counters); got != 3 { t.Errorf("got %v, expected %v", got, 3) diff --git a/entry.go b/entry.go deleted file mode 100644 index 2c2a107..0000000 --- a/entry.go +++ /dev/null @@ -1,54 +0,0 @@ -package ore - -import "context" - -type ( - Initializer[T any] func(ctx context.Context) (T, context.Context) -) - -type entry[T any] struct { - anonymousInitializer *Initializer[T] - creatorInstance Creator[T] - concrete *T - lifetime Lifetime -} - -func (i *entry[T]) load(ctx context.Context, ctxTidVal string) (T, context.Context, bool) { - // try get concrete implementation - if i.lifetime == Singleton && i.concrete != nil { - return *i.concrete, ctx, false - } - - // try get concrete from context scope - if i.lifetime == Scoped { - fromCtx, ok := ctx.Value(ctxTidVal).(T) - if ok { - return fromCtx, ctx, false - } - } - - var con T - - // first, try make concrete implementation from `anonymousInitializer` - // if nil, try the concrete implementation `Creator` - if i.anonymousInitializer != nil { - con, ctx = (*i.anonymousInitializer)(ctx) - } else { - con, ctx = i.creatorInstance.New(ctx) - } - - // if scoped, attach to the current context - if i.lifetime == Scoped { - ctx = context.WithValue(ctx, ctxTidVal, con) - } - - // if was lazily-created, then attach the newly-created concrete implementation - // to the entry - if i.lifetime == Singleton { - i.concrete = &con - - return con, ctx, true - } - - return con, ctx, false -} diff --git a/errors.go b/errors.go index df57341..5597b48 100644 --- a/errors.go +++ b/errors.go @@ -7,11 +7,11 @@ import ( ) func noValidImplementation[T any]() error { - return errors.New(fmt.Sprintf("implementation not found for type: %s", reflect.TypeFor[T]())) + return fmt.Errorf("implementation not found for type: %s", reflect.TypeFor[T]()) } func nilVal[T any]() error { - return errors.New(fmt.Sprintf("nil implementation for type: %s", reflect.TypeFor[T]())) + return fmt.Errorf("nil implementation for type: %s", reflect.TypeFor[T]()) } var alreadyBuilt = errors.New("services container is already built") diff --git a/get_list_test.go b/get_list_test.go index a785275..6731087 100644 --- a/get_list_test.go +++ b/get_list_test.go @@ -9,9 +9,9 @@ func TestGetList(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - counters, _ := GetList[Counter](context.Background()) + counters, _ := GetList[someCounter](context.Background()) if got := len(counters); got != 1 { t.Errorf("got %v, expected %v", got, 1) @@ -21,7 +21,7 @@ func TestGetList(t *testing.T) { func TestGetListShouldNotPanicIfNoImplementations(t *testing.T) { clearAll() - services, _ := GetList[Counter](context.Background()) + services, _ := GetList[someCounter](context.Background()) if len(services) != 0 { t.Errorf("got %v, expected %v", len(services), 0) } @@ -33,12 +33,12 @@ func TestGetListKeyed(t *testing.T) { key := "somekeyhere" - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}, key) - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}, key) - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}, key) - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}, "Firas") + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}, key) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}, key) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}, key) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}, "Firas") - counters, _ := GetList[Counter](context.Background(), key) + counters, _ := GetList[someCounter](context.Background(), key) if got := len(counters); got != 3 { t.Errorf("got %v, expected %v", got, 3) } diff --git a/get_test.go b/get_test.go index bb90cf2..5ae94a7 100644 --- a/get_test.go +++ b/get_test.go @@ -10,9 +10,9 @@ func TestGet(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) - c, _ := Get[Counter](context.Background()) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() @@ -27,13 +27,13 @@ func TestGetLatestByDefault(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}) - c, _ := Get[Counter](context.Background()) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() - RegisterLazyCreator[Counter](registrationType, &simpleCounter2{}) - c, _ = Get[Counter](context.Background()) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter2{}) + c, _ = Get[someCounter](context.Background()) c.AddOne() c.AddOne() c.AddOne() @@ -48,7 +48,7 @@ func TestGetLatestByDefault(t *testing.T) { func TestGetPanicIfNoImplementations(t *testing.T) { clearAll() defer mustHavePanicked(t) - Get[Counter](context.Background()) + Get[someCounter](context.Background()) } func TestGetKeyed(t *testing.T) { @@ -57,9 +57,9 @@ func TestGetKeyed(t *testing.T) { key := fmt.Sprintf("keynum: %v", i) - RegisterLazyCreator[Counter](registrationType, &simpleCounter{}, key) + RegisterLazyCreator[someCounter](registrationType, &simpleCounter{}, key) - c, _ := Get[Counter](context.Background(), key) + c, _ := Get[someCounter](context.Background(), key) c.AddOne() c.AddOne() diff --git a/getters.go b/getters.go index 8aa890d..73774f5 100644 --- a/getters.go +++ b/getters.go @@ -9,32 +9,26 @@ func Get[T any](ctx context.Context, key ...KeyStringer) (T, context.Context) { // generate type identifier typeId := typeIdentifier[T](key) - // try to get entry from container + // try to get service resolver from container lock.RLock() - entries, entryExists := container[typeId] + resolvers, resolverExists := container[typeId] lock.RUnlock() - if !entryExists { + if !resolverExists { panic(noValidImplementation[T]()) } - entriesCount := len(entries) + count := len(resolvers) - if entriesCount == 0 { + if count == 0 { panic(noValidImplementation[T]()) } - // index of the last implementation - index := entriesCount - 1 - - implementation := entries[index].(entry[T]) - - service, ctx, updateEntry := implementation.load(ctx, contextValueId(typeId, index)) - if updateEntry { - replaceEntry[T](typeId, index, implementation) - } - - return service, ctx + // lastIndex of the last implementation + lastIndex := count - 1 + lastRegisteredResolver := resolvers[lastIndex] + service, ctx := lastRegisteredResolver.resolveService(ctx, typeId, lastIndex) + return service.(T), ctx } // GetList Retrieves a list of instances based on type and key @@ -42,33 +36,27 @@ func GetList[T any](ctx context.Context, key ...KeyStringer) ([]T, context.Conte // generate type identifier typeId := typeIdentifier[T](key) - // try to get entry from container + // try to get service resolver from container lock.RLock() - entries, entryExists := container[typeId] + resolvers, resolverExists := container[typeId] lock.RUnlock() - if !entryExists { + if !resolverExists { return make([]T, 0), nil } - entriesCount := len(entries) + count := len(resolvers) - if entriesCount == 0 { + if count == 0 { return make([]T, 0), nil } - servicesArray := make([]T, entriesCount) - - for index := 0; index < entriesCount; index++ { - e := entries[index].(entry[T]) - - service, newCtx, updateEntry := e.load(ctx, contextValueId(typeId, index)) - - if updateEntry { - replaceEntry[T](typeId, index, e) - } + servicesArray := make([]T, count) - servicesArray[index] = service + for index := 0; index < count; index++ { + resolver := resolvers[index] + service, newCtx := resolver.resolveService(ctx, typeId, index) + servicesArray[index] = service.(T) ctx = newCtx } diff --git a/initializer_test.go b/initializer_test.go index c396779..a4ec332 100644 --- a/initializer_test.go +++ b/initializer_test.go @@ -9,11 +9,11 @@ func TestRegisterLazyFunc(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - c, _ := Get[Counter](context.Background()) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() @@ -27,38 +27,38 @@ func TestRegisterLazyFunc(t *testing.T) { func TestRegisterLazyFuncNilFuncTransient(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyFunc[Counter](Transient, nil) + RegisterLazyFunc[someCounter](Transient, nil) } func TestRegisterLazyFuncNilFuncScoped(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyFunc[Counter](Scoped, nil) + RegisterLazyFunc[someCounter](Scoped, nil) } func TestRegisterLazyFuncNilFuncSingleton(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyFunc[Counter](Singleton, nil) + RegisterLazyFunc[someCounter](Singleton, nil) } func TestRegisterLazyFuncMultipleImplementations(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - counters, _ := GetList[Counter](context.Background()) + counters, _ := GetList[someCounter](context.Background()) if got := len(counters); got != 3 { t.Errorf("got %v, expected %v", got, 3) @@ -70,19 +70,19 @@ func TestRegisterLazyFuncMultipleImplementationsKeyed(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }, "firas") - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }, "firas") - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - counters, _ := GetList[Counter](context.Background(), "firas") + counters, _ := GetList[someCounter](context.Background(), "firas") if got := len(counters); got != 2 { t.Errorf("got %v, expected %v", got, 2) @@ -95,18 +95,18 @@ func TestRegisterLazyFuncSingletonState(t *testing.T) { clearAll() - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) - c, _ := Get[Counter](context.Background()) + c, _ := Get[someCounter](context.Background()) c.AddOne() c.AddOne() - c, _ = Get[Counter](context.Background()) + c, _ = Get[someCounter](context.Background()) c.AddOne() - c, _ = Get[Counter](context.Background()) + c, _ = Get[someCounter](context.Background()) c.AddOne() c.AddOne() c.AddOne() @@ -121,20 +121,20 @@ func TestRegisterLazyFuncScopedState(t *testing.T) { clearAll() - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) ctx := context.Background() - c, ctx := Get[Counter](ctx) + c, ctx := Get[someCounter](ctx) c.AddOne() c.AddOne() - c, ctx = Get[Counter](ctx) + c, ctx = Get[someCounter](ctx) c.AddOne() - c, ctx = Get[Counter](ctx) + c, _ = Get[someCounter](ctx) c.AddOne() c.AddOne() c.AddOne() @@ -149,20 +149,20 @@ func TestRegisterLazyFuncTransientState(t *testing.T) { clearAll() - RegisterLazyFunc[Counter](registrationType, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](registrationType, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }) ctx := context.Background() - c, ctx := Get[Counter](ctx) + c, ctx := Get[someCounter](ctx) c.AddOne() c.AddOne() - c, ctx = Get[Counter](ctx) + c, ctx = Get[someCounter](ctx) c.AddOne() - c, ctx = Get[Counter](ctx) + c, _ = Get[someCounter](ctx) c.AddOne() c.AddOne() c.AddOne() @@ -175,7 +175,7 @@ func TestRegisterLazyFuncTransientState(t *testing.T) { func TestRegisterLazyFuncNilKeyOnRegistering(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyFunc[Counter](Scoped, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](Scoped, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }, nil) } @@ -183,22 +183,22 @@ func TestRegisterLazyFuncNilKeyOnRegistering(t *testing.T) { func TestRegisterLazyFuncNilKeyOnGetting(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyFunc[Counter](Scoped, func(ctx context.Context) (Counter, context.Context) { + RegisterLazyFunc[someCounter](Scoped, func(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx }, "firas") - Get[Counter](context.Background(), nil) + Get[someCounter](context.Background(), nil) } func TestRegisterLazyFuncGeneric(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyFunc[CounterGeneric[uint]](registrationType, func(ctx context.Context) (CounterGeneric[uint], context.Context) { + RegisterLazyFunc[someCounterGeneric[uint]](registrationType, func(ctx context.Context) (someCounterGeneric[uint], context.Context) { return &counterGeneric[uint]{}, ctx }) - c, _ := Get[CounterGeneric[uint]](context.Background()) + c, _ := Get[someCounterGeneric[uint]](context.Background()) c.Add(5) c.Add(5) @@ -213,19 +213,19 @@ func TestRegisterLazyFuncMultipleGenericImplementations(t *testing.T) { for _, registrationType := range types { clearAll() - RegisterLazyFunc[CounterGeneric[uint]](registrationType, func(ctx context.Context) (CounterGeneric[uint], context.Context) { + RegisterLazyFunc[someCounterGeneric[uint]](registrationType, func(ctx context.Context) (someCounterGeneric[uint], context.Context) { return &counterGeneric[uint]{}, ctx }) - RegisterLazyFunc[CounterGeneric[uint]](registrationType, func(ctx context.Context) (CounterGeneric[uint], context.Context) { + RegisterLazyFunc[someCounterGeneric[uint]](registrationType, func(ctx context.Context) (someCounterGeneric[uint], context.Context) { return &counterGeneric[uint]{}, ctx }) - RegisterLazyFunc[CounterGeneric[uint]](registrationType, func(ctx context.Context) (CounterGeneric[uint], context.Context) { + RegisterLazyFunc[someCounterGeneric[uint]](registrationType, func(ctx context.Context) (someCounterGeneric[uint], context.Context) { return &counterGeneric[uint]{}, ctx }) - counters, _ := GetList[CounterGeneric[uint]](context.Background()) + counters, _ := GetList[someCounterGeneric[uint]](context.Background()) if got := len(counters); got != 3 { t.Errorf("got %v, expected %v", got, 3) @@ -255,7 +255,7 @@ func TestRegisterLazyFuncScopedNested(t *testing.T) { a2, ctx := Get[*a](ctx) a2.C.Counter += 1 - a3, ctx := Get[*a](ctx) + a3, _ := Get[*a](ctx) a3.C.Counter += 1 if got := a2.C.Counter; got != 3 { diff --git a/lifetimes.go b/lifetimes.go index e7fd477..f2d1fbd 100644 --- a/lifetimes.go +++ b/lifetimes.go @@ -4,6 +4,6 @@ type Lifetime string const ( Singleton Lifetime = "singleton" - Transient = "transient" - Scoped = "scoped" + Transient Lifetime = "transient" + Scoped Lifetime = "scoped" ) diff --git a/ore.go b/ore.go index 15601f3..16be60a 100644 --- a/ore.go +++ b/ore.go @@ -9,29 +9,32 @@ import ( var ( lock = &sync.RWMutex{} isBuilt = false - container = map[string][]any{} + container = map[typeID][]serviceResolver{} ) type Creator[T any] interface { New(ctx context.Context) (T, context.Context) } -// Generates a unique identifier for an entry based on type and key(s) -func typeIdentifier[T any](key []KeyStringer) string { +// Generates a unique identifier for a service resolver based on type and key(s) +func getTypeId(pointerTypeName pointerTypeName, key []KeyStringer) typeID { for _, stringer := range key { if stringer == nil { panic(nilKey) } } - - var mockType *T customKey := oreKey(key) - tt := fmt.Sprintf("%c:%v", mockType, customKey) - return tt + tt := fmt.Sprintf("%s:%v", pointerTypeName, customKey) + return typeID(tt) +} + +// Generates a unique identifier for a service resolver based on type and key(s) +func typeIdentifier[T any](key []KeyStringer) typeID { + return getTypeId(getPointerTypeName[T](), key) } -// Appends an entry to the container with type and key -func appendToContainer[T any](entry entry[T], key []KeyStringer) { +// Appends a service resolver to the container with type and key +func appendToContainer[T any](resolver serviceResolver, key []KeyStringer) { if isBuilt { panic(alreadyBuiltCannotAdd) } @@ -39,13 +42,13 @@ func appendToContainer[T any](entry entry[T], key []KeyStringer) { typeId := typeIdentifier[T](key) lock.Lock() - container[typeId] = append(container[typeId], entry) + container[typeId] = append(container[typeId], resolver) lock.Unlock() } -func replaceEntry[T any](typeId string, index int, entry entry[T]) { +func replaceServiceResolver(typeId typeID, index int, resolver serviceResolver) { lock.Lock() - container[typeId][index] = entry + container[typeId][index] = resolver lock.Unlock() } diff --git a/ore_test.go b/ore_test.go index 53eef01..0130668 100644 --- a/ore_test.go +++ b/ore_test.go @@ -1,12 +1,31 @@ package ore -import "testing" +import ( + "testing" +) func TestBuild(t *testing.T) { clearAll() defer mustHavePanicked(t) - RegisterLazyCreator[Counter](Scoped, &simpleCounter{}) + RegisterLazyCreator[someCounter](Scoped, &simpleCounter{}) Build() - RegisterLazyCreator[Counter](Scoped, &simpleCounter{}) + RegisterLazyCreator[someCounter](Scoped, &simpleCounter{}) +} + +type A1 struct{} +type A2 struct{} + +func TestTypeIdentifier(t *testing.T) { + id1 := typeIdentifier[*A1]([]KeyStringer{}) + id2 := typeIdentifier[*A2]([]KeyStringer{}) + if id1 == id2 { + t.Errorf("got the same identifier value %v, expected different values", id1) + } + + id3 := typeIdentifier[*A1]([]KeyStringer{"a", "b"}) + id4 := typeIdentifier[*A1]([]KeyStringer{"a", "b"}) + if id3 != id4 { + t.Errorf("got %v, expected %v", id3, id4) + } } diff --git a/registrars.go b/registrars.go index 9a20533..df52ba6 100644 --- a/registrars.go +++ b/registrars.go @@ -6,7 +6,7 @@ func RegisterLazyCreator[T any](lifetime Lifetime, creator Creator[T], key ...Ke panic(nilVal[T]()) } - e := entry[T]{ + e := serviceResolverImpl[T]{ lifetime: lifetime, creatorInstance: creator, } @@ -19,9 +19,9 @@ func RegisterEagerSingleton[T comparable](impl T, key ...KeyStringer) { panic(nilVal[T]()) } - e := entry[T]{ - lifetime: Singleton, - concrete: &impl, + e := serviceResolverImpl[T]{ + lifetime: Singleton, + singletonConcrete: &impl, } appendToContainer[T](e, key) } @@ -32,7 +32,7 @@ func RegisterLazyFunc[T any](lifetime Lifetime, initializer Initializer[T], key panic(nilVal[T]()) } - e := entry[T]{ + e := serviceResolverImpl[T]{ lifetime: lifetime, anonymousInitializer: &initializer, } diff --git a/serviceResolver.go b/serviceResolver.go new file mode 100644 index 0000000..30adabc --- /dev/null +++ b/serviceResolver.go @@ -0,0 +1,64 @@ +package ore + +import "context" + +type ( + Initializer[T any] func(ctx context.Context) (T, context.Context) +) + +type serviceResolver interface { + resolveService(ctx context.Context, typeId typeID, index int) (any, context.Context) +} + +type serviceResolverImpl[T any] struct { + anonymousInitializer *Initializer[T] + creatorInstance Creator[T] + singletonConcrete *T + lifetime Lifetime +} + +//make sure that the `serviceResolverImpl` struct implements the `serviceResolver` interface +var _ serviceResolver = serviceResolverImpl[any]{} + +func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeId typeID, index int) (any, context.Context) { + + ctxTidVal := getContextValueID(typeId, index) + + // try get concrete implementation + if this.lifetime == Singleton && this.singletonConcrete != nil { + return *this.singletonConcrete, ctx + } + + // try get concrete from context scope + if this.lifetime == Scoped { + scopedConcrete, ok := ctx.Value(ctxTidVal).(T) + if ok { + return scopedConcrete, ctx + } + } + + var con T + + // first, try make concrete implementation from `anonymousInitializer` + // if nil, try the concrete implementation `Creator` + if this.anonymousInitializer != nil { + con, ctx = (*this.anonymousInitializer)(ctx) + } else { + con, ctx = this.creatorInstance.New(ctx) + } + + // if scoped, attach to the current context + if this.lifetime == Scoped { + ctx = context.WithValue(ctx, ctxTidVal, con) + } + + // if was lazily-created, then attach the newly-created concrete implementation + // to the service resolver + if this.lifetime == Singleton { + this.singletonConcrete = &con + replaceServiceResolver(typeId, index, this) + return con, ctx + } + + return con, ctx +} diff --git a/test.go b/test.go index 6d6fb35..3c12afe 100644 --- a/test.go +++ b/test.go @@ -2,8 +2,6 @@ package ore import ( "context" - "io" - "strconv" "testing" ) @@ -15,21 +13,16 @@ func mustHavePanicked(t *testing.T) { } } -type Counter interface { +type someCounter interface { AddOne() GetCount() int } -type CounterWriter interface { - Add(number int) - GetCount() int -} - type numeric interface { uint } -type CounterGeneric[T numeric] interface { +type someCounterGeneric[T numeric] interface { Add(number T) GetCount() T } @@ -46,7 +39,7 @@ func (c *simpleCounter) GetCount() int { return c.counter } -func (c *simpleCounter) New(ctx context.Context) (Counter, context.Context) { +func (c *simpleCounter) New(ctx context.Context) (someCounter, context.Context) { return &simpleCounter{}, ctx } @@ -62,34 +55,10 @@ func (c *simpleCounter2) GetCount() int { return c.counter } -func (c *simpleCounter2) New(ctx context.Context) (Counter, context.Context) { +func (c *simpleCounter2) New(ctx context.Context) (someCounter, context.Context) { return &simpleCounter2{}, ctx } -type counterWriter struct { - counter int - writer io.Writer -} - -func (c *counterWriter) Add(number int) { - _, _ = c.writer.Write([]byte("New Number Added: " + strconv.Itoa(number))) - c.counter += number -} - -func (c *counterWriter) GetCount() int { - _, _ = c.writer.Write([]byte("Total Count: " + strconv.Itoa(c.counter))) - return c.counter -} - -func (c *counterWriter) New(ctx context.Context) CounterWriter { - - writer, _ := Get[io.Writer](ctx) - - return &counterWriter{ - writer: writer, - } -} - type counterGeneric[T numeric] struct { counter T } @@ -102,7 +71,7 @@ func (c *counterGeneric[T]) GetCount() T { return c.counter } -func (c *counterGeneric[T]) New(ctx context.Context) (CounterGeneric[T], context.Context) { +func (c *counterGeneric[T]) New(ctx context.Context) (someCounterGeneric[T], context.Context) { return &counterGeneric[T]{}, ctx } diff --git a/utils.go b/utils.go index 1333d1c..3c5c60f 100644 --- a/utils.go +++ b/utils.go @@ -2,16 +2,27 @@ package ore import "fmt" +type contextValueID string +type typeID string +type pointerTypeName string + func isNil[T comparable](impl T) bool { var mock T return impl == mock } func clearAll() { - container = make(map[string][]any) + container = make(map[typeID][]serviceResolver) isBuilt = false } -func contextValueId(typeId string, index int) string { - return fmt.Sprintln(typeId, index) +func getContextValueID(typeId typeID, index int) contextValueID { + return contextValueID(fmt.Sprintln(typeId, index)) +} + +// Get type name of *T. +// it allocates less memory and is faster than `reflect.TypeFor[*T]().String()` +func getPointerTypeName[T any]() pointerTypeName { + var mockValue *T + return pointerTypeName(fmt.Sprintf("%T", mockValue)) } From 7b298387a4589f38f6b980728935caa0f7172c86 Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Thu, 31 Oct 2024 21:54:40 +0100 Subject: [PATCH 2/6] add new feature RegisterAlias --- README.md | 37 ++++++++++ alias_test.go | 67 ++++++++++++++++++ examples/aliasdemo/alias_test.go | 114 +++++++++++++++++++++++++++++++ getters.go | 89 +++++++++++++++++------- ore.go | 23 ++++++- registrars.go | 18 +++++ utils.go | 1 + 7 files changed, 321 insertions(+), 28 deletions(-) create mode 100644 alias_test.go create mode 100644 examples/aliasdemo/alias_test.go diff --git a/README.md b/README.md index 35dc573..8e72cc1 100644 --- a/README.md +++ b/README.md @@ -284,6 +284,43 @@ func main() { ``` +### Alias: Register struct, get interface + +```go +type IPerson interface{} +type Broker struct { + Name string +} //implements IPerson + +type Trader struct { + Name string +} //implements IPerson + +func TestGetInterfaceAlias(t *testing.T) { + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { + return &Broker{Name: "Peter"}, ctx + }) + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { + return &Broker{Name: "John"}, ctx + }) + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Trader, context.Context) { + return &Trader{Name: "Mary"}, ctx + }) + + ore.RegisterAlias[IPerson, *Trader]() //link IPerson to *Trader + ore.RegisterAlias[IPerson, *Broker]() //link IPerson to *Broker + + //no IPerson was registered to the container, but we can still `Get` it out of the container. + //(1) IPerson is alias to both *Broker and *Trader. *Broker takes precedence because it's the last one linked to IPerson. + //(2) multiple *Borker (Peter and John) are registered to the container, the last registered (John) takes precedence. + person, _ := ore.Get[IPerson](context.Background()) // will return the broker John + + personList, _ := ore.GetList[IPerson](context.Background()) // will return all registered broker and trader +} +``` + +Alias is also scoped by key. When you "Get" an alias with keys for eg: `ore.Get[IPerson](ctx, "module1")` then Ore would return only Services registered under this key ("module1") and panic if no service found. + ## More Complex Example ```go diff --git a/alias_test.go b/alias_test.go new file mode 100644 index 0000000..dbe0f77 --- /dev/null +++ b/alias_test.go @@ -0,0 +1,67 @@ +package ore + +import ( + "context" + "testing" +) + +func TestGetWithAlias(t *testing.T) { + for _, registrationType := range types { + clearAll() + + RegisterLazyFunc(registrationType, func(ctx context.Context) (*simpleCounterUint, context.Context) { + return &simpleCounterUint{}, ctx + }) + RegisterAlias[someCounterGeneric[uint], *simpleCounterUint]() + + c, _ := Get[someCounterGeneric[uint]](context.Background()) + + c.Add(1) + c.Add(1) + + if got := c.GetCount(); got != 2 { + t.Errorf("got %v, expected %v", got, 2) + } + } +} + +func TestGetListWithAlias(t *testing.T) { + for _, registrationType := range types { + clearAll() + + for i := 0; i < 3; i++ { + RegisterLazyFunc(registrationType, func(ctx context.Context) (*simpleCounterUint, context.Context) { + return &simpleCounterUint{}, ctx + }) + } + + RegisterAlias[someCounterGeneric[uint], *simpleCounterUint]() + + counters, _ := GetList[someCounterGeneric[uint]](context.Background()) + if got := len(counters); got != 3 { + t.Errorf("got %v, expected %v", got, 3) + } + + c := counters[1] + c.Add(1) + c.Add(1) + + if got := c.GetCount(); got != 2 { + t.Errorf("got %v, expected %v", got, 2) + } + } +} + +var _ someCounterGeneric[uint] = (*simpleCounterUint)(nil) + +type simpleCounterUint struct { + counter uint +} + +func (this *simpleCounterUint) Add(number uint) { + this.counter += number +} + +func (this *simpleCounterUint) GetCount() uint { + return this.counter +} diff --git a/examples/aliasdemo/alias_test.go b/examples/aliasdemo/alias_test.go new file mode 100644 index 0000000..e4a9ab8 --- /dev/null +++ b/examples/aliasdemo/alias_test.go @@ -0,0 +1,114 @@ +package aliasdemo + +import ( + "context" + "testing" + + "github.com/firasdarwish/ore" +) + +func TestGetInterfaceAliasWithKeys(t *testing.T) { + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { + return &Broker{Name: "Peter1"}, ctx + }, "module1") + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { + return &Broker{Name: "John1"}, ctx + }, "module1") + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Trader, context.Context) { + return &Trader{Name: "Mary1"}, ctx + }, "module1") + + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { + return &Broker{Name: "John2"}, ctx + }, "module2") + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Trader, context.Context) { + return &Trader{Name: "Mary2"}, ctx + }, "module2") + + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Trader, context.Context) { + return &Trader{Name: "Mary3"}, ctx + }, "module3") + + ore.RegisterAlias[IPerson, *Trader]() //link IPerson to *Trader + ore.RegisterAlias[IPerson, *Broker]() //link IPerson to *Broker + + ctx := context.Background() + + //no IPerson was registered to the container, but we can still `Get` it. + //(1) IPerson is alias to both *Broker and *Trader. *Broker takes precedence because it's the last one linked to IPerson. + //(2) multiple *Borker (Peter and John) are registered to the container, the last registered (John) takes precedence. + person1, ctx := ore.Get[IPerson](ctx, "module1") // will return the broker John + switch person := person1.(type) { + case *Broker: + if person.Name != "John1" { + t.Errorf("got %v, expected %v", person.Name, "John1") + } + case *Trader: + t.Errorf("got Trader, expected Broker") + } + + personList1, ctx := ore.GetList[IPerson](ctx, "module1") // will return all registered broker and trader + if len(personList1) != 3 { + t.Errorf("got %v, expected %v", len(personList1), 3) + } + + person2, ctx := ore.Get[IPerson](ctx, "module2") // will return the broker John + if person2.(*Broker).Name != "John2" { + t.Errorf("got %v, expected %v", person2.(*Broker).Name, "John2") + } + + personList2, ctx := ore.GetList[IPerson](ctx, "module2") // will return all registered broker and trader + if len(personList2) != 2 { + t.Errorf("got %v, expected %v", len(personList2), 2) + } + + person3, ctx := ore.Get[IPerson](ctx, "module3") // will return the trader Mary + if person3.(*Trader).Name != "Mary3" { + t.Errorf("got %v, expected %v", person3.(*Trader).Name, "Mary3") + } + + personList3, ctx := ore.GetList[IPerson](ctx, "module3") // will return all registered broker and trader + if len(personList3) != 1 { + t.Errorf("got %v, expected %v", len(personList3), 1) + } + + personListNoModule, _ := ore.GetList[IPerson](ctx) // will return all registered broker and trader without keys + if len(personListNoModule) != 0 { + t.Errorf("got %v, expected %v", len(personListNoModule), 0) + } +} + +// func TestGetInterfaceAliasWithDifferentScope(t *testing.T) { +// module := "TestGetInterfaceAliasWithDifferentScope" +// ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*Broker, context.Context) { +// return &Broker{Name: "Transient"}, ctx +// }, module) +// ore.RegisterLazyFunc(ore.Singleton, func(ctx context.Context) (*Broker, context.Context) { +// return &Broker{Name: "Singleton"}, ctx +// }, module) +// ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { +// return &Broker{Name: "Scoped"}, ctx +// }, module) +// ore.RegisterAlias[IPerson, *Broker]() //link IPerson to *Broker + +// ctx := context.Background() + +// person, ctx := ore.Get[IPerson](ctx, module) +// if person.(*Broker).Name != "Scoped" { +// t.Errorf("got %v, expected %v", person.(*Broker).Name, "Scoped") +// } + +// personList, _ := ore.GetList[IPerson](ctx, module) +// if len(personList) != 2 { +// t.Errorf("got %v, expected %v", len(personList), 2) +// } +// } + +type IPerson interface{} +type Broker struct { + Name string +} //implements IPerson + +type Trader struct { + Name string +} //implements IPerson diff --git a/getters.go b/getters.go index 73774f5..b1be99a 100644 --- a/getters.go +++ b/getters.go @@ -4,60 +4,97 @@ import ( "context" ) -// Get Retrieves an instance based on type and key (panics if no valid implementations) -func Get[T any](ctx context.Context, key ...KeyStringer) (T, context.Context) { - // generate type identifier - typeId := typeIdentifier[T](key) - +func getLastRegisteredResolver(typeId typeID) (serviceResolver, int) { // try to get service resolver from container lock.RLock() resolvers, resolverExists := container[typeId] lock.RUnlock() if !resolverExists { - panic(noValidImplementation[T]()) + return nil, -1 } count := len(resolvers) if count == 0 { - panic(noValidImplementation[T]()) + return nil, -1 } - // lastIndex of the last implementation + // index of the last implementation lastIndex := count - 1 - lastRegisteredResolver := resolvers[lastIndex] - service, ctx := lastRegisteredResolver.resolveService(ctx, typeId, lastIndex) + return resolvers[lastIndex], lastIndex +} + +// Get Retrieves an instance based on type and key (panics if no valid implementations) +func Get[T any](ctx context.Context, key ...KeyStringer) (T, context.Context) { + pointerTypeName := getPointerTypeName[T]() + typeID := getTypeID(pointerTypeName, key) + lastRegisteredResolver, lastIndex := getLastRegisteredResolver(typeID) + if lastRegisteredResolver == nil { //not found, T is an alias + + lock.RLock() + implementations, implExists := aliases[pointerTypeName] + lock.RUnlock() + + if !implExists { + panic(noValidImplementation[T]()) + } + count := len(implementations) + if count == 0 { + panic(noValidImplementation[T]()) + } + for i := count - 1; i >= 0; i-- { + impl := implementations[i] + typeID = getTypeID(impl, key) + lastRegisteredResolver, lastIndex = getLastRegisteredResolver(typeID) + if lastRegisteredResolver != nil { + break + } + } + } + if lastRegisteredResolver == nil { + panic(noValidImplementation[T]()) + } + service, ctx := lastRegisteredResolver.resolveService(ctx, typeID, lastIndex) return service.(T), ctx } // GetList Retrieves a list of instances based on type and key func GetList[T any](ctx context.Context, key ...KeyStringer) ([]T, context.Context) { - // generate type identifier - typeId := typeIdentifier[T](key) + inputPointerTypeName := getPointerTypeName[T]() - // try to get service resolver from container lock.RLock() - resolvers, resolverExists := container[typeId] + pointerTypeNames, implExists := aliases[inputPointerTypeName] lock.RUnlock() - if !resolverExists { - return make([]T, 0), nil + if implExists { + pointerTypeNames = append(pointerTypeNames, inputPointerTypeName) + } else { + pointerTypeNames = []pointerTypeName{inputPointerTypeName} } - count := len(resolvers) + servicesArray := []T{} - if count == 0 { - return make([]T, 0), nil - } + for i := 0; i < len(pointerTypeNames); i++ { + pointerTypeName := pointerTypeNames[i] + // generate type identifier + typeID := getTypeID(pointerTypeName, key) + + // try to get service resolver from container + lock.RLock() + resolvers, resolverExists := container[typeID] + lock.RUnlock() - servicesArray := make([]T, count) + if !resolverExists { + continue + } - for index := 0; index < count; index++ { - resolver := resolvers[index] - service, newCtx := resolver.resolveService(ctx, typeId, index) - servicesArray[index] = service.(T) - ctx = newCtx + for index := 0; index < len(resolvers); index++ { + resolver := resolvers[index] + service, newCtx := resolver.resolveService(ctx, typeID, index) + servicesArray = append(servicesArray, service.(T)) + ctx = newCtx + } } return servicesArray, ctx diff --git a/ore.go b/ore.go index 16be60a..f07f43a 100644 --- a/ore.go +++ b/ore.go @@ -10,6 +10,9 @@ var ( lock = &sync.RWMutex{} isBuilt = false container = map[typeID][]serviceResolver{} + + //map the alias type (usually an interface) to the original types (usually implementations of the interface) + aliases = map[pointerTypeName][]pointerTypeName{} ) type Creator[T any] interface { @@ -17,7 +20,7 @@ type Creator[T any] interface { } // Generates a unique identifier for a service resolver based on type and key(s) -func getTypeId(pointerTypeName pointerTypeName, key []KeyStringer) typeID { +func getTypeID(pointerTypeName pointerTypeName, key []KeyStringer) typeID { for _, stringer := range key { if stringer == nil { panic(nilKey) @@ -30,7 +33,7 @@ func getTypeId(pointerTypeName pointerTypeName, key []KeyStringer) typeID { // Generates a unique identifier for a service resolver based on type and key(s) func typeIdentifier[T any](key []KeyStringer) typeID { - return getTypeId(getPointerTypeName[T](), key) + return getTypeID(getPointerTypeName[T](), key) } // Appends a service resolver to the container with type and key @@ -52,6 +55,22 @@ func replaceServiceResolver(typeId typeID, index int, resolver serviceResolver) lock.Unlock() } +func appendToAliases[TInterface, TImpl any]() { + originalType := getPointerTypeName[TImpl]() + aliasType := getPointerTypeName[TInterface]() + if originalType == aliasType { + return + } + lock.Lock() + for _, ot := range aliases[aliasType] { + if ot == originalType { + return //already registered + } + } + aliases[aliasType] = append(aliases[aliasType], originalType) + lock.Unlock() +} + func Build() { if isBuilt { panic(alreadyBuilt) diff --git a/registrars.go b/registrars.go index df52ba6..7462a71 100644 --- a/registrars.go +++ b/registrars.go @@ -1,5 +1,10 @@ package ore +import ( + "fmt" + "reflect" +) + // RegisterLazyCreator Registers a lazily initialized value using a `Creator[T]` interface func RegisterLazyCreator[T any](lifetime Lifetime, creator Creator[T], key ...KeyStringer) { if creator == nil { @@ -38,3 +43,16 @@ func RegisterLazyFunc[T any](lifetime Lifetime, initializer Initializer[T], key } appendToContainer[T](e, key) } + +// RegisterAlias Registers an interface type to a concrete implementation. +// Allowing you to register the concrete implementation to the container and later get the interface from it. +func RegisterAlias[TInterface, TImpl any]() { + interfaceType := reflect.TypeFor[TInterface]() + implType := reflect.TypeFor[TImpl]() + + if !implType.Implements(interfaceType) { + panic(fmt.Errorf("%s does not implements %s", implType, interfaceType)) + } + + appendToAliases[TInterface, TImpl]() +} diff --git a/utils.go b/utils.go index 3c5c60f..20f382f 100644 --- a/utils.go +++ b/utils.go @@ -13,6 +13,7 @@ func isNil[T comparable](impl T) bool { func clearAll() { container = make(map[typeID][]serviceResolver) + aliases = make(map[pointerTypeName][]pointerTypeName) isBuilt = false } From 9d96656f2d9f4eeb4f45d7841aa02405e8b41bde Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Sat, 2 Nov 2024 00:33:27 +0100 Subject: [PATCH 3/6] add more alias test --- alias_test.go | 144 ++++++++++++++++++++++++++++--- examples/aliasdemo/alias_test.go | 86 +++++++----------- examples/{ => simple}/main.go | 3 +- examples/{ => simple}/service.go | 2 +- go.mod | 8 ++ go.sum | 9 ++ internal/models/person.go | 12 +++ test.go => utils_test.go | 0 8 files changed, 195 insertions(+), 69 deletions(-) rename examples/{ => simple}/main.go (98%) rename examples/{ => simple}/service.go (97%) create mode 100644 go.sum create mode 100644 internal/models/person.go rename test.go => utils_test.go (100%) diff --git a/alias_test.go b/alias_test.go index dbe0f77..d0dcb87 100644 --- a/alias_test.go +++ b/alias_test.go @@ -3,9 +3,137 @@ package ore import ( "context" "testing" + + m "github.com/firasdarwish/ore/internal/models" + "github.com/stretchr/testify/assert" ) -func TestGetWithAlias(t *testing.T) { +func TestAliasResolverConflict(t *testing.T) { + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (m.IPerson, context.Context) { + return &m.Trader{Name: "Peter Singleton"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "Mary Transient"}, ctx + }) + + RegisterAlias[m.IPerson, *m.Trader]() + RegisterAlias[m.IPerson, *m.Broker]() + + ctx := context.Background() + + //The last registered IPerson is "Mary Transient", it would normally takes precedence. + //However we registered a direct resolver for IPerson which is "Peter Singleton". + //So Ore won't treat IPerson as an alias and will resolve IPerson directly as "Peter Singleton" + person, ctx := Get[m.IPerson](ctx) + assert.Equal(t, person.(*m.Trader).Name, "Peter Singleton") + + //GetlList will return all possible IPerson whatever alias or from direct resolver. + personList, _ := GetList[m.IPerson](ctx) + assert.Equal(t, len(personList), 2) +} + +func TestAliasOfAliasIsNotAllow(t *testing.T) { + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.Trader, context.Context) { + return &m.Trader{Name: "Peter Singleton"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "Mary Transient"}, ctx + }) + + RegisterAlias[m.IPerson, *m.Trader]() + RegisterAlias[m.IPerson, *m.Broker]() + RegisterAlias[m.IHuman, m.IPerson]() //alias of alias + + assert.Panics(t, func() { + _, _ = Get[m.IHuman](context.Background()) + }, "implementation not found for type: IHuman") + + humans, _ := GetList[m.IHuman](context.Background()) + assert.Empty(t, humans) +} + +func TestAliasWithDifferentScope(t *testing.T) { + clearAll() + module := "TestGetInterfaceAliasWithDifferentScope" + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "Transient"}, ctx + }, module) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "Singleton"}, ctx + }, module) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "Scoped"}, ctx + }, module) + RegisterAlias[m.IPerson, *m.Broker]() //link m.IPerson to *m.Broker + + ctx := context.Background() + + person, ctx := Get[m.IPerson](ctx, module) + assert.Equal(t, person.(*m.Broker).Name, "Scoped") + + personList, _ := GetList[m.IPerson](ctx, module) + assert.Equal(t, len(personList), 3) +} + +func TestAliasIsScopedByKeys(t *testing.T) { + clearAll() + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "Peter1"}, ctx + }, "module1") + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "John1"}, ctx + }, "module1") + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.Trader, context.Context) { + return &m.Trader{Name: "Mary1"}, ctx + }, "module1") + + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.Broker, context.Context) { + return &m.Broker{Name: "John2"}, ctx + }, "module2") + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.Trader, context.Context) { + return &m.Trader{Name: "Mary2"}, ctx + }, "module2") + + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.Trader, context.Context) { + return &m.Trader{Name: "Mary3"}, ctx + }, "module3") + + RegisterAlias[m.IPerson, *m.Trader]() //link m.IPerson to *m.Trader + RegisterAlias[m.IPerson, *m.Broker]() //link m.IPerson to *m.Broker + + ctx := context.Background() + + person1, ctx := Get[m.IPerson](ctx, "module1") // will return the m.Broker John + assert.Equal(t, person1.(*m.Broker).Name, "John1") + + personList1, ctx := GetList[m.IPerson](ctx, "module1") // will return all registered m.Broker and m.Trader + assert.Equal(t, len(personList1), 3) + + person2, ctx := Get[m.IPerson](ctx, "module2") // will return the m.Broker John + assert.Equal(t, person2.(*m.Broker).Name, "John2") + + personList2, ctx := GetList[m.IPerson](ctx, "module2") // will return all registered m.Broker and m.Trader + assert.Equal(t, len(personList2), 2) + + person3, ctx := Get[m.IPerson](ctx, "module3") // will return the m.Trader Mary + assert.Equal(t, person3.(*m.Trader).Name, "Mary3") + + personList3, ctx := GetList[m.IPerson](ctx, "module3") // will return all registered m.Broker and m.Trader + assert.Equal(t, len(personList3), 1) + + personListNoModule, _ := GetList[m.IPerson](ctx) // will return all registered m.Broker and m.Trader without keys + assert.Empty(t, personListNoModule) +} + +func TestInvalidAlias(t *testing.T) { + assert.Panics(t, func() { + RegisterAlias[error, *m.Broker]() //register a struct (Broker) that does not implement interface (error) + }, "Broker does not implements error") +} + +func TestGetGenericAlias(t *testing.T) { for _, registrationType := range types { clearAll() @@ -19,13 +147,11 @@ func TestGetWithAlias(t *testing.T) { c.Add(1) c.Add(1) - if got := c.GetCount(); got != 2 { - t.Errorf("got %v, expected %v", got, 2) - } + assert.Equal(t, uint(2), c.GetCount()) } } -func TestGetListWithAlias(t *testing.T) { +func TestGetListGenericAlias(t *testing.T) { for _, registrationType := range types { clearAll() @@ -38,17 +164,13 @@ func TestGetListWithAlias(t *testing.T) { RegisterAlias[someCounterGeneric[uint], *simpleCounterUint]() counters, _ := GetList[someCounterGeneric[uint]](context.Background()) - if got := len(counters); got != 3 { - t.Errorf("got %v, expected %v", got, 3) - } + assert.Equal(t, len(counters), 3) c := counters[1] c.Add(1) c.Add(1) - if got := c.GetCount(); got != 2 { - t.Errorf("got %v, expected %v", got, 2) - } + assert.Equal(t, uint(2), c.GetCount()) } } diff --git a/examples/aliasdemo/alias_test.go b/examples/aliasdemo/alias_test.go index e4a9ab8..fdc8b05 100644 --- a/examples/aliasdemo/alias_test.go +++ b/examples/aliasdemo/alias_test.go @@ -8,107 +8,81 @@ import ( ) func TestGetInterfaceAliasWithKeys(t *testing.T) { - ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { - return &Broker{Name: "Peter1"}, ctx + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*broker, context.Context) { + return &broker{Name: "Peter1"}, ctx }, "module1") - ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { - return &Broker{Name: "John1"}, ctx + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*broker, context.Context) { + return &broker{Name: "John1"}, ctx }, "module1") - ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Trader, context.Context) { - return &Trader{Name: "Mary1"}, ctx + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*trader, context.Context) { + return &trader{Name: "Mary1"}, ctx }, "module1") - ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { - return &Broker{Name: "John2"}, ctx + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*broker, context.Context) { + return &broker{Name: "John2"}, ctx }, "module2") - ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Trader, context.Context) { - return &Trader{Name: "Mary2"}, ctx + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*trader, context.Context) { + return &trader{Name: "Mary2"}, ctx }, "module2") - ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Trader, context.Context) { - return &Trader{Name: "Mary3"}, ctx + ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*trader, context.Context) { + return &trader{Name: "Mary3"}, ctx }, "module3") - ore.RegisterAlias[IPerson, *Trader]() //link IPerson to *Trader - ore.RegisterAlias[IPerson, *Broker]() //link IPerson to *Broker + ore.RegisterAlias[iPerson, *trader]() //link IPerson to *Trader + ore.RegisterAlias[iPerson, *broker]() //link IPerson to *Broker ctx := context.Background() //no IPerson was registered to the container, but we can still `Get` it. //(1) IPerson is alias to both *Broker and *Trader. *Broker takes precedence because it's the last one linked to IPerson. //(2) multiple *Borker (Peter and John) are registered to the container, the last registered (John) takes precedence. - person1, ctx := ore.Get[IPerson](ctx, "module1") // will return the broker John + person1, ctx := ore.Get[iPerson](ctx, "module1") // will return the broker John switch person := person1.(type) { - case *Broker: + case *broker: if person.Name != "John1" { t.Errorf("got %v, expected %v", person.Name, "John1") } - case *Trader: + case *trader: t.Errorf("got Trader, expected Broker") } - personList1, ctx := ore.GetList[IPerson](ctx, "module1") // will return all registered broker and trader + personList1, ctx := ore.GetList[iPerson](ctx, "module1") // will return all registered broker and trader if len(personList1) != 3 { t.Errorf("got %v, expected %v", len(personList1), 3) } - person2, ctx := ore.Get[IPerson](ctx, "module2") // will return the broker John - if person2.(*Broker).Name != "John2" { - t.Errorf("got %v, expected %v", person2.(*Broker).Name, "John2") + person2, ctx := ore.Get[iPerson](ctx, "module2") // will return the broker John + if person2.(*broker).Name != "John2" { + t.Errorf("got %v, expected %v", person2.(*broker).Name, "John2") } - personList2, ctx := ore.GetList[IPerson](ctx, "module2") // will return all registered broker and trader + personList2, ctx := ore.GetList[iPerson](ctx, "module2") // will return all registered broker and trader if len(personList2) != 2 { t.Errorf("got %v, expected %v", len(personList2), 2) } - person3, ctx := ore.Get[IPerson](ctx, "module3") // will return the trader Mary - if person3.(*Trader).Name != "Mary3" { - t.Errorf("got %v, expected %v", person3.(*Trader).Name, "Mary3") + person3, ctx := ore.Get[iPerson](ctx, "module3") // will return the trader Mary + if person3.(*trader).Name != "Mary3" { + t.Errorf("got %v, expected %v", person3.(*trader).Name, "Mary3") } - personList3, ctx := ore.GetList[IPerson](ctx, "module3") // will return all registered broker and trader + personList3, ctx := ore.GetList[iPerson](ctx, "module3") // will return all registered broker and trader if len(personList3) != 1 { t.Errorf("got %v, expected %v", len(personList3), 1) } - personListNoModule, _ := ore.GetList[IPerson](ctx) // will return all registered broker and trader without keys + personListNoModule, _ := ore.GetList[iPerson](ctx) // will return all registered broker and trader without keys if len(personListNoModule) != 0 { t.Errorf("got %v, expected %v", len(personListNoModule), 0) } } -// func TestGetInterfaceAliasWithDifferentScope(t *testing.T) { -// module := "TestGetInterfaceAliasWithDifferentScope" -// ore.RegisterLazyFunc(ore.Transient, func(ctx context.Context) (*Broker, context.Context) { -// return &Broker{Name: "Transient"}, ctx -// }, module) -// ore.RegisterLazyFunc(ore.Singleton, func(ctx context.Context) (*Broker, context.Context) { -// return &Broker{Name: "Singleton"}, ctx -// }, module) -// ore.RegisterLazyFunc(ore.Scoped, func(ctx context.Context) (*Broker, context.Context) { -// return &Broker{Name: "Scoped"}, ctx -// }, module) -// ore.RegisterAlias[IPerson, *Broker]() //link IPerson to *Broker - -// ctx := context.Background() - -// person, ctx := ore.Get[IPerson](ctx, module) -// if person.(*Broker).Name != "Scoped" { -// t.Errorf("got %v, expected %v", person.(*Broker).Name, "Scoped") -// } - -// personList, _ := ore.GetList[IPerson](ctx, module) -// if len(personList) != 2 { -// t.Errorf("got %v, expected %v", len(personList), 2) -// } -// } - -type IPerson interface{} -type Broker struct { +type iPerson interface{} +type broker struct { Name string } //implements IPerson -type Trader struct { +type trader struct { Name string } //implements IPerson diff --git a/examples/main.go b/examples/simple/main.go similarity index 98% rename from examples/main.go rename to examples/simple/main.go index 99d7bf6..41ff1e6 100644 --- a/examples/main.go +++ b/examples/simple/main.go @@ -1,8 +1,9 @@ -package main +package simple import ( "context" "fmt" + "github.com/firasdarwish/ore" ) diff --git a/examples/service.go b/examples/simple/service.go similarity index 97% rename from examples/service.go rename to examples/simple/service.go index c05fbd0..3516ecd 100644 --- a/examples/service.go +++ b/examples/simple/service.go @@ -1,4 +1,4 @@ -package main +package simple import ( "context" diff --git a/go.mod b/go.mod index 50f9ee1..13c5b2e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module github.com/firasdarwish/ore go 1.22 + +require github.com/stretchr/testify v1.9.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e20fa14 --- /dev/null +++ b/go.sum @@ -0,0 +1,9 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/models/person.go b/internal/models/person.go new file mode 100644 index 0000000..0eeb105 --- /dev/null +++ b/internal/models/person.go @@ -0,0 +1,12 @@ +package models + +type IPerson interface{} +type Broker struct { + Name string +} //implements IPerson + +type Trader struct { + Name string +} //implements IPerson + +type IHuman interface{} diff --git a/test.go b/utils_test.go similarity index 100% rename from test.go rename to utils_test.go From 87caf113f2a139b186b65f1df0838de1aafdc2f3 Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Sat, 2 Nov 2024 17:22:14 +0100 Subject: [PATCH 4/6] refactor: use struct instead of string for ID object 1) create struct is faster than concatenate string 2) comparing struct equality is also faster than string 3) struct is more flexible, later we can easily access to different part/member of the struct instead of using string "contain" or regular expression --- ore.go | 9 +++------ ore_test.go | 10 ++++------ serviceResolver.go | 13 ++++++------- utils.go | 14 ++++++++------ 4 files changed, 21 insertions(+), 25 deletions(-) diff --git a/ore.go b/ore.go index f07f43a..0048464 100644 --- a/ore.go +++ b/ore.go @@ -2,7 +2,6 @@ package ore import ( "context" - "fmt" "sync" ) @@ -26,9 +25,7 @@ func getTypeID(pointerTypeName pointerTypeName, key []KeyStringer) typeID { panic(nilKey) } } - customKey := oreKey(key) - tt := fmt.Sprintf("%s:%v", pointerTypeName, customKey) - return typeID(tt) + return typeID{pointerTypeName, oreKey(key)} } // Generates a unique identifier for a service resolver based on type and key(s) @@ -42,10 +39,10 @@ func appendToContainer[T any](resolver serviceResolver, key []KeyStringer) { panic(alreadyBuiltCannotAdd) } - typeId := typeIdentifier[T](key) + typeID := typeIdentifier[T](key) lock.Lock() - container[typeId] = append(container[typeId], resolver) + container[typeID] = append(container[typeID], resolver) lock.Unlock() } diff --git a/ore_test.go b/ore_test.go index 0130668..2b68d49 100644 --- a/ore_test.go +++ b/ore_test.go @@ -2,6 +2,8 @@ package ore import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestBuild(t *testing.T) { @@ -19,13 +21,9 @@ type A2 struct{} func TestTypeIdentifier(t *testing.T) { id1 := typeIdentifier[*A1]([]KeyStringer{}) id2 := typeIdentifier[*A2]([]KeyStringer{}) - if id1 == id2 { - t.Errorf("got the same identifier value %v, expected different values", id1) - } + assert.NotEqual(t, id1, id2) id3 := typeIdentifier[*A1]([]KeyStringer{"a", "b"}) id4 := typeIdentifier[*A1]([]KeyStringer{"a", "b"}) - if id3 != id4 { - t.Errorf("got %v, expected %v", id3, id4) - } + assert.Equal(t, id3, id4) } diff --git a/serviceResolver.go b/serviceResolver.go index 30adabc..e053ca1 100644 --- a/serviceResolver.go +++ b/serviceResolver.go @@ -20,18 +20,17 @@ type serviceResolverImpl[T any] struct { //make sure that the `serviceResolverImpl` struct implements the `serviceResolver` interface var _ serviceResolver = serviceResolverImpl[any]{} -func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeId typeID, index int) (any, context.Context) { - - ctxTidVal := getContextValueID(typeId, index) - +func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID typeID, index int) (any, context.Context) { // try get concrete implementation if this.lifetime == Singleton && this.singletonConcrete != nil { return *this.singletonConcrete, ctx } + ctxKey := contextKey{typeID, index} + // try get concrete from context scope if this.lifetime == Scoped { - scopedConcrete, ok := ctx.Value(ctxTidVal).(T) + scopedConcrete, ok := ctx.Value(ctxKey).(T) if ok { return scopedConcrete, ctx } @@ -49,14 +48,14 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeId ty // if scoped, attach to the current context if this.lifetime == Scoped { - ctx = context.WithValue(ctx, ctxTidVal, con) + ctx = context.WithValue(ctx, ctxKey, con) } // if was lazily-created, then attach the newly-created concrete implementation // to the service resolver if this.lifetime == Singleton { this.singletonConcrete = &con - replaceServiceResolver(typeId, index, this) + replaceServiceResolver(typeID, index, this) return con, ctx } diff --git a/utils.go b/utils.go index 20f382f..187367a 100644 --- a/utils.go +++ b/utils.go @@ -2,8 +2,14 @@ package ore import "fmt" -type contextValueID string -type typeID string +type contextKey struct { + typeID + index int +} +type typeID struct { + pointerTypeName pointerTypeName + oreKey string +} type pointerTypeName string func isNil[T comparable](impl T) bool { @@ -17,10 +23,6 @@ func clearAll() { isBuilt = false } -func getContextValueID(typeId typeID, index int) contextValueID { - return contextValueID(fmt.Sprintln(typeId, index)) -} - // Get type name of *T. // it allocates less memory and is faster than `reflect.TypeFor[*T]().String()` func getPointerTypeName[T any]() pointerTypeName { From d9ef42b65a8939a7f194d76ae13f703a45ba3465 Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Sat, 2 Nov 2024 13:29:44 +0100 Subject: [PATCH 5/6] add support for graceful service or scope shutdown * Add GetResolvedSingletons func * Add GetResolvedScopedInstances func --- README.md | 71 ++++++++++++++++++++++++++ examples/shutdownerdemo/main.go | 75 +++++++++++++++++++++++++++ get_test.go | 90 +++++++++++++++++++++++++++++++++ getters.go | 53 +++++++++++++++++++ internal/models/disposable.go | 52 +++++++++++++++++++ ore.go | 9 ++++ serviceResolver.go | 20 ++++++++ 7 files changed, 370 insertions(+) create mode 100644 examples/shutdownerdemo/main.go create mode 100644 internal/models/disposable.go diff --git a/README.md b/README.md index 8e72cc1..89f9cf7 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,77 @@ func TestGetInterfaceAlias(t *testing.T) { Alias is also scoped by key. When you "Get" an alias with keys for eg: `ore.Get[IPerson](ctx, "module1")` then Ore would return only Services registered under this key ("module1") and panic if no service found. +### Graceful application termination + +On application termination, you want to call `Shutdown()` on all the "Singletons" objects which have been created during the application life time. + +Here how Ore can help you: + +```go +// Assuming that the Application provides certain instances with Singleton lifetime. +// Some of these singletons implement a custom `Shutdowner` interface (defined within the application) +type Shutdowner interface { + Shutdown() +} +ore.RegisterEagerSingleton(&Logger{}) //*Logger implements Shutdowner +ore.RegisterEagerSingleton(&SomeRepository{}) //*SomeRepository implements Shutdowner +ore.RegisterEagerSingleton(&SomeService{}, "some_module") //*SomeService implements Shutdowner + +//On application termination, Ore can help to retreive all the singletons implementation of the `Shutdowner` interface. +//There might be other `Shutdowner`'s implementation which were lazily registered but have never been created (a.k.a invoked). +//Ore will ignore them, and return only the concrete instances which can be Shutdown() +shutdowables := ore.GetResolvedSingletons[Shutdowner]() + +//Now we can Shutdown() them all and gracefully terminate our application. +for _, instance := range disposables { + instance.Shutdown() +} +``` + +In resume, the `ore.GetResolvedSingletons[TInterface]()` function returns a list of Singleton implementations of the `[TInterface]`. + +- It returns only the instances which had been invoked (a.k.a resolved). +- All the implementations including "keyed" one will be returned. + +### Graceful context termination + +On context termination, you want to call `Dispose()` on all the "Scoped" objects which have been created during the context life time. + +Here how Ore can help you: + +```go +//Assuming that your Application provides certain instances with Scoped lifetime. +//Some of them implements a "Disposer" interface (defined winthin the application). +type Disposer interface { + Dispose() +} +ore.RegisterLazyCreator(ore.Scoped, &SomeDisposableService{}) //*SomeDisposableService implements Disposer + +//a new request arrive +ctx, cancel := context.WithCancel(context.Background()) + +//start a go routine that will clean up resources when the context is canceled +go func() { + <-ctx.Done() // Wait for the context to be canceled + // Perform your cleanup tasks here + disposables := ore.GetResolvedScopedInstances[Disposer](ctx) + for _, d := range disposables { + _ = d.Dispose(ctx) + } +}() +... +ore.Get[*SomeDisposableService](ctx) //invoke some scoped services +cancel() //cancel the ctx + +``` + +The `ore.GetResolvedScopedInstances[TInterface](context)` function returns a list of implementations of the `[TInterface]` which are Scoped in the input context: + +- It returns only the instances which had been invoked (a.k.a resolved) during the context life time. +- All the implementations including "keyed" one will be returned. + +This function would help us to gracefully terminate the context. Example: + ## More Complex Example ```go diff --git a/examples/shutdownerdemo/main.go b/examples/shutdownerdemo/main.go new file mode 100644 index 0000000..7b6fb1f --- /dev/null +++ b/examples/shutdownerdemo/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "context" + "log" + "sync" + + "github.com/firasdarwish/ore" +) + +type shutdowner interface { + Shutdown() error +} + +type disposer interface { + Dispose(ctx context.Context) error +} + +type myGlobalRepo struct { +} + +var _ shutdowner = (*myGlobalRepo)(nil) + +func (*myGlobalRepo) Shutdown() error { + log.Println("shutdown globalRepo") + return nil +} + +type myScopedRepo struct { +} + +var _ disposer = (*myScopedRepo)(nil) + +func (*myScopedRepo) Dispose(ctx context.Context) error { + log.Println("dispose scopedRepo") + return nil +} + +func (*myScopedRepo) New(ctx context.Context) (*myScopedRepo, context.Context) { + return &myScopedRepo{}, ctx +} + +func main() { + ore.RegisterEagerSingleton[*myGlobalRepo](&myGlobalRepo{}) + ore.RegisterLazyCreator(ore.Scoped, &myScopedRepo{}) + + wg := sync.WaitGroup{} + wg.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Ensure context is canceled when main exits + + //start a go routine that will clean up resources when the context is canceled + go func() { + <-ctx.Done() // Wait for the context to be canceled + // Perform your cleanup tasks here + disposables := ore.GetResolvedScopedInstances[disposer](ctx) + for _, d := range disposables { + _ = d.Dispose(ctx) + } + wg.Done() + }() + + //invoke the scoped service + _, ctx = ore.Get[*myScopedRepo](ctx) + + //cancel the context will trigger the cleanup + cancel() + wg.Wait() // Wait for the goroutine to finish + + shutdownables := ore.GetResolvedSingletons[shutdowner]() + for _, s := range shutdownables { + _ = s.Shutdown() + } +} diff --git a/get_test.go b/get_test.go index 5ae94a7..b78b947 100644 --- a/get_test.go +++ b/get_test.go @@ -4,6 +4,9 @@ import ( "context" "fmt" "testing" + + m "github.com/firasdarwish/ore/internal/models" + "github.com/stretchr/testify/assert" ) func TestGet(t *testing.T) { @@ -69,3 +72,90 @@ func TestGetKeyed(t *testing.T) { } } } + +func TestGetResolvedSingletons(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A1"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A2"}, ctx + }) + RegisterEagerSingleton(&m.DisposableService2{Name: "E1"}) + RegisterEagerSingleton(&m.DisposableService2{Name: "E2"}) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "S1"}, ctx + }) + RegisterLazyFunc(Transient, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "S2"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService4, context.Context) { + return &m.DisposableService4{Name: "X1"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService4, context.Context) { + return &m.DisposableService4{Name: "X2"}, ctx + }, "somekey") + + ctx := context.Background() + //Act + disposables := GetResolvedSingletons[m.Disposer]() //E1, E2 + assert.Equal(t, 2, len(disposables)) + + //invoke A1, A2 + _, ctx = GetList[*m.DisposableService1](ctx) //A1, A2 + + //Act + disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2 + assert.Equal(t, 4, len(disposables)) + + //invoke S1, S2, X1 + RegisterAlias[fmt.Stringer, *m.DisposableService3]() + RegisterAlias[fmt.Stringer, *m.DisposableService4]() + _, ctx = GetList[fmt.Stringer](ctx) //S1, S2, X1 + + //Act + //because S1, S2 are not singleton, so they won't be returned, only X1 will be returned in addition + disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2, X1 + assert.Equal(t, 5, len(disposables)) + + //invoke X2 in "somekey" scope + _, _ = GetList[fmt.Stringer](ctx, "somekey") + + //Act + //all invoked singleton would be returned whatever keys they are registered with + disposables = GetResolvedSingletons[m.Disposer]() //E1, E2, A1, A2, X1, X2 + assert.Equal(t, 6, len(disposables)) +} + +func TestGetResolvedScopedInstances(t *testing.T) { + clearAll() + RegisterEagerSingleton(&m.DisposableService1{Name: "S1"}) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "S2"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "T1"}, ctx + }, "module1") + + ctx := context.Background() + + //Act + disposables := GetResolvedScopedInstances[m.Disposer](ctx) //empty + assert.Empty(t, disposables) + + //invoke S2 + _, ctx = GetList[*m.DisposableService1](ctx) + + //Act + disposables = GetResolvedScopedInstances[m.Disposer](ctx) //S2 + assert.Equal(t, 1, len(disposables)) + assert.Equal(t, "S2", disposables[0].String()) + + //invoke the keyed service T1 + _, ctx = GetList[*m.DisposableService2](ctx, "module1") + + //Act + disposables = GetResolvedScopedInstances[m.Disposer](ctx) //S2, T1 + assert.Equal(t, 2, len(disposables)) +} diff --git a/getters.go b/getters.go index b1be99a..2863950 100644 --- a/getters.go +++ b/getters.go @@ -99,3 +99,56 @@ func GetList[T any](ctx context.Context, key ...KeyStringer) ([]T, context.Conte return servicesArray, ctx } + +// GetResolvedSingletons retrieves a list of Singleton instances that implement the [TInterface]. +// It would return only the instances which had been resolved. Other lazy implementations which have never been invoked will not be returned. +// This function is useful for cleaning operations. +// +// Example: +// +// disposableSingletons := ore.GetResolvedSingletons[Disposer]() +// for _, disposable := range disposableSingletons { +// disposable.Dispose() +// } +func GetResolvedSingletons[TInterface any]() []TInterface { + lock.RLock() + defer lock.RUnlock() + + result := []TInterface{} + for _, resolvers := range container { + for _, resolver := range resolvers { + invokedValue, isInvokedSingleton := resolver.getInvokedSingleton() + if isInvokedSingleton { + if instance, ok := invokedValue.(TInterface); ok { + result = append(result, instance) + } + } + } + } + return result +} + +// GetResolvedScopedInstances retrieves a list of Scoped instances that implement the [TInterface]. +// It would return only the instances which had been resolved. Other lazy implementations which have never been invoked will not be returned. +// This function is useful for cleaning operations. +// +// Example: +// +// disposableInstances := ore.GetResolvedScopedInstances[Disposer](ctx) +// for _, disposable := range disposableInstances { +// disposable.Dispose() +// } +func GetResolvedScopedInstances[TInterface any](ctx context.Context) []TInterface { + contextKeyRepository, ok := ctx.Value(contextKeysRepositoryID).(contextKeysRepository) + if !ok { + return []TInterface{} + } + result := []TInterface{} + for _, contextKey := range contextKeyRepository { + invokedValue := ctx.Value(contextKey) + if instance, ok := invokedValue.(TInterface); ok { + result = append(result, instance) + } + } + return result +} diff --git a/internal/models/disposable.go b/internal/models/disposable.go new file mode 100644 index 0000000..31eb215 --- /dev/null +++ b/internal/models/disposable.go @@ -0,0 +1,52 @@ +package models + +import "fmt" + +type Disposer interface { + fmt.Stringer + Dispose() +} + +var _ Disposer = (*DisposableService1)(nil) + +type DisposableService1 struct { + Name string +} + +func (*DisposableService1) Dispose() {} +func (this *DisposableService1) String() string { + return this.Name +} + +var _ Disposer = (*DisposableService2)(nil) + +type DisposableService2 struct { + Name string +} + +func (*DisposableService2) Dispose() {} +func (this *DisposableService2) String() string { + return this.Name +} + +var _ Disposer = (*DisposableService3)(nil) + +type DisposableService3 struct { + Name string +} + +func (*DisposableService3) Dispose() {} +func (this *DisposableService3) String() string { + return this.Name +} + +var _ Disposer = (*DisposableService4)(nil) + +type DisposableService4 struct { + Name string +} + +func (*DisposableService4) Dispose() {} +func (this *DisposableService4) String() string { + return this.Name +} diff --git a/ore.go b/ore.go index 0048464..ca3823a 100644 --- a/ore.go +++ b/ore.go @@ -12,8 +12,17 @@ var ( //map the alias type (usually an interface) to the original types (usually implementations of the interface) aliases = map[pointerTypeName][]pointerTypeName{} + + //this is a special context key. The value of this key is the collection of other context keys stored in the context. + contextKeysRepositoryID = contextKey{ + typeID{ + pointerTypeName: "", + oreKey: "The context keys collection of Concrete Scoped Instances", + }, -1} ) +type contextKeysRepository = []contextKey + type Creator[T any] interface { New(ctx context.Context) (T, context.Context) } diff --git a/serviceResolver.go b/serviceResolver.go index e053ca1..05375e2 100644 --- a/serviceResolver.go +++ b/serviceResolver.go @@ -8,6 +8,8 @@ type ( type serviceResolver interface { resolveService(ctx context.Context, typeId typeID, index int) (any, context.Context) + //return the invoked singleton value, or false if the resolver is not a singleton or has not been invoked + getInvokedSingleton() (con any, isInvokedSingleton bool) } type serviceResolverImpl[T any] struct { @@ -49,6 +51,7 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID ty // if scoped, attach to the current context if this.lifetime == Scoped { ctx = context.WithValue(ctx, ctxKey, con) + ctx = addToContextKeysRepository(ctx, ctxKey) } // if was lazily-created, then attach the newly-created concrete implementation @@ -61,3 +64,20 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID ty return con, ctx } + +func (this serviceResolverImpl[T]) getInvokedSingleton() (con any, isInvokedSingleton bool) { + if this.lifetime == Singleton && this.singletonConcrete != nil { + return *this.singletonConcrete, true + } + return nil, false +} + +func addToContextKeysRepository(ctx context.Context, newContextKey contextKey) context.Context { + repository, ok := ctx.Value(contextKeysRepositoryID).(contextKeysRepository) + if ok { + repository = append(repository, newContextKey) + } else { + repository = contextKeysRepository{newContextKey} + } + return context.WithValue(ctx, contextKeysRepositoryID, repository) +} From 6c3ed6ea23caf91456b8de0a82297d44953171f4 Mon Sep 17 00:00:00 2001 From: Phu-Hiep DUONG Date: Mon, 4 Nov 2024 07:41:00 +0100 Subject: [PATCH 6/6] respect the invocation order for gracefully shutdown --- README.md | 6 +++-- concrete.go | 13 ++++++++++ get_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++ getters.go | 51 ++++++++++++++++++++++++++---------- ore.go | 4 +-- registrars.go | 9 +++++-- serviceResolver.go | 37 ++++++++++++++++---------- 7 files changed, 152 insertions(+), 33 deletions(-) create mode 100644 concrete.go diff --git a/README.md b/README.md index 89f9cf7..2bd70cc 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,7 @@ ore.RegisterEagerSingleton(&SomeService{}, "some_module") //*SomeService impleme shutdowables := ore.GetResolvedSingletons[Shutdowner]() //Now we can Shutdown() them all and gracefully terminate our application. +//The most recently created instance will be Shutdown() first for _, instance := range disposables { instance.Shutdown() } @@ -352,6 +353,7 @@ In resume, the `ore.GetResolvedSingletons[TInterface]()` function returns a list - It returns only the instances which had been invoked (a.k.a resolved). - All the implementations including "keyed" one will be returned. +- The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. ### Graceful context termination @@ -375,6 +377,7 @@ go func() { <-ctx.Done() // Wait for the context to be canceled // Perform your cleanup tasks here disposables := ore.GetResolvedScopedInstances[Disposer](ctx) + //The most recently created instance will be Dispose() first for _, d := range disposables { _ = d.Dispose(ctx) } @@ -389,8 +392,7 @@ The `ore.GetResolvedScopedInstances[TInterface](context)` function returns a lis - It returns only the instances which had been invoked (a.k.a resolved) during the context life time. - All the implementations including "keyed" one will be returned. - -This function would help us to gracefully terminate the context. Example: +- The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. ## More Complex Example diff --git a/concrete.go b/concrete.go new file mode 100644 index 0000000..29200d8 --- /dev/null +++ b/concrete.go @@ -0,0 +1,13 @@ +package ore + +import "time" + +// concrete holds the resolved instance value and other metadata +type concrete struct { + // the value implementation + value any + // the creation time + createdAt time.Time + // the lifetime of this concrete + lifetime Lifetime +} diff --git a/get_test.go b/get_test.go index b78b947..9ef6e52 100644 --- a/get_test.go +++ b/get_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" m "github.com/firasdarwish/ore/internal/models" "github.com/stretchr/testify/assert" @@ -128,6 +129,38 @@ func TestGetResolvedSingletons(t *testing.T) { assert.Equal(t, 6, len(disposables)) } +func TestGetResolvedSingletonsOrder(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "B"}, ctx + }) + RegisterLazyFunc(Singleton, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "C"}, ctx + }) + + ctx := context.Background() + + //invocation order: [A,C,B] + _, ctx = Get[*m.DisposableService1](ctx) + time.Sleep(1 * time.Millisecond) + _, ctx = Get[*m.DisposableService3](ctx) + time.Sleep(1 * time.Millisecond) + _, _ = Get[*m.DisposableService2](ctx) + + //Act + disposables := GetResolvedSingletons[m.Disposer]() //B, A + + //Assert that the order is [B,C,A], the most recent invocation would be returned first + assert.Equal(t, 3, len(disposables)) + assert.Equal(t, "B", disposables[0].String()) + assert.Equal(t, "C", disposables[1].String()) + assert.Equal(t, "A", disposables[2].String()) +} + func TestGetResolvedScopedInstances(t *testing.T) { clearAll() RegisterEagerSingleton(&m.DisposableService1{Name: "S1"}) @@ -159,3 +192,35 @@ func TestGetResolvedScopedInstances(t *testing.T) { disposables = GetResolvedScopedInstances[m.Disposer](ctx) //S2, T1 assert.Equal(t, 2, len(disposables)) } + +func TestGetResolvedScopedInstancesOrder(t *testing.T) { + //Arrange + clearAll() + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService1, context.Context) { + return &m.DisposableService1{Name: "A"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService2, context.Context) { + return &m.DisposableService2{Name: "B"}, ctx + }) + RegisterLazyFunc(Scoped, func(ctx context.Context) (*m.DisposableService3, context.Context) { + return &m.DisposableService3{Name: "C"}, ctx + }) + + ctx := context.Background() + + //invocation order: [A,C,B] + _, ctx = Get[*m.DisposableService1](ctx) + time.Sleep(1 * time.Millisecond) + _, ctx = Get[*m.DisposableService3](ctx) + time.Sleep(1 * time.Millisecond) + _, ctx = Get[*m.DisposableService2](ctx) + + //Act + disposables := GetResolvedScopedInstances[m.Disposer](ctx) //B, A + + //Assert that the order is [B,C,A], the most recent invocation would be returned first + assert.Equal(t, 3, len(disposables)) + assert.Equal(t, "B", disposables[0].String()) + assert.Equal(t, "C", disposables[1].String()) + assert.Equal(t, "A", disposables[2].String()) +} diff --git a/getters.go b/getters.go index 2863950..1ab5c3a 100644 --- a/getters.go +++ b/getters.go @@ -2,6 +2,7 @@ package ore import ( "context" + "sort" ) func getLastRegisteredResolver(typeId typeID) (serviceResolver, int) { @@ -55,8 +56,8 @@ func Get[T any](ctx context.Context, key ...KeyStringer) (T, context.Context) { if lastRegisteredResolver == nil { panic(noValidImplementation[T]()) } - service, ctx := lastRegisteredResolver.resolveService(ctx, typeID, lastIndex) - return service.(T), ctx + con, ctx := lastRegisteredResolver.resolveService(ctx, typeID, lastIndex) + return con.value.(T), ctx } // GetList Retrieves a list of instances based on type and key @@ -91,8 +92,8 @@ func GetList[T any](ctx context.Context, key ...KeyStringer) ([]T, context.Conte for index := 0; index < len(resolvers); index++ { resolver := resolvers[index] - service, newCtx := resolver.resolveService(ctx, typeID, index) - servicesArray = append(servicesArray, service.(T)) + con, newCtx := resolver.resolveService(ctx, typeID, index) + servicesArray = append(servicesArray, con.value.(T)) ctx = newCtx } } @@ -101,6 +102,7 @@ func GetList[T any](ctx context.Context, key ...KeyStringer) ([]T, context.Conte } // GetResolvedSingletons retrieves a list of Singleton instances that implement the [TInterface]. +// The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. // It would return only the instances which had been resolved. Other lazy implementations which have never been invoked will not be returned. // This function is useful for cleaning operations. // @@ -114,21 +116,25 @@ func GetResolvedSingletons[TInterface any]() []TInterface { lock.RLock() defer lock.RUnlock() - result := []TInterface{} + list := []*concrete{} + + //filtering for _, resolvers := range container { for _, resolver := range resolvers { - invokedValue, isInvokedSingleton := resolver.getInvokedSingleton() + con, isInvokedSingleton := resolver.getInvokedSingleton() if isInvokedSingleton { - if instance, ok := invokedValue.(TInterface); ok { - result = append(result, instance) + if _, ok := con.value.(TInterface); ok { + list = append(list, con) } } } } - return result + + return sortAndSelect[TInterface](list) } // GetResolvedScopedInstances retrieves a list of Scoped instances that implement the [TInterface]. +// The returned instances are sorted by creation time (a.k.a the invocation order), the first one being the most recently created one. // It would return only the instances which had been resolved. Other lazy implementations which have never been invoked will not be returned. // This function is useful for cleaning operations. // @@ -143,12 +149,31 @@ func GetResolvedScopedInstances[TInterface any](ctx context.Context) []TInterfac if !ok { return []TInterface{} } - result := []TInterface{} + + list := []*concrete{} + + //filtering for _, contextKey := range contextKeyRepository { - invokedValue := ctx.Value(contextKey) - if instance, ok := invokedValue.(TInterface); ok { - result = append(result, instance) + con := ctx.Value(contextKey).(*concrete) + if _, ok := con.value.(TInterface); ok { + list = append(list, con) } } + + return sortAndSelect[TInterface](list) +} + +// sortAndSelect sorts concretes by invocation order and return its value. +func sortAndSelect[TInterface any](list []*concrete) []TInterface { + //sorting + sort.Slice(list, func(i, j int) bool { + return list[i].createdAt.After(list[j].createdAt) + }) + + //selecting + result := make([]TInterface, len(list)) + for i := 0; i < len(list); i++ { + result[i] = list[i].value.(TInterface) + } return result } diff --git a/ore.go b/ore.go index ca3823a..1d46404 100644 --- a/ore.go +++ b/ore.go @@ -13,11 +13,11 @@ var ( //map the alias type (usually an interface) to the original types (usually implementations of the interface) aliases = map[pointerTypeName][]pointerTypeName{} - //this is a special context key. The value of this key is the collection of other context keys stored in the context. + //contextKeysRepositoryID is a special context key. The value of this key is the collection of other context keys stored in the context. contextKeysRepositoryID = contextKey{ typeID{ pointerTypeName: "", - oreKey: "The context keys collection of Concrete Scoped Instances", + oreKey: "The context keys repository", }, -1} ) diff --git a/registrars.go b/registrars.go index 7462a71..32bcd06 100644 --- a/registrars.go +++ b/registrars.go @@ -3,6 +3,7 @@ package ore import ( "fmt" "reflect" + "time" ) // RegisterLazyCreator Registers a lazily initialized value using a `Creator[T]` interface @@ -25,8 +26,12 @@ func RegisterEagerSingleton[T comparable](impl T, key ...KeyStringer) { } e := serviceResolverImpl[T]{ - lifetime: Singleton, - singletonConcrete: &impl, + lifetime: Singleton, + singletonConcrete: &concrete{ + value: impl, + lifetime: Singleton, + createdAt: time.Now(), + }, } appendToContainer[T](e, key) } diff --git a/serviceResolver.go b/serviceResolver.go index 05375e2..de6970f 100644 --- a/serviceResolver.go +++ b/serviceResolver.go @@ -1,51 +1,60 @@ package ore -import "context" +import ( + "context" + "time" +) type ( Initializer[T any] func(ctx context.Context) (T, context.Context) ) type serviceResolver interface { - resolveService(ctx context.Context, typeId typeID, index int) (any, context.Context) + resolveService(ctx context.Context, typeId typeID, index int) (*concrete, context.Context) //return the invoked singleton value, or false if the resolver is not a singleton or has not been invoked - getInvokedSingleton() (con any, isInvokedSingleton bool) + getInvokedSingleton() (con *concrete, isInvokedSingleton bool) } type serviceResolverImpl[T any] struct { anonymousInitializer *Initializer[T] creatorInstance Creator[T] - singletonConcrete *T + singletonConcrete *concrete lifetime Lifetime } -//make sure that the `serviceResolverImpl` struct implements the `serviceResolver` interface +// make sure that the `serviceResolverImpl` struct implements the `serviceResolver` interface var _ serviceResolver = serviceResolverImpl[any]{} -func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID typeID, index int) (any, context.Context) { +func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID typeID, index int) (*concrete, context.Context) { // try get concrete implementation if this.lifetime == Singleton && this.singletonConcrete != nil { - return *this.singletonConcrete, ctx + return this.singletonConcrete, ctx } ctxKey := contextKey{typeID, index} // try get concrete from context scope if this.lifetime == Scoped { - scopedConcrete, ok := ctx.Value(ctxKey).(T) + scopedConcrete, ok := ctx.Value(ctxKey).(*concrete) if ok { return scopedConcrete, ctx } } - var con T + var concreteValue T // first, try make concrete implementation from `anonymousInitializer` // if nil, try the concrete implementation `Creator` if this.anonymousInitializer != nil { - con, ctx = (*this.anonymousInitializer)(ctx) + concreteValue, ctx = (*this.anonymousInitializer)(ctx) } else { - con, ctx = this.creatorInstance.New(ctx) + concreteValue, ctx = this.creatorInstance.New(ctx) + } + + con := &concrete{ + value: concreteValue, + lifetime: this.lifetime, + createdAt: time.Now(), } // if scoped, attach to the current context @@ -57,7 +66,7 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID ty // if was lazily-created, then attach the newly-created concrete implementation // to the service resolver if this.lifetime == Singleton { - this.singletonConcrete = &con + this.singletonConcrete = con replaceServiceResolver(typeID, index, this) return con, ctx } @@ -65,9 +74,9 @@ func (this serviceResolverImpl[T]) resolveService(ctx context.Context, typeID ty return con, ctx } -func (this serviceResolverImpl[T]) getInvokedSingleton() (con any, isInvokedSingleton bool) { +func (this serviceResolverImpl[T]) getInvokedSingleton() (con *concrete, isInvokedSingleton bool) { if this.lifetime == Singleton && this.singletonConcrete != nil { - return *this.singletonConcrete, true + return this.singletonConcrete, true } return nil, false }