test(server): fix bugs and tests

This commit is contained in:
hexxa 2022-09-03 23:32:32 +08:00 committed by Hexxa
parent 4265ab593e
commit ce77eb7534
13 changed files with 266 additions and 288 deletions

View file

@ -71,22 +71,22 @@ type IFilesFunctions interface {
type IFileDB interface { type IFileDB interface {
AddFileInfo(ctx context.Context, userId uint64, itemPath string, info *FileInfo) error AddFileInfo(ctx context.Context, userId uint64, itemPath string, info *FileInfo) error
DelFileInfo(ctx context.Context, userId uint64, itemPath string) error DelFileInfo(ctx context.Context, userId uint64, itemPath string) error
GetFileInfo(ctx context.Context, userId uint64, itemPath string) (*FileInfo, error) GetFileInfo(ctx context.Context, itemPath string) (*FileInfo, error)
SetSha1(ctx context.Context, userId uint64, itemPath, sign string) error SetSha1(ctx context.Context, itemPath, sign string) error
MoveFileInfos(ctx context.Context, userID uint64, oldPath, newPath string, isDir bool) error MoveFileInfos(ctx context.Context, userId uint64, oldPath, newPath string, isDir bool) error
ListFileInfos(ctx context.Context, itemPaths []string) (map[string]*FileInfo, error) ListFileInfos(ctx context.Context, itemPaths []string) (map[string]*FileInfo, error)
} }
type IUploadDB interface { type IUploadDB interface {
AddUploadInfos(ctx context.Context, userId uint64, tmpPath, filePath string, info *FileInfo) error AddUploadInfos(ctx context.Context, userId uint64, tmpPath, filePath string, info *FileInfo) error
DelUploadingInfos(ctx context.Context, userId uint64, realPath string) error DelUploadingInfos(ctx context.Context, userId uint64, realPath string) error
// MoveUploadingInfos(ctx context.Context, userId uint64, uploadPath, itemPath string) error MoveUploadingInfos(ctx context.Context, userId uint64, uploadPath, itemPath string) error
SetUploadInfo(ctx context.Context, user uint64, filePath string, newUploaded int64) error SetUploadInfo(ctx context.Context, user uint64, filePath string, newUploaded int64) error
GetUploadInfo(ctx context.Context, userId uint64, filePath string) (string, int64, int64, error) GetUploadInfo(ctx context.Context, userId uint64, filePath string) (string, int64, int64, error)
ListUploadInfos(ctx context.Context, user uint64) ([]*UploadInfo, error) ListUploadInfos(ctx context.Context, user uint64) ([]*UploadInfo, error)
} }
type ISharingDB interface { type ISharingDB interface {
IsSharing(ctx context.Context, userId uint64, dirPath string) bool IsSharing(ctx context.Context, userId uint64, dirPath string) (bool, error)
GetSharingDir(ctx context.Context, hashID string) (string, error) GetSharingDir(ctx context.Context, hashID string) (string, error)
AddSharing(ctx context.Context, userId uint64, dirPath string) error AddSharing(ctx context.Context, userId uint64, dirPath string) error
DelSharing(ctx context.Context, userId uint64, dirPath string) error DelSharing(ctx context.Context, userId uint64, dirPath string) error

View file

@ -23,7 +23,7 @@ var (
maxHashingTime = 10 maxHashingTime = 10
) )
func (st *SQLiteStore) getFileInfo(ctx context.Context, userId uint64, itemPath string) (*db.FileInfo, error) { func (st *SQLiteStore) getFileInfo(ctx context.Context, itemPath string) (*db.FileInfo, error) {
var infoStr string var infoStr string
fInfo := &db.FileInfo{} fInfo := &db.FileInfo{}
var isDir bool var isDir bool
@ -33,10 +33,8 @@ func (st *SQLiteStore) getFileInfo(ctx context.Context, userId uint64, itemPath
ctx, ctx,
`select is_dir, size, share_id, info `select is_dir, size, share_id, info
from t_file_info from t_file_info
where path=? and user=? where path=?`,
`,
itemPath, itemPath,
userId,
).Scan( ).Scan(
&isDir, &isDir,
&size, &size,
@ -61,11 +59,11 @@ func (st *SQLiteStore) getFileInfo(ctx context.Context, userId uint64, itemPath
return fInfo, nil return fInfo, nil
} }
func (st *SQLiteStore) GetFileInfo(ctx context.Context, userId uint64, itemPath string) (*db.FileInfo, error) { func (st *SQLiteStore) GetFileInfo(ctx context.Context, itemPath string) (*db.FileInfo, error) {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
return st.getFileInfo(ctx, userId, itemPath) return st.getFileInfo(ctx, itemPath)
} }
func (st *SQLiteStore) ListFileInfos(ctx context.Context, itemPaths []string) (map[string]*db.FileInfo, error) { func (st *SQLiteStore) ListFileInfos(ctx context.Context, itemPaths []string) (map[string]*db.FileInfo, error) {
@ -160,31 +158,22 @@ func (st *SQLiteStore) AddFileInfo(ctx context.Context, userId uint64, itemPath
return st.setUsed(ctx, userId, true, info.Size) return st.setUsed(ctx, userId, true, info.Size)
} }
func (st *SQLiteStore) delFileInfo(ctx context.Context, userId uint64, itemPath string) error { func (st *SQLiteStore) delFileInfo(ctx context.Context, itemPath string) error {
_, err := st.db.ExecContext( _, err := st.db.ExecContext(
ctx, ctx,
`delete from t_file_info `delete from t_file_info
where path=? and user=? where path=?
`, `,
itemPath, itemPath,
userId,
) )
return err return err
} }
// func (st *SQLiteStore) DelFileInfo(ctx context.Context, itemPath string) error { func (st *SQLiteStore) SetSha1(ctx context.Context, itemPath, sign string) error {
// st.Lock()
// defer st.Unlock()
// return st.delFileInfo(ctx, itemPath)
// }
// sharings
func (st *SQLiteStore) SetSha1(ctx context.Context, userId uint64, itemPath, sign string) error {
st.Lock() st.Lock()
defer st.Unlock() defer st.Unlock()
info, err := st.getFileInfo(ctx, userId, itemPath) info, err := st.getFileInfo(ctx, itemPath)
if err != nil { if err != nil {
return err return err
} }
@ -199,10 +188,9 @@ func (st *SQLiteStore) SetSha1(ctx context.Context, userId uint64, itemPath, sig
ctx, ctx,
`update t_file_info `update t_file_info
set info=? set info=?
where path=? and user=?`, where path=?`,
infoStr, infoStr,
itemPath, itemPath,
userId,
) )
return err return err
} }
@ -264,11 +252,17 @@ func (st *SQLiteStore) MoveFileInfos(ctx context.Context, userId uint64, oldPath
st.Lock() st.Lock()
defer st.Unlock() defer st.Unlock()
info, err := st.getFileInfo(ctx, userId, oldPath) info, err := st.getFileInfo(ctx, oldPath)
if err != nil { if err != nil {
if errors.Is(err, db.ErrFileInfoNotFound) {
// info for file does not exist so no need to move it
// e.g. folder info is not created before
// TODO: but sometimes it could be a bug
return nil
}
return err return err
} }
err = st.delFileInfo(ctx, userId, oldPath) err = st.delFileInfo(ctx, oldPath)
if err != nil { if err != nil {
return err return err
} }

View file

@ -29,16 +29,29 @@ func (st *SQLiteStore) generateShareID(payload string) (string, error) {
return fmt.Sprintf("%x", h.Sum(nil))[:7], nil return fmt.Sprintf("%x", h.Sum(nil))[:7], nil
} }
func (st *SQLiteStore) IsSharing(ctx context.Context, userId uint64, dirPath string) bool { func (st *SQLiteStore) IsSharing(ctx context.Context, userId uint64, dirPath string) (bool, error) {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
// TODO: differentiate error and not exist // TODO: userId is not used, becauser it is searcher's userId
info, err := st.getFileInfo(ctx, userId, dirPath) var shareId string
err := st.db.QueryRowContext(
ctx,
`select share_id
from t_file_info
where path=?`,
dirPath,
).Scan(
&shareId,
)
if err != nil { if err != nil {
return false if errors.Is(err, sql.ErrNoRows) {
return false, db.ErrFileInfoNotFound
}
return false, err
} }
return info.ShareID != ""
return shareId != "", nil
} }
func (st *SQLiteStore) GetSharingDir(ctx context.Context, hashID string) (string, error) { func (st *SQLiteStore) GetSharingDir(ctx context.Context, hashID string) (string, error) {
@ -75,7 +88,7 @@ func (st *SQLiteStore) AddSharing(ctx context.Context, userId uint64, dirPath st
return err return err
} }
_, err = st.getFileInfo(ctx, userId, dirPath) _, err = st.getFileInfo(ctx, dirPath)
if err != nil && !errors.Is(err, db.ErrFileInfoNotFound) { if err != nil && !errors.Is(err, db.ErrFileInfoNotFound) {
return err return err
} }
@ -102,8 +115,8 @@ func (st *SQLiteStore) AddSharing(ctx context.Context, userId uint64, dirPath st
ctx, ctx,
`update t_file_info `update t_file_info
set share_id=? set share_id=?
where path=? and user=?`, where path=?`,
shareID, dirPath, userId, shareID, dirPath,
) )
return err return err
} }
@ -116,9 +129,8 @@ func (st *SQLiteStore) DelSharing(ctx context.Context, userId uint64, dirPath st
ctx, ctx,
`update t_file_info `update t_file_info
set share_id='' set share_id=''
where path=? and user=?`, where path=?`,
dirPath, dirPath,
userId,
) )
return err return err
} }

View file

@ -8,7 +8,7 @@ import (
"github.com/ihexxa/quickshare/src/db" "github.com/ihexxa/quickshare/src/db"
) )
func (st *SQLiteStore) addUploadInfoOnly(ctx context.Context, userId uint64, filePath, tmpPath string, fileSize int64) error { func (st *SQLiteStore) addUploadInfoOnly(ctx context.Context, userId uint64, tmpPath, filePath string, fileSize int64) error {
_, err := st.db.ExecContext( _, err := st.db.ExecContext(
ctx, ctx,
`insert into t_file_uploading `insert into t_file_uploading
@ -42,7 +42,7 @@ func (st *SQLiteStore) AddUploadInfos(ctx context.Context, userId uint64, tmpPat
return err return err
} }
return st.addUploadInfoOnly(ctx, userId, filePath, tmpPath, info.Size) return st.addUploadInfoOnly(ctx, userId, tmpPath, filePath, info.Size)
} }
func (st *SQLiteStore) DelUploadingInfos(ctx context.Context, userId uint64, realPath string) error { func (st *SQLiteStore) DelUploadingInfos(ctx context.Context, userId uint64, realPath string) error {
@ -82,22 +82,22 @@ func (st *SQLiteStore) delUploadInfoOnly(ctx context.Context, userId uint64, fil
return err return err
} }
// func (st *SQLiteStore) MoveUploadingInfos(ctx context.Context, userId uint64, uploadPath, itemPath string) error { func (st *SQLiteStore) MoveUploadingInfos(ctx context.Context, userId uint64, uploadPath, itemPath string) error {
// st.Lock() st.Lock()
// defer st.Unlock() defer st.Unlock()
// _, size, _, err := st.getUploadInfo(ctx, userId, itemPath) _, size, _, err := st.getUploadInfo(ctx, userId, itemPath)
// if err != nil { if err != nil {
// return err return err
// } }
// err = st.delUploadInfoOnly(ctx, userId, itemPath) err = st.delUploadInfoOnly(ctx, userId, itemPath)
// if err != nil { if err != nil {
// return err return err
// } }
// return st.addFileInfo(ctx, userId, itemPath, &db.FileInfo{ return st.addFileInfo(ctx, userId, itemPath, &db.FileInfo{
// Size: size, Size: size,
// }) })
// } }
func (st *SQLiteStore) SetUploadInfo(ctx context.Context, userId uint64, filePath string, newUploaded int64) error { func (st *SQLiteStore) SetUploadInfo(ctx context.Context, userId uint64, filePath string, newUploaded int64) error {
st.Lock() st.Lock()

View file

@ -49,7 +49,7 @@ func (h *FileHandlers) genSha1(msg worker.IMsg) error {
sha1Sign := fmt.Sprintf("%x", hasher.Sum(nil)) sha1Sign := fmt.Sprintf("%x", hasher.Sum(nil))
err = h.deps.FileInfos(). err = h.deps.FileInfos().
SetSha1(context.TODO(), taskInputs.UserId, taskInputs.FilePath, sha1Sign) // TODO: use source context SetSha1(context.TODO(), taskInputs.FilePath, sha1Sign) // TODO: use source context
if err != nil { if err != nil {
return fmt.Errorf("fail to set sha1: %s", err) return fmt.Errorf("fail to set sha1: %s", err)
} }

View file

@ -11,6 +11,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -41,14 +42,16 @@ const (
) )
type FileHandlers struct { type FileHandlers struct {
cfg gocfg.ICfg cfg gocfg.ICfg
deps *depidx.Deps deps *depidx.Deps
lockedPaths *sync.Map
} }
func NewFileHandlers(cfg gocfg.ICfg, deps *depidx.Deps) (*FileHandlers, error) { func NewFileHandlers(cfg gocfg.ICfg, deps *depidx.Deps) (*FileHandlers, error) {
handlers := &FileHandlers{ handlers := &FileHandlers{
cfg: cfg, cfg: cfg,
deps: deps, deps: deps,
lockedPaths: &sync.Map{},
} }
deps.Workers().AddHandler(MsgTypeSha1, handlers.genSha1) deps.Workers().AddHandler(MsgTypeSha1, handlers.genSha1)
deps.Workers().AddHandler(MsgTypeIndexing, handlers.indexingItems) deps.Workers().AddHandler(MsgTypeIndexing, handlers.indexingItems)
@ -56,43 +59,26 @@ func NewFileHandlers(cfg gocfg.ICfg, deps *depidx.Deps) (*FileHandlers, error) {
return handlers, nil return handlers, nil
} }
type AutoLocker struct { func (h *FileHandlers) lock(key string, code *int, err *error, execution func() (int, error)) {
h *FileHandlers var loaded bool
c *gin.Context
key string
}
func (h *FileHandlers) NewAutoLocker(c *gin.Context, key string) *AutoLocker {
return &AutoLocker{
h: h,
c: c,
key: key,
}
}
func (lk *AutoLocker) Exec(handler func()) error {
var err error
kv := lk.h.deps.KV()
locked := false
defer func() { defer func() {
if p := recover(); p != nil { if p := recover(); p != nil {
lk.h.deps.Log().Error(p) h.deps.Log().Error(p)
*code, *err = 500, fmt.Errorf("%s", p)
} }
if locked { if !loaded {
if err = kv.Unlock(lk.key); err != nil { h.lockedPaths.Delete(key)
lk.h.deps.Log().Error(err)
}
} }
}() }()
if err = kv.TryLock(lk.key); err != nil { _, loaded = h.lockedPaths.LoadOrStore(key, true)
return errors.New("fail to lock the file") if loaded {
*code, *err = 429, fmt.Errorf("failed to lock: %s", key)
return
} }
locked = true *code, *err = execution()
handler()
return nil
} }
// related elements: role, user, action(listing, downloading)/sharing // related elements: role, user, action(listing, downloading)/sharing
@ -115,7 +101,10 @@ func (h *FileHandlers) canAccess(ctx context.Context, userId uint64, userName, r
return false return false
} }
isSharing := h.deps.FileInfos().IsSharing(ctx, userId, sharedPath) isSharing, err := h.deps.FileInfos().IsSharing(ctx, userId, sharedPath)
if err != nil {
return false // TODO: return error
}
return isSharing return isSharing
} }
@ -168,7 +157,7 @@ func (h *FileHandlers) Create(c *gin.Context) {
} }
return return
} }
err = h.deps.FileInfos().MoveFileInfos(c, userID, tmpFilePath, fsFilePath, false) err = h.deps.FileInfos().MoveUploadingInfos(c, userID, tmpFilePath, fsFilePath)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -233,35 +222,25 @@ func (h *FileHandlers) Create(c *gin.Context) {
return return
} }
var txErr error var code int
locker := h.NewAutoLocker(c, lockName(tmpFilePath)) h.lock(lockName(tmpFilePath), &code, &err, func() (int, error) {
lockErr := locker.Exec(func() { err := h.deps.FS().Create(tmpFilePath)
err = h.deps.FS().Create(tmpFilePath)
if err != nil { if err != nil {
if os.IsExist(err) { if os.IsExist(err) {
createErr := fmt.Errorf("file(%s) exists", tmpFilePath) createErr := fmt.Errorf("file(%s) exists", tmpFilePath)
c.JSON(q.ErrResp(c, 304, createErr)) return 304, createErr
txErr = createErr
} else {
c.JSON(q.ErrResp(c, 500, err))
txErr = err
} }
return return 500, err
} }
err = h.deps.FS().MkdirAll(filepath.Dir(req.Path)) err = h.deps.FS().MkdirAll(filepath.Dir(req.Path))
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) return 500, err
txErr = err
return
} }
return 200, nil
}) })
if lockErr != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, lockErr)) c.JSON(q.ErrResp(c, code, err))
return
}
if txErr != nil {
c.JSON(q.ErrResp(c, 500, txErr))
return return
} }
c.JSON(q.Resp(200)) c.JSON(q.Resp(200))
@ -288,33 +267,28 @@ func (h *FileHandlers) Delete(c *gin.Context) {
return return
} }
var txErr error // var txErr error
locker := h.NewAutoLocker(c, lockName(filePath)) // locker := h.NewAutoLocker(c, lockName(filePath))
lockErr := locker.Exec(func() { var code int
err = h.deps.FS().Remove(filePath) h.lock(lockName(filePath), &code, &err, func() (int, error) {
err := h.deps.FS().Remove(filePath)
if err != nil { if err != nil {
txErr = err return 500, err
return
} }
err = h.deps.FileInfos().DelFileInfo(c, userId, filePath) err = h.deps.FileInfos().DelFileInfo(c, userId, filePath)
if err != nil { if err != nil {
txErr = err return 500, err
return
} }
err = h.deps.FileIndex().DelPath(filePath) err = h.deps.FileIndex().DelPath(filePath)
if err != nil && !errors.Is(err, fsearch.ErrNotFound) { if err != nil && !errors.Is(err, fsearch.ErrNotFound) {
txErr = err return 500, err
return
} }
return 200, nil
}) })
if lockErr != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, lockErr)) c.JSON(q.ErrResp(c, code, err))
return
}
if txErr != nil {
c.JSON(q.ErrResp(c, 500, txErr))
return return
} }
c.JSON(q.Resp(200)) c.JSON(q.Resp(200))
@ -510,53 +484,48 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
return return
} }
var txErr error // var txErr error
var statusCode int // var statusCode int
// locker := h.NewAutoLocker(c, lockName(tmpFilePath))
tmpFilePath := q.UploadPath(userName, filePath) tmpFilePath := q.UploadPath(userName, filePath)
locker := h.NewAutoLocker(c, lockName(tmpFilePath)) var code int
fsFilePath, fileSize, uploaded, wrote := "", int64(0), int64(0), 0 fsFilePath, fileSize, uploaded, wrote := "", int64(0), int64(0), 0
lockErr := locker.Exec(func() { h.lock(lockName(tmpFilePath), &code, &err, func() (int, error) {
// lockErr := locker.Exec(func() {
var err error var err error
fsFilePath, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(c, userId, tmpFilePath) fsFilePath, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(c, userId, filePath)
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} else if uploaded != req.Offset { } else if uploaded != req.Offset {
txErr, statusCode = errors.New("offset != uploaded"), 500 return 500, errors.New("offset != uploaded")
return
} }
content, err := base64.StdEncoding.DecodeString(req.Content) content, err := base64.StdEncoding.DecodeString(req.Content)
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} }
wrote, err = h.deps.FS().WriteAt(tmpFilePath, []byte(content), req.Offset) wrote, err = h.deps.FS().WriteAt(tmpFilePath, []byte(content), req.Offset)
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} }
err = h.deps.FileInfos().SetUploadInfo(c, userId, tmpFilePath, req.Offset+int64(wrote)) err = h.deps.FileInfos().SetUploadInfo(c, userId, filePath, req.Offset+int64(wrote))
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} }
// move the file from uploading dir to uploaded dir // move the file from uploading dir to uploaded dir
if uploaded+int64(wrote) == fileSize { if uploaded+int64(wrote) == fileSize {
err = h.deps.FileInfos().MoveFileInfos(c, userId, tmpFilePath, fsFilePath, false) err = h.deps.FileInfos().MoveUploadingInfos(c, userId, tmpFilePath, fsFilePath)
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} }
err = h.deps.FS().Rename(tmpFilePath, fsFilePath) err = h.deps.FS().Rename(tmpFilePath, fsFilePath)
if err != nil { if err != nil {
txErr, statusCode = fmt.Errorf("%s error: %w", fsFilePath, err), 500 return 500, fmt.Errorf("%s error: %w", fsFilePath, err)
return
} }
msg, err := json.Marshal(Sha1Params{ msg, err := json.Marshal(Sha1Params{
@ -564,8 +533,7 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
FilePath: fsFilePath, FilePath: fsFilePath,
}) })
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} }
err = h.deps.Workers().TryPut( err = h.deps.Workers().TryPut(
@ -576,24 +544,21 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
), ),
) )
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} }
err = h.deps.FileIndex().AddPath(fsFilePath) err = h.deps.FileIndex().AddPath(fsFilePath)
if err != nil { if err != nil {
txErr, statusCode = err, 500 return 500, err
return
} }
} }
return 200, nil
}) })
if lockErr != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, code, err))
}
if txErr != nil {
c.JSON(q.ErrResp(c, statusCode, txErr))
return return
} }
c.JSON(200, &UploadStatusResp{ c.JSON(200, &UploadStatusResp{
Path: fsFilePath, Path: fsFilePath,
IsDir: false, IsDir: false,
@ -667,32 +632,28 @@ func (h *FileHandlers) UploadStatus(c *gin.Context) {
return return
} }
// locker := h.NewAutoLocker(c, lockName(tmpFilePath))
// var txErr error
tmpFilePath := q.UploadPath(userName, filePath) tmpFilePath := q.UploadPath(userName, filePath)
locker := h.NewAutoLocker(c, lockName(tmpFilePath))
fileSize, uploaded := int64(0), int64(0) fileSize, uploaded := int64(0), int64(0)
var txErr error var code int
lockErr := locker.Exec(func() { h.lock(lockName(tmpFilePath), &code, &err, func() (int, error) {
// lockErr := locker.Exec(func() {
var err error var err error
_, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(c, userId, tmpFilePath) _, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(c, userId, filePath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.JSON(q.ErrResp(c, 404, err)) return 404, err
txErr = err
} else {
c.JSON(q.ErrResp(c, 500, err))
txErr = err
} }
return return 500, err
} }
return 200, nil
}) })
if lockErr != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, lockErr)) c.JSON(q.ErrResp(c, code, err))
return
}
if txErr != nil {
c.JSON(q.ErrResp(c, 500, txErr))
return return
} }
c.JSON(200, &UploadStatusResp{ c.JSON(200, &UploadStatusResp{
Path: filePath, Path: filePath,
IsDir: false, IsDir: false,
@ -981,35 +942,32 @@ func (h *FileHandlers) DelUploading(c *gin.Context) {
return return
} }
var txErr error // var txErr error
var statusCode int // var statusCode int
tmpFilePath := q.UploadPath(userName, filePath) tmpFilePath := q.UploadPath(userName, filePath)
locker := h.NewAutoLocker(c, lockName(tmpFilePath)) // locker := h.NewAutoLocker(c, lockName(tmpFilePath))
lockErr := locker.Exec(func() { // lockErr := locker.Exec(func() {
var code int
h.lock(lockName(tmpFilePath), &code, &err, func() (int, error) {
_, err = h.deps.FS().Stat(tmpFilePath) _, err = h.deps.FS().Stat(tmpFilePath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
// no op // no op
} else { } else {
txErr, statusCode = err, 500 return 500, err
return
}
} else {
err = h.deps.FS().Remove(tmpFilePath)
if err != nil {
txErr, statusCode = err, 500
return
} }
} }
err = h.deps.FS().Remove(tmpFilePath)
if err != nil {
return 500, err
}
return 200, nil
}) })
if lockErr != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, lockErr)) c.JSON(q.ErrResp(c, code, err))
return
}
if txErr != nil {
c.JSON(q.ErrResp(c, statusCode, txErr))
return return
} }
err = h.deps.FileInfos().DelUploadingInfos(c, userId, filePath) err = h.deps.FileInfos().DelUploadingInfos(c, userId, filePath)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
@ -1111,8 +1069,14 @@ func (h *FileHandlers) IsSharing(c *gin.Context) {
return return
} }
exist := h.deps.FileInfos().IsSharing(c, userId, dirPath) exist, err := h.deps.FileInfos().IsSharing(c, userId, dirPath)
if exist { if err != nil {
if errors.Is(err, db.ErrFileInfoNotFound) {
c.JSON(q.Resp(404))
} else {
c.JSON(q.Resp(500))
}
} else if exist {
c.JSON(q.Resp(200)) c.JSON(q.Resp(200))
} else { } else {
c.JSON(q.Resp(404)) c.JSON(q.Resp(404))

View file

@ -145,7 +145,7 @@ func NewMultiUsersSvc(cfg gocfg.ICfg, deps *depidx.Deps) (*MultiUsersSvc, error)
return handlers, nil return handlers, nil
} }
func (h *MultiUsersSvc) Init(ctx context.Context, adminName, adminPwd string) (string, error) { func (h *MultiUsersSvc) Init(ctx context.Context, adminName string) (string, error) {
var err error var err error
fsPath := q.FsRootPath(adminName, "/") fsPath := q.FsRootPath(adminName, "/")

View file

@ -8,8 +8,8 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ihexxa/quickshare/src/cryptoutil" "github.com/ihexxa/quickshare/src/cryptoutil"
"github.com/ihexxa/quickshare/src/db"
) )
var ( var (
@ -171,6 +171,8 @@ func GetUserId(ctx *gin.Context) (uint64, error) {
if !ok { if !ok {
return 0, errors.New("user id not found") return 0, errors.New("user id not found")
} }
if userID == "" {
return db.VisitorID, nil
}
return strconv.ParseUint(userID, 10, 64) return strconv.ParseUint(userID, 10, 64)
} }

View file

@ -9,8 +9,7 @@ import (
const fileIndexPath = "/fileindex.jsonl" const fileIndexPath = "/fileindex.jsonl"
type DbConfig struct { type DbConfig struct {
DbPath string `json:"dbPath" yaml:"dbPath"` DbPath string `json:"dbPath" yaml:"dbPath"`
RdbPath string `json:"rdbPath" yaml:"rdbPath"` // valid values: rdb, kv
} }
type FSConfig struct { type FSConfig struct {
@ -141,8 +140,7 @@ func DefaultConfigStruct() *Config {
}, },
}, },
Db: &DbConfig{ Db: &DbConfig{
DbPath: "quickshare.db", DbPath: "quickshare.sqlite",
RdbPath: "quickshare.sqlite",
}, },
} }
} }

View file

@ -57,9 +57,9 @@ func NewServer(cfg gocfg.ICfg) (*Server, error) {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
deps := initDeps(cfg) deps, adminName := initDeps(cfg)
router := gin.Default() router := gin.Default()
router, err := initHandlers(router, cfg, deps) router, err := initHandlers(router, adminName, cfg, deps)
if err != nil { if err != nil {
return nil, fmt.Errorf("init handlers error: %w", err) return nil, fmt.Errorf("init handlers error: %w", err)
} }
@ -104,7 +104,7 @@ func mkRoot(rootPath string) {
} }
} }
func initDeps(cfg gocfg.ICfg) *depidx.Deps { func initDeps(cfg gocfg.ICfg) (*depidx.Deps, string) {
var err error var err error
logger := initLogger(cfg) logger := initLogger(cfg)
@ -142,7 +142,7 @@ func initDeps(cfg gocfg.ICfg) *depidx.Deps {
// panic(fmt.Sprintf("failed to init bolt store: %s", err)) // panic(fmt.Sprintf("failed to init bolt store: %s", err))
// } // }
quickshareDb, err := initDB(cfg, filesystem) quickshareDb, adminName, err := initDB(cfg, filesystem)
if err != nil { if err != nil {
logger.Errorf("failed to init DB: %s", err) logger.Errorf("failed to init DB: %s", err)
os.Exit(1) os.Exit(1)
@ -197,25 +197,25 @@ func initDeps(cfg gocfg.ICfg) *depidx.Deps {
logger.Infof("file index inited(%t)", indexInited) logger.Infof("file index inited(%t)", indexInited)
deps.SetFileIndex(fileIndex) deps.SetFileIndex(fileIndex)
return deps return deps, adminName
} }
func initDB(cfg gocfg.ICfg, filesystem fs.ISimpleFS) (db.IDBQuickshare, error) { func initDB(cfg gocfg.ICfg, filesystem fs.ISimpleFS) (db.IDBQuickshare, string, error) {
dbPath := cfg.GrabString("Db.DbPath") dbPath := cfg.GrabString("Db.DbPath")
dbDir := path.Dir(dbPath) dbDir := path.Dir(dbPath)
err := filesystem.MkdirAll(dbDir) err := filesystem.MkdirAll(dbDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create path for db: %w", err) return nil, "", fmt.Errorf("failed to create path for db: %w", err)
} }
sqliteDB, err := sqlite.NewSQLite(dbPath) sqliteDB, err := sqlite.NewSQLite(dbPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create path for db: %w", err) return nil, "", fmt.Errorf("failed to create path for db: %w", err)
} }
dbQuickshare, err := sqlite.NewSQLiteStore(sqliteDB) dbQuickshare, err := sqlite.NewSQLiteStore(sqliteDB)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create quickshare db: %w", err) return nil, "", fmt.Errorf("failed to create quickshare db: %w", err)
} }
var ok bool var ok bool
@ -232,20 +232,20 @@ func initDB(cfg gocfg.ICfg, filesystem fs.ISimpleFS) (db.IDBQuickshare, error) {
if adminPwd == "" { if adminPwd == "" {
adminPwd, err = generatePwd() adminPwd, err = generatePwd()
if err != nil { if err != nil {
return nil, fmt.Errorf("generate password error: %w", err) return nil, "", fmt.Errorf("generate password error: %w", err)
} }
fmt.Printf("password is generated: %s, please update it immediately after login\n", adminPwd) fmt.Printf("password is generated: %s, please update it immediately after login\n", adminPwd)
} }
pwdHash, err = bcrypt.GenerateFromPassword([]byte(adminPwd), 10) pwdHash, err = bcrypt.GenerateFromPassword([]byte(adminPwd), 10)
if err != nil { if err != nil {
return nil, fmt.Errorf("hashing password error: %w", err) return nil, "", fmt.Errorf("hashing password error: %w", err)
} }
} }
err = dbQuickshare.InitUserTable(context.TODO(), adminName, string(pwdHash)) err = dbQuickshare.InitUserTable(context.TODO(), adminName, string(pwdHash))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to init user table: %w", err) return nil, "", fmt.Errorf("failed to init user table: %w", err)
} }
err = dbQuickshare.InitConfigTable( err = dbQuickshare.InitConfigTable(
context.TODO(), context.TODO(),
@ -264,21 +264,26 @@ func initDB(cfg gocfg.ICfg, filesystem fs.ISimpleFS) (db.IDBQuickshare, error) {
}, },
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to init config table: %w", err) return nil, "", fmt.Errorf("failed to init config table: %w", err)
} }
err = dbQuickshare.InitFileTables(context.TODO()) err = dbQuickshare.InitFileTables(context.TODO())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to init files tables: %w", err) return nil, "", fmt.Errorf("failed to init files tables: %w", err)
} }
return dbQuickshare, nil return dbQuickshare, adminName, nil
} }
func initHandlers(router *gin.Engine, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.Engine, error) { func initHandlers(router *gin.Engine, adminName string, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.Engine, error) {
// handlers // handlers
userHdrs, err := multiusers.NewMultiUsersSvc(cfg, deps) userHdrs, err := multiusers.NewMultiUsersSvc(cfg, deps)
if err != nil { if err != nil {
return nil, fmt.Errorf("new users svc error: %w", err) return nil, fmt.Errorf("new users svc error: %w", err)
} }
_, err = userHdrs.Init(context.TODO(), adminName)
if err != nil {
return nil, fmt.Errorf("failed to init user handlers: %w", err)
}
fileHdrs, err := fileshdr.NewFileHandlers(cfg, deps) fileHdrs, err := fileshdr.NewFileHandlers(cfg, deps)
if err != nil { if err != nil {
return nil, fmt.Errorf("new files service error: %w", err) return nil, fmt.Errorf("new files service error: %w", err)
@ -323,10 +328,10 @@ func initHandlers(router *gin.Engine, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.E
usersAPI.PATCH("/preferences", userHdrs.SetPreferences) usersAPI.PATCH("/preferences", userHdrs.SetPreferences)
usersAPI.PUT("/used-space", userHdrs.ResetUsedSpace) usersAPI.PUT("/used-space", userHdrs.ResetUsedSpace)
rolesAPI := v1.Group("/roles") // rolesAPI := v1.Group("/roles")
rolesAPI.POST("/", userHdrs.AddRole) // rolesAPI.POST("/", userHdrs.AddRole)
rolesAPI.DELETE("/", userHdrs.DelRole) // rolesAPI.DELETE("/", userHdrs.DelRole)
rolesAPI.GET("/list", userHdrs.ListRoles) // rolesAPI.GET("/list", userHdrs.ListRoles)
captchaAPI := v1.Group("/captchas") captchaAPI := v1.Group("/captchas")
captchaAPI.GET("/", userHdrs.GetCaptchaID) captchaAPI.GET("/", userHdrs.GetCaptchaID)

View file

@ -452,6 +452,8 @@ func TestFileHandlers(t *testing.T) {
t.Fatal(errs) t.Fatal(errs)
} else if res.StatusCode != 200 { } else if res.StatusCode != 200 {
t.Fatal(res.StatusCode) t.Fatal(res.StatusCode)
} else if len(shRes.IDs) != len(sharedPaths) {
t.Fatal("shared size not match")
} }
for dirPath, shareID := range shRes.IDs { for dirPath, shareID := range shRes.IDs {
if !sharedPaths[dirPath] { if !sharedPaths[dirPath] {
@ -469,6 +471,7 @@ func TestFileHandlers(t *testing.T) {
} }
} }
fmt.Println("\n\n\n", shRes.IDs)
// check isSharing // check isSharing
for dirPath := range sharedPaths { for dirPath := range sharedPaths {
res, _, errs := userFilesCl.IsSharing(dirPath) res, _, errs := userFilesCl.IsSharing(dirPath)
@ -673,7 +676,7 @@ func TestFileHandlers(t *testing.T) {
} else if res.StatusCode != 200 { } else if res.StatusCode != 200 {
t.Fatal(res.StatusCode) t.Fatal(res.StatusCode)
} else if len(lResp.UploadInfos) != 0 { } else if len(lResp.UploadInfos) != 0 {
t.Fatalf("info is not deleted, info len(%d)", len(lResp.UploadInfos)) t.Fatalf("info is not deleted, info len(%+v)", lResp.UploadInfos)
} }
}) })
@ -828,7 +831,7 @@ func TestFileHandlers(t *testing.T) {
res, _, errs = adminFilesClient.IsSharing(dstDir) res, _, errs = adminFilesClient.IsSharing(dstDir)
if len(errs) > 0 { if len(errs) > 0 {
t.Fatal(errs) t.Fatal(errs)
} else if res.StatusCode != 404 { // should not be in sharing } else if res.StatusCode != 200 { // should still be in sharing
t.Fatal(res.StatusCode) t.Fatal(res.StatusCode)
} }

View file

@ -112,7 +112,7 @@ func TestPermissions(t *testing.T) {
} }
tmpUser, tmpPwd, tmpRole := "tmpUser", "1234", "user" tmpUser, tmpPwd, tmpRole := "tmpUser", "1234", "user"
tmpAdmin, tmpAdminPwd := "tmpAdmin", "1234" tmpAdmin, tmpAdminPwd := "tmpAdmin", "1234"
tmpNewRole := "tmpNewRole" // tmpNewRole := "tmpNewRole"
cl := client.NewUsersClient(addr) cl := client.NewUsersClient(addr)
// token := &http.Cookie{} // token := &http.Cookie{}
@ -204,14 +204,14 @@ func TestPermissions(t *testing.T) {
assertResp(t, resp, errs, expectedCodes["DelUserAdmin"], fmt.Sprintf("%s-%s", desc, "DelUserAdmin")) assertResp(t, resp, errs, expectedCodes["DelUserAdmin"], fmt.Sprintf("%s-%s", desc, "DelUserAdmin"))
// role management // role management
resp, _, errs = cl.AddRole(tmpNewRole) // resp, _, errs = cl.AddRole(tmpNewRole)
assertResp(t, resp, errs, expectedCodes["AddRole"], fmt.Sprintf("%s-%s", desc, "AddRole")) // assertResp(t, resp, errs, expectedCodes["AddRole"], fmt.Sprintf("%s-%s", desc, "AddRole"))
resp, _, errs = cl.ListRoles() // resp, _, errs = cl.ListRoles()
assertResp(t, resp, errs, expectedCodes["ListRoles"], fmt.Sprintf("%s-%s", desc, "ListRoles")) // assertResp(t, resp, errs, expectedCodes["ListRoles"], fmt.Sprintf("%s-%s", desc, "ListRoles"))
resp, _, errs = cl.DelRole(tmpNewRole) // resp, _, errs = cl.DelRole(tmpNewRole)
assertResp(t, resp, errs, expectedCodes["DelRole"], fmt.Sprintf("%s-%s", desc, "DelRole")) // assertResp(t, resp, errs, expectedCodes["DelRole"], fmt.Sprintf("%s-%s", desc, "DelRole"))
if requireAuth { if requireAuth {
resp, _, errs := cl.Logout() resp, _, errs := cl.Logout()

View file

@ -384,71 +384,71 @@ func TestUsersHandlers(t *testing.T) {
} }
}) })
t.Run("test roles APIs: Login-AddRole-ListRoles-DelRole-ListRoles-Logout", func(t *testing.T) { // t.Run("test roles APIs: Login-AddRole-ListRoles-DelRole-ListRoles-Logout", func(t *testing.T) {
adminUsersCli := client.NewUsersClient(addr) // adminUsersCli := client.NewUsersClient(addr)
resp, _, errs := adminUsersCli.Login(adminName, adminNewPwd) // resp, _, errs := adminUsersCli.Login(adminName, adminNewPwd)
if len(errs) > 0 { // if len(errs) > 0 {
t.Fatal(errs) // t.Fatal(errs)
} else if resp.StatusCode != 200 { // } else if resp.StatusCode != 200 {
t.Fatal(resp.StatusCode) // t.Fatal(resp.StatusCode)
} // }
// token := client.GetCookie(resp.Cookies(), su.TokenCookie) // // token := client.GetCookie(resp.Cookies(), su.TokenCookie)
roles := []string{"role1", "role2"} // roles := []string{"role1", "role2"}
for _, role := range roles { // for _, role := range roles {
resp, _, errs := adminUsersCli.AddRole(role) // resp, _, errs := adminUsersCli.AddRole(role)
if len(errs) > 0 { // if len(errs) > 0 {
t.Fatal(errs) // t.Fatal(errs)
} else if resp.StatusCode != 200 { // } else if resp.StatusCode != 200 {
t.Fatal(resp.StatusCode) // t.Fatal(resp.StatusCode)
} // }
} // }
resp, lsResp, errs := adminUsersCli.ListRoles() // resp, lsResp, errs := adminUsersCli.ListRoles()
if len(errs) > 0 { // if len(errs) > 0 {
t.Fatal(errs) // t.Fatal(errs)
} else if resp.StatusCode != 200 { // } else if resp.StatusCode != 200 {
t.Fatal(resp.StatusCode) // t.Fatal(resp.StatusCode)
} // }
for _, role := range append(roles, []string{ // for _, role := range append(roles, []string{
db.AdminRole, // db.AdminRole,
db.UserRole, // db.UserRole,
db.VisitorRole, // db.VisitorRole,
}...) { // }...) {
if !lsResp.Roles[role] { // if !lsResp.Roles[role] {
t.Fatalf("role(%s) not found", role) // t.Fatalf("role(%s) not found", role)
} // }
} // }
for _, role := range roles { // for _, role := range roles {
resp, _, errs := adminUsersCli.DelRole(role) // resp, _, errs := adminUsersCli.DelRole(role)
if len(errs) > 0 { // if len(errs) > 0 {
t.Fatal(errs) // t.Fatal(errs)
} else if resp.StatusCode != 200 { // } else if resp.StatusCode != 200 {
t.Fatal(resp.StatusCode) // t.Fatal(resp.StatusCode)
} // }
} // }
resp, lsResp, errs = adminUsersCli.ListRoles() // resp, lsResp, errs = adminUsersCli.ListRoles()
if len(errs) > 0 { // if len(errs) > 0 {
t.Fatal(errs) // t.Fatal(errs)
} else if resp.StatusCode != 200 { // } else if resp.StatusCode != 200 {
t.Fatal(resp.StatusCode) // t.Fatal(resp.StatusCode)
} // }
for _, role := range roles { // for _, role := range roles {
if lsResp.Roles[role] { // if lsResp.Roles[role] {
t.Fatalf("role(%s) should not exist", role) // t.Fatalf("role(%s) should not exist", role)
} // }
} // }
resp, _, errs = adminUsersCli.Logout() // resp, _, errs = adminUsersCli.Logout()
if len(errs) > 0 { // if len(errs) > 0 {
t.Fatal(errs) // t.Fatal(errs)
} else if resp.StatusCode != 200 { // } else if resp.StatusCode != 200 {
t.Fatal(resp.StatusCode) // t.Fatal(resp.StatusCode)
} // }
}) // })
t.Run("Login, SetPreferences, Self, Logout", func(t *testing.T) { t.Run("Login, SetPreferences, Self, Logout", func(t *testing.T) {
adminUsersCli := client.NewUsersClient(addr) adminUsersCli := client.NewUsersClient(addr)