Skip to content

Commit

Permalink
Merge remote-tracking branch 'giteaofficial/main'
Browse files Browse the repository at this point in the history
* giteaofficial/main:
  Don't include link of deleted branch when listing branches (go-gitea#31028)
  [skip ci] Updated translations via Crowdin
  Refactor sha1 and time-limited code (go-gitea#31023)
  Return `access_denied` error when an OAuth2 request is denied (go-gitea#30974)
  Avoid 500 panic error when uploading invalid maven package file (go-gitea#31014)
  Fix incorrect "blob excerpt" link when comparing files (go-gitea#31013)
  Fix project column title overflow (go-gitea#31011)
  Fix data-race during testing (go-gitea#30999)
  • Loading branch information
zjjhot committed May 21, 2024
2 parents ab16ce1 + 1007ce7 commit d07ae65
Show file tree
Hide file tree
Showing 26 changed files with 260 additions and 151 deletions.
26 changes: 20 additions & 6 deletions models/unit/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"strings"
"sync/atomic"

"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/modules/container"
Expand Down Expand Up @@ -106,10 +107,23 @@ var (
TypeExternalTracker,
}

// DisabledRepoUnits contains the units that have been globally disabled
DisabledRepoUnits = []Type{}
disabledRepoUnitsAtomic atomic.Pointer[[]Type] // the units that have been globally disabled
)

// DisabledRepoUnitsGet returns the globally disabled units, it is a quick patch to fix data-race during testing.
// Because the queue worker might read when a test is mocking the value. FIXME: refactor to a clear solution later.
func DisabledRepoUnitsGet() []Type {
v := disabledRepoUnitsAtomic.Load()
if v == nil {
return nil
}
return *v
}

func DisabledRepoUnitsSet(v []Type) {
disabledRepoUnitsAtomic.Store(&v)
}

// Get valid set of default repository units from settings
func validateDefaultRepoUnits(defaultUnits, settingDefaultUnits []Type) []Type {
units := defaultUnits
Expand All @@ -127,7 +141,7 @@ func validateDefaultRepoUnits(defaultUnits, settingDefaultUnits []Type) []Type {
}

// Remove disabled units
for _, disabledUnit := range DisabledRepoUnits {
for _, disabledUnit := range DisabledRepoUnitsGet() {
for i, unit := range units {
if unit == disabledUnit {
units = append(units[:i], units[i+1:]...)
Expand All @@ -140,11 +154,11 @@ func validateDefaultRepoUnits(defaultUnits, settingDefaultUnits []Type) []Type {

// LoadUnitConfig load units from settings
func LoadUnitConfig() error {
var invalidKeys []string
DisabledRepoUnits, invalidKeys = FindUnitTypes(setting.Repository.DisabledRepoUnits...)
disabledRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DisabledRepoUnits...)
if len(invalidKeys) > 0 {
log.Warn("Invalid keys in disabled repo units: %s", strings.Join(invalidKeys, ", "))
}
DisabledRepoUnitsSet(disabledRepoUnits)

setDefaultRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultRepoUnits...)
if len(invalidKeys) > 0 {
Expand All @@ -167,7 +181,7 @@ func LoadUnitConfig() error {

// UnitGlobalDisabled checks if unit type is global disabled
func (u Type) UnitGlobalDisabled() bool {
for _, ud := range DisabledRepoUnits {
for _, ud := range DisabledRepoUnitsGet() {
if u == ud {
return true
}
Expand Down
24 changes: 12 additions & 12 deletions models/unit/unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
func TestLoadUnitConfig(t *testing.T) {
t.Run("regular", func(t *testing.T) {
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
DisabledRepoUnits = disabledRepoUnits
DisabledRepoUnitsSet(disabledRepoUnits)
DefaultRepoUnits = defaultRepoUnits
DefaultForkRepoUnits = defaultForkRepoUnits
}(DisabledRepoUnits, DefaultRepoUnits, DefaultForkRepoUnits)
}(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
setting.Repository.DisabledRepoUnits = disabledRepoUnits
setting.Repository.DefaultRepoUnits = defaultRepoUnits
Expand All @@ -28,16 +28,16 @@ func TestLoadUnitConfig(t *testing.T) {
setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls"}
setting.Repository.DefaultForkRepoUnits = []string{"repo.releases"}
assert.NoError(t, LoadUnitConfig())
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits)
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
})
t.Run("invalid", func(t *testing.T) {
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
DisabledRepoUnits = disabledRepoUnits
DisabledRepoUnitsSet(disabledRepoUnits)
DefaultRepoUnits = defaultRepoUnits
DefaultForkRepoUnits = defaultForkRepoUnits
}(DisabledRepoUnits, DefaultRepoUnits, DefaultForkRepoUnits)
}(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
setting.Repository.DisabledRepoUnits = disabledRepoUnits
setting.Repository.DefaultRepoUnits = defaultRepoUnits
Expand All @@ -48,16 +48,16 @@ func TestLoadUnitConfig(t *testing.T) {
setting.Repository.DefaultRepoUnits = []string{"repo.code", "invalid.2", "repo.releases", "repo.issues", "repo.pulls"}
setting.Repository.DefaultForkRepoUnits = []string{"invalid.3", "repo.releases"}
assert.NoError(t, LoadUnitConfig())
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits)
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
})
t.Run("duplicate", func(t *testing.T) {
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
DisabledRepoUnits = disabledRepoUnits
DisabledRepoUnitsSet(disabledRepoUnits)
DefaultRepoUnits = defaultRepoUnits
DefaultForkRepoUnits = defaultForkRepoUnits
}(DisabledRepoUnits, DefaultRepoUnits, DefaultForkRepoUnits)
}(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
setting.Repository.DisabledRepoUnits = disabledRepoUnits
setting.Repository.DefaultRepoUnits = defaultRepoUnits
Expand All @@ -68,16 +68,16 @@ func TestLoadUnitConfig(t *testing.T) {
setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls", "repo.code"}
setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"}
assert.NoError(t, LoadUnitConfig())
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits)
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
})
t.Run("empty_default", func(t *testing.T) {
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
DisabledRepoUnits = disabledRepoUnits
DisabledRepoUnitsSet(disabledRepoUnits)
DefaultRepoUnits = defaultRepoUnits
DefaultForkRepoUnits = defaultForkRepoUnits
}(DisabledRepoUnits, DefaultRepoUnits, DefaultForkRepoUnits)
}(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
setting.Repository.DisabledRepoUnits = disabledRepoUnits
setting.Repository.DefaultRepoUnits = defaultRepoUnits
Expand All @@ -88,7 +88,7 @@ func TestLoadUnitConfig(t *testing.T) {
setting.Repository.DefaultRepoUnits = []string{}
setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"}
assert.NoError(t, LoadUnitConfig())
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits)
assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
assert.ElementsMatch(t, []Type{TypeCode, TypePullRequests, TypeReleases, TypeWiki, TypePackages, TypeProjects, TypeActions}, DefaultRepoUnits)
assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
})
Expand Down
5 changes: 2 additions & 3 deletions models/user/email_address.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/mail"
"regexp"
"strings"
"time"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
Expand Down Expand Up @@ -353,14 +354,12 @@ func ChangeInactivePrimaryEmail(ctx context.Context, uid int64, oldEmailAddr, ne

// VerifyActiveEmailCode verifies active email code when active account
func VerifyActiveEmailCode(ctx context.Context, code, email string) *EmailAddress {
minutes := setting.Service.ActiveCodeLives

if user := GetVerifyUser(ctx, code); user != nil {
// time limit code
prefix := code[:base.TimeLimitCodeLength]
data := fmt.Sprintf("%d%s%s%s%s", user.ID, email, user.LowerName, user.Passwd, user.Rands)

if base.VerifyTimeLimitCode(data, minutes, prefix) {
if base.VerifyTimeLimitCode(time.Now(), data, setting.Service.ActiveCodeLives, prefix) {
emailAddress := &EmailAddress{UID: user.ID, Email: email}
if has, _ := db.GetEngine(ctx).Get(emailAddress); has {
return emailAddress
Expand Down
7 changes: 2 additions & 5 deletions models/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ func (u *User) OrganisationLink() string {
func (u *User) GenerateEmailActivateCode(email string) string {
code := base.CreateTimeLimitCode(
fmt.Sprintf("%d%s%s%s%s", u.ID, email, u.LowerName, u.Passwd, u.Rands),
setting.Service.ActiveCodeLives, nil)
setting.Service.ActiveCodeLives, time.Now(), nil)

// Add tail hex username
code += hex.EncodeToString([]byte(u.LowerName))
Expand Down Expand Up @@ -791,14 +791,11 @@ func GetVerifyUser(ctx context.Context, code string) (user *User) {

// VerifyUserActiveCode verifies active code when active account
func VerifyUserActiveCode(ctx context.Context, code string) (user *User) {
minutes := setting.Service.ActiveCodeLives

if user = GetVerifyUser(ctx, code); user != nil {
// time limit code
prefix := code[:base.TimeLimitCodeLength]
data := fmt.Sprintf("%d%s%s%s%s", user.ID, user.Email, user.LowerName, user.Passwd, user.Rands)

if base.VerifyTimeLimitCode(data, minutes, prefix) {
if base.VerifyTimeLimitCode(time.Now(), data, setting.Service.ActiveCodeLives, prefix) {
return user
}
}
Expand Down
85 changes: 40 additions & 45 deletions modules/base/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
package base

import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"hash"
"os"
"path/filepath"
"runtime"
Expand All @@ -25,13 +28,6 @@ import (
"github.com/dustin/go-humanize"
)

// EncodeSha1 string to sha1 hex value.
func EncodeSha1(str string) string {
h := sha1.New()
_, _ = h.Write([]byte(str))
return hex.EncodeToString(h.Sum(nil))
}

// EncodeSha256 string to sha256 hex value.
func EncodeSha256(str string) string {
h := sha256.New()
Expand Down Expand Up @@ -62,63 +58,62 @@ func BasicAuthDecode(encoded string) (string, string, error) {
}

// VerifyTimeLimitCode verify time limit code
func VerifyTimeLimitCode(data string, minutes int, code string) bool {
func VerifyTimeLimitCode(now time.Time, data string, minutes int, code string) bool {
if len(code) <= 18 {
return false
}

// split code
start := code[:12]
lives := code[12:18]
if d, err := strconv.ParseInt(lives, 10, 0); err == nil {
minutes = int(d)
}
startTimeStr := code[:12]
aliveTimeStr := code[12:18]
aliveTime, _ := strconv.Atoi(aliveTimeStr) // no need to check err, if anything wrong, the following code check will fail soon

// right active code
retCode := CreateTimeLimitCode(data, minutes, start)
if retCode == code && minutes > 0 {
// check time is expired or not
before, _ := time.ParseInLocation("200601021504", start, time.Local)
now := time.Now()
if before.Add(time.Minute*time.Duration(minutes)).Unix() > now.Unix() {
return true
// check code
retCode := CreateTimeLimitCode(data, aliveTime, startTimeStr, nil)
if subtle.ConstantTimeCompare([]byte(retCode), []byte(code)) != 1 {
retCode = CreateTimeLimitCode(data, aliveTime, startTimeStr, sha1.New()) // TODO: this is only for the support of legacy codes, remove this in/after 1.23
if subtle.ConstantTimeCompare([]byte(retCode), []byte(code)) != 1 {
return false
}
}

return false
// check time is expired or not: startTime <= now && now < startTime + minutes
startTime, _ := time.ParseInLocation("200601021504", startTimeStr, time.Local)
return (startTime.Before(now) || startTime.Equal(now)) && now.Before(startTime.Add(time.Minute*time.Duration(minutes)))
}

// TimeLimitCodeLength default value for time limit code
const TimeLimitCodeLength = 12 + 6 + 40

// CreateTimeLimitCode create a time limit code
// code format: 12 length date time string + 6 minutes string + 40 sha1 encoded string
func CreateTimeLimitCode(data string, minutes int, startInf any) string {
format := "200601021504"

var start, end time.Time
var startStr, endStr string
// CreateTimeLimitCode create a time-limited code.
// Format: 12 length date time string + 6 minutes string (not used) + 40 hash string, some other code depends on this fixed length
// If h is nil, then use the default hmac hash.
func CreateTimeLimitCode[T time.Time | string](data string, minutes int, startTimeGeneric T, h hash.Hash) string {
const format = "200601021504"

if startInf == nil {
// Use now time create code
start = time.Now()
startStr = start.Format(format)
var start time.Time
var startTimeAny any = startTimeGeneric
if t, ok := startTimeAny.(time.Time); ok {
start = t
} else {
// use start string create code
startStr = startInf.(string)
start, _ = time.ParseInLocation(format, startStr, time.Local)
startStr = start.Format(format)
var err error
start, err = time.ParseInLocation(format, startTimeAny.(string), time.Local)
if err != nil {
return "" // return an invalid code because the "parse" failed
}
}
startStr := start.Format(format)
end := start.Add(time.Minute * time.Duration(minutes))

end = start.Add(time.Minute * time.Duration(minutes))
endStr = end.Format(format)

// create sha1 encode string
sh := sha1.New()
_, _ = sh.Write([]byte(fmt.Sprintf("%s%s%s%s%d", data, hex.EncodeToString(setting.GetGeneralTokenSigningSecret()), startStr, endStr, minutes)))
encoded := hex.EncodeToString(sh.Sum(nil))
if h == nil {
h = hmac.New(sha1.New, setting.GetGeneralTokenSigningSecret())
}
_, _ = fmt.Fprintf(h, "%s%s%s%s%d", data, hex.EncodeToString(setting.GetGeneralTokenSigningSecret()), startStr, end.Format(format), minutes)
encoded := hex.EncodeToString(h.Sum(nil))

code := fmt.Sprintf("%s%06d%s", startStr, minutes, encoded)
if len(code) != TimeLimitCodeLength {
panic("there is a hard requirement for the length of time-limited code") // it shouldn't happen
}
return code
}

Expand Down
Loading

0 comments on commit d07ae65

Please sign in to comment.