diff --git a/gorp.go b/gorp.go index aac8c6b9..2a3ccca7 100644 --- a/gorp.go +++ b/gorp.go @@ -1370,17 +1370,21 @@ func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string, // Determine where the results are: written to i, or returned in list if t, _ := toSliceType(i); t == nil { for _, v := range list { - err = runHook("PostGet", reflect.ValueOf(v), hookArg(exec)) - if err != nil { - return nil, err + if v, ok := v.(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } } } else { resultsValue := reflect.Indirect(reflect.ValueOf(i)) for i := 0; i < resultsValue.Len(); i++ { - err = runHook("PostGet", resultsValue.Index(i), hookArg(exec)) - if err != nil { - return nil, err + if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } } } @@ -1708,16 +1712,17 @@ func get(m *DbMap, exec SqlExecutor, i interface{}, } } - err = runHook("PostGet", v, hookArg(exec)) - if err != nil { - return nil, err + if v, ok := v.Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } return v.Interface(), nil } func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { - hookarg := hookArg(exec) count := int64(0) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, true) @@ -1725,10 +1730,12 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { return -1, err } - eptr := elem.Addr() - err = runHook("PreDelete", eptr, hookarg) - if err != nil { - return -1, err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreDelete); ok { + err = v.PreDelete(exec) + if err != nil { + return -1, err + } } bi, err := table.bindDelete(elem) @@ -1752,9 +1759,11 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { count += rows - err = runHook("PostDelete", eptr, hookarg) - if err != nil { - return -1, err + if v, ok := eval.(HasPostDelete); ok { + err := v.PostDelete(exec) + if err != nil { + return -1, err + } } } @@ -1762,7 +1771,6 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { } func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { - hookarg := hookArg(exec) count := int64(0) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, true) @@ -1770,10 +1778,12 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { return -1, err } - eptr := elem.Addr() - err = runHook("PreUpdate", eptr, hookarg) - if err != nil { - return -1, err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreUpdate); ok { + err = v.PreUpdate(exec) + if err != nil { + return -1, err + } } bi, err := table.bindUpdate(elem) @@ -1802,26 +1812,29 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { count += rows - err = runHook("PostUpdate", eptr, hookarg) - if err != nil { - return -1, err + if v, ok := eval.(HasPostUpdate); ok { + err = v.PostUpdate(exec) + if err != nil { + return -1, err + } } } return count, nil } func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { - hookarg := hookArg(exec) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, false) if err != nil { return err } - eptr := elem.Addr() - err = runHook("PreInsert", eptr, hookarg) - if err != nil { - return err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreInsert); ok { + err := v.PreInsert(exec) + if err != nil { + return err + } } bi, err := table.bindInsert(elem) @@ -1850,25 +1863,11 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { } } - err = runHook("PostInsert", eptr, hookarg) - if err != nil { - return err - } - } - return nil -} - -func hookArg(exec SqlExecutor) []reflect.Value { - execval := reflect.ValueOf(exec) - return []reflect.Value{execval} -} - -func runHook(name string, eptr reflect.Value, arg []reflect.Value) error { - hook := eptr.MethodByName(name) - if hook != zeroVal { - ret := hook.Call(arg) - if len(ret) > 0 && !ret[0].IsNil() { - return ret[0].Interface().(error) + if v, ok := eval.(HasPostInsert); ok { + err := v.PostInsert(exec) + if err != nil { + return err + } } } return nil @@ -1889,3 +1888,31 @@ func lockError(m *DbMap, exec SqlExecutor, tableName string, } return -1, ole } + +type HasPostGet interface { + PostGet(SqlExecutor) error +} + +type HasPostDelete interface { + PostDelete(SqlExecutor) error +} + +type HasPostUpdate interface { + PostUpdate(SqlExecutor) error +} + +type HasPostInsert interface { + PostInsert(SqlExecutor) error +} + +type HasPreDelete interface { + PreDelete(SqlExecutor) error +} + +type HasPreUpdate interface { + PreUpdate(SqlExecutor) error +} + +type HasPreInsert interface { + PreInsert(SqlExecutor) error +}