feat(handlers): replace boltdb with sqlite in handlers

This commit is contained in:
hexxa 2022-09-03 20:24:45 +08:00 committed by Hexxa
parent 791848f75c
commit 085a3e4e10
14 changed files with 342 additions and 307 deletions

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
@ -17,7 +18,8 @@ func main() {
panic(err) panic(err)
} }
cfg, err := serverPkg.LoadCfg(args) ctx := context.TODO()
cfg, err := serverPkg.LoadCfg(ctx, args)
if err != nil { if err != nil {
fmt.Printf("failed to load config: %s", err) fmt.Printf("failed to load config: %s", err)
os.Exit(1) os.Exit(1)

View file

@ -25,8 +25,11 @@ type IDB interface {
// Stats() DBStats // Stats() DBStats
} }
type IDBFunctions interface { type IDBQuickshare interface {
Init(ctx context.Context, adminName, adminPwd string, config *SiteConfig) error Init(ctx context.Context, adminName, adminPwd string, config *SiteConfig) error
InitUserTable(ctx context.Context, rootName, rootPwd string) error
InitFileTables(ctx context.Context) error
InitConfigTable(ctx context.Context, cfg *SiteConfig) error
IDBLockable IDBLockable
IUserDB IUserDB
IFileDB IFileDB
@ -43,7 +46,6 @@ type IDBLockable interface {
} }
type IUserDB interface { type IUserDB interface {
IsInited() bool
AddUser(ctx context.Context, user *User) error AddUser(ctx context.Context, user *User) error
DelUser(ctx context.Context, id uint64) error DelUser(ctx context.Context, id uint64) error
GetUser(ctx context.Context, id uint64) (*User, error) GetUser(ctx context.Context, id uint64) (*User, error)
@ -60,9 +62,14 @@ type IUserDB interface {
ListRoles() (map[string]bool, error) ListRoles() (map[string]bool, error)
} }
type IFilesFunctions interface {
IFileDB
IUploadDB
ISharingDB
}
type IFileDB interface { type IFileDB interface {
AddFileInfo(ctx context.Context, userId uint64, itemPath string, info *FileInfo) error AddFileInfo(ctx context.Context, userId uint64, itemPath string, info *FileInfo) error
// DelFileInfo(ctx context.Context, itemPath string) error
DelFileInfo(ctx context.Context, userId uint64, itemPath string) error DelFileInfo(ctx context.Context, userId uint64, itemPath string) error
GetFileInfo(ctx context.Context, userId uint64, itemPath string) (*FileInfo, error) GetFileInfo(ctx context.Context, userId uint64, itemPath string) (*FileInfo, error)
SetSha1(ctx context.Context, userId uint64, itemPath, sign string) error SetSha1(ctx context.Context, userId uint64, itemPath, sign string) error

View file

@ -14,14 +14,8 @@ import (
"github.com/ihexxa/quickshare/src/db/rdb/sqlite" "github.com/ihexxa/quickshare/src/db/rdb/sqlite"
) )
type IFilesFunctions interface {
db.IFileDB
db.IUploadDB
db.ISharingDB
}
func TestFileStore(t *testing.T) { func TestFileStore(t *testing.T) {
testSharingMethods := func(t *testing.T, store db.IDBFunctions) { testSharingMethods := func(t *testing.T, store db.IDBQuickshare) {
dirPaths := []string{"admin/path1", "admin/path1/path2"} dirPaths := []string{"admin/path1", "admin/path1/path2"}
var err error var err error
@ -125,7 +119,7 @@ func TestFileStore(t *testing.T) {
} }
} }
testFileInfoMethods := func(t *testing.T, store db.IDBFunctions) { testFileInfoMethods := func(t *testing.T, store db.IDBQuickshare) {
pathInfos := map[string]*db.FileInfo{ pathInfos := map[string]*db.FileInfo{
"admin/origin/item1": &db.FileInfo{ "admin/origin/item1": &db.FileInfo{
// Shared: false, // deprecated // Shared: false, // deprecated
@ -274,7 +268,7 @@ func TestFileStore(t *testing.T) {
} }
} }
testUploadingMethods := func(t *testing.T, store db.IDBFunctions) { testUploadingMethods := func(t *testing.T, store db.IDBQuickshare) {
pathInfos := map[string]*db.FileInfo{ pathInfos := map[string]*db.FileInfo{
"admin/origin/item1": &db.FileInfo{ "admin/origin/item1": &db.FileInfo{
// Shared: false, // deprecated // Shared: false, // deprecated

View file

@ -6,10 +6,10 @@ import (
"github.com/ihexxa/quickshare/src/cron" "github.com/ihexxa/quickshare/src/cron"
"github.com/ihexxa/quickshare/src/cryptoutil" "github.com/ihexxa/quickshare/src/cryptoutil"
"github.com/ihexxa/quickshare/src/db/boltstore" // "github.com/ihexxa/quickshare/src/db/boltstore"
"github.com/ihexxa/quickshare/src/db/fileinfostore" // "github.com/ihexxa/quickshare/src/db/fileinfostore"
"github.com/ihexxa/quickshare/src/db/rdb" "github.com/ihexxa/quickshare/src/db"
"github.com/ihexxa/quickshare/src/db/sitestore" // "github.com/ihexxa/quickshare/src/db/sitestore"
"github.com/ihexxa/quickshare/src/fs" "github.com/ihexxa/quickshare/src/fs"
"github.com/ihexxa/quickshare/src/idgen" "github.com/ihexxa/quickshare/src/idgen"
"github.com/ihexxa/quickshare/src/iolimiter" "github.com/ihexxa/quickshare/src/iolimiter"
@ -30,17 +30,17 @@ type Deps struct {
fs fs.ISimpleFS fs fs.ISimpleFS
token cryptoutil.ITokenEncDec token cryptoutil.ITokenEncDec
kv kvstore.IKVStore kv kvstore.IKVStore
users db.IUserStore // users db.IUserDB
fileInfos fileinfostore.IFileInfoStore // fileInfos db.IFileDB
siteStore sitestore.ISiteStore // siteStore db.IConfigDB
// boltStore *boltstore.BoltStore
id idgen.IIDGen id idgen.IIDGen
logger *zap.SugaredLogger logger *zap.SugaredLogger
limiter iolimiter.ILimiter limiter iolimiter.ILimiter
workers worker.IWorkerPool workers worker.IWorkerPool
boltStore *boltstore.BoltStore
cron cron.ICron cron cron.ICron
fileIndex fileindex.IFileIndex fileIndex fileindex.IFileIndex
db rdb.IDB db db.IDBQuickshare
} }
func NewDeps(cfg gocfg.ICfg) *Deps { func NewDeps(cfg gocfg.ICfg) *Deps {
@ -87,28 +87,16 @@ func (deps *Deps) SetLog(logger *zap.SugaredLogger) {
deps.logger = logger deps.logger = logger
} }
func (deps *Deps) Users() db.IUserStore { func (deps *Deps) Users() db.IUserDB {
return deps.users return deps.db
} }
func (deps *Deps) SetUsers(users db.IUserStore) { func (deps *Deps) FileInfos() db.IFilesFunctions {
deps.users = users return deps.db
} }
func (deps *Deps) FileInfos() fileinfostore.IFileInfoStore { func (deps *Deps) SiteStore() db.IConfigDB {
return deps.fileInfos return deps.db
}
func (deps *Deps) SetFileInfos(fileInfos fileinfostore.IFileInfoStore) {
deps.fileInfos = fileInfos
}
func (deps *Deps) SiteStore() sitestore.ISiteStore {
return deps.siteStore
}
func (deps *Deps) SetSiteStore(siteStore sitestore.ISiteStore) {
deps.siteStore = siteStore
} }
func (deps *Deps) Limiter() iolimiter.ILimiter { func (deps *Deps) Limiter() iolimiter.ILimiter {
@ -127,13 +115,13 @@ func (deps *Deps) SetWorkers(workers worker.IWorkerPool) {
deps.workers = workers deps.workers = workers
} }
func (deps *Deps) BoltStore() *boltstore.BoltStore { // func (deps *Deps) BoltStore() *boltstore.BoltStore {
return deps.boltStore // return deps.boltStore
} // }
func (deps *Deps) SetBoltStore(boltStore *boltstore.BoltStore) { // func (deps *Deps) SetBoltStore(boltStore *boltstore.BoltStore) {
deps.boltStore = boltStore // deps.boltStore = boltStore
} // }
func (deps *Deps) Cron() cron.ICron { func (deps *Deps) Cron() cron.ICron {
return deps.cron return deps.cron
@ -151,10 +139,10 @@ func (deps *Deps) SetFileIndex(index fileindex.IFileIndex) {
deps.fileIndex = index deps.fileIndex = index
} }
func (deps *Deps) DB() rdb.IDB { func (deps *Deps) DB() db.IDBQuickshare {
return deps.db return deps.db
} }
func (deps *Deps) SetDB(db rdb.IDB) { func (deps *Deps) SetDB(rdb db.IDBQuickshare) {
deps.db = db deps.db = rdb
} }

View file

@ -1,6 +1,7 @@
package fileshdr package fileshdr
import ( import (
"context"
"crypto/sha1" "crypto/sha1"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -18,6 +19,7 @@ const (
type Sha1Params struct { type Sha1Params struct {
FilePath string FilePath string
UserId uint64
} }
func (h *FileHandlers) genSha1(msg worker.IMsg) error { func (h *FileHandlers) genSha1(msg worker.IMsg) error {
@ -46,7 +48,8 @@ func (h *FileHandlers) genSha1(msg worker.IMsg) error {
} }
sha1Sign := fmt.Sprintf("%x", hasher.Sum(nil)) sha1Sign := fmt.Sprintf("%x", hasher.Sum(nil))
err = h.deps.FileInfos().SetSha1(taskInputs.FilePath, sha1Sign) err = h.deps.FileInfos().
SetSha1(context.TODO(), taskInputs.UserId, taskInputs.FilePath, sha1Sign) // TODO: use source context
if err != nil { if err != nil {
return fmt.Errorf("fail to set sha1: %s", err) return fmt.Errorf("fail to set sha1: %s", err)
} }

View file

@ -1,6 +1,7 @@
package fileshdr package fileshdr
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@ -9,7 +10,6 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"time" "time"
@ -96,7 +96,7 @@ func (lk *AutoLocker) Exec(handler func()) error {
} }
// related elements: role, user, action(listing, downloading)/sharing // related elements: role, user, action(listing, downloading)/sharing
func (h *FileHandlers) canAccess(userName, role, op, sharedPath string) bool { func (h *FileHandlers) canAccess(ctx context.Context, userId uint64, userName, role, op, sharedPath string) bool {
if role == db.AdminRole { if role == db.AdminRole {
return true return true
} }
@ -115,8 +115,8 @@ func (h *FileHandlers) canAccess(userName, role, op, sharedPath string) bool {
return false return false
} }
isSharing, ok := h.deps.FileInfos().GetSharing(sharedPath) isSharing := h.deps.FileInfos().IsSharing(ctx, userId, sharedPath)
return isSharing && ok return isSharing
} }
type CreateReq struct { type CreateReq struct {
@ -131,9 +131,13 @@ func (h *FileHandlers) Create(c *gin.Context) {
return return
} }
userID := c.MustGet(q.UserIDParam).(string) userID, err := q.GetUserId(c)
fsFilePath, err := h.getFSFilePath(userID, req.Path) if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
fsFilePath, err := h.getFSFilePath(fmt.Sprint(userID), req.Path)
if err != nil { if err != nil {
if errors.Is(err, os.ErrExist) { if errors.Is(err, os.ErrExist) {
c.JSON(q.ErrResp(c, 400, err)) c.JSON(q.ErrResp(c, 400, err))
@ -145,22 +149,15 @@ func (h *FileHandlers) Create(c *gin.Context) {
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
if !h.canAccess(userName, role, "create", fsFilePath) { if !h.canAccess(c, userID, userName, role, "create", fsFilePath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
userIDInt, err := strconv.ParseUint(userID, 10, 64)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
tmpFilePath := q.UploadPath(userName, fsFilePath) tmpFilePath := q.UploadPath(userName, fsFilePath)
if req.FileSize == 0 { if req.FileSize == 0 {
// TODO: limit the number of files with 0 byte // TODO: limit the number of files with 0 byte
err = h.deps.BoltStore().AddUploadInfos(userIDInt, tmpFilePath, fsFilePath, &db.FileInfo{ err = h.deps.FileInfos().AddUploadInfos(c, userID, tmpFilePath, fsFilePath, &db.FileInfo{
Size: req.FileSize, Size: req.FileSize,
}) })
if err != nil { if err != nil {
@ -171,7 +168,7 @@ func (h *FileHandlers) Create(c *gin.Context) {
} }
return return
} }
err = h.deps.BoltStore().MoveUploadingInfos(userIDInt, tmpFilePath, fsFilePath) err = h.deps.FileInfos().MoveFileInfos(c, userID, tmpFilePath, fsFilePath, false)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -194,6 +191,7 @@ func (h *FileHandlers) Create(c *gin.Context) {
} }
msg, err := json.Marshal(Sha1Params{ msg, err := json.Marshal(Sha1Params{
UserId: userID,
FilePath: fsFilePath, FilePath: fsFilePath,
}) })
if err != nil { if err != nil {
@ -223,7 +221,7 @@ func (h *FileHandlers) Create(c *gin.Context) {
return return
} }
err = h.deps.BoltStore().AddUploadInfos(userIDInt, tmpFilePath, fsFilePath, &db.FileInfo{ err = h.deps.FileInfos().AddUploadInfos(c, userID, tmpFilePath, fsFilePath, &db.FileInfo{
Size: req.FileSize, Size: req.FileSize,
}) })
if err != nil { if err != nil {
@ -277,17 +275,16 @@ func (h *FileHandlers) Delete(c *gin.Context) {
return return
} }
role := c.MustGet(q.RoleParam).(string) userId, err := q.GetUserId(c)
userName := c.MustGet(q.UserParam).(string) if err != nil {
if !h.canAccess(userName, role, "delete", filePath) { c.JSON(q.ErrResp(c, 500, err))
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
userID := c.MustGet(q.UserIDParam).(string) role := c.MustGet(q.RoleParam).(string)
userIDInt, err := strconv.ParseUint(userID, 10, 64) userName := c.MustGet(q.UserParam).(string)
if err != nil { if !h.canAccess(c, userId, userName, role, "delete", filePath) {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
@ -300,7 +297,7 @@ func (h *FileHandlers) Delete(c *gin.Context) {
return return
} }
err = h.deps.BoltStore().DelInfos(userIDInt, filePath) err = h.deps.FileInfos().DelFileInfo(c, userId, filePath)
if err != nil { if err != nil {
txErr = err txErr = err
return return
@ -338,9 +335,16 @@ func (h *FileHandlers) Metadata(c *gin.Context) {
c.JSON(q.ErrResp(c, 400, errors.New("invalid file path"))) c.JSON(q.ErrResp(c, 400, errors.New("invalid file path")))
return return
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
if !h.canAccess(userName, role, "metadata", filePath) { if !h.canAccess(c, userId, userName, role, "metadata", filePath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
@ -374,15 +378,21 @@ func (h *FileHandlers) Mkdir(c *gin.Context) {
return return
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
dirPath := filepath.Clean(req.Path) dirPath := filepath.Clean(req.Path)
if !h.canAccess(userName, role, "mkdir", dirPath) { if !h.canAccess(c, userId, userName, role, "mkdir", dirPath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
err := h.deps.FS().MkdirAll(dirPath) err = h.deps.FS().MkdirAll(dirPath)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -407,21 +417,21 @@ func (h *FileHandlers) Move(c *gin.Context) {
c.JSON(q.ErrResp(c, 400, err)) c.JSON(q.ErrResp(c, 400, err))
return return
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userID := c.MustGet(q.UserIDParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
oldPath := filepath.Clean(req.OldPath) oldPath := filepath.Clean(req.OldPath)
newPath := filepath.Clean(req.NewPath) newPath := filepath.Clean(req.NewPath)
if !h.canAccess(userName, role, "move", oldPath) || !h.canAccess(userName, role, "move", newPath) { if !h.canAccess(c, userId, userName, role, "move", oldPath) ||
!h.canAccess(c, userId, userName, role, "move", newPath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
userIDInt, err := strconv.ParseUint(userID, 10, 64)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
itemInfo, err := h.deps.FS().Stat(oldPath) itemInfo, err := h.deps.FS().Stat(oldPath)
if err != nil { if err != nil {
@ -438,7 +448,7 @@ func (h *FileHandlers) Move(c *gin.Context) {
return return
} }
err = h.deps.BoltStore().MoveInfos(userIDInt, oldPath, newPath, itemInfo.IsDir()) err = h.deps.FileInfos().MoveFileInfos(c, userId, oldPath, newPath, itemInfo.IsDir())
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -477,22 +487,21 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
} }
role := c.MustGet(q.RoleParam).(string)
userID := c.MustGet(q.UserIDParam).(string)
userName := c.MustGet(q.UserParam).(string)
filePath := filepath.Clean(req.Path)
if !h.canAccess(userName, role, "upload.chunk", filePath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return
}
userIDInt, err := strconv.ParseUint(userID, 10, 64) userId, err := q.GetUserId(c)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
} }
role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string)
filePath := filepath.Clean(req.Path)
if !h.canAccess(c, userId, userName, role, "upload.chunk", filePath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return
}
ok, err := h.deps.Limiter().CanWrite(userIDInt, len([]byte(req.Content))) ok, err := h.deps.Limiter().CanWrite(userId, len([]byte(req.Content)))
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -509,7 +518,7 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
lockErr := locker.Exec(func() { lockErr := locker.Exec(func() {
var err error var err error
fsFilePath, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(userID, tmpFilePath) fsFilePath, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(c, userId, tmpFilePath)
if err != nil { if err != nil {
txErr, statusCode = err, 500 txErr, statusCode = err, 500
return return
@ -530,7 +539,7 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
return return
} }
err = h.deps.FileInfos().SetUploadInfo(userID, tmpFilePath, req.Offset+int64(wrote)) err = h.deps.FileInfos().SetUploadInfo(c, userId, tmpFilePath, req.Offset+int64(wrote))
if err != nil { if err != nil {
txErr, statusCode = err, 500 txErr, statusCode = err, 500
return return
@ -538,7 +547,7 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
// move the file from uploading dir to uploaded dir // move the file from uploading dir to uploaded dir
if uploaded+int64(wrote) == fileSize { if uploaded+int64(wrote) == fileSize {
err = h.deps.BoltStore().MoveUploadingInfos(userIDInt, tmpFilePath, fsFilePath) err = h.deps.FileInfos().MoveFileInfos(c, userId, tmpFilePath, fsFilePath, false)
if err != nil { if err != nil {
txErr, statusCode = err, 500 txErr, statusCode = err, 500
return return
@ -551,6 +560,7 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
} }
msg, err := json.Marshal(Sha1Params{ msg, err := json.Marshal(Sha1Params{
UserId: userId,
FilePath: fsFilePath, FilePath: fsFilePath,
}) })
if err != nil { if err != nil {
@ -643,21 +653,27 @@ func (h *FileHandlers) UploadStatus(c *gin.Context) {
if filePath == "" { if filePath == "" {
c.JSON(q.ErrResp(c, 400, errors.New("invalid file name"))) c.JSON(q.ErrResp(c, 400, errors.New("invalid file name")))
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
if !h.canAccess(userName, role, "upload.status", filePath) { if !h.canAccess(c, userId, userName, role, "upload.status", filePath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
userID := c.MustGet(q.UserIDParam).(string)
tmpFilePath := q.UploadPath(userName, filePath) tmpFilePath := q.UploadPath(userName, filePath)
locker := h.NewAutoLocker(c, lockName(tmpFilePath)) locker := h.NewAutoLocker(c, lockName(tmpFilePath))
fileSize, uploaded := int64(0), int64(0) fileSize, uploaded := int64(0), int64(0)
var txErr error var txErr error
lockErr := locker.Exec(func() { lockErr := locker.Exec(func() {
var err error var err error
_, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(userID, tmpFilePath) _, fileSize, uploaded, err = h.deps.FileInfos().GetUploadInfo(c, userId, tmpFilePath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.JSON(q.ErrResp(c, 404, err)) c.JSON(q.ErrResp(c, 404, err))
@ -695,25 +711,26 @@ func (h *FileHandlers) Download(c *gin.Context) {
c.JSON(q.ErrResp(c, 400, errors.New("invalid file name"))) c.JSON(q.ErrResp(c, 400, errors.New("invalid file name")))
return return
} }
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
dirPath := filepath.Dir(filePath) dirPath := filepath.Dir(filePath)
if !h.canAccess(userName, role, "download", dirPath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return
}
var err error var err error
userIDInt := userstore.VisitorID userId := userstore.VisitorID
if role != db.VisitorRole { if role != db.VisitorRole {
userID := c.MustGet(q.UserIDParam).(string) userId, err = q.GetUserId(c)
userIDInt, err = strconv.ParseUint(userID, 10, 64)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
} }
} }
if !h.canAccess(c, userId, userName, role, "download", dirPath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return
}
// TODO: when sharing is introduced, move following logics to a separeted method // TODO: when sharing is introduced, move following logics to a separeted method
// concurrently file accessing is managed by os // concurrently file accessing is managed by os
info, err := h.deps.FS().Stat(filePath) info, err := h.deps.FS().Stat(filePath)
@ -757,7 +774,7 @@ func (h *FileHandlers) Download(c *gin.Context) {
// respond to normal requests // respond to normal requests
if ifRangeVal != "" || rangeVal == "" { if ifRangeVal != "" || rangeVal == "" {
limitedReader, err := h.GetStreamReader(userIDInt, fd) limitedReader, err := h.GetStreamReader(userId, fd)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -797,7 +814,7 @@ func (h *FileHandlers) Download(c *gin.Context) {
// TODO: reader will be closed by multipart response writer // TODO: reader will be closed by multipart response writer
go mr.Start() go mr.Start()
limitedReader, err := h.GetStreamReader(userIDInt, mr) limitedReader, err := h.GetStreamReader(userId, mr)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -818,7 +835,7 @@ type ListResp struct {
Metadatas []*MetadataResp `json:"metadatas"` Metadatas []*MetadataResp `json:"metadatas"`
} }
func (h *FileHandlers) MergeFileInfos(dirPath string, infos []os.FileInfo) ([]*MetadataResp, error) { func (h *FileHandlers) MergeFileInfos(ctx *gin.Context, dirPath string, infos []os.FileInfo) ([]*MetadataResp, error) {
filePaths := []string{} filePaths := []string{}
metadatas := []*MetadataResp{} metadatas := []*MetadataResp{}
for _, info := range infos { for _, info := range infos {
@ -833,7 +850,7 @@ func (h *FileHandlers) MergeFileInfos(dirPath string, infos []os.FileInfo) ([]*M
}) })
} }
dbInfos, err := h.deps.FileInfos().GetInfos(filePaths) dbInfos, err := h.deps.FileInfos().ListFileInfos(ctx, filePaths)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -856,9 +873,15 @@ func (h *FileHandlers) List(c *gin.Context) {
c.JSON(q.ErrResp(c, 400, errors.New("incorrect path name"))) c.JSON(q.ErrResp(c, 400, errors.New("incorrect path name")))
return return
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
if !h.canAccess(userName, role, "list", dirPath) { if !h.canAccess(c, userId, userName, role, "list", dirPath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
@ -869,7 +892,7 @@ func (h *FileHandlers) List(c *gin.Context) {
return return
} }
metadatas, err := h.MergeFileInfos(dirPath, infos) metadatas, err := h.MergeFileInfos(c, dirPath, infos)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -891,7 +914,7 @@ func (h *FileHandlers) ListHome(c *gin.Context) {
return return
} }
metadatas, err := h.MergeFileInfos(fsPath, infos) metadatas, err := h.MergeFileInfos(c, fsPath, infos)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -920,9 +943,13 @@ type ListUploadingsResp struct {
} }
func (h *FileHandlers) ListUploadings(c *gin.Context) { func (h *FileHandlers) ListUploadings(c *gin.Context) {
userID := c.MustGet(q.UserIDParam).(string) userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
infos, err := h.deps.FileInfos().ListUploadInfo(userID) infos, err := h.deps.FileInfos().ListUploadInfos(c, userId)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -941,21 +968,19 @@ func (h *FileHandlers) DelUploading(c *gin.Context) {
return return
} }
userID := c.MustGet(q.UserIDParam).(string) userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
// op is empty, because users must be admin, or the path belongs to this user // op is empty, because users must be admin, or the path belongs to this user
if !h.canAccess(userName, role, "", filePath) { if !h.canAccess(c, userId, userName, role, "", filePath) {
c.JSON(q.ErrResp(c, 403, errors.New("forbidden"))) c.JSON(q.ErrResp(c, 403, errors.New("forbidden")))
return return
} }
userIDInt, err := strconv.ParseUint(userID, 10, 64)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
var txErr error var txErr error
var statusCode int var statusCode int
tmpFilePath := q.UploadPath(userName, filePath) tmpFilePath := q.UploadPath(userName, filePath)
@ -985,7 +1010,7 @@ func (h *FileHandlers) DelUploading(c *gin.Context) {
c.JSON(q.ErrResp(c, statusCode, txErr)) c.JSON(q.ErrResp(c, statusCode, txErr))
return return
} }
err = h.deps.BoltStore().DelUploadingInfos(userIDInt, tmpFilePath) err = h.deps.FileInfos().DelUploadingInfos(c, userId, filePath)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -1004,12 +1029,18 @@ func (h *FileHandlers) AddSharing(c *gin.Context) {
return return
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
sharingPath := filepath.Clean(req.SharingPath) sharingPath := filepath.Clean(req.SharingPath)
// TODO: move canAccess to authedFS // TODO: move canAccess to authedFS
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
// op is empty, because users must be admin, or the path belongs to this user // op is empty, because users must be admin, or the path belongs to this user
if !h.canAccess(userName, role, "", sharingPath) { if !h.canAccess(c, userId, userName, role, "", sharingPath) {
c.JSON(q.ErrResp(c, 403, errors.New("forbidden"))) c.JSON(q.ErrResp(c, 403, errors.New("forbidden")))
return return
} }
@ -1028,7 +1059,7 @@ func (h *FileHandlers) AddSharing(c *gin.Context) {
return return
} }
err = h.deps.FileInfos().AddSharing(sharingPath) err = h.deps.FileInfos().AddSharing(c, userId, sharingPath)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -1044,15 +1075,21 @@ func (h *FileHandlers) DelSharing(c *gin.Context) {
return return
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
// TODO: move canAccess to authedFS // TODO: move canAccess to authedFS
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
if !h.canAccess(userName, role, "", dirPath) { if !h.canAccess(c, userId, userName, role, "", dirPath) {
c.JSON(q.ErrResp(c, 403, errors.New("forbidden"))) c.JSON(q.ErrResp(c, 403, errors.New("forbidden")))
return return
} }
err := h.deps.FileInfos().DelSharing(dirPath) err = h.deps.FileInfos().DelSharing(c, userId, dirPath)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -1068,8 +1105,14 @@ func (h *FileHandlers) IsSharing(c *gin.Context) {
return return
} }
exist, ok := h.deps.FileInfos().GetSharing(dirPath) userId, err := q.GetUserId(c)
if exist && ok { if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
exist := h.deps.FileInfos().IsSharing(c, userId, dirPath)
if exist {
c.JSON(q.Resp(200)) c.JSON(q.Resp(200))
} else { } else {
c.JSON(q.Resp(404)) c.JSON(q.Resp(404))
@ -1083,9 +1126,13 @@ type SharingResp struct {
// Deprecated: use ListSharingIDs instead // Deprecated: use ListSharingIDs instead
func (h *FileHandlers) ListSharings(c *gin.Context) { func (h *FileHandlers) ListSharings(c *gin.Context) {
// TODO: move canAccess to authedFS // TODO: move canAccess to authedFS
userName := c.MustGet(q.UserParam).(string) userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
sharingDirs, err := h.deps.FileInfos().ListSharings(q.FsRootPath(userName, "/")) sharingDirs, err := h.deps.FileInfos().ListUserSharings(c, userId)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -1103,10 +1150,13 @@ type SharingIDsResp struct {
} }
func (h *FileHandlers) ListSharingIDs(c *gin.Context) { func (h *FileHandlers) ListSharingIDs(c *gin.Context) {
// TODO: move canAccess to authedFS userId, err := q.GetUserId(c)
userName := c.MustGet(q.UserParam).(string) if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
dirToID, err := h.deps.FileInfos().ListSharings(q.FsRootPath(userName, "/")) dirToID, err := h.deps.FileInfos().ListUserSharings(c, userId)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -1131,14 +1181,21 @@ func (h *FileHandlers) GenerateHash(c *gin.Context) {
return return
} }
userId, err := q.GetUserId(c)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
role := c.MustGet(q.RoleParam).(string) role := c.MustGet(q.RoleParam).(string)
userName := c.MustGet(q.UserParam).(string) userName := c.MustGet(q.UserParam).(string)
if !h.canAccess(userName, role, "hash.gen", filePath) { if !h.canAccess(c, userId, userName, role, "hash.gen", filePath) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return return
} }
msg, err := json.Marshal(Sha1Params{ msg, err := json.Marshal(Sha1Params{
UserId: userId,
FilePath: filePath, FilePath: filePath,
}) })
if err != nil { if err != nil {
@ -1171,7 +1228,7 @@ func (h *FileHandlers) GetSharingDir(c *gin.Context) {
return return
} }
dirPath, err := h.deps.FileInfos().GetSharingDir(shareID) dirPath, err := h.deps.FileInfos().GetSharingDir(c, shareID)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return

View file

@ -1,6 +1,7 @@
package multiusers package multiusers
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"path/filepath" "path/filepath"
@ -44,5 +45,5 @@ func (h *MultiUsersSvc) resetUsedSpace(msg worker.IMsg) error {
} }
} }
return h.deps.Users().ResetUsed(params.UserID, usedSpace) return h.deps.Users().ResetUsed(context.TODO(), params.UserID, usedSpace) // TODO: use source context
} }

View file

@ -1,6 +1,7 @@
package multiusers package multiusers
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -144,7 +145,7 @@ func NewMultiUsersSvc(cfg gocfg.ICfg, deps *depidx.Deps) (*MultiUsersSvc, error)
return handlers, nil return handlers, nil
} }
func (h *MultiUsersSvc) Init(adminName, adminPwd string) (string, error) { func (h *MultiUsersSvc) Init(ctx context.Context, adminName, adminPwd string) (string, error) {
var err error var err error
fsPath := q.FsRootPath(adminName, "/") fsPath := q.FsRootPath(adminName, "/")
@ -156,12 +157,6 @@ func (h *MultiUsersSvc) Init(adminName, adminPwd string) (string, error) {
return "", err return "", err
} }
// TODO: return "" for being compatible with singleuser service, should remove this
err = h.deps.Users().Init(c, adminName, adminPwd)
if err != nil {
return "", err
}
usersInterface, ok := h.cfg.Slice("Users.PredefinedUsers") usersInterface, ok := h.cfg.Slice("Users.PredefinedUsers")
spaceLimit := int64(h.cfg.IntOr("Users.SpaceLimit", 100*1024*1024)) spaceLimit := int64(h.cfg.IntOr("Users.SpaceLimit", 100*1024*1024))
uploadSpeedLimit := h.cfg.IntOr("Users.UploadSpeedLimit", 100*1024) uploadSpeedLimit := h.cfg.IntOr("Users.UploadSpeedLimit", 100*1024)
@ -205,7 +200,7 @@ func (h *MultiUsersSvc) Init(adminName, adminPwd string) (string, error) {
Preferences: &preferences, Preferences: &preferences,
} }
err = h.deps.Users().AddUser(c, user) err = h.deps.Users().AddUser(ctx, user)
if err != nil { if err != nil {
h.deps.Log().Warn("warning: failed to add user(%s): %s", user, err) h.deps.Log().Warn("warning: failed to add user(%s): %s", user, err)
return "", err return "", err
@ -216,10 +211,6 @@ func (h *MultiUsersSvc) Init(adminName, adminPwd string) (string, error) {
return "", nil return "", nil
} }
func (h *MultiUsersSvc) IsInited() bool {
return h.deps.Users().IsInited()
}
type LoginReq struct { type LoginReq struct {
User string `json:"user"` User string `json:"user"`
Pwd string `json:"pwd"` Pwd string `json:"pwd"`

View file

@ -35,7 +35,7 @@ type ClientCfgMsg struct {
func (h *SettingsSvc) GetClientCfg(c *gin.Context) { func (h *SettingsSvc) GetClientCfg(c *gin.Context) {
// TODO: add cache // TODO: add cache
siteCfg, err := h.deps.SiteStore().GetCfg() siteCfg, err := h.deps.SiteStore().GetCfg(c)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return
@ -74,7 +74,7 @@ func (h *SettingsSvc) SetClientCfg(c *gin.Context) {
h.cfg.SetBool("Site.ClientCfg.AllowSetBg", req.ClientCfg.AllowSetBg) h.cfg.SetBool("Site.ClientCfg.AllowSetBg", req.ClientCfg.AllowSetBg)
h.cfg.SetBool("Site.ClientCfg.AutoTheme", req.ClientCfg.AutoTheme) h.cfg.SetBool("Site.ClientCfg.AutoTheme", req.ClientCfg.AutoTheme)
err = h.deps.SiteStore().SetClientCfg(clientCfg) err = h.deps.SiteStore().SetClientCfg(c, clientCfg)
if err != nil { if err != nil {
c.JSON(q.ErrResp(c, 500, err)) c.JSON(q.ErrResp(c, 500, err))
return return

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -164,3 +165,12 @@ func GetUserInfo(tokenStr string, tokenEncDec cryptoutil.ITokenEncDec) (map[stri
return claims, nil return claims, nil
} }
func GetUserId(ctx *gin.Context) (uint64, error) {
userID, ok := ctx.MustGet(UserIDParam).(string)
if !ok {
return 0, errors.New("user id not found")
}
return strconv.ParseUint(userID, 10, 64)
}

View file

@ -1,11 +1,11 @@
package iolimiter package iolimiter
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"github.com/ihexxa/quickshare/src/db" "github.com/ihexxa/quickshare/src/db"
"github.com/ihexxa/quickshare/src/db/userstore"
"github.com/ihexxa/quickshare/src/golimiter" "github.com/ihexxa/quickshare/src/golimiter"
) )
@ -20,11 +20,11 @@ type IOLimiter struct {
mtx *sync.Mutex mtx *sync.Mutex
UploadLimiter *golimiter.Limiter UploadLimiter *golimiter.Limiter
DownloadLimiter *golimiter.Limiter DownloadLimiter *golimiter.Limiter
users userstore.IUserStore users db.IUserDB
quotaCache map[uint64]*db.Quota quotaCache map[uint64]*db.Quota
} }
func NewIOLimiter(cap, cyc int, users userstore.IUserStore) *IOLimiter { func NewIOLimiter(cap, cyc int, users db.IUserDB) *IOLimiter {
return &IOLimiter{ return &IOLimiter{
mtx: &sync.Mutex{}, mtx: &sync.Mutex{},
UploadLimiter: golimiter.New(cap, cyc), UploadLimiter: golimiter.New(cap, cyc),
@ -40,7 +40,7 @@ func (lm *IOLimiter) CanWrite(id uint64, chunkSize int) (bool, error) {
quota, ok := lm.quotaCache[id] quota, ok := lm.quotaCache[id]
if !ok { if !ok {
user, err := lm.users.GetUser(id) user, err := lm.users.GetUser(context.TODO(), id) // TODO: add context
if err != nil { if err != nil {
return false, err return false, err
} }
@ -64,7 +64,7 @@ func (lm *IOLimiter) CanRead(id uint64, chunkSize int) (bool, error) {
quota, ok := lm.quotaCache[id] quota, ok := lm.quotaCache[id]
if !ok { if !ok {
user, err := lm.users.GetUser(id) user, err := lm.users.GetUser(context.TODO(), id) // TODO: add context
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -23,7 +24,7 @@ type Args struct {
// LoadCfg loads the default config, the config in database, config files and arguments in order. // LoadCfg loads the default config, the config in database, config files and arguments in order.
// All config values will be merged into one, and the latter overwrites the former. // All config values will be merged into one, and the latter overwrites the former.
// Each config can be part of the whole ServerCfg // Each config can be part of the whole ServerCfg
func LoadCfg(args *Args) (*gocfg.Cfg, error) { func LoadCfg(ctx context.Context, args *Args) (*gocfg.Cfg, error) {
defaultCfg, err := DefaultConfig() defaultCfg, err := DefaultConfig()
if err != nil { if err != nil {
return nil, err return nil, err
@ -40,7 +41,7 @@ func LoadCfg(args *Args) (*gocfg.Cfg, error) {
} }
_, err = os.Stat(dbPath) _, err = os.Stat(dbPath)
if err == nil { if err == nil {
cfg, err = mergeDbConfig(cfg, dbPath) cfg, err = mergeDbConfig(ctx, cfg, dbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -52,7 +53,7 @@ func LoadCfg(args *Args) (*gocfg.Cfg, error) {
} }
} }
cfg, err = mergeConfigFiles(cfg, args.Configs) cfg, err = mergeConfigFiles(ctx, cfg, args.Configs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -60,7 +61,7 @@ func LoadCfg(args *Args) (*gocfg.Cfg, error) {
return mergeArgs(cfg, args) return mergeArgs(cfg, args)
} }
func mergeDbConfig(cfg *gocfg.Cfg, dbPath string) (*gocfg.Cfg, error) { func mergeDbConfig(ctx context.Context, cfg *gocfg.Cfg, dbPath string) (*gocfg.Cfg, error) {
kv := boltdbpvd.New(dbPath, 1024) kv := boltdbpvd.New(dbPath, 1024)
defer kv.Close() defer kv.Close()
@ -69,7 +70,7 @@ func mergeDbConfig(cfg *gocfg.Cfg, dbPath string) (*gocfg.Cfg, error) {
return nil, fmt.Errorf("fail to new site config store: %s", err) return nil, fmt.Errorf("fail to new site config store: %s", err)
} }
clientCfg, err := siteStore.GetCfg() clientCfg, err := siteStore.GetCfg(ctx)
if err != nil { if err != nil {
if errors.Is(err, sitestore.ErrNotFound) { if errors.Is(err, sitestore.ErrNotFound) {
return cfg, nil return cfg, nil
@ -111,7 +112,7 @@ func getDbPath(cfg *gocfg.Cfg, configPaths []string, argDbPath string) (string,
return cfg.GrabString("Db.DbPath"), nil return cfg.GrabString("Db.DbPath"), nil
} }
func mergeConfigFiles(cfg *gocfg.Cfg, configPaths []string) (*gocfg.Cfg, error) { func mergeConfigFiles(ctx context.Context, cfg *gocfg.Cfg, configPaths []string) (*gocfg.Cfg, error) {
var err error var err error
for _, configPath := range configPaths { for _, configPath := range configPaths {

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
@ -381,7 +382,7 @@ func TestLoadCfg(t *testing.T) {
testLoadCfg := func(t *testing.T) { testLoadCfg := func(t *testing.T) {
for i, args := range argsList { for i, args := range argsList {
gotCfg, err := LoadCfg(args) gotCfg, err := LoadCfg(context.TODO(), args)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -4,13 +4,13 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/json" // "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path"
"strconv" "strconv"
"syscall" "syscall"
"time" "time"
@ -25,11 +25,11 @@ import (
"github.com/ihexxa/quickshare/src/cryptoutil/jwt" "github.com/ihexxa/quickshare/src/cryptoutil/jwt"
"github.com/ihexxa/quickshare/src/db" "github.com/ihexxa/quickshare/src/db"
"github.com/ihexxa/quickshare/src/db/boltstore" // "github.com/ihexxa/quickshare/src/db/boltstore"
"github.com/ihexxa/quickshare/src/db/fileinfostore" // "github.com/ihexxa/quickshare/src/db/fileinfostore"
"github.com/ihexxa/quickshare/src/db/rdb/sqlite" "github.com/ihexxa/quickshare/src/db/rdb/sqlite"
"github.com/ihexxa/quickshare/src/db/sitestore" // "github.com/ihexxa/quickshare/src/db/sitestore"
"github.com/ihexxa/quickshare/src/db/userstore" // "github.com/ihexxa/quickshare/src/db/userstore"
"github.com/ihexxa/quickshare/src/depidx" "github.com/ihexxa/quickshare/src/depidx"
"github.com/ihexxa/quickshare/src/fs" "github.com/ihexxa/quickshare/src/fs"
"github.com/ihexxa/quickshare/src/fs/local" "github.com/ihexxa/quickshare/src/fs/local"
@ -39,7 +39,7 @@ import (
"github.com/ihexxa/quickshare/src/idgen/simpleidgen" "github.com/ihexxa/quickshare/src/idgen/simpleidgen"
"github.com/ihexxa/quickshare/src/iolimiter" "github.com/ihexxa/quickshare/src/iolimiter"
"github.com/ihexxa/quickshare/src/kvstore" "github.com/ihexxa/quickshare/src/kvstore"
"github.com/ihexxa/quickshare/src/kvstore/boltdbpvd" // "github.com/ihexxa/quickshare/src/kvstore/boltdbpvd"
"github.com/ihexxa/quickshare/src/search/fileindex" "github.com/ihexxa/quickshare/src/search/fileindex"
"github.com/ihexxa/quickshare/src/worker/localworker" "github.com/ihexxa/quickshare/src/worker/localworker"
qsstatic "github.com/ihexxa/quickshare/static" qsstatic "github.com/ihexxa/quickshare/static"
@ -64,11 +64,6 @@ func NewServer(cfg gocfg.ICfg) (*Server, error) {
return nil, fmt.Errorf("init handlers error: %w", err) return nil, fmt.Errorf("init handlers error: %w", err)
} }
err = checkCompatibility(deps)
if err != nil {
return nil, fmt.Errorf("failed to check compatibility: %w", err)
}
port := cfg.GrabInt("Server.Port") port := cfg.GrabInt("Server.Port")
portStr, ok := cfg.String("ENV.PORT") portStr, ok := cfg.String("ENV.PORT")
if ok && portStr != "" { if ok && portStr != "" {
@ -93,21 +88,6 @@ func NewServer(cfg gocfg.ICfg) (*Server, error) {
}, nil }, nil
} }
func checkCompatibility(deps *depidx.Deps) error {
users, err := deps.Users().ListUsers()
if err != nil {
return err
}
for _, user := range users {
if user.Preferences == nil {
deps.Users().SetPreferences(user.ID, &db.DefaultPreferences)
}
}
return nil
}
func mkRoot(rootPath string) { func mkRoot(rootPath string) {
info, err := os.Stat(rootPath) info, err := os.Stat(rootPath)
if err != nil { if err != nil {
@ -131,7 +111,7 @@ func initDeps(cfg gocfg.ICfg) *depidx.Deps {
secret, ok := cfg.String("ENV.TOKENSECRET") secret, ok := cfg.String("ENV.TOKENSECRET")
if !ok { if !ok {
secret = makeRandToken() secret = makeRandToken()
logger.Info("warning: TOKENSECRET is not given, using generated token") logger.Info("warning: TOKENSECRET is not set, will generate token")
} }
rootPath := cfg.GrabString("Fs.Root") rootPath := cfg.GrabString("Fs.Root")
@ -144,45 +124,132 @@ func initDeps(cfg gocfg.ICfg) *depidx.Deps {
filesystem := local.NewLocalFS(rootPath, 0660, opensLimit, openTTL, readerTTL, ider) filesystem := local.NewLocalFS(rootPath, 0660, opensLimit, openTTL, readerTTL, ider)
jwtEncDec := jwt.NewJWTEncDec(secret) jwtEncDec := jwt.NewJWTEncDec(secret)
// kv := boltdbpvd.New(dbPath, 1024)
// users, err := userstore.NewKVUserStore(kv)
// if err != nil {
// panic(fmt.Sprintf("failed to init user store: %s", err))
// }
// fileInfos, err := fileinfostore.NewFileInfoStore(kv)
// if err != nil {
// panic(fmt.Sprintf("failed to init file info store: %s", err))
// }
// siteStore, err := sitestore.NewSiteStore(kv)
// if err != nil {
// panic(fmt.Sprintf("failed to init site config store: %s", err))
// }
// boltDB, err := boltstore.NewBoltStore(kv.Bolt())
// if err != nil {
// panic(fmt.Sprintf("failed to init bolt store: %s", err))
// }
quickshareDb, err := initDB(cfg, filesystem)
if err != nil {
logger.Errorf("failed to init DB: %s", err)
os.Exit(1)
}
limiterCap := cfg.IntOr("Users.LimiterCapacity", 10000)
limiterCyc := cfg.IntOr("Users.LimiterCyc", 1000)
limiter := iolimiter.NewIOLimiter(limiterCap, limiterCyc, quickshareDb)
deps := depidx.NewDeps(cfg)
deps.SetDB(quickshareDb)
deps.SetFS(filesystem)
deps.SetToken(jwtEncDec)
// deps.SetKV(kv)
// deps.SetUsers(users)
// deps.SetFileInfos(fileInfos)
// deps.SetSiteStore(siteStore)
// deps.SetBoltStore(boltDB)
deps.SetID(ider)
deps.SetLog(logger)
deps.SetLimiter(limiter)
queueSize := cfg.GrabInt("Workers.QueueSize")
sleepCyc := cfg.GrabInt("Workers.SleepCyc")
workerCount := cfg.GrabInt("Workers.WorkerCount")
workers := localworker.NewWorkerPool(queueSize, sleepCyc, workerCount, logger)
workers.Start()
deps.SetWorkers(workers)
searchResultLimit := cfg.GrabInt("Server.SearchResultLimit")
// initFileIndex := cfg.GrabBool("Server.InitFileIndex")
fileIndex := fileindex.NewFileTreeIndex(filesystem, "/", searchResultLimit)
indexInfo, err := filesystem.Stat(fileIndexPath)
indexInited := false
if err != nil {
if !os.IsNotExist(err) {
logger.Warnf("failed to detect file index: %s", err)
} else {
logger.Warnf("no file index found")
}
} else if indexInfo.IsDir() {
logger.Warnf("file index is folder, not file: %s", fileIndexPath)
} else {
err = fileIndex.ReadFrom(fileIndexPath)
if err != nil {
logger.Infof("failed to load file index: %s", err)
} else {
indexInited = true
}
}
logger.Infof("file index inited(%t)", indexInited)
deps.SetFileIndex(fileIndex)
return deps
}
func initDB(cfg gocfg.ICfg, filesystem fs.ISimpleFS) (db.IDBQuickshare, error) {
dbPath := cfg.GrabString("Db.DbPath") dbPath := cfg.GrabString("Db.DbPath")
dbDir := filepath.Dir(dbPath) dbDir := path.Dir(dbPath)
if err = filesystem.MkdirAll(dbDir); err != nil {
panic(fmt.Sprintf("failed to create path for db: %s", err)) err := filesystem.MkdirAll(dbDir)
if err != nil {
return nil, fmt.Errorf("failed to create path for db: %w", err)
} }
kv := boltdbpvd.New(dbPath, 1024) sqliteDB, err := sqlite.NewSQLite(dbPath)
users, err := userstore.NewKVUserStore(kv)
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to init user store: %s", err)) return nil, fmt.Errorf("failed to create path for db: %w", err)
} }
fileInfos, err := fileinfostore.NewFileInfoStore(kv) dbQuickshare, err := sqlite.NewSQLiteStore(sqliteDB)
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to init file info store: %s", err)) return nil, fmt.Errorf("failed to create quickshare db: %w", err)
}
siteStore, err := sitestore.NewSiteStore(kv)
if err != nil {
panic(fmt.Sprintf("failed to init site config store: %s", err))
}
boltDB, err := boltstore.NewBoltStore(kv.Bolt())
if err != nil {
panic(fmt.Sprintf("failed to init bolt store: %s", err))
} }
rdbPath := cfg.GrabString("Db.RdbPath") var ok bool
if rdbPath == "" { var adminName string
panic("rdbPath is blank") var pwdHash []byte
} if cfg.BoolOr("Users.EnableAuth", true) {
rdbDir := filepath.Dir(rdbPath) adminName, ok = cfg.String("ENV.DEFAULTADMIN")
if err = filesystem.MkdirAll(rdbDir); err != nil { if !ok || adminName == "" {
panic(fmt.Sprintf("failed to create path for rdb: %s", err)) fmt.Println("Please input admin name: ")
fmt.Scanf("%s", &adminName)
} }
rdb, err := sqlite.NewSQLite(rdbPath) adminPwd, _ := cfg.String("ENV.DEFAULTADMINPWD")
if adminPwd == "" {
adminPwd, err = generatePwd()
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to open sqlite: %s", err)) return nil, fmt.Errorf("generate password error: %w", err)
}
fmt.Printf("password is generated: %s, please update it immediately after login\n", adminPwd)
} }
err = siteStore.Init(&db.SiteConfig{ pwdHash, err = bcrypt.GenerateFromPassword([]byte(adminPwd), 10)
if err != nil {
return nil, fmt.Errorf("hashing password error: %w", err)
}
}
err = dbQuickshare.InitUserTable(context.TODO(), adminName, string(pwdHash))
if err != nil {
return nil, fmt.Errorf("failed to init user table: %w", err)
}
err = dbQuickshare.InitConfigTable(
context.TODO(),
&db.SiteConfig{
ClientCfg: &db.ClientConfig{ ClientCfg: &db.ClientConfig{
SiteName: cfg.StringOr("Site.ClientCfg.SiteName", "Quickshare"), SiteName: cfg.StringOr("Site.ClientCfg.SiteName", "Quickshare"),
SiteDesc: cfg.StringOr("Site.ClientCfg.SiteDesc", "Quick and simple file sharing"), SiteDesc: cfg.StringOr("Site.ClientCfg.SiteDesc", "Quick and simple file sharing"),
@ -194,73 +261,16 @@ func initDeps(cfg gocfg.ICfg) *depidx.Deps {
BgColor: cfg.StringOr("Site.ClientCfg.Bg.BgColor", ""), BgColor: cfg.StringOr("Site.ClientCfg.Bg.BgColor", ""),
}, },
}, },
}) },
if err != nil {
panic(fmt.Sprintf("failed to init site config store: %s", err))
}
limiterCap := cfg.IntOr("Users.LimiterCapacity", 10000)
limiterCyc := cfg.IntOr("Users.LimiterCyc", 1000)
limiter := iolimiter.NewIOLimiter(limiterCap, limiterCyc, users)
deps := depidx.NewDeps(cfg)
deps.SetFS(filesystem)
deps.SetToken(jwtEncDec)
deps.SetKV(kv)
deps.SetUsers(users)
deps.SetFileInfos(fileInfos)
deps.SetSiteStore(siteStore)
deps.SetBoltStore(boltDB)
deps.SetID(ider)
deps.SetLog(logger)
deps.SetLimiter(limiter)
deps.SetDB(rdb)
queueSize := cfg.GrabInt("Workers.QueueSize")
sleepCyc := cfg.GrabInt("Workers.SleepCyc")
workerCount := cfg.GrabInt("Workers.WorkerCount")
workers := localworker.NewWorkerPool(queueSize, sleepCyc, workerCount, logger)
workers.Start()
deps.SetWorkers(workers)
searchResultLimit := cfg.GrabInt("Server.SearchResultLimit")
initFileIndex := cfg.GrabBool("Server.InitFileIndex")
fileIndex := fileindex.NewFileTreeIndex(filesystem, "/", searchResultLimit)
indexInfo, err := filesystem.Stat(fileIndexPath)
inited := false
if err != nil {
if !os.IsNotExist(err) {
logger.Infof("failed to detect file index: %s", err)
} else {
logger.Info("warning: no file index found")
}
} else if indexInfo.IsDir() {
logger.Infof("file index is folder, not file: %s", fileIndexPath)
} else {
err = fileIndex.ReadFrom(fileIndexPath)
if err != nil {
logger.Infof("failed to load file index: %s", err)
} else {
inited = true
}
}
if !inited && initFileIndex {
msg, _ := json.Marshal(fileshdr.IndexingParams{})
err = deps.Workers().TryPut(
localworker.NewMsg(
deps.ID().Gen(),
map[string]string{localworker.MsgTypeKey: fileshdr.MsgTypeIndexing},
string(msg),
),
) )
if err != nil { if err != nil {
logger.Infof("failed to reindex file index: %s", err) return nil, fmt.Errorf("failed to init config table: %w", err)
} }
err = dbQuickshare.InitFileTables(context.TODO())
if err != nil {
return nil, fmt.Errorf("failed to init files tables: %w", err)
} }
deps.SetFileIndex(fileIndex) return dbQuickshare, nil
return deps
} }
func initHandlers(router *gin.Engine, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.Engine, error) { func initHandlers(router *gin.Engine, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.Engine, error) {
@ -269,40 +279,10 @@ func initHandlers(router *gin.Engine, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.E
if err != nil { if err != nil {
return nil, fmt.Errorf("new users svc error: %w", err) return nil, fmt.Errorf("new users svc error: %w", err)
} }
if cfg.BoolOr("Users.EnableAuth", true) && !userHdrs.IsInited() {
adminName, ok := cfg.String("ENV.DEFAULTADMIN")
if !ok || adminName == "" {
// only write to stdout
deps.Log().Info("Please input admin name: ")
fmt.Scanf("%s", &adminName)
}
adminPwd, _ := cfg.String("ENV.DEFAULTADMINPWD")
if adminPwd == "" {
adminPwd, err = generatePwd()
if err != nil {
return nil, fmt.Errorf("generate pwd error: %w", err)
}
// only write to stdout
fmt.Printf("password is generated: %s, please update it after login\n", adminPwd)
}
pwdHash, err := bcrypt.GenerateFromPassword([]byte(adminPwd), 10)
if err != nil {
return nil, fmt.Errorf("generate pwd error: %w", err)
}
if _, err := userHdrs.Init(adminName, string(pwdHash)); err != nil {
return nil, fmt.Errorf("init admin error: %w", err)
}
deps.Log().Infof("admin(%s) is created", adminName)
}
fileHdrs, err := fileshdr.NewFileHandlers(cfg, deps) fileHdrs, err := fileshdr.NewFileHandlers(cfg, deps)
if err != nil { if err != nil {
return nil, fmt.Errorf("new files service error: %w", err) return nil, fmt.Errorf("new files service error: %w", err)
} }
settingsSvc, err := settings.NewSettingsSvc(cfg, deps) settingsSvc, err := settings.NewSettingsSvc(cfg, deps)
if err != nil { if err != nil {
return nil, fmt.Errorf("new setting service error: %w", err) return nil, fmt.Errorf("new setting service error: %w", err)
@ -394,7 +374,7 @@ func initHandlers(router *gin.Engine, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.E
func initLogger(cfg gocfg.ICfg) *zap.SugaredLogger { func initLogger(cfg gocfg.ICfg) *zap.SugaredLogger {
fileWriter := zapcore.AddSync(&lumberjack.Logger{ fileWriter := zapcore.AddSync(&lumberjack.Logger{
Filename: filepath.Join(cfg.GrabString("Fs.Root"), "quickshare.log"), Filename: path.Join(cfg.GrabString("Fs.Root"), "quickshare.log"),
MaxSize: cfg.IntOr("Log.MaxSize", 50), // megabytes MaxSize: cfg.IntOr("Log.MaxSize", 50), // megabytes
MaxBackups: cfg.IntOr("Log.MaxBackups", 2), MaxBackups: cfg.IntOr("Log.MaxBackups", 2),
MaxAge: cfg.IntOr("Log.MaxAge", 31), // days MaxAge: cfg.IntOr("Log.MaxAge", 31), // days