diff --git a/src/client/singleuser.go b/src/client/singleuser.go index adc14fb..2f12c52 100644 --- a/src/client/singleuser.go +++ b/src/client/singleuser.go @@ -61,11 +61,50 @@ func (cl *SingleUserClient) AddUser(name, pwd, role string, token *http.Cookie) }). End() + if len(errs) > 0 { + return nil, nil, errs + } + auResp := &multiusers.AddUserResp{} err := json.Unmarshal([]byte(body), auResp) if err != nil { errs = append(errs, err) return nil, nil, errs } - return resp, auResp, nil + return resp, auResp, errs +} + +func (cl *SingleUserClient) AddRole(role string, token *http.Cookie) (*http.Response, string, []error) { + return cl.r.Post(cl.url("/v1/roles/")). + AddCookie(token). + Send(multiusers.AddRoleReq{ + Role: role, + }). + End() +} + +func (cl *SingleUserClient) DelRole(role string, token *http.Cookie) (*http.Response, string, []error) { + return cl.r.Delete(cl.url("/v1/roles/")). + AddCookie(token). + Send(multiusers.DelRoleReq{ + Role: role, + }). + End() +} + +func (cl *SingleUserClient) ListRoles(token *http.Cookie) (*http.Response, *multiusers.ListRolesResp, []error) { + resp, body, errs := cl.r.Get(cl.url("/v1/roles/")). + AddCookie(token). + End() + if len(errs) > 0 { + return nil, nil, errs + } + + lsResp := &multiusers.ListRolesResp{} + err := json.Unmarshal([]byte(body), lsResp) + if err != nil { + errs = append(errs, err) + return nil, nil, errs + } + return resp, lsResp, errs } diff --git a/src/handlers/multiusers/handlers.go b/src/handlers/multiusers/handlers.go index 769e156..88b88ed 100644 --- a/src/handlers/multiusers/handlers.go +++ b/src/handlers/multiusers/handlers.go @@ -176,7 +176,6 @@ func (h *MultiUsersSvc) AddUser(c *gin.Context) { c.JSON(q.ErrResp(c, 400, err)) return } - // TODO: check privilege? // TODO: do more comprehensive validation // Role and duplicated name will be validated by the store @@ -209,6 +208,73 @@ func (h *MultiUsersSvc) AddUser(c *gin.Context) { c.JSON(200, &AddUserResp{ID: fmt.Sprint(uid)}) } +type AddRoleReq struct { + Role string `json:"role"` +} + +func (h *MultiUsersSvc) AddRole(c *gin.Context) { + req := &AddRoleReq{} + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(q.ErrResp(c, 400, err)) + return + } + + // TODO: do more comprehensive validation + if len(req.Role) < 2 { + c.JSON(q.ErrResp(c, 400, errors.New("name length must be greater than 2"))) + return + } + + err := h.deps.Users().AddRole(req.Role) + if err != nil { + c.JSON(q.ErrResp(c, 500, err)) + return + } + + c.JSON(q.Resp(200)) +} + +type DelRoleReq struct { + Role string `json:"role"` +} + +func (h *MultiUsersSvc) DelRole(c *gin.Context) { + req := &DelRoleReq{} + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(q.ErrResp(c, 400, err)) + return + } + + // TODO: do more comprehensive validation + if len(req.Role) < 2 { + c.JSON(q.ErrResp(c, 400, errors.New("name length must be greater than 2"))) + return + } + + err := h.deps.Users().DelRole(req.Role) + if err != nil { + c.JSON(q.ErrResp(c, 500, err)) + return + } + + c.JSON(q.Resp(200)) +} + +type ListRolesReq struct{} +type ListRolesResp struct { + Roles map[string]bool `json:"roles"` +} + +func (h *MultiUsersSvc) ListRoles(c *gin.Context) { + roles, err := h.deps.Users().ListRoles() + if err != nil { + c.JSON(q.ErrResp(c, 500, err)) + return + } + + c.JSON(200, &ListRolesResp{Roles: roles}) +} + func (h *MultiUsersSvc) getUserInfo(c *gin.Context) (map[string]string, error) { tokenStr, err := c.Cookie(TokenCookie) if err != nil { diff --git a/src/kvstore/boltdbpvd/provider.go b/src/kvstore/boltdbpvd/provider.go index 783e479..f1c4a46 100644 --- a/src/kvstore/boltdbpvd/provider.go +++ b/src/kvstore/boltdbpvd/provider.go @@ -77,10 +77,14 @@ func (bp *BoltPvd) Close() error { } func (bp *BoltPvd) GetBool(key string) (bool, bool) { + return bp.GetBoolIn("bools", key) +} + +func (bp *BoltPvd) GetBoolIn(ns, key string) (bool, bool) { buf, ok := make([]byte, 1), false bp.db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("bools")) + b := tx.Bucket([]byte(ns)) v := b.Get([]byte(key)) copy(buf, v) ok = v != nil @@ -92,23 +96,52 @@ func (bp *BoltPvd) GetBool(key string) (bool, bool) { } func (bp *BoltPvd) SetBool(key string, val bool) error { + return bp.SetBoolIn("bools", key, val) +} + +func (bp *BoltPvd) SetBoolIn(ns, key string, val bool) error { var bVal byte = 0 if val { bVal = 1 } return bp.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("bools")) + b := tx.Bucket([]byte(ns)) return b.Put([]byte(key), []byte{bVal}) }) } func (bp *BoltPvd) DelBool(key string) error { + return bp.DelBoolIn("bools", key) +} + +func (bp *BoltPvd) DelBoolIn(ns, key string) error { return bp.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("bools")) + b := tx.Bucket([]byte(ns)) return b.Delete([]byte(key)) }) } +func (bp *BoltPvd) ListBools() (map[string]bool, error) { + return bp.ListBoolsIn("bools") +} + +func (bp *BoltPvd) ListBoolsIn(ns string) (map[string]bool, error) { + list := map[string]bool{} + err := bp.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(ns)) + if b == nil { + return ErrBucketNotFound + } + + b.ForEach(func(k, v []byte) error { + list[string(k)] = (v[0] == 1) + return nil + }) + return nil + }) + return list, err +} + func (bp *BoltPvd) GetInt(key string) (int, bool) { x, ok := bp.GetInt64(key) return int(x), ok @@ -195,7 +228,7 @@ func (bp *BoltPvd) ListInt64sIn(ns string) (map[string]int64, error) { if n < 0 { return fmt.Errorf("fail to parse int64 for key (%s)", k) } - list[fmt.Sprintf("%s", k)] = x + list[string(k)] = x return nil }) return nil @@ -340,7 +373,7 @@ func (bp *BoltPvd) ListStringsIn(ns string) (map[string]string, error) { } b.ForEach(func(k, v []byte) error { - kv[fmt.Sprintf("%s", k)] = fmt.Sprintf("%s", v) + kv[string(k)] = string(v) return nil }) return nil diff --git a/src/kvstore/kvstore_interface.go b/src/kvstore/kvstore_interface.go index 4294c70..212b384 100644 --- a/src/kvstore/kvstore_interface.go +++ b/src/kvstore/kvstore_interface.go @@ -9,8 +9,13 @@ type IKVStore interface { AddNamespace(nsName string) error DelNamespace(nsName string) error GetBool(key string) (bool, bool) + GetBoolIn(ns, key string) (bool, bool) SetBool(key string, val bool) error + SetBoolIn(ns, key string, val bool) error DelBool(key string) error + DelBoolIn(ns, key string) error + ListBools() (map[string]bool, error) + ListBoolsIn(ns string) (map[string]bool, error) GetInt(key string) (int, bool) SetInt(key string, val int) error DelInt(key string) error diff --git a/src/kvstore/test/provider_test.go b/src/kvstore/test/provider_test.go index 6b1d957..76010e8 100644 --- a/src/kvstore/test/provider_test.go +++ b/src/kvstore/test/provider_test.go @@ -15,6 +15,7 @@ func TestKVStoreProviders(t *testing.T) { var err error var ok bool key, boolV, intV, int64V, floatV, stringV := "key", true, 2027, int64(2027), 3.1415, "foobar" + key2, boolV2 := "key2", false kvstoreTest := func(store kvstore.IKVStore, t *testing.T) { // test bools @@ -26,6 +27,19 @@ func TestKVStoreProviders(t *testing.T) { if err != nil { t.Errorf("there should be no error %v", err) } + err = store.SetBool(key2, boolV2) + if err != nil { + t.Errorf("there should be no error %v", err) + } + boolList, err := store.ListBools() + if err != nil { + t.Errorf("there should be no error %v", err) + } + if boolList[key] != boolV { + t.Error("listBool incorrect val1") + } else if boolList[key2] != boolV2 { + t.Error("listBool incorrect val2") + } boolVGot, ok := store.GetBool(key) if !ok { t.Error("value should exit") diff --git a/src/server/server.go b/src/server/server.go index 63fa5ea..7669ad8 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -179,6 +179,11 @@ func initHandlers(router *gin.Engine, cfg gocfg.ICfg, deps *depidx.Deps) (*gin.E usersAPI.PATCH("/pwd", userHdrs.SetPwd) usersAPI.POST("/", userHdrs.AddUser) + rolesAPI := v1.Group("/roles") + rolesAPI.POST("/", userHdrs.AddRole) + rolesAPI.DELETE("/", userHdrs.DelRole) + rolesAPI.GET("/", userHdrs.ListRoles) + filesAPI := v1.Group("/fs") filesAPI.POST("/files", fileHdrs.Create) filesAPI.DELETE("/files", fileHdrs.Delete) diff --git a/src/server/server_files_test.go b/src/server/server_files_test.go index 07ce733..213164e 100644 --- a/src/server/server_files_test.go +++ b/src/server/server_files_test.go @@ -17,7 +17,7 @@ import ( "github.com/ihexxa/quickshare/src/handlers/fileshdr" ) -func xTestFileHandlers(t *testing.T) { +func TestFileHandlers(t *testing.T) { addr := "http://127.0.0.1:8686" root := "testData" config := `{ diff --git a/src/server/server_users_test.go b/src/server/server_users_test.go index ca2ae62..3d43e37 100644 --- a/src/server/server_users_test.go +++ b/src/server/server_users_test.go @@ -112,4 +112,62 @@ func TestSingleUserHandlers(t *testing.T) { t.Fatal(resp.StatusCode) } }) + + t.Run("test roles APIs: Login-AddRole-ListRoles-DelRole-ListRoles", func(t *testing.T) { + resp, _, errs := usersCl.Login(adminName, adminNewPwd) + if len(errs) > 0 { + t.Fatal(errs) + } else if resp.StatusCode != 200 { + t.Fatal(resp.StatusCode) + } + + token := client.GetCookie(resp.Cookies(), su.TokenCookie) + roles := []string{"role1", "role2"} + + for _, role := range roles { + resp, _, errs := usersCl.AddRole(role, token) + if len(errs) > 0 { + t.Fatal(errs) + } else if resp.StatusCode != 200 { + t.Fatal(resp.StatusCode) + } + } + + resp, lsResp, errs := usersCl.ListRoles(token) + if len(errs) > 0 { + t.Fatal(errs) + } else if resp.StatusCode != 200 { + t.Fatal(resp.StatusCode) + } + for _, role := range append(roles, []string{ + userstore.AdminRole, + userstore.UserRole, + userstore.VisitorRole, + }...) { + if !lsResp.Roles[role] { + t.Fatalf("role(%s) not found", role) + } + } + + for _, role := range roles { + resp, _, errs := usersCl.DelRole(role, token) + if len(errs) > 0 { + t.Fatal(errs) + } else if resp.StatusCode != 200 { + t.Fatal(resp.StatusCode) + } + } + + resp, lsResp, errs = usersCl.ListRoles(token) + if len(errs) > 0 { + t.Fatal(errs) + } else if resp.StatusCode != 200 { + t.Fatal(resp.StatusCode) + } + for _, role := range roles { + if lsResp.Roles[role] { + t.Fatalf("role(%s) should not exist", role) + } + } + }) } diff --git a/src/userstore/user_store.go b/src/userstore/user_store.go index 3b698cf..1f725eb 100644 --- a/src/userstore/user_store.go +++ b/src/userstore/user_store.go @@ -21,6 +21,7 @@ const ( NamesNs = "users" PwdsNs = "pwds" RolesNs = "roles" + RoleListNs = "roleList" InitTimeKey = "initTime" ) @@ -40,6 +41,9 @@ type IUserStore interface { SetName(id uint64, name string) error SetPwd(id uint64, pwd string) error SetRole(id uint64, role string) error + AddRole(role string) error + DelRole(role string) error + ListRoles() (map[string]bool, error) } type KVUserStore struct { @@ -57,6 +61,7 @@ func NewKVUserStore(store kvstore.IKVStore) (*KVUserStore, error) { PwdsNs, RolesNs, InitNs, + RoleListNs, } { if err = store.AddNamespace(nsName); err != nil { return nil, err @@ -71,7 +76,8 @@ func NewKVUserStore(store kvstore.IKVStore) (*KVUserStore, error) { } func (us *KVUserStore) Init(rootName, rootPwd string) error { - err := us.AddUser(&User{ + var err error + err = us.AddUser(&User{ ID: 0, Name: rootName, Pwd: rootPwd, @@ -81,6 +87,13 @@ func (us *KVUserStore) Init(rootName, rootPwd string) error { return err } + for _, role := range []string{AdminRole, UserRole, VisitorRole} { + err = us.AddRole(role) + if err != nil { + return err + } + } + return us.store.SetStringIn(InitNs, InitTimeKey, fmt.Sprintf("%d", time.Now().Unix())) } @@ -244,3 +257,33 @@ func (us *KVUserStore) SetRole(id uint64, role string) error { return us.store.SetStringIn(RolesNs, userID, role) } + +func (us *KVUserStore) AddRole(role string) error { + us.mtx.Lock() + defer us.mtx.Unlock() + + _, ok := us.store.GetBoolIn(RoleListNs, role) + if ok { + return fmt.Errorf("role (%s) exists", role) + } + + return us.store.SetBoolIn(RoleListNs, role, true) +} + +func (us *KVUserStore) DelRole(role string) error { + us.mtx.Lock() + defer us.mtx.Unlock() + + if role == AdminRole || role == UserRole || role == VisitorRole { + return errors.New("predefined roles can not be deleted") + } + + return us.store.DelBoolIn(RoleListNs, role) +} + +func (us *KVUserStore) ListRoles() (map[string]bool, error) { + us.mtx.Lock() + defer us.mtx.Unlock() + + return us.store.ListBoolsIn(RoleListNs) +} diff --git a/src/userstore/user_store_test.go b/src/userstore/user_store_test.go index e95d59b..dea570f 100644 --- a/src/userstore/user_store_test.go +++ b/src/userstore/user_store_test.go @@ -11,7 +11,7 @@ import ( func TestUserStores(t *testing.T) { rootName, rootPwd := "root", "rootPwd" - testUserStore := func(t *testing.T, store IUserStore) { + testUserMethods := func(t *testing.T, store IUserStore) { root, err := store.GetUser(0) if err != nil { t.Fatal(err) @@ -93,6 +93,47 @@ func TestUserStores(t *testing.T) { } } + testRoleMethods := func(t *testing.T, store IUserStore) { + roles := []string{"role1", "role2"} + var err error + for _, role := range roles { + err = store.AddRole(role) + if err != nil { + t.Fatal(err) + } + } + + roleMap, err := store.ListRoles() + if err != nil { + t.Fatal(err) + } + + for _, role := range append(roles, []string{ + AdminRole, UserRole, VisitorRole, + }...) { + if !roleMap[role] { + t.Fatalf("role(%s) not found", role) + } + } + + for _, role := range roles { + err = store.DelRole(role) + if err != nil { + t.Fatal(err) + } + } + + roleMap, err = store.ListRoles() + if err != nil { + t.Fatal(err) + } + for _, role := range roles { + if roleMap[role] { + t.Fatalf("role(%s) should not exist", role) + } + } + } + t.Run("testing KVUserStore", func(t *testing.T) { rootPath, err := ioutil.TempDir("./", "quickshare_userstore_test_") if err != nil { @@ -111,6 +152,7 @@ func TestUserStores(t *testing.T) { t.Fatal("fail to init kvstore", err) } - testUserStore(t, store) + testUserMethods(t, store) + testRoleMethods(t, store) }) }