From e4e5494690f7698e8bb473bda27026f6669d2945 Mon Sep 17 00:00:00 2001
From: Andrew Moon <andy.moon@pollenware.com>
Date: Wed, 12 Aug 2015 18:38:50 -0500
Subject: [PATCH] Added support for embedded structs when inserting or
 updating.

---
 dataset_insert.go      | 35 +++++++++++++++------
 dataset_insert_test.go | 70 +++++++++++++++++++++++++++++++++++++-----
 dataset_update.go      | 26 +++++++++++-----
 dataset_update_test.go | 53 ++++++++++++++++++++++++++++++++
 errors.go              |  2 +-
 5 files changed, 161 insertions(+), 25 deletions(-)

diff --git a/dataset_insert.go b/dataset_insert.go
index 75a5b1b0..8832e60d 100644
--- a/dataset_insert.go
+++ b/dataset_insert.go
@@ -3,6 +3,7 @@ package goqu
 import (
 	"reflect"
 	"sort"
+	"time"
 )
 
 //Generates the default INSERT statement. If Prepared has been called with true then the statement will not be interpolated. See examples.
@@ -94,16 +95,7 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList
 				rowCols []interface{}
 				rowVals []interface{}
 			)
-			for j := 0; j < newRowValue.NumField(); j++ {
-				f := newRowValue.Field(j)
-				t := newRowValue.Type().Field(j)
-				if me.canInsertField(t) {
-					if columns == nil {
-						rowCols = append(rowCols, t.Tag.Get("db"))
-					}
-					rowVals = append(rowVals, f.Interface())
-				}
-			}
+			rowCols, rowVals = me.getFieldsValues(newRowValue)
 			if columns == nil {
 				columns = cols(rowCols...)
 			}
@@ -115,6 +107,29 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList
 	return columns, vals, nil
 }
 
+func (me *Dataset) getFieldsValues(value reflect.Value) (rowCols []interface{}, rowVals []interface{}) {
+	if value.IsValid() {
+		for i := 0; i < value.NumField(); i++ {
+			v := value.Field(i)
+
+			kind := v.Kind()
+			if (reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name()) || ((kind != reflect.Struct) && (kind != reflect.Ptr)) {
+				t := value.Type().Field(i)
+				if me.canInsertField(t) {
+					rowCols = append(rowCols, t.Tag.Get("db"))
+					rowVals = append(rowVals, v.Interface())
+				}
+			} else {
+				cols, vals := me.getFieldsValues(reflect.Indirect(reflect.ValueOf(v.Interface())))
+				rowCols = append(rowCols, cols...)
+				rowVals = append(rowVals, vals...)
+			}
+		}
+	}
+
+	return rowCols, rowVals
+}
+
 //Creates an INSERT statement with the columns and values passed in
 func (me *Dataset) insertSql(cols ColumnList, values [][]interface{}, prepared bool) (string, []interface{}, error) {
 	buf := NewSqlBuilder(prepared)
diff --git a/dataset_insert_test.go b/dataset_insert_test.go
index bc16a166..9c9d8b5e 100644
--- a/dataset_insert_test.go
+++ b/dataset_insert_test.go
@@ -3,6 +3,8 @@ package goqu
 import (
 	"github.com/DATA-DOG/go-sqlmock"
 	"github.com/technotronicoz/testify/assert"
+
+	"time"
 )
 
 func (me *datasetTest) TestInsertSqlNoReturning() {
@@ -36,21 +38,75 @@ func (me *datasetTest) TestInsertSqlWithStructs() {
 	t := me.T()
 	ds1 := From("items")
 	type item struct {
+		Address string    `db:"address"`
+		Name    string    `db:"name"`
+		Created time.Time `db:"created"`
+	}
+	created, _ := time.Parse("2006-01-02", "2015-01-01")
+	sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", Created: created})
+	assert.NoError(t, err)
+	assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test', '`+created.Format(time.RFC3339Nano)+`')`)
+
+	sql, _, err = ds1.ToInsertSql(
+		item{Address: "111 Test Addr", Name: "Test1", Created: created},
+		item{Address: "211 Test Addr", Name: "Test2", Created: created},
+		item{Address: "311 Test Addr", Name: "Test3", Created: created},
+		item{Address: "411 Test Addr", Name: "Test4", Created: created},
+	)
+	assert.NoError(t, err)
+	assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test1', '`+created.Format(time.RFC3339Nano)+`'), ('211 Test Addr', 'Test2', '`+created.Format(time.RFC3339Nano)+`'), ('311 Test Addr', 'Test3', '`+created.Format(time.RFC3339Nano)+`'), ('411 Test Addr', 'Test4', '`+created.Format(time.RFC3339Nano)+`')`)
+}
+
+func (me *datasetTest) TestInsertSqlWithEmbeddedStruct() {
+	t := me.T()
+	ds1 := From("items")
+	type phone struct {
+		Primary string `db:"primary_phone"`
+		Home    string `db:"home_phone"`
+	}
+	type item struct {
+		phone
 		Address string `db:"address"`
 		Name    string `db:"name"`
 	}
-	sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr"})
+	sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", phone: phone{Home: "123123", Primary: "456456"}})
 	assert.NoError(t, err)
-	assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`)
+	assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test')`)
 
 	sql, _, err = ds1.ToInsertSql(
-		item{Address: "111 Test Addr", Name: "Test1"},
-		item{Address: "211 Test Addr", Name: "Test2"},
-		item{Address: "311 Test Addr", Name: "Test3"},
-		item{Address: "411 Test Addr", Name: "Test4"},
+		item{Address: "111 Test Addr", Name: "Test1", phone: phone{Home: "123123", Primary: "456456"}},
+		item{Address: "211 Test Addr", Name: "Test2", phone: phone{Home: "123123", Primary: "456456"}},
+		item{Address: "311 Test Addr", Name: "Test3", phone: phone{Home: "123123", Primary: "456456"}},
+		item{Address: "411 Test Addr", Name: "Test4", phone: phone{Home: "123123", Primary: "456456"}},
 	)
 	assert.NoError(t, err)
-	assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1'), ('211 Test Addr', 'Test2'), ('311 Test Addr', 'Test3'), ('411 Test Addr', 'Test4')`)
+	assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test1'), ('456456', '123123', '211 Test Addr', 'Test2'), ('456456', '123123', '311 Test Addr', 'Test3'), ('456456', '123123', '411 Test Addr', 'Test4')`)
+}
+
+func (me *datasetTest) TestInsertSqlWithEmbeddedStructPtr() {
+	t := me.T()
+	ds1 := From("items")
+	type phone struct {
+		Primary string `db:"primary_phone"`
+		Home    string `db:"home_phone"`
+	}
+	type item struct {
+		*phone
+		Address string `db:"address"`
+		Name    string `db:"name"`
+	}
+	sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", phone: &phone{Home: "123123", Primary: "456456"}})
+	assert.NoError(t, err)
+	assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test')`)
+
+	sql, _, err = ds1.ToInsertSql(
+	item{Address: "111 Test Addr", Name: "Test1", phone: &phone{Home: "123123", Primary: "456456"}},
+	item{Address: "211 Test Addr", Name: "Test2", phone: &phone{Home: "123123", Primary: "456456"}},
+	item{Address: "311 Test Addr", Name: "Test3", phone: &phone{Home: "123123", Primary: "456456"}},
+	item{Address: "411 Test Addr", Name: "Test4", phone: &phone{Home: "123123", Primary: "456456"}},
+	)
+	assert.NoError(t, err)
+	assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test1'), ('456456', '123123', '211 Test Addr', 'Test2'), ('456456', '123123', '311 Test Addr', 'Test3'), ('456456', '123123', '411 Test Addr', 'Test4')`)
 }
 
 func (me *datasetTest) TestInsertSqlWithMaps() {
diff --git a/dataset_update.go b/dataset_update.go
index 84444d19..eca50aa9 100644
--- a/dataset_update.go
+++ b/dataset_update.go
@@ -3,6 +3,7 @@ package goqu
 import (
 	"reflect"
 	"sort"
+	"time"
 )
 
 func (me *Dataset) canUpdateField(field reflect.StructField) bool {
@@ -38,13 +39,7 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error
 			updates = append(updates, I(key.String()).Set(updateValue.MapIndex(key).Interface()))
 		}
 	case reflect.Struct:
-		for j := 0; j < updateValue.NumField(); j++ {
-			f := updateValue.Field(j)
-			t := updateValue.Type().Field(j)
-			if me.canUpdateField(t) {
-				updates = append(updates, I(t.Tag.Get("db")).Set(f.Interface()))
-			}
-		}
+		updates = me.getUpdateExpression(updateValue)
 	default:
 		return "", nil, NewGoquError("Unsupported update interface type %+v", updateValue.Type())
 	}
@@ -81,3 +76,20 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error
 	sql, args := buf.ToSql()
 	return sql, args, nil
 }
+
+func (me *Dataset) getUpdateExpression(value reflect.Value) (updates []UpdateExpression) {
+	for i := 0; i < value.NumField(); i++ {
+		v := value.Field(i)
+		kind := v.Kind()
+		if reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name() || (kind != reflect.Struct && kind != reflect.Ptr) {
+			t := value.Type().Field(i)
+			if me.canUpdateField(t) {
+				updates = append(updates, I(t.Tag.Get("db")).Set(v.Interface()))
+			}
+		} else {
+			updates = append(updates, me.getUpdateExpression(reflect.Indirect(reflect.ValueOf(v.Interface())))...)
+		}
+	}
+
+	return updates
+}
diff --git a/dataset_update_test.go b/dataset_update_test.go
index 95706bee..fa60a0d0 100644
--- a/dataset_update_test.go
+++ b/dataset_update_test.go
@@ -3,6 +3,7 @@ package goqu
 import (
 	"database/sql/driver"
 	"fmt"
+	"time"
 
 	"github.com/DATA-DOG/go-sqlmock"
 	"github.com/technotronicoz/testify/assert"
@@ -221,6 +222,58 @@ func (me *datasetTest) TestPreparedUpdateSqlWithSkipupdateTag() {
 	assert.Equal(t, sql, `UPDATE "items" SET "name"=?`)
 }
 
+func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStruct() {
+	t := me.T()
+	ds1 := From("items")
+	type phone struct {
+		Primary string    `db:"primary_phone"`
+		Home    string    `db:"home_phone"`
+		Created time.Time `db:"phone_created"`
+	}
+	type item struct {
+		phone
+		Address string    `db:"address" goqu:"skipupdate"`
+		Name    string    `db:"name"`
+		Created time.Time `db:"created"`
+	}
+	created, _ := time.Parse("2006-01-02", "2015-01-01")
+
+	sql, args, err := ds1.Prepared(true).ToUpdateSql(item{Name: "Test", Address: "111 Test Addr", Created: created, phone: phone{
+		Home:    "123123",
+		Primary: "456456",
+		Created: created,
+	}})
+	assert.NoError(t, err)
+	assert.Equal(t, args, []interface{}{"456456", "123123", created, "Test", created})
+	assert.Equal(t, sql, `UPDATE "items" SET "primary_phone"=?,"home_phone"=?,"phone_created"=?,"name"=?,"created"=?`)
+}
+
+func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStructPtr() {
+	t := me.T()
+	ds1 := From("items")
+	type phone struct {
+		Primary string    `db:"primary_phone"`
+		Home    string    `db:"home_phone"`
+		Created time.Time `db:"phone_created"`
+	}
+	type item struct {
+		*phone
+		Address string    `db:"address" goqu:"skipupdate"`
+		Name    string    `db:"name"`
+		Created time.Time `db:"created"`
+	}
+	created, _ := time.Parse("2006-01-02", "2015-01-01")
+
+	sql, args, err := ds1.Prepared(true).ToUpdateSql(item{Name: "Test", Address: "111 Test Addr", Created: created, phone: &phone{
+		Home:    "123123",
+		Primary: "456456",
+		Created: created,
+	}})
+	assert.NoError(t, err)
+	assert.Equal(t, args, []interface{}{"456456", "123123", created, "Test", created})
+	assert.Equal(t, sql, `UPDATE "items" SET "primary_phone"=?,"home_phone"=?,"phone_created"=?,"name"=?,"created"=?`)
+}
+
 func (me *datasetTest) TestPreparedUpdateSqlWithWhere() {
 	t := me.T()
 	ds1 := From("items")
diff --git a/errors.go b/errors.go
index 22c9c424..285a0848 100644
--- a/errors.go
+++ b/errors.go
@@ -25,4 +25,4 @@ type GoquError struct {
 
 func (me GoquError) Error() string {
 	return me.err
-}
\ No newline at end of file
+}