diff --git a/src/db/rdb/base/configs.go b/src/db/rdb/base/configs.go new file mode 100644 index 0000000..ba63783 --- /dev/null +++ b/src/db/rdb/base/configs.go @@ -0,0 +1,94 @@ +package base + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/ihexxa/quickshare/src/db" +) + +func (st *BaseStore) getCfg(ctx context.Context, tx *sql.Tx) (*db.SiteConfig, error) { + var configStr string + err := tx.QueryRowContext( + ctx, + `select config + from t_config + where id=0`, + ).Scan(&configStr) + if err != nil { + return nil, err + } + + config := &db.SiteConfig{} + err = json.Unmarshal([]byte(configStr), config) + if err != nil { + return nil, err + } + + if err = db.CheckSiteCfg(config, true); err != nil { + return nil, err + } + return config, nil +} + +func (st *BaseStore) setCfg(ctx context.Context, tx *sql.Tx, cfg *db.SiteConfig) error { + if err := db.CheckSiteCfg(cfg, false); err != nil { + return err + } + + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `update t_config + set config=? + where id=0`, + string(cfgBytes), + ) + return err +} + +func (st *BaseStore) SetClientCfg(ctx context.Context, cfg *db.ClientConfig) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + siteCfg, err := st.getCfg(ctx, tx) + if err != nil { + return err + } + siteCfg.ClientCfg = cfg + + err = st.setCfg(ctx, tx, siteCfg) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) GetCfg(ctx context.Context) (*db.SiteConfig, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + siteConfig, err := st.getCfg(ctx, tx) + if err != nil { + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + + return siteConfig, nil +} diff --git a/src/db/rdb/base/files.go b/src/db/rdb/base/files.go new file mode 100644 index 0000000..13b289c --- /dev/null +++ b/src/db/rdb/base/files.go @@ -0,0 +1,326 @@ +package base + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "path" + "strings" + + "github.com/ihexxa/quickshare/src/db" +) + +func (st *BaseStore) getFileInfo(ctx context.Context, tx *sql.Tx, itemPath string) (*db.FileInfo, error) { + var infoStr string + fInfo := &db.FileInfo{} + var id uint64 + var isDir bool + var size int64 + var shareId string + err := tx.QueryRowContext( + ctx, + `select id, is_dir, size, share_id, info + from t_file_info + where path=?`, + itemPath, + ).Scan( + &id, + &isDir, + &size, + &shareId, + &infoStr, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, db.ErrFileInfoNotFound + } + return nil, err + } + + err = json.Unmarshal([]byte(infoStr), &fInfo) + if err != nil { + return nil, err + } + fInfo.Id = id + fInfo.IsDir = isDir + fInfo.Size = size + fInfo.ShareID = shareId + fInfo.Shared = shareId != "" + return fInfo, nil +} + +func (st *BaseStore) GetFileInfo(ctx context.Context, itemPath string) (*db.FileInfo, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + info, err := st.getFileInfo(ctx, tx, itemPath) + if err != nil { + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return info, err +} + +func (st *BaseStore) ListFileInfos(ctx context.Context, itemPaths []string) (map[string]*db.FileInfo, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + // TODO: add pagination + placeholders := []string{} + values := []any{} + for i := 0; i < len(itemPaths); i++ { + placeholders = append(placeholders, "?") + values = append(values, itemPaths[i]) + } + rows, err := tx.QueryContext( + ctx, + fmt.Sprintf( + `select id, path, is_dir, size, share_id, info + from t_file_info + where path in (%s) + `, + strings.Join(placeholders, ","), + ), + values..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var fInfoStr, itemPath, shareId string + var isDir bool + var size int64 + var id uint64 + fInfos := map[string]*db.FileInfo{} + for rows.Next() { + fInfo := &db.FileInfo{} + + err = rows.Scan(&id, &itemPath, &isDir, &size, &shareId, &fInfoStr) + if err != nil { + return nil, err + } + + err = json.Unmarshal([]byte(fInfoStr), fInfo) + if err != nil { + return nil, err + } + fInfo.Id = id + fInfo.IsDir = isDir + fInfo.Size = size + fInfo.ShareID = shareId + fInfo.Shared = shareId != "" + fInfos[itemPath] = fInfo + } + if rows.Err() != nil { + return nil, rows.Err() + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return fInfos, nil +} + +func (st *BaseStore) addFileInfo(ctx context.Context, tx *sql.Tx, infoId, userId uint64, itemPath string, info *db.FileInfo) error { + infoStr, err := json.Marshal(info) + if err != nil { + return err + } + + location, err := getLocation(itemPath) + if err != nil { + return err + } + + dirPath, itemName := path.Split(itemPath) + _, err = tx.ExecContext( + ctx, + `insert into t_file_info ( + id, path, user, location, parent, name, + is_dir, size, share_id, info + ) + values ( + ?, ?, ?, ?, ?, ?, + ?, ?, ?, ? + )`, + infoId, itemPath, userId, location, dirPath, itemName, + info.IsDir, info.Size, info.ShareID, infoStr, + ) + return err +} + +func (st *BaseStore) AddFileInfo(ctx context.Context, infoId, userId uint64, itemPath string, info *db.FileInfo) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + err = st.addFileInfo(ctx, tx, infoId, userId, itemPath, info) + if err != nil { + return err + } + + // increase used space + err = st.setUsed(ctx, tx, userId, true, info.Size) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) delFileInfo(ctx context.Context, tx *sql.Tx, itemPath string) error { + _, err := tx.ExecContext( + ctx, + `delete from t_file_info + where path=? + `, + itemPath, + ) + return err +} + +func (st *BaseStore) SetSha1(ctx context.Context, itemPath, sign string) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + info, err := st.getFileInfo(ctx, tx, itemPath) + if err != nil { + return err + } + info.Sha1 = sign + + infoStr, err := json.Marshal(info) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `update t_file_info + set info=? + where path=?`, + infoStr, + itemPath, + ) + if err != nil { + return err + } + return tx.Commit() +} + +func (st *BaseStore) DelFileInfo(ctx context.Context, userID uint64, itemPath string) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + // get all children and size + rows, err := tx.QueryContext( + ctx, + `select path, size + from t_file_info + where path = ? or path like ? + `, + itemPath, + fmt.Sprintf("%s/%%", itemPath), + ) + if err != nil { + return err + } + defer rows.Close() + + var childrenPath string + var itemSize int64 + placeholders := []string{} + values := []any{} + decrSize := int64(0) + for rows.Next() { + err = rows.Scan(&childrenPath, &itemSize) + if err != nil { + return err + } + placeholders = append(placeholders, "?") + values = append(values, childrenPath) + decrSize += itemSize + } + + // decrease used space + err = st.setUsed(ctx, tx, userID, false, decrSize) + if err != nil { + return err + } + + // delete file info entries + _, err = tx.ExecContext( + ctx, + fmt.Sprintf( + `delete from t_file_info + where path in (%s)`, + strings.Join(placeholders, ","), + ), + values..., + ) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) MoveFileInfo(ctx context.Context, userId uint64, oldPath, newPath string, isDir bool) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + info, err := st.getFileInfo(ctx, tx, oldPath) + if err != nil { + if errors.Is(err, db.ErrFileInfoNotFound) { + // info for file does not exist so no need to move it + // e.g. folder info is not created before + // TODO: but sometimes it could be a bug + return nil + } + return err + } + err = st.delFileInfo(ctx, tx, oldPath) + if err != nil { + return err + } + err = st.addFileInfo(ctx, tx, info.Id, userId, newPath, info) + if err != nil { + return err + } + + return tx.Commit() +} + +func getLocation(itemPath string) (string, error) { + // location is taken from item path + itemPathParts := strings.Split(itemPath, "/") + if len(itemPathParts) == 0 { + return "", fmt.Errorf("invalid item path '%s'", itemPath) + } + return itemPathParts[0], nil +} diff --git a/src/db/rdb/base/files_sharings.go b/src/db/rdb/base/files_sharings.go new file mode 100644 index 0000000..1b56853 --- /dev/null +++ b/src/db/rdb/base/files_sharings.go @@ -0,0 +1,219 @@ +package base + +import ( + "context" + "crypto/sha1" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "path" + "time" + + "github.com/ihexxa/quickshare/src/db" +) + +func (st *BaseStore) generateShareID(payload string) (string, error) { + if len(payload) == 0 { + return "", db.ErrEmpty + } + + msg := fmt.Sprintf("%s-%d", payload, time.Now().Unix()) + h := sha1.New() + _, err := io.WriteString(h, msg) + if err != nil { + return "", err + } + + return fmt.Sprintf("%x", h.Sum(nil))[:7], nil +} + +func (st *BaseStore) IsSharing(ctx context.Context, dirPath string) (bool, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return false, err + } + defer tx.Rollback() + + var shareId string + err = tx.QueryRowContext( + ctx, + `select share_id + from t_file_info + where path=?`, + dirPath, + ).Scan( + &shareId, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, db.ErrFileInfoNotFound + } + return false, err + } + + err = tx.Commit() + if err != nil { + return false, err + } + return shareId != "", nil +} + +func (st *BaseStore) GetSharingDir(ctx context.Context, hashID string) (string, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return "", err + } + defer tx.Rollback() + + var sharedPath string + err = tx.QueryRowContext( + ctx, + `select path + from t_file_info + where share_id=? + `, + hashID, + ).Scan( + &sharedPath, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", db.ErrSharingNotFound + } + return "", err + } + + err = tx.Commit() + if err != nil { + return "", err + } + return sharedPath, nil +} + +func (st *BaseStore) AddSharing(ctx context.Context, infoId, userId uint64, dirPath string) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + shareID, err := st.generateShareID(dirPath) + if err != nil { + return err + } + + location, err := getLocation(dirPath) + if err != nil { + return err + } + + _, err = st.getFileInfo(ctx, tx, dirPath) + if err != nil && !errors.Is(err, db.ErrFileInfoNotFound) { + return err + } + + if errors.Is(err, db.ErrFileInfoNotFound) { + // insert new + parentPath, name := path.Split(dirPath) + info := &db.FileInfo{Shared: true} // TODO: deprecate shared in info + infoStr, err := json.Marshal(info) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `insert into t_file_info ( + id, path, user, + location, parent, name, + is_dir, size, share_id, info + ) + values ( + ?, ?, ?, + ?, ?, ?, + ?, ?, ?, ? + )`, + infoId, dirPath, userId, + location, parentPath, name, + true, 0, shareID, infoStr, + ) + if err != nil { + return err + } + } + + _, err = tx.ExecContext( + ctx, + `update t_file_info + set share_id=? + where path=?`, + shareID, dirPath, + ) + if err != nil { + return err + } + return tx.Commit() +} + +func (st *BaseStore) DelSharing(ctx context.Context, userId uint64, dirPath string) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext( + ctx, + `update t_file_info + set share_id='' + where path=?`, + dirPath, + ) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) ListSharingsByLocation(ctx context.Context, location string) (map[string]string, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + rows, err := tx.QueryContext( + ctx, + `select path, share_id + from t_file_info + where share_id<>'' and location=?`, + location, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var pathname, shareId string + pathToShareId := map[string]string{} + for rows.Next() { + err = rows.Scan(&pathname, &shareId) + if err != nil { + return nil, err + } + pathToShareId[pathname] = shareId + } + if rows.Err() != nil { + return nil, rows.Err() + } + + err = tx.Commit() + if err != nil { + return nil, err + } + + return pathToShareId, nil +} diff --git a/src/db/rdb/base/files_uploadings.go b/src/db/rdb/base/files_uploadings.go new file mode 100644 index 0000000..26bec3c --- /dev/null +++ b/src/db/rdb/base/files_uploadings.go @@ -0,0 +1,244 @@ +package base + +import ( + "context" + "database/sql" + "errors" + + "github.com/ihexxa/quickshare/src/db" +) + +func (st *BaseStore) addUploadInfoOnly(ctx context.Context, tx *sql.Tx, uploadId, userId uint64, tmpPath, filePath string, fileSize int64) error { + _, err := tx.ExecContext( + ctx, + `insert into t_file_uploading ( + id, real_path, tmp_path, user, size, uploaded + ) + values ( + ?, ?, ?, ?, ?, ? + )`, + uploadId, filePath, tmpPath, userId, fileSize, 0, + ) + return err +} + +func (st *BaseStore) AddUploadInfos(ctx context.Context, uploadId, userId uint64, tmpPath, filePath string, info *db.FileInfo) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + userInfo, err := st.getUser(ctx, tx, userId) + if err != nil { + return err + } else if userInfo.UsedSpace+info.Size > int64(userInfo.Quota.SpaceLimit) { + return db.ErrQuota + } + + _, _, _, err = st.getUploadInfo(ctx, tx, userId, filePath) + if err == nil { + return db.ErrKeyExisting + } else if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + userInfo.UsedSpace += info.Size + err = st.setUser(ctx, tx, userInfo) + if err != nil { + return err + } + + err = st.addUploadInfoOnly(ctx, tx, uploadId, userId, tmpPath, filePath, info.Size) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) DelUploadingInfos(ctx context.Context, userId uint64, realPath string) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + err = st.delUploadingInfos(ctx, tx, userId, realPath) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) delUploadingInfos(ctx context.Context, tx *sql.Tx, userId uint64, realPath string) error { + _, size, _, err := st.getUploadInfo(ctx, tx, userId, realPath) + if err != nil { + // info may not exist + return err + } + + err = st.delUploadInfoOnly(ctx, tx, userId, realPath) + if err != nil { + return err + } + + userInfo, err := st.getUser(ctx, tx, userId) + if err != nil { + return err + } + userInfo.UsedSpace -= size + return st.setUser(ctx, tx, userInfo) +} + +func (st *BaseStore) delUploadInfoOnly(ctx context.Context, tx *sql.Tx, userId uint64, filePath string) error { + _, err := tx.ExecContext( + ctx, + `delete from t_file_uploading + where real_path=? and user=?`, + filePath, userId, + ) + return err +} + +func (st *BaseStore) MoveUploadingInfos(ctx context.Context, infoId, userId uint64, uploadPath, itemPath string) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + _, size, _, err := st.getUploadInfo(ctx, tx, userId, itemPath) + if err != nil { + return err + } + err = st.delUploadInfoOnly(ctx, tx, userId, itemPath) + if err != nil { + return err + } + err = st.addFileInfo(ctx, tx, infoId, userId, itemPath, &db.FileInfo{ + Size: size, + }) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) SetUploadInfo(ctx context.Context, userId uint64, filePath string, newUploaded int64) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + var size, uploaded int64 + err = tx.QueryRowContext( + ctx, + `select size, uploaded + from t_file_uploading + where real_path=? and user=?`, + filePath, userId, + ).Scan(&size, &uploaded) + if err != nil { + return err + } else if newUploaded > size { + return db.ErrGreaterThanSize + } + + _, err = tx.ExecContext( + ctx, + `update t_file_uploading + set uploaded=? + where real_path=? and user=?`, + newUploaded, filePath, userId, + ) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) getUploadInfo(ctx context.Context, tx *sql.Tx, userId uint64, filePath string) (string, int64, int64, error) { + var size, uploaded int64 + err := tx.QueryRowContext( + ctx, + `select size, uploaded + from t_file_uploading + where real_path=? and user=?`, + filePath, userId, + ).Scan(&size, &uploaded) + if err != nil { + return "", 0, 0, err + } + + return filePath, size, uploaded, nil +} + +func (st *BaseStore) GetUploadInfo(ctx context.Context, userId uint64, filePath string) (string, int64, int64, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return "", 0, 0, err + } + defer tx.Rollback() + + filePath, size, uploaded, err := st.getUploadInfo(ctx, tx, userId, filePath) + if err != nil { + return filePath, size, uploaded, err + } + + err = tx.Commit() + return filePath, size, uploaded, err +} + +func (st *BaseStore) ListUploadInfos(ctx context.Context, userId uint64) ([]*db.UploadInfo, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + rows, err := tx.QueryContext( + ctx, + `select real_path, size, uploaded + from t_file_uploading + where user=?`, + userId, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var pathname string + var size, uploaded int64 + infos := []*db.UploadInfo{} + for rows.Next() { + err = rows.Scan( + &pathname, + &size, + &uploaded, + ) + if err != nil { + return nil, err + } + + infos = append(infos, &db.UploadInfo{ + RealFilePath: pathname, + Size: size, + Uploaded: uploaded, + }) + } + if rows.Err() != nil { + return nil, rows.Err() + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return infos, nil +} diff --git a/src/db/rdb/base/init.go b/src/db/rdb/base/init.go new file mode 100644 index 0000000..ed43bbc --- /dev/null +++ b/src/db/rdb/base/init.go @@ -0,0 +1,233 @@ +package base + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "time" + + "github.com/ihexxa/quickshare/src/db" + _ "github.com/mattn/go-sqlite3" +) + +var ( + txOpts = &sql.TxOptions{} +) + +type BaseStore struct { + db db.IDB +} + +func NewBaseStore(db db.IDB) *BaseStore { + return &BaseStore{ + db: db, + } +} + +func (st *BaseStore) Db() db.IDB { + return st.db +} + +func (st *BaseStore) Close() error { + return st.db.Close() +} + +func (st *BaseStore) IsInited() bool { + // always try to init the db + return false +} + +func (st *BaseStore) Init(ctx context.Context, rootName, rootPwd string, cfg *db.SiteConfig) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + err = st.InitUserTable(ctx, tx, rootName, rootPwd) + if err != nil { + return err + } + + if err = st.InitFileTables(ctx, tx); err != nil { + return err + } + + if err = st.InitConfigTable(ctx, tx, cfg); err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) InitUserTable(ctx context.Context, tx *sql.Tx, rootName, rootPwd string) error { + _, err := tx.ExecContext( + ctx, + `create table if not exists t_user ( + id bigint not null, + name varchar not null unique, + pwd varchar not null, + role integer not null, + used_space bigint not null, + quota varchar not null, + preference varchar not null, + primary key(id) + )`, + ) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `create index if not exists i_user_name on t_user (name)`, + ) + if err != nil { + return err + } + + admin := &db.User{ + ID: 0, + Name: rootName, + Pwd: rootPwd, + Role: db.AdminRole, + Quota: &db.Quota{ + SpaceLimit: db.DefaultSpaceLimit, + UploadSpeedLimit: db.DefaultUploadSpeedLimit, + DownloadSpeedLimit: db.DefaultDownloadSpeedLimit, + }, + Preferences: &db.DefaultPreferences, + } + visitor := &db.User{ + ID: db.VisitorID, + Name: db.VisitorName, + Pwd: rootPwd, + Role: db.VisitorRole, + Quota: &db.Quota{ + SpaceLimit: 0, + UploadSpeedLimit: db.VisitorUploadSpeedLimit, + DownloadSpeedLimit: db.VisitorDownloadSpeedLimit, + }, + Preferences: &db.DefaultPreferences, + } + for _, user := range []*db.User{admin, visitor} { + _, err := st.getUser(ctx, tx, user.ID) + if err != nil { + if errors.Is(err, db.ErrUserNotFound) { + err = st.addUser(ctx, tx, user) + if err != nil { + return err + } + } else { + return err + } + } + } + + return nil +} + +func (st *BaseStore) InitFileTables(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext( + ctx, + `create table if not exists t_file_info ( + id bigint not null, + path varchar not null, + user bigint not null, + location varchar not null, + parent varchar not null, + name varchar not null, + is_dir boolean not null, + size bigint not null, + share_id varchar not null, + info varchar not null, + primary key(id) + )`, + ) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `create index if not exists t_file_path on t_file_info (path, location)`, + ) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `create index if not exists t_file_share on t_file_info (share_id, location)`, + ) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `create table if not exists t_file_uploading ( + id bigint not null, + real_path varchar not null, + tmp_path varchar not null unique, + user bigint not null, + size bigint not null, + uploaded bigint not null, + primary key(id) + )`, + ) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `create index if not exists t_file_uploading_path on t_file_uploading (real_path, user)`, + ) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `create index if not exists t_file_uploading_user on t_file_uploading (user)`, + ) + return err +} + +func (st *BaseStore) InitConfigTable(ctx context.Context, tx *sql.Tx, cfg *db.SiteConfig) error { + _, err := tx.ExecContext( + ctx, + `create table if not exists t_config ( + id bigint not null, + config varchar not null, + modified datetime not null, + primary key(id) + )`, + ) + if err != nil { + return err + } + + cfgStr, err := json.Marshal(cfg) + if err != nil { + return err + } + + _, err = st.getCfg(ctx, tx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + _, err = tx.ExecContext( + ctx, + `insert into t_config + (id, config, modified) values (?, ?, ?)`, + 0, cfgStr, time.Now(), + ) + return err + } + return err + } + + return nil +} diff --git a/src/db/rdb/base/users.go b/src/db/rdb/base/users.go new file mode 100644 index 0000000..69d4590 --- /dev/null +++ b/src/db/rdb/base/users.go @@ -0,0 +1,424 @@ +package base + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + + "github.com/ihexxa/quickshare/src/db" +) + +func (st *BaseStore) setUser(ctx context.Context, tx *sql.Tx, user *db.User) error { + var err error + if err = db.CheckUser(user, false); err != nil { + return err + } + + quotaStr, err := json.Marshal(user.Quota) + if err != nil { + return err + } + preferencesStr, err := json.Marshal(user.Preferences) + if err != nil { + return err + } + _, err = tx.ExecContext( + ctx, + `update t_user + set name=?, pwd=?, role=?, used_space=?, quota=?, preference=? + where id=?`, + user.Name, + user.Pwd, + user.Role, + user.UsedSpace, + quotaStr, + preferencesStr, + user.ID, + ) + return err +} + +func (st *BaseStore) getUser(ctx context.Context, tx *sql.Tx, id uint64) (*db.User, error) { + user := &db.User{} + var quotaStr, preferenceStr string + err := tx.QueryRowContext( + ctx, + `select id, name, pwd, role, used_space, quota, preference + from t_user + where id=?`, + id, + ).Scan( + &user.ID, + &user.Name, + &user.Pwd, + &user.Role, + &user.UsedSpace, + "aStr, + &preferenceStr, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, db.ErrUserNotFound + } + return nil, err + } + + err = json.Unmarshal([]byte(quotaStr), &user.Quota) + if err != nil { + return nil, err + } + err = json.Unmarshal([]byte(preferenceStr), &user.Preferences) + if err != nil { + return nil, err + } + return user, nil +} + +func (st *BaseStore) addUser(ctx context.Context, tx *sql.Tx, user *db.User) error { + quotaStr, err := json.Marshal(user.Quota) + if err != nil { + return err + } + preferenceStr, err := json.Marshal(user.Preferences) + if err != nil { + return err + } + _, err = tx.ExecContext( + ctx, + `insert into t_user (id, name, pwd, role, used_space, quota, preference) values (?, ?, ?, ?, ?, ?, ?)`, + user.ID, + user.Name, + user.Pwd, + user.Role, + user.UsedSpace, + quotaStr, + preferenceStr, + ) + return err +} + +func (st *BaseStore) AddUser(ctx context.Context, user *db.User) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + err = st.addUser(ctx, tx, user) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) DelUser(ctx context.Context, id uint64) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext( + ctx, + `delete from t_user where id=?`, + id, + ) + if err != nil { + return err + } + return tx.Commit() +} + +func (st *BaseStore) GetUser(ctx context.Context, id uint64) (*db.User, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + user, err := st.getUser(ctx, tx, id) + if err != nil { + return nil, err + } + err = tx.Commit() + if err != nil { + return nil, err + } + return user, err +} + +func (st *BaseStore) GetUserByName(ctx context.Context, name string) (*db.User, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + user := &db.User{} + var quotaStr, preferenceStr string + err = tx.QueryRowContext( + ctx, + `select id, name, pwd, role, used_space, quota, preference + from t_user + where name=?`, + name, + ).Scan( + &user.ID, + &user.Name, + &user.Pwd, + &user.Role, + &user.UsedSpace, + "aStr, + &preferenceStr, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, db.ErrUserNotFound + } + return nil, err + } + + err = json.Unmarshal([]byte(quotaStr), &user.Quota) + if err != nil { + return nil, err + } + err = json.Unmarshal([]byte(preferenceStr), &user.Preferences) + if err != nil { + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return user, nil +} + +func (st *BaseStore) SetPwd(ctx context.Context, id uint64, pwd string) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext( + ctx, + `update t_user + set pwd=? + where id=?`, + pwd, + id, + ) + if err != nil { + return err + } + + return tx.Commit() +} + +// role + quota +func (st *BaseStore) SetInfo(ctx context.Context, id uint64, user *db.User) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + quotaStr, err := json.Marshal(user.Quota) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `update t_user + set role=?, quota=? + where id=?`, + user.Role, quotaStr, + id, + ) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) SetPreferences(ctx context.Context, id uint64, prefers *db.Preferences) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + preferenceStr, err := json.Marshal(prefers) + if err != nil { + return err + } + + _, err = tx.ExecContext( + ctx, + `update t_user + set preference=? + where id=?`, + preferenceStr, + id, + ) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) SetUsed(ctx context.Context, id uint64, incr bool, capacity int64) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + err = st.setUsed(ctx, tx, id, incr, capacity) + if err != nil { + return err + } + + return tx.Commit() +} + +func (st *BaseStore) setUsed(ctx context.Context, tx *sql.Tx, id uint64, incr bool, capacity int64) error { + gotUser, err := st.getUser(ctx, tx, id) + if err != nil { + return err + } + + if incr && gotUser.UsedSpace+capacity > int64(gotUser.Quota.SpaceLimit) { + return db.ErrReachedLimit + } + + if incr { + gotUser.UsedSpace = gotUser.UsedSpace + capacity + } else { + if gotUser.UsedSpace-capacity < 0 { + return db.ErrNegtiveUsedSpace + } + gotUser.UsedSpace = gotUser.UsedSpace - capacity + } + + _, err = tx.ExecContext( + ctx, + `update t_user + set used_space=? + where id=?`, + gotUser.UsedSpace, + gotUser.ID, + ) + return err +} + +func (st *BaseStore) ResetUsed(ctx context.Context, id uint64, used int64) error { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext( + ctx, + `update t_user + set used_space=? + where id=?`, + used, + id, + ) + if err != nil { + return err + } + return tx.Commit() +} + +func (st *BaseStore) ListUsers(ctx context.Context) ([]*db.User, error) { + tx, err := st.db.BeginTx(ctx, txOpts) + if err != nil { + return nil, err + } + defer tx.Rollback() + + // TODO: support pagination + rows, err := tx.QueryContext( + ctx, + `select id, name, role, used_space, quota, preference + from t_user`, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, db.ErrUserNotFound + } + return nil, err + } + defer rows.Close() // TODO: check error + + users := []*db.User{} + for rows.Next() { + user := &db.User{} + var quotaStr, preferenceStr string + err = rows.Scan( + &user.ID, + &user.Name, + &user.Role, + &user.UsedSpace, + "aStr, + &preferenceStr, + ) + err = json.Unmarshal([]byte(quotaStr), &user.Quota) + if err != nil { + return nil, err + } + err = json.Unmarshal([]byte(preferenceStr), &user.Preferences) + if err != nil { + return nil, err + } + + users = append(users, user) + } + if rows.Err() != nil { + return nil, rows.Err() + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return users, nil +} + +func (st *BaseStore) ListUserIDs(ctx context.Context) (map[string]string, error) { + users, err := st.ListUsers(ctx) + if err != nil { + return nil, err + } + + nameToId := map[string]string{} + for _, user := range users { + nameToId[user.Name] = fmt.Sprint(user.ID) + } + return nameToId, nil +} + +func (st *BaseStore) AddRole(role string) error { + // TODO: implement this after adding grant/revoke + panic("not implemented") +} + +func (st *BaseStore) DelRole(role string) error { + // TODO: implement this after adding grant/revoke + panic("not implemented") +} + +func (st *BaseStore) ListRoles() (map[string]bool, error) { + // TODO: implement this after adding grant/revoke + panic("not implemented") +}