diff --git a/src/db/interfaces.go b/src/db/interfaces.go index 4645f82..5a3b8a5 100644 --- a/src/db/interfaces.go +++ b/src/db/interfaces.go @@ -27,9 +27,9 @@ type IDB interface { type IDBQuickshare interface { Init(ctx context.Context, adminName, adminPwd string, config *SiteConfig) error - InitUserTable(ctx context.Context, rootName, rootPwd string) error - InitFileTables(ctx context.Context) error - InitConfigTable(ctx context.Context, cfg *SiteConfig) error + InitUserTable(ctx context.Context, tx *sql.Tx, rootName, rootPwd string) error + InitFileTables(ctx context.Context, tx *sql.Tx) error + InitConfigTable(ctx context.Context, tx *sql.Tx, cfg *SiteConfig) error Close() error IDBLockable IUserDB diff --git a/src/db/rdb/sqlite/configs.go b/src/db/rdb/sqlite/configs.go index 1d69f64..aa50f69 100644 --- a/src/db/rdb/sqlite/configs.go +++ b/src/db/rdb/sqlite/configs.go @@ -2,71 +2,20 @@ package sqlite import ( "context" - "encoding/json" "github.com/ihexxa/quickshare/src/db" ) -func (st *SQLiteStore) getCfg(ctx context.Context) (*db.SiteConfig, error) { - var configStr string - err := st.db.QueryRowContext( - ctx, - `select config - from t_config - where id=0`, - ).Scan(&configStr) - if err != nil { - return nil, err - } - - config := &db.SiteConfig{} - err = json.Unmarshal([]byte(configStr), config) - if err != nil { - return nil, err - } - - if err = db.CheckSiteCfg(config, true); err != nil { - return nil, err - } - return config, nil -} - -func (st *SQLiteStore) setCfg(ctx context.Context, cfg *db.SiteConfig) error { - if err := db.CheckSiteCfg(cfg, false); err != nil { - return err - } - - cfgBytes, err := json.Marshal(cfg) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `update t_config - set config=? - where id=0`, - string(cfgBytes), - ) - return err -} - func (st *SQLiteStore) SetClientCfg(ctx context.Context, cfg *db.ClientConfig) error { st.Lock() defer st.Unlock() - siteCfg, err := st.getCfg(ctx) - if err != nil { - return err - } - siteCfg.ClientCfg = cfg - - return st.setCfg(ctx, siteCfg) + return st.store.SetClientCfg(ctx, cfg) } func (st *SQLiteStore) GetCfg(ctx context.Context) (*db.SiteConfig, error) { st.RLock() defer st.RUnlock() - return st.getCfg(ctx) + return st.store.GetCfg(ctx) } diff --git a/src/db/rdb/sqlite/files.go b/src/db/rdb/sqlite/files.go index 2388845..0215ab5 100644 --- a/src/db/rdb/sqlite/files.go +++ b/src/db/rdb/sqlite/files.go @@ -2,277 +2,48 @@ package sqlite import ( "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "path" - "strings" "github.com/ihexxa/quickshare/src/db" ) -func (st *SQLiteStore) getFileInfo(ctx context.Context, itemPath string) (*db.FileInfo, error) { - var infoStr string - fInfo := &db.FileInfo{} - var id uint64 - var isDir bool - var size int64 - var shareId string - err := st.db.QueryRowContext( - ctx, - `select id, is_dir, size, share_id, info - from t_file_info - where path=?`, - itemPath, - ).Scan( - &id, - &isDir, - &size, - &shareId, - &infoStr, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, db.ErrFileInfoNotFound - } - return nil, err - } - - err = json.Unmarshal([]byte(infoStr), &fInfo) - if err != nil { - return nil, err - } - fInfo.Id = id - fInfo.IsDir = isDir - fInfo.Size = size - fInfo.ShareID = shareId - fInfo.Shared = shareId != "" - return fInfo, nil -} - func (st *SQLiteStore) GetFileInfo(ctx context.Context, itemPath string) (*db.FileInfo, error) { st.RLock() defer st.RUnlock() - return st.getFileInfo(ctx, itemPath) + return st.store.GetFileInfo(ctx, itemPath) } func (st *SQLiteStore) ListFileInfos(ctx context.Context, itemPaths []string) (map[string]*db.FileInfo, error) { st.RLock() defer st.RUnlock() - // TODO: add pagination - placeholders := []string{} - values := []any{} - for i := 0; i < len(itemPaths); i++ { - placeholders = append(placeholders, "?") - values = append(values, itemPaths[i]) - } - rows, err := st.db.QueryContext( - ctx, - fmt.Sprintf( - `select id, path, is_dir, size, share_id, info - from t_file_info - where path in (%s) - `, - strings.Join(placeholders, ","), - ), - values..., - ) - if err != nil { - return nil, err - } - defer rows.Close() - - var fInfoStr, itemPath, shareId string - var isDir bool - var size int64 - var id uint64 - fInfos := map[string]*db.FileInfo{} - for rows.Next() { - fInfo := &db.FileInfo{} - - err = rows.Scan(&id, &itemPath, &isDir, &size, &shareId, &fInfoStr) - if err != nil { - return nil, err - } - - err = json.Unmarshal([]byte(fInfoStr), fInfo) - if err != nil { - return nil, err - } - fInfo.Id = id - fInfo.IsDir = isDir - fInfo.Size = size - fInfo.ShareID = shareId - fInfo.Shared = shareId != "" - fInfos[itemPath] = fInfo - } - if rows.Err() != nil { - return nil, rows.Err() - } - - return fInfos, nil -} - -func (st *SQLiteStore) addFileInfo(ctx context.Context, infoId, userId uint64, itemPath string, info *db.FileInfo) error { - infoStr, err := json.Marshal(info) - if err != nil { - return err - } - - location, err := getLocation(itemPath) - if err != nil { - return err - } - - dirPath, itemName := path.Split(itemPath) - _, err = st.db.ExecContext( - ctx, - `insert into t_file_info ( - id, path, user, location, parent, name, - is_dir, size, share_id, info - ) - values ( - ?, ?, ?, ?, ?, ?, - ?, ?, ?, ? - )`, - infoId, itemPath, userId, location, dirPath, itemName, - info.IsDir, info.Size, info.ShareID, infoStr, - ) - return err + return st.store.ListFileInfos(ctx, itemPaths) } func (st *SQLiteStore) AddFileInfo(ctx context.Context, infoId, userId uint64, itemPath string, info *db.FileInfo) error { st.Lock() defer st.Unlock() - err := st.addFileInfo(ctx, infoId, userId, itemPath, info) - if err != nil { - return err - } - - // increase used space - return st.setUsed(ctx, userId, true, info.Size) -} - -func (st *SQLiteStore) delFileInfo(ctx context.Context, itemPath string) error { - _, err := st.db.ExecContext( - ctx, - `delete from t_file_info - where path=? - `, - itemPath, - ) - return err + return st.store.AddFileInfo(ctx, infoId, userId, itemPath, info) } func (st *SQLiteStore) SetSha1(ctx context.Context, itemPath, sign string) error { st.Lock() defer st.Unlock() - info, err := st.getFileInfo(ctx, itemPath) - if err != nil { - return err - } - info.Sha1 = sign - - infoStr, err := json.Marshal(info) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `update t_file_info - set info=? - where path=?`, - infoStr, - itemPath, - ) - return err + return st.store.SetSha1(ctx, itemPath, sign) } func (st *SQLiteStore) DelFileInfo(ctx context.Context, userID uint64, itemPath string) error { st.Lock() defer st.Unlock() - // get all children and size - rows, err := st.db.QueryContext( - ctx, - `select path, size - from t_file_info - where path = ? or path like ? - `, - itemPath, - fmt.Sprintf("%s/%%", itemPath), - ) - if err != nil { - return err - } - defer rows.Close() - - var childrenPath string - var itemSize int64 - placeholders := []string{} - values := []any{} - decrSize := int64(0) - for rows.Next() { - err = rows.Scan(&childrenPath, &itemSize) - if err != nil { - return err - } - placeholders = append(placeholders, "?") - values = append(values, childrenPath) - decrSize += itemSize - } - - // decrease used space - err = st.setUsed(ctx, userID, false, decrSize) - if err != nil { - return err - } - - // delete file info entries - _, err = st.db.ExecContext( - ctx, - fmt.Sprintf( - `delete from t_file_info - where path in (%s)`, - strings.Join(placeholders, ","), - ), - values..., - ) - return err + return st.store.DelFileInfo(ctx, userID, itemPath) } func (st *SQLiteStore) MoveFileInfo(ctx context.Context, userId uint64, oldPath, newPath string, isDir bool) error { st.Lock() defer st.Unlock() - info, err := st.getFileInfo(ctx, oldPath) - if err != nil { - if errors.Is(err, db.ErrFileInfoNotFound) { - // info for file does not exist so no need to move it - // e.g. folder info is not created before - // TODO: but sometimes it could be a bug - return nil - } - return err - } - err = st.delFileInfo(ctx, oldPath) - if err != nil { - return err - } - return st.addFileInfo(ctx, info.Id, userId, newPath, info) -} - -func getLocation(itemPath string) (string, error) { - // location is taken from item path - itemPathParts := strings.Split(itemPath, "/") - if len(itemPathParts) == 0 { - return "", fmt.Errorf("invalid item path '%s'", itemPath) - } - return itemPathParts[0], nil + return st.store.MoveFileInfo(ctx, userId, oldPath, newPath, isDir) } diff --git a/src/db/rdb/sqlite/files_sharings.go b/src/db/rdb/sqlite/files_sharings.go index 019f192..5a1f39e 100644 --- a/src/db/rdb/sqlite/files_sharings.go +++ b/src/db/rdb/sqlite/files_sharings.go @@ -2,181 +2,39 @@ package sqlite import ( "context" - "crypto/sha1" - "database/sql" - "encoding/json" - "errors" - "fmt" - "io" - "path" - "time" - - "github.com/ihexxa/quickshare/src/db" ) -func (st *SQLiteStore) generateShareID(payload string) (string, error) { - if len(payload) == 0 { - return "", db.ErrEmpty - } - - msg := fmt.Sprintf("%s-%d", payload, time.Now().Unix()) - h := sha1.New() - _, err := io.WriteString(h, msg) - if err != nil { - return "", err - } - - return fmt.Sprintf("%x", h.Sum(nil))[:7], nil -} - func (st *SQLiteStore) IsSharing(ctx context.Context, dirPath string) (bool, error) { st.RLock() defer st.RUnlock() - var shareId string - err := st.db.QueryRowContext( - ctx, - `select share_id - from t_file_info - where path=?`, - dirPath, - ).Scan( - &shareId, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return false, db.ErrFileInfoNotFound - } - return false, err - } - - return shareId != "", nil + return st.store.IsSharing(ctx, dirPath) } func (st *SQLiteStore) GetSharingDir(ctx context.Context, hashID string) (string, error) { st.RLock() defer st.RUnlock() - var sharedPath string - err := st.db.QueryRowContext( - ctx, - `select path - from t_file_info - where share_id=? - `, - hashID, - ).Scan( - &sharedPath, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return "", db.ErrSharingNotFound - } - return "", err - } - - return sharedPath, nil + return st.store.GetSharingDir(ctx, hashID) } func (st *SQLiteStore) AddSharing(ctx context.Context, infoId, userId uint64, dirPath string) error { st.Lock() defer st.Unlock() - shareID, err := st.generateShareID(dirPath) - if err != nil { - return err - } - - location, err := getLocation(dirPath) - if err != nil { - return err - } - - _, err = st.getFileInfo(ctx, dirPath) - if err != nil && !errors.Is(err, db.ErrFileInfoNotFound) { - return err - } - - if errors.Is(err, db.ErrFileInfoNotFound) { - // insert new - parentPath, name := path.Split(dirPath) - info := &db.FileInfo{Shared: true} // TODO: deprecate shared in info - infoStr, err := json.Marshal(info) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `insert into t_file_info ( - id, path, user, - location, parent, name, - is_dir, size, share_id, info - ) - values ( - ?, ?, ?, - ?, ?, ?, - ?, ?, ?, ? - )`, - infoId, dirPath, userId, - location, parentPath, name, - true, 0, shareID, infoStr, - ) - return err - } - - _, err = st.db.ExecContext( - ctx, - `update t_file_info - set share_id=? - where path=?`, - shareID, dirPath, - ) - return err + return st.store.AddSharing(ctx, infoId, userId, dirPath) } func (st *SQLiteStore) DelSharing(ctx context.Context, userId uint64, dirPath string) error { st.Lock() defer st.Unlock() - _, err := st.db.ExecContext( - ctx, - `update t_file_info - set share_id='' - where path=?`, - dirPath, - ) - return err + return st.store.DelSharing(ctx, userId, dirPath) } func (st *SQLiteStore) ListSharingsByLocation(ctx context.Context, location string) (map[string]string, error) { st.RLock() defer st.RUnlock() - rows, err := st.db.QueryContext( - ctx, - `select path, share_id - from t_file_info - where share_id<>'' and location=?`, - location, - ) - if err != nil { - return nil, err - } - defer rows.Close() - - var pathname, shareId string - pathToShareId := map[string]string{} - for rows.Next() { - err = rows.Scan(&pathname, &shareId) - if err != nil { - return nil, err - } - pathToShareId[pathname] = shareId - } - if rows.Err() != nil { - return nil, rows.Err() - } - - return pathToShareId, nil + return st.store.ListSharingsByLocation(ctx, location) } diff --git a/src/db/rdb/sqlite/files_uploadings.go b/src/db/rdb/sqlite/files_uploadings.go index 7650520..d686b3b 100644 --- a/src/db/rdb/sqlite/files_uploadings.go +++ b/src/db/rdb/sqlite/files_uploadings.go @@ -2,195 +2,48 @@ package sqlite import ( "context" - "database/sql" - "errors" "github.com/ihexxa/quickshare/src/db" ) -func (st *SQLiteStore) addUploadInfoOnly(ctx context.Context, uploadId, userId uint64, tmpPath, filePath string, fileSize int64) error { - _, err := st.db.ExecContext( - ctx, - `insert into t_file_uploading ( - id, real_path, tmp_path, user, size, uploaded - ) - values ( - ?, ?, ?, ?, ?, ? - )`, - uploadId, filePath, tmpPath, userId, fileSize, 0, - ) - return err -} - func (st *SQLiteStore) AddUploadInfos(ctx context.Context, uploadId, userId uint64, tmpPath, filePath string, info *db.FileInfo) error { st.Lock() defer st.Unlock() - userInfo, err := st.getUser(ctx, userId) - if err != nil { - return err - } else if userInfo.UsedSpace+info.Size > int64(userInfo.Quota.SpaceLimit) { - return db.ErrQuota - } - - _, _, _, err = st.getUploadInfo(ctx, userId, filePath) - if err == nil { - return db.ErrKeyExisting - } else if err != nil && !errors.Is(err, sql.ErrNoRows) { - return err - } - - userInfo.UsedSpace += info.Size - err = st.setUser(ctx, userInfo) - if err != nil { - return err - } - - return st.addUploadInfoOnly(ctx, uploadId, userId, tmpPath, filePath, info.Size) + return st.store.AddUploadInfos(ctx, uploadId, userId, tmpPath, filePath, info) } func (st *SQLiteStore) DelUploadingInfos(ctx context.Context, userId uint64, realPath string) error { st.Lock() defer st.Unlock() - return st.delUploadingInfos(ctx, userId, realPath) -} - -func (st *SQLiteStore) delUploadingInfos(ctx context.Context, userId uint64, realPath string) error { - _, size, _, err := st.getUploadInfo(ctx, userId, realPath) - if err != nil { - // info may not exist - return err - } - - err = st.delUploadInfoOnly(ctx, userId, realPath) - if err != nil { - return err - } - - userInfo, err := st.getUser(ctx, userId) - if err != nil { - return err - } - userInfo.UsedSpace -= size - return st.setUser(ctx, userInfo) -} - -func (st *SQLiteStore) delUploadInfoOnly(ctx context.Context, userId uint64, filePath string) error { - _, err := st.db.ExecContext( - ctx, - `delete from t_file_uploading - where real_path=? and user=?`, - filePath, userId, - ) - return err + return st.store.DelUploadingInfos(ctx, userId, realPath) } func (st *SQLiteStore) MoveUploadingInfos(ctx context.Context, infoId, userId uint64, uploadPath, itemPath string) error { st.Lock() defer st.Unlock() - _, size, _, err := st.getUploadInfo(ctx, userId, itemPath) - if err != nil { - return err - } - err = st.delUploadInfoOnly(ctx, userId, itemPath) - if err != nil { - return err - } - return st.addFileInfo(ctx, infoId, userId, itemPath, &db.FileInfo{ - Size: size, - }) + return st.store.MoveUploadingInfos(ctx, infoId, userId, uploadPath, itemPath) } func (st *SQLiteStore) SetUploadInfo(ctx context.Context, userId uint64, filePath string, newUploaded int64) error { st.Lock() defer st.Unlock() - var size, uploaded int64 - err := st.db.QueryRowContext( - ctx, - `select size, uploaded - from t_file_uploading - where real_path=? and user=?`, - filePath, userId, - ).Scan(&size, &uploaded) - if err != nil { - return err - } else if newUploaded > size { - return db.ErrGreaterThanSize - } - - _, err = st.db.ExecContext( - ctx, - `update t_file_uploading - set uploaded=? - where real_path=? and user=?`, - newUploaded, filePath, userId, - ) - return err -} - -func (st *SQLiteStore) getUploadInfo(ctx context.Context, userId uint64, filePath string) (string, int64, int64, error) { - var size, uploaded int64 - err := st.db.QueryRowContext( - ctx, - `select size, uploaded - from t_file_uploading - where real_path=? and user=?`, - filePath, userId, - ).Scan(&size, &uploaded) - if err != nil { - return "", 0, 0, err - } - - return filePath, size, uploaded, nil + return st.store.SetUploadInfo(ctx, userId, filePath, newUploaded) } func (st *SQLiteStore) GetUploadInfo(ctx context.Context, userId uint64, filePath string) (string, int64, int64, error) { st.RLock() defer st.RUnlock() - return st.getUploadInfo(ctx, userId, filePath) + + return st.store.GetUploadInfo(ctx, userId, filePath) } func (st *SQLiteStore) ListUploadInfos(ctx context.Context, userId uint64) ([]*db.UploadInfo, error) { st.RLock() defer st.RUnlock() - rows, err := st.db.QueryContext( - ctx, - `select real_path, size, uploaded - from t_file_uploading - where user=?`, - userId, - ) - if err != nil { - return nil, err - } - defer rows.Close() - - var pathname string - var size, uploaded int64 - infos := []*db.UploadInfo{} - for rows.Next() { - err = rows.Scan( - &pathname, - &size, - &uploaded, - ) - if err != nil { - return nil, err - } - - infos = append(infos, &db.UploadInfo{ - RealFilePath: pathname, - Size: size, - Uploaded: uploaded, - }) - } - if rows.Err() != nil { - return nil, rows.Err() - } - - return infos, nil + return st.store.ListUploadInfos(ctx, userId) } diff --git a/src/db/rdb/sqlite/init.go b/src/db/rdb/sqlite/init.go index 67e1596..4d053cb 100644 --- a/src/db/rdb/sqlite/init.go +++ b/src/db/rdb/sqlite/init.go @@ -3,20 +3,17 @@ package sqlite import ( "context" "database/sql" - "encoding/json" - "errors" "fmt" "sync" - "time" "github.com/ihexxa/quickshare/src/db" + "github.com/ihexxa/quickshare/src/db/rdb/base" _ "github.com/mattn/go-sqlite3" ) type SQLite struct { db.IDB dbPath string - mtx *sync.RWMutex } func NewSQLite(dbPath string) (*SQLite, error) { @@ -31,15 +28,20 @@ func NewSQLite(dbPath string) (*SQLite, error) { }, nil } +type SQLiteStore struct { + store *base.BaseStore + mtx *sync.RWMutex +} + func NewSQLiteStore(db db.IDB) (*SQLiteStore, error) { return &SQLiteStore{ - db: db, - mtx: &sync.RWMutex{}, + store: base.NewBaseStore(db), + mtx: &sync.RWMutex{}, }, nil } func (st *SQLiteStore) Close() error { - return st.db.Close() + return st.store.Close() } func (st *SQLiteStore) Lock() { @@ -64,189 +66,20 @@ func (st *SQLiteStore) IsInited() bool { } func (st *SQLiteStore) Init(ctx context.Context, rootName, rootPwd string, cfg *db.SiteConfig) error { - err := st.InitUserTable(ctx, rootName, rootPwd) - if err != nil { - return err - } - - if err = st.InitFileTables(ctx); err != nil { - return err - } - - return st.InitConfigTable(ctx, cfg) -} - -func (st *SQLiteStore) InitUserTable(ctx context.Context, rootName, rootPwd string) error { - _, err := st.db.ExecContext( - ctx, - `create table if not exists t_user ( - id bigint not null, - name varchar not null unique, - pwd varchar not null, - role integer not null, - used_space bigint not null, - quota varchar not null, - preference varchar not null, - primary key(id) - )`, - ) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `create index if not exists i_user_name on t_user (name)`, - ) - if err != nil { - return err - } - - admin := &db.User{ - ID: 0, - Name: rootName, - Pwd: rootPwd, - Role: db.AdminRole, - Quota: &db.Quota{ - SpaceLimit: db.DefaultSpaceLimit, - UploadSpeedLimit: db.DefaultUploadSpeedLimit, - DownloadSpeedLimit: db.DefaultDownloadSpeedLimit, - }, - Preferences: &db.DefaultPreferences, - } - visitor := &db.User{ - ID: db.VisitorID, - Name: db.VisitorName, - Pwd: rootPwd, - Role: db.VisitorRole, - Quota: &db.Quota{ - SpaceLimit: 0, - UploadSpeedLimit: db.VisitorUploadSpeedLimit, - DownloadSpeedLimit: db.VisitorDownloadSpeedLimit, - }, - Preferences: &db.DefaultPreferences, - } - for _, user := range []*db.User{admin, visitor} { - // TODO: not atomic - _, err := st.GetUser(ctx, user.ID) - if err != nil { - if errors.Is(err, db.ErrUserNotFound) { - err = st.AddUser(ctx, user) - if err != nil { - return err - } - } else { - return err - } - } - } - - return nil -} - -func (st *SQLiteStore) InitFileTables(ctx context.Context) error { - _, err := st.db.ExecContext( - ctx, - `create table if not exists t_file_info ( - id bigint not null, - path varchar not null, - user bigint not null, - location varchar not null, - parent varchar not null, - name varchar not null, - is_dir boolean not null, - size bigint not null, - share_id varchar not null, - info varchar not null, - primary key(id) - )`, - ) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `create index if not exists t_file_path on t_file_info (path, location)`, - ) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `create index if not exists t_file_share on t_file_info (share_id, location)`, - ) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `create table if not exists t_file_uploading ( - id bigint not null, - real_path varchar not null, - tmp_path varchar not null unique, - user bigint not null, - size bigint not null, - uploaded bigint not null, - primary key(id) - )`, - ) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `create index if not exists t_file_uploading_path on t_file_uploading (real_path, user)`, - ) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `create index if not exists t_file_uploading_user on t_file_uploading (user)`, - ) - return err -} - -func (st *SQLiteStore) InitConfigTable(ctx context.Context, cfg *db.SiteConfig) error { st.Lock() defer st.Unlock() - _, err := st.db.ExecContext( - ctx, - `create table if not exists t_config ( - id bigint not null, - config varchar not null, - modified datetime not null, - primary key(id) - )`, - ) - if err != nil { - return err - } - - cfgStr, err := json.Marshal(cfg) - if err != nil { - return err - } - - _, err = st.getCfg(ctx) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - _, err = st.db.ExecContext( - ctx, - `insert into t_config - (id, config, modified) values (?, ?, ?)`, - 0, cfgStr, time.Now(), - ) - return err - } - return err - } - - return nil + return st.store.Init(ctx, rootName, rootPwd, cfg) +} + +func (st *SQLiteStore) InitUserTable(ctx context.Context, tx *sql.Tx, rootName, rootPwd string) error { + return st.store.InitUserTable(ctx, tx, rootName, rootPwd) +} + +func (st *SQLiteStore) InitFileTables(ctx context.Context, tx *sql.Tx) error { + return st.store.InitFileTables(ctx, tx) +} + +func (st *SQLiteStore) InitConfigTable(ctx context.Context, tx *sql.Tx, cfg *db.SiteConfig) error { + return st.store.InitConfigTable(ctx, tx, cfg) } diff --git a/src/db/rdb/sqlite/users.go b/src/db/rdb/sqlite/users.go index 6a68dd1..5359252 100644 --- a/src/db/rdb/sqlite/users.go +++ b/src/db/rdb/sqlite/users.go @@ -2,188 +2,43 @@ package sqlite import ( "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "sync" "github.com/ihexxa/quickshare/src/db" ) -type SQLiteStore struct { - db db.IDB - mtx *sync.RWMutex -} - -func (st *SQLiteStore) setUser(ctx context.Context, user *db.User) error { - var err error - if err = db.CheckUser(user, false); err != nil { - return err - } - - quotaStr, err := json.Marshal(user.Quota) - if err != nil { - return err - } - preferencesStr, err := json.Marshal(user.Preferences) - if err != nil { - return err - } - _, err = st.db.ExecContext( - ctx, - `update t_user - set name=?, pwd=?, role=?, used_space=?, quota=?, preference=? - where id=?`, - user.Name, - user.Pwd, - user.Role, - user.UsedSpace, - quotaStr, - preferencesStr, - user.ID, - ) - return err -} - -func (st *SQLiteStore) getUser(ctx context.Context, id uint64) (*db.User, error) { - user := &db.User{} - var quotaStr, preferenceStr string - err := st.db.QueryRowContext( - ctx, - `select id, name, pwd, role, used_space, quota, preference - from t_user - where id=?`, - id, - ).Scan( - &user.ID, - &user.Name, - &user.Pwd, - &user.Role, - &user.UsedSpace, - "aStr, - &preferenceStr, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, db.ErrUserNotFound - } - return nil, err - } - - err = json.Unmarshal([]byte(quotaStr), &user.Quota) - if err != nil { - return nil, err - } - err = json.Unmarshal([]byte(preferenceStr), &user.Preferences) - if err != nil { - return nil, err - } - return user, nil -} - func (st *SQLiteStore) AddUser(ctx context.Context, user *db.User) error { st.Lock() defer st.Unlock() - quotaStr, err := json.Marshal(user.Quota) - if err != nil { - return err - } - preferenceStr, err := json.Marshal(user.Preferences) - if err != nil { - return err - } - _, err = st.db.ExecContext( - ctx, - `insert into t_user (id, name, pwd, role, used_space, quota, preference) values (?, ?, ?, ?, ?, ?, ?)`, - user.ID, - user.Name, - user.Pwd, - user.Role, - user.UsedSpace, - quotaStr, - preferenceStr, - ) - return err + return st.store.AddUser(ctx, user) } func (st *SQLiteStore) DelUser(ctx context.Context, id uint64) error { st.Lock() defer st.Unlock() - _, err := st.db.ExecContext( - ctx, - `delete from t_user where id=?`, - id, - ) - return err + return st.store.DelUser(ctx, id) } func (st *SQLiteStore) GetUser(ctx context.Context, id uint64) (*db.User, error) { st.RLock() defer st.RUnlock() - user, err := st.getUser(ctx, id) - if err != nil { - return nil, err - } - - return user, err + return st.store.GetUser(ctx, id) } func (st *SQLiteStore) GetUserByName(ctx context.Context, name string) (*db.User, error) { st.RLock() defer st.RUnlock() - user := &db.User{} - var quotaStr, preferenceStr string - err := st.db.QueryRowContext( - ctx, - `select id, name, pwd, role, used_space, quota, preference - from t_user - where name=?`, - name, - ).Scan( - &user.ID, - &user.Name, - &user.Pwd, - &user.Role, - &user.UsedSpace, - "aStr, - &preferenceStr, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, db.ErrUserNotFound - } - return nil, err - } - - err = json.Unmarshal([]byte(quotaStr), &user.Quota) - if err != nil { - return nil, err - } - err = json.Unmarshal([]byte(preferenceStr), &user.Preferences) - if err != nil { - return nil, err - } - return user, nil + return st.store.GetUserByName(ctx, name) } func (st *SQLiteStore) SetPwd(ctx context.Context, id uint64, pwd string) error { st.Lock() defer st.Unlock() - _, err := st.db.ExecContext( - ctx, - `update t_user - set pwd=? - where id=?`, - pwd, - id, - ) - return err + return st.store.SetPwd(ctx, id, pwd) } // role + quota @@ -191,158 +46,42 @@ func (st *SQLiteStore) SetInfo(ctx context.Context, id uint64, user *db.User) er st.Lock() defer st.Unlock() - quotaStr, err := json.Marshal(user.Quota) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `update t_user - set role=?, quota=? - where id=?`, - user.Role, quotaStr, - id, - ) - return err + return st.store.SetInfo(ctx, id, user) } func (st *SQLiteStore) SetPreferences(ctx context.Context, id uint64, prefers *db.Preferences) error { st.Lock() defer st.Unlock() - preferenceStr, err := json.Marshal(prefers) - if err != nil { - return err - } - - _, err = st.db.ExecContext( - ctx, - `update t_user - set preference=? - where id=?`, - preferenceStr, - id, - ) - return err + return st.store.SetPreferences(ctx, id, prefers) } func (st *SQLiteStore) SetUsed(ctx context.Context, id uint64, incr bool, capacity int64) error { st.Lock() defer st.Unlock() - return st.setUsed(ctx, id, incr, capacity) -} -func (st *SQLiteStore) setUsed(ctx context.Context, id uint64, incr bool, capacity int64) error { - gotUser, err := st.getUser(ctx, id) - if err != nil { - return err - } - - if incr && gotUser.UsedSpace+capacity > int64(gotUser.Quota.SpaceLimit) { - return db.ErrReachedLimit - } - - if incr { - gotUser.UsedSpace = gotUser.UsedSpace + capacity - } else { - if gotUser.UsedSpace-capacity < 0 { - return db.ErrNegtiveUsedSpace - } - gotUser.UsedSpace = gotUser.UsedSpace - capacity - } - - _, err = st.db.ExecContext( - ctx, - `update t_user - set used_space=? - where id=?`, - gotUser.UsedSpace, - gotUser.ID, - ) - if err != nil { - return err - } - - return nil + return st.store.SetUsed(ctx, id, incr, capacity) } func (st *SQLiteStore) ResetUsed(ctx context.Context, id uint64, used int64) error { st.Lock() defer st.Unlock() - _, err := st.db.ExecContext( - ctx, - `update t_user - set used_space=? - where id=?`, - used, - id, - ) - return err + return st.store.ResetUsed(ctx, id, used) } func (st *SQLiteStore) ListUsers(ctx context.Context) ([]*db.User, error) { st.RLock() defer st.RUnlock() - // TODO: support pagination - rows, err := st.db.QueryContext( - ctx, - `select id, name, role, used_space, quota, preference - from t_user`, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, db.ErrUserNotFound - } - return nil, err - } - defer rows.Close() // TODO: check error - - users := []*db.User{} - for rows.Next() { - user := &db.User{} - var quotaStr, preferenceStr string - err = rows.Scan( - &user.ID, - &user.Name, - &user.Role, - &user.UsedSpace, - "aStr, - &preferenceStr, - ) - err = json.Unmarshal([]byte(quotaStr), &user.Quota) - if err != nil { - return nil, err - } - err = json.Unmarshal([]byte(preferenceStr), &user.Preferences) - if err != nil { - return nil, err - } - - users = append(users, user) - } - if rows.Err() != nil { - return nil, rows.Err() - } - return users, nil + return st.store.ListUsers(ctx) } func (st *SQLiteStore) ListUserIDs(ctx context.Context) (map[string]string, error) { st.RLock() defer st.RUnlock() - users, err := st.ListUsers(ctx) - if err != nil { - return nil, err - } - - nameToId := map[string]string{} - for _, user := range users { - nameToId[user.Name] = fmt.Sprint(user.ID) - } - return nameToId, nil + return st.store.ListUserIDs(ctx) } func (st *SQLiteStore) AddRole(role string) error {