-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(sftp-server): public key login (#7668)
- Loading branch information
Showing
7 changed files
with
289 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package db | ||
|
||
import ( | ||
"github.com/alist-org/alist/v3/internal/model" | ||
"github.com/pkg/errors" | ||
) | ||
|
||
func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { | ||
keyDB := db.Model(&model.SSHPublicKey{}) | ||
query := model.SSHPublicKey{UserId: userId} | ||
if err := keyDB.Where(query).Count(&count).Error; err != nil { | ||
return nil, 0, errors.Wrapf(err, "failed get user's keys count") | ||
} | ||
if err := keyDB.Where(query).Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil { | ||
return nil, 0, errors.Wrapf(err, "failed get find user's keys") | ||
} | ||
return keys, count, nil | ||
} | ||
|
||
func GetSSHPublicKeyById(id uint) (*model.SSHPublicKey, error) { | ||
var k model.SSHPublicKey | ||
if err := db.First(&k, id).Error; err != nil { | ||
return nil, errors.Wrapf(err, "failed get old key") | ||
} | ||
return &k, nil | ||
} | ||
|
||
func GetSSHPublicKeyByUserTitle(userId uint, title string) (*model.SSHPublicKey, error) { | ||
key := model.SSHPublicKey{UserId: userId, Title: title} | ||
if err := db.Where(key).First(&key).Error; err != nil { | ||
return nil, errors.Wrapf(err, "failed find key with title of user") | ||
} | ||
return &key, nil | ||
} | ||
|
||
func CreateSSHPublicKey(k *model.SSHPublicKey) error { | ||
return errors.WithStack(db.Create(k).Error) | ||
} | ||
|
||
func UpdateSSHPublicKey(k *model.SSHPublicKey) error { | ||
return errors.WithStack(db.Save(k).Error) | ||
} | ||
|
||
func GetSSHPublicKeys(pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { | ||
keyDB := db.Model(&model.SSHPublicKey{}) | ||
if err := keyDB.Count(&count).Error; err != nil { | ||
return nil, 0, errors.Wrapf(err, "failed get keys count") | ||
} | ||
if err := keyDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil { | ||
return nil, 0, errors.Wrapf(err, "failed get find keys") | ||
} | ||
return keys, count, nil | ||
} | ||
|
||
func DeleteSSHPublicKeyById(id uint) error { | ||
return errors.WithStack(db.Delete(&model.SSHPublicKey{}, id).Error) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package model | ||
|
||
import ( | ||
"golang.org/x/crypto/ssh" | ||
"time" | ||
) | ||
|
||
type SSHPublicKey struct { | ||
ID uint `json:"id" gorm:"primaryKey"` | ||
UserId uint `json:"-"` | ||
Title string `json:"title"` | ||
Fingerprint string `json:"fingerprint"` | ||
KeyStr string `gorm:"type:text" json:"-"` | ||
AddedTime time.Time `json:"added_time"` | ||
LastUsedTime time.Time `json:"last_used_time"` | ||
} | ||
|
||
func (k *SSHPublicKey) GetKey() (ssh.PublicKey, error) { | ||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return pubKey, nil | ||
} | ||
|
||
func (k *SSHPublicKey) UpdateLastUsedTime() { | ||
k.LastUsedTime = time.Now() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package op | ||
|
||
import ( | ||
"github.com/alist-org/alist/v3/internal/db" | ||
"github.com/alist-org/alist/v3/internal/model" | ||
"github.com/pkg/errors" | ||
"golang.org/x/crypto/ssh" | ||
"time" | ||
) | ||
|
||
func CreateSSHPublicKey(k *model.SSHPublicKey) (error, bool) { | ||
_, err := db.GetSSHPublicKeyByUserTitle(k.UserId, k.Title) | ||
if err == nil { | ||
return errors.New("key with the same title already exists"), true | ||
} | ||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr)) | ||
if err != nil { | ||
return err, false | ||
} | ||
k.KeyStr = string(pubKey.Marshal()) | ||
k.Fingerprint = ssh.FingerprintSHA256(pubKey) | ||
k.AddedTime = time.Now() | ||
k.LastUsedTime = k.AddedTime | ||
return db.CreateSSHPublicKey(k), true | ||
} | ||
|
||
func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { | ||
return db.GetSSHPublicKeyByUserId(userId, pageIndex, pageSize) | ||
} | ||
|
||
func GetSSHPublicKeyByIdAndUserId(id uint, userId uint) (*model.SSHPublicKey, error) { | ||
key, err := db.GetSSHPublicKeyById(id) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if key.UserId != userId { | ||
return nil, errors.Wrapf(err, "failed get old key") | ||
} | ||
return key, nil | ||
} | ||
|
||
func UpdateSSHPublicKey(k *model.SSHPublicKey) error { | ||
return db.UpdateSSHPublicKey(k) | ||
} | ||
|
||
func DeleteSSHPublicKeyById(keyId uint) error { | ||
return db.DeleteSSHPublicKeyById(keyId) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package handles | ||
|
||
import ( | ||
"github.com/alist-org/alist/v3/internal/model" | ||
"github.com/alist-org/alist/v3/internal/op" | ||
"github.com/alist-org/alist/v3/server/common" | ||
"github.com/gin-gonic/gin" | ||
"strconv" | ||
) | ||
|
||
type SSHKeyAddReq struct { | ||
Title string `json:"title" binding:"required"` | ||
Key string `json:"key" binding:"required"` | ||
} | ||
|
||
func AddMyPublicKey(c *gin.Context) { | ||
userObj, ok := c.Value("user").(*model.User) | ||
if !ok || userObj.IsGuest() { | ||
common.ErrorStrResp(c, "user invalid", 401) | ||
return | ||
} | ||
var req SSHKeyAddReq | ||
if err := c.ShouldBind(&req); err != nil { | ||
common.ErrorStrResp(c, "request invalid", 400) | ||
return | ||
} | ||
if req.Title == "" { | ||
common.ErrorStrResp(c, "request invalid", 400) | ||
return | ||
} | ||
key := &model.SSHPublicKey{ | ||
Title: req.Title, | ||
KeyStr: req.Key, | ||
UserId: userObj.ID, | ||
} | ||
err, parsed := op.CreateSSHPublicKey(key) | ||
if !parsed { | ||
common.ErrorStrResp(c, "provided key invalid", 400) | ||
return | ||
} else if err != nil { | ||
common.ErrorResp(c, err, 500, true) | ||
return | ||
} | ||
common.SuccessResp(c) | ||
} | ||
|
||
func ListMyPublicKey(c *gin.Context) { | ||
userObj, ok := c.Value("user").(*model.User) | ||
if !ok || userObj.IsGuest() { | ||
common.ErrorStrResp(c, "user invalid", 401) | ||
return | ||
} | ||
list(c, userObj) | ||
} | ||
|
||
func DeleteMyPublicKey(c *gin.Context) { | ||
userObj, ok := c.Value("user").(*model.User) | ||
if !ok || userObj.IsGuest() { | ||
common.ErrorStrResp(c, "user invalid", 401) | ||
return | ||
} | ||
keyId, err := strconv.Atoi(c.Query("id")) | ||
if err != nil { | ||
common.ErrorStrResp(c, "id format invalid", 400) | ||
return | ||
} | ||
key, err := op.GetSSHPublicKeyByIdAndUserId(uint(keyId), userObj.ID) | ||
if err != nil { | ||
common.ErrorStrResp(c, "failed to get public key", 404) | ||
return | ||
} | ||
err = op.DeleteSSHPublicKeyById(key.ID) | ||
if err != nil { | ||
common.ErrorResp(c, err, 500, true) | ||
return | ||
} | ||
common.SuccessResp(c) | ||
} | ||
|
||
func ListPublicKeys(c *gin.Context) { | ||
userId, err := strconv.Atoi(c.Query("uid")) | ||
if err != nil { | ||
common.ErrorStrResp(c, "user id format invalid", 400) | ||
return | ||
} | ||
userObj, err := op.GetUserById(uint(userId)) | ||
if err != nil { | ||
common.ErrorStrResp(c, "user invalid", 404) | ||
return | ||
} | ||
list(c, userObj) | ||
} | ||
|
||
func DeletePublicKey(c *gin.Context) { | ||
keyId, err := strconv.Atoi(c.Query("id")) | ||
if err != nil { | ||
common.ErrorStrResp(c, "id format invalid", 400) | ||
return | ||
} | ||
err = op.DeleteSSHPublicKeyById(uint(keyId)) | ||
if err != nil { | ||
common.ErrorResp(c, err, 500, true) | ||
return | ||
} | ||
common.SuccessResp(c) | ||
} | ||
|
||
func list(c *gin.Context, userObj *model.User) { | ||
var req model.PageReq | ||
if err := c.ShouldBind(&req); err != nil { | ||
common.ErrorResp(c, err, 400) | ||
return | ||
} | ||
req.Validate() | ||
keys, total, err := op.GetSSHPublicKeyByUserId(userObj.ID, req.Page, req.PerPage) | ||
if err != nil { | ||
common.ErrorResp(c, err, 500, true) | ||
return | ||
} | ||
common.SuccessResp(c, common.PageResp{ | ||
Content: keys, | ||
Total: total, | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters