From db9922412611f6d546f723586ec67ce587666478 Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Wed, 25 Dec 2024 21:08:22 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20Speed=20=E2=80=8B=E2=80=8Bof=20database?= =?UTF-8?q?=20initialization=20(#7694)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf: 优化非sqlite3数据库时初始化慢的问题 * refactor --- internal/bootstrap/data/setting.go | 48 ++++++++++++++++++------------ internal/bootstrap/db.go | 18 +++++------ internal/op/setting.go | 8 ++--- server/common/base.go | 8 ++--- 4 files changed, 46 insertions(+), 36 deletions(-) diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index 206273b41ac..bcb64f792d7 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -3,6 +3,7 @@ package data import ( "github.com/alist-org/alist/v3/cmd/flags" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/offline_download/tool" "github.com/alist-org/alist/v3/internal/op" @@ -21,17 +22,19 @@ func initSettings() { if err != nil { utils.Log.Fatalf("failed get settings: %+v", err) } - for i := range settings { - if !isActive(settings[i].Key) && settings[i].Flag != model.DEPRECATED { - settings[i].Flag = model.DEPRECATED - err = op.SaveSettingItem(&settings[i]) + settingMap := map[string]*model.SettingItem{} + for _, v := range settings { + if !isActive(v.Key) && v.Flag != model.DEPRECATED { + v.Flag = model.DEPRECATED + err = op.SaveSettingItem(&v) if err != nil { utils.Log.Fatalf("failed save setting: %+v", err) } } + settingMap[v.Key] = &v } - // create or save setting + save := false for i := range initialSettingItems { item := &initialSettingItems[i] item.Index = uint(i) @@ -39,26 +42,33 @@ func initSettings() { item.PreDefault = item.Value } // err - stored, err := op.GetSettingItemByKey(item.Key) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - utils.Log.Fatalf("failed get setting: %+v", err) - continue + stored, ok := settingMap[item.Key] + if !ok { + stored, err = op.GetSettingItemByKey(item.Key) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + utils.Log.Fatalf("failed get setting: %+v", err) + continue + } } - // save if stored != nil && item.Key != conf.VERSION && stored.Value != item.PreDefault { item.Value = stored.Value } + _, err = op.HandleSettingItemHook(item) + if err != nil { + utils.Log.Errorf("failed to execute hook on %s: %+v", item.Key, err) + continue + } + // save if stored == nil || *item != *stored { - err = op.SaveSettingItem(item) - if err != nil { - utils.Log.Fatalf("failed save setting: %+v", err) - } + save = true + } + } + if save { + err = db.SaveSettingItems(initialSettingItems) + if err != nil { + utils.Log.Fatalf("failed save setting: %+v", err) } else { - // Not save so needs to execute hook - _, err = op.HandleSettingItemHook(item) - if err != nil { - utils.Log.Errorf("failed to execute hook on %s: %+v", item.Key, err) - } + op.SettingCacheUpdate() } } } diff --git a/internal/bootstrap/db.go b/internal/bootstrap/db.go index 5dfa2820d18..39b659b78f1 100644 --- a/internal/bootstrap/db.go +++ b/internal/bootstrap/db.go @@ -56,20 +56,20 @@ func InitDB() { } case "mysql": { - //[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&tls=%s", - database.User, database.Password, database.Host, database.Port, database.Name, database.SSLMode) - if database.DSN != "" { - dsn = database.DSN + dsn := database.DSN + if dsn == "" { + //[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] + dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&tls=%s", + database.User, database.Password, database.Host, database.Port, database.Name, database.SSLMode) } dB, err = gorm.Open(mysql.Open(dsn), gormConfig) } case "postgres": { - dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=Asia/Shanghai", - database.Host, database.User, database.Password, database.Name, database.Port, database.SSLMode) - if database.DSN != "" { - dsn = database.DSN + dsn := database.DSN + if dsn == "" { + dsn = fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=Asia/Shanghai", + database.Host, database.User, database.Password, database.Name, database.Port, database.SSLMode) } dB, err = gorm.Open(postgres.Open(dsn), gormConfig) } diff --git a/internal/op/setting.go b/internal/op/setting.go index 83d19c12fbe..50eba3f744e 100644 --- a/internal/op/setting.go +++ b/internal/op/setting.go @@ -26,7 +26,7 @@ var settingGroupCacheF = func(key string, item []model.SettingItem) { settingGroupCache.Set(key, item, cache.WithEx[[]model.SettingItem](time.Hour)) } -func settingCacheUpdate() { +func SettingCacheUpdate() { settingCache.Clear() settingGroupCache.Clear() } @@ -167,7 +167,7 @@ func SaveSettingItems(items []model.SettingItem) error { } } if len(errs) < len(items)-len(noHookItems)+1 { - settingCacheUpdate() + SettingCacheUpdate() } return utils.MergeErrors(errs...) } @@ -181,7 +181,7 @@ func SaveSettingItem(item *model.SettingItem) (err error) { if err = db.SaveSettingItem(item); err != nil { return err } - settingCacheUpdate() + SettingCacheUpdate() return nil } @@ -193,6 +193,6 @@ func DeleteSettingItemByKey(key string) error { if !old.IsDeprecated() { return errors.Errorf("setting [%s] is not deprecated", key) } - settingCacheUpdate() + SettingCacheUpdate() return db.DeleteSettingItemByKey(key) } diff --git a/server/common/base.go b/server/common/base.go index eb6ef2b8ac2..11a28d25039 100644 --- a/server/common/base.go +++ b/server/common/base.go @@ -12,16 +12,16 @@ import ( func GetApiUrl(r *http.Request) string { api := conf.Conf.SiteURL if strings.HasPrefix(api, "http") { - return api + return strings.TrimSuffix(api, "/") } if r != nil { protocol := "http" if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { protocol = "https" } - host := r.Host - if r.Header.Get("X-Forwarded-Host") != "" { - host = r.Header.Get("X-Forwarded-Host") + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + host = r.Host } api = fmt.Sprintf("%s://%s", protocol, stdpath.Join(host, api)) }