From 35967190380a343ea0e588a7bd94e14d11307a92 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 2 Jan 2023 16:39:39 +0100 Subject: [PATCH] Add `AfterEagerFind()` See https://github.com/gobuffalo/pop/issues/557 Closes https://github.com/gobuffalo/pop/issues/476 Signed-off-by: aeneasr <3372410+aeneasr@users.noreply.github.com> --- callbacks.go | 22 +++++++++-- callbacks_test.go | 13 ++++++- finders.go | 37 +++++++++++++------ pop_test.go | 6 +++ .../20181104135800_callbacks_users.up.fizz | 1 + 5 files changed, 63 insertions(+), 16 deletions(-) diff --git a/callbacks.go b/callbacks.go index b841df642..81c86c7f9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -12,12 +12,24 @@ type AfterFindable interface { AfterFind(*Connection) error } -func (m *Model) afterFind(c *Connection) error { - if x, ok := m.Value.(AfterFindable); ok { +// AfterEagerFindable callback will be called after a record, or records, +// has been retrieved from the database and their associations have been +// eagerly loaded. +type AfterEagerFindable interface { + AfterEagerFind(*Connection) error +} + +func (m *Model) afterFind(c *Connection, eager bool) error { + if x, ok := m.Value.(AfterFindable); ok && !eager { if err := x.AfterFind(c); err != nil { return err } } + if x, ok := m.Value.(AfterEagerFindable); ok && eager { + if err := x.AfterEagerFind(c); err != nil { + return err + } + } // if the "model" is a slice/array we want // to loop through each of the elements in the collection @@ -34,9 +46,13 @@ func (m *Model) afterFind(c *Connection) error { wg.Go(func() error { y := rv.Index(i) y = y.Addr() - if x, ok := y.Interface().(AfterFindable); ok { + if x, ok := y.Interface().(AfterFindable); ok && !eager { return x.AfterFind(c) } + + if x, ok := y.Interface().(AfterEagerFindable); ok && eager { + return x.AfterEagerFind(c) + } return nil }) }(i) diff --git a/callbacks_test.go b/callbacks_test.go index 02ecd4072..ac527c9a4 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -45,6 +45,10 @@ func Test_Callbacks(t *testing.T) { r.Equal("AF", user.AfterF) r.NoError(tx.Find(user, user.ID)) r.Equal("AfterFind", user.AfterF) + r.Empty(user.AfterEF) + + r.NoError(tx.Eager().Find(user, user.ID)) + r.Equal("AfterEagerFind", user.AfterEF) r.NoError(tx.Destroy(user)) @@ -70,11 +74,16 @@ func Test_Callbacks_on_Slice(t *testing.T) { users := CallbacksUsers{} r.NoError(tx.All(&users)) - r.Len(users, 2) - for _, u := range users { r.Equal("AfterFind", u.AfterF) + r.Empty(u.AfterEF) + } + + r.NoError(tx.Eager().All(&users)) + r.Len(users, 2) + for _, u := range users { + r.Equal("AfterEagerFind", u.AfterEF) } }) } diff --git a/finders.go b/finders.go index 2afa0b474..d59cc9e85 100644 --- a/finders.go +++ b/finders.go @@ -66,13 +66,14 @@ func (c *Connection) First(model interface{}) error { // // q.Where("name = ?", "mark").First(&User{}) func (q *Query) First(model interface{}) error { + var m *Model err := q.Connection.timeFunc("First", func() error { q.Limit(1) - m := NewModel(model, q.Connection.Context()) + m = NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil { return err } - return m.afterFind(q.Connection) + return m.afterFind(q.Connection, false) }) if err != nil { @@ -80,10 +81,14 @@ func (q *Query) First(model interface{}) error { } if q.eager { - err = q.eagerAssociations(model) + err := q.eagerAssociations(model) q.disableEager() - return err + if err != nil { + return err + } + return m.afterFind(q.Connection, true) } + return nil } @@ -98,14 +103,15 @@ func (c *Connection) Last(model interface{}) error { // // q.Where("name = ?", "mark").Last(&User{}) func (q *Query) Last(model interface{}) error { + var m *Model err := q.Connection.timeFunc("Last", func() error { q.Limit(1) q.Order("created_at DESC, id DESC") - m := NewModel(model, q.Connection.Context()) + m = NewModel(model, q.Connection.Context()) if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil { return err } - return m.afterFind(q.Connection) + return m.afterFind(q.Connection, false) }) if err != nil { @@ -115,7 +121,10 @@ func (q *Query) Last(model interface{}) error { if q.eager { err = q.eagerAssociations(model) q.disableEager() - return err + if err != nil { + return err + } + return m.afterFind(q.Connection, true) } return nil @@ -132,17 +141,20 @@ func (c *Connection) All(models interface{}) error { // // q.Where("name = ?", "mark").All(&[]User{}) func (q *Query) All(models interface{}) error { + var m *Model err := q.Connection.timeFunc("All", func() error { - m := NewModel(models, q.Connection.Context()) + m = NewModel(models, q.Connection.Context()) err := q.Connection.Dialect.SelectMany(q.Connection, m, *q) if err != nil { return err } + err = q.paginateModel(models) if err != nil { return err } - return m.afterFind(q.Connection) + + return m.afterFind(q.Connection, false) }) if err != nil { @@ -152,7 +164,10 @@ func (q *Query) All(models interface{}) error { if q.eager { err = q.eagerAssociations(models) q.disableEager() - return err + if err != nil { + return err + } + return m.afterFind(q.Connection, true) } return nil @@ -301,7 +316,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error { // Exists returns true/false if a record exists in the database that matches // the query. // -// q.Where("name = ?", "mark").Exists(&User{}) +// q.Where("name = ?", "mark").Exists(&User{}) func (q *Query) Exists(model interface{}) (bool, error) { tmpQuery := Q(q.Connection) q.Clone(tmpQuery) // avoid meddling with original query diff --git a/pop_test.go b/pop_test.go index 95fd8648d..1382c39aa 100644 --- a/pop_test.go +++ b/pop_test.go @@ -364,6 +364,7 @@ type CallbacksUser struct { AfterU string `db:"after_u"` AfterD string `db:"after_d"` AfterF string `db:"after_f"` + AfterEF string `db:"after_ef"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` } @@ -420,6 +421,11 @@ func (u *CallbacksUser) AfterFind(tx *Connection) error { return nil } +func (u *CallbacksUser) AfterEagerFind(tx *Connection) error { + u.AfterEF = "AfterEagerFind" + return nil +} + type Label struct { ID string `db:"id"` } diff --git a/testdata/migrations/20181104135800_callbacks_users.up.fizz b/testdata/migrations/20181104135800_callbacks_users.up.fizz index bc07aba90..b0e5ae8fd 100644 --- a/testdata/migrations/20181104135800_callbacks_users.up.fizz +++ b/testdata/migrations/20181104135800_callbacks_users.up.fizz @@ -9,6 +9,7 @@ create_table("callbacks_users") { t.Column("after_u", "string", {}) t.Column("after_d", "string", {}) t.Column("after_f", "string", {}) + t.Column("after_ef", "string", {}) t.Column("before_v", "string", {}) t.Timestamps() } \ No newline at end of file