From 8d36fc64693fd6adb9bb6cb6ac87feb8be3eda30 Mon Sep 17 00:00:00 2001 From: SMLuthi Date: Wed, 16 Dec 2020 12:07:11 -0800 Subject: [PATCH] bcda-3942 Feature: Upgrade to gorm v2 (#601) * update gorm to v2 * remove hstore dep * update gopkg * remove db close as its not supported in gorm v2 * remove db close err check as its not supported in gorm v2 --- Gopkg.lock | 51 +- Gopkg.toml | 9 +- bcda/api/requests.go | 40 +- bcda/api/requests_test.go | 4 +- bcda/api/v1/api.go | 4 +- bcda/api/v1/api_test.go | 8 +- bcda/api/v2/api_test.go | 4 +- bcda/auth/alpha_test.go | 8 +- bcda/auth/api_test.go | 5 +- bcda/auth/middleware.go | 2 +- bcda/auth/models_test.go | 2 +- bcda/auth/okta_test.go | 10 +- bcda/auth/ssas_test.go | 4 +- bcda/bcdacli/cli.go | 4 +- bcda/bcdacli/cli_test.go | 2 +- bcda/cclf/cclf.go | 12 +- bcda/cclf/cclf_test.go | 15 +- bcda/database/connection.go | 10 +- bcda/database/connection_test.go | 12 +- bcda/database/utils.go | 9 +- bcda/models/models.go | 19 +- bcda/models/models_test.go | 10 +- bcda/models/postgres/repository.go | 16 +- bcda/models/postgres/repository_test.go | 87 +- bcda/models/service_test.go | 2 +- bcda/suppression/suppression.go | 5 +- bcda/suppression/suppression_test.go | 6 +- bcda/testUtils/utils.go | 28 + bcdaworker/main.go | 7 +- bcdaworker/main_test.go | 4 +- db/migrations/migrations_test.go | 19 +- vendor/github.com/jinzhu/gorm/association.go | 377 ----- vendor/github.com/jinzhu/gorm/callback.go | 250 --- .../github.com/jinzhu/gorm/callback_create.go | 197 --- .../github.com/jinzhu/gorm/callback_delete.go | 63 - .../github.com/jinzhu/gorm/callback_query.go | 109 -- .../jinzhu/gorm/callback_query_preload.go | 410 ----- .../jinzhu/gorm/callback_row_query.go | 41 - .../github.com/jinzhu/gorm/callback_save.go | 170 -- .../github.com/jinzhu/gorm/callback_update.go | 121 -- vendor/github.com/jinzhu/gorm/dialect.go | 147 -- .../github.com/jinzhu/gorm/dialect_common.go | 196 --- .../github.com/jinzhu/gorm/dialect_mysql.go | 246 --- .../jinzhu/gorm/dialect_postgres.go | 147 -- .../github.com/jinzhu/gorm/dialect_sqlite3.go | 107 -- .../jinzhu/gorm/dialects/postgres/postgres.go | 81 - .../github.com/jinzhu/gorm/docker-compose.yml | 30 - vendor/github.com/jinzhu/gorm/errors.go | 72 - vendor/github.com/jinzhu/gorm/field.go | 66 - vendor/github.com/jinzhu/gorm/go.mod | 15 - vendor/github.com/jinzhu/gorm/go.sum | 29 - vendor/github.com/jinzhu/gorm/interface.go | 24 - .../jinzhu/gorm/join_table_handler.go | 211 --- vendor/github.com/jinzhu/gorm/logger.go | 141 -- vendor/github.com/jinzhu/gorm/main.go | 881 ---------- vendor/github.com/jinzhu/gorm/model.go | 14 - vendor/github.com/jinzhu/gorm/model_struct.go | 671 -------- vendor/github.com/jinzhu/gorm/naming.go | 124 -- vendor/github.com/jinzhu/gorm/scope.go | 1421 ----------------- vendor/github.com/jinzhu/gorm/search.go | 153 -- vendor/github.com/jinzhu/gorm/test_all.sh | 5 - vendor/github.com/jinzhu/gorm/utils.go | 226 --- vendor/github.com/jinzhu/gorm/wercker.yml | 154 -- vendor/github.com/jinzhu/now/Guardfile | 3 + .../github.com/jinzhu/{gorm => now}/License | 0 vendor/github.com/jinzhu/now/README.md | 134 ++ vendor/github.com/jinzhu/now/go.mod | 3 + vendor/github.com/jinzhu/now/main.go | 194 +++ vendor/github.com/jinzhu/now/now.go | 213 +++ vendor/github.com/jinzhu/now/wercker.yml | 23 + vendor/github.com/lib/pq/hstore/hstore.go | 118 -- vendor/gorm.io/driver/postgres/License | 21 + vendor/gorm.io/driver/postgres/README.md | 16 + vendor/gorm.io/driver/postgres/go.mod | 5 + vendor/gorm.io/driver/postgres/migrator.go | 151 ++ vendor/gorm.io/driver/postgres/postgres.go | 127 ++ .../jinzhu => gorm.io}/gorm/.gitignore | 1 + vendor/gorm.io/gorm/License | 21 + .../jinzhu => gorm.io}/gorm/README.md | 23 +- vendor/gorm.io/gorm/association.go | 482 ++++++ vendor/gorm.io/gorm/callbacks.go | 298 ++++ vendor/gorm.io/gorm/callbacks/associations.go | 369 +++++ vendor/gorm.io/gorm/callbacks/callbacks.go | 51 + vendor/gorm.io/gorm/callbacks/callmethod.go | 23 + vendor/gorm.io/gorm/callbacks/create.go | 361 +++++ vendor/gorm.io/gorm/callbacks/delete.go | 165 ++ vendor/gorm.io/gorm/callbacks/helper.go | 90 ++ vendor/gorm.io/gorm/callbacks/interfaces.go | 39 + vendor/gorm.io/gorm/callbacks/preload.go | 155 ++ vendor/gorm.io/gorm/callbacks/query.go | 228 +++ vendor/gorm.io/gorm/callbacks/raw.go | 16 + vendor/gorm.io/gorm/callbacks/row.go | 21 + vendor/gorm.io/gorm/callbacks/transaction.go | 29 + vendor/gorm.io/gorm/callbacks/update.go | 263 +++ vendor/gorm.io/gorm/chainable_api.go | 292 ++++ vendor/gorm.io/gorm/clause/clause.go | 88 + vendor/gorm.io/gorm/clause/delete.go | 23 + vendor/gorm.io/gorm/clause/expression.go | 301 ++++ vendor/gorm.io/gorm/clause/from.go | 37 + vendor/gorm.io/gorm/clause/group_by.go | 42 + vendor/gorm.io/gorm/clause/insert.go | 39 + vendor/gorm.io/gorm/clause/joins.go | 47 + vendor/gorm.io/gorm/clause/limit.go | 48 + vendor/gorm.io/gorm/clause/locking.go | 31 + vendor/gorm.io/gorm/clause/on_conflict.go | 45 + vendor/gorm.io/gorm/clause/order_by.go | 54 + vendor/gorm.io/gorm/clause/returning.go | 30 + vendor/gorm.io/gorm/clause/select.go | 45 + vendor/gorm.io/gorm/clause/set.go | 60 + vendor/gorm.io/gorm/clause/update.go | 38 + vendor/gorm.io/gorm/clause/values.go | 45 + vendor/gorm.io/gorm/clause/where.go | 177 ++ vendor/gorm.io/gorm/clause/with.go | 4 + vendor/gorm.io/gorm/errors.go | 34 + vendor/gorm.io/gorm/finisher_api.go | 605 +++++++ vendor/gorm.io/gorm/go.mod | 8 + vendor/gorm.io/gorm/go.sum | 4 + vendor/gorm.io/gorm/gorm.go | 386 +++++ vendor/gorm.io/gorm/interfaces.go | 59 + vendor/gorm.io/gorm/logger/logger.go | 183 +++ vendor/gorm.io/gorm/logger/sql.go | 127 ++ vendor/gorm.io/gorm/migrator.go | 70 + vendor/gorm.io/gorm/migrator/migrator.go | 706 ++++++++ vendor/gorm.io/gorm/model.go | 15 + vendor/gorm.io/gorm/prepare_stmt.go | 150 ++ vendor/gorm.io/gorm/scan.go | 247 +++ vendor/gorm.io/gorm/schema/check.go | 32 + vendor/gorm.io/gorm/schema/field.go | 819 ++++++++++ vendor/gorm.io/gorm/schema/index.go | 141 ++ vendor/gorm.io/gorm/schema/interfaces.go | 25 + vendor/gorm.io/gorm/schema/naming.go | 145 ++ vendor/gorm.io/gorm/schema/relationship.go | 566 +++++++ vendor/gorm.io/gorm/schema/schema.go | 280 ++++ vendor/gorm.io/gorm/schema/utils.go | 197 +++ vendor/gorm.io/gorm/soft_delete.go | 136 ++ vendor/gorm.io/gorm/statement.go | 594 +++++++ vendor/gorm.io/gorm/utils/utils.go | 113 ++ 137 files changed, 10824 insertions(+), 7224 deletions(-) delete mode 100644 vendor/github.com/jinzhu/gorm/association.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_create.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_delete.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_query.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_query_preload.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_row_query.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_save.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_update.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_common.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_mysql.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_postgres.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_sqlite3.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go delete mode 100644 vendor/github.com/jinzhu/gorm/docker-compose.yml delete mode 100644 vendor/github.com/jinzhu/gorm/errors.go delete mode 100644 vendor/github.com/jinzhu/gorm/field.go delete mode 100644 vendor/github.com/jinzhu/gorm/go.mod delete mode 100644 vendor/github.com/jinzhu/gorm/go.sum delete mode 100644 vendor/github.com/jinzhu/gorm/interface.go delete mode 100644 vendor/github.com/jinzhu/gorm/join_table_handler.go delete mode 100644 vendor/github.com/jinzhu/gorm/logger.go delete mode 100644 vendor/github.com/jinzhu/gorm/main.go delete mode 100644 vendor/github.com/jinzhu/gorm/model.go delete mode 100644 vendor/github.com/jinzhu/gorm/model_struct.go delete mode 100644 vendor/github.com/jinzhu/gorm/naming.go delete mode 100644 vendor/github.com/jinzhu/gorm/scope.go delete mode 100644 vendor/github.com/jinzhu/gorm/search.go delete mode 100755 vendor/github.com/jinzhu/gorm/test_all.sh delete mode 100644 vendor/github.com/jinzhu/gorm/utils.go delete mode 100644 vendor/github.com/jinzhu/gorm/wercker.yml create mode 100644 vendor/github.com/jinzhu/now/Guardfile rename vendor/github.com/jinzhu/{gorm => now}/License (100%) create mode 100644 vendor/github.com/jinzhu/now/README.md create mode 100644 vendor/github.com/jinzhu/now/go.mod create mode 100644 vendor/github.com/jinzhu/now/main.go create mode 100644 vendor/github.com/jinzhu/now/now.go create mode 100644 vendor/github.com/jinzhu/now/wercker.yml delete mode 100644 vendor/github.com/lib/pq/hstore/hstore.go create mode 100644 vendor/gorm.io/driver/postgres/License create mode 100644 vendor/gorm.io/driver/postgres/README.md create mode 100644 vendor/gorm.io/driver/postgres/go.mod create mode 100644 vendor/gorm.io/driver/postgres/migrator.go create mode 100644 vendor/gorm.io/driver/postgres/postgres.go rename vendor/{github.com/jinzhu => gorm.io}/gorm/.gitignore (82%) create mode 100644 vendor/gorm.io/gorm/License rename vendor/{github.com/jinzhu => gorm.io}/gorm/README.md (54%) create mode 100644 vendor/gorm.io/gorm/association.go create mode 100644 vendor/gorm.io/gorm/callbacks.go create mode 100644 vendor/gorm.io/gorm/callbacks/associations.go create mode 100644 vendor/gorm.io/gorm/callbacks/callbacks.go create mode 100644 vendor/gorm.io/gorm/callbacks/callmethod.go create mode 100644 vendor/gorm.io/gorm/callbacks/create.go create mode 100644 vendor/gorm.io/gorm/callbacks/delete.go create mode 100644 vendor/gorm.io/gorm/callbacks/helper.go create mode 100644 vendor/gorm.io/gorm/callbacks/interfaces.go create mode 100644 vendor/gorm.io/gorm/callbacks/preload.go create mode 100644 vendor/gorm.io/gorm/callbacks/query.go create mode 100644 vendor/gorm.io/gorm/callbacks/raw.go create mode 100644 vendor/gorm.io/gorm/callbacks/row.go create mode 100644 vendor/gorm.io/gorm/callbacks/transaction.go create mode 100644 vendor/gorm.io/gorm/callbacks/update.go create mode 100644 vendor/gorm.io/gorm/chainable_api.go create mode 100644 vendor/gorm.io/gorm/clause/clause.go create mode 100644 vendor/gorm.io/gorm/clause/delete.go create mode 100644 vendor/gorm.io/gorm/clause/expression.go create mode 100644 vendor/gorm.io/gorm/clause/from.go create mode 100644 vendor/gorm.io/gorm/clause/group_by.go create mode 100644 vendor/gorm.io/gorm/clause/insert.go create mode 100644 vendor/gorm.io/gorm/clause/joins.go create mode 100644 vendor/gorm.io/gorm/clause/limit.go create mode 100644 vendor/gorm.io/gorm/clause/locking.go create mode 100644 vendor/gorm.io/gorm/clause/on_conflict.go create mode 100644 vendor/gorm.io/gorm/clause/order_by.go create mode 100644 vendor/gorm.io/gorm/clause/returning.go create mode 100644 vendor/gorm.io/gorm/clause/select.go create mode 100644 vendor/gorm.io/gorm/clause/set.go create mode 100644 vendor/gorm.io/gorm/clause/update.go create mode 100644 vendor/gorm.io/gorm/clause/values.go create mode 100644 vendor/gorm.io/gorm/clause/where.go create mode 100644 vendor/gorm.io/gorm/clause/with.go create mode 100644 vendor/gorm.io/gorm/errors.go create mode 100644 vendor/gorm.io/gorm/finisher_api.go create mode 100644 vendor/gorm.io/gorm/go.mod create mode 100644 vendor/gorm.io/gorm/go.sum create mode 100644 vendor/gorm.io/gorm/gorm.go create mode 100644 vendor/gorm.io/gorm/interfaces.go create mode 100644 vendor/gorm.io/gorm/logger/logger.go create mode 100644 vendor/gorm.io/gorm/logger/sql.go create mode 100644 vendor/gorm.io/gorm/migrator.go create mode 100644 vendor/gorm.io/gorm/migrator/migrator.go create mode 100644 vendor/gorm.io/gorm/model.go create mode 100644 vendor/gorm.io/gorm/prepare_stmt.go create mode 100644 vendor/gorm.io/gorm/scan.go create mode 100644 vendor/gorm.io/gorm/schema/check.go create mode 100644 vendor/gorm.io/gorm/schema/field.go create mode 100644 vendor/gorm.io/gorm/schema/index.go create mode 100644 vendor/gorm.io/gorm/schema/interfaces.go create mode 100644 vendor/gorm.io/gorm/schema/naming.go create mode 100644 vendor/gorm.io/gorm/schema/relationship.go create mode 100644 vendor/gorm.io/gorm/schema/schema.go create mode 100644 vendor/gorm.io/gorm/schema/utils.go create mode 100644 vendor/gorm.io/gorm/soft_delete.go create mode 100644 vendor/gorm.io/gorm/statement.go create mode 100644 vendor/gorm.io/gorm/utils/utils.go diff --git a/Gopkg.lock b/Gopkg.lock index deeb4a0af..20400b85b 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -148,17 +148,6 @@ pruneopts = "UT" revision = "39bbc98d99d7b666759f84514859becf8067128f" -[[projects]] - digest = "1:e40b74b04fa9a70ff39869139c9fd15ec9a9eefde3f51ef3b65405a60ce30e0e" - name = "github.com/jinzhu/gorm" - packages = [ - ".", - "dialects/postgres", - ] - pruneopts = "UT" - revision = "79a77d771dee4e4b60e9c543e8663bbc80466670" - version = "v1.9.12" - [[projects]] branch = "master" digest = "1:fd97437fbb6b7dce04132cf06775bd258cce305c44add58eb55ca86c6c325160" @@ -167,6 +156,14 @@ pruneopts = "UT" revision = "04140366298a54a039076d798123ffa108fff46c" +[[projects]] + digest = "1:3458e5dd5934a9c852e11a4ea54a07de091fdad4c36d0ac033bb90bc16b0a046" + name = "github.com/jinzhu/now" + packages = ["."] + pruneopts = "UT" + revision = "7e7333ac029d4aad7e88500a30889fcd22489425" + version = "v1.1.1" + [[projects]] digest = "1:bb81097a5b62634f3e9fec1014657855610c82d19b9a40c17612e32651e35dca" name = "github.com/jmespath/go-jmespath" @@ -183,11 +180,10 @@ version = "v1.0.2" [[projects]] - digest = "1:2d370ad7a48d3523291ad53a04b9b1510c182880927f800b46221cc465427800" + digest = "1:37ce7d7d80531b227023331002c0d42b4b4b291a96798c82a049d03a54ba79e4" name = "github.com/lib/pq" packages = [ ".", - "hstore", "oid", ] pruneopts = "UT" @@ -518,6 +514,30 @@ pruneopts = "UT" revision = "9856a29383ce1c59f308dd1cf0363a79b5bef6b5" +[[projects]] + digest = "1:cf0eeff871fb146d0a77cfb6eecc38e79f7458fe6d05726ee59654d9799711ca" + name = "gorm.io/driver/postgres" + packages = ["."] + pruneopts = "UT" + revision = "2591a12c4fdf62e81734996d0361db21fe90ae45" + version = "v0.2.4" + +[[projects]] + digest = "1:5051ed026b2633aab6de4ac829eb0b781ab9f92ab66370b3601dbc41feacb19d" + name = "gorm.io/gorm" + packages = [ + ".", + "callbacks", + "clause", + "logger", + "migrator", + "schema", + "utils", + ] + pruneopts = "UT" + revision = "f2321ca164c0e5fd6cdcd5727152b39f2062ca6b" + version = "v1.20.8" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 @@ -534,10 +554,7 @@ "github.com/go-chi/chi/middleware", "github.com/go-chi/render", "github.com/jackc/pgx", - "github.com/jinzhu/gorm", - "github.com/jinzhu/gorm/dialects/postgres", "github.com/lib/pq", - "github.com/newrelic/go-agent", "github.com/newrelic/go-agent/_integrations/nrlogrus", "github.com/newrelic/go-agent/v3/newrelic", "github.com/otiai10/copy", @@ -555,6 +572,8 @@ "github.com/tsenart/vegeta/lib/plot", "github.com/urfave/cli", "golang.org/x/crypto/pbkdf2", + "gorm.io/driver/postgres", + "gorm.io/gorm", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 302b2bbb5..680231a21 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -24,7 +24,6 @@ # go-tests = true # unused-packages = true - [[constraint]] name = "github.com/bgentry/que-go" version = "1.0.1" @@ -76,3 +75,11 @@ [[constraint]] name = "github.com/newrelic/go-agent" version = "3.9.0" + +[[constraint]] + name = "gorm.io/gorm" + version = "1.20.8" + +[[constraint]] + name = "gorm.io/driver/postgres" + version = "0.2.4" diff --git a/bcda/api/requests.go b/bcda/api/requests.go index 4bd006d10..fa25ba8d8 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -68,13 +68,17 @@ func NewHandler(resources []string, basePath string) *Handler { runoutClaimThru := utils.FromEnv("RUNOUT_CLAIM_THRU_DATE", "2020-12-31") runoutClaimThruDate, err := time.Parse(claimThruDateFmt, runoutClaimThru) if err != nil { - log.Fatalf("Failed to parse RUNOUT_CLAIM_THRU_DATE '%s'. Err: %s", runoutClaimThru, err.Error()) + log.Fatalf("Failed to parse RUNOUT_CLAIM_THRU_DATE '%s'. Err: %v", runoutClaimThru, err) } db := database.GetGORMDbConnection() - db.DB().SetMaxOpenConns(utils.GetEnvInt("BCDA_DB_MAX_OPEN_CONNS", 25)) - db.DB().SetMaxIdleConns(utils.GetEnvInt("BCDA_DB_MAX_IDLE_CONNS", 25)) - db.DB().SetConnMaxLifetime(time.Duration(utils.GetEnvInt("BCDA_DB_CONN_MAX_LIFETIME_MIN", 5)) * time.Minute) + dbc, err := db.DB() + if err != nil { + log.Fatalf("Failed to retrieve database connection. Err: %v", err) + } + dbc.SetMaxOpenConns(utils.GetEnvInt("BCDA_DB_MAX_OPEN_CONNS", 25)) + dbc.SetMaxIdleConns(utils.GetEnvInt("BCDA_DB_MAX_IDLE_CONNS", 25)) + dbc.SetConnMaxLifetime(time.Duration(utils.GetEnvInt("BCDA_DB_CONN_MAX_LIFETIME_MIN", 5)) * time.Minute) repository := postgres.NewRepository(db) h.svc = models.NewService(repository, cutoffDuration, utils.GetEnvInt("BCDA_SUPPRESSION_LOOKBACK_DAYS", 60), runoutCutoffDuration, runoutClaimThruDate, @@ -173,21 +177,23 @@ func (h *Handler) bulkRequest(resourceTypes []string, w http.ResponseWriter, r * // their job finishes or time expires (+24 hours default) for any remaining jobs left in a pending or in-progress state. // Overall, this will prevent a queue of concurrent calls from slowing up our system. // NOTE: this logic is relevant to PROD only; simultaneous requests in our lower environments is acceptable (i.e., shared opensbx creds) - if (os.Getenv("DEPLOYMENT_TARGET") == "prod") && - (!db.Find(&pendingAndInProgressJobs, "aco_id = ? AND status IN (?, ?)", acoID, "In Progress", "Pending").RecordNotFound()) { - if types, err := check429(pendingAndInProgressJobs, resourceTypes, version); err != nil { - if _, ok := err.(duplicateTypeError); ok { - w.Header().Set("Retry-After", strconv.Itoa(utils.GetEnvInt("CLIENT_RETRY_AFTER_IN_SECONDS", 0))) - w.WriteHeader(http.StatusTooManyRequests) + if os.Getenv("DEPLOYMENT_TARGET") == "prod" { + db.Where("aco_id = ? AND status IN (?, ?)", acoID, "In Progress", "Pending").Find(&pendingAndInProgressJobs) + if len(pendingAndInProgressJobs) > 0 { + if types, err := check429(pendingAndInProgressJobs, resourceTypes, version); err != nil { + if _, ok := err.(duplicateTypeError); ok { + w.Header().Set("Retry-After", strconv.Itoa(utils.GetEnvInt("CLIENT_RETRY_AFTER_IN_SECONDS", 0))) + w.WriteHeader(http.StatusTooManyRequests) + } else { + log.Error(err) + oo := responseutils.CreateOpOutcome(responseutils.Error, responseutils.Exception, responseutils.Processing, "") + responseutils.WriteError(oo, w, http.StatusInternalServerError) + } + + return } else { - log.Error(err) - oo := responseutils.CreateOpOutcome(responseutils.Error, responseutils.Exception, responseutils.Processing, "") - responseutils.WriteError(oo, w, http.StatusInternalServerError) + resourceTypes = types } - - return - } else { - resourceTypes = types } } diff --git a/bcda/api/requests_test.go b/bcda/api/requests_test.go index a21cf3a20..b4fb5d448 100644 --- a/bcda/api/requests_test.go +++ b/bcda/api/requests_test.go @@ -22,9 +22,9 @@ import ( "github.com/pborman/uuid" "github.com/go-chi/chi" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) type RequestsTestSuite struct { @@ -61,7 +61,7 @@ func (s *RequestsTestSuite) SetupTest() { func (s *RequestsTestSuite) TearDownSuite() { s.NoError(s.db.Unscoped().Delete(&models.Job{}, "aco_id = ?", s.acoID).Error) s.NoError(s.db.Unscoped().Delete(&models.ACO{}, "uuid = ?", s.acoID).Error) - s.db.Close() + database.Close(s.db) } func (s *RequestsTestSuite) TearDownTest() { diff --git a/bcda/api/v1/api.go b/bcda/api/v1/api.go index 247fee93e..897f3a1c7 100644 --- a/bcda/api/v1/api.go +++ b/bcda/api/v1/api.go @@ -61,7 +61,7 @@ func BulkPatientRequest(w http.ResponseWriter, r *http.Request) { Start data export (for the specified group identifier) for all supported resource types Initiates a job to collect data from the Blue Button API for your ACO. The supported Group identifiers are `all` and `runout`. - + The `all` identifier returns data for the group of all patients attributed to the requesting ACO. If used when specifying `_since`: all claims data which has been updated since the specified date will be returned for beneficiaries which have been attributed to the ACO since before the specified date; and all historical claims data will be returned for beneficiaries which have been newly attributed to the ACO since the specified date. The `runout` identifier returns claims runouts data. @@ -113,7 +113,7 @@ func JobStatus(w http.ResponseWriter, r *http.Request) { defer database.Close(db) var job models.Job - err := db.Find(&job, "id = ?", jobID).Error + err := db.First(&job, jobID).Error if err != nil { log.Print(err) oo := responseutils.CreateOpOutcome(responseutils.Error, responseutils.Exception, responseutils.DbErr, "") diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index 3efa40d21..a3a11a23d 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -19,10 +19,10 @@ import ( fhirmodels "github.com/eug48/fhir/models" "github.com/go-chi/chi" "github.com/jackc/pgx" - "github.com/jinzhu/gorm" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/gorm" api "github.com/CMSgov/bcda-app/bcda/api" "github.com/CMSgov/bcda-app/bcda/auth" @@ -1050,14 +1050,14 @@ func (s *APITestSuite) TestAuthInfoOkta() { auth.SetProvider(originalProvider) } -func (s *APITestSuite) verifyJobCount(acoID string, expectedJobCount int) { +func (s *APITestSuite) verifyJobCount(acoID string, expectedJobCount int64) { count, err := s.getJobCount(acoID) assert.NoError(s.T(), err) assert.Equal(s.T(), expectedJobCount, count) } -func (s *APITestSuite) getJobCount(acoID string) (int, error) { - var count int +func (s *APITestSuite) getJobCount(acoID string) (int64, error) { + var count int64 err := s.db.Model(&models.Job{}).Where("aco_id = ?", acoID).Count(&count).Error return count, err } diff --git a/bcda/api/v2/api_test.go b/bcda/api/v2/api_test.go index c5d300718..e409e0920 100644 --- a/bcda/api/v2/api_test.go +++ b/bcda/api/v2/api_test.go @@ -19,11 +19,11 @@ import ( "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/go-chi/chi" - "github.com/jinzhu/gorm" "github.com/pborman/uuid" "github.com/samply/golang-fhir-models/fhir-models/fhir" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) const ( @@ -57,7 +57,7 @@ func (s *APITestSuite) SetupSuite() { func (s *APITestSuite) TearDownSuite() { s.cleanup() - s.db.Close() + database.Close(s.db) } func TestAPITestSuite(t *testing.T) { diff --git a/bcda/auth/alpha_test.go b/bcda/auth/alpha_test.go index 28dd602e5..0fef465e7 100644 --- a/bcda/auth/alpha_test.go +++ b/bcda/auth/alpha_test.go @@ -10,10 +10,10 @@ import ( "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/testUtils" "github.com/dgrijalva/jwt-go" - "github.com/jinzhu/gorm" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) var connections = make(map[string]*gorm.DB) @@ -52,7 +52,11 @@ func (s *AlphaAuthPluginTestSuite) AfterTest(suiteName, testName string) { if !ok { s.FailNow("WTF? no db connection for %s", testName) } - if err := c.Close(); err != nil { + gc, err := c.DB() + if err != nil { + s.FailNow("error retrieving db connection: %s", err) + } + if err := gc.Close(); err != nil { s.FailNow("error closing db connection for %s because %s", testName, err) } } diff --git a/bcda/auth/api_test.go b/bcda/auth/api_test.go index fb75c7c64..a22761edc 100644 --- a/bcda/auth/api_test.go +++ b/bcda/auth/api_test.go @@ -3,17 +3,18 @@ package auth_test import ( "encoding/json" "fmt" - "github.com/go-chi/chi" "io/ioutil" "net/http" "net/http/httptest" "testing" + "github.com/go-chi/chi" + "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/gorm" "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/database" diff --git a/bcda/auth/middleware.go b/bcda/auth/middleware.go index 683dda1f7..47924e450 100644 --- a/bcda/auth/middleware.go +++ b/bcda/auth/middleware.go @@ -164,7 +164,7 @@ func RequireTokenJobMatch(next http.Handler) http.Handler { defer database.Close(db) var job models.Job - err = db.Find(&job, "id = ? and aco_id = ?", i, ad.ACOID).Error + err = db.First(&job, "id = ? and aco_id = ?", i, ad.ACOID).Error if err != nil { log.Error(err) oo := responseutils.CreateOpOutcome(responseutils.Error, responseutils.Exception, responseutils.Not_found, "") diff --git a/bcda/auth/models_test.go b/bcda/auth/models_test.go index dc0e02b4b..4e64a8e1e 100644 --- a/bcda/auth/models_test.go +++ b/bcda/auth/models_test.go @@ -7,10 +7,10 @@ import ( "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/testUtils" - "github.com/jinzhu/gorm" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) type ModelsTestSuite struct { diff --git a/bcda/auth/okta_test.go b/bcda/auth/okta_test.go index d6e7b9217..cde21332f 100644 --- a/bcda/auth/okta_test.go +++ b/bcda/auth/okta_test.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "fmt" "regexp" "testing" @@ -12,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) const KnownFixtureACO = "DBBD1CE1-AE24-435C-807D-ED45953077D3" @@ -25,14 +27,14 @@ type OktaAuthPluginTestSuite struct { func (s *OktaAuthPluginTestSuite) SetupSuite() { db := database.GetGORMDbConnection() + defer func() { - if err := db.Close(); err != nil { - assert.Failf(s.T(), err.Error(), "okta plugin test") - } + database.Close(db) }() var aco models.ACO - if db.Find(&aco, "UUID = ?", uuid.Parse(KnownFixtureACO)).RecordNotFound() { + err := db.First(&aco, "UUID = ?", uuid.Parse(KnownFixtureACO)).Error + if errors.Is(err, gorm.ErrRecordNotFound) { assert.NotNil(s.T(), fmt.Errorf("Unable to find ACO %s", KnownFixtureACO)) return } diff --git a/bcda/auth/ssas_test.go b/bcda/auth/ssas_test.go index ab5e4bc5b..9bfe4bb1a 100644 --- a/bcda/auth/ssas_test.go +++ b/bcda/auth/ssas_test.go @@ -56,7 +56,7 @@ func (s *SSASPluginTestSuite) SetupSuite() { func (s *SSASPluginTestSuite) SetupTest() { db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) db.Create(&models.ACO{ UUID: uuid.Parse(testACOUUID), @@ -75,7 +75,7 @@ func (s *SSASPluginTestSuite) TearDownTest() { os.Setenv("BCDA_SSAS_SECRET", origSSASSecret) db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) db.Unscoped().Delete(models.ACO{}, "uuid = ?", testACOUUID) } diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index 7a6a19bdd..c81de9b0d 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -524,7 +524,7 @@ func createGroup(id, name, acoID string) (string, error) { aco.GroupID = ssasID db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) err = db.Save(&aco).Error if err != nil { @@ -678,7 +678,7 @@ func setBlacklistState(cmsID string, blacklistState bool) error { return err } db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) return db.Model(&aco).Update("blacklisted", blacklistState).Error } diff --git a/bcda/bcdacli/cli_test.go b/bcda/bcdacli/cli_test.go index a09596704..0f93deb25 100644 --- a/bcda/bcdacli/cli_test.go +++ b/bcda/bcdacli/cli_test.go @@ -22,12 +22,12 @@ import ( "github.com/CMSgov/bcda-app/bcda/testUtils" "github.com/CMSgov/bcda-app/bcda/utils" "github.com/go-chi/chi" - "github.com/jinzhu/gorm" "github.com/pborman/uuid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/urfave/cli" + "gorm.io/gorm" ) var origDate string diff --git a/bcda/cclf/cclf.go b/bcda/cclf/cclf.go index 1a658eb28..f8a081c09 100644 --- a/bcda/cclf/cclf.go +++ b/bcda/cclf/cclf.go @@ -12,9 +12,9 @@ import ( "strings" "time" - "github.com/jinzhu/gorm" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "gorm.io/gorm" "github.com/CMSgov/bcda-app/bcda/cclf/metrics" "github.com/CMSgov/bcda-app/bcda/constants" @@ -241,7 +241,11 @@ func importCCLF8(ctx context.Context, fileMetadata *cclfFileMetadata) (err error sc := bufio.NewScanner(rc) // Open transaction to encompass entire CCLF file ingest. - txn, err := db.DB().Begin() + sdb, err := db.DB() + if err != nil { + return err + } + txn, err := sdb.Begin() if err != nil { return err } @@ -395,7 +399,7 @@ func orderACOs(cclfMap map[string]map[metadataKey][]*cclfFileMetadata) []string var acoOrder []string db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) priorityACOs := getPriorityACOs(db) for _, acoID := range priorityACOs { @@ -423,7 +427,7 @@ func getPriorityACOs(db *gorm.DB) []string { ` var acoIDs []string - if err := db.Raw(query).Pluck("aco_id", &acoIDs).Error; err != nil { + if err := db.Raw(query).Scan(&acoIDs).Error; err != nil { log.Warnf("Failed to query for active ACOs %s. No ACOs are prioritized.", err.Error()) return nil } diff --git a/bcda/cclf/cclf_test.go b/bcda/cclf/cclf_test.go index c50efae88..970c20af0 100644 --- a/bcda/cclf/cclf_test.go +++ b/bcda/cclf/cclf_test.go @@ -19,7 +19,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/constants" "github.com/CMSgov/bcda-app/bcda/testUtils" - "github.com/jinzhu/gorm" + "gorm.io/gorm" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -336,19 +336,10 @@ func (s *CCLFTestSuite) TestGetPriorityACOs() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - gdb, err := gorm.Open("postgres", db) - if err != nil { - t.Fatalf("Failed to instantiate gorm db %s", err.Error()) - } - gdb.LogMode(true) + gdb, mock := testUtils.GetGormMock(t) defer func() { assert.NoError(t, mock.ExpectationsWereMet()) - gdb.Close() - db.Close() + database.Close(gdb) }() expected := mock.ExpectQuery(query) diff --git a/bcda/database/connection.go b/bcda/database/connection.go index 3b84f2a8c..dadd5b1d7 100644 --- a/bcda/database/connection.go +++ b/bcda/database/connection.go @@ -5,9 +5,9 @@ import ( "log" "os" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/lib/pq" + "gorm.io/driver/postgres" + "gorm.io/gorm" ) // Variable substitution to support testing. @@ -28,13 +28,9 @@ func GetDbConnection() *sql.DB { func GetGORMDbConnection() *gorm.DB { databaseURL := os.Getenv("DATABASE_URL") - db, err := gorm.Open("postgres", databaseURL) + db, err := gorm.Open(postgres.Open(databaseURL), &gorm.Config{}) if err != nil { LogFatal(err) } - pingErr := db.DB().Ping() - if pingErr != nil { - LogFatal(pingErr) - } return db } diff --git a/bcda/database/connection_test.go b/bcda/database/connection_test.go index fd14f75e2..5edeec29e 100644 --- a/bcda/database/connection_test.go +++ b/bcda/database/connection_test.go @@ -6,9 +6,9 @@ import ( "os" "testing" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) type ConnectionTestSuite struct { @@ -40,11 +40,12 @@ func (suite *ConnectionTestSuite) TestDbConnections() { // asert that Ping returns an error assert.NotNil(suite.T(), suite.db.Ping(), "Database should fail to connect (negative scenario)") - assert.NotNil(suite.T(), suite.gormdb.DB().Ping(), "Gorm database should fail to connect (negative scenario)") + gdb, _ := suite.gormdb.DB() + assert.NotNil(suite.T(), gdb.Ping(), "Gorm database should fail to connect (negative scenario)") // close DBs to reset the test suite.db.Close() - suite.gormdb.Close() + Close(suite.gormdb) // set the database URL back to the real value to test the positive scenarios os.Setenv("DATABASE_URL", actualDatabaseURL) @@ -53,11 +54,12 @@ func (suite *ConnectionTestSuite) TestDbConnections() { defer suite.db.Close() suite.gormdb = GetGORMDbConnection() - defer suite.gormdb.Close() + defer Close(suite.gormdb) // assert that Ping() does not return an error assert.Nil(suite.T(), suite.db.Ping(), "Error connecting to sql database") - assert.Nil(suite.T(), suite.gormdb.DB().Ping(), "Error connecting to gorm database") + gdb, _ = suite.gormdb.DB() + assert.Nil(suite.T(), gdb.Ping(), "Error connecting to gorm database") } diff --git a/bcda/database/utils.go b/bcda/database/utils.go index 2b1bb8700..387be01f9 100644 --- a/bcda/database/utils.go +++ b/bcda/database/utils.go @@ -3,12 +3,17 @@ package database import ( "runtime" - "github.com/jinzhu/gorm" log "github.com/sirupsen/logrus" + "gorm.io/gorm" ) func Close(db *gorm.DB) { - if err := db.Close(); err != nil { + dbc, err := db.DB() + if err != nil { + log.Infof("failed to retrieve db connection: %v", err) + return + } + if err := dbc.Close(); err != nil { _, file, line, _ := runtime.Caller(1) log.Infof("failed to close db connection at %s#%d because %s", file, line, err) } diff --git a/bcda/models/models.go b/bcda/models/models.go index afae5931e..67db22190 100644 --- a/bcda/models/models.go +++ b/bcda/models/models.go @@ -16,10 +16,10 @@ import ( "github.com/CMSgov/bcda-app/bcda/client" "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/utils" - "github.com/jinzhu/gorm" "github.com/pborman/uuid" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "gorm.io/gorm" ) type Job struct { @@ -31,7 +31,6 @@ type Job struct { TransactionTime time.Time // most recent data load transaction time from BFD JobCount int CompletedJobCount int - JobKeys []JobKey } func (job *Job) CheckCompletedAndCleanup(db *gorm.DB) (bool, error) { @@ -83,7 +82,6 @@ func (j *Job) StatusMessage() string { type JobKey struct { gorm.Model - Job Job `gorm:"foreignkey:jobID"` JobID uint `gorm:"primary_key" json:"job_id"` FileName string `gorm:"type:char(127)"` ResourceType string @@ -223,7 +221,7 @@ type CCLFFile struct { func (cclfFile *CCLFFile) Delete() error { db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) err := db.Unscoped().Where("file_id = ?", cclfFile.ID).Delete(&CCLFBeneficiary{}).Error if err != nil { return err @@ -235,11 +233,11 @@ func (cclfFile *CCLFFile) Delete() error { // https://www.cms.gov/Medicare/New-Medicare-Card/Understanding-the-MBI-with-Format.pdf type CCLFBeneficiary struct { gorm.Model - CCLFFile CCLFFile - FileID uint `gorm:"not null;index:idx_cclf_beneficiaries_file_id"` - HICN string `gorm:"type:varchar(11);not null;index:idx_cclf_beneficiaries_hicn"` - MBI string `gorm:"type:char(11);not null;index:idx_cclf_beneficiaries_mbi"` - BlueButtonID string `gorm:"type: text;index:idx_cclf_beneficiaries_bb_id"` + CCLFFile CCLFFile `gorm:"foreignkey:file_id;association_foreignkey:id"` + FileID uint `gorm:"not null;index:idx_cclf_beneficiaries_file_id"` + HICN string `gorm:"type:varchar(11);not null;index:idx_cclf_beneficiaries_hicn"` + MBI string `gorm:"type:char(11);not null;index:idx_cclf_beneficiaries_mbi"` + BlueButtonID string `gorm:"type: text;index:idx_cclf_beneficiaries_bb_id"` } type SuppressionFile struct { @@ -251,7 +249,7 @@ type SuppressionFile struct { func (suppressionFile *SuppressionFile) Delete() error { db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) err := db.Unscoped().Where("file_id = ?", suppressionFile.ID).Delete(&Suppression{}).Error if err != nil { return err @@ -261,7 +259,6 @@ func (suppressionFile *SuppressionFile) Delete() error { type Suppression struct { gorm.Model - SuppressionFile SuppressionFile FileID uint `gorm:"not null"` MBI string `gorm:"type:varchar(11);index:idx_suppression_mbi"` HICN string `gorm:"type:varchar(11)"` diff --git a/bcda/models/models_test.go b/bcda/models/models_test.go index 7bea8621d..d1436638d 100644 --- a/bcda/models/models_test.go +++ b/bcda/models/models_test.go @@ -19,10 +19,10 @@ import ( "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/testUtils" "github.com/go-chi/chi" - "github.com/jinzhu/gorm" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) type ModelsTestSuite struct { @@ -493,7 +493,7 @@ func (s *ModelsTestSuite) TestDuplicateCCLFFileNames() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { var err error - var expectedFileCount int + var expectedFileCount int64 for _, acoID := range tt.acoIDs { cclfFile := &CCLFFile{ Name: tt.fileName, @@ -517,7 +517,7 @@ func (s *ModelsTestSuite) TestDuplicateCCLFFileNames() { assert.NoError(t, err) } - var count int + var count int64 s.db.Model(&CCLFFile{}).Where("name = ?", tt.fileName).Count(&count) assert.True(t, expectedFileCount > 0) assert.Equal(t, expectedFileCount, count) @@ -538,11 +538,11 @@ func (s *ModelsTestSuite) TestCMSID() { defer s.db.Unscoped().Delete(aco) var actualCMSID []string - assert.NoError(s.T(), s.db.Find(&ACO{}, "id = ?", aco.ID).Pluck("cms_id", &actualCMSID).Error) + assert.NoError(s.T(), s.db.Model(&ACO{}).Where("id = ?", aco.ID).Pluck("cms_id", &actualCMSID).Error) assert.Equal(s.T(), 1, len(actualCMSID)) assert.Equal(s.T(), cmsID, actualCMSID[0]) - assert.NoError(s.T(), s.db.Find(&CCLFFile{}, "id = ?", cclfFile.ID).Pluck("aco_cms_id", &actualCMSID).Error) + assert.NoError(s.T(), s.db.Model(&CCLFFile{}).Where("id = ?", cclfFile.ID).Pluck("aco_cms_id", &actualCMSID).Error) assert.Equal(s.T(), 1, len(actualCMSID)) assert.Equal(s.T(), cmsID, actualCMSID[0]) } diff --git a/bcda/models/postgres/repository.go b/bcda/models/postgres/repository.go index d7d7e53e2..a9117e223 100644 --- a/bcda/models/postgres/repository.go +++ b/bcda/models/postgres/repository.go @@ -1,13 +1,14 @@ package postgres import ( + "errors" "strconv" "time" "github.com/CMSgov/bcda-app/bcda/constants" "github.com/CMSgov/bcda-app/bcda/models" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) // Ensure Repository satisfies the interface @@ -52,7 +53,7 @@ func (r *Repository) GetLatestCCLFFile(cmsID string, cclfNum int, importStatus s } result = result.Order("timestamp DESC").First(&cclfFile) - if result.RecordNotFound() { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } @@ -70,16 +71,11 @@ func (r *Repository) GetCCLFBeneficiaryMBIs(cclfFileID uint) ([]string, error) { } func (r *Repository) GetCCLFBeneficiaries(cclfFileID uint, ignoredMBIs []string) ([]*models.CCLFBeneficiary, error) { - - const ( - // this is used to get unique ids for de-duplicating MBIs that are listed multiple times in the CCLF8 file - idQuery = "SELECT id FROM ( SELECT max(id) as id, mbi FROM cclf_beneficiaries WHERE file_id = ? GROUP BY mbi ) as id" - ) var beneficiaries []*models.CCLFBeneficiary // NOTE: We changed the query that was being used for "old benes" // By querying by IDs, we really should not need to also query by the corresponding MBIs as well - query := r.db.Where("id in (?)", r.db.Raw(idQuery, cclfFileID).SubQuery()) + query := r.db.Where("id in (?)", r.db.Table("cclf_beneficiaries").Select("MAX(id)").Where("file_id = ?", cclfFileID).Group("mbi")) if len(ignoredMBIs) != 0 { query = query.Not("mbi", ignoredMBIs) @@ -100,12 +96,12 @@ func (r *Repository) GetSuppressedMBIs(lookbackDays int) ([]string, error) { FROM ( SELECT mbi, MAX(effective_date) max_date FROM suppressions - WHERE (NOW() - interval '`+strconv.Itoa(lookbackDays)+` days') < effective_date AND effective_date <= NOW() + WHERE (NOW() - interval '` + strconv.Itoa(lookbackDays) + ` days') < effective_date AND effective_date <= NOW() AND preference_indicator != '' GROUP BY mbi ) h JOIN suppressions s ON s.mbi = h.mbi and s.effective_date = h.max_date - WHERE preference_indicator = 'N'`).Pluck("mbi", &suppressedMBIs).Error; err != nil { + WHERE preference_indicator = 'N'`).Scan(&suppressedMBIs).Error; err != nil { return nil, err } diff --git a/bcda/models/postgres/repository_test.go b/bcda/models/postgres/repository_test.go index 7c658be59..881675829 100644 --- a/bcda/models/postgres/repository_test.go +++ b/bcda/models/postgres/repository_test.go @@ -11,9 +11,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/CMSgov/bcda-app/bcda/constants" + "github.com/CMSgov/bcda-app/bcda/database" + "github.com/CMSgov/bcda-app/bcda/testUtils" "github.com/CMSgov/bcda-app/bcda/models" - "github.com/jinzhu/gorm" + "gorm.io/gorm" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/suite" @@ -45,7 +47,7 @@ func (r *RepositoryTestSuite) TestGetLatestCCLFFile() { time.Time{}, time.Time{}, models.FileTypeDefault, - `SELECT * FROM "cclf_files" WHERE "cclf_files"."deleted_at" IS NULL AND ((aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4)) ORDER BY timestamp DESC`, + `SELECT * FROM "cclf_files" WHERE (aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4) AND "cclf_files"."deleted_at" IS NULL ORDER BY timestamp DESC,"cclf_files"."id" LIMIT 1`, getCCLFFile(cclfNum, cmsID, importStatus, models.FileTypeDefault), }, { @@ -53,7 +55,7 @@ func (r *RepositoryTestSuite) TestGetLatestCCLFFile() { time.Time{}, time.Time{}, models.FileTypeRunout, - `SELECT * FROM "cclf_files" WHERE "cclf_files"."deleted_at" IS NULL AND ((aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4)) ORDER BY timestamp DESC`, + `SELECT * FROM "cclf_files" WHERE (aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4) AND "cclf_files"."deleted_at" IS NULL ORDER BY timestamp DESC,"cclf_files"."id" LIMIT 1`, getCCLFFile(cclfNum, cmsID, importStatus, models.FileTypeRunout), }, { @@ -61,7 +63,7 @@ func (r *RepositoryTestSuite) TestGetLatestCCLFFile() { time.Now(), time.Time{}, models.FileTypeDefault, - `SELECT * FROM "cclf_files" WHERE "cclf_files"."deleted_at" IS NULL AND ((aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4 AND timestamp >= $5)) ORDER BY timestamp DESC`, + `SELECT * FROM "cclf_files" WHERE (aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4 AND timestamp >= $5) AND "cclf_files"."deleted_at" IS NULL ORDER BY timestamp DESC,"cclf_files"."id" LIMIT 1`, getCCLFFile(cclfNum, cmsID, importStatus, models.FileTypeDefault), }, { @@ -69,7 +71,7 @@ func (r *RepositoryTestSuite) TestGetLatestCCLFFile() { time.Time{}, time.Now(), models.FileTypeDefault, - `SELECT * FROM "cclf_files" WHERE "cclf_files"."deleted_at" IS NULL AND ((aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4 AND timestamp <= $5)) ORDER BY timestamp DESC`, + `SELECT * FROM "cclf_files" WHERE (aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4 AND timestamp <= $5) AND "cclf_files"."deleted_at" IS NULL ORDER BY timestamp DESC,"cclf_files"."id" LIMIT 1`, getCCLFFile(cclfNum, cmsID, importStatus, models.FileTypeDefault), }, { @@ -77,7 +79,7 @@ func (r *RepositoryTestSuite) TestGetLatestCCLFFile() { time.Now(), time.Now(), models.FileTypeDefault, - `SELECT * FROM "cclf_files" WHERE "cclf_files"."deleted_at" IS NULL AND ((aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4 AND timestamp >= $5 AND timestamp <= $6)) ORDER BY timestamp DESC`, + `SELECT * FROM "cclf_files" WHERE (aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4 AND timestamp >= $5 AND timestamp <= $6) AND "cclf_files"."deleted_at" IS NULL ORDER BY timestamp DESC,"cclf_files"."id" LIMIT 1`, getCCLFFile(cclfNum, cmsID, importStatus, models.FileTypeDefault), }, { @@ -85,7 +87,7 @@ func (r *RepositoryTestSuite) TestGetLatestCCLFFile() { time.Time{}, time.Time{}, models.FileTypeDefault, - `SELECT * FROM "cclf_files" WHERE "cclf_files"."deleted_at" IS NULL AND ((aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4)) ORDER BY timestamp DESC`, + `SELECT * FROM "cclf_files" WHERE (aco_cms_id = $1 AND cclf_num = $2 AND import_status = $3 AND type = $4) AND "cclf_files"."deleted_at" IS NULL ORDER BY timestamp DESC,"cclf_files"."id" LIMIT 1`, nil, }, } @@ -93,20 +95,11 @@ func (r *RepositoryTestSuite) TestGetLatestCCLFFile() { for _, tt := range tests { r.T().Run(tt.name, func(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - gdb, err := gorm.Open("postgres", db) - if err != nil { - t.Fatalf("Failed to instantiate gorm db %s", err.Error()) - } - + gdb, mock := testUtils.GetGormMock(t) defer func() { - err = mock.ExpectationsWereMet() + err := mock.ExpectationsWereMet() assert.NoError(t, err) - gdb.Close() - db.Close() + database.Close(gdb) }() repository := NewRepository(gdb) @@ -148,12 +141,12 @@ func (r *RepositoryTestSuite) TestGetCCLFBeneficiaryMBIs() { }{ { "HappyPath", - `SELECT mbi FROM "cclf_beneficiaries" WHERE (file_id = $1)`, + `SELECT "mbi" FROM "cclf_beneficiaries" WHERE file_id = $1`, nil, }, { "ErrorOnQuery", - `SELECT mbi FROM "cclf_beneficiaries" WHERE (file_id = $1)`, + `SELECT "mbi" FROM "cclf_beneficiaries" WHERE file_id = $1`, fmt.Errorf("Some SQL error"), }, } @@ -163,20 +156,12 @@ func (r *RepositoryTestSuite) TestGetCCLFBeneficiaryMBIs() { mbis := []string{"0", "1", "2"} cclfFileID := uint(rand.Int63()) - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - gdb, err := gorm.Open("postgres", db) - if err != nil { - t.Fatalf("Failed to instantiate gorm db %s", err.Error()) - } + gdb, mock := testUtils.GetGormMock(t) defer func() { - err = mock.ExpectationsWereMet() + err := mock.ExpectationsWereMet() assert.NoError(t, err) - gdb.Close() - db.Close() + database.Close(gdb) }() repository := NewRepository(gdb) @@ -215,7 +200,7 @@ func (r *RepositoryTestSuite) TestGetCCLFBeneficiaries() { }{ { "NoIgnoreMBIs", - `SELECT * FROM "cclf_beneficiaries" WHERE "cclf_beneficiaries"."deleted_at" IS NULL AND ((id in (( SELECT id FROM ( SELECT max(id) as id, mbi FROM cclf_beneficiaries WHERE file_id = $1 GROUP BY mbi ) as id))))`, + `SELECT * FROM "cclf_beneficiaries" WHERE id in (SELECT MAX(id) FROM "cclf_beneficiaries" WHERE file_id = $1 GROUP BY "mbi") AND "cclf_beneficiaries"."deleted_at" IS NULL`, nil, []*models.CCLFBeneficiary{ getCCLFBeneficiary(), @@ -227,7 +212,7 @@ func (r *RepositoryTestSuite) TestGetCCLFBeneficiaries() { }, { "IgnoredMBIs", - `SELECT * FROM "cclf_beneficiaries" WHERE "cclf_beneficiaries"."deleted_at" IS NULL AND ((id in (( SELECT id FROM ( SELECT max(id) as id, mbi FROM cclf_beneficiaries WHERE file_id = $1 GROUP BY mbi ) as id))) AND ("cclf_beneficiaries"."mbi" NOT IN ($2,$3)))`, + `SELECT * FROM "cclf_beneficiaries" WHERE id in (SELECT MAX(id) FROM "cclf_beneficiaries" WHERE file_id = $1 GROUP BY "mbi") AND "mbi" <> ($2,$3) AND "cclf_beneficiaries"."deleted_at`, []string{"123", "456"}, []*models.CCLFBeneficiary{ getCCLFBeneficiary(), @@ -236,7 +221,7 @@ func (r *RepositoryTestSuite) TestGetCCLFBeneficiaries() { }, { "ErrorOnQuery", - `SELECT * FROM "cclf_beneficiaries" WHERE "cclf_beneficiaries"."deleted_at" IS NULL AND ((id in (( SELECT id FROM ( SELECT max(id) as id, mbi FROM cclf_beneficiaries WHERE file_id = $1 GROUP BY mbi ) as id))))`, + `SELECT * FROM "cclf_beneficiaries" WHERE id in (SELECT MAX(id) FROM "cclf_beneficiaries" WHERE file_id = $1 GROUP BY "mbi") AND "cclf_beneficiaries"."deleted_at" IS NULL`, nil, nil, fmt.Errorf("Some SQL error"), @@ -246,21 +231,11 @@ func (r *RepositoryTestSuite) TestGetCCLFBeneficiaries() { for _, tt := range tests { r.T().Run(tt.name, func(t *testing.T) { cclfFileID := uint(rand.Int63()) - - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - gdb, err := gorm.Open("postgres", db) - if err != nil { - t.Fatalf("Failed to instantiate gorm db %s", err.Error()) - } - + gdb, mock := testUtils.GetGormMock(t) defer func() { - err = mock.ExpectationsWereMet() + err := mock.ExpectationsWereMet() assert.NoError(t, err) - gdb.Close() - db.Close() + database.Close(gdb) }() repository := NewRepository(gdb) @@ -321,21 +296,11 @@ func (r *RepositoryTestSuite) TestGetSuppressedMBIs() { for _, tt := range tests { r.T().Run(tt.name, func(t *testing.T) { suppressedMBIs := []string{"0", "1", "2"} - - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - gdb, err := gorm.Open("postgres", db) - if err != nil { - t.Fatalf("Failed to instantiate gorm db %s", err.Error()) - } - + gdb, mock := testUtils.GetGormMock(t) defer func() { - err = mock.ExpectationsWereMet() + err := mock.ExpectationsWereMet() assert.NoError(t, err) - gdb.Close() - db.Close() + database.Close(gdb) }() repository := NewRepository(gdb) diff --git a/bcda/models/service_test.go b/bcda/models/service_test.go index 1e485618d..1d1a6a86c 100644 --- a/bcda/models/service_test.go +++ b/bcda/models/service_test.go @@ -16,8 +16,8 @@ import ( "github.com/stretchr/testify/assert" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) const ( diff --git a/bcda/suppression/suppression.go b/bcda/suppression/suppression.go index 160d3be08..24bf13284 100644 --- a/bcda/suppression/suppression.go +++ b/bcda/suppression/suppression.go @@ -4,18 +4,19 @@ import ( "bufio" "bytes" "fmt" - "github.com/CMSgov/bcda-app/bcda/constants" "os" "path/filepath" "regexp" "strconv" "time" + "github.com/CMSgov/bcda-app/bcda/constants" + "github.com/CMSgov/bcda-app/bcda/utils" - "github.com/jinzhu/gorm" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "gorm.io/gorm" "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" diff --git a/bcda/suppression/suppression_test.go b/bcda/suppression/suppression_test.go index 733f491b4..c123f8957 100644 --- a/bcda/suppression/suppression_test.go +++ b/bcda/suppression/suppression_test.go @@ -13,7 +13,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/constants" "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/testUtils" - "github.com/jinzhu/gorm" + "gorm.io/gorm" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -56,7 +56,7 @@ func TestSuppressionTestSuite(t *testing.T) { func (s *SuppressionTestSuite) TestImportSuppression() { assert := assert.New(s.T()) db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) // 181120 file fileTime, _ := time.Parse(time.RFC3339, "2018-11-20T10:00:00Z") @@ -129,7 +129,7 @@ func (s *SuppressionTestSuite) TestImportSuppression() { func (s *SuppressionTestSuite) TestImportSuppression_MissingData() { assert := assert.New(s.T()) db := database.GetGORMDbConnection() - defer db.Close() + defer database.Close(db) metadata := &suppressionFileMetadata{} err := importSuppressionData(metadata) diff --git a/bcda/testUtils/utils.go b/bcda/testUtils/utils.go index f1209d8f9..cbbd06678 100644 --- a/bcda/testUtils/utils.go +++ b/bcda/testUtils/utils.go @@ -10,9 +10,13 @@ import ( "path/filepath" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/otiai10/copy" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" ) // PrintSeparator prints a line of stars to stdout @@ -135,3 +139,27 @@ func GetRandomIPV4Address(t *testing.T) string { return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) } + +// GetGormMock returns a gorm.DB along with a sqlmock instance used for testing +// This implementation is based off a newer version of GORM's postgres driver. +// See: https://github.com/go-gorm/postgres/blob/v1.0.5/postgres.go#L24 +// In the newer versions, you can explicitly set the ConnPool on the postgres.Config struct. +// It allows the caller to inject sqlmock's db instance into gorm without forcing the caller to +// rely on connecting via DSN, which will always fail when using sqlmock. +func GetGormMock(t *testing.T) (*gorm.DB, sqlmock.Sqlmock) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + gdb, err := gorm.Open(nil, &gorm.Config{ + ConnPool: db, + }) + if err != nil { + t.Fatalf("Failed to instantiate gorm db %s", err.Error()) + } + gdb.Dialector = &postgres.Dialector{} + callbacks.RegisterDefaultCallbacks(gdb, &callbacks.Config{}) + + return gdb, mock +} diff --git a/bcdaworker/main.go b/bcdaworker/main.go index 4d45491d4..cd08e3fc6 100644 --- a/bcdaworker/main.go +++ b/bcdaworker/main.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "encoding/json" + goerrors "errors" "fmt" "os" "os/signal" @@ -14,11 +15,11 @@ import ( "github.com/bgentry/que-go" "github.com/jackc/pgx" - "github.com/jinzhu/gorm" "github.com/newrelic/go-agent/v3/newrelic" "github.com/pborman/uuid" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "gorm.io/gorm" "github.com/CMSgov/bcda-app/bcda/client" "github.com/CMSgov/bcda-app/bcda/database" @@ -87,7 +88,7 @@ func processJob(j *que.Job) error { var exportJob models.Job result := db.First(&exportJob, "ID = ?", jobArgs.ID) - if result.RecordNotFound() { + if goerrors.Is(result.Error, gorm.ErrRecordNotFound) { // Based on the current backoff delay (j.ErrorCount^4 + 3 seconds), this should've given // us plenty of headroom to ensure that the parent job will never be found. maxNotFoundRetries := int32(utils.GetEnvInt("BCDA_WORKER_MAX_JOB_NOT_FOUND_RETRIES", 3)) @@ -430,7 +431,7 @@ func updateJobStats(jID uint, db *gorm.DB) { var j models.Job if err := db.First(&j, jID).Error; err == nil { - db.Model(&j).Update(models.Job{CompletedJobCount: j.CompletedJobCount + 1}) + db.Model(&j).Update("completed_job_count", j.CompletedJobCount+1) } } diff --git a/bcdaworker/main_test.go b/bcdaworker/main_test.go index 376d00755..90182cc2f 100644 --- a/bcdaworker/main_test.go +++ b/bcdaworker/main_test.go @@ -13,8 +13,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/mock" + "gorm.io/gorm" "github.com/bgentry/que-go" "github.com/pborman/uuid" @@ -90,7 +90,7 @@ func (s *MainTestSuite) TearDownSuite() { testUtils.SetUnitTestKeysForAuth() s.db.Unscoped().Where("aco_id = ?", s.testACO.UUID).Delete(&models.Job{}) s.db.Unscoped().Delete(s.testACO) - s.db.Close() + database.Close(s.db) os.RemoveAll(os.Getenv("FHIR_STAGING_DIR")) os.RemoveAll(os.Getenv("FHIR_PAYLOAD_DIR")) } diff --git a/db/migrations/migrations_test.go b/db/migrations/migrations_test.go index 132115552..5fd9af37c 100644 --- a/db/migrations/migrations_test.go +++ b/db/migrations/migrations_test.go @@ -14,7 +14,8 @@ import ( "github.com/CMSgov/bcda-app/bcda/models" - "github.com/jinzhu/gorm" + "gorm.io/driver/postgres" + "gorm.io/gorm" "github.com/stretchr/testify/assert" @@ -77,11 +78,11 @@ func (s *MigrationTestSuite) TestBCDAMigration() { migrationPath: "./bcda/", dbURL: s.bcdaDBURL, } - db, err := gorm.Open("postgres", s.bcdaDBURL) + db, err := gorm.Open(postgres.Open(s.bcdaDBURL), &gorm.Config{}) if err != nil { assert.FailNowf(s.T(), "Failed to open postgres connection", err.Error()) } - defer db.Close() + defer database.Close(db) migration1Tables := []string{"acos", "cclf_beneficiaries", "cclf_beneficiary_xrefs", "cclf_files", "job_keys", "jobs", "suppression_files", "suppressions"} @@ -96,7 +97,7 @@ func (s *MigrationTestSuite) TestBCDAMigration() { func(t *testing.T) { migrator.runMigration(t, "1") for _, table := range migration1Tables { - assert.True(t, db.HasTable(table), fmt.Sprintf("Table %s should exist", table)) + assert.True(t, db.Migrator().HasTable(table), fmt.Sprintf("Table %s should exist", table)) } }, }, @@ -220,7 +221,7 @@ func (s *MigrationTestSuite) TestBCDAMigration() { func(t *testing.T) { migrator.runMigration(t, "0") for _, table := range migration1Tables { - assert.False(t, db.HasTable(table), fmt.Sprintf("Table %s should not exist", table)) + assert.False(t, db.Migrator().HasTable(table), fmt.Sprintf("Table %s should not exist", table)) } }, }, @@ -236,11 +237,11 @@ func (s *MigrationTestSuite) TestBCDAQueueMigration() { migrationPath: "./bcda_queue/", dbURL: s.bcdaQueueDBURL, } - db, err := gorm.Open("postgres", s.bcdaQueueDBURL) + db, err := gorm.Open(postgres.Open(s.bcdaQueueDBURL), &gorm.Config{}) if err != nil { assert.FailNowf(s.T(), "Failed to open postgres connection", err.Error()) } - defer db.Close() + defer database.Close(db) migration1Tables := []string{"que_jobs"} @@ -254,7 +255,7 @@ func (s *MigrationTestSuite) TestBCDAQueueMigration() { func(t *testing.T) { migrator.runMigration(t, "1") for _, table := range migration1Tables { - assert.True(t, db.HasTable(table), fmt.Sprintf("Table %s should exist", table)) + assert.True(t, db.Migrator().HasTable(table), fmt.Sprintf("Table %s should exist", table)) } }, }, @@ -263,7 +264,7 @@ func (s *MigrationTestSuite) TestBCDAQueueMigration() { func(t *testing.T) { migrator.runMigration(t, "0") for _, table := range migration1Tables { - assert.False(t, db.HasTable(table), fmt.Sprintf("Table %s should not exist", table)) + assert.False(t, db.Migrator().HasTable(table), fmt.Sprintf("Table %s should not exist", table)) } }, }, diff --git a/vendor/github.com/jinzhu/gorm/association.go b/vendor/github.com/jinzhu/gorm/association.go deleted file mode 100644 index a73344fe6..000000000 --- a/vendor/github.com/jinzhu/gorm/association.go +++ /dev/null @@ -1,377 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Association Mode contains some helper methods to handle relationship things easily. -type Association struct { - Error error - scope *Scope - column string - field *Field -} - -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} - -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames, associationForeignDBNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for idx, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) - } - } - } else { - // If has one/many relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, field.DBName) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - return association -} - -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break - } - } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break - } - } - } - } - - return association -} - -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { - return association.Replace() -} - -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - switch relationship.Kind { - case "many_to_many": - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - case "has_many", "has_one": - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - case "belongs_to": - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - relationship.PolymorphicValue, - ) - } - - if err := query.Model(fieldValue).Count(&count).Error; err != nil { - association.Error = err - } - return count -} - -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { - var ( - scope = association.scope - field = association.field - relationship = field.Relationship - ) - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -// setErr set error when the error is not nil. And return Association. -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} diff --git a/vendor/github.com/jinzhu/gorm/callback.go b/vendor/github.com/jinzhu/gorm/callback.go deleted file mode 100644 index 1f0e3c79c..000000000 --- a/vendor/github.com/jinzhu/gorm/callback.go +++ /dev/null @@ -1,250 +0,0 @@ -package gorm - -import "fmt" - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{logger: nopLogger{}} - -// Callback is a struct that contains all CRUD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - logger logger - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - logger logger - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone(logger logger) *Callback { - return &Callback{ - logger: logger, - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - if cp.kind == "row_query" { - if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) - cp.before = "gorm:row_query" - } - } - - cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("CreatedAt", now) -// scope.SetColumn("UpdatedAt", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind { - if p.remove { - callback = nil - } else { - callback = *p.processor - } - } - } - return -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CRUD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/vendor/github.com/jinzhu/gorm/callback_create.go b/vendor/github.com/jinzhu/gorm/callback_create.go deleted file mode 100644 index c4d25f372..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_create.go +++ /dev/null @@ -1,197 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := scope.db.nowFunc() - - if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { - if createdAtField.IsBlank { - createdAtField.Set(now) - } - } - - if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { - if updatedAtField.IsBlank { - updatedAtField.Set(now) - } - } - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal && !field.IsIgnored { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - insertModifier string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - if str, ok := scope.Get("gorm:insert_modifier"); ok { - insertModifier = strings.ToUpper(fmt.Sprint(str)) - if insertModifier == "INTO" { - insertModifier = "" - } - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) - var lastInsertIDReturningSuffix string - if lastInsertIDOutputInterstitial == "" { - lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - } - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v %v%v%v", - addExtraSpaceIfExist(insertModifier), - quotedTableName, - scope.Dialect().DefaultValueStr(), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", - addExtraSpaceIfExist(insertModifier), - scope.QuotedTableName(), - strings.Join(columns, ","), - addExtraSpaceIfExist(lastInsertIDOutputInterstitial), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql: no primaryField - if primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: lastInsertID implemention for majority of dialects - if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - return - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_delete.go b/vendor/github.com/jinzhu/gorm/callback_delete.go deleted file mode 100644 index 48b97acbf..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_delete.go +++ /dev/null @@ -1,63 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" -) - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while deleting")) - return - } - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") - - if !scope.Search.Unscoped && hasDeletedAtField { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v=%v%v%v", - scope.QuotedTableName(), - scope.Quote(deletedAtField.DBName), - scope.AddToVars(scope.db.nowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_query.go b/vendor/github.com/jinzhu/gorm/callback_query.go deleted file mode 100644 index 544afd631..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_query.go +++ /dev/null @@ -1,109 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - //we are only preloading relations, dont touch base model - if _, skip := scope.InstanceGet("gorm:only_preload"); skip { - return - } - - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } else if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_query_preload.go b/vendor/github.com/jinzhu/gorm/callback_query_preload.go deleted file mode 100644 index a936180ad..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_query_preload.go +++ /dev/null @@ -1,410 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - if ap, ok := scope.Get("gorm:auto_preload"); ok { - // If gorm:auto_preload IS NOT a bool then auto preload. - // Else if it IS a bool, use the value - if apb, ok := ap.(bool); !ok { - autoPreload(scope) - } else if apb { - autoPreload(scope) - } - } - - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func autoPreload(scope *Scope) { - for _, field := range scope.Fields() { - if field.Relationship == nil { - continue - } - - if val, ok := field.TagSettingsGet("PRELOAD"); ok { - if preload, err := strconv.ParseBool(val); err != nil { - scope.Err(errors.New("invalid preload option")) - return - } else if !preload { - continue - } - } - - scope.Search.Preload(field.Name) - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - foreignValuesToResults := make(map[string]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) - foreignValuesToResults[foreignValues] = result - } - for j := 0; j < indirectScopeValue.Len(); j++ { - indirectValue := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) - if result, found := foreignValuesToResults[valueString]; found { - indirectValue.FieldByName(field.Name).Set(result) - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - f := object.FieldByName(field.Name) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f.Set(reflect.Append(f, results...)) - } else { - f.Set(reflect.MakeSlice(f.Type(), 0, 0)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - foreignFieldToObjects := make(map[string][]*reflect.Value) - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) - foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) - } - } - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) - if objects, found := foreignFieldToObjects[valueString]; found { - for _, object := range objects { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) - - if len(preloadDB.search.selects) == 0 { - preloadDB = preloadDB.Select("*") - } - - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - scope.New(elem.Addr().Interface()). - InstanceSet("gorm:skip_query_callback", true). - callCallbacks(scope.db.parent.callbacks.queries) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - - for source, fields := range fieldsSourceMap { - for _, f := range fields { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if f.Len() != 0 { - continue - } - - v := reflect.MakeSlice(f.Type(), 0, 0) - if len(linkHash[source]) > 0 { - v = reflect.Append(f, linkHash[source]...) - } - - f.Set(v) - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_row_query.go b/vendor/github.com/jinzhu/gorm/callback_row_query.go deleted file mode 100644 index 323b16054..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_row_query.go +++ /dev/null @@ -1,41 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" -) - -// Define callbacks for row query -func init() { - DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) -} - -type RowQueryResult struct { - Row *sql.Row -} - -type RowsQueryResult struct { - Rows *sql.Rows - Error error -} - -// queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { - if result, ok := scope.InstanceGet("row_query_result"); ok { - scope.prepareQuerySQL() - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) - } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_save.go b/vendor/github.com/jinzhu/gorm/callback_save.go deleted file mode 100644 index 3b4e05895..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_save.go +++ /dev/null @@ -1,170 +0,0 @@ -package gorm - -import ( - "reflect" - "strings" -) - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { - checkTruth := func(value interface{}) bool { - if v, ok := value.(bool); ok && !v { - return false - } - - if v, ok := value.(string); ok { - v = strings.ToLower(v) - return v == "true" - } - - return true - } - - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if r = field.Relationship; r != nil { - autoUpdate, autoCreate, saveReference = true, true, true - - if value, ok := scope.Get("gorm:save_associations"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } - - if value, ok := scope.Get("gorm:association_autoupdate"); ok { - autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { - autoUpdate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_autocreate"); ok { - autoCreate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { - autoCreate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_save_reference"); ok { - saveReference = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { - saveReference = checkTruth(value) - } - } - } - - return -} - -func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - newScope := scope.New(fieldValue) - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if saveReference { - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(newDB.Save(elem).Error) - } - } else if autoUpdate { - scope.Err(newDB.Save(elem).Error) - } - - if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_update.go b/vendor/github.com/jinzhu/gorm/callback_update.go deleted file mode 100644 index 699e534b9..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_update.go +++ /dev/null @@ -1,121 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "sort" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while updating")) - return - } - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", scope.db.nowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - // Sort the column names so that the generated SQL is the same every time. - updateMap := updateAttrs.(map[string]interface{}) - var columns []string - for c := range updateMap { - columns = append(columns, c) - } - sort.Strings(columns) - - for _, column := range columns { - value := updateMap[column] - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { - if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/dialect.go b/vendor/github.com/jinzhu/gorm/dialect.go deleted file mode 100644 index 749587f44..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db SQLCommon) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - // ModifyColumn modify column's type - ModifyColumn(tableName string, columnName string, typ string) error - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) (string, error) - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` - LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string - // DefaultValueStr - DefaultValueStr() string - - // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference - BuildKeyName(kind, tableName string, fields ...string) string - - // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect - NormalizeIndexAndColumn(indexName, columnName string) (string, string) - - // CurrentDatabase return current database name - CurrentDatabase() string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db SQLCommon) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// GetDialect gets the dialect for the specified dialect name -func GetDialect(name string) (dialect Dialect, ok bool) { - dialect, ok = dialectsMap[name] - return -} - -// ParseFieldStructForDialect get field's sql data type -var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var ( - reflectType = field.Struct.Type - dataType, _ = field.TagSettingsGet("TYPE") - ) - - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - if gormDataType, ok := fieldValue.Interface().(interface { - GormDataType(Dialect) string - }); ok { - dataType = gormDataType.GormDataType(dialect) - } - - // Get scanner's real value - if dataType == "" { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - } - - // Default Size - if num, ok := field.TagSettingsGet("SIZE"); ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - notNull, _ := field.TagSettingsGet("NOT NULL") - unique, _ := field.TagSettingsGet("UNIQUE") - additionalType = notNull + " " + unique - if value, ok := field.TagSettingsGet("DEFAULT"); ok { - additionalType = additionalType + " DEFAULT " + value - } - - if value, ok := field.TagSettingsGet("COMMENT"); ok { - additionalType = additionalType + " COMMENT " + value - } - - return fieldValue, dataType, size, strings.TrimSpace(additionalType) -} - -func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_common.go b/vendor/github.com/jinzhu/gorm/dialect_common.go deleted file mode 100644 index d549510cc..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_common.go +++ /dev/null @@ -1,196 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") - -// DefaultForeignKeyNamer contains the default foreign key name generator method -type DefaultForeignKeyNamer struct { -} - -type commonDialect struct { - db SQLCommon - DefaultForeignKeyNamer -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db SQLCommon) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - return strings.ToLower(value) != "false" - } - return field.IsPrimaryKey -} - -func (s *commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) - return err -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -// LimitAndOffsetSQL return generated SQL with Limit and Offset -func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - if parsedLimit, err := s.parseInt(limit); err != nil { - return "", err - } else if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := s.parseInt(offset); err != nil { - return "", err - } else if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (commonDialect) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference -func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) - keyName = keyNameRegex.ReplaceAllString(keyName, "_") - return keyName -} - -// NormalizeIndexAndColumn returns argument's index name and column name without doing anything -func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - return indexName, columnName -} - -func (commonDialect) parseInt(value interface{}) (int64, error) { - return strconv.ParseInt(fmt.Sprint(value), 0, 0) -} - -// IsByteArrayOrSlice returns true of the reflected value is an array or slice -func IsByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_mysql.go b/vendor/github.com/jinzhu/gorm/dialect_mysql.go deleted file mode 100644 index b4467ffa1..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_mysql.go +++ /dev/null @@ -1,246 +0,0 @@ -package gorm - -import ( - "crypto/sha1" - "database/sql" - "fmt" - "reflect" - "regexp" - "strings" - "time" - "unicode/utf8" -) - -var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (s *mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - // MySQL allows only one auto increment column per table, and it must - // be a KEY column. - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { - field.TagSettingsDelete("AUTO_INCREMENT") - } - } - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint AUTO_INCREMENT" - } else { - sqlType = "tinyint" - } - case reflect.Int, reflect.Int16, reflect.Int32: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint unsigned AUTO_INCREMENT" - } else { - sqlType = "tinyint unsigned" - } - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - precision := "" - if p, ok := field.TagSettingsGet("PRECISION"); ok { - precision = fmt.Sprintf("(%s)", p) - } - - if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { - sqlType = fmt.Sprintf("DATETIME%v", precision) - } else { - sqlType = fmt.Sprintf("DATETIME%v NULL", precision) - } - } - default: - if IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - parsedLimit, err := s.parseInt(limit) - if err != nil { - return "", err - } - if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - - if offset != nil { - parsedOffset, err := s.parseInt(offset) - if err != nil { - return "", err - } - if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - } - } - return -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) HasTable(tableName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - var name string - // allow mysql database name with '-' character - if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { - if err == sql.ErrNoRows { - return false - } - panic(err) - } else { - return true - } -} - -func (s mysql) HasIndex(tableName string, indexName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) HasColumn(tableName string, columnName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) - if utf8.RuneCountInString(keyName) <= 64 { - return keyName - } - h := sha1.New() - h.Write([]byte(keyName)) - bs := h.Sum(nil) - - // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) - if len(destRunes) > 24 { - destRunes = destRunes[:24] - } - - return fmt.Sprintf("%s%x", string(destRunes), bs) -} - -// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed -func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - submatch := mysqlIndexRegex.FindStringSubmatch(indexName) - if len(submatch) != 3 { - return indexName, columnName - } - indexName = submatch[1] - columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) - return indexName, columnName -} - -func (mysql) DefaultValueStr() string { - return "VALUES()" -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_postgres.go b/vendor/github.com/jinzhu/gorm/dialect_postgres.go deleted file mode 100644 index d2df31318..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_postgres.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "encoding/json" - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) - RegisterDialect("cloudsqlpostgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (s *postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint32, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettingsGet("SIZE"); !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "bytea" - - if isUUID(dataValue) { - sqlType = "uuid" - } - - if isJSON(dataValue) { - sqlType = "jsonb" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { - return "" -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} - -func isJSON(value reflect.Value) bool { - _, ok := value.Interface().(json.RawMessage) - return ok -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go b/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go deleted file mode 100644 index 5f96c363a..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go +++ /dev/null @@ -1,107 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (s *sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go b/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go deleted file mode 100644 index e6c088b1c..000000000 --- a/vendor/github.com/jinzhu/gorm/dialects/postgres/postgres.go +++ /dev/null @@ -1,81 +0,0 @@ -package postgres - -import ( - "database/sql" - "database/sql/driver" - - "encoding/json" - "errors" - "fmt" - - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" -) - -type Hstore map[string]*string - -// Value get value of Hstore -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -// Scan scan value into Hstore -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} - -// Jsonb Postgresql's JSONB data type -type Jsonb struct { - json.RawMessage -} - -// Value get value of Jsonb -func (j Jsonb) Value() (driver.Value, error) { - if len(j.RawMessage) == 0 { - return nil, nil - } - return j.MarshalJSON() -} - -// Scan scan value into Jsonb -func (j *Jsonb) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) - } - - return json.Unmarshal(bytes, j) -} diff --git a/vendor/github.com/jinzhu/gorm/docker-compose.yml b/vendor/github.com/jinzhu/gorm/docker-compose.yml deleted file mode 100644 index 79bf5fc39..000000000 --- a/vendor/github.com/jinzhu/gorm/docker-compose.yml +++ /dev/null @@ -1,30 +0,0 @@ -version: '3' - -services: - mysql: - image: 'mysql:latest' - ports: - - 9910:3306 - environment: - - MYSQL_DATABASE=gorm - - MYSQL_USER=gorm - - MYSQL_PASSWORD=gorm - - MYSQL_RANDOM_ROOT_PASSWORD="yes" - postgres: - image: 'postgres:latest' - ports: - - 9920:5432 - environment: - - POSTGRES_USER=gorm - - POSTGRES_DB=gorm - - POSTGRES_PASSWORD=gorm - mssql: - image: 'mcmoe/mssqldocker:latest' - ports: - - 9930:1433 - environment: - - ACCEPT_EULA=Y - - SA_PASSWORD=LoremIpsum86 - - MSSQL_DB=gorm - - MSSQL_USER=gorm - - MSSQL_PASSWORD=LoremIpsum86 diff --git a/vendor/github.com/jinzhu/gorm/errors.go b/vendor/github.com/jinzhu/gorm/errors.go deleted file mode 100644 index d5ef8d571..000000000 --- a/vendor/github.com/jinzhu/gorm/errors.go +++ /dev/null @@ -1,72 +0,0 @@ -package gorm - -import ( - "errors" - "strings" -) - -var ( - // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL occurs when you attempt a query with invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") -) - -// Errors contains all happened errors -type Errors []error - -// IsRecordNotFoundError returns true if error contains a RecordNotFound error -func IsRecordNotFoundError(err error) bool { - if errs, ok := err.(Errors); ok { - for _, err := range errs { - if err == ErrRecordNotFound { - return true - } - } - } - return err == ErrRecordNotFound -} - -// GetErrors gets all errors that have occurred and returns a slice of errors (Error type) -func (errs Errors) GetErrors() []error { - return errs -} - -// Add adds an error to a given slice of errors -func (errs Errors) Add(newErrors ...error) Errors { - for _, err := range newErrors { - if err == nil { - continue - } - - if errors, ok := err.(Errors); ok { - errs = errs.Add(errors...) - } else { - ok = true - for _, e := range errs { - if err == e { - ok = false - } - } - if ok { - errs = append(errs, err) - } - } - } - return errs -} - -// Error takes a slice of all errors that have occurred and returns it as a formatted string -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/vendor/github.com/jinzhu/gorm/field.go b/vendor/github.com/jinzhu/gorm/field.go deleted file mode 100644 index acd06e20d..000000000 --- a/vendor/github.com/jinzhu/gorm/field.go +++ /dev/null @@ -1,66 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - v := reflectValue.Interface() - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = scanner.Scan(v) - } - } else { - err = scanner.Scan(v) - } - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/vendor/github.com/jinzhu/gorm/go.mod b/vendor/github.com/jinzhu/gorm/go.mod deleted file mode 100644 index 6e923b9dc..000000000 --- a/vendor/github.com/jinzhu/gorm/go.mod +++ /dev/null @@ -1,15 +0,0 @@ -module github.com/jinzhu/gorm - -go 1.12 - -require ( - github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd - github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 - github.com/go-sql-driver/mysql v1.4.1 - github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.0.1 - github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v2.0.1+incompatible - golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect - google.golang.org/appengine v1.4.0 // indirect -) diff --git a/vendor/github.com/jinzhu/gorm/go.sum b/vendor/github.com/jinzhu/gorm/go.sum deleted file mode 100644 index 915b4c215..000000000 --- a/vendor/github.com/jinzhu/gorm/go.sum +++ /dev/null @@ -1,29 +0,0 @@ -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= -github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= -github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= -github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= diff --git a/vendor/github.com/jinzhu/gorm/interface.go b/vendor/github.com/jinzhu/gorm/interface.go deleted file mode 100644 index fe6492314..000000000 --- a/vendor/github.com/jinzhu/gorm/interface.go +++ /dev/null @@ -1,24 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" -) - -// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. -type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/vendor/github.com/jinzhu/gorm/join_table_handler.go b/vendor/github.com/jinzhu/gorm/join_table_handler.go deleted file mode 100644 index a036d46d2..000000000 --- a/vendor/github.com/jinzhu/gorm/join_table_handler.go +++ /dev/null @@ -1,211 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - s.Source.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - s.Destination.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return DefaultTableNameHandler(db, s.TableName) -} - -func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - for _, joinTableSource := range joinTableSources { - if joinTableSource.ModelType == modelType { - for _, foreignKey := range joinTableSource.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - conditionMap[foreignKey.DBName] = field.Field.Interface() - } - } - break - } - } - } -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - var ( - scope = db.NewScope("") - conditionMap = map[string]interface{}{} - ) - - // Update condition map for source - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) - - // Update condition map for destination - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range conditionMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - conditionMap = map[string]interface{}{} - ) - - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) - - for key, value := range conditionMap { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/vendor/github.com/jinzhu/gorm/logger.go b/vendor/github.com/jinzhu/gorm/logger.go deleted file mode 100644 index 88e167dd6..000000000 --- a/vendor/github.com/jinzhu/gorm/logger.go +++ /dev/null @@ -1,141 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "strconv" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`\?`) - numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) -) - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} - -var LogFormatter = func(values ...interface{}) (messages []interface{}) { - if len(values) > 1 { - var ( - sql string - formattedValues []string - level = values[0] - currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - ) - - messages = []interface{}{source, currentTime} - - if len(values) == 2 { - //remove the line break - currentTime = currentTime[1:] - //remove the brackets - source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) - - messages = []interface{}{currentTime, source} - } - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - if t.IsZero() { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) - } - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - switch value.(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: - formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) - default: - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } - } else { - formattedValues = append(formattedValues, "NULL") - } - } - - // differentiate between $n placeholders or else treat like ? - if numericPlaceHolderRegexp.MatchString(values[3].(string)) { - sql = values[3].(string) - for index, value := range formattedValues { - placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) - sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") - } - } else { - formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - } - - messages = append(messages, sql) - messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - } - - return -} - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - logger.Println(LogFormatter(values...)...) -} - -type nopLogger struct{} - -func (nopLogger) Print(values ...interface{}) {} diff --git a/vendor/github.com/jinzhu/gorm/main.go b/vendor/github.com/jinzhu/gorm/main.go deleted file mode 100644 index 3db87870c..000000000 --- a/vendor/github.com/jinzhu/gorm/main.go +++ /dev/null @@ -1,881 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "sync" - "time" -) - -// DB contains information for current db connection -type DB struct { - sync.RWMutex - Value interface{} - Error error - RowsAffected int64 - - // single db - db SQLCommon - blockGlobalUpdate bool - logMode logModeValue - logger logger - search *search - values sync.Map - - // global db - parent *DB - callbacks *Callback - dialect Dialect - singularTable bool - - // function to be used to override the creating of a new timestamp - nowFuncOverride func() time.Time -} - -type logModeValue int - -const ( - defaultLogMode logModeValue = iota - noLogMode - detailedLogMode -) - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (db *DB, err error) { - if len(args) == 0 { - err = errors.New("invalid database source") - return nil, err - } - var source string - var dbSQL SQLCommon - var ownDbSQL bool - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - ownDbSQL = true - case SQLCommon: - dbSQL = value - ownDbSQL = false - default: - return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) - } - - db = &DB{ - db: dbSQL, - logger: defaultLogger, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), - } - db.parent = db - if err != nil { - return - } - // Send a ping to make sure the database connection is alive. - if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil && ownDbSQL { - d.Close() - } - } - return -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -type closer interface { - Close() error -} - -// Close close current db connection. If database connection is not an io.Closer, returns an error. -func (s *DB) Close() error { - if db, ok := s.parent.db.(closer); ok { - return db.Close() - } - return errors.New("can't close current db") -} - -// DB get `*sql.DB` from current connection -// If the underlying database connection is not a *sql.DB, returns nil -func (s *DB) DB() *sql.DB { - db, ok := s.db.(*sql.DB) - if !ok { - panic("can't support full GORM on currently status, maybe this is a TX instance.") - } - return db -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() SQLCommon { - return s.db -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.dialect -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone(s.logger) - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = detailedLogMode - } else { - s.logMode = noLogMode - } - return s -} - -// SetNowFuncOverride set the function to be used when creating a new timestamp -func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { - s.nowFuncOverride = nowFuncOverride - return s -} - -// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, -// otherwise defaults to the global NowFunc() -func (s *DB) nowFunc() time.Time { - if s.nowFuncOverride != nil { - return s.nowFuncOverride() - } - - return NowFunc() -} - -// BlockGlobalUpdate if true, generates an error on update/delete without where clause. -// This is to prevent eventual error with empty objects updates/deletions -func (s *DB) BlockGlobalUpdate(enable bool) *DB { - s.blockGlobalUpdate = enable - return s -} - -// HasBlockGlobalUpdate return state of block -func (s *DB) HasBlockGlobalUpdate() bool { - return s.blockGlobalUpdate -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - s.parent.Lock() - defer s.parent.Unlock() - s.parent.singularTable = enable -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - scope := &Scope{db: dbClone, Value: value} - if s.search != nil { - scope.Search = s.search.clone() - } else { - scope.Search = &search{} - } - return scope -} - -// QueryExpr returns the query as SqlExpr object -func (s *DB) QueryExpr() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(scope.SQL, scope.SQLVars...) -} - -// SubQuery returns the query as sub query -func (s *DB) SubQuery() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query interface{}, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Take return a record that match given conditions, the order will depend on the database implementation -func (s *DB) Take(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -//Preloads preloads relations, don`t touch out -func (s *DB) Preloads(out interface{}) *DB { - return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - scope = s.NewScope(result) - clone = scope.db - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := s.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db - } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -// WARNING when update with struct, GORM will not update fields that with zero value -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.NewScope(value) - if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db - if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().Table(scope.TableName()).FirstOrCreate(value) - } - return newDB - } - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.NewScope(nil) - generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Transaction start a transaction as a block, -// return error will rollback, otherwise to commit. -func (s *DB) Transaction(fc func(tx *DB) error) (err error) { - panicked := true - tx := s.Begin() - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - tx.Rollback() - } - }() - - err = fc(tx) - - if err == nil { - err = tx.Commit().Error - } - - panicked = false - return -} - -// Begin begins a transaction -func (s *DB) Begin() *DB { - return s.BeginTx(context.Background(), &sql.TxOptions{}) -} - -// BeginTx begins a transaction with options -func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.BeginTx(ctx, opts) - c.db = interface{}(tx).(SQLCommon) - - c.dialect.SetDB(c.db) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - if err := db.Rollback(); err != nil && err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// RollbackUnlessCommitted rollback a transaction if it has not yet been -// committed. -func (s *DB) RollbackUnlessCommitted() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - err := db.Rollback() - // Ignore the error indicating that the transaction has already - // been committed. - if err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") -func (s *DB) RemoveForeignKey(field string, dest string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeForeignKey(field, dest) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values.Store(name, value) - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values.Load(name) - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == defaultLogMode { - go s.print("error", fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors(s.GetErrors()) - errors = errors.Add(err) - if len(errors) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() []error { - if errs, ok := s.Error.(Errors); ok { - return errs - } else if s.Error != nil { - return []error{s.Error} - } - return []error{} -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := &DB{ - db: s.db, - parent: s.parent, - logger: s.logger, - logMode: s.logMode, - Value: s.Value, - Error: s.Error, - blockGlobalUpdate: s.blockGlobalUpdate, - dialect: newDialect(s.dialect.GetName(), s.db), - nowFuncOverride: s.nowFuncOverride, - } - - s.values.Range(func(k, v interface{}) bool { - db.values.Store(k, v) - return true - }) - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = db - return db -} - -func (s *DB) print(v ...interface{}) { - s.logger.Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == detailedLogMode { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == detailedLogMode { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) - } -} diff --git a/vendor/github.com/jinzhu/gorm/model.go b/vendor/github.com/jinzhu/gorm/model.go deleted file mode 100644 index f37ff7eaa..000000000 --- a/vendor/github.com/jinzhu/gorm/model.go +++ /dev/null @@ -1,14 +0,0 @@ -package gorm - -import "time" - -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` -} diff --git a/vendor/github.com/jinzhu/gorm/model_struct.go b/vendor/github.com/jinzhu/gorm/model_struct.go deleted file mode 100644 index d9e2e90f4..000000000 --- a/vendor/github.com/jinzhu/gorm/model_struct.go +++ /dev/null @@ -1,671 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -// lock for mutating global cached model metadata -var structsLock sync.Mutex - -// global cache of model metadata -var modelStructsMap sync.Map - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - - defaultTableName string - l sync.Mutex -} - -// TableName returns model's table name -func (s *ModelStruct) TableName(db *DB) string { - s.l.Lock() - defer s.l.Unlock() - - if s.defaultTableName == "" && db != nil && s.ModelType != nil { - // Set default table name - if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { - s.defaultTableName = tabler.TableName() - } else { - tableName := ToTableName(s.ModelType.Name()) - db.parent.RLock() - if db == nil || (db.parent != nil && !db.parent.singularTable) { - tableName = inflection.Plural(tableName) - } - db.parent.RUnlock() - s.defaultTableName = tableName - } - } - - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship - - tagSettingsLock sync.RWMutex -} - -// TagSettingsSet Sets a tag in the tag settings map -func (sf *StructField) TagSettingsSet(key, val string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - sf.TagSettings[key] = val -} - -// TagSettingsGet returns a tag from the tag settings -func (sf *StructField) TagSettingsGet(key string) (string, bool) { - sf.tagSettingsLock.RLock() - defer sf.tagSettingsLock.RUnlock() - val, ok := sf.TagSettings[key] - return val, ok -} - -// TagSettingsDelete deletes a tag -func (sf *StructField) TagSettingsDelete(key string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - delete(sf.TagSettings, key) -} - -func (sf *StructField) clone() *StructField { - clone := &StructField{ - DBName: sf.DBName, - Name: sf.Name, - Names: sf.Names, - IsPrimaryKey: sf.IsPrimaryKey, - IsNormal: sf.IsNormal, - IsIgnored: sf.IsIgnored, - IsScanner: sf.IsScanner, - HasDefaultValue: sf.HasDefaultValue, - Tag: sf.Tag, - TagSettings: map[string]string{}, - Struct: sf.Struct, - IsForeignKey: sf.IsForeignKey, - } - - if sf.Relationship != nil { - relationship := *sf.Relationship - clone.Relationship = &relationship - } - - // copy the struct field tagSettings, they should be read-locked while they are copied - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - for key, value := range sf.TagSettings { - clone.TagSettings[key] = value - } - - return clone -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - PolymorphicValue string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - isSingularTable := false - if scope.db != nil && scope.db.parent != nil { - scope.db.parent.RLock() - isSingularTable = scope.db.parent.singularTable - scope.db.parent.RUnlock() - } - - hashKey := struct { - singularTable bool - reflectType reflect.Type - }{isSingularTable, reflectType} - if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { - return value.(*ModelStruct) - } - - modelStruct.ModelType = reflectType - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if _, ok := field.TagSettingsGet("-"); ok { - field.IsIgnored = true - } else { - if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettingsGet(key); !ok { - field.TagSettingsSet(key, value) - } - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { - subField.DBName = prefix + subField.DBName - } - - if subField.IsPrimaryKey { - if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } else { - subField.IsPrimaryKey = false - } - } - - if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { - if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { - newJoinTableHandler := &JoinTableHandler{} - newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) - subField.Relationship.JoinTableHandler = newJoinTableHandler - } - } - - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - foreignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - relationship.Kind = "many_to_many" - - { // Foreign Keys for Source - joinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { - joinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - - // setup join table foreign keys for source - if len(joinTableDBNames) > idx { - // if defined join table's foreign key - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) - } else { - defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) - } - } - } - } - - { // Foreign Keys for Association (Destination) - associationJoinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { - associationJoinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for idx, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // setup join table foreign keys for association - if len(associationJoinTableDBNames) > idx { - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) - } else { - // join table foreign keys for association - joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - tagForeignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettingsGet("COLUMN"); ok { - field.DBName = value - } else { - field.DBName = ToColumnName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Store(hashKey, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - if str == "" { - continue - } - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/vendor/github.com/jinzhu/gorm/naming.go b/vendor/github.com/jinzhu/gorm/naming.go deleted file mode 100644 index 6b0a4fddb..000000000 --- a/vendor/github.com/jinzhu/gorm/naming.go +++ /dev/null @@ -1,124 +0,0 @@ -package gorm - -import ( - "bytes" - "strings" -) - -// Namer is a function type which is given a string and return a string -type Namer func(string) string - -// NamingStrategy represents naming strategies -type NamingStrategy struct { - DB Namer - Table Namer - Column Namer -} - -// TheNamingStrategy is being initialized with defaultNamingStrategy -var TheNamingStrategy = &NamingStrategy{ - DB: defaultNamer, - Table: defaultNamer, - Column: defaultNamer, -} - -// AddNamingStrategy sets the naming strategy -func AddNamingStrategy(ns *NamingStrategy) { - if ns.DB == nil { - ns.DB = defaultNamer - } - if ns.Table == nil { - ns.Table = defaultNamer - } - if ns.Column == nil { - ns.Column = defaultNamer - } - TheNamingStrategy = ns -} - -// DBName alters the given name by DB -func (ns *NamingStrategy) DBName(name string) string { - return ns.DB(name) -} - -// TableName alters the given name by Table -func (ns *NamingStrategy) TableName(name string) string { - return ns.Table(name) -} - -// ColumnName alters the given name by Column -func (ns *NamingStrategy) ColumnName(name string) string { - return ns.Column(name) -} - -// ToDBName convert string to db name -func ToDBName(name string) string { - return TheNamingStrategy.DBName(name) -} - -// ToTableName convert string to table name -func ToTableName(name string) string { - return TheNamingStrategy.TableName(name) -} - -// ToColumnName convert string to db name -func ToColumnName(name string) string { - return TheNamingStrategy.ColumnName(name) -} - -var smap = newSafeMap() - -func defaultNamer(name string) string { - const ( - lower = false - upper = true - ) - - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber bool - ) - - for i, v := range value[:len(value)-1] { - nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} diff --git a/vendor/github.com/jinzhu/gorm/scope.go b/vendor/github.com/jinzhu/gorm/scope.go deleted file mode 100644 index d82cadbc8..000000000 --- a/vendor/github.com/jinzhu/gorm/scope.go +++ /dev/null @@ -1,1421 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" - "regexp" - "strings" - "time" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() SQLCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Contains(str, ".") { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToColumnName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - _, skipBindVar := scope.InstanceGet("skip_bindvar") - - if expr, ok := value.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - if skipBindVar { - scope.AddToVars(arg) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - - if skipBindVar { - return "?" - } - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Contains(scope.Search.tableName, " ") { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - joinSQL := scope.joinsSQL() - whereSQL := scope.whereSQL() - if scope.Search.raw { - whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") - } - return joinSQL + whereSQL + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); scope.Err(err) == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var ( - columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` - isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") - countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") -) - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - values = make([]interface{}, len(columns)) - selectFields []*Field - selectedColumnsMap = map[string]int{} - resetFields = map[int]*Field{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - offset := 0 - if idx, ok := selectedColumnsMap[column]; ok { - offset = idx + 1 - selectFields = selectFields[offset:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[index] = field - } - - selectedColumnsMap[column] = offset + fieldIndex - - if field.IsNormal { - break - } - } - } - } - - scope.Err(rows.Scan(values...)) - - for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { - var ( - quotedTableName = scope.QuotedTableName() - quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) - equalSQL = "=" - inSQL = "IN" - ) - - // If building not conditions - if !include { - equalSQL = "<>" - inSQL = "NOT IN" - } - - switch value := clause["query"].(type) { - case sql.NullInt64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - if !include && reflect.ValueOf(value).Len() == 0 { - return - } - str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) - clause["args"] = []interface{}{value} - case string: - if isNumberRegexp.MatchString(value) { - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) - } - - if value != "" { - if !include { - if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) - } - } else { - str = fmt.Sprintf("(%v)", value) - } - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) - } else { - if !include { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) - } - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - newScope := scope.New(value) - - if len(newScope.Fields()) == 0 { - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - scopeQuotedTableName := newScope.QuotedTableName() - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - default: - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - - replacements := []string{} - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if as, ok := arg.([][]interface{}); ok { - var tempMarks []string - for _, a := range as { - var arrayMarks []string - for _, v := range a { - arrayMarks = append(arrayMarks, scope.AddToVars(v)) - } - - if len(arrayMarks) > 0 { - tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) - } - } - - if len(tempMarks) > 0 { - replacements = append(replacements, strings.Join(tempMarks, ",")) - } - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = valuer.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) - } - - if err != nil { - scope.Err(err) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for _, s := range str { - if s == '?' && len(replacements) > i { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(s) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - replacements := []string{} - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - replacements = append(replacements, scope.AddToVars(arg)) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for pos, char := range str { - if str[pos] == '?' { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(char) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && hasDeletedAtField { - sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildCondition(clause, false); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - if str, ok := order.(string); ok { - orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - orders = append(orders, exp) - } - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) - scope.Err(err) - return sql -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(scope.CombinedConditionSql()) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - defer func() { - if err := recover(); err != nil { - if db, ok := scope.db.db.(sqlTx); ok { - db.Rollback() - } - panic(err) - } - }() - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values, db: db}).Fields() { - if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false, scope.db), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value, true, scope.db) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*SqlExpr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal && !field.IsIgnored { - hasUpdate = true - if err == ErrUnaddressable { - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - - result := &RowQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Row -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - - result := &RowsQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Rows, result.Error -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { - queryStr := strings.ToLower(fmt.Sprint(query)) - if queryStr == column { - return true - } - - if strings.HasSuffix(queryStr, "as "+column) { - return true - } - - if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { - return true - } - - return false -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - if dest.Len() > 0 { - dest.Set(reflect.Zero(dest.Type())) - } - - if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { - scope.Search.Select(column) - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - if len(scope.Search.group) != 0 { - if len(scope.Search.havingConditions) != 0 { - scope.prepareQuerySQL() - scope.Search = &search{} - scope.Search.Select("count(*)") - scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL)) - } else { - scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") - scope.Search.group += " ) AS count_table" - } - } else { - scope.Search.Select("count(*)") - } - } - scope.Search.ignoreOrderQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - tx := scope.db.Set("gorm:association:source", scope.Value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(tx.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - scope.Err(tx.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return " " + tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - // Compatible with old generated key - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var mysql mysql - var query string - if scope.Dialect().GetName() == mysql.GetName() { - query = `ALTER TABLE %s DROP FOREIGN KEY %s;` - } else { - query = `ALTER TABLE %s DROP CONSTRAINT %s;` - } - - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettingsGet("INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - indexes[name] = append(indexes[name], column) - } - } - - if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - uniqueIndexes[name] = append(uniqueIndexes[name], column) - } - } - } - - for name, columns := range indexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - for name, columns := range uniqueIndexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - resultMap := make(map[string][]interface{}) - for _, value := range values { - indirectValue := indirect(reflect.ValueOf(value)) - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - var hasValue = false - for _, column := range columns { - field := object.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - case reflect.Struct: - var result []interface{} - var hasValue = false - for _, column := range columns { - field := indirectValue.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - } - for _, v := range resultMap { - results = append(results, v) - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - resultsMap := map[interface{}]bool{} - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { - resultsMap[elem.Addr()] = true - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() && resultsMap[result.Addr()] != true { - resultsMap[result.Addr()] = true - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} - -func (scope *Scope) hasConditions() bool { - return !scope.PrimaryKeyZero() || - len(scope.Search.whereConditions) > 0 || - len(scope.Search.orConditions) > 0 || - len(scope.Search.notConditions) > 0 -} diff --git a/vendor/github.com/jinzhu/gorm/search.go b/vendor/github.com/jinzhu/gorm/search.go deleted file mode 100644 index 7c4cc184a..000000000 --- a/vendor/github.com/jinzhu/gorm/search.go +++ /dev/null @@ -1,153 +0,0 @@ -package gorm - -import ( - "fmt" -) - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []interface{} - preload []searchPreload - offset interface{} - limit interface{} - group string - tableName string - raw bool - Unscoped bool - ignoreOrderQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := *s - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value interface{}, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - s.orders = []interface{}{} - } - - if value != nil && value != "" { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit interface{}) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset interface{}) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*SqlExpr); ok { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) - } else { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - } - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/vendor/github.com/jinzhu/gorm/test_all.sh b/vendor/github.com/jinzhu/gorm/test_all.sh deleted file mode 100755 index 5cfb3321a..000000000 --- a/vendor/github.com/jinzhu/gorm/test_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -dialects=("postgres" "mysql" "mssql" "sqlite") - -for dialect in "${dialects[@]}" ; do - DEBUG=false GORM_DIALECT=${dialect} go test -done diff --git a/vendor/github.com/jinzhu/gorm/utils.go b/vendor/github.com/jinzhu/gorm/utils.go deleted file mode 100644 index d2ae9465d..000000000 --- a/vendor/github.com/jinzhu/gorm/utils.go +++ /dev/null @@ -1,226 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -// SQL expression -type SqlExpr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *SqlExpr { - return &SqlExpr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - switch value.Kind() { - case reflect.String: - return value.Len() == 0 - case reflect.Bool: - return !value.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return value.Uint() == 0 - case reflect.Float32, reflect.Float64: - return value.Float() == 0 - case reflect.Interface, reflect.Ptr: - return value.IsNil() - } - - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/vendor/github.com/jinzhu/gorm/wercker.yml b/vendor/github.com/jinzhu/gorm/wercker.yml deleted file mode 100644 index c74fa4d4b..000000000 --- a/vendor/github.com/jinzhu/gorm/wercker.yml +++ /dev/null @@ -1,154 +0,0 @@ -# use the default golang container from Docker Hub -box: golang - -services: - - name: mariadb - id: mariadb:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql - id: mysql:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql57 - id: mysql:5.7 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql56 - id: mysql:5.6 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: postgres - id: postgres:latest - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres96 - id: postgres:9.6 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres95 - id: postgres:9.5 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres94 - id: postgres:9.4 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres93 - id: postgres:9.3 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: mssql - id: mcmoe/mssqldocker:latest - env: - ACCEPT_EULA: Y - SA_PASSWORD: LoremIpsum86 - MSSQL_DB: gorm - MSSQL_USER: gorm - MSSQL_PASSWORD: LoremIpsum86 - -# The steps that will be executed in the build pipeline -build: - # The steps that will be executed on build - steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace - - # Gets the dependencies - - script: - name: go get - code: | - cd $WERCKER_SOURCE_DIR - go version - go get -t -v ./... - - # Build the project - - script: - name: go build - code: | - go build ./... - - # Test the project - - script: - name: test sqlite - code: | - go test -race -v ./... - - - script: - name: test mariadb - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.7 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.6 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test postgres - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres96 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres95 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres94 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres93 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test mssql - code: | - GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./... - - - script: - name: codecov - code: | - go test -race -coverprofile=coverage.txt -covermode=atomic ./... - bash <(curl -s https://codecov.io/bash) diff --git a/vendor/github.com/jinzhu/now/Guardfile b/vendor/github.com/jinzhu/now/Guardfile new file mode 100644 index 000000000..0b860b065 --- /dev/null +++ b/vendor/github.com/jinzhu/now/Guardfile @@ -0,0 +1,3 @@ +guard 'gotest' do + watch(%r{\.go$}) +end diff --git a/vendor/github.com/jinzhu/gorm/License b/vendor/github.com/jinzhu/now/License similarity index 100% rename from vendor/github.com/jinzhu/gorm/License rename to vendor/github.com/jinzhu/now/License diff --git a/vendor/github.com/jinzhu/now/README.md b/vendor/github.com/jinzhu/now/README.md new file mode 100644 index 000000000..3add6bfa5 --- /dev/null +++ b/vendor/github.com/jinzhu/now/README.md @@ -0,0 +1,134 @@ +## Now + +Now is a time toolkit for golang + +[![wercker status](https://app.wercker.com/status/a350da4eae6cb28a35687ba41afb565a/s/master "wercker status")](https://app.wercker.com/project/byKey/a350da4eae6cb28a35687ba41afb565a) + +## Install + +``` +go get -u github.com/jinzhu/now +``` + +## Usage + +Calculating time based on current time + +```go +import "github.com/jinzhu/now" + +time.Now() // 2013-11-18 17:51:49.123456789 Mon + +now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon +now.BeginningOfHour() // 2013-11-18 17:00:00 Mon +now.BeginningOfDay() // 2013-11-18 00:00:00 Mon +now.BeginningOfWeek() // 2013-11-17 00:00:00 Sun +now.BeginningOfMonth() // 2013-11-01 00:00:00 Fri +now.BeginningOfQuarter() // 2013-10-01 00:00:00 Tue +now.BeginningOfYear() // 2013-01-01 00:00:00 Tue + +now.WeekStartDay = time.Monday // Set Monday as first day, default is Sunday +now.BeginningOfWeek() // 2013-11-18 00:00:00 Mon + +now.EndOfMinute() // 2013-11-18 17:51:59.999999999 Mon +now.EndOfHour() // 2013-11-18 17:59:59.999999999 Mon +now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon +now.EndOfWeek() // 2013-11-23 23:59:59.999999999 Sat +now.EndOfMonth() // 2013-11-30 23:59:59.999999999 Sat +now.EndOfQuarter() // 2013-12-31 23:59:59.999999999 Tue +now.EndOfYear() // 2013-12-31 23:59:59.999999999 Tue + +now.WeekStartDay = time.Monday // Set Monday as first day, default is Sunday +now.EndOfWeek() // 2013-11-24 23:59:59.999999999 Sun +``` + +Calculating time based on another time + +```go +t := time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.Now().Location()) +now.With(t).EndOfMonth() // 2013-02-28 23:59:59.999999999 Thu +``` + +Calculating time based on configuration + +```go +location, err := time.LoadLocation("Asia/Shanghai") + +myConfig := &now.Config{ + WeekStartDay: time.Monday, + TimeLocation: location, + TimeFormats: []string{"2006-01-02 15:04:05"}, +} + +t := time.Date(2013, 11, 18, 17, 51, 49, 123456789, time.Now().Location()) // // 2013-11-18 17:51:49.123456789 Mon +myConfig.With(t).BeginningOfWeek() // 2013-11-18 00:00:00 Mon + +myConfig.Parse("2002-10-12 22:14:01") // 2002-10-12 22:14:01 +myConfig.Parse("2002-10-12 22:14") // returns error 'can't parse string as time: 2002-10-12 22:14' +``` + +### Monday/Sunday + +Don't be bothered with the `WeekStartDay` setting, you can use `Monday`, `Sunday` + +```go +now.Monday() // 2013-11-18 00:00:00 Mon +now.Sunday() // 2013-11-24 00:00:00 Sun (Next Sunday) +now.EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of next Sunday) + +t := time.Date(2013, 11, 24, 17, 51, 49, 123456789, time.Now().Location()) // 2013-11-24 17:51:49.123456789 Sun +now.With(t).Monday() // 2013-11-18 00:00:00 Sun (Last Monday if today is Sunday) +now.With(t).Sunday() // 2013-11-24 00:00:00 Sun (Beginning Of Today if today is Sunday) +now.With(t).EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of Today if today is Sunday) +``` + +### Parse String to Time + +```go +time.Now() // 2013-11-18 17:51:49.123456789 Mon + +// Parse(string) (time.Time, error) +t, err := now.Parse("2017") // 2017-01-01 00:00:00, nil +t, err := now.Parse("2017-10") // 2017-10-01 00:00:00, nil +t, err := now.Parse("2017-10-13") // 2017-10-13 00:00:00, nil +t, err := now.Parse("1999-12-12 12") // 1999-12-12 12:00:00, nil +t, err := now.Parse("1999-12-12 12:20") // 1999-12-12 12:20:00, nil +t, err := now.Parse("1999-12-12 12:20:21") // 1999-12-12 12:20:00, nil +t, err := now.Parse("10-13") // 2013-10-13 00:00:00, nil +t, err := now.Parse("12:20") // 2013-11-18 12:20:00, nil +t, err := now.Parse("12:20:13") // 2013-11-18 12:20:13, nil +t, err := now.Parse("14") // 2013-11-18 14:00:00, nil +t, err := now.Parse("99:99") // 2013-11-18 12:20:00, Can't parse string as time: 99:99 + +// MustParse must parse string to time or it will panic +now.MustParse("2013-01-13") // 2013-01-13 00:00:00 +now.MustParse("02-17") // 2013-02-17 00:00:00 +now.MustParse("2-17") // 2013-02-17 00:00:00 +now.MustParse("8") // 2013-11-18 08:00:00 +now.MustParse("2002-10-12 22:14") // 2002-10-12 22:14:00 +now.MustParse("99:99") // panic: Can't parse string as time: 99:99 +``` + +Extend `now` to support more formats is quite easy, just update `now.TimeFormats` with other time layouts, e.g: + +```go +now.TimeFormats = append(now.TimeFormats, "02 Jan 2006 15:04") +``` + +Please send me pull requests if you want a format to be supported officially + +## Contributing + +You can help to make the project better, check out [http://gorm.io/contribute.html](http://gorm.io/contribute.html) for things you can do. + +# Author + +**jinzhu** + +* +* +* + +## License + +Released under the [MIT License](http://www.opensource.org/licenses/MIT). diff --git a/vendor/github.com/jinzhu/now/go.mod b/vendor/github.com/jinzhu/now/go.mod new file mode 100644 index 000000000..018d266a3 --- /dev/null +++ b/vendor/github.com/jinzhu/now/go.mod @@ -0,0 +1,3 @@ +module github.com/jinzhu/now + +go 1.12 diff --git a/vendor/github.com/jinzhu/now/main.go b/vendor/github.com/jinzhu/now/main.go new file mode 100644 index 000000000..c996a8a5c --- /dev/null +++ b/vendor/github.com/jinzhu/now/main.go @@ -0,0 +1,194 @@ +// Package now is a time toolkit for golang. +// +// More details README here: https://github.com/jinzhu/now +// +// import "github.com/jinzhu/now" +// +// now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon +// now.BeginningOfDay() // 2013-11-18 00:00:00 Mon +// now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon +package now + +import "time" + +// WeekStartDay set week start day, default is sunday +var WeekStartDay = time.Sunday + +// TimeFormats default time formats will be parsed as +var TimeFormats = []string{ + "2006", "2006-1", "2006-1-2", "2006-1-2 15", "2006-1-2 15:4", "2006-1-2 15:4:5", "1-2", + "15:4:5", "15:4", "15", + "15:4:5 Jan 2, 2006 MST", "2006-01-02 15:04:05.999999999 -0700 MST", "2006-01-02T15:04:05-07:00", + "2006.1.2", "2006.1.2 15:04:05", "2006.01.02", "2006.01.02 15:04:05", + "1/2/2006", "1/2/2006 15:4:5", "2006/01/02", "2006/01/02 15:04:05", + time.ANSIC, time.UnixDate, time.RubyDate, time.RFC822, time.RFC822Z, time.RFC850, + time.RFC1123, time.RFC1123Z, time.RFC3339, time.RFC3339Nano, + time.Kitchen, time.Stamp, time.StampMilli, time.StampMicro, time.StampNano, +} + +// Config configuration for now package +type Config struct { + WeekStartDay time.Weekday + TimeLocation *time.Location + TimeFormats []string +} + +// DefaultConfig default config +var DefaultConfig *Config + +// New initialize Now based on configuration +func (config *Config) With(t time.Time) *Now { + return &Now{Time: t, Config: config} +} + +// Parse parse string to time based on configuration +func (config *Config) Parse(strs ...string) (time.Time, error) { + if config.TimeLocation == nil { + return config.With(time.Now()).Parse(strs...) + } else { + return config.With(time.Now().In(config.TimeLocation)).Parse(strs...) + } +} + +// MustParse must parse string to time or will panic +func (config *Config) MustParse(strs ...string) time.Time { + if config.TimeLocation == nil { + return config.With(time.Now()).MustParse(strs...) + } else { + return config.With(time.Now().In(config.TimeLocation)).MustParse(strs...) + } +} + +// Now now struct +type Now struct { + time.Time + *Config +} + +// With initialize Now with time +func With(t time.Time) *Now { + config := DefaultConfig + if config == nil { + config = &Config{ + WeekStartDay: WeekStartDay, + TimeFormats: TimeFormats, + } + } + + return &Now{Time: t, Config: config} +} + +// New initialize Now with time +func New(t time.Time) *Now { + return With(t) +} + +// BeginningOfMinute beginning of minute +func BeginningOfMinute() time.Time { + return With(time.Now()).BeginningOfMinute() +} + +// BeginningOfHour beginning of hour +func BeginningOfHour() time.Time { + return With(time.Now()).BeginningOfHour() +} + +// BeginningOfDay beginning of day +func BeginningOfDay() time.Time { + return With(time.Now()).BeginningOfDay() +} + +// BeginningOfWeek beginning of week +func BeginningOfWeek() time.Time { + return With(time.Now()).BeginningOfWeek() +} + +// BeginningOfMonth beginning of month +func BeginningOfMonth() time.Time { + return With(time.Now()).BeginningOfMonth() +} + +// BeginningOfQuarter beginning of quarter +func BeginningOfQuarter() time.Time { + return With(time.Now()).BeginningOfQuarter() +} + +// BeginningOfYear beginning of year +func BeginningOfYear() time.Time { + return With(time.Now()).BeginningOfYear() +} + +// EndOfMinute end of minute +func EndOfMinute() time.Time { + return With(time.Now()).EndOfMinute() +} + +// EndOfHour end of hour +func EndOfHour() time.Time { + return With(time.Now()).EndOfHour() +} + +// EndOfDay end of day +func EndOfDay() time.Time { + return With(time.Now()).EndOfDay() +} + +// EndOfWeek end of week +func EndOfWeek() time.Time { + return With(time.Now()).EndOfWeek() +} + +// EndOfMonth end of month +func EndOfMonth() time.Time { + return With(time.Now()).EndOfMonth() +} + +// EndOfQuarter end of quarter +func EndOfQuarter() time.Time { + return With(time.Now()).EndOfQuarter() +} + +// EndOfYear end of year +func EndOfYear() time.Time { + return With(time.Now()).EndOfYear() +} + +// Monday monday +func Monday() time.Time { + return With(time.Now()).Monday() +} + +// Sunday sunday +func Sunday() time.Time { + return With(time.Now()).Sunday() +} + +// EndOfSunday end of sunday +func EndOfSunday() time.Time { + return With(time.Now()).EndOfSunday() +} + +// Parse parse string to time +func Parse(strs ...string) (time.Time, error) { + return With(time.Now()).Parse(strs...) +} + +// ParseInLocation parse string to time in location +func ParseInLocation(loc *time.Location, strs ...string) (time.Time, error) { + return With(time.Now().In(loc)).Parse(strs...) +} + +// MustParse must parse string to time or will panic +func MustParse(strs ...string) time.Time { + return With(time.Now()).MustParse(strs...) +} + +// MustParseInLocation must parse string to time in location or will panic +func MustParseInLocation(loc *time.Location, strs ...string) time.Time { + return With(time.Now().In(loc)).MustParse(strs...) +} + +// Between check now between the begin, end time or not +func Between(time1, time2 string) bool { + return With(time.Now()).Between(time1, time2) +} diff --git a/vendor/github.com/jinzhu/now/now.go b/vendor/github.com/jinzhu/now/now.go new file mode 100644 index 000000000..353835a37 --- /dev/null +++ b/vendor/github.com/jinzhu/now/now.go @@ -0,0 +1,213 @@ +package now + +import ( + "errors" + "regexp" + "time" +) + +// BeginningOfMinute beginning of minute +func (now *Now) BeginningOfMinute() time.Time { + return now.Truncate(time.Minute) +} + +// BeginningOfHour beginning of hour +func (now *Now) BeginningOfHour() time.Time { + y, m, d := now.Date() + return time.Date(y, m, d, now.Time.Hour(), 0, 0, 0, now.Time.Location()) +} + +// BeginningOfDay beginning of day +func (now *Now) BeginningOfDay() time.Time { + y, m, d := now.Date() + return time.Date(y, m, d, 0, 0, 0, 0, now.Time.Location()) +} + +// BeginningOfWeek beginning of week +func (now *Now) BeginningOfWeek() time.Time { + t := now.BeginningOfDay() + weekday := int(t.Weekday()) + + if now.WeekStartDay != time.Sunday { + weekStartDayInt := int(now.WeekStartDay) + + if weekday < weekStartDayInt { + weekday = weekday + 7 - weekStartDayInt + } else { + weekday = weekday - weekStartDayInt + } + } + return t.AddDate(0, 0, -weekday) +} + +// BeginningOfMonth beginning of month +func (now *Now) BeginningOfMonth() time.Time { + y, m, _ := now.Date() + return time.Date(y, m, 1, 0, 0, 0, 0, now.Location()) +} + +// BeginningOfQuarter beginning of quarter +func (now *Now) BeginningOfQuarter() time.Time { + month := now.BeginningOfMonth() + offset := (int(month.Month()) - 1) % 3 + return month.AddDate(0, -offset, 0) +} + +// BeginningOfHalf beginning of half year +func (now *Now) BeginningOfHalf() time.Time { + month := now.BeginningOfMonth() + offset := (int(month.Month()) - 1) % 6 + return month.AddDate(0, -offset, 0) +} + +// BeginningOfYear BeginningOfYear beginning of year +func (now *Now) BeginningOfYear() time.Time { + y, _, _ := now.Date() + return time.Date(y, time.January, 1, 0, 0, 0, 0, now.Location()) +} + +// EndOfMinute end of minute +func (now *Now) EndOfMinute() time.Time { + return now.BeginningOfMinute().Add(time.Minute - time.Nanosecond) +} + +// EndOfHour end of hour +func (now *Now) EndOfHour() time.Time { + return now.BeginningOfHour().Add(time.Hour - time.Nanosecond) +} + +// EndOfDay end of day +func (now *Now) EndOfDay() time.Time { + y, m, d := now.Date() + return time.Date(y, m, d, 23, 59, 59, int(time.Second-time.Nanosecond), now.Location()) +} + +// EndOfWeek end of week +func (now *Now) EndOfWeek() time.Time { + return now.BeginningOfWeek().AddDate(0, 0, 7).Add(-time.Nanosecond) +} + +// EndOfMonth end of month +func (now *Now) EndOfMonth() time.Time { + return now.BeginningOfMonth().AddDate(0, 1, 0).Add(-time.Nanosecond) +} + +// EndOfQuarter end of quarter +func (now *Now) EndOfQuarter() time.Time { + return now.BeginningOfQuarter().AddDate(0, 3, 0).Add(-time.Nanosecond) +} + +// EndOfHalf end of half year +func (now *Now) EndOfHalf() time.Time { + return now.BeginningOfHalf().AddDate(0, 6, 0).Add(-time.Nanosecond) +} + +// EndOfYear end of year +func (now *Now) EndOfYear() time.Time { + return now.BeginningOfYear().AddDate(1, 0, 0).Add(-time.Nanosecond) +} + +// Monday monday +func (now *Now) Monday() time.Time { + t := now.BeginningOfDay() + weekday := int(t.Weekday()) + if weekday == 0 { + weekday = 7 + } + return t.AddDate(0, 0, -weekday+1) +} + +// Sunday sunday +func (now *Now) Sunday() time.Time { + t := now.BeginningOfDay() + weekday := int(t.Weekday()) + if weekday == 0 { + return t + } + return t.AddDate(0, 0, (7 - weekday)) +} + +// EndOfSunday end of sunday +func (now *Now) EndOfSunday() time.Time { + return New(now.Sunday()).EndOfDay() +} + +func (now *Now) parseWithFormat(str string, location *time.Location) (t time.Time, err error) { + for _, format := range now.TimeFormats { + t, err = time.ParseInLocation(format, str, location) + + if err == nil { + return + } + } + err = errors.New("Can't parse string as time: " + str) + return +} + +var hasTimeRegexp = regexp.MustCompile(`(\s+|^\s*)\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))\s*$`) // match 15:04:05, 15:04:05.000, 15:04:05.000000 15, 2017-01-01 15:04, etc +var onlyTimeRegexp = regexp.MustCompile(`^\s*\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))\s*$`) // match 15:04:05, 15, 15:04:05.000, 15:04:05.000000, etc + +// Parse parse string to time +func (now *Now) Parse(strs ...string) (t time.Time, err error) { + var ( + setCurrentTime bool + parseTime []int + currentTime = []int{now.Nanosecond(), now.Second(), now.Minute(), now.Hour(), now.Day(), int(now.Month()), now.Year()} + currentLocation = now.Location() + onlyTimeInStr = true + ) + + for _, str := range strs { + hasTimeInStr := hasTimeRegexp.MatchString(str) // match 15:04:05, 15 + onlyTimeInStr = hasTimeInStr && onlyTimeInStr && onlyTimeRegexp.MatchString(str) + if t, err = now.parseWithFormat(str, currentLocation); err == nil { + location := t.Location() + + parseTime = []int{t.Nanosecond(), t.Second(), t.Minute(), t.Hour(), t.Day(), int(t.Month()), t.Year()} + + for i, v := range parseTime { + // Don't reset hour, minute, second if current time str including time + if hasTimeInStr && i <= 3 { + continue + } + + // If value is zero, replace it with current time + if v == 0 { + if setCurrentTime { + parseTime[i] = currentTime[i] + } + } else { + setCurrentTime = true + } + + // if current time only includes time, should change day, month to current time + if onlyTimeInStr { + if i == 4 || i == 5 { + parseTime[i] = currentTime[i] + continue + } + } + } + + t = time.Date(parseTime[6], time.Month(parseTime[5]), parseTime[4], parseTime[3], parseTime[2], parseTime[1], parseTime[0], location) + currentTime = []int{t.Nanosecond(), t.Second(), t.Minute(), t.Hour(), t.Day(), int(t.Month()), t.Year()} + } + } + return +} + +// MustParse must parse string to time or it will panic +func (now *Now) MustParse(strs ...string) (t time.Time) { + t, err := now.Parse(strs...) + if err != nil { + panic(err) + } + return t +} + +// Between check time between the begin, end time or not +func (now *Now) Between(begin, end string) bool { + beginTime := now.MustParse(begin) + endTime := now.MustParse(end) + return now.After(beginTime) && now.Before(endTime) +} diff --git a/vendor/github.com/jinzhu/now/wercker.yml b/vendor/github.com/jinzhu/now/wercker.yml new file mode 100644 index 000000000..5e6ce981d --- /dev/null +++ b/vendor/github.com/jinzhu/now/wercker.yml @@ -0,0 +1,23 @@ +box: golang + +build: + steps: + - setup-go-workspace + + # Gets the dependencies + - script: + name: go get + code: | + go get + + # Build the project + - script: + name: go build + code: | + go build ./... + + # Test the project + - script: + name: go test + code: | + go test ./... diff --git a/vendor/github.com/lib/pq/hstore/hstore.go b/vendor/github.com/lib/pq/hstore/hstore.go deleted file mode 100644 index f1470db14..000000000 --- a/vendor/github.com/lib/pq/hstore/hstore.go +++ /dev/null @@ -1,118 +0,0 @@ -package hstore - -import ( - "database/sql" - "database/sql/driver" - "strings" -) - -// Hstore is a wrapper for transferring Hstore values back and forth easily. -type Hstore struct { - Map map[string]sql.NullString -} - -// escapes and quotes hstore keys/values -// s should be a sql.NullString or string -func hQuote(s interface{}) string { - var str string - switch v := s.(type) { - case sql.NullString: - if !v.Valid { - return "NULL" - } - str = v.String - case string: - str = v - default: - panic("not a string or sql.NullString") - } - - str = strings.Replace(str, "\\", "\\\\", -1) - return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` -} - -// Scan implements the Scanner interface. -// -// Note h.Map is reallocated before the scan to clear existing values. If the -// hstore column's database value is NULL, then h.Map is set to nil instead. -func (h *Hstore) Scan(value interface{}) error { - if value == nil { - h.Map = nil - return nil - } - h.Map = make(map[string]sql.NullString) - var b byte - pair := [][]byte{{}, {}} - pi := 0 - inQuote := false - didQuote := false - sawSlash := false - bindex := 0 - for bindex, b = range value.([]byte) { - if sawSlash { - pair[pi] = append(pair[pi], b) - sawSlash = false - continue - } - - switch b { - case '\\': - sawSlash = true - continue - case '"': - inQuote = !inQuote - if !didQuote { - didQuote = true - } - continue - default: - if !inQuote { - switch b { - case ' ', '\t', '\n', '\r': - continue - case '=': - continue - case '>': - pi = 1 - didQuote = false - continue - case ',': - s := string(pair[1]) - if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { - h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} - } else { - h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} - } - pair[0] = []byte{} - pair[1] = []byte{} - pi = 0 - continue - } - } - } - pair[pi] = append(pair[pi], b) - } - if bindex > 0 { - s := string(pair[1]) - if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { - h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} - } else { - h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} - } - } - return nil -} - -// Value implements the driver Valuer interface. Note if h.Map is nil, the -// database column value will be set to NULL. -func (h Hstore) Value() (driver.Value, error) { - if h.Map == nil { - return nil, nil - } - parts := []string{} - for key, val := range h.Map { - thispart := hQuote(key) + "=>" + hQuote(val) - parts = append(parts, thispart) - } - return []byte(strings.Join(parts, ",")), nil -} diff --git a/vendor/gorm.io/driver/postgres/License b/vendor/gorm.io/driver/postgres/License new file mode 100644 index 000000000..037e1653e --- /dev/null +++ b/vendor/gorm.io/driver/postgres/License @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/gorm.io/driver/postgres/README.md b/vendor/gorm.io/driver/postgres/README.md new file mode 100644 index 000000000..1220bf82d --- /dev/null +++ b/vendor/gorm.io/driver/postgres/README.md @@ -0,0 +1,16 @@ +# GORM PostgreSQL Driver + +## USAGE + +```go +import ( + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +// https://github.com/lib/pq +dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" +db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) +``` + +Checkout [https://gorm.io](https://gorm.io) for details. diff --git a/vendor/gorm.io/driver/postgres/go.mod b/vendor/gorm.io/driver/postgres/go.mod new file mode 100644 index 000000000..00b10cf5c --- /dev/null +++ b/vendor/gorm.io/driver/postgres/go.mod @@ -0,0 +1,5 @@ +module gorm.io/driver/postgres + +go 1.14 + +require github.com/lib/pq v1.6.0 diff --git a/vendor/gorm.io/driver/postgres/migrator.go b/vendor/gorm.io/driver/postgres/migrator.go new file mode 100644 index 000000000..243a75d7b --- /dev/null +++ b/vendor/gorm.io/driver/postgres/migrator.go @@ -0,0 +1,151 @@ +package postgres + +import ( + "fmt" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER INDEX ? RENAME TO ?", + clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error + }) +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + return nil +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", + stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ?", + stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} diff --git a/vendor/gorm.io/driver/postgres/postgres.go b/vendor/gorm.io/driver/postgres/postgres.go new file mode 100644 index 000000000..8d0b829ab --- /dev/null +++ b/vendor/gorm.io/driver/postgres/postgres.go @@ -0,0 +1,127 @@ +package postgres + +import ( + "database/sql" + "fmt" + "regexp" + "strconv" + "strings" + + _ "github.com/lib/pq" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +type Dialector struct { + DSN string +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{DSN: dsn} +} + +func (dialector Dialector) Name() string { + return "postgres" +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + // register callbacks + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + WithReturning: true, + }) + db.ConnPool, err = sql.Open("postgres", dialector.DSN) + return +} + +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, + }}} +} + +func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { + return clause.Expr{SQL: "DEFAULT"} +} + +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('$') + writer.WriteString(strconv.Itoa(len(stmt.Vars))) +} + +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('"') + if strings.Contains(str, ".") { + for idx, str := range strings.Split(str, ".") { + if idx > 0 { + writer.WriteString(`."`) + } + writer.WriteString(str) + writer.WriteByte('"') + } + } else { + writer.WriteString(str) + writer.WriteByte('"') + } +} + +var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) +} + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "boolean" + case schema.Int, schema.Uint: + if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { + switch { + case field.Size < 16: + return "smallserial" + case field.Size < 31: + return "serial" + default: + return "bigserial" + } + } else { + switch { + case field.Size < 16: + return "smallint" + case field.Size < 31: + return "integer" + default: + return "bigint" + } + } + case schema.Float: + return "decimal" + case schema.String: + if field.Size > 0 { + return fmt.Sprintf("varchar(%d)", field.Size) + } + return "text" + case schema.Time: + return "timestamptz" + case schema.Bytes: + return "bytea" + } + + return string(field.DataType) +} + +func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error { + tx.Exec("SAVEPOINT " + name) + return nil +} + +func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error { + tx.Exec("ROLLBACK TO SAVEPOINT " + name) + return nil +} diff --git a/vendor/github.com/jinzhu/gorm/.gitignore b/vendor/gorm.io/gorm/.gitignore similarity index 82% rename from vendor/github.com/jinzhu/gorm/.gitignore rename to vendor/gorm.io/gorm/.gitignore index 117f92f52..c14d60050 100644 --- a/vendor/github.com/jinzhu/gorm/.gitignore +++ b/vendor/gorm.io/gorm/.gitignore @@ -1,3 +1,4 @@ +TODO* documents coverage.txt _book diff --git a/vendor/gorm.io/gorm/License b/vendor/gorm.io/gorm/License new file mode 100644 index 000000000..037e1653e --- /dev/null +++ b/vendor/gorm.io/gorm/License @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/jinzhu/gorm/README.md b/vendor/gorm.io/gorm/README.md similarity index 54% rename from vendor/github.com/jinzhu/gorm/README.md rename to vendor/gorm.io/gorm/README.md index 6d2311037..9c0aded05 100644 --- a/vendor/github.com/jinzhu/gorm/README.md +++ b/vendor/gorm.io/gorm/README.md @@ -2,27 +2,28 @@ The fantastic ORM library for Golang, aims to be developer friendly. -[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) -[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) -[![codecov](https://codecov.io/gh/jinzhu/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/jinzhu/gorm) +[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) +[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) +[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) ## Overview -* Full-Featured ORM (almost) -* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) +* Full-Featured ORM +* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance) * Hooks (Before/After Create/Save/Update/Delete/Find) -* Preloading (eager loading) -* Transactions +* Eager loading with `Preload`, `Joins` +* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point +* Context, Prepared Statment Mode, DryRun Mode +* Batch Insert, FindInBatches, Find To Map +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key -* SQL Builder * Auto Migrations * Logger -* Extendable, write Plugins based on GORM callbacks +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly @@ -38,4 +39,4 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) diff --git a/vendor/gorm.io/gorm/association.go b/vendor/gorm.io/gorm/association.go new file mode 100644 index 000000000..7adb8c914 --- /dev/null +++ b/vendor/gorm.io/gorm/association.go @@ -0,0 +1,482 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Association Mode contains some helper methods to handle relationship things easily. +type Association struct { + DB *DB + Relationship *schema.Relationship + Error error +} + +func (db *DB) Association(column string) *Association { + association := &Association{DB: db} + table := db.Statement.Table + + if err := db.Statement.Parse(db.Statement.Model); err == nil { + db.Statement.Table = table + association.Relationship = db.Statement.Schema.Relationships.Relations[column] + + if association.Relationship == nil { + association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + } + + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + } else { + association.Error = err + } + + return association +} + +func (association *Association) Find(out interface{}, conds ...interface{}) error { + if association.Error == nil { + association.Error = association.buildCondition().Find(out, conds...).Error + } + return association.Error +} + +func (association *Association) Append(values ...interface{}) error { + if association.Error == nil { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + if len(values) > 0 { + association.Error = association.Replace(values...) + } + default: + association.saveAssociation( /*clear*/ false, values...) + } + } + + return association.Error +} + +func (association *Association) Replace(values ...interface{}) error { + if association.Error == nil { + // save associations + association.saveAssociation( /*clear*/ true, values...) + + // set old associations's foreign key to null + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + switch rel.Type { + case schema.BelongsTo: + if len(values) == 0 { + updateMap := map[string]interface{}{} + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + } + case reflect.Struct: + association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + } + + for _, ref := range rel.References { + updateMap[ref.ForeignKey.DBName] = nil + } + + association.Error = association.DB.UpdateColumns(updateMap).Error + } + case schema.HasOne, schema.HasMany: + var ( + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { + tx.Not(clause.IN{Column: column, Values: values}) + } + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateMap[ref.ForeignKey.DBName] = nil + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + } + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { + tx.Where(clause.IN{Column: column, Values: values}) + } else { + return ErrPrimaryKeyRequired + } + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { + tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) + } + + association.Error = tx.Delete(modelValue).Error + } + } + return association.Error +} + +func (association *Association) Delete(values ...interface{}) error { + if association.Error == nil { + var ( + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + primaryFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + conds []clause.Expression + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + switch rel.Type { + case schema.BelongsTo: + tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + case schema.HasOne, schema.HasMany: + tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error + } + + if association.Error == nil { + // clean up deleted values's foreign key + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + + cleanUpDeletedRelations := func(data reflect.Value) { + if _, zero := rel.Field.ValueOf(data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) + + switch fieldValue.Kind() { + case reflect.Slice, reflect.Array: + validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) + for i := 0; i < fieldValue.Len(); i++ { + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { + validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) + } + } + + association.Error = rel.Field.Set(data, validFieldValues.Interface()) + case reflect.Struct: + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(fieldValue) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { + if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + break + } + + if rel.JoinTable == nil { + for _, ref := range rel.References { + if ref.OwnPrimaryKey || ref.PrimaryValue != "" { + association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } else { + association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } + } + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) + } + case reflect.Struct: + cleanUpDeletedRelations(reflectValue) + } + } + } + + return association.Error +} + +func (association *Association) Clear() error { + return association.Replace() +} + +func (association *Association) Count() (count int64) { + if association.Error == nil { + association.Error = association.buildCondition().Count(&count).Error + } + return +} + +type assignBack struct { + Source reflect.Value + Index int + Dest reflect.Value +} + +func (association *Association) saveAssociation(clear bool, values ...interface{}) { + var ( + reflectValue = association.DB.Statement.ReflectValue + assignBacks []assignBack // assign association values back to arguments after save + ) + + appendToRelations := func(source, rv reflect.Value, clear bool) { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() > 0 { + association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) + } + } + case reflect.Struct: + association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) + } + } + case schema.HasMany, schema.Many2Many: + elemType := association.Relationship.Field.IndirectFieldType.Elem() + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) + if clear { + fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() + } + + appendToFieldValues := func(ev reflect.Value) { + if ev.Type().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev) + } else if ev.Type().Elem().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev.Elem()) + } else { + association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) + } + + if elemType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) + } + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + } + case reflect.Struct: + appendToFieldValues(rv.Addr()) + } + + if association.Error == nil { + association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) + } + } + } + + selectedSaveColumns := []string{association.Relationship.Name} + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(values) != reflectValue.Len() { + // clear old data + if clear && len(values) == 0 { + for i := 0; i < reflectValue.Len(); i++ { + if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + association.Error = err + break + } + + if association.Relationship.JoinTable == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + association.Error = err + break + } + } + } + } + } + break + } + + association.Error = errors.New("invalid association values, length doesn't match") + return + } + + for i := 0; i < reflectValue.Len(); i++ { + appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + + // TODO support save slice data, sql with case? + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error + } + case reflect.Struct: + // clear old data + if clear && len(values) == 0 { + association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + + if association.Relationship.JoinTable == nil && association.Error == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } + } + } + + for idx, value := range values { + rv := reflect.Indirect(reflect.ValueOf(value)) + appendToRelations(reflectValue, rv, clear && idx == 0) + } + + if len(values) > 0 { + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error + } + } + + for _, assignBack := range assignBacks { + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + if assignBack.Index > 0 { + reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) + } else { + reflect.Indirect(assignBack.Dest).Set(fieldValue) + } + } +} + +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + return tx +} diff --git a/vendor/gorm.io/gorm/callbacks.go b/vendor/gorm.io/gorm/callbacks.go new file mode 100644 index 000000000..e21e0718f --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks.go @@ -0,0 +1,298 @@ +package gorm + +import ( + "context" + "errors" + "fmt" + "reflect" + "sort" + "time" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func initializeCallbacks(db *DB) *callbacks { + return &callbacks{ + processors: map[string]*processor{ + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + }, + } +} + +// callbacks gorm callbacks manager +type callbacks struct { + processors map[string]*processor +} + +type processor struct { + db *DB + fns []func(*DB) + callbacks []*callback +} + +type callback struct { + name string + before string + after string + remove bool + replace bool + match func(*DB) bool + handler func(*DB) + processor *processor +} + +func (cs *callbacks) Create() *processor { + return cs.processors["create"] +} + +func (cs *callbacks) Query() *processor { + return cs.processors["query"] +} + +func (cs *callbacks) Update() *processor { + return cs.processors["update"] +} + +func (cs *callbacks) Delete() *processor { + return cs.processors["delete"] +} + +func (cs *callbacks) Row() *processor { + return cs.processors["row"] +} + +func (cs *callbacks) Raw() *processor { + return cs.processors["raw"] +} + +func (p *processor) Execute(db *DB) { + curTime := time.Now() + stmt := db.Statement + + if stmt.Model == nil { + stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model + } + + if stmt.Model != nil { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) + } else { + db.AddError(err) + } + } + } + + if stmt.Dest != nil { + stmt.ReflectValue = reflect.ValueOf(stmt.Dest) + for stmt.ReflectValue.Kind() == reflect.Ptr { + stmt.ReflectValue = stmt.ReflectValue.Elem() + } + if !stmt.ReflectValue.IsValid() { + db.AddError(fmt.Errorf("invalid value")) + } + } + + for _, f := range p.fns { + f(db) + } + + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + }, db.Error) + + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + } +} + +func (p *processor) Get(name string) func(*DB) { + for i := len(p.callbacks) - 1; i >= 0; i-- { + if v := p.callbacks[i]; v.name == name && !v.remove { + return v.handler + } + } + return nil +} + +func (p *processor) Before(name string) *callback { + return &callback{before: name, processor: p} +} + +func (p *processor) After(name string) *callback { + return &callback{after: name, processor: p} +} + +func (p *processor) Match(fc func(*DB) bool) *callback { + return &callback{match: fc, processor: p} +} + +func (p *processor) Register(name string, fn func(*DB)) error { + return (&callback{processor: p}).Register(name, fn) +} + +func (p *processor) Remove(name string) error { + return (&callback{processor: p}).Remove(name) +} + +func (p *processor) Replace(name string, fn func(*DB)) error { + return (&callback{processor: p}).Replace(name, fn) +} + +func (p *processor) compile() (err error) { + var callbacks []*callback + for _, callback := range p.callbacks { + if callback.match == nil || callback.match(p.db) { + callbacks = append(callbacks, callback) + } + } + p.callbacks = callbacks + + if p.fns, err = sortCallbacks(p.callbacks); err != nil { + p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) + } + return +} + +func (c *callback) Before(name string) *callback { + c.before = name + return c +} + +func (c *callback) After(name string) *callback { + c.after = name + return c +} + +func (c *callback) Register(name string, fn func(*DB)) error { + c.name = name + c.handler = fn + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Remove(name string) error { + c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.remove = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Replace(name string, fn func(*DB)) error { + c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.handler = fn + c.replace = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +// getRIndex get right index from string slice +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i + } + } + return -1 +} + +func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { + var ( + names, sorted []string + sortCallback func(*callback) error + ) + sort.Slice(cs, func(i, j int) bool { + return cs[j].before == "*" || cs[j].after == "*" + }) + + for _, c := range cs { + // show warning message the callback name already exists + if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + } + names = append(names, c.name) + } + + sortCallback = func(c *callback) error { + if c.before != "" { // if defined before callback + if c.before == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append([]string{c.name}, sorted...) + } + } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if before callback already sorted, append current callback just after it + sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) + } + } else if idx := getRIndex(names, c.before); idx != -1 { + // if before callback exists + cs[idx].after = c.name + } + } + + if c.after != "" { // if defined after callback + if c.after == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append(sorted, c.name) + } + } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if after callback sorted, append current callback to last + sorted = append(sorted, c.name) + } else if curIdx < sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + } + } else if idx := getRIndex(names, c.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + after := cs[idx] + + if after.before == "" { + after.before = c.name + } + + if err := sortCallback(after); err != nil { + return err + } + + if err := sortCallback(c); err != nil { + return err + } + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sorted, c.name) == -1 { + sorted = append(sorted, c.name) + } + + return nil + } + + for _, c := range cs { + if err = sortCallback(c); err != nil { + return + } + } + + for _, name := range sorted { + if idx := getRIndex(names, name); !cs[idx].remove { + fns = append(fns, cs[idx].handler) + } + } + + return +} diff --git a/vendor/gorm.io/gorm/callbacks/associations.go b/vendor/gorm.io/gorm/callbacks/associations.go new file mode 100644 index 000000000..9e767e5e1 --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/associations.go @@ -0,0 +1,369 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func SaveBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + // Save Belongs To associations + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elem) + db.AddError(ref.ForeignKey.Set(obj, pv)) + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + objs []reflect.Value + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } + } else { + break + } + } + + if elems.Len() > 0 { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + for i := 0; i < elems.Len(); i++ { + setupReferences(objs[i], elems.Index(i)) + } + } + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + setupReferences(db.Statement.ReflectValue, rv) + } + } + } + } + } +} + +func SaveAfterAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } + } + } + + if elems.Len() > 0 { + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } + + assignmentColumns := []string{} + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(f, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(f, ref.PrimaryValue) + } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + } + } + } + + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(v) + ref.ForeignKey.Set(elem, pv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(elem, ref.PrimaryValue) + } + } + + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + if elems.Len() > 0 { + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + } + } + + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) + objs := []reflect.Value{} + + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(joinValue, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + } else { + fv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(joinValue, fv) + } + } + joins = reflect.Append(joins, joinValue) + } + + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + + objs = append(objs, v) + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + if elems.Len() > 0 { + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + } + + for i := 0; i < elems.Len(); i++ { + appendToJoins(objs[i], elems.Index(i)) + } + } + + if joins.Len() > 0 { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) + } + } + } +} + +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { + if stmt.DB.FullSaveAssociations { + defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) + for _, dbName := range s.DBNames { + if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { + continue + } + + if !s.LookUpField(dbName).PrimaryKey { + defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) + } + } + } + + if len(defaultUpdatingColumns) > 0 { + var columns []clause.Column + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + + return clause.OnConflict{ + Columns: columns, + DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + } + } + + return clause.OnConflict{DoNothing: true} +} + +func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + var ( + selects, omits []string + onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + refName = rel.Name + "." + ) + + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, refName) { + columnName = strings.TrimPrefix(name, refName) + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selects = append(selects, columnName) + } else { + omits = append(omits, columnName) + } + } + } + + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + + if len(selects) > 0 { + tx = tx.Select(selects) + } + + if len(omits) > 0 { + tx = tx.Omit(omits...) + } + + return db.AddError(tx.Create(values).Error) +} diff --git a/vendor/gorm.io/gorm/callbacks/callbacks.go b/vendor/gorm.io/gorm/callbacks/callbacks.go new file mode 100644 index 000000000..dda4b0466 --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/callbacks.go @@ -0,0 +1,51 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +type Config struct { + LastInsertIDReversed bool + WithReturning bool +} + +func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { + enableTransaction := func(db *gorm.DB) bool { + return !db.SkipDefaultTransaction + } + + createCallback := db.Callback().Create() + createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + createCallback.Register("gorm:before_create", BeforeCreate) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + createCallback.Register("gorm:create", Create(config)) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + createCallback.Register("gorm:after_create", AfterCreate) + createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + queryCallback := db.Callback().Query() + queryCallback.Register("gorm:query", Query) + queryCallback.Register("gorm:preload", Preload) + queryCallback.Register("gorm:after_query", AfterQuery) + + deleteCallback := db.Callback().Delete() + deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) + deleteCallback.Register("gorm:delete", Delete) + deleteCallback.Register("gorm:after_delete", AfterDelete) + deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + updateCallback := db.Callback().Update() + updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) + updateCallback.Register("gorm:before_update", BeforeUpdate) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + updateCallback.Register("gorm:update", Update) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + updateCallback.Register("gorm:after_update", AfterUpdate) + updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + db.Callback().Row().Register("gorm:row", RowQuery) + db.Callback().Raw().Register("gorm:raw", RawExec) +} diff --git a/vendor/gorm.io/gorm/callbacks/callmethod.go b/vendor/gorm.io/gorm/callbacks/callmethod.go new file mode 100644 index 000000000..bcaa03f3d --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/callmethod.go @@ -0,0 +1,23 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" +) + +func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { + tx := db.Session(&gorm.Session{NewDB: true}) + if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + db.Statement.CurDestIndex = 0 + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + db.Statement.CurDestIndex++ + } + case reflect.Struct: + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } + } +} diff --git a/vendor/gorm.io/gorm/callbacks/create.go b/vendor/gorm.io/gorm/callbacks/create.go new file mode 100644 index 000000000..3ca56d733 --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/create.go @@ -0,0 +1,361 @@ +package callbacks + +import ( + "fmt" + "reflect" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func BeforeCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeCreate { + if i, ok := value.(BeforeCreateInterface); ok { + called = true + db.AddError(i.BeforeCreate(tx)) + } + } + return called + }) + } +} + +func Create(config *Config) func(db *gorm.DB) { + if config.WithReturning { + return CreateWithReturning + } else { + return func(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } + + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + + if db.RowsAffected > 0 { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID-- + } + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID++ + } + } + } + case reflect.Struct: + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } + } + } else { + db.AddError(err) + } + } + } + } else { + db.AddError(err) + } + } + } + } + } +} + +func CreateWithReturning(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } + + if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { + db.Statement.WriteString(" RETURNING ") + + var ( + fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) + values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) + ) + + for idx, field := range sch.FieldsWithDefaultDBValue { + if idx > 0 { + db.Statement.WriteByte(',') + } + + fields[idx] = field + db.Statement.WriteQuoted(field.DBName) + } + + if !db.DryRun && db.Error == nil { + db.RowsAffected = 0 + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + defer rows.Close() + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + c := db.Statement.Clauses["ON CONFLICT"] + onConflict, _ := c.Expression.(clause.OnConflict) + + for rows.Next() { + BEGIN: + reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) + if reflect.Indirect(reflectValue).Kind() != reflect.Struct { + break + } + + for idx, field := range fields { + fieldValue := field.ReflectValueOf(reflectValue) + + if onConflict.DoNothing && !fieldValue.IsZero() { + db.RowsAffected++ + + if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { + return + } + + goto BEGIN + } + + values[idx] = fieldValue.Addr().Interface() + } + + db.RowsAffected++ + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + } + case reflect.Struct: + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + + if rows.Next() { + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } + } + } else { + db.AddError(err) + } + } + } else if !db.DryRun && db.Error == nil { + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func AfterCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + + if db.Statement.Schema.AfterCreate { + if i, ok := value.(AfterCreateInterface); ok { + called = true + db.AddError(i.AfterCreate(tx)) + } + } + return called + }) + } +} + +// ConvertToCreateValues convert to create values +func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { + switch value := stmt.Dest.(type) { + case map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) + case []map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) + default: + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + curTime = stmt.DB.NowFunc() + isZero bool + ) + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} + + for _, db := range stmt.Schema.DBNames { + if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { + if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: db}) + } + } + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) + values.Values = make([][]interface{}, stmt.ReflectValue.Len()) + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + if stmt.ReflectValue.Len() == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + for i := 0; i < stmt.ReflectValue.Len(); i++ { + rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + + values.Values[i] = make([]interface{}, len(values.Columns)) + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if field.DefaultValueInterface != nil { + values.Values[i][idx] = field.DefaultValueInterface + field.Set(rv, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(rv, curTime) + values.Values[i][idx], _ = field.ValueOf(rv) + } + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, isZero := field.ValueOf(rv); !isZero { + if len(defaultValueFieldsHavingValue[field]) == 0 { + defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) + } + defaultValueFieldsHavingValue[field][i] = v + } + } + } + } + + for field, vs := range defaultValueFieldsHavingValue { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } + } + } + case reflect.Struct: + values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { + if field.DefaultValueInterface != nil { + values.Values[0][idx] = field.DefaultValueInterface + field.Set(stmt.ReflectValue, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + } + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + values.Values[0] = append(values.Values[0], v) + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + if c, ok := stmt.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { + if stmt.Schema != nil && len(values.Columns) > 1 { + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { + if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { + columns = append(columns, column.Name) + } + } + } + + onConflict := clause.OnConflict{ + Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), + DoUpdates: clause.AssignmentColumns(columns), + } + + for idx, field := range stmt.Schema.PrimaryFields { + onConflict.Columns[idx] = clause.Column{Name: field.DBName} + } + + stmt.AddClause(onConflict) + } + } + } + + return values +} diff --git a/vendor/gorm.io/gorm/callbacks/delete.go b/vendor/gorm.io/gorm/callbacks/delete.go new file mode 100644 index 000000000..867aa6970 --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/delete.go @@ -0,0 +1,165 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func BeforeDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(BeforeDeleteInterface); ok { + db.AddError(i.BeforeDelete(tx)) + return true + } + + return false + }) + } +} + +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + if restricted { + for column, v := range selectColumns { + if v { + if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + + if len(db.Statement.Selects) > 0 { + var selects []string + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if strings.HasPrefix(s, column+".") { + selects = append(selects, strings.TrimPrefix(s, column+".")) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions { + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + case schema.Many2Many: + var ( + queryConds []clause.Expression + foreignFields []*schema.Field + relForeignKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + } + } + } + } +} + +func Delete(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(100) + db.Statement.AddClauseIfNotExists(clause.Delete{}) + + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.Build("DELETE", "FROM", "WHERE") + } + + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func AfterDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterDeleteInterface); ok { + db.AddError(i.AfterDelete(tx)) + return true + } + return false + }) + } +} diff --git a/vendor/gorm.io/gorm/callbacks/helper.go b/vendor/gorm.io/gorm/callbacks/helper.go new file mode 100644 index 000000000..3ac63fa19 --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/helper.go @@ -0,0 +1,90 @@ +package callbacks + +import ( + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// ConvertMapToValuesForCreate convert map to values +func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { + values.Columns = make([]clause.Column, 0, len(mapValue)) + selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) + + var keys = make([]string, 0, len(mapValue)) + for k := range mapValue { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + value := mapValue[k] + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + + values.Values[0] = append(values.Values[0], value) + } + } + return +} + +// ConvertSliceOfMapToValuesForCreate convert slice of map to values +func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { + var ( + columns = make([]string, 0, len(mapValues)) + result = map[string][]interface{}{} + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + ) + + if len(mapValues) == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + for idx, mapValue := range mapValues { + for k, v := range mapValue { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if _, ok := result[k]; !ok { + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + result[k] = make([]interface{}, len(mapValues)) + columns = append(columns, k) + } else { + continue + } + } + + result[k][idx] = v + } + } + + sort.Strings(columns) + values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) + for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + + for i, v := range result[column] { + if len(values.Values[i]) == 0 { + values.Values[i] = make([]interface{}, len(columns)) + } + + values.Values[i][idx] = v + } + } + return +} diff --git a/vendor/gorm.io/gorm/callbacks/interfaces.go b/vendor/gorm.io/gorm/callbacks/interfaces.go new file mode 100644 index 000000000..2302470fc --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/vendor/gorm.io/gorm/callbacks/preload.go b/vendor/gorm.io/gorm/callbacks/preload.go new file mode 100644 index 000000000..682427c9e --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/preload.go @@ -0,0 +1,155 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { + var ( + reflectValue = db.Statement.ReflectValue + rel = rels[len(rels)-1] + tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + relForeignKeys []string + relForeignFields []*schema.Field + foreignFields []*schema.Field + foreignValues [][]interface{} + identityMap = map[string][]reflect.Value{} + inlineConds []interface{} + ) + + if len(rels) > 1 { + reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1]) + } + + if rel.JoinTable != nil { + var joinForeignFields, joinRelForeignFields []*schema.Field + var joinForeignKeys []string + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) + joinForeignFields = append(joinForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + } + } + + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(joinForeignValues) == 0 { + return + } + + joinResults := rel.JoinTable.MakeSlice().Elem() + column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) + + // convert join identity map to relation identity map + fieldValues := make([]interface{}, len(joinForeignFields)) + joinFieldValues := make([]interface{}, len(joinRelForeignFields)) + for i := 0; i < joinResults.Len(); i++ { + for idx, field := range joinForeignFields { + fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + } + + for idx, field := range joinRelForeignFields { + joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + } + + if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { + joinKey := utils.ToStringKey(joinFieldValues...) + identityMap[joinKey] = append(identityMap[joinKey], results...) + } + } + + _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + relForeignFields = append(relForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(foreignValues) == 0 { + return + } + } + + reflectResults := rel.FieldSchema.MakeSlice().Elem() + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) + + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + tx = fc(tx) + } else { + inlineConds = append(inlineConds, cond) + } + } + + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + + fieldValues := make([]interface{}, len(relForeignFields)) + + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + default: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + default: + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } + } + } + + for i := 0; i < reflectResults.Len(); i++ { + elem := reflectResults.Index(i) + for idx, field := range relForeignFields { + fieldValues[idx], _ = field.ValueOf(elem) + } + + for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + rel.Field.Set(data, reflectResults.Index(i).Interface()) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + } + } + } + } +} diff --git a/vendor/gorm.io/gorm/callbacks/query.go b/vendor/gorm.io/gorm/callbacks/query.go new file mode 100644 index 000000000..aa4629a2e --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/query.go @@ -0,0 +1,228 @@ +package callbacks + +import ( + "fmt" + "reflect" + "sort" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func Query(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + + if !db.DryRun && db.Error == nil { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer rows.Close() + + gorm.Scan(rows, db, false) + } + } +} + +func BuildQuerySQL(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(100) + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} + + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) + } + } + + if len(db.Statement.Selects) > 0 { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} + } else { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } + } + } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { + selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) + clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) + for _, dbName := range db.Statement.Schema.DBNames { + if v, ok := selectColumns[dbName]; (ok && v) || !ok { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) + } + } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + queryFields := db.QueryFields + if !queryFields { + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } + } + + if queryFields { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + } + } + + // inline joins + if len(db.Statement.Joins) != 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + + joins := []clause.Join{} + for _, join := range db.Statement.Joins { + if db.Statement.Schema == nil { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + }) + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { + tableAliasName := relation.Name + + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } + } + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + }) + } + } + + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + db.Statement.AddClauseIfNotExists(clauseSelect) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + } +} + +func Preload(db *gorm.DB) { + if db.Error == nil && len(db.Statement.Preloads) > 0 { + preloadMap := map[string][]string{} + for name := range db.Statement.Preloads { + if name == clause.Associations { + for _, rel := range db.Statement.Schema.Relationships.Relations { + if rel.Schema == db.Statement.Schema { + preloadMap[rel.Name] = []string{rel.Name} + } + } + } else { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + } + } + } + + preloadNames := make([]string, len(preloadMap)) + idx := 0 + for key := range preloadMap { + preloadNames[idx] = key + idx++ + } + sort.Strings(preloadNames) + + for _, name := range preloadNames { + var ( + curSchema = db.Statement.Schema + preloadFields = preloadMap[name] + rels = make([]*schema.Relationship, len(preloadFields)) + ) + + for idx, preloadField := range preloadFields { + if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { + rels[idx] = rel + curSchema = rel.FieldSchema + } else { + db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) + } + } + + if db.Error == nil { + preload(db, rels, db.Statement.Preloads[name]) + } + } + } +} + +func AfterQuery(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterFindInterface); ok { + db.AddError(i.AfterFind(tx)) + return true + } + return false + }) + } +} diff --git a/vendor/gorm.io/gorm/callbacks/raw.go b/vendor/gorm.io/gorm/callbacks/raw.go new file mode 100644 index 000000000..d594ab391 --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/raw.go @@ -0,0 +1,16 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RawExec(db *gorm.DB) { + if db.Error == nil && !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + } else { + db.RowsAffected, _ = result.RowsAffected() + } + } +} diff --git a/vendor/gorm.io/gorm/callbacks/row.go b/vendor/gorm.io/gorm/callbacks/row.go new file mode 100644 index 000000000..10e880e12 --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/row.go @@ -0,0 +1,21 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RowQuery(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + + if !db.DryRun { + if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } + + db.RowsAffected = -1 + } + } +} diff --git a/vendor/gorm.io/gorm/callbacks/transaction.go b/vendor/gorm.io/gorm/callbacks/transaction.go new file mode 100644 index 000000000..3171b5bba --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/transaction.go @@ -0,0 +1,29 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func BeginTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction { + if tx := db.Begin(); tx.Error == nil { + db.Statement.ConnPool = tx.Statement.ConnPool + db.InstanceSet("gorm:started_transaction", true) + } else { + tx.Error = nil + } + } +} + +func CommitOrRollbackTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction { + if _, ok := db.InstanceGet("gorm:started_transaction"); ok { + if db.Error == nil { + db.Commit() + } else { + db.Rollback() + } + db.Statement.ConnPool = db.ConnPool + } + } +} diff --git a/vendor/gorm.io/gorm/callbacks/update.go b/vendor/gorm.io/gorm/callbacks/update.go new file mode 100644 index 000000000..c8f3922eb --- /dev/null +++ b/vendor/gorm.io/gorm/callbacks/update.go @@ -0,0 +1,263 @@ +package callbacks + +import ( + "reflect" + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func SetupUpdateReflectValue(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if _, ok := dest[rel.Name]; ok { + rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + } + } + } + } + } +} + +func BeforeUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeUpdate { + if i, ok := value.(BeforeUpdateInterface); ok { + called = true + db.AddError(i.BeforeUpdate(tx)) + } + } + + return called + }) + } +} + +func Update(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build("UPDATE", "SET", "WHERE") + } + + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func AfterUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(AfterUpdateInterface); ok { + called = true + db.AddError(i.AfterUpdate(tx)) + } + } + return called + }) + } +} + +// ConvertToAssignments convert to update assignments +func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) + assignValue func(field *schema.Field, value interface{}) + ) + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + assignValue = func(field *schema.Field, value interface{}) { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.ReflectValue.Index(i), value) + } + } + case reflect.Struct: + assignValue = func(field *schema.Field, value interface{}) { + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.ReflectValue, value) + } + } + default: + assignValue = func(field *schema.Field, value interface{}) { + } + } + + updatingValue := reflect.ValueOf(stmt.Dest) + for updatingValue.Kind() == reflect.Ptr { + updatingValue = updatingValue.Elem() + } + + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var primaryKeyExprs []clause.Expression + for i := 0; i < stmt.ReflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + + switch value := updatingValue.Interface().(type) { + case map[string]interface{}: + set = make([]clause.Assignment, 0, len(value)) + + keys := make([]string, 0, len(value)) + for k := range value { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { + assignValue(field, value[k]) + } + continue + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) + } + } + + if !stmt.SkipHooks && stmt.Schema != nil { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) + if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + now := stmt.DB.NowFunc() + assignValue(field, now) + + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.AutoUpdateTime == schema.UnixMillisecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + } else if field.GORMDataType == schema.Time { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } + } + } + } + } + default: + switch updatingValue.Kind() { + case reflect.Struct: + set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) + if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + value, isZero := field.ValueOf(updatingValue) + if !stmt.SkipHooks { + if field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.GORMDataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } + isZero = false + } + } + + if ok || !isZero { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignValue(field, value) + } + } + } else { + if value, isZero := field.ValueOf(updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + return +} diff --git a/vendor/gorm.io/gorm/chainable_api.go b/vendor/gorm.io/gorm/chainable_api.go new file mode 100644 index 000000000..c3a02d205 --- /dev/null +++ b/vendor/gorm.io/gorm/chainable_api.go @@ -0,0 +1,292 @@ +package gorm + +import ( + "fmt" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +// Model specify the model you would like to run db operations +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (db *DB) Model(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Model = value + return +} + +// Clauses Add clauses +func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { + tx = db.getInstance() + var whereConds []interface{} + + for _, cond := range conds { + if c, ok := cond.(clause.Interface); ok { + tx.Statement.AddClause(c) + } else if optimizer, ok := cond.(StatementModifier); ok { + optimizer.ModifyStatement(tx.Statement) + } else { + whereConds = append(whereConds, cond) + } + } + + if len(whereConds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)}) + } + return +} + +var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) + +// Table specify the table you would like to run db operations +func (db *DB) Table(name string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { + tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} + if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { + tx.Statement.Table = results[1] + return + } + } else if tables := strings.Split(name, "."); len(tables) == 2 { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = tables[1] + return + } + + tx.Statement.Table = name + return +} + +// Distinct specify distinct fields that you want querying +func (db *DB) Distinct(args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Distinct = true + if len(args) > 0 { + tx = tx.Select(args[0], args[1:]...) + } + return +} + +// Select specify fields that you want when querying, creating, updating +func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := query.(type) { + case []string: + tx.Statement.Selects = v + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + return + } + } + delete(tx.Statement.Clauses, "SELECT") + case string: + fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) + + // normal field names + if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + tx.Statement.Selects = []string{v} + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + return + } + } + + delete(tx.Statement.Clauses, "SELECT") + } else { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + } + + return +} + +// Omit specify fields that you want to ignore when creating, updating and querying +func (db *DB) Omit(columns ...string) (tx *DB) { + tx = db.getInstance() + + if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + } else { + tx.Statement.Omits = columns + } + return +} + +// Where add conditions +func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: conds}) + } + return +} + +// Not add NOT conditions +func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) + } + return +} + +// Or add OR conditions +func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) + } + return +} + +// Joins specify Joins conditions +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) + return +} + +// Group specify the group method on the find +func (db *DB) Group(name string) (tx *DB) { + tx = db.getInstance() + + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + tx.Statement.AddClause(clause.GroupBy{ + Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, + }) + return +} + +// Having specify HAVING conditions for GROUP BY +func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.GroupBy{ + Having: tx.Statement.BuildCondition(query, args...), + }) + return +} + +// Order specify order when retrieve records from database +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +func (db *DB) Order(value interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := value.(type) { + case clause.OrderByColumn: + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{v}, + }) + default: + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ + Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, + }}, + }) + } + return +} + +// Limit specify the number of records to be retrieved +func (db *DB) Limit(limit int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Limit: limit}) + return +} + +// Offset specify the number of records to skip before starting to return the records +func (db *DB) Offset(offset int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Offset: offset}) + return +} + +// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { + for _, f := range funcs { + db = f(db) + } + return db +} + +// Preload preload associations with given conditions +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Preloads == nil { + tx.Statement.Preloads = map[string][]interface{}{} + } + tx.Statement.Preloads[query] = args + return +} + +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.attrs = attrs + return +} + +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.assigns = attrs + return +} + +func (db *DB) Unscoped() (tx *DB) { + tx = db.getInstance() + tx.Statement.Unscoped = true + return +} + +func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + return +} diff --git a/vendor/gorm.io/gorm/clause/clause.go b/vendor/gorm.io/gorm/clause/clause.go new file mode 100644 index 000000000..d413d0ee2 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/clause.go @@ -0,0 +1,88 @@ +package clause + +// Interface clause interface +type Interface interface { + Name() string + Build(Builder) + MergeClause(*Clause) +} + +// ClauseBuilder clause builder, allows to custmize how to build clause +type ClauseBuilder func(Clause, Builder) + +type Writer interface { + WriteByte(byte) error + WriteString(string) (int, error) +} + +// Builder builder interface +type Builder interface { + Writer + WriteQuoted(field interface{}) + AddVar(Writer, ...interface{}) +} + +// Clause +type Clause struct { + Name string // WHERE + BeforeExpression Expression + AfterNameExpression Expression + AfterExpression Expression + Expression Expression + Builder ClauseBuilder +} + +// Build build clause +func (c Clause) Build(builder Builder) { + if c.Builder != nil { + c.Builder(c, builder) + } else if c.Expression != nil { + if c.BeforeExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') + } + + if c.Name != "" { + builder.WriteString(c.Name) + builder.WriteByte(' ') + } + + if c.AfterNameExpression != nil { + c.AfterNameExpression.Build(builder) + builder.WriteByte(' ') + } + + c.Expression.Build(builder) + + if c.AfterExpression != nil { + builder.WriteByte(' ') + c.AfterExpression.Build(builder) + } + } +} + +const ( + PrimaryKey string = "@@@py@@@" // primary key + CurrentTable string = "@@@ct@@@" // current table + Associations string = "@@@as@@@" // associations +) + +var ( + currentTable = Table{Name: CurrentTable} + PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} +) + +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool +} + +// Table quote with name +type Table struct { + Name string + Alias string + Raw bool +} diff --git a/vendor/gorm.io/gorm/clause/delete.go b/vendor/gorm.io/gorm/clause/delete.go new file mode 100644 index 000000000..fc462cd7f --- /dev/null +++ b/vendor/gorm.io/gorm/clause/delete.go @@ -0,0 +1,23 @@ +package clause + +type Delete struct { + Modifier string +} + +func (d Delete) Name() string { + return "DELETE" +} + +func (d Delete) Build(builder Builder) { + builder.WriteString("DELETE") + + if d.Modifier != "" { + builder.WriteByte(' ') + builder.WriteString(d.Modifier) + } +} + +func (d Delete) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = d +} diff --git a/vendor/gorm.io/gorm/clause/expression.go b/vendor/gorm.io/gorm/clause/expression.go new file mode 100644 index 000000000..b30c46b03 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/expression.go @@ -0,0 +1,301 @@ +package clause + +import ( + "database/sql" + "database/sql/driver" + "go/ast" + "reflect" +) + +// Expression expression interface +type Expression interface { + Build(builder Builder) +} + +// NegationExpressionBuilder negation expression builder +type NegationExpressionBuilder interface { + NegationBuild(builder Builder) +} + +// Expr raw expression +type Expr struct { + SQL string + Vars []interface{} + WithoutParentheses bool +} + +// Build build raw expression +func (expr Expr) Build(builder Builder) { + var ( + afterParenthesis bool + idx int + ) + + for _, v := range []byte(expr.SQL) { + if v == '?' && len(expr.Vars) > idx { + if afterParenthesis || expr.WithoutParentheses { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + + idx++ + } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } + builder.WriteByte(v) + } + } +} + +// NamedExpr raw expression for named expr +type NamedExpr struct { + SQL string + Vars []interface{} +} + +// Build build raw expression +func (expr NamedExpr) Build(builder Builder) { + var ( + idx int + inName bool + namedMap = make(map[string]interface{}, len(expr.Vars)) + ) + + for _, v := range expr.Vars { + switch value := v.(type) { + case sql.NamedArg: + namedMap[value.Name] = value.Value + case map[string]interface{}: + for k, v := range value { + namedMap[k] = v + } + default: + var appendFieldsToMap func(reflect.Value) + appendFieldsToMap = func(reflectValue reflect.Value) { + reflectValue = reflect.Indirect(reflectValue) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + + if fieldStruct.Anonymous { + appendFieldsToMap(reflectValue.Field(i)) + } + } + } + } + } + + appendFieldsToMap(reflect.ValueOf(value)) + } + } + + name := make([]byte, 0, 10) + + for _, v := range []byte(expr.SQL) { + if v == '@' && !inName { + inName = true + name = []byte{} + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' { + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + inName = false + } + + builder.WriteByte(v) + } else if v == '?' && len(expr.Vars) > idx { + builder.AddVar(builder, expr.Vars[idx]) + idx++ + } else if inName { + name = append(name, v) + } else { + builder.WriteByte(v) + } + } + + if inName { + builder.AddVar(builder, namedMap[string(name)]) + } +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder Builder) { + builder.WriteQuoted(in.Column) + + switch len(in.Values) { + case 0: + builder.WriteString(" IN (NULL)") + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" = ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteString(" IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +func (in IN) NegationBuild(builder Builder) { + switch len(in.Values) { + case 0: + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteQuoted(in.Column) + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteQuoted(in.Column) + builder.WriteString(" NOT IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder Builder) { + builder.WriteQuoted(eq.Column) + + if eq.Value == nil { + builder.WriteString(" IS NULL") + } else { + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) + } +} + +func (eq Eq) NegationBuild(builder Builder) { + Neq(eq).Build(builder) +} + +// Neq not equal to for where +type Neq Eq + +func (neq Neq) Build(builder Builder) { + builder.WriteQuoted(neq.Column) + + if neq.Value == nil { + builder.WriteString(" IS NOT NULL") + } else { + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) + } +} + +func (neq Neq) NegationBuild(builder Builder) { + Eq(neq).Build(builder) +} + +// Gt greater than for where +type Gt Eq + +func (gt Gt) Build(builder Builder) { + builder.WriteQuoted(gt.Column) + builder.WriteString(" > ") + builder.AddVar(builder, gt.Value) +} + +func (gt Gt) NegationBuild(builder Builder) { + Lte(gt).Build(builder) +} + +// Gte greater than or equal to for where +type Gte Eq + +func (gte Gte) Build(builder Builder) { + builder.WriteQuoted(gte.Column) + builder.WriteString(" >= ") + builder.AddVar(builder, gte.Value) +} + +func (gte Gte) NegationBuild(builder Builder) { + Lt(gte).Build(builder) +} + +// Lt less than for where +type Lt Eq + +func (lt Lt) Build(builder Builder) { + builder.WriteQuoted(lt.Column) + builder.WriteString(" < ") + builder.AddVar(builder, lt.Value) +} + +func (lt Lt) NegationBuild(builder Builder) { + Gte(lt).Build(builder) +} + +// Lte less than or equal to for where +type Lte Eq + +func (lte Lte) Build(builder Builder) { + builder.WriteQuoted(lte.Column) + builder.WriteString(" <= ") + builder.AddVar(builder, lte.Value) +} + +func (lte Lte) NegationBuild(builder Builder) { + Gt(lte).Build(builder) +} + +// Like whether string matches regular expression +type Like Eq + +func (like Like) Build(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" LIKE ") + builder.AddVar(builder, like.Value) +} + +func (like Like) NegationBuild(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" NOT LIKE ") + builder.AddVar(builder, like.Value) +} diff --git a/vendor/gorm.io/gorm/clause/from.go b/vendor/gorm.io/gorm/clause/from.go new file mode 100644 index 000000000..1ea2d5951 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/from.go @@ -0,0 +1,37 @@ +package clause + +// From from clause +type From struct { + Tables []Table + Joins []Join +} + +// Name from clause name +func (from From) Name() string { + return "FROM" +} + +// Build build from clause +func (from From) Build(builder Builder) { + if len(from.Tables) > 0 { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(table) + } + } else { + builder.WriteQuoted(currentTable) + } + + for _, join := range from.Joins { + builder.WriteByte(' ') + join.Build(builder) + } +} + +// MergeClause merge from clause +func (from From) MergeClause(clause *Clause) { + clause.Expression = from +} diff --git a/vendor/gorm.io/gorm/clause/group_by.go b/vendor/gorm.io/gorm/clause/group_by.go new file mode 100644 index 000000000..882319169 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/group_by.go @@ -0,0 +1,42 @@ +package clause + +// GroupBy group by clause +type GroupBy struct { + Columns []Column + Having []Expression +} + +// Name from clause name +func (groupBy GroupBy) Name() string { + return "GROUP BY" +} + +// Build build group by clause +func (groupBy GroupBy) Build(builder Builder) { + for idx, column := range groupBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + + if len(groupBy.Having) > 0 { + builder.WriteString(" HAVING ") + Where{Exprs: groupBy.Having}.Build(builder) + } +} + +// MergeClause merge group by clause +func (groupBy GroupBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(GroupBy); ok { + copiedColumns := make([]Column, len(v.Columns)) + copy(copiedColumns, v.Columns) + groupBy.Columns = append(copiedColumns, groupBy.Columns...) + + copiedHaving := make([]Expression, len(v.Having)) + copy(copiedHaving, v.Having) + groupBy.Having = append(copiedHaving, groupBy.Having...) + } + clause.Expression = groupBy +} diff --git a/vendor/gorm.io/gorm/clause/insert.go b/vendor/gorm.io/gorm/clause/insert.go new file mode 100644 index 000000000..8efaa0352 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/insert.go @@ -0,0 +1,39 @@ +package clause + +type Insert struct { + Table Table + Modifier string +} + +// Name insert clause name +func (insert Insert) Name() string { + return "INSERT" +} + +// Build build insert clause +func (insert Insert) Build(builder Builder) { + if insert.Modifier != "" { + builder.WriteString(insert.Modifier) + builder.WriteByte(' ') + } + + builder.WriteString("INTO ") + if insert.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(insert.Table) + } +} + +// MergeClause merge insert clause +func (insert Insert) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Insert); ok { + if insert.Modifier == "" { + insert.Modifier = v.Modifier + } + if insert.Table.Name == "" { + insert.Table = v.Table + } + } + clause.Expression = insert +} diff --git a/vendor/gorm.io/gorm/clause/joins.go b/vendor/gorm.io/gorm/clause/joins.go new file mode 100644 index 000000000..f3e373f26 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/joins.go @@ -0,0 +1,47 @@ +package clause + +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin JoinType = "INNER" + LeftJoin JoinType = "LEFT" + RightJoin JoinType = "RIGHT" +) + +// Join join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string + Expression Expression +} + +func (join Join) Build(builder Builder) { + if join.Expression != nil { + join.Expression.Build(builder) + } else { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } + } +} diff --git a/vendor/gorm.io/gorm/clause/limit.go b/vendor/gorm.io/gorm/clause/limit.go new file mode 100644 index 000000000..2082f4d98 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/limit.go @@ -0,0 +1,48 @@ +package clause + +import "strconv" + +// Limit limit clause +type Limit struct { + Limit int + Offset int +} + +// Name where clause name +func (limit Limit) Name() string { + return "LIMIT" +} + +// Build build where clause +func (limit Limit) Build(builder Builder) { + if limit.Limit > 0 { + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(limit.Limit)) + } + if limit.Offset > 0 { + if limit.Limit > 0 { + builder.WriteString(" ") + } + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + } +} + +// MergeClause merge order by clauses +func (limit Limit) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(Limit); ok { + if limit.Limit == 0 && v.Limit != 0 { + limit.Limit = v.Limit + } + + if limit.Offset == 0 && v.Offset > 0 { + limit.Offset = v.Offset + } else if limit.Offset < 0 { + limit.Offset = 0 + } + } + + clause.Expression = limit +} diff --git a/vendor/gorm.io/gorm/clause/locking.go b/vendor/gorm.io/gorm/clause/locking.go new file mode 100644 index 000000000..290aac92b --- /dev/null +++ b/vendor/gorm.io/gorm/clause/locking.go @@ -0,0 +1,31 @@ +package clause + +type Locking struct { + Strength string + Table Table + Options string +} + +// Name where clause name +func (locking Locking) Name() string { + return "FOR" +} + +// Build build where clause +func (locking Locking) Build(builder Builder) { + builder.WriteString(locking.Strength) + if locking.Table.Name != "" { + builder.WriteString(" OF ") + builder.WriteQuoted(locking.Table) + } + + if locking.Options != "" { + builder.WriteByte(' ') + builder.WriteString(locking.Options) + } +} + +// MergeClause merge order by clauses +func (locking Locking) MergeClause(clause *Clause) { + clause.Expression = locking +} diff --git a/vendor/gorm.io/gorm/clause/on_conflict.go b/vendor/gorm.io/gorm/clause/on_conflict.go new file mode 100644 index 000000000..47fe169ca --- /dev/null +++ b/vendor/gorm.io/gorm/clause/on_conflict.go @@ -0,0 +1,45 @@ +package clause + +type OnConflict struct { + Columns []Column + Where Where + DoNothing bool + DoUpdates Set + UpdateAll bool +} + +func (OnConflict) Name() string { + return "ON CONFLICT" +} + +// Build build onConflict clause +func (onConflict OnConflict) Build(builder Builder) { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.Where.Exprs) > 0 { + builder.WriteString("WHERE ") + onConflict.Where.Build(builder) + builder.WriteByte(' ') + } + + if onConflict.DoNothing { + builder.WriteString("DO NOTHING") + } else { + builder.WriteString("DO UPDATE SET ") + onConflict.DoUpdates.Build(builder) + } +} + +// MergeClause merge onConflict clauses +func (onConflict OnConflict) MergeClause(clause *Clause) { + clause.Expression = onConflict +} diff --git a/vendor/gorm.io/gorm/clause/order_by.go b/vendor/gorm.io/gorm/clause/order_by.go new file mode 100644 index 000000000..412180255 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/order_by.go @@ -0,0 +1,54 @@ +package clause + +type OrderByColumn struct { + Column Column + Desc bool + Reorder bool +} + +type OrderBy struct { + Columns []OrderByColumn + Expression Expression +} + +// Name where clause name +func (orderBy OrderBy) Name() string { + return "ORDER BY" +} + +// Build build where clause +func (orderBy OrderBy) Build(builder Builder) { + if orderBy.Expression != nil { + orderBy.Expression.Build(builder) + } else { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column.Column) + if column.Desc { + builder.WriteString(" DESC") + } + } + } +} + +// MergeClause merge order by clauses +func (orderBy OrderBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(OrderBy); ok { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + if orderBy.Columns[i].Reorder { + orderBy.Columns = orderBy.Columns[i:] + clause.Expression = orderBy + return + } + } + + copiedColumns := make([]OrderByColumn, len(v.Columns)) + copy(copiedColumns, v.Columns) + orderBy.Columns = append(copiedColumns, orderBy.Columns...) + } + + clause.Expression = orderBy +} diff --git a/vendor/gorm.io/gorm/clause/returning.go b/vendor/gorm.io/gorm/clause/returning.go new file mode 100644 index 000000000..04bc96dab --- /dev/null +++ b/vendor/gorm.io/gorm/clause/returning.go @@ -0,0 +1,30 @@ +package clause + +type Returning struct { + Columns []Column +} + +// Name where clause name +func (returning Returning) Name() string { + return "RETURNING" +} + +// Build build where clause +func (returning Returning) Build(builder Builder) { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } +} + +// MergeClause merge order by clauses +func (returning Returning) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Returning); ok { + returning.Columns = append(v.Columns, returning.Columns...) + } + + clause.Expression = returning +} diff --git a/vendor/gorm.io/gorm/clause/select.go b/vendor/gorm.io/gorm/clause/select.go new file mode 100644 index 000000000..b93b87690 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/select.go @@ -0,0 +1,45 @@ +package clause + +// Select select attrs when querying, updating, creating +type Select struct { + Distinct bool + Columns []Column + Expression Expression +} + +func (s Select) Name() string { + return "SELECT" +} + +func (s Select) Build(builder Builder) { + if len(s.Columns) > 0 { + if s.Distinct { + builder.WriteString("DISTINCT ") + } + + for idx, column := range s.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +func (s Select) MergeClause(clause *Clause) { + if s.Expression != nil { + if s.Distinct { + if expr, ok := s.Expression.(Expr); ok { + expr.SQL = "DISTINCT " + expr.SQL + clause.Expression = expr + return + } + } + + clause.Expression = s.Expression + } else { + clause.Expression = s + } +} diff --git a/vendor/gorm.io/gorm/clause/set.go b/vendor/gorm.io/gorm/clause/set.go new file mode 100644 index 000000000..6a885711a --- /dev/null +++ b/vendor/gorm.io/gorm/clause/set.go @@ -0,0 +1,60 @@ +package clause + +import "sort" + +type Set []Assignment + +type Assignment struct { + Column Column + Value interface{} +} + +func (set Set) Name() string { + return "SET" +} + +func (set Set) Build(builder Builder) { + if len(set) > 0 { + for idx, assignment := range set { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(assignment.Column) + builder.WriteByte('=') + builder.AddVar(builder, assignment.Value) + } + } else { + builder.WriteQuoted(PrimaryColumn) + builder.WriteByte('=') + builder.WriteQuoted(PrimaryColumn) + } +} + +// MergeClause merge assignments clauses +func (set Set) MergeClause(clause *Clause) { + copiedAssignments := make([]Assignment, len(set)) + copy(copiedAssignments, set) + clause.Expression = Set(copiedAssignments) +} + +func Assignments(values map[string]interface{}) Set { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + + assignments := make([]Assignment, len(keys)) + for idx, key := range keys { + assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} + } + return assignments +} + +func AssignmentColumns(values []string) Set { + assignments := make([]Assignment, len(values)) + for idx, value := range values { + assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} + } + return assignments +} diff --git a/vendor/gorm.io/gorm/clause/update.go b/vendor/gorm.io/gorm/clause/update.go new file mode 100644 index 000000000..f9d68ac67 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/update.go @@ -0,0 +1,38 @@ +package clause + +type Update struct { + Modifier string + Table Table +} + +// Name update clause name +func (update Update) Name() string { + return "UPDATE" +} + +// Build build update clause +func (update Update) Build(builder Builder) { + if update.Modifier != "" { + builder.WriteString(update.Modifier) + builder.WriteByte(' ') + } + + if update.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(update.Table) + } +} + +// MergeClause merge update clause +func (update Update) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Update); ok { + if update.Modifier == "" { + update.Modifier = v.Modifier + } + if update.Table.Name == "" { + update.Table = v.Table + } + } + clause.Expression = update +} diff --git a/vendor/gorm.io/gorm/clause/values.go b/vendor/gorm.io/gorm/clause/values.go new file mode 100644 index 000000000..b2f5421be --- /dev/null +++ b/vendor/gorm.io/gorm/clause/values.go @@ -0,0 +1,45 @@ +package clause + +type Values struct { + Columns []Column + Values [][]interface{} +} + +// Name from clause name +func (Values) Name() string { + return "VALUES" +} + +// Build build from clause +func (values Values) Build(builder Builder) { + if len(values.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteByte(')') + + builder.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteByte('(') + builder.AddVar(builder, value...) + builder.WriteByte(')') + } + } else { + builder.WriteString("DEFAULT VALUES") + } +} + +// MergeClause merge values clauses +func (values Values) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = values +} diff --git a/vendor/gorm.io/gorm/clause/where.go b/vendor/gorm.io/gorm/clause/where.go new file mode 100644 index 000000000..00b1a40e9 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/where.go @@ -0,0 +1,177 @@ +package clause + +import ( + "strings" +) + +// Where where clause +type Where struct { + Exprs []Expression +} + +// Name where clause name +func (where Where) Name() string { + return "WHERE" +} + +// Build build where clause +func (where Where) Build(builder Builder) { + // Switch position if the first query expression is a single Or condition + for idx, expr := range where.Exprs { + if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { + if idx != 0 { + where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] + } + break + } + } + + buildExprs(where.Exprs, builder, " AND ") +} + +func buildExprs(exprs []Expression, builder Builder, joinCond string) { + wrapInParentheses := false + + for idx, expr := range exprs { + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.WriteString(" OR ") + } else { + builder.WriteString(joinCond) + } + } + + if len(exprs) > 1 { + switch v := expr.(type) { + case OrConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case AndConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case Expr: + sql := strings.ToLower(v.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + + if wrapInParentheses { + builder.WriteString(`(`) + expr.Build(builder) + builder.WriteString(`)`) + wrapInParentheses = false + } else { + expr.Build(builder) + } + } +} + +// MergeClause merge where clauses +func (where Where) MergeClause(clause *Clause) { + if w, ok := clause.Expression.(Where); ok { + exprs := make([]Expression, len(w.Exprs)+len(where.Exprs)) + copy(exprs, w.Exprs) + copy(exprs[len(w.Exprs):], where.Exprs) + where.Exprs = exprs + } + + clause.Expression = where +} + +func And(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } else if len(exprs) == 1 { + return exprs[0] + } + return AndConditions{Exprs: exprs} +} + +type AndConditions struct { + Exprs []Expression +} + +func (and AndConditions) Build(builder Builder) { + if len(and.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(and.Exprs, builder, " AND ") + builder.WriteByte(')') + } else { + buildExprs(and.Exprs, builder, " AND ") + } +} + +func Or(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return OrConditions{Exprs: exprs} +} + +type OrConditions struct { + Exprs []Expression +} + +func (or OrConditions) Build(builder Builder) { + if len(or.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(or.Exprs, builder, " OR ") + builder.WriteByte(')') + } else { + buildExprs(or.Exprs, builder, " OR ") + } +} + +func Not(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return NotConditions{Exprs: exprs} +} + +type NotConditions struct { + Exprs []Expression +} + +func (not NotConditions) Build(builder Builder) { + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(" AND ") + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToLower(e.SQL) + if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } +} diff --git a/vendor/gorm.io/gorm/clause/with.go b/vendor/gorm.io/gorm/clause/with.go new file mode 100644 index 000000000..7e9eaef17 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/with.go @@ -0,0 +1,4 @@ +package clause + +type With struct { +} diff --git a/vendor/gorm.io/gorm/errors.go b/vendor/gorm.io/gorm/errors.go new file mode 100644 index 000000000..087550832 --- /dev/null +++ b/vendor/gorm.io/gorm/errors.go @@ -0,0 +1,34 @@ +package gorm + +import ( + "errors" +) + +var ( + // ErrRecordNotFound record not found error + ErrRecordNotFound = errors.New("record not found") + // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + ErrInvalidTransaction = errors.New("no valid transaction") + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("not implemented") + // ErrMissingWhereClause missing where clause + ErrMissingWhereClause = errors.New("WHERE conditions required") + // ErrUnsupportedRelation unsupported relations + ErrUnsupportedRelation = errors.New("unsupported relations") + // ErrPrimaryKeyRequired primary keys required + ErrPrimaryKeyRequired = errors.New("primary key required") + // ErrModelValueRequired model value required + ErrModelValueRequired = errors.New("model value required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") + // ErrUnsupportedDriver unsupported driver + ErrUnsupportedDriver = errors.New("unsupported driver") + // ErrRegistered registered + ErrRegistered = errors.New("registered") + // ErrInvalidField invalid field + ErrInvalidField = errors.New("invalid field") + // ErrEmptySlice empty slice found + ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") +) diff --git a/vendor/gorm.io/gorm/finisher_api.go b/vendor/gorm.io/gorm/finisher_api.go new file mode 100644 index 000000000..d36dc754f --- /dev/null +++ b/vendor/gorm.io/gorm/finisher_api.go @@ -0,0 +1,605 @@ +package gorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Create insert the value into database +func (db *DB) Create(value interface{}) (tx *DB) { + if db.CreateBatchSize > 0 { + return db.CreateInBatches(value, db.CreateBatchSize) + } + + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) + return +} + +// CreateInBatches insert the value in batches into database +func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var rowsAffected int64 + tx = db.getInstance() + tx.AddError(tx.Transaction(func(tx *DB) error { + for i := 0; i < reflectValue.Len(); i += batchSize { + ends := i + batchSize + if ends > reflectValue.Len() { + ends = reflectValue.Len() + } + + subtx := tx.getInstance() + subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() + subtx.callbacks.Create().Execute(subtx) + if subtx.Error != nil { + return subtx.Error + } + rowsAffected += subtx.RowsAffected + } + return nil + })) + tx.RowsAffected = rowsAffected + default: + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) + } + return +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (db *DB) Save(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = value + + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { + tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) + } + tx.callbacks.Create().Execute(tx) + case reflect.Struct: + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { + for _, pf := range tx.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(reflectValue); isZero { + tx.callbacks.Create().Execute(tx) + return + } + } + } + + fallthrough + default: + selectedUpdate := len(tx.Statement.Selects) != 0 + // when updating, use all fields including those zero-value fields + if !selectedUpdate { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } + + tx.callbacks.Update().Execute(tx) + + if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { + result := reflect.New(tx.Statement.Schema.ModelType).Interface() + if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + return tx.Create(value) + } + } + } + + return +} + +// First find first record that match given conditions, order by primary key +func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// Take return a record that match given conditions, the order will depend on the database implementation +func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// Last find last record that match given conditions, order by primary key +func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, + }) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// Find find records that match given conditions +func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + } + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// FindInBatches find records in batches +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + var ( + tx = db.Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }).Session(&Session{}) + queryDB = tx + rowsAffected int64 + batch int + ) + + for { + result := queryDB.Limit(batchSize).Find(dest) + rowsAffected += result.RowsAffected + batch++ + + if result.Error == nil && result.RowsAffected != 0 { + tx.AddError(fc(result, batch)) + } + + if tx.Error != nil || int(result.RowsAffected) < batchSize { + break + } else { + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + } + } + + tx.RowsAffected = rowsAffected + return tx +} + +func (tx *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + } + } else if andCond, ok := expr.(clause.AndConditions); ok { + tx.assignInterfacesToValue(andCond.Exprs) + } + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + exprs := tx.Statement.BuildCondition(value) + tx.assignInterfacesToValue(exprs) + default: + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(reflectValue); !isZero { + if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + } + } + } + } + } + } else if len(values) > 0 { + exprs := tx.Statement.BuildCondition(values[0], values[1:]...) + tx.assignInterfacesToValue(exprs) + return + } + } + } +} + +func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } + return +} + +func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + default: + } + } + } + + return tx.Model(dest).Updates(assigns) + } + + return db +} + +// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Update(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.callbacks.Update().Execute(tx) + return +} + +// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Updates(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + tx.callbacks.Update().Execute(tx) + return +} + +func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.Statement.SkipHooks = true + tx.callbacks.Update().Execute(tx) + return +} + +func (db *DB) UpdateColumns(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + tx.Statement.SkipHooks = true + tx.callbacks.Update().Execute(tx) + return +} + +// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + } + tx.Statement.Dest = value + tx.callbacks.Delete().Execute(tx) + return +} + +func (db *DB) Count(count *int64) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model == nil { + tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() + } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + defer delete(tx.Statement.Clauses, "SELECT") + } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { + expr := clause.Expr{SQL: "count(1)"} + + if len(tx.Statement.Selects) == 1 { + dbName := tx.Statement.Selects[0] + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName + } + } + + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } + } + + tx.Statement.AddClause(clause.Select{Expression: expr}) + defer delete(tx.Statement.Clauses, "SELECT") + } + + if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { + if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { + delete(db.Statement.Clauses, "ORDER BY") + defer func() { + db.Statement.Clauses["ORDER BY"] = orderByClause + }() + } + } + + tx.Statement.Dest = count + tx.callbacks.Query().Execute(tx) + if tx.RowsAffected != 1 { + *count = tx.RowsAffected + } + return +} + +func (db *DB) Row() *sql.Row { + tx := db.getInstance().InstanceSet("rows", false) + tx.callbacks.Row().Execute(tx) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row +} + +func (db *DB) Rows() (*sql.Rows, error) { + tx := db.getInstance().InstanceSet("rows", true) + tx.callbacks.Row().Execute(tx) + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error +} + +// Scan scan value to a struct +func (db *DB) Scan(dest interface{}) (tx *DB) { + config := *db.Config + currentLogger, newLogger := config.Logger, logger.Recorder.New() + config.Logger = newLogger + + tx = db.getInstance() + tx.Config = &config + + if rows, err := tx.Rows(); err != nil { + tx.AddError(err) + } else { + defer rows.Close() + if rows.Next() { + tx.ScanRows(rows, dest) + } + } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger + return +} + +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) +func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model != nil { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + } else if tx.Statement.Table == "" { + tx.AddError(ErrModelValueRequired) + } + + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, + }) + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { + tx := db.getInstance() + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } + tx.Statement.Dest = dest + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + } + Scan(rows, tx, true) + return tx.Error +} + +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { + panicked := true + + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + // nested transaction + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() + + if err == nil { + err = fc(db.Session(&Session{})) + } + } else { + tx := db.Begin(opts...) + + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + if err = tx.Error; err == nil { + err = fc(tx) + } + + if err == nil { + err = tx.Commit().Error + } + } + + panicked = false + return +} + +// Begin begins a transaction +func (db *DB) Begin(opts ...*sql.TxOptions) *DB { + var ( + // clone statement + tx = db.Session(&Session{Context: db.Statement.Context}) + opt *sql.TxOptions + err error + ) + + if len(opts) > 0 { + opt = opts[0] + } + + if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else { + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx +} + +// Commit commit a transaction +func (db *DB) Commit() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Commit()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +// Rollback rollback a transaction +func (db *DB) Rollback() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +func (db *DB) SavePoint(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + db.AddError(savePointer.SavePoint(db, name)) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +func (db *DB) RollbackTo(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + db.AddError(savePointer.RollbackTo(db, name)) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +// Exec execute raw sql +func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + + tx.callbacks.Raw().Execute(tx) + return +} diff --git a/vendor/gorm.io/gorm/go.mod b/vendor/gorm.io/gorm/go.mod new file mode 100644 index 000000000..faf63a46b --- /dev/null +++ b/vendor/gorm.io/gorm/go.mod @@ -0,0 +1,8 @@ +module gorm.io/gorm + +go 1.14 + +require ( + github.com/jinzhu/inflection v1.0.0 + github.com/jinzhu/now v1.1.1 +) diff --git a/vendor/gorm.io/gorm/go.sum b/vendor/gorm.io/gorm/go.sum new file mode 100644 index 000000000..148bd6f53 --- /dev/null +++ b/vendor/gorm.io/gorm/go.sum @@ -0,0 +1,4 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.1 h1:g39TucaRWyV3dwDO++eEc6qf8TVIQ/Da48WmqjZ3i7E= +github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/vendor/gorm.io/gorm/gorm.go b/vendor/gorm.io/gorm/gorm.go new file mode 100644 index 000000000..ae1cf2c9b --- /dev/null +++ b/vendor/gorm.io/gorm/gorm.go @@ -0,0 +1,386 @@ +package gorm + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync" + "time" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +// Config GORM config +type Config struct { + // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity + // You can disable it by setting `SkipDefaultTransaction` to true + SkipDefaultTransaction bool + // NamingStrategy tables, columns naming strategy + NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool + // Logger + Logger logger.Interface + // NowFunc the function to be used when creating a new timestamp + NowFunc func() time.Time + // DryRun generate sql without execute + DryRun bool + // PrepareStmt executes the given query in cached statement + PrepareStmt bool + // DisableAutomaticPing + DisableAutomaticPing bool + // DisableForeignKeyConstraintWhenMigrating + DisableForeignKeyConstraintWhenMigrating bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool + // QueryFields executes the SQL query with all fields of the table + QueryFields bool + // CreateBatchSize default create batch size + CreateBatchSize int + + // ClauseBuilders clause builder + ClauseBuilders map[string]clause.ClauseBuilder + // ConnPool db conn pool + ConnPool ConnPool + // Dialector database dialector + Dialector + // Plugins registered plugins + Plugins map[string]Plugin + + callbacks *callbacks + cacheStore *sync.Map +} + +// DB GORM DB definition +type DB struct { + *Config + Error error + RowsAffected int64 + Statement *Statement + clone int +} + +// Session session config when create session with Session() method +type Session struct { + DryRun bool + PrepareStmt bool + NewDB bool + SkipHooks bool + SkipDefaultTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + QueryFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int +} + +// Open initialize db session based on dialector +func Open(dialector Dialector, config *Config) (db *DB, err error) { + if config == nil { + config = &Config{} + } + + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{} + } + + if config.Logger == nil { + config.Logger = logger.Default + } + + if config.NowFunc == nil { + config.NowFunc = func() time.Time { return time.Now().Local() } + } + + if dialector != nil { + config.Dialector = dialector + } + + if config.Plugins == nil { + config.Plugins = map[string]Plugin{} + } + + if config.cacheStore == nil { + config.cacheStore = &sync.Map{} + } + + db = &DB{Config: config, clone: 1} + + db.callbacks = initializeCallbacks(db) + + if config.ClauseBuilders == nil { + config.ClauseBuilders = map[string]clause.ClauseBuilder{} + } + + if config.Dialector != nil { + err = config.Dialector.Initialize(db) + } + + preparedStmt := &PreparedStmtDB{ + ConnPool: db.ConnPool, + Stmts: map[string]*sql.Stmt{}, + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } + db.cacheStore.Store("preparedStmt", preparedStmt) + + if config.PrepareStmt { + db.ConnPool = preparedStmt + } + + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + } + + if err == nil && !config.DisableAutomaticPing { + if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { + err = pinger.Ping() + } + } + + if err != nil { + config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) + } + + return +} + +// Session create new db session +func (db *DB) Session(config *Session) *DB { + var ( + txConfig = *db.Config + tx = &DB{ + Config: &txConfig, + Statement: db.Statement, + clone: 1, + } + ) + + if config.CreateBatchSize > 0 { + tx.Config.CreateBatchSize = config.CreateBatchSize + } + + if config.SkipDefaultTransaction { + tx.Config.SkipDefaultTransaction = true + } + + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + + if config.Context != nil || config.PrepareStmt || config.SkipHooks { + tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx + } + + if config.Context != nil { + tx.Statement.Context = config.Context + } + + if config.PrepareStmt { + if v, ok := db.cacheStore.Load("preparedStmt"); ok { + preparedStmt := v.(*PreparedStmtDB) + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true + } + } + + if config.SkipHooks { + tx.Statement.SkipHooks = true + } + + if !config.NewDB { + tx.clone = 2 + } + + if config.DryRun { + tx.Config.DryRun = true + } + + if config.QueryFields { + tx.Config.QueryFields = true + } + + if config.Logger != nil { + tx.Config.Logger = config.Logger + } + + if config.NowFunc != nil { + tx.Config.NowFunc = config.NowFunc + } + + return tx +} + +// WithContext change current instance db's context to ctx +func (db *DB) WithContext(ctx context.Context) *DB { + return db.Session(&Session{Context: ctx}) +} + +// Debug start debug mode +func (db *DB) Debug() (tx *DB) { + return db.Session(&Session{ + Logger: db.Logger.LogMode(logger.Info), + }) +} + +// Set store value with key into current db instance's context +func (db *DB) Set(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(key, value) + return tx +} + +// Get get value with key from current db instance's context +func (db *DB) Get(key string) (interface{}, bool) { + return db.Statement.Settings.Load(key) +} + +// InstanceSet store value with key into current db instance's context +func (db *DB) InstanceSet(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) + return tx +} + +// InstanceGet get value with key from current db instance's context +func (db *DB) InstanceGet(key string) (interface{}, bool) { + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) +} + +// Callback returns callback manager +func (db *DB) Callback() *callbacks { + return db.callbacks +} + +// AddError add error to db +func (db *DB) AddError(err error) error { + if db.Error == nil { + db.Error = err + } else if err != nil { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } + return db.Error +} + +// DB returns `*sql.DB` +func (db *DB) DB() (*sql.DB, error) { + connPool := db.ConnPool + + if stmtDB, ok := connPool.(*PreparedStmtDB); ok { + connPool = stmtDB.ConnPool + } + + if sqldb, ok := connPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, errors.New("invalid db") +} + +func (db *DB) getInstance() *DB { + if db.clone > 0 { + tx := &DB{Config: db.Config} + + if db.clone == 1 { + // clone with new statement + tx.Statement = &Statement{ + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + } + } else { + // with clone statement + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx + } + + return tx + } + + return db +} + +func Expr(expr string, args ...interface{}) clause.Expr { + return clause.Expr{SQL: expr, Vars: args} +} + +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + if err := stmt.Parse(model); err == nil { + modelSchema = stmt.Schema + } else { + return err + } + + if err := stmt.Parse(joinTable); err == nil { + joinSchema = stmt.Schema + } else { + return err + } + + if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { + for _, ref := range relation.References { + if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } + ref.ForeignKey = f + } else { + return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + } + } + + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } + } + + relation.JoinTable = joinSchema + } else { + return fmt.Errorf("failed to found relation: %v", field) + } + + return nil +} + +func (db *DB) Use(plugin Plugin) (err error) { + name := plugin.Name() + if _, ok := db.Plugins[name]; !ok { + if err = plugin.Initialize(db); err == nil { + db.Plugins[name] = plugin + } + } else { + return ErrRegistered + } + + return err +} diff --git a/vendor/gorm.io/gorm/interfaces.go b/vendor/gorm.io/gorm/interfaces.go new file mode 100644 index 000000000..e933952bb --- /dev/null +++ b/vendor/gorm.io/gorm/interfaces.go @@ -0,0 +1,59 @@ +package gorm + +import ( + "context" + "database/sql" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Dialector GORM database dialector +type Dialector interface { + Name() string + Initialize(*DB) error + Migrator(db *DB) Migrator + DataTypeOf(*schema.Field) string + DefaultValueOf(*schema.Field) clause.Expression + BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) + QuoteTo(clause.Writer, string) + Explain(sql string, vars ...interface{}) string +} + +// Plugin GORM plugin interface +type Plugin interface { + Name() string + Initialize(*DB) error +} + +// ConnPool db conns pool interface +type ConnPool interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// SavePointerDialectorInterface save pointer interface +type SavePointerDialectorInterface interface { + SavePoint(tx *DB, name string) error + RollbackTo(tx *DB, name string) error +} + +type TxBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +type ConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) +} + +type TxCommitter interface { + Commit() error + Rollback() error +} + +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr +} diff --git a/vendor/gorm.io/gorm/logger/logger.go b/vendor/gorm.io/gorm/logger/logger.go new file mode 100644 index 000000000..11619c92f --- /dev/null +++ b/vendor/gorm.io/gorm/logger/logger.go @@ -0,0 +1,183 @@ +package logger + +import ( + "context" + "fmt" + "io/ioutil" + "log" + "os" + "time" + + "gorm.io/gorm/utils" +) + +// Colors +const ( + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + BlueBold = "\033[34;1m" + MagentaBold = "\033[35;1m" + RedBold = "\033[31;1m" + YellowBold = "\033[33;1m" +) + +// LogLevel +type LogLevel int + +const ( + Silent LogLevel = iota + 1 + Error + Warn + Info +) + +// Writer log writer interface +type Writer interface { + Printf(string, ...interface{}) +} + +type Config struct { + SlowThreshold time.Duration + Colorful bool + LogLevel LogLevel +} + +// Interface logger interface +type Interface interface { + LogMode(LogLevel) Interface + Info(context.Context, string, ...interface{}) + Warn(context.Context, string, ...interface{}) + Error(context.Context, string, ...interface{}) + Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) +} + +var ( + Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: Warn, + Colorful: true, + }) + Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} +) + +func New(writer Writer, config Config) Interface { + var ( + infoStr = "%s\n[info] " + warnStr = "%s\n[warn] " + errStr = "%s\n[error] " + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" + ) + + if config.Colorful { + infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset + errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + } + + return &logger{ + Writer: writer, + Config: config, + infoStr: infoStr, + warnStr: warnStr, + errStr: errStr, + traceStr: traceStr, + traceWarnStr: traceWarnStr, + traceErrStr: traceErrStr, + } +} + +type logger struct { + Writer + Config + infoStr, warnStr, errStr string + traceStr, traceErrStr, traceWarnStr string +} + +// LogMode log mode +func (l *logger) LogMode(level LogLevel) Interface { + newlogger := *l + newlogger.LogLevel = level + return &newlogger +} + +// Info print info +func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Warn print warn messages +func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Error print error messages +func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Trace print sql message +func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + if l.LogLevel > 0 { + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel >= Info: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + } + } +} + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} +} + +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/vendor/gorm.io/gorm/logger/sql.go b/vendor/gorm.io/gorm/logger/sql.go new file mode 100644 index 000000000..d080def21 --- /dev/null +++ b/vendor/gorm.io/gorm/logger/sql.go @@ -0,0 +1,127 @@ +package logger + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" + "time" + "unicode" + + "gorm.io/gorm/utils" +) + +func isPrintable(s []byte) bool { + for _, r := range s { + if !unicode.IsPrint(rune(r)) { + return false + } + } + return true +} + +var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} + +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { + var convertParams func(interface{}, int) + var vars = make([]string, len(avars)) + + convertParams = func(v interface{}, idx int) { + switch v := v.(type) { + case bool: + vars[idx] = strconv.FormatBool(v) + case time.Time: + if v.IsZero() { + vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + } else { + vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + } + case *time.Time: + if v != nil { + if v.IsZero() { + vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + } else { + vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + } + } else { + vars[idx] = "NULL" + } + case fmt.Stringer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = "NULL" + } + case driver.Valuer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + r, _ := v.Value() + convertParams(r, idx) + } else { + vars[idx] = "NULL" + } + case []byte: + if isPrintable(v) { + vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = escaper + "" + escaper + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + vars[idx] = utils.ToString(v) + case float64, float32: + vars[idx] = fmt.Sprintf("%.6f", v) + case string: + vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + default: + rv := reflect.ValueOf(v) + if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { + vars[idx] = "NULL" + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) + } else { + for _, t := range convertableTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } + } + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + } + } + } + + for idx, v := range avars { + convertParams(v, idx) + } + + if numericPlaceholder == nil { + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) + } + + sql = newSQL.String() + } else { + sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") + for idx, v := range vars { + sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) + } + } + + return sql +} diff --git a/vendor/gorm.io/gorm/migrator.go b/vendor/gorm.io/gorm/migrator.go new file mode 100644 index 000000000..28ac35e7d --- /dev/null +++ b/vendor/gorm.io/gorm/migrator.go @@ -0,0 +1,70 @@ +package gorm + +import ( + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Migrator returns migrator +func (db *DB) Migrator() Migrator { + return db.Dialector.Migrator(db.Session(&Session{})) +} + +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + +// ViewOption view option +type ViewOption struct { + Replace bool + CheckOption string + Query *DB +} + +type ColumnType interface { + Name() string + DatabaseTypeName() string + Length() (length int64, ok bool) + DecimalSize() (precision int64, scale int64, ok bool) + Nullable() (nullable bool, ok bool) +} + +type Migrator interface { + // AutoMigrate + AutoMigrate(dst ...interface{}) error + + // Database + CurrentDatabase() string + FullDataTypeOf(*schema.Field) clause.Expr + + // Tables + CreateTable(dst ...interface{}) error + DropTable(dst ...interface{}) error + HasTable(dst interface{}) bool + RenameTable(oldName, newName interface{}) error + + // Columns + AddColumn(dst interface{}, field string) error + DropColumn(dst interface{}, field string) error + AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + HasColumn(dst interface{}, field string) bool + RenameColumn(dst interface{}, oldName, field string) error + ColumnTypes(dst interface{}) ([]ColumnType, error) + + // Views + CreateView(name string, option ViewOption) error + DropView(name string) error + + // Constraints + CreateConstraint(dst interface{}, name string) error + DropConstraint(dst interface{}, name string) error + HasConstraint(dst interface{}, name string) bool + + // Indexes + CreateIndex(dst interface{}, name string) error + DropIndex(dst interface{}, name string) error + HasIndex(dst interface{}, name string) bool + RenameIndex(dst interface{}, oldName, newName string) error +} diff --git a/vendor/gorm.io/gorm/migrator/migrator.go b/vendor/gorm.io/gorm/migrator/migrator.go new file mode 100644 index 000000000..084d430f0 --- /dev/null +++ b/vendor/gorm.io/gorm/migrator/migrator.go @@ -0,0 +1,706 @@ +package migrator + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Migrator m struct +type Migrator struct { + Config +} + +// Config schema config +type Config struct { + CreateIndexAfterCreateTable bool + DB *gorm.DB + gorm.Dialector +} + +type GormDataTypeInterface interface { + GormDBDataType(*gorm.DB, *schema.Field) string +} + +func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := &gorm.Statement{DB: m.DB} + if m.DB.Statement != nil { + stmt.Table = m.DB.Statement.Table + stmt.TableExpr = m.DB.Statement.TableExpr + } + + if table, ok := value.(string); ok { + stmt.Table = table + } else if err := stmt.Parse(value); err != nil { + return err + } + + return fc(stmt) +} + +func (m Migrator) DataTypeOf(field *schema.Field) string { + fieldValue := reflect.New(field.IndirectFieldType) + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { + return dataType + } + } + + return m.Dialector.DataTypeOf(field) +} + +func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { + expr.SQL = m.DataTypeOf(field) + + if field.NotNull { + expr.SQL += " NOT NULL" + } + + if field.Unique { + expr.SQL += " UNIQUE" + } + + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) + } else if field.DefaultValue != "(-)" { + expr.SQL += " DEFAULT " + field.DefaultValue + } + } + + return +} + +// AutoMigrate +func (m Migrator) AutoMigrate(values ...interface{}) error { + for _, value := range m.ReorderModels(values, true) { + tx := m.DB.Session(&gorm.Session{NewDB: true}) + if !tx.Migrator().HasTable(value) { + if err := tx.Migrator().CreateTable(value); err != nil { + return err + } + } else { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + + for _, field := range stmt.Schema.FieldsByDBName { + var foundColumn gorm.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column + if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { + return err + } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + } + + for _, idx := range stmt.Schema.ParseIndexes() { + if !tx.Migrator().HasIndex(value, idx.Name) { + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + + return nil + }); err != nil { + return err + } + } + } + + return nil +} + +func (m Migrator) CreateTable(values ...interface{}) error { + for _, value := range m.ReorderModels(values, false) { + tx := m.DB.Session(&gorm.Session{NewDB: true}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + var ( + createTableSQL = "CREATE TABLE ? (" + values = []interface{}{m.CurrentTable(stmt)} + hasPrimaryKeyInDataType bool + ) + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) + createTableSQL += "," + } + + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := []interface{}{} + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) + } + + values = append(values, primaryKeys) + } + + for _, idx := range stmt.Schema.ParseIndexes() { + if m.CreateIndexAfterCreateTable { + defer func(value interface{}, name string) { + errr = tx.Migrator().CreateIndex(value, name) + }(value, idx.Name) + } else { + if idx.Class != "" { + createTableSQL += idx.Class + " " + } + createTableSQL += "INDEX ? ?" + + if idx.Option != "" { + createTableSQL += " " + idx.Option + } + + createTableSQL += "," + values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if !m.DB.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + createTableSQL += "CONSTRAINT ? CHECK (?)," + values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } + + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + + createTableSQL += ")" + + if tableOption, ok := m.DB.Get("gorm:table_options"); ok { + createTableSQL += fmt.Sprint(tableOption) + } + + errr = tx.Exec(createTableSQL, values...).Error + return errr + }); err != nil { + return err + } + } + return nil +} + +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + for i := len(values) - 1; i >= 0; i-- { + tx := m.DB.Session(&gorm.Session{NewDB: true}) + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error + }); err != nil { + return err + } + } + return nil +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) RenameTable(oldName, newName interface{}) error { + var oldTable, newTable interface{} + if v, ok := oldName.(string); ok { + oldTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(oldName); err == nil { + oldTable = m.CurrentTable(stmt) + } else { + return err + } + } + + if v, ok := newName.(string); ok { + newTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(newName); err == nil { + newTable = m.CurrentTable(stmt) + } else { + return err + } + } + + return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error +} + +func (m Migrator) AddColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return m.DB.Exec( + "ALTER TABLE ? ADD ? ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + ).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name}, + ).Error + }) +} + +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + fileType := clause.Expr{SQL: m.DataTypeOf(field)} + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, + ).Error + + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + // found, smart migrate + fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + alterColumn := false + + // check size + if length, _ := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1) + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if strings.Contains(m.DataTypeOf(field), fmt.Sprint(field.Precision)) { + alterColumn = true + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + if alterColumn { + return m.DB.Migrator().AlterColumn(value, field.Name) + } + + return nil +} + +func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { + columnTypes = make([]gorm.ColumnType, 0) + err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + rows, err := m.DB.Session(&gorm.Session{NewDB: true}).Table(stmt.Table).Limit(1).Rows() + if err == nil { + defer rows.Close() + rawColumnTypes, err := rows.ColumnTypes() + if err == nil { + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, c) + } + } + } + return err + }) + return +} + +func (m Migrator) CreateView(name string, option gorm.ViewOption) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) DropView(name string) error { + return gorm.ErrNotImplemented +} + +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return m.DB.Exec( + "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", + m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + ).Error + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + sql, values := buildConstraint(constraint) + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{m.CurrentTable(stmt)}, values...)...).Error + } + } + + err := fmt.Errorf("failed to create constraint with name %v", name) + if field := stmt.Schema.LookUpField(name); field != nil { + for _, cc := range checkConstraints { + if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { + return err + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { + if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { + return err + } + } + } + } + + return err + }) +} + +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER TABLE ? DROP CONSTRAINT ?", + m.CurrentTable(stmt), clause.Column{Name: name}, + ).Error + }) +} + +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } else if opt.Length > 0 { + str += fmt.Sprintf("(%d)", opt.Length) + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +type BuildIndexOptionsInterface interface { + BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + if idx.Option != "" { + createIndexSQL += " " + idx.Option + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error + }) +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER TABLE ? RENAME INDEX ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) + return +} + +// ReorderModels reorder models according to constraint dependencies +func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { + type Dependency struct { + *gorm.Statement + Depends []*schema.Schema + } + + var ( + modelNames, orderedModelNames []string + orderedModelNamesMap = map[string]bool{} + parsedSchemas = map[*schema.Schema]bool{} + valuesMap = map[string]Dependency{} + insertIntoOrderedList func(name string) + parseDependence func(value interface{}, addToList bool) + ) + + parseDependence = func(value interface{}, addToList bool) { + dep := Dependency{ + Statement: &gorm.Statement{DB: m.DB, Dest: value}, + } + beDependedOn := map[*schema.Schema]bool{} + if err := dep.Parse(value); err != nil { + m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) + } + if _, ok := parsedSchemas[dep.Statement.Schema]; ok { + return + } + parsedSchemas[dep.Statement.Schema] = true + + for _, rel := range dep.Schema.Relationships.Relations { + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { + dep.Depends = append(dep.Depends, c.ReferenceSchema) + } + + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } + + if rel.JoinTable != nil { + // append join value + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } + parseDependence(joinValue, autoAdd) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + } + } + + valuesMap[dep.Schema.Table] = dep + + if addToList { + modelNames = append(modelNames, dep.Schema.Table) + } + } + + insertIntoOrderedList = func(name string) { + if _, ok := orderedModelNamesMap[name]; ok { + return // avoid loop + } + orderedModelNamesMap[name] = true + + dep := valuesMap[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + insertIntoOrderedList(d.Table) + } else if autoAdd { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedList(d.Table) + } + } + + orderedModelNames = append(orderedModelNames, name) + } + + for _, value := range values { + if v, ok := value.(string); ok { + results = append(results, v) + } else { + parseDependence(value, true) + } + } + + for _, name := range modelNames { + insertIntoOrderedList(name) + } + + for _, name := range orderedModelNames { + results = append(results, valuesMap[name].Statement.Dest) + } + return +} + +func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { + if stmt.TableExpr != nil { + return *stmt.TableExpr + } + return clause.Table{Name: stmt.Table} +} diff --git a/vendor/gorm.io/gorm/model.go b/vendor/gorm.io/gorm/model.go new file mode 100644 index 000000000..3334d17cb --- /dev/null +++ b/vendor/gorm.io/gorm/model.go @@ -0,0 +1,15 @@ +package gorm + +import "time" + +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embedded into your model or you may build your own model without it +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt DeletedAt `gorm:"index"` +} diff --git a/vendor/gorm.io/gorm/prepare_stmt.go b/vendor/gorm.io/gorm/prepare_stmt.go new file mode 100644 index 000000000..eddee1f2b --- /dev/null +++ b/vendor/gorm.io/gorm/prepare_stmt.go @@ -0,0 +1,150 @@ +package gorm + +import ( + "context" + "database/sql" + "sync" +) + +type PreparedStmtDB struct { + Stmts map[string]*sql.Stmt + PreparedSQL []string + Mux *sync.RWMutex + ConnPool +} + +func (db *PreparedStmtDB) Close() { + db.Mux.Lock() + for _, query := range db.PreparedSQL { + if stmt, ok := db.Stmts[query]; ok { + delete(db.Stmts, query) + stmt.Close() + } + } + + db.Mux.Unlock() +} + +func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) { + db.Mux.RLock() + if stmt, ok := db.Stmts[query]; ok { + db.Mux.RUnlock() + return stmt, nil + } + db.Mux.RUnlock() + + db.Mux.Lock() + // double check + if stmt, ok := db.Stmts[query]; ok { + db.Mux.Unlock() + return stmt, nil + } + + stmt, err := db.ConnPool.PrepareContext(ctx, query) + if err == nil { + db.Stmts[query] = stmt + db.PreparedSQL = append(db.PreparedSQL, query) + } + db.Mux.Unlock() + + return stmt, err +} + +func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { + if beginner, ok := db.ConnPool.(TxBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } + return nil, ErrInvalidTransaction +} + +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := db.prepare(ctx, query) + if err == nil { + result, err = stmt.ExecContext(ctx, args...) + if err != nil { + db.Mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.Mux.Unlock() + } + } + return result, err +} + +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := db.prepare(ctx, query) + if err == nil { + rows, err = stmt.QueryContext(ctx, args...) + if err != nil { + db.Mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.Mux.Unlock() + } + } + return rows, err +} + +func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := db.prepare(ctx, query) + if err == nil { + return stmt.QueryRowContext(ctx, args...) + } + return &sql.Row{} +} + +type PreparedStmtTX struct { + *sql.Tx + PreparedStmtDB *PreparedStmtDB +} + +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + if err == nil { + result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.Mux.Unlock() + } + } + return result, err +} + +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + if err == nil { + rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.Mux.Unlock() + } + } + return rows, err +} + +func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + if err == nil { + return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) + } + return &sql.Row{} +} diff --git a/vendor/gorm.io/gorm/scan.go b/vendor/gorm.io/gorm/scan.go new file mode 100644 index 000000000..0416489d9 --- /dev/null +++ b/vendor/gorm.io/gorm/scan.go @@ -0,0 +1,247 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "reflect" + "strings" + "time" + + "gorm.io/gorm/schema" +) + +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + +func Scan(rows *sql.Rows, db *DB, initialized bool) { + columns, _ := rows.Columns() + values := make([]interface{}, len(columns)) + db.RowsAffected = 0 + + switch dest := db.Statement.Dest.(type) { + case map[string]interface{}, *map[string]interface{}: + if initialized || rows.Next() { + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue, ok := dest.(map[string]interface{}) + if !ok { + if v, ok := dest.(*map[string]interface{}); ok { + mapValue = *v + } + } + scanIntoMap(mapValue, values, columns) + } + case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() + for initialized || rows.Next() { + prepareValues(values, db, columnTypes, columns) + + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue := map[string]interface{}{} + scanIntoMap(mapValue, values, columns) + *dest = append(*dest, mapValue) + } + case *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, + *float32, *float64, + *bool, *string, *time.Time, + *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, + *sql.NullBool, *sql.NullString, *sql.NullTime: + for initialized || rows.Next() { + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(dest)) + } + default: + Schema := db.Statement.Schema + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + reflectValueType = db.Statement.ReflectValue.Type().Elem() + isPtr = reflectValueType.Kind() == reflect.Ptr + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + ) + + if isPtr { + reflectValueType = reflectValueType.Elem() + } + + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) + + if Schema != nil { + if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + + for idx, column := range columns { + if field := Schema.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } + + // pluck values into slice of data + isPluck := false + if len(fields) == 1 { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner + reflectValueType.Kind() != reflect.Struct || // is not struct + Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + isPluck = true + } + } + + for initialized || rows.Next() { + initialized = false + db.RowsAffected++ + + elem := reflect.New(reflectValueType) + if isPluck { + db.AddError(rows.Scan(elem.Interface())) + } else { + for idx, field := range fields { + if field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } + } + + db.AddError(rows.Scan(values...)) + + for idx, field := range fields { + if len(joinFields) != 0 && joinFields[idx][0] != nil { + value := reflect.ValueOf(values[idx]).Elem() + relValue := joinFields[idx][0].ReflectValueOf(elem) + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } else if field != nil { + field.Set(elem, values[idx]) + } + } + } + + if isPtr { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + } else { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) + } + } + case reflect.Struct: + if db.Statement.ReflectValue.Type() != Schema.ModelType { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + + if initialized || rows.Next() { + for idx, column := range columns { + if field := Schema.LookUpField(column); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + for idx, column := range columns { + if field := Schema.LookUpField(column); field != nil && field.Readable { + field.Set(db.Statement.ReflectValue, values[idx]) + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + value := reflect.ValueOf(values[idx]).Elem() + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } + } + } + } + } + } + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { + db.AddError(ErrRecordNotFound) + } +} diff --git a/vendor/gorm.io/gorm/schema/check.go b/vendor/gorm.io/gorm/schema/check.go new file mode 100644 index 000000000..7d31ec70e --- /dev/null +++ b/vendor/gorm.io/gorm/schema/check.go @@ -0,0 +1,32 @@ +package schema + +import ( + "regexp" + "strings" +) + +type Check struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]Check { + var checks = map[string]Check{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) { + checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = Check{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} diff --git a/vendor/gorm.io/gorm/schema/field.go b/vendor/gorm.io/gorm/schema/field.go new file mode 100644 index 000000000..86b4a0610 --- /dev/null +++ b/vendor/gorm.io/gorm/schema/field.go @@ -0,0 +1,819 @@ +package schema + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/jinzhu/now" + "gorm.io/gorm/utils" +) + +type DataType string + +type TimeType int64 + +var TimeReflectType = reflect.TypeOf(time.Time{}) + +const ( + UnixSecond TimeType = 1 + UnixMillisecond TimeType = 2 + UnixNanosecond TimeType = 3 +) + +const ( + Bool DataType = "bool" + Int DataType = "int" + Uint DataType = "uint" + Float DataType = "float" + String DataType = "string" + Time DataType = "time" + Bytes DataType = "bytes" +) + +type Field struct { + Name string + DBName string + BindNames []string + DataType DataType + GORMDataType DataType + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + Readable bool + HasDefaultValue bool + AutoCreateTime TimeType + AutoUpdateTime TimeType + DefaultValue string + DefaultValueInterface interface{} + NotNull bool + Unique bool + Comment string + Size int + Precision int + Scale int + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + OwnerSchema *Schema + ReflectValueOf func(reflect.Value) reflect.Value + ValueOf func(reflect.Value) (value interface{}, zero bool) + Set func(reflect.Value, interface{}) error +} + +func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + var err error + + field := &Field{ + Name: fieldStruct.Name, + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Creatable: true, + Updatable: true, + Readable: true, + Tag: fieldStruct.Tag, + TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), + Schema: schema, + } + + for field.IndirectFieldType.Kind() == reflect.Ptr { + field.IndirectFieldType = field.IndirectFieldType.Elem() + } + + fieldValue := reflect.New(field.IndirectFieldType) + // if field is valuer, used its value or first fields as data type + valuer, isValuer := fieldValue.Interface().(driver.Valuer) + if isValuer { + if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { + fieldValue = reflect.ValueOf(v) + } + + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + rv := reflect.Indirect(v) + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { + for i := 0; i < rv.Type().NumField(); i++ { + newFieldType := rv.Type().Field(i).Type + for newFieldType.Kind() == reflect.Ptr { + newFieldType = newFieldType.Elem() + } + + fieldValue = reflect.New(newFieldType) + + if rv.Type() != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + } + } + + getRealFieldValue(fieldValue) + } + } + + if dbName, ok := field.TagSettings["COLUMN"]; ok { + field.DBName = dbName + } + + if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + field.PrimaryKey = true + } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + field.PrimaryKey = true + } + + if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { + field.AutoIncrement = true + field.HasDefaultValue = true + } + + if v, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + field.DefaultValue = v + } + + if num, ok := field.TagSettings["SIZE"]; ok { + if field.Size, err = strconv.Atoi(num); err != nil { + field.Size = -1 + } + } + + if p, ok := field.TagSettings["PRECISION"]; ok { + field.Precision, _ = strconv.Atoi(p) + } + + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { + field.NotNull = true + } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { + field.NotNull = true + } + + if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { + field.Unique = true + } + + if val, ok := field.TagSettings["COMMENT"]; ok { + field.Comment = val + } + + // default value is function or null or blank (primary keys) + skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Bool: + field.DataType = Bool + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.DataType = Int + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.DataType = Uint + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + } + } + case reflect.Float32, reflect.Float64: + field.DataType = Float + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + } + } + case reflect.String: + field.DataType = String + + if field.HasDefaultValue && !skipParseDefaultValue { + field.DefaultValue = strings.Trim(field.DefaultValue, "'") + field.DefaultValue = strings.Trim(field.DefaultValue, "\"") + field.DefaultValueInterface = field.DefaultValue + } + case reflect.Struct: + if _, ok := fieldValue.Interface().(*time.Time); ok { + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + field.DataType = Time + } + case reflect.Array, reflect.Slice: + if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { + field.DataType = Bytes + } + } + + field.GORMDataType = field.DataType + + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + field.DataType = DataType(dataTyper.GormDataType()) + } + + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if strings.ToUpper(v) == "NANO" { + field.AutoCreateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoCreateTime = UnixMillisecond + } else { + field.AutoCreateTime = UnixSecond + } + } + + if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if strings.ToUpper(v) == "NANO" { + field.AutoUpdateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoUpdateTime = UnixMillisecond + } else { + field.AutoUpdateTime = UnixSecond + } + } + + if val, ok := field.TagSettings["TYPE"]; ok { + switch DataType(strings.ToLower(val)) { + case Bool, Int, Uint, Float, String, Time, Bytes: + field.DataType = DataType(strings.ToLower(val)) + default: + field.DataType = DataType(val) + } + } + + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + + if field.Size == 0 { + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: + field.Size = 64 + case reflect.Int8, reflect.Uint8: + field.Size = 8 + case reflect.Int16, reflect.Uint16: + field.Size = 16 + case reflect.Int32, reflect.Uint32, reflect.Float32: + field.Size = 32 + } + } + + // setup permission + if _, ok := field.TagSettings["-"]; ok { + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + } + + if v, ok := field.TagSettings["->"]; ok { + field.Creatable = false + field.Updatable = false + if strings.ToLower(v) == "false" { + field.Readable = false + } else { + field.Readable = true + } + } + + if v, ok := field.TagSettings["<-"]; ok { + field.Creatable = true + field.Updatable = true + + if v != "<-" { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } + } + } + + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { + if reflect.Indirect(fieldValue).Kind() == reflect.Struct { + var err error + field.Creatable = false + field.Updatable = false + field.Readable = false + + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + schema.err = err + } + + for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema + ef.OwnerSchema = field.EmbeddedSchema + ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) + } + + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { + ef.DBName = prefix + ef.DBName + } + + if ef.PrimaryKey { + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false + + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } + } + } + + for k, v := range field.TagSettings { + ef.TagSettings[k] = v + } + } + } else { + schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) + } + } + + return field +} + +// create valuer, setter when parse struct +func (field *Field) setupValuerAndSetter() { + // ValueOf + switch { + case len(field.StructField.Index) == 1: + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) + return fieldValue.Interface(), fieldValue.IsZero() + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + return fieldValue.Interface(), fieldValue.IsZero() + } + default: + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + v := reflect.Indirect(value) + + for _, idx := range field.StructField.Index { + if idx >= 0 { + v = v.Field(idx) + } else { + v = v.Field(-idx - 1) + + if v.Type().Elem().Kind() == reflect.Struct { + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } + } else { + return nil, true + } + } + } + return v.Interface(), v.IsZero() + } + } + + // ReflectValueOf + switch { + case len(field.StructField.Index) == 1: + if field.FieldType.Kind() == reflect.Ptr { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) + return fieldValue + } + } else { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(field.StructField.Index[0]) + } + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + } + default: + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + v := reflect.Indirect(value) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + } + + if v.Kind() == reflect.Ptr { + if v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + } + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } + } + } + return v + } + } + + fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + + if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + return + } else if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + return + } else if field.FieldType.Kind() == reflect.Ptr { + fieldValue := field.ReflectValueOf(value) + + if reflectV.Type().AssignableTo(field.FieldType.Elem()) { + if !fieldValue.IsValid() { + fieldValue = reflect.New(field.FieldType.Elem()) + } else if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflectV) + return + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) + return + } + } + + if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + err = setter(value, reflectV.Elem().Interface()) + } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(value, v) + } + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + + return + } + + // Set + switch field.FieldType.Kind() { + case reflect.Bool: + field.Set = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case bool: + field.ReflectValueOf(value).SetBool(data) + case *bool: + if data != nil { + field.ReflectValueOf(value).SetBool(*data) + } else { + field.ReflectValueOf(value).SetBool(false) + } + case int64: + if data > 0 { + field.ReflectValueOf(value).SetBool(true) + } else { + field.ReflectValueOf(value).SetBool(false) + } + case string: + b, _ := strconv.ParseBool(data) + field.ReflectValueOf(value).SetBool(b) + default: + return fallbackSetter(value, v, field.Set) + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case int64: + field.ReflectValueOf(value).SetInt(data) + case int: + field.ReflectValueOf(value).SetInt(int64(data)) + case int8: + field.ReflectValueOf(value).SetInt(int64(data)) + case int16: + field.ReflectValueOf(value).SetInt(int64(data)) + case int32: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint8: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint16: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint32: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint64: + field.ReflectValueOf(value).SetInt(int64(data)) + case float32: + field.ReflectValueOf(value).SetInt(int64(data)) + case float64: + field.ReflectValueOf(value).SetInt(int64(data)) + case []byte: + return field.Set(value, string(data)) + case string: + if i, err := strconv.ParseInt(data, 0, 64); err == nil { + field.ReflectValueOf(value).SetInt(i) + } else { + return err + } + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + case *time.Time: + if data != nil { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + } else { + field.ReflectValueOf(value).SetInt(0) + } + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case uint64: + field.ReflectValueOf(value).SetUint(data) + case uint: + field.ReflectValueOf(value).SetUint(uint64(data)) + case uint8: + field.ReflectValueOf(value).SetUint(uint64(data)) + case uint16: + field.ReflectValueOf(value).SetUint(uint64(data)) + case uint32: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int64: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int8: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int16: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int32: + field.ReflectValueOf(value).SetUint(uint64(data)) + case float32: + field.ReflectValueOf(value).SetUint(uint64(data)) + case float64: + field.ReflectValueOf(value).SetUint(uint64(data)) + case []byte: + return field.Set(value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + } + case string: + if i, err := strconv.ParseUint(data, 0, 64); err == nil { + field.ReflectValueOf(value).SetUint(i) + } else { + return err + } + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + case reflect.Float32, reflect.Float64: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case float64: + field.ReflectValueOf(value).SetFloat(data) + case float32: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int64: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int8: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int16: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int32: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint8: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint16: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint32: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint64: + field.ReflectValueOf(value).SetFloat(float64(data)) + case []byte: + return field.Set(value, string(data)) + case string: + if i, err := strconv.ParseFloat(data, 64); err == nil { + field.ReflectValueOf(value).SetFloat(i) + } else { + return err + } + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + case reflect.String: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case string: + field.ReflectValueOf(value).SetString(data) + case []byte: + field.ReflectValueOf(value).SetString(string(data)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + field.ReflectValueOf(value).SetString(utils.ToString(data)) + case float64, float32: + field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + default: + fieldValue := reflect.New(field.FieldType) + switch fieldValue.Elem().Interface().(type) { + case time.Time: + field.Set = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + case *time.Time: + if data != nil { + field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + } else { + field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + } + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValueOf(value).Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(value, v, field.Set) + } + return nil + } + case *time.Time: + field.Set = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + case string: + if t, err := now.Parse(data); err == nil { + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == "" { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(value, v, field.Set) + } + return nil + } + default: + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + // pointer scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = fieldValue.Interface().(sql.Scanner).Scan(v) + } + return + } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + } else { + field.Set = func(value reflect.Value, v interface{}) (err error) { + return fallbackSetter(value, v, field.Set) + } + } + } + } +} diff --git a/vendor/gorm.io/gorm/schema/index.go b/vendor/gorm.io/gorm/schema/index.go new file mode 100644 index 000000000..b54e08ad2 --- /dev/null +++ b/vendor/gorm.io/gorm/schema/index.go @@ -0,0 +1,141 @@ +package schema + +import ( + "sort" + "strconv" + "strings" +) + +type Index struct { + Name string + Class string // UNIQUE | FULLTEXT | SPATIAL + Type string // btree, hash, gist, spgist, gin, and brin + Where string + Comment string + Option string // WITH PARSER parser_name + Fields []IndexOption +} + +type IndexOption struct { + *Field + Expression string + Sort string // DESC, ASC + Collate string + Length int + priority int +} + +// ParseIndexes parse schema indexes +func (schema *Schema) ParseIndexes() map[string]Index { + var indexes = map[string]Index{} + + for _, field := range schema.Fields { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { + for _, index := range parseFieldIndexes(field) { + idx := indexes[index.Name] + idx.Name = index.Name + if idx.Class == "" { + idx.Class = index.Class + } + if idx.Type == "" { + idx.Type = index.Type + } + if idx.Where == "" { + idx.Where = index.Where + } + if idx.Comment == "" { + idx.Comment = index.Comment + } + if idx.Option == "" { + idx.Option = index.Option + } + + idx.Fields = append(idx.Fields, index.Fields...) + sort.Slice(idx.Fields, func(i, j int) bool { + return idx.Fields[i].priority < idx.Fields[j].priority + }) + + indexes[index.Name] = idx + } + } + } + + return indexes +} + +func (schema *Schema) LookIndex(name string) *Index { + if schema != nil { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { + return &index + } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } + } + } + + return nil +} + +func parseFieldIndexes(field *Field) (indexes []Index) { + for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if k == "INDEX" || k == "UNIQUEINDEX" { + var ( + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + settings = ParseTagSetting(tag, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) + ) + + if idx == -1 { + idx = len(tag) + } + + if idx != -1 { + name = tag[0:idx] + } + + if name == "" { + name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) + } + + if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { + settings["CLASS"] = "UNIQUE" + } + + priority, err := strconv.Atoi(settings["PRIORITY"]) + if err != nil { + priority = 10 + } + + indexes = append(indexes, Index{ + Name: name, + Class: settings["CLASS"], + Type: settings["TYPE"], + Where: settings["WHERE"], + Comment: settings["COMMENT"], + Option: settings["OPTION"], + Fields: []IndexOption{{ + Field: field, + Expression: settings["EXPRESSION"], + Sort: settings["SORT"], + Collate: settings["COLLATE"], + Length: length, + priority: priority, + }}, + }) + } + } + } + + return +} diff --git a/vendor/gorm.io/gorm/schema/interfaces.go b/vendor/gorm.io/gorm/schema/interfaces.go new file mode 100644 index 000000000..98abffbd4 --- /dev/null +++ b/vendor/gorm.io/gorm/schema/interfaces.go @@ -0,0 +1,25 @@ +package schema + +import ( + "gorm.io/gorm/clause" +) + +type GormDataTypeInterface interface { + GormDataType() string +} + +type CreateClausesInterface interface { + CreateClauses(*Field) []clause.Interface +} + +type QueryClausesInterface interface { + QueryClauses(*Field) []clause.Interface +} + +type UpdateClausesInterface interface { + UpdateClauses(*Field) []clause.Interface +} + +type DeleteClausesInterface interface { + DeleteClauses(*Field) []clause.Interface +} diff --git a/vendor/gorm.io/gorm/schema/naming.go b/vendor/gorm.io/gorm/schema/naming.go new file mode 100644 index 000000000..63296967c --- /dev/null +++ b/vendor/gorm.io/gorm/schema/naming.go @@ -0,0 +1,145 @@ +package schema + +import ( + "crypto/sha1" + "fmt" + "strings" + "sync" + "unicode/utf8" + + "github.com/jinzhu/inflection" +) + +// Namer namer interface +type Namer interface { + TableName(table string) string + ColumnName(table, column string) string + JoinTableName(joinTable string) string + RelationshipFKName(Relationship) string + CheckerName(table, column string) string + IndexName(table, column string) string +} + +// NamingStrategy tables, columns naming strategy +type NamingStrategy struct { + TablePrefix string + SingularTable bool + NameReplacer *strings.Replacer +} + +// TableName convert string to table name +func (ns NamingStrategy) TableName(str string) string { + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// ColumnName convert string to column name +func (ns NamingStrategy) ColumnName(table, column string) string { + return ns.toDBName(column) +} + +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + if strings.ToLower(str) == str { + return ns.TablePrefix + str + } + + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// RelationshipFKName generate fk name for relation +func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { + return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, ns.toDBName(rel.Name)), ".", "_", -1) +} + +// CheckerName generate checker name +func (ns NamingStrategy) CheckerName(table, column string) string { + return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1) +} + +// IndexName generate index name +func (ns NamingStrategy) IndexName(table, column string) string { + idxName := fmt.Sprintf("idx_%v_%v", table, ns.toDBName(column)) + idxName = strings.Replace(idxName, ".", "_", -1) + + if utf8.RuneCountInString(idxName) > 64 { + h := sha1.New() + h.Write([]byte(idxName)) + bs := h.Sum(nil) + + idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8] + } + return idxName +} + +var ( + smap sync.Map + // https://github.com/golang/lint/blob/master/lint.go#L770 + commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} + commonInitialismsReplacer *strings.Replacer +) + +func init() { + var commonInitialismsForReplacer []string + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +func (ns NamingStrategy) toDBName(name string) string { + if name == "" { + return "" + } else if v, ok := smap.Load(name); ok { + return v.(string) + } + + if ns.NameReplacer != nil { + name = ns.NameReplacer.Replace(name) + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf strings.Builder + lastCase, nextCase, nextNumber bool // upper case == true + curCase = value[0] <= 'Z' && value[0] >= 'A' + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' + nextNumber = value[i+1] >= '0' && value[i+1] <= '9' + + if curCase { + if lastCase && (nextCase || nextNumber) { + buf.WriteRune(v + 32) + } else { + if i > 0 && value[i-1] != '_' && value[i+1] != '_' { + buf.WriteByte('_') + } + buf.WriteRune(v + 32) + } + } else { + buf.WriteRune(v) + } + + lastCase = curCase + curCase = nextCase + } + + if curCase { + if !lastCase && len(value) > 1 { + buf.WriteByte('_') + } + buf.WriteByte(value[len(value)-1] + 32) + } else { + buf.WriteByte(value[len(value)-1]) + } + ret := buf.String() + smap.Store(name, ret) + return ret +} diff --git a/vendor/gorm.io/gorm/schema/relationship.go b/vendor/gorm.io/gorm/schema/relationship.go new file mode 100644 index 000000000..9cfc10bed --- /dev/null +++ b/vendor/gorm.io/gorm/schema/relationship.go @@ -0,0 +1,566 @@ +package schema + +import ( + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/jinzhu/inflection" + "gorm.io/gorm/clause" +) + +// RelationshipType relationship type +type RelationshipType string + +const ( + HasOne RelationshipType = "has_one" // HasOneRel has one relationship + HasMany RelationshipType = "has_many" // HasManyRel has many relationship + BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship +) + +type Relationships struct { + HasOne []*Relationship + BelongsTo []*Relationship + HasMany []*Relationship + Many2Many []*Relationship + Relations map[string]*Relationship +} + +type Relationship struct { + Name string + Type RelationshipType + Field *Field + Polymorphic *Polymorphic + References []*Reference + Schema *Schema + FieldSchema *Schema + JoinTable *Schema + foreignKeys, primaryKeys []string +} + +type Polymorphic struct { + PolymorphicID *Field + PolymorphicType *Field + Value string +} + +type Reference struct { + PrimaryKey *Field + PrimaryValue string + ForeignKey *Field + OwnPrimaryKey bool +} + +func (schema *Schema) parseRelation(field *Field) { + var ( + err error + fieldValue = reflect.New(field.IndirectFieldType).Interface() + relation = &Relationship{ + Name: field.Name, + Field: field, + Schema: schema, + foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + primaryKeys: toColumns(field.TagSettings["REFERENCES"]), + } + ) + + cacheStore := schema.cacheStore + if field.OwnerSchema != nil { + cacheStore = field.OwnerSchema.cacheStore + } + + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return + } + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + schema.buildPolymorphicRelation(relation, field, polymorphic) + } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + schema.buildMany2ManyRelation(relation, field, many2many) + } else { + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + schema.guessRelation(relation, field, guessBelongs) + case reflect.Slice: + schema.guessRelation(relation, field, guessHas) + default: + schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) + } + } + + if relation.Type == "has" { + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil { + relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation + } + + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + relation.Type = HasOne + case reflect.Slice: + relation.Type = HasMany + } + } + + if schema.err == nil { + schema.Relationships.Relations[relation.Name] = relation + switch relation.Type { + case HasOne: + schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) + case HasMany: + schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation) + case BelongsTo: + schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation) + case Many2Many: + schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) + } + } +} + +// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, &Reference{ + PrimaryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.foreignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) + } + } + + // use same data type for foreign keys + relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType + if relation.Polymorphic.PolymorphicID.Size == 0 { + relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicID, + OwnPrimaryKey: true, + }) + } + + relation.Type = "has" +} + +func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { + relation.Type = Many2Many + + var ( + err error + joinTableFields []reflect.StructField + fieldsMap = map[string]*Field{} + ownFieldsMap = map[string]bool{} // fix self join many2many + joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) + joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) + ) + + ownForeignFields := schema.PrimaryFields + refForeignFields := relation.FieldSchema.PrimaryFields + + if len(relation.foreignKeys) > 0 { + ownForeignFields = []*Field{} + for _, foreignKey := range relation.foreignKeys { + if field := schema.LookUpField(foreignKey); field != nil { + ownForeignFields = append(ownForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } + } + } + + if len(relation.primaryKeys) > 0 { + refForeignFields = []*Field{} + for _, foreignKey := range relation.primaryKeys { + if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { + refForeignFields = append(refForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } + } + } + + for idx, ownField := range ownForeignFields { + joinFieldName := schema.Name + ownField.Name + if len(joinForeignKeys) > idx { + joinFieldName = strings.Title(joinForeignKeys[idx]) + } + + ownFieldsMap[joinFieldName] = true + fieldsMap[joinFieldName] = ownField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: ownField.StructField.PkgPath, + Type: ownField.StructField.Type, + Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + + for idx, relField := range refForeignFields { + joinFieldName := relation.FieldSchema.Name + relField.Name + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + if _, ok := ownFieldsMap[joinFieldName]; ok { + if field.Name != relation.FieldSchema.Name { + joinFieldName = inflection.Singular(field.Name) + relField.Name + } else { + joinFieldName += "Reference" + } + } + + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: schema.Name + field.Name, + Type: schema.ModelType, + Tag: `gorm:"-"`, + }) + + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + schema.err = err + } + relation.JoinTable.Name = many2many + relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) + + relName := relation.Schema.Name + relRefName := relation.FieldSchema.Name + if relName == relRefName { + relRefName = relation.Field.Name + } + + if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok { + relation.JoinTable.Relationships.Relations[relName] = &Relationship{ + Name: relName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.Schema, + } + } else { + relation.JoinTable.Relationships.Relations[relName].References = []*Reference{} + } + + if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok { + relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{ + Name: relRefName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.FieldSchema, + } + } else { + relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{} + } + + // build references + for _, f := range relation.JoinTable.Fields { + if f.Creatable || f.Readable || f.Updatable { + // use same data type for foreign keys + f.DataType = fieldsMap[f.Name].DataType + f.GORMDataType = fieldsMap[f.Name].GORMDataType + if f.Size == 0 { + f.Size = fieldsMap[f.Name].Size + } + relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) + ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + + if ownPriamryField { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } else { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + OwnPrimaryKey: ownPriamryField, + }) + } + } +} + +type guessLevel int + +const ( + guessBelongs guessLevel = iota + guessEmbeddedBelongs + guessHas + guessEmbeddedHas +) + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { + var ( + primaryFields, foreignFields []*Field + primarySchema, foreignSchema = schema, relation.FieldSchema + ) + + reguessOrErr := func() { + switch gl { + case guessBelongs: + schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: + default: + schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) + } + } + + switch gl { + case guessBelongs: + primarySchema, foreignSchema = relation.FieldSchema, schema + case guessEmbeddedBelongs: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema + } else { + reguessOrErr() + return + } + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } else { + reguessOrErr() + return + } + } + + if len(relation.foreignKeys) > 0 { + for _, foreignKey := range relation.foreignKeys { + if f := foreignSchema.LookUpField(foreignKey); f != nil { + foreignFields = append(foreignFields, f) + } else { + reguessOrErr() + return + } + } + } else { + for _, primaryField := range primarySchema.PrimaryFields { + lookUpName := primarySchema.Name + primaryField.Name + if gl == guessBelongs { + lookUpName = field.Name + primaryField.Name + } + + if f := foreignSchema.LookUpField(lookUpName); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + } + } + } + + if len(foreignFields) == 0 { + reguessOrErr() + return + } else if len(relation.primaryKeys) > 0 { + for idx, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + if len(primaryFields) < idx+1 { + primaryFields = append(primaryFields, f) + } else if f != primaryFields[idx] { + reguessOrErr() + return + } + } else { + reguessOrErr() + return + } + } + } else if len(primaryFields) == 0 { + if len(foreignFields) == 1 { + primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) + } else if len(primarySchema.PrimaryFields) == len(foreignFields) { + primaryFields = append(primaryFields, primarySchema.PrimaryFields...) + } else { + reguessOrErr() + return + } + } + + // build references + for idx, foreignField := range foreignFields { + // use same data type for foreign keys + foreignField.DataType = primaryFields[idx].DataType + foreignField.GORMDataType = primaryFields[idx].GORMDataType + if foreignField.Size == 0 { + foreignField.Size = primaryFields[idx].Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryFields[idx], + ForeignKey: foreignField, + OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), + }) + } + + if gl == guessHas || gl == guessEmbeddedHas { + relation.Type = "has" + } else { + relation.Type = BelongsTo + } +} + +type Constraint struct { + Name string + Field *Field + Schema *Schema + ForeignKeys []*Field + ReferenceSchema *Schema + References []*Field + OnDelete string + OnUpdate string +} + +func (rel *Relationship) ParseConstraint() *Constraint { + str := rel.Field.TagSettings["CONSTRAINT"] + if str == "-" { + return nil + } + + var ( + name string + idx = strings.Index(str, ",") + settings = ParseTagSetting(str, ",") + ) + + if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) { + name = str[0:idx] + } else { + name = rel.Schema.namer.RelationshipFKName(*rel) + } + + constraint := Constraint{ + Name: name, + Field: rel.Field, + OnUpdate: settings["ONUPDATE"], + OnDelete: settings["ONDELETE"], + } + + for _, ref := range rel.References { + if ref.PrimaryKey != nil { + constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) + constraint.References = append(constraint.References, ref.PrimaryKey) + + if ref.OwnPrimaryKey { + constraint.Schema = ref.ForeignKey.Schema + constraint.ReferenceSchema = rel.Schema + } else { + constraint.Schema = rel.Schema + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } + } + } + + if rel.JoinTable != nil { + return nil + } + + return &constraint +} + +func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { + table := rel.FieldSchema.Table + foreignFields := []*Field{} + relForeignKeys := []string{} + + if rel.JoinTable != nil { + table = rel.JoinTable.Table + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + } + + _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + column, values := ToQueryValues(table, relForeignKeys, foreignValues) + + conds = append(conds, clause.IN{Column: column, Values: values}) + return +} diff --git a/vendor/gorm.io/gorm/schema/schema.go b/vendor/gorm.io/gorm/schema/schema.go new file mode 100644 index 000000000..da4be3050 --- /dev/null +++ b/vendor/gorm.io/gorm/schema/schema.go @@ -0,0 +1,280 @@ +package schema + +import ( + "context" + "errors" + "fmt" + "go/ast" + "reflect" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +// ErrUnsupportedDataType unsupported data type +var ErrUnsupportedDataType = errors.New("unsupported data type") + +type Schema struct { + Name string + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + DBNames []string + PrimaryFields []*Field + PrimaryFieldDBNames []string + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + FieldsWithDefaultDBValue []*Field // fields with default value assigned by database + Relationships Relationships + CreateClauses []clause.Interface + QueryClauses []clause.Interface + UpdateClauses []clause.Interface + DeleteClauses []clause.Interface + BeforeCreate, AfterCreate bool + BeforeUpdate, AfterUpdate bool + BeforeDelete, AfterDelete bool + BeforeSave, AfterSave bool + AfterFind bool + err error + initialized chan struct{} + namer Namer + cacheStore *sync.Map +} + +func (schema Schema) String() string { + if schema.ModelType.Name() == "" { + return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) + } + return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) +} + +func (schema Schema) MakeSlice() reflect.Value { + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) + results := reflect.New(slice.Type()) + results.Elem().Set(slice) + return results +} + +func (schema Schema) LookUpField(name string) *Field { + if field, ok := schema.FieldsByDBName[name]; ok { + return field + } + if field, ok := schema.FieldsByName[name]; ok { + return field + } + return nil +} + +type Tabler interface { + TableName() string +} + +// get data type from dialector +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + if dest == nil { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + s := v.(*Schema) + <-s.initialized + return s, nil + } + + modelValue := reflect.New(modelType) + tableName := namer.TableName(modelType.Name()) + if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } + + schema := &Schema{ + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), + } + + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() + + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { + schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) + } else { + schema.Fields = append(schema.Fields, field) + } + } + } + + for _, field := range schema.Fields { + if field.DBName == "" && field.DataType != "" { + field.DBName = namer.ColumnName(schema.Table, field.Name) + } + + if field.DBName != "" { + // nonexistence or shortest path or first appear prioritized if has permission + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { + if _, ok := schema.FieldsByDBName[field.DBName]; !ok { + schema.DBNames = append(schema.DBNames, field.DBName) + } + schema.FieldsByDBName[field.DBName] = field + schema.FieldsByName[field.Name] = field + + if v != nil && v.PrimaryKey { + for idx, f := range schema.PrimaryFields { + if f == v { + schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) + } + } + } + + if field.PrimaryKey { + schema.PrimaryFields = append(schema.PrimaryFields, field) + } + } + } + + if _, ok := schema.FieldsByName[field.Name]; !ok { + schema.FieldsByName[field.Name] = field + } + + field.setupValuerAndSetter() + } + + prioritizedPrimaryField := schema.LookUpField("id") + if prioritizedPrimaryField == nil { + prioritizedPrimaryField = schema.LookUpField("ID") + } + + if prioritizedPrimaryField != nil { + if prioritizedPrimaryField.PrimaryKey { + schema.PrioritizedPrimaryField = prioritizedPrimaryField + } else if len(schema.PrimaryFields) == 0 { + prioritizedPrimaryField.PrimaryKey = true + schema.PrioritizedPrimaryField = prioritizedPrimaryField + schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) + } + } + + if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } + + for _, field := range schema.PrimaryFields { + schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) + } + + for _, field := range schema.FieldsByDBName { + if field.HasDefaultValue && field.DefaultValueInterface == nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + } + + if field := schema.PrioritizedPrimaryField; field != nil { + switch field.GORMDataType { + case Int, Uint: + if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + field.AutoIncrement = true + } + } + } + + callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} + for _, name := range callbacks { + if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": // TODO hack + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + } + } + } + + if s, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } + } + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } + close(schema.initialized) + } + } else { + <-s.(*Schema).initialized + return s.(*Schema), nil + } + + return schema, schema.err +} + +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema), nil + } + + return Parse(dest, cacheStore, namer) +} diff --git a/vendor/gorm.io/gorm/schema/utils.go b/vendor/gorm.io/gorm/schema/utils.go new file mode 100644 index 000000000..6e5fd5284 --- /dev/null +++ b/vendor/gorm.io/gorm/schema/utils.go @@ -0,0 +1,197 @@ +package schema + +import ( + "reflect" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +var embeddedCacheKey = "embedded_cache_store" + +func ParseTagSetting(str string, sep string) map[string]string { + settings := map[string]string{} + names := strings.Split(str, sep) + + for i := 0; i < len(names); i++ { + j := i + if len(names[j]) > 0 { + for { + if names[j][len(names[j])-1] == '\\' { + i++ + names[j] = names[j][0:len(names[j])-1] + sep + names[i] + names[i] = "" + } else { + break + } + } + } + + values := strings.Split(names[j], ":") + k := strings.TrimSpace(strings.ToUpper(values[0])) + + if len(values) >= 2 { + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k + } + } + + return settings +} + +func toColumns(val string) (results []string) { + if val != "" { + for _, v := range strings.Split(val, ",") { + results = append(results, strings.TrimSpace(v)) + } + } + return +} + +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag +} + +// GetRelationsValues get relations's values from a reflect value +func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result.Addr()) + case reflect.Slice, reflect.Array: + for i := 0; i < result.Len(); i++ { + if result.Index(i).Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, result.Index(i)) + } else { + reflectResults = reflect.Append(reflectResults, result.Index(i).Addr()) + } + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +// GetIdentityFieldValuesMap get identity map from fields +func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + loaded = map[interface{}]bool{} + notZero, zero bool + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + results[0][idx], zero = field.ValueOf(reflectValue) + notZero = notZero || !zero + } + + if !notZero { + return nil, nil + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + elem := reflectValue.Index(i) + elemKey := elem.Interface() + if elem.Kind() != reflect.Ptr { + elemKey = elem.Addr().Interface() + } + + if _, ok := loaded[elemKey]; ok { + continue + } + loaded[elemKey] = true + + fieldValues := make([]interface{}, len(fields)) + notZero = false + for idx, field := range fields { + fieldValues[idx], zero = field.ValueOf(elem) + notZero = notZero || !zero + } + + if notZero { + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + results = append(results, fieldValues[:]) + dataResults[dataKey] = []reflect.Value{elem} + } else { + dataResults[dataKey] = append(dataResults[dataKey], elem) + } + } + } + } + + return dataResults, results +} + +// GetIdentityFieldValuesMapFromValues get identity map from fields +func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + resultsMap := map[string][]reflect.Value{} + results := [][]interface{}{} + + for _, v := range values { + rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + for k, v := range rm { + resultsMap[k] = append(resultsMap[k], v...) + } + results = append(results, rs...) + } + return resultsMap, results +} + +// ToQueryValues to query values +func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + + return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues + } else { + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + + for idx, r := range foreignValues { + queryValues[idx] = r + } + return columns, queryValues + } +} + +type embeddedNamer struct { + Table string + Namer +} diff --git a/vendor/gorm.io/gorm/soft_delete.go b/vendor/gorm.io/gorm/soft_delete.go new file mode 100644 index 000000000..cb56035d6 --- /dev/null +++ b/vendor/gorm.io/gorm/soft_delete.go @@ -0,0 +1,136 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +type DeletedAt sql.NullTime + +// Scan implements the Scanner interface. +func (n *DeletedAt) Scan(value interface{}) error { + return (*sql.NullTime)(n).Scan(value) +} + +// Value implements the driver Valuer interface. +func (n DeletedAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +func (n DeletedAt) MarshalJSON() ([]byte, error) { + if n.Valid { + return json.Marshal(n.Time) + } + return json.Marshal(nil) +} + +func (n *DeletedAt) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.Valid = false + return nil + } + err := json.Unmarshal(b, &n.Time) + if err == nil { + n.Valid = true + } + return err +} + +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteQueryClause{Field: f}} +} + +type SoftDeleteQueryClause struct { + Field *schema.Field +} + +func (sd SoftDeleteQueryClause) Name() string { + return "" +} + +func (sd SoftDeleteQueryClause) Build(clause.Builder) { +} + +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + if c, ok := stmt.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { + for _, expr := range where.Exprs { + if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { + where.Exprs = []clause.Expression{clause.And(where.Exprs...)} + c.Expression = where + stmt.Clauses["WHERE"] = c + break + } + } + } + } + + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f}} +} + +type SoftDeleteDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.String() == "" { + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) + + if stmt.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { + stmt.DB.AddError(ErrMissingWhereClause) + } else { + SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt) + } + + stmt.AddClauseIfNotExists(clause.Update{}) + stmt.Build("UPDATE", "SET", "WHERE") + } +} diff --git a/vendor/gorm.io/gorm/statement.go b/vendor/gorm.io/gorm/statement.go new file mode 100644 index 000000000..27edf9da8 --- /dev/null +++ b/vendor/gorm.io/gorm/statement.go @@ -0,0 +1,594 @@ +package gorm + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Statement statement +type Statement struct { + *DB + TableExpr *clause.Expr + Table string + Model interface{} + Unscoped bool + Dest interface{} + ReflectValue reflect.Value + Clauses map[string]clause.Clause + Distinct bool + Selects []string // selected columns + Omits []string // omit columns + Joins []join + Preloads map[string][]interface{} + Settings sync.Map + ConnPool ConnPool + Schema *schema.Schema + Context context.Context + RaiseErrorOnNotFound bool + SkipHooks bool + SQL strings.Builder + Vars []interface{} + CurDestIndex int + attrs []interface{} + assigns []interface{} +} + +type join struct { + Name string + Conds []interface{} +} + +// StatementModifier statement modifier interface +type StatementModifier interface { + ModifyStatement(*Statement) +} + +// Write write string +func (stmt *Statement) WriteString(str string) (int, error) { + return stmt.SQL.WriteString(str) +} + +// Write write string +func (stmt *Statement) WriteByte(c byte) error { + return stmt.SQL.WriteByte(c) +} + +// WriteQuoted write quoted value +func (stmt *Statement) WriteQuoted(value interface{}) { + stmt.QuoteTo(&stmt.SQL, value) +} + +// QuoteTo write quoted value to writer +func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { + switch v := field.(type) { + case clause.Table: + if v.Name == clause.CurrentTable { + if stmt.TableExpr != nil { + stmt.TableExpr.Build(stmt) + } else { + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } + } else if v.Raw { + writer.WriteString(v.Name) + } else { + stmt.DB.Dialector.QuoteTo(writer, v.Name) + } + + if v.Alias != "" { + writer.WriteByte(' ') + stmt.DB.Dialector.QuoteTo(writer, v.Alias) + } + case clause.Column: + if v.Table != "" { + if v.Table == clause.CurrentTable { + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } else { + stmt.DB.Dialector.QuoteTo(writer, v.Table) + } + writer.WriteByte('.') + } + + if v.Name == clause.PrimaryKey { + if stmt.Schema == nil { + stmt.DB.AddError(ErrModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { + stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if len(stmt.Schema.DBNames) > 0 { + stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) + } + } else if v.Raw { + writer.WriteString(v.Name) + } else { + stmt.DB.Dialector.QuoteTo(writer, v.Name) + } + + if v.Alias != "" { + writer.WriteString(" AS ") + stmt.DB.Dialector.QuoteTo(writer, v.Alias) + } + case []clause.Column: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteString(",") + } + stmt.QuoteTo(writer, d) + } + writer.WriteByte(')') + case string: + stmt.DB.Dialector.QuoteTo(writer, v) + case []string: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteString(",") + } + stmt.DB.Dialector.QuoteTo(writer, d) + } + writer.WriteByte(')') + default: + stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) + } +} + +// Quote returns quoted value +func (stmt *Statement) Quote(field interface{}) string { + var builder strings.Builder + stmt.QuoteTo(&builder, field) + return builder.String() +} + +// Write write string +func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { + for idx, v := range vars { + if idx > 0 { + writer.WriteByte(',') + } + + switch v := v.(type) { + case sql.NamedArg: + stmt.Vars = append(stmt.Vars, v.Value) + case clause.Column, clause.Table: + stmt.QuoteTo(writer, v) + case Valuer: + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) + case clause.Expr: + var varStr strings.Builder + var sql = v.SQL + for _, arg := range v.Vars { + stmt.Vars = append(stmt.Vars, arg) + stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg) + sql = strings.Replace(sql, "?", varStr.String(), 1) + varStr.Reset() + } + + writer.WriteString(sql) + case driver.Valuer: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []byte: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []interface{}: + if len(v) > 0 { + writer.WriteByte('(') + stmt.AddVar(writer, v...) + writer.WriteByte(')') + } else { + writer.WriteString("(NULL)") + } + case *DB: + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) + subdb.callbacks.Query().Execute(subdb) + writer.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars + default: + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + writer.WriteString("(NULL)") + } else { + writer.WriteByte('(') + for i := 0; i < rv.Len(); i++ { + if i > 0 { + writer.WriteByte(',') + } + stmt.AddVar(writer, rv.Index(i).Interface()) + } + writer.WriteByte(')') + } + default: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + } + } + } +} + +// AddClause add clause +func (stmt *Statement) AddClause(v clause.Interface) { + if optimizer, ok := v.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } else { + name := v.Name() + c := stmt.Clauses[name] + c.Name = name + v.MergeClause(&c) + stmt.Clauses[name] = c + } +} + +// AddClauseIfNotExists add clause if not exists +func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { + if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { + stmt.AddClause(v) + } +} + +// BuildCondition build condition +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { + if s, ok := query.(string); ok { + // if it is a number, then treats it as primary key + if _, err := strconv.Atoi(s); err != nil { + if s == "" && len(args) == 0 { + return nil + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } else if len(args) > 0 && strings.Contains(s, "@") { + // looks like a named query + return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} + } else if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} + } + } + } + + conds := make([]clause.Expression, 0, 4) + args = append([]interface{}{query}, args...) + for _, arg := range args { + if valuer, ok := arg.(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + switch v := arg.(type) { + case clause.Expression: + conds = append(conds, v) + case *DB: + if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + conds = append(conds, clause.And(where.Exprs...)) + } else if cs.Expression != nil { + conds = append(conds, cs.Expression) + } + } + case map[interface{}]interface{}: + for i, j := range v { + conds = append(conds, clause.Eq{Column: i, Value: j}) + } + case map[string]string: + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + case map[string]interface{}: + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + conds = append(conds, clause.IN{Column: key, Values: values}) + } + default: + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.Fields { + if field.Readable { + if v, isZero := field.ValueOf(reflectValue); !isZero { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for _, field := range s.Fields { + if field.Readable { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + } + } + } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return conds + } + } + + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + } + } + } + + return conds +} + +// Build build sql with clauses names +func (stmt *Statement) Build(clauses ...string) { + var firstClauseWritten bool + + for _, name := range clauses { + if c, ok := stmt.Clauses[name]; ok { + if firstClauseWritten { + stmt.WriteByte(' ') + } + + firstClauseWritten = true + if b, ok := stmt.DB.ClauseBuilders[name]; ok { + b(c, stmt) + } else { + c.Build(stmt) + } + } + } +} + +func (stmt *Statement) Parse(value interface{}) (err error) { + if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { + stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} + stmt.Table = tables[1] + return + } + + stmt.Table = stmt.Schema.Table + } + return err +} + +func (stmt *Statement) clone() *Statement { + newStmt := &Statement{ + TableExpr: stmt.TableExpr, + Table: stmt.Table, + Model: stmt.Model, + Unscoped: stmt.Unscoped, + Dest: stmt.Dest, + ReflectValue: stmt.ReflectValue, + Clauses: map[string]clause.Clause{}, + Distinct: stmt.Distinct, + Selects: stmt.Selects, + Omits: stmt.Omits, + Preloads: map[string][]interface{}{}, + ConnPool: stmt.ConnPool, + Schema: stmt.Schema, + Context: stmt.Context, + RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, + SkipHooks: stmt.SkipHooks, + } + + for k, c := range stmt.Clauses { + newStmt.Clauses[k] = c + } + + for k, p := range stmt.Preloads { + newStmt.Preloads[k] = p + } + + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) + } + + stmt.Settings.Range(func(k, v interface{}) bool { + newStmt.Settings.Store(k, v) + return true + }) + + return newStmt +} + +// Helpers +// SetColumn set column's value +func (stmt *Statement) SetColumn(name string, value interface{}) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + v[name] = value + } else if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + if stmt.ReflectValue != destValue { + if !destValue.CanAddr() { + destValueCanAddr := reflect.New(destValue.Type()) + destValueCanAddr.Elem().Set(destValue) + stmt.Dest = destValueCanAddr.Interface() + destValue = destValueCanAddr.Elem() + } + + switch destValue.Kind() { + case reflect.Struct: + field.Set(destValue, value) + default: + stmt.AddError(ErrInvalidData) + } + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + case reflect.Struct: + field.Set(stmt.ReflectValue, value) + } + } else { + stmt.AddError(ErrInvalidField) + } + } else { + stmt.AddError(ErrInvalidField) + } +} + +// Changed check model changed or not when updating +func (stmt *Statement) Changed(fields ...string) bool { + modelValue := stmt.ReflectValue + switch modelValue.Kind() { + case reflect.Slice, reflect.Array: + modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) + } + + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) + changed := func(field *schema.Field) bool { + fieldValue, _ := field.ValueOf(modelValue) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + if fv, ok := v[field.Name]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if fv, ok := v[field.DBName]; ok { + return !utils.AssertEqual(fv, fieldValue) + } + } else { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + changedValue, zero := field.ValueOf(destValue) + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } + } + return false + } + + if len(fields) == 0 { + for _, field := range stmt.Schema.FieldsByDBName { + if changed(field) { + return true + } + } + } else { + for _, name := range fields { + if field := stmt.Schema.LookUpField(name); field != nil { + if changed(field) { + return true + } + } + } + } + + return false +} + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { + results := map[string]bool{} + notRestricted := false + + // select columns + for _, column := range stmt.Selects { + if column == "*" { + notRestricted = true + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = true + } + } else if column == clause.Associations && stmt.Schema != nil { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + results[field.DBName] = true + } else { + results[column] = true + } + } + + // omit columns + for _, omit := range stmt.Omits { + if omit == clause.Associations { + if stmt.Schema != nil { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } + } + } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + results[field.DBName] = false + } else { + results[omit] = false + } + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.Fields { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results, !notRestricted && len(stmt.Selects) > 0 +} diff --git a/vendor/gorm.io/gorm/utils/utils.go b/vendor/gorm.io/gorm/utils/utils.go new file mode 100644 index 000000000..ecba7fb93 --- /dev/null +++ b/vendor/gorm.io/gorm/utils/utils.go @@ -0,0 +1,113 @@ +package utils + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "runtime" + "strconv" + "strings" + "unicode" +) + +var gormSourceDir string + +func init() { + _, file, _, _ := runtime.Caller(0) + gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") +} + +func FileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + return file + ":" + strconv.FormatInt(int64(line), 10) + } + } + return "" +} + +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' +} + +func CheckTruth(val interface{}) bool { + if v, ok := val.(bool); ok { + return v + } + + if v, ok := val.(string); ok { + v = strings.ToLower(v) + return v != "false" + } + + return !reflect.ValueOf(val).IsZero() +} + +func ToStringKey(values ...interface{}) string { + results := make([]string, len(values)) + + for idx, value := range values { + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() + } + + switch v := value.(type) { + case string: + results[idx] = v + case []byte: + results[idx] = string(v) + case uint: + results[idx] = strconv.FormatUint(uint64(v), 10) + default: + results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + } + } + + return strings.Join(results, "_") +} + +func AssertEqual(src, dst interface{}) bool { + if !reflect.DeepEqual(src, dst) { + if valuer, ok := src.(driver.Valuer); ok { + src, _ = valuer.Value() + } + + if valuer, ok := dst.(driver.Valuer); ok { + dst, _ = valuer.Value() + } + + return reflect.DeepEqual(src, dst) + } + return true +} + +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case int: + return strconv.FormatInt(int64(v), 10) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + } + return "" +}