feat(db): replace boltdb with sqlite

This commit is contained in:
hexxa 2022-09-03 18:40:55 +08:00 committed by Hexxa
parent 59a39efc4a
commit 791848f75c
16 changed files with 1749 additions and 543 deletions

View file

@ -1,25 +0,0 @@
package rdb
import (
"context"
"database/sql"
_ "github.com/mattn/go-sqlite3"
)
// TODO: expose more APIs if needed
type IDB interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
Close() error
PingContext(ctx context.Context) error
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
// Conn(ctx context.Context) (*Conn, error)
// Driver() driver.Driver
// SetConnMaxIdleTime(d time.Duration)
// SetConnMaxLifetime(d time.Duration)
// SetMaxIdleConns(n int)
// SetMaxOpenConns(n int)
// Stats() DBStats
}

View file

@ -0,0 +1,72 @@
package sqlite
import (
"context"
"encoding/json"
"github.com/ihexxa/quickshare/src/db"
)
func (st *SQLiteStore) getCfg(ctx context.Context) (*db.SiteConfig, error) {
var configStr string
err := st.db.QueryRowContext(
ctx,
`select config
from t_config
where id=0`,
).Scan(&configStr)
if err != nil {
return nil, err
}
config := &db.SiteConfig{}
err = json.Unmarshal([]byte(configStr), config)
if err != nil {
return nil, err
}
if err = db.CheckSiteCfg(config, true); err != nil {
return nil, err
}
return config, nil
}
func (st *SQLiteStore) setCfg(ctx context.Context, cfg *db.SiteConfig) error {
if err := db.CheckSiteCfg(cfg, false); err != nil {
return err
}
cfgBytes, err := json.Marshal(cfg)
if err != nil {
return err
}
_, err = st.db.ExecContext(
ctx,
`update t_config
set config=?
where id=0`,
string(cfgBytes),
)
return err
}
func (st *SQLiteStore) SetClientCfg(ctx context.Context, cfg *db.ClientConfig) error {
st.Lock()
defer st.Unlock()
siteCfg, err := st.getCfg(ctx)
if err != nil {
return err
}
siteCfg.ClientCfg = cfg
return st.setCfg(ctx, siteCfg)
}
func (st *SQLiteStore) GetCfg(ctx context.Context) (*db.SiteConfig, error) {
st.RLock()
defer st.RUnlock()
return st.getCfg(ctx)
}

276
src/db/rdb/sqlite/files.go Normal file
View file

@ -0,0 +1,276 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"path"
"strings"
"github.com/ihexxa/quickshare/src/db"
)
const (
InitNs = "Init"
InitTimeKey = "initTime"
SchemaVerKey = "SchemaVersion"
SchemaV1 = "v1"
)
var (
maxHashingTime = 10
)
func (st *SQLiteStore) getFileInfo(ctx context.Context, userId uint64, itemPath string) (*db.FileInfo, error) {
var infoStr string
fInfo := &db.FileInfo{}
var isDir bool
var size int64
var shareId string
err := st.db.QueryRowContext(
ctx,
`select is_dir, size, share_id, info
from t_file_info
where path=? and user=?
`,
itemPath,
userId,
).Scan(
&isDir,
&size,
&shareId,
&infoStr,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, db.ErrFileInfoNotFound
}
return nil, err
}
err = json.Unmarshal([]byte(infoStr), &fInfo)
if err != nil {
return nil, err
}
fInfo.IsDir = isDir
fInfo.Size = size
fInfo.ShareID = shareId
fInfo.Shared = shareId != ""
return fInfo, nil
}
func (st *SQLiteStore) GetFileInfo(ctx context.Context, userId uint64, itemPath string) (*db.FileInfo, error) {
st.RLock()
defer st.RUnlock()
return st.getFileInfo(ctx, userId, itemPath)
}
func (st *SQLiteStore) ListFileInfos(ctx context.Context, itemPaths []string) (map[string]*db.FileInfo, error) {
st.RLock()
defer st.RUnlock()
// TODO: add pagination
placeholders := []string{}
values := []any{}
for i := 0; i < len(itemPaths); i++ {
placeholders = append(placeholders, "?")
values = append(values, itemPaths[i])
}
rows, err := st.db.QueryContext(
ctx,
fmt.Sprintf(
`select path, is_dir, size, share_id, info
from t_file_info
where path in (%s)
`,
strings.Join(placeholders, ","),
),
values...,
)
if err != nil {
return nil, err
}
defer rows.Close()
var fInfoStr, itemPath, shareId string
var isDir bool
var size int64
fInfos := map[string]*db.FileInfo{}
for rows.Next() {
fInfo := &db.FileInfo{}
err = rows.Scan(&itemPath, &isDir, &size, &shareId, &fInfoStr)
if err != nil {
return nil, err
}
err = json.Unmarshal([]byte(fInfoStr), fInfo)
if err != nil {
return nil, err
}
fInfo.IsDir = isDir
fInfo.Size = size
fInfo.ShareID = shareId
fInfo.Shared = shareId != ""
fInfos[itemPath] = fInfo
}
if rows.Err() != nil {
return nil, rows.Err()
}
return fInfos, nil
}
func (st *SQLiteStore) addFileInfo(ctx context.Context, userId uint64, itemPath string, info *db.FileInfo) error {
infoStr, err := json.Marshal(info)
if err != nil {
return err
}
dirPath, itemName := path.Split(itemPath)
_, err = st.db.ExecContext(
ctx,
`insert into t_file_info
(path, user, parent, name, is_dir, size, share_id, info) values (?, ?, ?, ?, ?, ?, ?, ?)`,
itemPath,
userId,
dirPath,
itemName,
info.IsDir,
info.Size,
info.ShareID,
infoStr,
)
return err
}
func (st *SQLiteStore) AddFileInfo(ctx context.Context, userId uint64, itemPath string, info *db.FileInfo) error {
st.Lock()
defer st.Unlock()
err := st.addFileInfo(ctx, userId, itemPath, info)
if err != nil {
return err
}
// increase used space
return st.setUsed(ctx, userId, true, info.Size)
}
func (st *SQLiteStore) delFileInfo(ctx context.Context, userId uint64, itemPath string) error {
_, err := st.db.ExecContext(
ctx,
`delete from t_file_info
where path=? and user=?
`,
itemPath,
userId,
)
return err
}
// func (st *SQLiteStore) DelFileInfo(ctx context.Context, itemPath string) error {
// st.Lock()
// defer st.Unlock()
// return st.delFileInfo(ctx, itemPath)
// }
// sharings
func (st *SQLiteStore) SetSha1(ctx context.Context, userId uint64, itemPath, sign string) error {
st.Lock()
defer st.Unlock()
info, err := st.getFileInfo(ctx, userId, itemPath)
if err != nil {
return err
}
info.Sha1 = sign
infoStr, err := json.Marshal(info)
if err != nil {
return err
}
_, err = st.db.ExecContext(
ctx,
`update t_file_info
set info=?
where path=? and user=?`,
infoStr,
itemPath,
userId,
)
return err
}
func (st *SQLiteStore) DelFileInfo(ctx context.Context, userID uint64, itemPath string) error {
st.Lock()
defer st.Unlock()
// get all children and size
rows, err := st.db.QueryContext(
ctx,
`select path, size
from t_file_info
where path = ? or path like ?
`,
itemPath,
fmt.Sprintf("%s/%%", itemPath),
)
if err != nil {
return err
}
defer rows.Close()
var childrenPath string
var itemSize int64
placeholders := []string{}
values := []any{}
decrSize := int64(0)
for rows.Next() {
err = rows.Scan(&childrenPath, &itemSize)
if err != nil {
return err
}
placeholders = append(placeholders, "?")
values = append(values, childrenPath)
decrSize += itemSize
}
// decrease used space
err = st.setUsed(ctx, userID, false, decrSize)
if err != nil {
return err
}
// delete file info entries
_, err = st.db.ExecContext(
ctx,
fmt.Sprintf(
`delete from t_file_info
where path in (%s)`,
strings.Join(placeholders, ","),
),
values...,
)
return err
}
func (st *SQLiteStore) MoveFileInfos(ctx context.Context, userId uint64, oldPath, newPath string, isDir bool) error {
st.Lock()
defer st.Unlock()
info, err := st.getFileInfo(ctx, userId, oldPath)
if err != nil {
return err
}
err = st.delFileInfo(ctx, userId, oldPath)
if err != nil {
return err
}
return st.addFileInfo(ctx, userId, newPath, info)
}

View file

@ -0,0 +1,156 @@
package sqlite
import (
"context"
"crypto/sha1"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"path"
"time"
"github.com/ihexxa/quickshare/src/db"
)
func (st *SQLiteStore) generateShareID(payload string) (string, error) {
if len(payload) == 0 {
return "", db.ErrEmpty
}
msg := fmt.Sprintf("%s-%d", payload, time.Now().Unix())
h := sha1.New()
_, err := io.WriteString(h, msg)
if err != nil {
return "", err
}
return fmt.Sprintf("%x", h.Sum(nil))[:7], nil
}
func (st *SQLiteStore) IsSharing(ctx context.Context, userId uint64, dirPath string) bool {
st.RLock()
defer st.RUnlock()
// TODO: differentiate error and not exist
info, err := st.getFileInfo(ctx, userId, dirPath)
if err != nil {
return false
}
return info.ShareID != ""
}
func (st *SQLiteStore) GetSharingDir(ctx context.Context, hashID string) (string, error) {
st.RLock()
defer st.RUnlock()
var sharedPath string
err := st.db.QueryRowContext(
ctx,
`select path
from t_file_info
where share_id=?
`,
hashID,
).Scan(
&sharedPath,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", db.ErrSharingNotFound
}
return "", err
}
return sharedPath, nil
}
func (st *SQLiteStore) AddSharing(ctx context.Context, userId uint64, dirPath string) error {
st.Lock()
defer st.Unlock()
shareID, err := st.generateShareID(dirPath)
if err != nil {
return err
}
_, err = st.getFileInfo(ctx, userId, dirPath)
if err != nil && !errors.Is(err, db.ErrFileInfoNotFound) {
return err
}
if errors.Is(err, db.ErrFileInfoNotFound) {
// insert new
parentPath, name := path.Split(dirPath)
info := &db.FileInfo{Shared: true} // TODO: deprecate shared in info
infoStr, err := json.Marshal(info)
if err != nil {
return err
}
_, err = st.db.ExecContext(
ctx,
`insert into t_file_info
(path, user, parent, name, is_dir, size, share_id, info) values (?, ?, ?, ?, ?, ?, ?, ?)`,
dirPath, userId, parentPath, name, true, 0, shareID, infoStr,
)
return err
}
_, err = st.db.ExecContext(
ctx,
`update t_file_info
set share_id=?
where path=? and user=?`,
shareID, dirPath, userId,
)
return err
}
func (st *SQLiteStore) DelSharing(ctx context.Context, userId uint64, dirPath string) error {
st.Lock()
defer st.Unlock()
_, err := st.db.ExecContext(
ctx,
`update t_file_info
set share_id=''
where path=? and user=?`,
dirPath,
userId,
)
return err
}
func (st *SQLiteStore) ListUserSharings(ctx context.Context, userId uint64) (map[string]string, error) {
st.RLock()
defer st.RUnlock()
rows, err := st.db.QueryContext(
ctx,
`select path, share_id
from t_file_info
where user=? and share_id <> ''`,
userId,
)
if err != nil {
return nil, err
}
defer rows.Close()
var pathname, shareId string
pathToShareId := map[string]string{}
for rows.Next() {
err = rows.Scan(&pathname, &shareId)
if err != nil {
return nil, err
}
pathToShareId[pathname] = shareId
}
if rows.Err() != nil {
return nil, rows.Err()
}
return pathToShareId, nil
}

View file

@ -0,0 +1,192 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"github.com/ihexxa/quickshare/src/db"
)
func (st *SQLiteStore) addUploadInfoOnly(ctx context.Context, userId uint64, filePath, tmpPath string, fileSize int64) error {
_, err := st.db.ExecContext(
ctx,
`insert into t_file_uploading
(real_path, tmp_path, user, size, uploaded) values (?, ?, ?, ?, ?)`,
filePath, tmpPath, userId, fileSize, 0,
)
return err
}
func (st *SQLiteStore) AddUploadInfos(ctx context.Context, userId uint64, tmpPath, filePath string, info *db.FileInfo) error {
st.Lock()
defer st.Unlock()
userInfo, err := st.getUser(ctx, userId)
if err != nil {
return err
} else if userInfo.UsedSpace+info.Size > int64(userInfo.Quota.SpaceLimit) {
return db.ErrQuota
}
_, _, _, err = st.getUploadInfo(ctx, userId, filePath)
if err == nil {
return db.ErrKeyExisting
} else if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
userInfo.UsedSpace += info.Size
err = st.setUser(ctx, userInfo)
if err != nil {
return err
}
return st.addUploadInfoOnly(ctx, userId, filePath, tmpPath, info.Size)
}
func (st *SQLiteStore) DelUploadingInfos(ctx context.Context, userId uint64, realPath string) error {
st.Lock()
defer st.Unlock()
return st.delUploadingInfos(ctx, userId, realPath)
}
func (st *SQLiteStore) delUploadingInfos(ctx context.Context, userId uint64, realPath string) error {
_, size, _, err := st.getUploadInfo(ctx, userId, realPath)
if err != nil {
// info may not exist
return err
}
err = st.delUploadInfoOnly(ctx, userId, realPath)
if err != nil {
return err
}
userInfo, err := st.getUser(ctx, userId)
if err != nil {
return err
}
userInfo.UsedSpace -= size
return st.setUser(ctx, userInfo)
}
func (st *SQLiteStore) delUploadInfoOnly(ctx context.Context, userId uint64, filePath string) error {
_, err := st.db.ExecContext(
ctx,
`delete from t_file_uploading
where real_path=? and user=?`,
filePath, userId,
)
return err
}
// func (st *SQLiteStore) MoveUploadingInfos(ctx context.Context, userId uint64, uploadPath, itemPath string) error {
// st.Lock()
// defer st.Unlock()
// _, size, _, err := st.getUploadInfo(ctx, userId, itemPath)
// if err != nil {
// return err
// }
// err = st.delUploadInfoOnly(ctx, userId, itemPath)
// if err != nil {
// return err
// }
// return st.addFileInfo(ctx, userId, itemPath, &db.FileInfo{
// Size: size,
// })
// }
func (st *SQLiteStore) SetUploadInfo(ctx context.Context, userId uint64, filePath string, newUploaded int64) error {
st.Lock()
defer st.Unlock()
var size, uploaded int64
err := st.db.QueryRowContext(
ctx,
`select size, uploaded
from t_file_uploading
where real_path=? and user=?`,
filePath, userId,
).Scan(&size, &uploaded)
if err != nil {
return err
} else if newUploaded > size {
return db.ErrGreaterThanSize
}
_, err = st.db.ExecContext(
ctx,
`update t_file_uploading
set uploaded=?
where real_path=? and user=?`,
newUploaded, filePath, userId,
)
return err
}
func (st *SQLiteStore) getUploadInfo(ctx context.Context, userId uint64, filePath string) (string, int64, int64, error) {
var size, uploaded int64
err := st.db.QueryRowContext(
ctx,
`select size, uploaded
from t_file_uploading
where user=? and real_path=?`,
userId, filePath,
).Scan(&size, &uploaded)
if err != nil {
return "", 0, 0, err
}
return filePath, size, uploaded, nil
}
func (st *SQLiteStore) GetUploadInfo(ctx context.Context, userId uint64, filePath string) (string, int64, int64, error) {
st.RLock()
defer st.RUnlock()
return st.getUploadInfo(ctx, userId, filePath)
}
func (st *SQLiteStore) ListUploadInfos(ctx context.Context, userId uint64) ([]*db.UploadInfo, error) {
st.RLock()
defer st.RUnlock()
rows, err := st.db.QueryContext(
ctx,
`select real_path, size, uploaded
from t_file_uploading
where user=?`,
userId,
)
if err != nil {
return nil, err
}
defer rows.Close()
var pathname string
var size, uploaded int64
infos := []*db.UploadInfo{}
for rows.Next() {
err = rows.Scan(
&pathname,
&size,
&uploaded,
)
if err != nil {
return nil, err
}
infos = append(infos, &db.UploadInfo{
RealFilePath: pathname,
Size: size,
Uploaded: uploaded,
})
}
if rows.Err() != nil {
return nil, rows.Err()
}
return infos, nil
}

View file

@ -1,16 +1,21 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
_ "github.com/mattn/go-sqlite3"
"sync"
"time"
"github.com/ihexxa/quickshare/src/db/rdb"
"github.com/ihexxa/quickshare/src/db"
_ "github.com/mattn/go-sqlite3"
)
type SQLite struct {
rdb.IDB
db.IDB
dbPath string
mtx *sync.RWMutex
}
func NewSQLite(dbPath string) (*SQLite, error) {
@ -24,3 +29,159 @@ func NewSQLite(dbPath string) (*SQLite, error) {
dbPath: dbPath,
}, nil
}
func NewSQLiteStore(db db.IDB) (*SQLiteStore, error) {
return &SQLiteStore{
db: db,
mtx: &sync.RWMutex{},
}, nil
}
func (st *SQLiteStore) Lock() {
st.mtx.Lock()
}
func (st *SQLiteStore) Unlock() {
st.mtx.Unlock()
}
func (st *SQLiteStore) RLock() {
st.mtx.RLock()
}
func (st *SQLiteStore) RUnlock() {
st.mtx.RUnlock()
}
func (st *SQLiteStore) IsInited() bool {
// always try to init the db
return false
}
func (st *SQLiteStore) Init(ctx context.Context, rootName, rootPwd string, cfg *db.SiteConfig) error {
err := st.InitUserTable(ctx, rootName, rootPwd)
if err != nil {
return err
}
if err = st.InitFileTables(ctx); err != nil {
return err
}
return st.InitConfigTable(ctx, cfg)
}
func (st *SQLiteStore) InitUserTable(ctx context.Context, rootName, rootPwd string) error {
_, err := st.db.ExecContext(
ctx,
`create table if not exists t_user (
id bigint not null,
name varchar not null,
pwd varchar not null,
role integer not null,
used_space bigint not null,
quota varchar not null,
preference varchar not null,
primary key(id)
)`,
)
if err != nil {
return err
}
admin := &db.User{
ID: 0,
Name: rootName,
Pwd: rootPwd,
Role: db.AdminRole,
Quota: &db.Quota{
SpaceLimit: db.DefaultSpaceLimit,
UploadSpeedLimit: db.DefaultUploadSpeedLimit,
DownloadSpeedLimit: db.DefaultDownloadSpeedLimit,
},
Preferences: &db.DefaultPreferences,
}
visitor := &db.User{
ID: db.VisitorID,
Name: db.VisitorName,
Pwd: rootPwd,
Role: db.VisitorRole,
Quota: &db.Quota{
SpaceLimit: 0,
UploadSpeedLimit: db.VisitorUploadSpeedLimit,
DownloadSpeedLimit: db.VisitorDownloadSpeedLimit,
},
Preferences: &db.DefaultPreferences,
}
for _, user := range []*db.User{admin, visitor} {
err = st.AddUser(ctx, user)
if err != nil {
return err
}
}
return nil
}
func (st *SQLiteStore) InitFileTables(ctx context.Context) error {
_, err := st.db.ExecContext(
ctx,
`create table if not exists t_file_info (
path varchar not null,
user bigint not null,
parent varchar not null,
name varchar not null,
is_dir boolean not null,
size bigint not null,
share_id varchar not null,
info varchar not null,
primary key(path)
)`,
)
if err != nil {
return err
}
_, err = st.db.ExecContext(
ctx,
`create table if not exists t_file_uploading (
real_path varchar not null,
tmp_path varchar not null,
user bigint not null,
size bigint not null,
uploaded bigint not null,
primary key(real_path)
)`,
)
return err
}
func (st *SQLiteStore) InitConfigTable(ctx context.Context, cfg *db.SiteConfig) error {
st.Lock()
defer st.Unlock()
_, err := st.db.ExecContext(
ctx,
`create table if not exists t_config (
id bigint not null,
config varchar not null,
modified datetime not null,
primary key(id)
)`,
)
if err != nil {
return err
}
cfgStr, err := json.Marshal(cfg)
if err != nil {
return err
}
_, err = st.db.ExecContext(
ctx,
`insert into t_config
(id, config, modified) values (?, ?, ?)`,
0, cfgStr, time.Now(),
)
return err
}

View file

@ -5,97 +5,18 @@ import (
"database/sql"
"encoding/json"
"errors"
// "errors"
"fmt"
// "sync"
// "time"
"sync"
"github.com/ihexxa/quickshare/src/db"
"github.com/ihexxa/quickshare/src/db/rdb"
// "github.com/ihexxa/quickshare/src/kvstore"
)
// TODO: use sync.Pool instead
const (
VisitorID = uint64(1)
VisitorName = "visitor"
)
var (
ErrReachedLimit = errors.New("reached space limit")
ErrUserNotFound = errors.New("user not found")
ErrNegtiveUsedSpace = errors.New("used space can not be negative")
)
type SQLiteUsers struct {
db rdb.IDB
type SQLiteStore struct {
db db.IDB
mtx *sync.RWMutex
}
func NewSQLiteUsers(db rdb.IDB) (*SQLiteUsers, error) {
return &SQLiteUsers{db: db}, nil
}
func (u *SQLiteUsers) Init(ctx context.Context, rootName, rootPwd string) error {
_, err := u.db.ExecContext(
ctx,
`create table if not exists t_user (
id bigint not null,
name varchar not null,
pwd varchar not null,
role integer not null,
used_space bigint not null,
quota varchar not null,
preference varchar not null,
primary key(id)
)`,
)
if err != nil {
return err
}
admin := &db.User{
ID: 0,
Name: rootName,
Pwd: rootPwd,
Role: db.AdminRole,
Quota: &db.Quota{
SpaceLimit: db.DefaultSpaceLimit,
UploadSpeedLimit: db.DefaultUploadSpeedLimit,
DownloadSpeedLimit: db.DefaultDownloadSpeedLimit,
},
Preferences: &db.DefaultPreferences,
}
visitor := &db.User{
ID: VisitorID,
Name: VisitorName,
Pwd: rootPwd,
Role: db.VisitorRole,
Quota: &db.Quota{
SpaceLimit: 0,
UploadSpeedLimit: db.VisitorUploadSpeedLimit,
DownloadSpeedLimit: db.VisitorDownloadSpeedLimit,
},
Preferences: &db.DefaultPreferences,
}
for _, user := range []*db.User{admin, visitor} {
err = u.AddUser(ctx, user)
if err != nil {
return err
}
}
return nil
}
func (u *SQLiteUsers) IsInited() bool {
// always try to init the db
return false
}
// t_users
// id, name, pwd, role, used_space, config
func (u *SQLiteUsers) setUser(ctx context.Context, tx *sql.Tx, user *db.User) error {
func (st *SQLiteStore) setUser(ctx context.Context, user *db.User) error {
var err error
if err = db.CheckUser(user, false); err != nil {
return err
@ -109,7 +30,7 @@ func (u *SQLiteUsers) setUser(ctx context.Context, tx *sql.Tx, user *db.User) er
if err != nil {
return err
}
_, err = tx.ExecContext(
_, err = st.db.ExecContext(
ctx,
`update t_user
set name=?, pwd=?, role=?, used_space=?, quota=?, preference=?
@ -120,16 +41,15 @@ func (u *SQLiteUsers) setUser(ctx context.Context, tx *sql.Tx, user *db.User) er
user.UsedSpace,
quotaStr,
preferencesStr,
user.ID,
)
return err
}
func (u *SQLiteUsers) getUser(ctx context.Context, tx *sql.Tx, id uint64) (*db.User, error) {
var err error
func (st *SQLiteStore) getUser(ctx context.Context, id uint64) (*db.User, error) {
user := &db.User{}
var quotaStr, preferenceStr string
err = tx.QueryRowContext(
err := st.db.QueryRowContext(
ctx,
`select id, name, pwd, role, used_space, quota, preference
from t_user
@ -146,7 +66,7 @@ func (u *SQLiteUsers) getUser(ctx context.Context, tx *sql.Tx, id uint64) (*db.U
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
return nil, db.ErrUserNotFound
}
return nil, err
}
@ -162,7 +82,10 @@ func (u *SQLiteUsers) getUser(ctx context.Context, tx *sql.Tx, id uint64) (*db.U
return user, nil
}
func (u *SQLiteUsers) AddUser(ctx context.Context, user *db.User) error {
func (st *SQLiteStore) AddUser(ctx context.Context, user *db.User) error {
st.Lock()
defer st.Unlock()
quotaStr, err := json.Marshal(user.Quota)
if err != nil {
return err
@ -171,7 +94,7 @@ func (u *SQLiteUsers) AddUser(ctx context.Context, user *db.User) error {
if err != nil {
return err
}
_, err = u.db.ExecContext(
_, err = st.db.ExecContext(
ctx,
`insert into t_user (id, name, pwd, role, used_space, quota, preference) values (?, ?, ?, ?, ?, ?, ?)`,
user.ID,
@ -185,8 +108,11 @@ func (u *SQLiteUsers) AddUser(ctx context.Context, user *db.User) error {
return err
}
func (u *SQLiteUsers) DelUser(ctx context.Context, id uint64) error {
_, err := u.db.ExecContext(
func (st *SQLiteStore) DelUser(ctx context.Context, id uint64) error {
st.Lock()
defer st.Unlock()
_, err := st.db.ExecContext(
ctx,
`delete from t_user where id=?`,
id,
@ -194,36 +120,34 @@ func (u *SQLiteUsers) DelUser(ctx context.Context, id uint64) error {
return err
}
func (u *SQLiteUsers) GetUser(ctx context.Context, id uint64) (*db.User, error) {
tx, err := u.db.BeginTx(ctx, &sql.TxOptions{})
func (st *SQLiteStore) GetUser(ctx context.Context, id uint64) (*db.User, error) {
st.RLock()
defer st.RUnlock()
user, err := st.getUser(ctx, id)
if err != nil {
return nil, err
}
user, err := u.getUser(ctx, tx, id)
if err != nil {
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, err
}
return user, err
}
func (u *SQLiteUsers) GetUserByName(ctx context.Context, name string) (*db.User, error) {
func (st *SQLiteStore) GetUserByName(ctx context.Context, name string) (*db.User, error) {
st.RLock()
defer st.RUnlock()
user := &db.User{}
var quotaStr, preferenceStr string
err := u.db.QueryRowContext(
err := st.db.QueryRowContext(
ctx,
`select id, name, role, used_space, quota, preference
`select id, name, pwd, role, used_space, quota, preference
from t_user
where name=?`,
name,
).Scan(
&user.ID,
&user.Name,
&user.Pwd,
&user.Role,
&user.UsedSpace,
&quotaStr,
@ -231,7 +155,7 @@ func (u *SQLiteUsers) GetUserByName(ctx context.Context, name string) (*db.User,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
return nil, db.ErrUserNotFound
}
return nil, err
}
@ -247,8 +171,11 @@ func (u *SQLiteUsers) GetUserByName(ctx context.Context, name string) (*db.User,
return user, nil
}
func (u *SQLiteUsers) SetPwd(ctx context.Context, id uint64, pwd string) error {
_, err := u.db.ExecContext(
func (st *SQLiteStore) SetPwd(ctx context.Context, id uint64, pwd string) error {
st.Lock()
defer st.Unlock()
_, err := st.db.ExecContext(
ctx,
`update t_user
set pwd=?
@ -260,13 +187,16 @@ func (u *SQLiteUsers) SetPwd(ctx context.Context, id uint64, pwd string) error {
}
// role + quota
func (u *SQLiteUsers) SetInfo(ctx context.Context, id uint64, user *db.User) error {
func (st *SQLiteStore) SetInfo(ctx context.Context, id uint64, user *db.User) error {
st.Lock()
defer st.Unlock()
quotaStr, err := json.Marshal(user.Quota)
if err != nil {
return err
}
_, err = u.db.ExecContext(
_, err = st.db.ExecContext(
ctx,
`update t_user
set role=?, quota=?
@ -277,13 +207,16 @@ func (u *SQLiteUsers) SetInfo(ctx context.Context, id uint64, user *db.User) err
return err
}
func (u *SQLiteUsers) SetPreferences(ctx context.Context, id uint64, prefers *db.Preferences) error {
func (st *SQLiteStore) SetPreferences(ctx context.Context, id uint64, prefers *db.Preferences) error {
st.Lock()
defer st.Unlock()
preferenceStr, err := json.Marshal(prefers)
if err != nil {
return err
}
_, err = u.db.ExecContext(
_, err = st.db.ExecContext(
ctx,
`update t_user
set preference=?
@ -294,31 +227,32 @@ func (u *SQLiteUsers) SetPreferences(ctx context.Context, id uint64, prefers *db
return err
}
func (u *SQLiteUsers) SetUsed(ctx context.Context, id uint64, incr bool, capacity int64) error {
tx, err := u.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
func (st *SQLiteStore) SetUsed(ctx context.Context, id uint64, incr bool, capacity int64) error {
st.Lock()
defer st.Unlock()
return st.setUsed(ctx, id, incr, capacity)
}
gotUser, err := u.getUser(ctx, tx, id)
func (st *SQLiteStore) setUsed(ctx context.Context, id uint64, incr bool, capacity int64) error {
gotUser, err := st.getUser(ctx, id)
if err != nil {
return err
}
if incr && gotUser.UsedSpace+capacity > int64(gotUser.Quota.SpaceLimit) {
return ErrReachedLimit
return db.ErrReachedLimit
}
if incr {
gotUser.UsedSpace = gotUser.UsedSpace + capacity
} else {
if gotUser.UsedSpace-capacity < 0 {
return ErrNegtiveUsedSpace
return db.ErrNegtiveUsedSpace
}
gotUser.UsedSpace = gotUser.UsedSpace - capacity
}
_, err = tx.ExecContext(
_, err = st.db.ExecContext(
ctx,
`update t_user
set used_space=?
@ -330,11 +264,14 @@ func (u *SQLiteUsers) SetUsed(ctx context.Context, id uint64, incr bool, capacit
return err
}
return tx.Commit()
return nil
}
func (u *SQLiteUsers) ResetUsed(ctx context.Context, id uint64, used int64) error {
_, err := u.db.ExecContext(
func (st *SQLiteStore) ResetUsed(ctx context.Context, id uint64, used int64) error {
st.Lock()
defer st.Unlock()
_, err := st.db.ExecContext(
ctx,
`update t_user
set used_space=?
@ -345,16 +282,19 @@ func (u *SQLiteUsers) ResetUsed(ctx context.Context, id uint64, used int64) erro
return err
}
func (u *SQLiteUsers) ListUsers(ctx context.Context) ([]*db.User, error) {
func (st *SQLiteStore) ListUsers(ctx context.Context) ([]*db.User, error) {
st.RLock()
defer st.RUnlock()
// TODO: support pagination
rows, err := u.db.QueryContext(
rows, err := st.db.QueryContext(
ctx,
`select id, name, role, used_space, quota, preference
from t_user`,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
return nil, db.ErrUserNotFound
}
return nil, err
}
@ -389,8 +329,11 @@ func (u *SQLiteUsers) ListUsers(ctx context.Context) ([]*db.User, error) {
return users, nil
}
func (u *SQLiteUsers) ListUserIDs(ctx context.Context) (map[string]string, error) {
users, err := u.ListUsers(ctx)
func (st *SQLiteStore) ListUserIDs(ctx context.Context) (map[string]string, error) {
st.RLock()
defer st.RUnlock()
users, err := st.ListUsers(ctx)
if err != nil {
return nil, err
}
@ -402,17 +345,17 @@ func (u *SQLiteUsers) ListUserIDs(ctx context.Context) (map[string]string, error
return nameToId, nil
}
func (u *SQLiteUsers) AddRole(role string) error {
func (st *SQLiteStore) AddRole(role string) error {
// TODO: implement this after adding grant/revoke
panic("not implemented")
}
func (u *SQLiteUsers) DelRole(role string) error {
func (st *SQLiteStore) DelRole(role string) error {
// TODO: implement this after adding grant/revoke
panic("not implemented")
}
func (u *SQLiteUsers) ListRoles() (map[string]bool, error) {
func (st *SQLiteStore) ListRoles() (map[string]bool, error) {
// TODO: implement this after adding grant/revoke
panic("not implemented")
}

View file

@ -1,293 +0,0 @@
package sqlite
import (
"context"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/ihexxa/quickshare/src/db"
)
func TestUserStores(t *testing.T) {
rootName, rootPwd := "root", "rootPwd"
testUserMethods := func(t *testing.T, store db.IUserStore) {
ctx := context.TODO()
root, err := store.GetUser(ctx, 0)
if err != nil {
t.Fatal(err)
}
if root.Name != rootName {
t.Fatal("root user not found")
}
if root.Pwd != rootPwd {
t.Fatalf("passwords not match (%s) (%s)", root.Pwd, rootPwd)
}
if root.Role != db.AdminRole {
t.Fatalf("incorrect root role")
}
if root.Quota.SpaceLimit != db.DefaultSpaceLimit {
t.Fatalf("incorrect root SpaceLimit")
}
if root.Quota.UploadSpeedLimit != db.DefaultUploadSpeedLimit {
t.Fatalf("incorrect root UploadSpeedLimit")
}
if root.Quota.DownloadSpeedLimit != db.DefaultDownloadSpeedLimit {
t.Fatalf("incorrect root DownloadSpeedLimit")
}
if !db.ComparePreferences(root.Preferences, &db.DefaultPreferences) {
t.Fatalf("incorrect preference %v %v", root.Preferences, db.DefaultPreferences)
}
visitor, err := store.GetUser(ctx, 1)
if err != nil {
t.Fatal(err)
}
if visitor.Name != VisitorName {
t.Fatal("visitor not found")
}
if visitor.Pwd != rootPwd {
t.Fatalf("passwords not match %s", err)
}
if visitor.Role != db.VisitorRole {
t.Fatalf("incorrect visitor role")
}
if visitor.Quota.SpaceLimit != 0 {
t.Fatalf("incorrect visitor SpaceLimit")
}
if visitor.Quota.UploadSpeedLimit != db.VisitorUploadSpeedLimit {
t.Fatalf("incorrect visitor UploadSpeedLimit")
}
if visitor.Quota.DownloadSpeedLimit != db.VisitorDownloadSpeedLimit {
t.Fatalf("incorrect visitor DownloadSpeedLimit")
}
if !db.ComparePreferences(visitor.Preferences, &db.DefaultPreferences) {
t.Fatalf("incorrect preference")
}
id, name1 := uint64(2), "test_user1"
pwd1, pwd2 := "666", "888"
role1, role2 := db.UserRole, db.AdminRole
spaceLimit1, upLimit1, downLimit1 := int64(17), 5, 7
spaceLimit2, upLimit2, downLimit2 := int64(19), 13, 17
err = store.AddUser(ctx, &db.User{
ID: id,
Name: name1,
Pwd: pwd1,
Role: role1,
Quota: &db.Quota{
SpaceLimit: spaceLimit1,
UploadSpeedLimit: upLimit1,
DownloadSpeedLimit: downLimit1,
},
Preferences: &db.DefaultPreferences,
})
if err != nil {
t.Fatal("there should be no error")
}
user, err := store.GetUser(ctx, id)
if err != nil {
t.Fatal(err)
}
if user.Name != name1 {
t.Fatalf("names not matched %s %s", name1, user.Name)
}
if user.Pwd != pwd1 {
t.Fatalf("passwords not match %s", err)
}
if user.Role != role1 {
t.Fatalf("roles not matched %s %s", role1, user.Role)
}
if user.Quota.SpaceLimit != spaceLimit1 {
t.Fatalf("space limit not matched %d %d", spaceLimit1, user.Quota.SpaceLimit)
}
if user.Quota.UploadSpeedLimit != upLimit1 {
t.Fatalf("up limit not matched %d %d", upLimit1, user.Quota.UploadSpeedLimit)
}
if user.Quota.DownloadSpeedLimit != downLimit1 {
t.Fatalf("down limit not matched %d %d", downLimit1, user.Quota.DownloadSpeedLimit)
}
users, err := store.ListUsers(ctx)
if err != nil {
t.Fatal(err)
}
if len(users) != 3 {
t.Fatalf("users size should be 3 (%d)", len(users))
}
for _, user := range users {
if user.ID == 0 {
if user.Name != rootName || user.Role != db.AdminRole {
t.Fatalf("incorrect root info %v", user)
}
}
if user.ID == id {
if user.Name != name1 || user.Role != role1 {
t.Fatalf("incorrect user info %v", user)
}
}
if user.Pwd != "" {
t.Fatalf("password must be empty")
}
}
err = store.SetPwd(ctx, id, pwd2)
if err != nil {
t.Fatal(err)
}
store.SetInfo(ctx, id, &db.User{
ID: id,
Role: role2,
Quota: &db.Quota{
SpaceLimit: spaceLimit2,
UploadSpeedLimit: upLimit2,
DownloadSpeedLimit: downLimit2,
},
})
usedIncr, usedDecr := int64(spaceLimit2), int64(7)
err = store.SetUsed(ctx, id, true, usedIncr)
if err != nil {
t.Fatal(err)
}
err = store.SetUsed(ctx, id, false, usedDecr)
if err != nil {
t.Fatal(err)
}
err = store.SetUsed(ctx, id, true, int64(spaceLimit2)-(usedIncr-usedDecr)+1)
if err == nil || !strings.Contains(err.Error(), "reached space limit") {
t.Fatal("should reject big file")
} else {
err = nil
}
user, err = store.GetUser(ctx, id)
if err != nil {
t.Fatal(err)
}
if user.Pwd != pwd2 {
t.Fatalf("passwords not match %s %s", user.Pwd, pwd2)
}
if user.Role != role2 {
t.Fatalf("roles not matched %s %s", role2, user.Role)
}
if user.Quota.SpaceLimit != spaceLimit2 {
t.Fatalf("space limit not matched %d %d", spaceLimit2, user.Quota.SpaceLimit)
}
if user.Quota.UploadSpeedLimit != upLimit2 {
t.Fatalf("up limit not matched %d %d", upLimit2, user.Quota.UploadSpeedLimit)
}
if user.Quota.DownloadSpeedLimit != downLimit2 {
t.Fatalf("down limit not matched %d %d", downLimit2, user.Quota.DownloadSpeedLimit)
}
if user.UsedSpace != usedIncr-usedDecr {
t.Fatalf("used space not matched %d %d", user.UsedSpace, usedIncr-usedDecr)
}
time.Sleep(5 * time.Second)
newPrefer := &db.Preferences{
Bg: &db.BgConfig{
Url: "/url",
Repeat: "repeat",
Position: "center",
Align: "fixed",
BgColor: "#333",
},
CSSURL: "/cssurl",
LanPackURL: "lanPackURL",
Lan: "zhCN",
Theme: "dark",
Avatar: "/avatar",
Email: "foo@gmail.com",
}
err = store.SetPreferences(ctx, id, newPrefer)
if err != nil {
t.Fatal(err)
}
user, err = store.GetUserByName(ctx, name1)
if err != nil {
t.Fatal(err)
}
if user.ID != id {
t.Fatalf("ids not matched %d %d", id, user.ID)
}
if user.Pwd != pwd2 {
t.Fatalf("passwords not match %s", err)
}
if user.Role != role2 {
t.Fatalf("roles not matched %s %s", role2, user.Role)
}
if user.Quota.SpaceLimit != spaceLimit2 {
t.Fatalf("space limit not matched %d %d", spaceLimit2, user.Quota.SpaceLimit)
}
if user.Quota.UploadSpeedLimit != upLimit2 {
t.Fatalf("up limit not matched %d %d", upLimit2, user.Quota.UploadSpeedLimit)
}
if user.Quota.DownloadSpeedLimit != downLimit2 {
t.Fatalf("down limit not matched %d %d", downLimit2, user.Quota.DownloadSpeedLimit)
}
if !db.ComparePreferences(user.Preferences, newPrefer) {
t.Fatalf("preferences not matched %v %v", user.Preferences, newPrefer)
}
err = store.DelUser(ctx, id)
if err != nil {
t.Fatal(err)
}
users, err = store.ListUsers(ctx)
if err != nil {
t.Fatal(err)
}
if len(users) != 2 {
t.Fatalf("users size should be 2 (%d)", len(users))
}
for _, user := range users {
if user.ID == 0 && user.Name != rootName && user.Role != db.AdminRole {
t.Fatalf("incorrect root info %v", user)
}
if user.ID == VisitorID && user.Name != VisitorName && user.Role != db.VisitorRole {
t.Fatalf("incorrect visitor info %v", user)
}
}
nameToID, err := store.ListUserIDs(ctx)
if err != nil {
t.Fatal(err)
}
if len(nameToID) != len(users) {
t.Fatalf("nameToID size (%d) should be same as (%d)", len(nameToID), len(users))
}
}
t.Run("testing UserStore sqlite", func(t *testing.T) {
rootPath, err := ioutil.TempDir("./", "quickshare_userstore_test_")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rootPath)
dbPath := filepath.Join(rootPath, "quickshare.sqlite")
sqliteDB, err := NewSQLite(dbPath)
if err != nil {
t.Fatal(err)
}
defer sqliteDB.Close()
store, err := NewSQLiteUsers(sqliteDB)
if err != nil {
t.Fatal("fail to new user store", err)
}
if err = store.Init(context.TODO(), rootName, rootPwd); err != nil {
t.Fatal("fail to init user store", err)
}
testUserMethods(t, store)
})
}