parent
30c963a5f0
commit
61a1c93f0f
89 changed files with 15859 additions and 2 deletions
105
server/apis/auth.go
Normal file
105
server/apis/auth.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/httpworker"
|
||||
)
|
||||
|
||||
func (srv *SrvShare) LoginHandler(res http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
srv.Http.Fill(httputil.Err404, res)
|
||||
return
|
||||
}
|
||||
|
||||
act := req.FormValue(srv.Conf.KeyAct)
|
||||
todo := func(res http.ResponseWriter, req *http.Request) interface{} { return httputil.Err404 }
|
||||
switch act {
|
||||
case srv.Conf.ActLogin:
|
||||
todo = srv.Login
|
||||
case srv.Conf.ActLogout:
|
||||
todo = srv.Logout
|
||||
default:
|
||||
srv.Http.Fill(httputil.Err404, res)
|
||||
return
|
||||
}
|
||||
|
||||
ack := make(chan error, 1)
|
||||
ok := srv.WorkerPool.Put(&httpworker.Task{
|
||||
Ack: ack,
|
||||
Do: srv.Wrap(todo),
|
||||
Res: res,
|
||||
Req: req,
|
||||
})
|
||||
if !ok {
|
||||
srv.Http.Fill(httputil.Err503, res)
|
||||
return
|
||||
}
|
||||
|
||||
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
|
||||
if srv.Err.IsErr(execErr) {
|
||||
srv.Http.Fill(httputil.Err500, res)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) Login(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
// all users need to pass same wall to login
|
||||
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) ||
|
||||
!srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdLogin) {
|
||||
return httputil.Err504
|
||||
}
|
||||
|
||||
return srv.login(
|
||||
req.FormValue(srv.Conf.KeyAdminId),
|
||||
req.FormValue(srv.Conf.KeyAdminPwd),
|
||||
res,
|
||||
)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) login(adminId string, adminPwd string, res http.ResponseWriter) interface{} {
|
||||
if adminId != srv.Conf.AdminId ||
|
||||
adminPwd != srv.Conf.AdminPwd {
|
||||
return httputil.Err401
|
||||
}
|
||||
|
||||
token := srv.Walls.MakeLoginToken(srv.Conf.AdminId)
|
||||
if token == "" {
|
||||
return httputil.Err500
|
||||
}
|
||||
|
||||
srv.Http.SetCookie(res, srv.Conf.KeyToken, token)
|
||||
return httputil.Ok200
|
||||
}
|
||||
|
||||
func (srv *SrvShare) Logout(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
srv.Http.SetCookie(res, srv.Conf.KeyToken, "-")
|
||||
return httputil.Ok200
|
||||
}
|
||||
|
||||
func (srv *SrvShare) IsValidLength(length int64) bool {
|
||||
return length > 0 && length <= srv.Conf.MaxUpBytesPerSec
|
||||
}
|
||||
|
||||
func (srv *SrvShare) IsValidStart(start, expectStart int64) bool {
|
||||
return start == expectStart
|
||||
}
|
||||
|
||||
func (srv *SrvShare) IsValidShareId(shareId string) bool {
|
||||
// id could be 0 for dev environment
|
||||
if srv.Conf.Production {
|
||||
return len(shareId) == 64
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (srv *SrvShare) IsValidDownLimit(limit int) bool {
|
||||
return limit >= -1
|
||||
}
|
||||
|
||||
func IsValidFileName(fileName string) bool {
|
||||
return fileName != "" && len(fileName) < 240
|
||||
}
|
78
server/apis/auth_test.go
Normal file
78
server/apis/auth_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/httputil"
|
||||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
|
||||
type testCase struct {
|
||||
Desc string
|
||||
AdminId string
|
||||
AdminPwd string
|
||||
Result interface{}
|
||||
VerifyToken bool
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "invalid input",
|
||||
AdminId: "",
|
||||
AdminPwd: "",
|
||||
Result: httputil.Err401,
|
||||
VerifyToken: false,
|
||||
},
|
||||
testCase{
|
||||
Desc: "account not match",
|
||||
AdminId: "unknown",
|
||||
AdminPwd: "unknown",
|
||||
Result: httputil.Err401,
|
||||
VerifyToken: false,
|
||||
},
|
||||
testCase{
|
||||
Desc: "succeed to login",
|
||||
AdminId: conf.AdminId,
|
||||
AdminPwd: conf.AdminPwd,
|
||||
Result: httputil.Ok200,
|
||||
VerifyToken: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := NewSrvShare(conf)
|
||||
res := &stubWriter{Headers: map[string][]string{}}
|
||||
ret := srv.login(testCase.AdminId, testCase.AdminPwd, res)
|
||||
|
||||
if ret != testCase.Result {
|
||||
t.Fatalf("login: reponse=%v testCase=%v", ret, testCase.Result)
|
||||
}
|
||||
|
||||
// verify cookie (only token.adminid part))
|
||||
if testCase.VerifyToken {
|
||||
cookieVal := strings.Replace(
|
||||
res.Header().Get("Set-Cookie"),
|
||||
fmt.Sprintf("%s=", conf.KeyToken),
|
||||
"",
|
||||
1,
|
||||
)
|
||||
|
||||
gotTokenStr := strings.Split(cookieVal, ";")[0]
|
||||
token := encrypt.JwtEncrypterMaker(conf.SecretKey)
|
||||
token.FromStr(gotTokenStr)
|
||||
gotToken, found := token.Get(conf.KeyAdminId)
|
||||
if !found || conf.AdminId != gotToken {
|
||||
t.Fatalf("login: token admin id unmatch got=%v expect=%v", gotToken, conf.AdminId)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
66
server/apis/client.go
Normal file
66
server/apis/client.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/httpworker"
|
||||
)
|
||||
|
||||
func (srv *SrvShare) ClientHandler(res http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodGet {
|
||||
srv.Http.Fill(httputil.Err404, res)
|
||||
return
|
||||
}
|
||||
|
||||
ack := make(chan error, 1)
|
||||
ok := srv.WorkerPool.Put(&httpworker.Task{
|
||||
Ack: ack,
|
||||
Do: srv.Wrap(srv.GetClient),
|
||||
Res: res,
|
||||
Req: req,
|
||||
})
|
||||
if !ok {
|
||||
srv.Http.Fill(httputil.Err503, res)
|
||||
return
|
||||
}
|
||||
|
||||
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
|
||||
if srv.Err.IsErr(execErr) {
|
||||
srv.Http.Fill(httputil.Err500, res)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) GetClient(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) {
|
||||
return httputil.Err504
|
||||
}
|
||||
|
||||
return srv.getClient(res, req, req.URL.EscapedPath())
|
||||
}
|
||||
|
||||
func (srv *SrvShare) getClient(res http.ResponseWriter, req *http.Request, relPath string) interface{} {
|
||||
if strings.HasSuffix(relPath, "/") {
|
||||
relPath = relPath + "index.html"
|
||||
}
|
||||
if !IsValidClientPath(relPath) {
|
||||
return httputil.Err400
|
||||
}
|
||||
|
||||
fullPath := filepath.Clean(filepath.Join("./public", relPath))
|
||||
http.ServeFile(res, req, fullPath)
|
||||
return 0
|
||||
}
|
||||
|
||||
func IsValidClientPath(fullPath string) bool {
|
||||
if strings.Contains(fullPath, "..") {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
69
server/apis/download.go
Normal file
69
server/apis/download.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/httpworker"
|
||||
)
|
||||
|
||||
func (srv *SrvShare) DownloadHandler(res http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodGet {
|
||||
srv.Http.Fill(httputil.Err404, res)
|
||||
}
|
||||
|
||||
ack := make(chan error, 1)
|
||||
ok := srv.WorkerPool.Put(&httpworker.Task{
|
||||
Ack: ack,
|
||||
Do: srv.Wrap(srv.Download),
|
||||
Res: res,
|
||||
Req: req,
|
||||
})
|
||||
if !ok {
|
||||
srv.Http.Fill(httputil.Err503, res)
|
||||
}
|
||||
|
||||
// using WriteTimeout instead of Timeout
|
||||
// After timeout, connection will be lost, and worker will fail to write and return
|
||||
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.WriteTimeout)*time.Millisecond)
|
||||
if srv.Err.IsErr(execErr) {
|
||||
srv.Http.Fill(httputil.Err500, res)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) Download(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
shareId := req.FormValue(srv.Conf.KeyShareId)
|
||||
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) ||
|
||||
!srv.Walls.PassOpLimit(shareId, srv.Conf.OpIdDownload) {
|
||||
return httputil.Err429
|
||||
}
|
||||
|
||||
return srv.download(shareId, res, req)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) download(shareId string, res http.ResponseWriter, req *http.Request) interface{} {
|
||||
if !srv.IsValidShareId(shareId) {
|
||||
return httputil.Err400
|
||||
}
|
||||
|
||||
fileInfo, found := srv.Index.Get(shareId)
|
||||
switch {
|
||||
case !found || fileInfo.State != fileidx.StateDone:
|
||||
return httputil.Err404
|
||||
case fileInfo.DownLimit == 0:
|
||||
return httputil.Err412
|
||||
default:
|
||||
updated, _ := srv.Index.DecrDownLimit(shareId)
|
||||
if updated != 1 {
|
||||
return httputil.Err500
|
||||
}
|
||||
}
|
||||
|
||||
err := srv.Downloader.ServeFile(res, req, fileInfo)
|
||||
srv.Err.IsErr(err)
|
||||
return 0
|
||||
}
|
271
server/apis/download_test.go
Normal file
271
server/apis/download_test.go
Normal file
|
@ -0,0 +1,271 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/errutil"
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/logutil"
|
||||
"quickshare/server/libs/qtube"
|
||||
)
|
||||
|
||||
func initServiceForDownloadTest(config *cfg.Config, indexMap map[string]*fileidx.FileInfo, content string) *SrvShare {
|
||||
setDownloader := func(srv *SrvShare) {
|
||||
srv.Downloader = stubDownloader{Content: content}
|
||||
}
|
||||
|
||||
setIndex := func(srv *SrvShare) {
|
||||
srv.Index = fileidx.NewMemFileIndexWithMap(len(indexMap), indexMap)
|
||||
}
|
||||
|
||||
setFs := func(srv *SrvShare) {
|
||||
srv.Fs = &stubFsUtil{
|
||||
MockFile: &qtube.StubFile{
|
||||
Content: content,
|
||||
Offset: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
logger := logutil.NewSlog(os.Stdout, config.AppName)
|
||||
setLog := func(srv *SrvShare) {
|
||||
srv.Log = logger
|
||||
}
|
||||
|
||||
setErr := func(srv *SrvShare) {
|
||||
srv.Err = errutil.NewErrChecker(!config.Production, logger)
|
||||
}
|
||||
|
||||
return InitSrvShare(config, setDownloader, setIndex, setFs, setLog, setErr)
|
||||
}
|
||||
|
||||
func TestDownload(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
Content string
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
ShareId string
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response interface{}
|
||||
Body string
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "empty file index",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "file info not found",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"1": &fileidx.FileInfo{},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"1": &fileidx.FileInfo{},
|
||||
},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "file not found because of state=uploading",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: 1,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateUploading,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: 1,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateUploading,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "download failed because download limit = 0",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: 0,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: 0,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
Response: httputil.Err412,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "succeed to download",
|
||||
Init: Init{
|
||||
Content: "content",
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: 1,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: 0,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
Response: 0,
|
||||
Body: "content",
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "succeed to download DownLimit == -1",
|
||||
Init: Init{
|
||||
Content: "content",
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: -1,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: -1,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "path",
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
Response: 0,
|
||||
Body: "content",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForDownloadTest(conf, testCase.Init.IndexMap, testCase.Content)
|
||||
writer := &stubWriter{Headers: map[string][]string{}}
|
||||
response := srv.download(
|
||||
testCase.ShareId,
|
||||
writer,
|
||||
&http.Request{},
|
||||
)
|
||||
|
||||
// verify downlimit
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
info, _ := srv.Index.Get(testCase.ShareId)
|
||||
t.Fatalf(
|
||||
"download: index incorrect got=%v want=%v",
|
||||
info,
|
||||
testCase.Output.IndexMap[testCase.ShareId],
|
||||
)
|
||||
}
|
||||
|
||||
// verify response
|
||||
if response != testCase.Output.Response {
|
||||
t.Fatalf(
|
||||
"download: response incorrect response=%v testCase=%v",
|
||||
response,
|
||||
testCase.Output.Response,
|
||||
)
|
||||
}
|
||||
|
||||
// verify writerContent
|
||||
if string(writer.Response) != testCase.Output.Body {
|
||||
t.Fatalf(
|
||||
"download: body incorrect got=%v want=%v",
|
||||
string(writer.Response),
|
||||
testCase.Output.Body,
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
234
server/apis/file_info.go
Normal file
234
server/apis/file_info.go
Normal file
|
@ -0,0 +1,234 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/httpworker"
|
||||
)
|
||||
|
||||
func (srv *SrvShare) FileInfoHandler(res http.ResponseWriter, req *http.Request) {
|
||||
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
|
||||
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) ||
|
||||
!srv.Walls.PassLoginCheck(tokenStr, req) {
|
||||
srv.Http.Fill(httputil.Err429, res)
|
||||
return
|
||||
}
|
||||
|
||||
todo := func(res http.ResponseWriter, req *http.Request) interface{} { return httputil.Err404 }
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
todo = srv.List
|
||||
case http.MethodDelete:
|
||||
todo = srv.Del
|
||||
case http.MethodPatch:
|
||||
act := req.FormValue(srv.Conf.KeyAct)
|
||||
switch act {
|
||||
case srv.Conf.ActShadowId:
|
||||
todo = srv.ShadowId
|
||||
case srv.Conf.ActPublishId:
|
||||
todo = srv.PublishId
|
||||
case srv.Conf.ActSetDownLimit:
|
||||
todo = srv.SetDownLimit
|
||||
case srv.Conf.ActAddLocalFiles:
|
||||
todo = srv.AddLocalFiles
|
||||
default:
|
||||
srv.Http.Fill(httputil.Err404, res)
|
||||
return
|
||||
}
|
||||
default:
|
||||
srv.Http.Fill(httputil.Err404, res)
|
||||
return
|
||||
}
|
||||
|
||||
ack := make(chan error, 1)
|
||||
ok := srv.WorkerPool.Put(&httpworker.Task{
|
||||
Ack: ack,
|
||||
Do: srv.Wrap(todo),
|
||||
Res: res,
|
||||
Req: req,
|
||||
})
|
||||
if !ok {
|
||||
srv.Http.Fill(httputil.Err503, res)
|
||||
}
|
||||
|
||||
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
|
||||
if srv.Err.IsErr(execErr) {
|
||||
srv.Http.Fill(httputil.Err500, res)
|
||||
}
|
||||
}
|
||||
|
||||
type ResInfos struct {
|
||||
List []*fileidx.FileInfo
|
||||
}
|
||||
|
||||
func (srv *SrvShare) List(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdGetFInfo) {
|
||||
return httputil.Err429
|
||||
}
|
||||
|
||||
return srv.list()
|
||||
}
|
||||
|
||||
func (srv *SrvShare) list() interface{} {
|
||||
infos := make([]*fileidx.FileInfo, 0)
|
||||
for _, info := range srv.Index.List() {
|
||||
infos = append(infos, info)
|
||||
}
|
||||
|
||||
return &ResInfos{List: infos}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) Del(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
shareId := req.FormValue(srv.Conf.KeyShareId)
|
||||
if !srv.Walls.PassOpLimit(shareId, srv.Conf.OpIdDelFInfo) {
|
||||
return httputil.Err504
|
||||
}
|
||||
|
||||
return srv.del(shareId)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) del(shareId string) interface{} {
|
||||
if !srv.IsValidShareId(shareId) {
|
||||
return httputil.Err400
|
||||
}
|
||||
|
||||
fileInfo, found := srv.Index.Get(shareId)
|
||||
if !found {
|
||||
return httputil.Err404
|
||||
}
|
||||
|
||||
srv.Index.Del(shareId)
|
||||
fullPath := filepath.Join(srv.Conf.PathLocal, fileInfo.PathLocal)
|
||||
if !srv.Fs.DelFile(fullPath) {
|
||||
// TODO: may log file name because file not exist or delete is not authenticated
|
||||
return httputil.Err500
|
||||
}
|
||||
|
||||
return httputil.Ok200
|
||||
}
|
||||
|
||||
func (srv *SrvShare) ShadowId(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdOpFInfo) {
|
||||
return httputil.Err429
|
||||
}
|
||||
|
||||
shareId := req.FormValue(srv.Conf.KeyShareId)
|
||||
return srv.shadowId(shareId)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) shadowId(shareId string) interface{} {
|
||||
if !srv.IsValidShareId(shareId) {
|
||||
return httputil.Err400
|
||||
}
|
||||
|
||||
info, found := srv.Index.Get(shareId)
|
||||
if !found {
|
||||
return httputil.Err404
|
||||
}
|
||||
|
||||
secretId := srv.Encryptor.Encrypt(
|
||||
[]byte(fmt.Sprintf("%s%s", info.PathLocal, genPwd())),
|
||||
)
|
||||
if !srv.Index.SetId(info.Id, secretId) {
|
||||
return httputil.Err412
|
||||
}
|
||||
|
||||
return &ShareInfo{ShareId: secretId}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) PublishId(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdOpFInfo) {
|
||||
return httputil.Err429
|
||||
}
|
||||
|
||||
shareId := req.FormValue(srv.Conf.KeyShareId)
|
||||
return srv.publishId(shareId)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) publishId(shareId string) interface{} {
|
||||
if !srv.IsValidShareId(shareId) {
|
||||
return httputil.Err400
|
||||
}
|
||||
|
||||
info, found := srv.Index.Get(shareId)
|
||||
if !found {
|
||||
return httputil.Err404
|
||||
}
|
||||
|
||||
publicId := srv.Encryptor.Encrypt([]byte(info.PathLocal))
|
||||
if !srv.Index.SetId(info.Id, publicId) {
|
||||
return httputil.Err412
|
||||
}
|
||||
|
||||
return &ShareInfo{ShareId: publicId}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) SetDownLimit(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdOpFInfo) {
|
||||
return httputil.Err429
|
||||
}
|
||||
|
||||
shareId := req.FormValue(srv.Conf.KeyShareId)
|
||||
downLimit64, downLimitParseErr := strconv.ParseInt(req.FormValue(srv.Conf.KeyDownLimit), 10, 32)
|
||||
downLimit := int(downLimit64)
|
||||
if srv.Err.IsErr(downLimitParseErr) {
|
||||
return httputil.Err400
|
||||
}
|
||||
|
||||
return srv.setDownLimit(shareId, downLimit)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) setDownLimit(shareId string, downLimit int) interface{} {
|
||||
if !srv.IsValidShareId(shareId) || !srv.IsValidDownLimit(downLimit) {
|
||||
return httputil.Err400
|
||||
}
|
||||
|
||||
if !srv.Index.SetDownLimit(shareId, downLimit) {
|
||||
return httputil.Err404
|
||||
}
|
||||
return httputil.Ok200
|
||||
}
|
||||
|
||||
func (srv *SrvShare) AddLocalFiles(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
return srv.AddLocalFilesImp()
|
||||
}
|
||||
|
||||
func (srv *SrvShare) AddLocalFilesImp() interface{} {
|
||||
infos, err := srv.Fs.Readdir(srv.Conf.PathLocal, srv.Conf.LocalFileLimit)
|
||||
if srv.Err.IsErr(err) {
|
||||
panic(fmt.Sprintf("fail to readdir: %v", err))
|
||||
}
|
||||
|
||||
for _, info := range infos {
|
||||
info.DownLimit = srv.Conf.DownLimit
|
||||
info.State = fileidx.StateDone
|
||||
info.PathLocal = info.PathLocal
|
||||
info.Id = srv.Encryptor.Encrypt([]byte(info.PathLocal))
|
||||
|
||||
addRet := srv.Index.Add(info)
|
||||
switch {
|
||||
case addRet == 0 || addRet == -1:
|
||||
// TODO: return files not added
|
||||
continue
|
||||
case addRet == 1:
|
||||
break
|
||||
default:
|
||||
return httputil.Err500
|
||||
}
|
||||
}
|
||||
|
||||
return httputil.Ok200
|
||||
}
|
||||
|
||||
func genPwd() string {
|
||||
return fmt.Sprintf("%d%d%d%d", rand.Intn(10), rand.Intn(10), rand.Intn(10), rand.Intn(10))
|
||||
}
|
584
server/apis/file_info_test.go
Normal file
584
server/apis/file_info_test.go
Normal file
|
@ -0,0 +1,584 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/errutil"
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/logutil"
|
||||
)
|
||||
|
||||
const mockShadowId = "shadowId"
|
||||
const mockPublicId = "publicId"
|
||||
|
||||
func initServiceForFileInfoTest(
|
||||
config *cfg.Config,
|
||||
indexMap map[string]*fileidx.FileInfo,
|
||||
useShadowEnc bool,
|
||||
localFileInfos []*fileidx.FileInfo,
|
||||
) *SrvShare {
|
||||
setIndex := func(srv *SrvShare) {
|
||||
srv.Index = fileidx.NewMemFileIndexWithMap(len(indexMap), indexMap)
|
||||
}
|
||||
|
||||
setFs := func(srv *SrvShare) {
|
||||
srv.Fs = &stubFsUtil{MockLocalFileInfos: localFileInfos}
|
||||
}
|
||||
|
||||
logger := logutil.NewSlog(os.Stdout, config.AppName)
|
||||
setLog := func(srv *SrvShare) {
|
||||
srv.Log = logger
|
||||
}
|
||||
|
||||
errChecker := errutil.NewErrChecker(!config.Production, logger)
|
||||
setErr := func(srv *SrvShare) {
|
||||
srv.Err = errChecker
|
||||
}
|
||||
|
||||
var setEncryptor AddDep
|
||||
if useShadowEnc {
|
||||
setEncryptor = func(srv *SrvShare) {
|
||||
srv.Encryptor = &stubEncryptor{MockResult: mockShadowId}
|
||||
}
|
||||
} else {
|
||||
setEncryptor = func(srv *SrvShare) {
|
||||
srv.Encryptor = &stubEncryptor{MockResult: mockPublicId}
|
||||
}
|
||||
}
|
||||
|
||||
return InitSrvShare(config, setIndex, setFs, setEncryptor, setLog, setErr)
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type TestCase struct {
|
||||
Desc string
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
TestCase{
|
||||
Desc: "success",
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
},
|
||||
"1": &fileidx.FileInfo{
|
||||
Id: "1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForFileInfoTest(conf, testCase.Output.IndexMap, true, []*fileidx.FileInfo{})
|
||||
response := srv.list()
|
||||
resInfos := response.(*ResInfos)
|
||||
|
||||
for _, info := range resInfos.List {
|
||||
infoFromSrv, found := srv.Index.Get(info.Id)
|
||||
if !found || infoFromSrv.Id != info.Id {
|
||||
t.Fatalf("list: file infos are not identical")
|
||||
}
|
||||
}
|
||||
|
||||
if len(resInfos.List) != len(srv.Index.List()) {
|
||||
t.Fatalf("list: file infos are not identical")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDel(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
ShareId string
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response httputil.MsgRes
|
||||
}
|
||||
type TestCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
TestCase{
|
||||
Desc: "success",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
},
|
||||
"1": &fileidx.FileInfo{
|
||||
Id: "1",
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"1": &fileidx.FileInfo{
|
||||
Id: "1",
|
||||
},
|
||||
},
|
||||
Response: httputil.Ok200,
|
||||
},
|
||||
},
|
||||
TestCase{
|
||||
Desc: "not found",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"1": &fileidx.FileInfo{
|
||||
Id: "1",
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"1": &fileidx.FileInfo{
|
||||
Id: "1",
|
||||
},
|
||||
},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, true, []*fileidx.FileInfo{})
|
||||
response := srv.del(testCase.ShareId)
|
||||
res := response.(httputil.MsgRes)
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
t.Fatalf("del: index incorrect")
|
||||
}
|
||||
|
||||
if res != testCase.Output.Response {
|
||||
t.Fatalf("del: response incorrect got: %v, want: %v", res, testCase.Output.Response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShadowId(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
ShareId string
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response interface{}
|
||||
}
|
||||
type TestCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
TestCase{
|
||||
Desc: "success",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
mockShadowId: &fileidx.FileInfo{
|
||||
Id: mockShadowId,
|
||||
},
|
||||
},
|
||||
Response: &ShareInfo{
|
||||
ShareId: mockShadowId,
|
||||
},
|
||||
},
|
||||
},
|
||||
TestCase{
|
||||
Desc: "original id not exists",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
TestCase{
|
||||
Desc: "dest id exists",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
},
|
||||
mockShadowId: &fileidx.FileInfo{
|
||||
Id: mockShadowId,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
},
|
||||
mockShadowId: &fileidx.FileInfo{
|
||||
Id: mockShadowId,
|
||||
},
|
||||
},
|
||||
Response: httputil.Err412,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, true, []*fileidx.FileInfo{})
|
||||
response := srv.shadowId(testCase.ShareId)
|
||||
|
||||
switch response.(type) {
|
||||
case *ShareInfo:
|
||||
res := response.(*ShareInfo)
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
info, found := srv.Index.Get(mockShadowId)
|
||||
t.Fatalf(
|
||||
"shadowId: index incorrect got %v found: %v want %v",
|
||||
info,
|
||||
found,
|
||||
testCase.Output.IndexMap[mockShadowId],
|
||||
)
|
||||
}
|
||||
|
||||
if res.ShareId != mockShadowId {
|
||||
t.Fatalf("shadowId: mockId incorrect")
|
||||
}
|
||||
|
||||
case httputil.MsgRes:
|
||||
res := response.(httputil.MsgRes)
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
t.Fatalf("shadowId: map not identical")
|
||||
}
|
||||
|
||||
if res != testCase.Output.Response {
|
||||
t.Fatalf("shadowId: response incorrect")
|
||||
}
|
||||
default:
|
||||
t.Fatalf("shadowId: return type not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishId(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
ShareId string
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response interface{}
|
||||
}
|
||||
type TestCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
TestCase{
|
||||
Desc: "success",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
mockShadowId: &fileidx.FileInfo{
|
||||
Id: mockShadowId,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: mockShadowId,
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
mockPublicId: &fileidx.FileInfo{
|
||||
Id: mockPublicId,
|
||||
},
|
||||
},
|
||||
Response: &ShareInfo{
|
||||
ShareId: mockPublicId,
|
||||
},
|
||||
},
|
||||
},
|
||||
TestCase{
|
||||
Desc: "original id not exists",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
TestCase{
|
||||
Desc: "dest id exists",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
mockShadowId: &fileidx.FileInfo{
|
||||
Id: mockShadowId,
|
||||
},
|
||||
mockPublicId: &fileidx.FileInfo{
|
||||
Id: mockPublicId,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: mockShadowId,
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
mockShadowId: &fileidx.FileInfo{
|
||||
Id: mockShadowId,
|
||||
},
|
||||
mockPublicId: &fileidx.FileInfo{
|
||||
Id: mockPublicId,
|
||||
},
|
||||
},
|
||||
Response: httputil.Err412,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, false, []*fileidx.FileInfo{})
|
||||
response := srv.publishId(testCase.ShareId)
|
||||
|
||||
switch response.(type) {
|
||||
case *ShareInfo:
|
||||
res := response.(*ShareInfo)
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
info, found := srv.Index.Get(mockPublicId)
|
||||
t.Fatalf(
|
||||
"shadowId: index incorrect got %v found: %v want %v",
|
||||
info,
|
||||
found,
|
||||
testCase.Output.IndexMap[mockPublicId],
|
||||
)
|
||||
}
|
||||
|
||||
if res.ShareId != mockPublicId {
|
||||
t.Fatalf("shadowId: mockId incorrect", res.ShareId, mockPublicId)
|
||||
}
|
||||
|
||||
case httputil.MsgRes:
|
||||
res := response.(httputil.MsgRes)
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
t.Fatalf("shadowId: map not identical")
|
||||
}
|
||||
|
||||
if res != testCase.Output.Response {
|
||||
t.Fatalf("shadowId: response incorrect got: %v want: %v", res, testCase.Output.Response)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("shadowId: return type not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDownLimit(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
mockDownLimit := 100
|
||||
|
||||
type Init struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
ShareId string
|
||||
DownLimit int
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response httputil.MsgRes
|
||||
}
|
||||
type TestCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
TestCase{
|
||||
Desc: "success",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
DownLimit: mockDownLimit,
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
"0": &fileidx.FileInfo{
|
||||
Id: "0",
|
||||
DownLimit: mockDownLimit,
|
||||
},
|
||||
},
|
||||
Response: httputil.Ok200,
|
||||
},
|
||||
},
|
||||
TestCase{
|
||||
Desc: "not found",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: "0",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, true, []*fileidx.FileInfo{})
|
||||
response := srv.setDownLimit(testCase.ShareId, mockDownLimit)
|
||||
res := response.(httputil.MsgRes)
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
info, _ := srv.Index.Get(testCase.ShareId)
|
||||
t.Fatalf(
|
||||
"setDownLimit: index incorrect got: %v want: %v",
|
||||
info,
|
||||
testCase.Output.IndexMap[testCase.ShareId],
|
||||
)
|
||||
}
|
||||
|
||||
if res != testCase.Output.Response {
|
||||
t.Fatalf("setDownLimit: response incorrect got: %v, want: %v", res, testCase.Output.Response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddLocalFiles(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
Infos []*fileidx.FileInfo
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response httputil.MsgRes
|
||||
}
|
||||
type TestCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
TestCase{
|
||||
Desc: "success",
|
||||
Init: Init{
|
||||
Infos: []*fileidx.FileInfo{
|
||||
&fileidx.FileInfo{
|
||||
Id: "",
|
||||
DownLimit: 0,
|
||||
ModTime: 13,
|
||||
PathLocal: "filename1",
|
||||
State: "",
|
||||
Uploaded: 13,
|
||||
},
|
||||
},
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
mockPublicId: &fileidx.FileInfo{
|
||||
Id: mockPublicId,
|
||||
DownLimit: conf.DownLimit,
|
||||
ModTime: 13,
|
||||
PathLocal: filepath.Join(conf.PathLocal, "filename1"),
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 13,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForFileInfoTest(conf, testCase.Output.IndexMap, false, testCase.Init.Infos)
|
||||
response := srv.AddLocalFilesImp()
|
||||
res := response.(httputil.MsgRes)
|
||||
|
||||
if res.Code != 200 {
|
||||
t.Fatalf("addLocalFiles: code not correct")
|
||||
}
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
t.Fatalf(
|
||||
"addLocalFiles: indexes not identical got: %v want: %v",
|
||||
srv.Index.List(),
|
||||
testCase.Output.IndexMap,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
145
server/apis/service.go
Normal file
145
server/apis/service.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/errutil"
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/fsutil"
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/httpworker"
|
||||
"quickshare/server/libs/limiter"
|
||||
"quickshare/server/libs/logutil"
|
||||
"quickshare/server/libs/qtube"
|
||||
"quickshare/server/libs/walls"
|
||||
)
|
||||
|
||||
type AddDep func(*SrvShare)
|
||||
|
||||
func NewSrvShare(config *cfg.Config) *SrvShare {
|
||||
logger := logutil.NewSlog(os.Stdout, config.AppName)
|
||||
setLog := func(srv *SrvShare) {
|
||||
srv.Log = logger
|
||||
}
|
||||
|
||||
errChecker := errutil.NewErrChecker(!config.Production, logger)
|
||||
setErr := func(srv *SrvShare) {
|
||||
srv.Err = errChecker
|
||||
}
|
||||
|
||||
setWorkerPool := func(srv *SrvShare) {
|
||||
workerPoolSize := config.WorkerPoolSize
|
||||
taskQueueSize := config.TaskQueueSize
|
||||
srv.WorkerPool = httpworker.NewWorkerPool(workerPoolSize, taskQueueSize, logger)
|
||||
}
|
||||
|
||||
setWalls := func(srv *SrvShare) {
|
||||
encrypterMaker := encrypt.JwtEncrypterMaker
|
||||
ipLimiter := limiter.NewRateLimiter(
|
||||
config.LimiterCap,
|
||||
config.LimiterTtl,
|
||||
config.LimiterCyc,
|
||||
config.BucketCap,
|
||||
config.SpecialCaps,
|
||||
)
|
||||
opLimiter := limiter.NewRateLimiter(
|
||||
config.LimiterCap,
|
||||
config.LimiterTtl,
|
||||
config.LimiterCyc,
|
||||
config.BucketCap,
|
||||
config.SpecialCaps,
|
||||
)
|
||||
srv.Walls = walls.NewAccessWalls(config, ipLimiter, opLimiter, encrypterMaker)
|
||||
}
|
||||
|
||||
setIndex := func(srv *SrvShare) {
|
||||
srv.Index = fileidx.NewMemFileIndex(config.MaxShares)
|
||||
}
|
||||
|
||||
fs := fsutil.NewSimpleFs(errChecker)
|
||||
setFs := func(srv *SrvShare) {
|
||||
srv.Fs = fs
|
||||
}
|
||||
|
||||
setDownloader := func(srv *SrvShare) {
|
||||
srv.Downloader = qtube.NewQTube(
|
||||
config.PathLocal,
|
||||
config.MaxDownBytesPerSec,
|
||||
config.MaxRangeLength,
|
||||
fs,
|
||||
)
|
||||
}
|
||||
|
||||
setEncryptor := func(srv *SrvShare) {
|
||||
srv.Encryptor = &encrypt.HmacEncryptor{Key: config.SecretKeyByte}
|
||||
}
|
||||
|
||||
setHttp := func(srv *SrvShare) {
|
||||
srv.Http = &httputil.QHttpUtil{
|
||||
CookieDomain: config.CookieDomain,
|
||||
CookieHttpOnly: config.CookieHttpOnly,
|
||||
CookieMaxAge: config.CookieMaxAge,
|
||||
CookiePath: config.CookiePath,
|
||||
CookieSecure: config.CookieSecure,
|
||||
Err: errChecker,
|
||||
}
|
||||
}
|
||||
|
||||
return InitSrvShare(config, setIndex, setWalls, setWorkerPool, setFs, setDownloader, setEncryptor, setLog, setErr, setHttp)
|
||||
}
|
||||
|
||||
func InitSrvShare(config *cfg.Config, addDeps ...AddDep) *SrvShare {
|
||||
srv := &SrvShare{}
|
||||
srv.Conf = config
|
||||
for _, addDep := range addDeps {
|
||||
addDep(srv)
|
||||
}
|
||||
|
||||
if !srv.Fs.MkdirAll(srv.Conf.PathLocal, os.FileMode(0775)) {
|
||||
panic("fail to make ./files/ folder")
|
||||
}
|
||||
|
||||
if res := srv.AddLocalFilesImp(); res != httputil.Ok200 {
|
||||
panic("fail to add local files")
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
type SrvShare struct {
|
||||
Conf *cfg.Config
|
||||
Encryptor encrypt.Encryptor
|
||||
Err errutil.ErrUtil
|
||||
Downloader qtube.Downloader
|
||||
Http httputil.HttpUtil
|
||||
Index fileidx.FileIndex
|
||||
Fs fsutil.FsUtil
|
||||
Log logutil.LogUtil
|
||||
Walls walls.Walls
|
||||
WorkerPool httpworker.Workers
|
||||
}
|
||||
|
||||
func (srv *SrvShare) Wrap(serviceFunc httpworker.ServiceFunc) httpworker.DoFunc {
|
||||
return func(res http.ResponseWriter, req *http.Request) {
|
||||
body := serviceFunc(res, req)
|
||||
|
||||
if body != nil && body != 0 && srv.Http.Fill(body, res) <= 0 {
|
||||
log.Println("Wrap: fail to fill body", body, res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetRemoteIp(addr string) string {
|
||||
addrParts := strings.Split(addr, ":")
|
||||
if len(addrParts) > 0 {
|
||||
return addrParts[0]
|
||||
}
|
||||
return "unknown ip"
|
||||
}
|
117
server/apis/test_helper.go
Normal file
117
server/apis/test_helper.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/qtube"
|
||||
)
|
||||
|
||||
type stubFsUtil struct {
|
||||
MockLocalFileInfos []*fileidx.FileInfo
|
||||
MockFile *qtube.StubFile
|
||||
}
|
||||
|
||||
var expectCreateFileName = ""
|
||||
|
||||
func (fs *stubFsUtil) CreateFile(fileName string) error {
|
||||
if fileName != expectCreateFileName {
|
||||
panic(
|
||||
fmt.Sprintf("CreateFile: got: %s expect: %s", fileName, expectCreateFileName),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fs *stubFsUtil) CopyChunkN(fullPath string, chunk io.Reader, start int64, len int64) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (fs *stubFsUtil) ServeFile(res http.ResponseWriter, req *http.Request, fileName string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (fs *stubFsUtil) DelFile(fullPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (fs *stubFsUtil) MkdirAll(path string, mode os.FileMode) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (fs *stubFsUtil) Readdir(dirname string, n int) ([]*fileidx.FileInfo, error) {
|
||||
return fs.MockLocalFileInfos, nil
|
||||
}
|
||||
|
||||
func (fs *stubFsUtil) Open(filePath string) (qtube.ReadSeekCloser, error) {
|
||||
return fs.MockFile, nil
|
||||
}
|
||||
|
||||
type stubWriter struct {
|
||||
Headers http.Header
|
||||
Response []byte
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (w *stubWriter) Header() http.Header {
|
||||
return w.Headers
|
||||
}
|
||||
|
||||
func (w *stubWriter) Write(body []byte) (int, error) {
|
||||
w.Response = append(w.Response, body...)
|
||||
return len(body), nil
|
||||
}
|
||||
|
||||
func (w *stubWriter) WriteHeader(statusCode int) {
|
||||
w.StatusCode = statusCode
|
||||
}
|
||||
|
||||
type stubDownloader struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
func (d stubDownloader) ServeFile(w http.ResponseWriter, r *http.Request, fileInfo *fileidx.FileInfo) error {
|
||||
_, err := w.Write([]byte(d.Content))
|
||||
return err
|
||||
}
|
||||
|
||||
func sameInfoWithoutTime(info1, info2 *fileidx.FileInfo) bool {
|
||||
return info1.Id == info2.Id &&
|
||||
info1.DownLimit == info2.DownLimit &&
|
||||
info1.PathLocal == info2.PathLocal &&
|
||||
info1.State == info2.State &&
|
||||
info1.Uploaded == info2.Uploaded
|
||||
}
|
||||
|
||||
func sameMap(map1, map2 map[string]*fileidx.FileInfo) bool {
|
||||
for key, info1 := range map1 {
|
||||
info2, found := map2[key]
|
||||
if !found || !sameInfoWithoutTime(info1, info2) {
|
||||
fmt.Printf("infos are not same: \n%v \n%v", info1, info2)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for key, info2 := range map2 {
|
||||
info1, found := map1[key]
|
||||
if !found || !sameInfoWithoutTime(info1, info2) {
|
||||
fmt.Printf("infos are not same: \n%v \n%v", info1, info2)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type stubEncryptor struct {
|
||||
MockResult string
|
||||
}
|
||||
|
||||
func (enc *stubEncryptor) Encrypt(content []byte) string {
|
||||
return enc.MockResult
|
||||
}
|
253
server/apis/upload.go
Normal file
253
server/apis/upload.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/fsutil"
|
||||
httpUtil "quickshare/server/libs/httputil"
|
||||
worker "quickshare/server/libs/httpworker"
|
||||
)
|
||||
|
||||
const DefaultId = "0"
|
||||
|
||||
type ByteRange struct {
|
||||
ShareId string
|
||||
Start int64
|
||||
Length int64
|
||||
}
|
||||
|
||||
type ShareInfo struct {
|
||||
ShareId string
|
||||
}
|
||||
|
||||
func (srv *SrvShare) StartUploadHandler(res http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
srv.Http.Fill(httpUtil.Err404, res)
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
|
||||
ipPass := srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr))
|
||||
loginPass := srv.Walls.PassLoginCheck(tokenStr, req)
|
||||
opPass := srv.Walls.PassOpLimit(GetRemoteIp(req.RemoteAddr), srv.Conf.OpIdUpload)
|
||||
if !ipPass || !loginPass || !opPass {
|
||||
srv.Http.Fill(httpUtil.Err429, res)
|
||||
return
|
||||
}
|
||||
|
||||
ack := make(chan error, 1)
|
||||
ok := srv.WorkerPool.Put(&worker.Task{
|
||||
Ack: ack,
|
||||
Do: srv.Wrap(srv.StartUpload),
|
||||
Res: res,
|
||||
Req: req,
|
||||
})
|
||||
if !ok {
|
||||
srv.Http.Fill(httpUtil.Err503, res)
|
||||
}
|
||||
|
||||
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
|
||||
if srv.Err.IsErr(execErr) {
|
||||
srv.Http.Fill(httpUtil.Err500, res)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) UploadHandler(res http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
srv.Http.Fill(httpUtil.Err404, res)
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
|
||||
ipPass := srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr))
|
||||
loginPass := srv.Walls.PassLoginCheck(tokenStr, req)
|
||||
opPass := srv.Walls.PassOpLimit(GetRemoteIp(req.RemoteAddr), srv.Conf.OpIdUpload)
|
||||
if !ipPass || !loginPass || !opPass {
|
||||
srv.Http.Fill(httpUtil.Err429, res)
|
||||
return
|
||||
}
|
||||
|
||||
multiFormErr := req.ParseMultipartForm(srv.Conf.ParseFormBufSize)
|
||||
if srv.Err.IsErr(multiFormErr) {
|
||||
srv.Http.Fill(httpUtil.Err400, res)
|
||||
return
|
||||
}
|
||||
|
||||
srv.Log.Println("form", req.Form)
|
||||
srv.Log.Println("pform", req.PostForm)
|
||||
srv.Log.Println("mform", req.MultipartForm)
|
||||
ack := make(chan error, 1)
|
||||
ok := srv.WorkerPool.Put(&worker.Task{
|
||||
Ack: ack,
|
||||
Do: srv.Wrap(srv.Upload),
|
||||
Res: res,
|
||||
Req: req,
|
||||
})
|
||||
if !ok {
|
||||
srv.Http.Fill(httpUtil.Err503, res)
|
||||
}
|
||||
|
||||
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
|
||||
if srv.Err.IsErr(execErr) {
|
||||
srv.Http.Fill(httpUtil.Err500, res)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) FinishUploadHandler(res http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
srv.Http.Fill(httpUtil.Err404, res)
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
|
||||
ipPass := srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr))
|
||||
loginPass := srv.Walls.PassLoginCheck(tokenStr, req)
|
||||
opPass := srv.Walls.PassOpLimit(GetRemoteIp(req.RemoteAddr), srv.Conf.OpIdUpload)
|
||||
if !ipPass || !loginPass || !opPass {
|
||||
srv.Http.Fill(httpUtil.Err429, res)
|
||||
return
|
||||
}
|
||||
|
||||
ack := make(chan error, 1)
|
||||
ok := srv.WorkerPool.Put(&worker.Task{
|
||||
Ack: ack,
|
||||
Do: srv.Wrap(srv.FinishUpload),
|
||||
Res: res,
|
||||
Req: req,
|
||||
})
|
||||
if !ok {
|
||||
srv.Http.Fill(httpUtil.Err503, res)
|
||||
}
|
||||
|
||||
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
|
||||
if srv.Err.IsErr(execErr) {
|
||||
srv.Http.Fill(httpUtil.Err500, res)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) StartUpload(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
return srv.startUpload(req.FormValue(srv.Conf.KeyFileName))
|
||||
}
|
||||
|
||||
func (srv *SrvShare) startUpload(fileName string) interface{} {
|
||||
if !IsValidFileName(fileName) {
|
||||
return httpUtil.Err400
|
||||
}
|
||||
|
||||
id := DefaultId
|
||||
if srv.Conf.Production {
|
||||
id = genInfoId(fileName, srv.Conf.SecretKeyByte)
|
||||
}
|
||||
|
||||
info := &fileidx.FileInfo{
|
||||
Id: id,
|
||||
DownLimit: srv.Conf.DownLimit,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: fileName,
|
||||
Uploaded: 0,
|
||||
State: fileidx.StateStarted,
|
||||
}
|
||||
|
||||
switch srv.Index.Add(info) {
|
||||
case 0:
|
||||
// go on
|
||||
case -1:
|
||||
return httpUtil.Err412
|
||||
case 1:
|
||||
return httpUtil.Err500 // TODO: use correct status code
|
||||
default:
|
||||
srv.Index.Del(id)
|
||||
return httpUtil.Err500
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(srv.Conf.PathLocal, info.PathLocal)
|
||||
createFileErr := srv.Fs.CreateFile(fullPath)
|
||||
switch {
|
||||
case createFileErr == fsutil.ErrExists:
|
||||
srv.Index.Del(id)
|
||||
return httpUtil.Err412
|
||||
case createFileErr == fsutil.ErrUnknown:
|
||||
srv.Index.Del(id)
|
||||
return httpUtil.Err500
|
||||
default:
|
||||
srv.Index.SetState(id, fileidx.StateUploading)
|
||||
return &ByteRange{
|
||||
ShareId: id,
|
||||
Start: 0,
|
||||
Length: srv.Conf.MaxUpBytesPerSec,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) Upload(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
shareId := req.FormValue(srv.Conf.KeyShareId)
|
||||
start, startErr := strconv.ParseInt(req.FormValue(srv.Conf.KeyStart), 10, 64)
|
||||
length, lengthErr := strconv.ParseInt(req.FormValue(srv.Conf.KeyLen), 10, 64)
|
||||
chunk, _, chunkErr := req.FormFile(srv.Conf.KeyChunk)
|
||||
|
||||
if srv.Err.IsErr(startErr) ||
|
||||
srv.Err.IsErr(lengthErr) ||
|
||||
srv.Err.IsErr(chunkErr) {
|
||||
return httpUtil.Err400
|
||||
}
|
||||
|
||||
return srv.upload(shareId, start, length, chunk)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) upload(shareId string, start int64, length int64, chunk io.Reader) interface{} {
|
||||
if !srv.IsValidShareId(shareId) {
|
||||
return httpUtil.Err400
|
||||
}
|
||||
|
||||
fileInfo, found := srv.Index.Get(shareId)
|
||||
if !found {
|
||||
return httpUtil.Err404
|
||||
}
|
||||
|
||||
if !srv.IsValidStart(start, fileInfo.Uploaded) || !srv.IsValidLength(length) {
|
||||
return httpUtil.Err400
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(srv.Conf.PathLocal, fileInfo.PathLocal)
|
||||
if !srv.Fs.CopyChunkN(fullPath, chunk, start, length) {
|
||||
return httpUtil.Err500
|
||||
}
|
||||
|
||||
if srv.Index.IncrUploaded(shareId, length) == 0 {
|
||||
return httpUtil.Err404
|
||||
}
|
||||
|
||||
return &ByteRange{
|
||||
ShareId: shareId,
|
||||
Start: start + length,
|
||||
Length: srv.Conf.MaxUpBytesPerSec,
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *SrvShare) FinishUpload(res http.ResponseWriter, req *http.Request) interface{} {
|
||||
shareId := req.FormValue(srv.Conf.KeyShareId)
|
||||
return srv.finishUpload(shareId)
|
||||
}
|
||||
|
||||
func (srv *SrvShare) finishUpload(shareId string) interface{} {
|
||||
if !srv.Index.SetState(shareId, fileidx.StateDone) {
|
||||
return httpUtil.Err404
|
||||
}
|
||||
|
||||
return &ShareInfo{
|
||||
ShareId: shareId,
|
||||
}
|
||||
}
|
||||
|
||||
func genInfoId(content string, key []byte) string {
|
||||
encrypter := encrypt.HmacEncryptor{Key: key}
|
||||
return encrypter.Encrypt([]byte(content))
|
||||
}
|
368
server/apis/upload_test.go
Normal file
368
server/apis/upload_test.go
Normal file
|
@ -0,0 +1,368 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/errutil"
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/httputil"
|
||||
"quickshare/server/libs/httpworker"
|
||||
"quickshare/server/libs/limiter"
|
||||
"quickshare/server/libs/logutil"
|
||||
"quickshare/server/libs/walls"
|
||||
)
|
||||
|
||||
const testCap = 3
|
||||
|
||||
func initServiceForUploadTest(config *cfg.Config, indexMap map[string]*fileidx.FileInfo) *SrvShare {
|
||||
logger := logutil.NewSlog(os.Stdout, config.AppName)
|
||||
setLog := func(srv *SrvShare) {
|
||||
srv.Log = logger
|
||||
}
|
||||
|
||||
setWorkerPool := func(srv *SrvShare) {
|
||||
workerPoolSize := config.WorkerPoolSize
|
||||
taskQueueSize := config.TaskQueueSize
|
||||
srv.WorkerPool = httpworker.NewWorkerPool(workerPoolSize, taskQueueSize, logger)
|
||||
}
|
||||
|
||||
setWalls := func(srv *SrvShare) {
|
||||
encrypterMaker := encrypt.JwtEncrypterMaker
|
||||
ipLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
|
||||
opLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
|
||||
srv.Walls = walls.NewAccessWalls(config, ipLimiter, opLimiter, encrypterMaker)
|
||||
}
|
||||
|
||||
setIndex := func(srv *SrvShare) {
|
||||
srv.Index = fileidx.NewMemFileIndexWithMap(len(indexMap)+testCap, indexMap)
|
||||
}
|
||||
|
||||
setFs := func(srv *SrvShare) {
|
||||
srv.Fs = &stubFsUtil{}
|
||||
}
|
||||
|
||||
setEncryptor := func(srv *SrvShare) {
|
||||
srv.Encryptor = &encrypt.HmacEncryptor{Key: config.SecretKeyByte}
|
||||
}
|
||||
|
||||
errChecker := errutil.NewErrChecker(!config.Production, logger)
|
||||
setErr := func(srv *SrvShare) {
|
||||
srv.Err = errChecker
|
||||
}
|
||||
|
||||
return InitSrvShare(config, setIndex, setWalls, setWorkerPool, setFs, setEncryptor, setLog, setErr)
|
||||
}
|
||||
|
||||
func TestStartUpload(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
FileName string
|
||||
}
|
||||
type Output struct {
|
||||
Response interface{}
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "invalid file name",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
FileName: "",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
Response: httputil.Err400,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "succeed to start uploading",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
FileName: "filename",
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
DefaultId: &fileidx.FileInfo{
|
||||
Id: DefaultId,
|
||||
DownLimit: conf.DownLimit,
|
||||
ModTime: time.Now().UnixNano(),
|
||||
PathLocal: "filename",
|
||||
Uploaded: 0,
|
||||
State: fileidx.StateUploading,
|
||||
},
|
||||
},
|
||||
Response: &ByteRange{
|
||||
ShareId: DefaultId,
|
||||
Start: 0,
|
||||
Length: conf.MaxUpBytesPerSec,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForUploadTest(conf, testCase.Init.IndexMap)
|
||||
|
||||
// verify CreateFile
|
||||
expectCreateFileName = filepath.Join(conf.PathLocal, testCase.FileName)
|
||||
|
||||
response := srv.startUpload(testCase.FileName)
|
||||
|
||||
// verify index
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
t.Fatalf("startUpload: index not equal got: %v, %v, expect: %v", srv.Index.List(), response, testCase.Output.IndexMap)
|
||||
}
|
||||
|
||||
// verify response
|
||||
switch expectRes := testCase.Output.Response.(type) {
|
||||
case *ByteRange:
|
||||
res := response.(*ByteRange)
|
||||
if res.ShareId != expectRes.ShareId ||
|
||||
res.Start != expectRes.Start ||
|
||||
res.Length != expectRes.Length {
|
||||
t.Fatalf(fmt.Sprintf("startUpload: res=%v expect=%v", res, expectRes))
|
||||
}
|
||||
case httputil.MsgRes:
|
||||
if response != expectRes {
|
||||
t.Fatalf(fmt.Sprintf("startUpload: reponse=%v expectRes=%v", response, expectRes))
|
||||
}
|
||||
default:
|
||||
t.Fatalf(fmt.Sprintf("startUpload: type not found: %T %T", testCase.Output.Response, httputil.Err400))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpload(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
ShareId string
|
||||
Start int64
|
||||
Len int64
|
||||
Chunk io.Reader
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response interface{}
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "shareid does not exist",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: DefaultId,
|
||||
Start: 0,
|
||||
Len: 1,
|
||||
Chunk: strings.NewReader(""),
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "succeed",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
DefaultId: &fileidx.FileInfo{
|
||||
Id: DefaultId,
|
||||
DownLimit: conf.MaxShares,
|
||||
PathLocal: "path/filename",
|
||||
State: fileidx.StateUploading,
|
||||
Uploaded: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: DefaultId,
|
||||
Start: 0,
|
||||
Len: 1,
|
||||
Chunk: strings.NewReader("a"),
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
DefaultId: &fileidx.FileInfo{
|
||||
Id: DefaultId,
|
||||
DownLimit: conf.MaxShares,
|
||||
PathLocal: "path/filename",
|
||||
State: fileidx.StateUploading,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
Response: &ByteRange{
|
||||
ShareId: DefaultId,
|
||||
Start: 1,
|
||||
Length: conf.MaxUpBytesPerSec,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForUploadTest(conf, testCase.Init.IndexMap)
|
||||
|
||||
response := srv.upload(
|
||||
testCase.Input.ShareId,
|
||||
testCase.Input.Start,
|
||||
testCase.Input.Len,
|
||||
testCase.Input.Chunk,
|
||||
)
|
||||
|
||||
// TODO: not verified copyChunk
|
||||
|
||||
// verify index
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
t.Fatalf("upload: index not identical got: %v want: %v", srv.Index.List(), testCase.Output.IndexMap)
|
||||
}
|
||||
// verify response
|
||||
switch response.(type) {
|
||||
case *ByteRange:
|
||||
br := testCase.Output.Response.(*ByteRange)
|
||||
res := response.(*ByteRange)
|
||||
if res.ShareId != br.ShareId || res.Start != br.Start || res.Length != br.Length {
|
||||
t.Fatalf(fmt.Sprintf("upload: response=%v expectRes=%v", res, br))
|
||||
}
|
||||
default:
|
||||
if response != testCase.Output.Response {
|
||||
t.Fatalf(fmt.Sprintf("upload: response=%v expectRes=%v", response, testCase.Output.Response))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishUpload(t *testing.T) {
|
||||
conf := cfg.NewConfig()
|
||||
conf.Production = false
|
||||
|
||||
type Init struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
}
|
||||
type Input struct {
|
||||
ShareId string
|
||||
Start int64
|
||||
Len int64
|
||||
Chunk io.Reader
|
||||
}
|
||||
type Output struct {
|
||||
IndexMap map[string]*fileidx.FileInfo
|
||||
Response interface{}
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "success",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
DefaultId: &fileidx.FileInfo{
|
||||
Id: DefaultId,
|
||||
DownLimit: conf.MaxShares,
|
||||
PathLocal: "path/filename",
|
||||
State: fileidx.StateUploading,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: DefaultId,
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{
|
||||
DefaultId: &fileidx.FileInfo{
|
||||
Id: DefaultId,
|
||||
DownLimit: conf.MaxShares,
|
||||
PathLocal: "path/filename",
|
||||
State: fileidx.StateDone,
|
||||
Uploaded: 1,
|
||||
},
|
||||
},
|
||||
Response: &ShareInfo{
|
||||
ShareId: DefaultId,
|
||||
},
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "shareId exists",
|
||||
Init: Init{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
},
|
||||
Input: Input{
|
||||
ShareId: DefaultId,
|
||||
},
|
||||
Output: Output{
|
||||
IndexMap: map[string]*fileidx.FileInfo{},
|
||||
Response: httputil.Err404,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv := initServiceForUploadTest(conf, testCase.Init.IndexMap)
|
||||
|
||||
response := srv.finishUpload(testCase.ShareId)
|
||||
|
||||
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
|
||||
t.Fatalf("finishUpload: index not identical got: %v, want: %v", srv.Index.List(), testCase.Output.IndexMap)
|
||||
}
|
||||
|
||||
switch res := response.(type) {
|
||||
case httputil.MsgRes:
|
||||
expectRes := testCase.Output.Response.(httputil.MsgRes)
|
||||
if res != expectRes {
|
||||
t.Fatalf(fmt.Sprintf("finishUpload: reponse=%v expectRes=%v", res, expectRes))
|
||||
}
|
||||
case *ShareInfo:
|
||||
info, found := testCase.Output.IndexMap[res.ShareId]
|
||||
if !found || info.State != fileidx.StateDone {
|
||||
// TODO: should use isValidUrl or better to verify result
|
||||
t.Fatalf(fmt.Sprintf("finishUpload: share info is not correct: received: %v expect: %v", res.ShareId, testCase.ShareId))
|
||||
}
|
||||
default:
|
||||
t.Fatalf(fmt.Sprintf("finishUpload: type not found: %T %T", response, testCase.Output.Response))
|
||||
}
|
||||
}
|
||||
}
|
251
server/libs/cfg/cfg.go
Normal file
251
server/libs/cfg/cfg.go
Normal file
|
@ -0,0 +1,251 @@
|
|||
package cfg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
AppName string
|
||||
AdminId string
|
||||
AdminPwd string
|
||||
SecretKey string
|
||||
SecretKeyByte []byte `json:",omitempty"`
|
||||
// server
|
||||
Production bool
|
||||
HostName string
|
||||
Port int
|
||||
// performance
|
||||
MaxUpBytesPerSec int64
|
||||
MaxDownBytesPerSec int64
|
||||
MaxRangeLength int64
|
||||
Timeout int // millisecond
|
||||
ReadTimeout int
|
||||
WriteTimeout int
|
||||
IdleTimeout int
|
||||
WorkerPoolSize int
|
||||
TaskQueueSize int
|
||||
QueueSize int
|
||||
ParseFormBufSize int64
|
||||
MaxHeaderBytes int
|
||||
DownLimit int
|
||||
MaxShares int
|
||||
LocalFileLimit int
|
||||
// Cookie
|
||||
CookieDomain string
|
||||
CookieHttpOnly bool
|
||||
CookieMaxAge int
|
||||
CookiePath string
|
||||
CookieSecure bool
|
||||
// keys
|
||||
KeyAdminId string
|
||||
KeyAdminPwd string
|
||||
KeyToken string
|
||||
KeyFileName string
|
||||
KeyFileSize string
|
||||
KeyShareId string
|
||||
KeyStart string
|
||||
KeyLen string
|
||||
KeyChunk string
|
||||
KeyAct string
|
||||
KeyExpires string
|
||||
KeyDownLimit string
|
||||
ActStartUpload string
|
||||
ActUpload string
|
||||
ActFinishUpload string
|
||||
ActLogin string
|
||||
ActLogout string
|
||||
ActShadowId string
|
||||
ActPublishId string
|
||||
ActSetDownLimit string
|
||||
ActAddLocalFiles string
|
||||
// resource id
|
||||
AllUsers string
|
||||
// opIds
|
||||
OpIdIpVisit int16
|
||||
OpIdUpload int16
|
||||
OpIdDownload int16
|
||||
OpIdLogin int16
|
||||
OpIdGetFInfo int16
|
||||
OpIdDelFInfo int16
|
||||
OpIdOpFInfo int16
|
||||
// local
|
||||
PathLocal string
|
||||
PathLogin string
|
||||
PathDownloadLogin string
|
||||
PathDownload string
|
||||
PathUpload string
|
||||
PathStartUpload string
|
||||
PathFinishUpload string
|
||||
PathFileInfo string
|
||||
PathClient string
|
||||
// rate Limiter
|
||||
LimiterCap int64
|
||||
LimiterTtl int32
|
||||
LimiterCyc int32
|
||||
BucketCap int16
|
||||
SpecialCapsStr map[string]int16
|
||||
SpecialCaps map[int16]int16
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
config := &Config{
|
||||
// secrets
|
||||
AppName: "qs",
|
||||
AdminId: "admin",
|
||||
AdminPwd: "qs",
|
||||
SecretKey: "qs",
|
||||
SecretKeyByte: []byte("qs"),
|
||||
// server
|
||||
Production: true,
|
||||
HostName: "localhost",
|
||||
Port: 8888,
|
||||
// performance
|
||||
MaxUpBytesPerSec: 500 * 1000,
|
||||
MaxDownBytesPerSec: 500 * 1000,
|
||||
MaxRangeLength: 10 * 1024 * 1024,
|
||||
Timeout: 500, // millisecond,
|
||||
ReadTimeout: 500,
|
||||
WriteTimeout: 43200000,
|
||||
IdleTimeout: 10000,
|
||||
WorkerPoolSize: 2,
|
||||
TaskQueueSize: 2,
|
||||
QueueSize: 2,
|
||||
ParseFormBufSize: 600,
|
||||
MaxHeaderBytes: 1 << 15, // 32KB
|
||||
DownLimit: -1,
|
||||
MaxShares: 1 << 31,
|
||||
LocalFileLimit: -1,
|
||||
// Cookie
|
||||
CookieDomain: "",
|
||||
CookieHttpOnly: false,
|
||||
CookieMaxAge: 3600 * 24 * 30, // one week,
|
||||
CookiePath: "/",
|
||||
CookieSecure: false,
|
||||
// keys
|
||||
KeyAdminId: "adminid",
|
||||
KeyAdminPwd: "adminpwd",
|
||||
KeyToken: "token",
|
||||
KeyFileName: "fname",
|
||||
KeyFileSize: "size",
|
||||
KeyShareId: "shareid",
|
||||
KeyStart: "start",
|
||||
KeyLen: "len",
|
||||
KeyChunk: "chunk",
|
||||
KeyAct: "act",
|
||||
KeyExpires: "expires",
|
||||
KeyDownLimit: "downlimit",
|
||||
ActStartUpload: "startupload",
|
||||
ActUpload: "upload",
|
||||
ActFinishUpload: "finishupload",
|
||||
ActLogin: "login",
|
||||
ActLogout: "logout",
|
||||
ActShadowId: "shadowid",
|
||||
ActPublishId: "publishid",
|
||||
ActSetDownLimit: "setdownlimit",
|
||||
ActAddLocalFiles: "addlocalfiles",
|
||||
AllUsers: "allusers",
|
||||
// opIds
|
||||
OpIdIpVisit: 0,
|
||||
OpIdUpload: 1,
|
||||
OpIdDownload: 2,
|
||||
OpIdLogin: 3,
|
||||
OpIdGetFInfo: 4,
|
||||
OpIdDelFInfo: 5,
|
||||
OpIdOpFInfo: 6,
|
||||
// local
|
||||
PathLocal: "files",
|
||||
PathLogin: "/login",
|
||||
PathDownloadLogin: "/download-login",
|
||||
PathDownload: "/download",
|
||||
PathUpload: "/upload",
|
||||
PathStartUpload: "/startupload",
|
||||
PathFinishUpload: "/finishupload",
|
||||
PathFileInfo: "/fileinfo",
|
||||
PathClient: "/",
|
||||
// rate Limiter
|
||||
LimiterCap: 256, // how many op supported for each user
|
||||
LimiterTtl: 3600, // second
|
||||
LimiterCyc: 1, // second
|
||||
BucketCap: 3, // how many op can do per LimiterCyc sec
|
||||
SpecialCaps: map[int16]int16{
|
||||
0: 5, // ip
|
||||
1: 1, // upload
|
||||
2: 1, // download
|
||||
3: 1, // login
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func NewConfigFrom(path string) *Config {
|
||||
configBytes, readErr := ioutil.ReadFile(path)
|
||||
if readErr != nil {
|
||||
panic(fmt.Sprintf("config file not found: %s", path))
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
marshalErr := json.Unmarshal(configBytes, config)
|
||||
|
||||
// TODO: look for a better solution
|
||||
config.SpecialCaps = make(map[int16]int16)
|
||||
for strKey, value := range config.SpecialCapsStr {
|
||||
key, parseKeyErr := strconv.ParseInt(strKey, 10, 16)
|
||||
if parseKeyErr != nil {
|
||||
panic("fail to parse SpecialCapsStr, its type should be map[int16]int16")
|
||||
}
|
||||
config.SpecialCaps[int16(key)] = value
|
||||
}
|
||||
|
||||
if marshalErr != nil {
|
||||
panic("config file format is incorrect")
|
||||
}
|
||||
|
||||
config.SecretKeyByte = []byte(config.SecretKey)
|
||||
if config.HostName == "" {
|
||||
hostName, err := GetLocalAddr()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
config.HostName = hostName.String()
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func GetLocalAddr() (net.IP, error) {
|
||||
fmt.Println(`config.HostName is empty(""), choose one IP for listening automatically.`)
|
||||
infs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
panic("fail to get net interfaces")
|
||||
}
|
||||
|
||||
for _, inf := range infs {
|
||||
if inf.Flags&4 != 4 && !strings.Contains(inf.Name, "docker") {
|
||||
addrs, err := inf.Addrs()
|
||||
if err != nil {
|
||||
panic("fail to get addrs of interface")
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
switch v := addr.(type) {
|
||||
case *net.IPAddr:
|
||||
if !strings.Contains(v.IP.String(), ":") {
|
||||
return v.IP, nil
|
||||
}
|
||||
case *net.IPNet:
|
||||
if !strings.Contains(v.IP.String(), ":") {
|
||||
return v.IP, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("no addr found")
|
||||
}
|
17
server/libs/encrypt/encrypter_hmac.go
Normal file
17
server/libs/encrypt/encrypter_hmac.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package encrypt
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
type HmacEncryptor struct {
|
||||
Key []byte
|
||||
}
|
||||
|
||||
func (encryptor *HmacEncryptor) Encrypt(content []byte) string {
|
||||
mac := hmac.New(sha256.New, encryptor.Key)
|
||||
mac.Write(content)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
5
server/libs/encrypt/encryptor.go
Normal file
5
server/libs/encrypt/encryptor.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package encrypt
|
||||
|
||||
type Encryptor interface {
|
||||
Encrypt(content []byte) string
|
||||
}
|
53
server/libs/encrypt/jwt.go
Normal file
53
server/libs/encrypt/jwt.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package encrypt
|
||||
|
||||
import (
|
||||
"github.com/robbert229/jwt"
|
||||
)
|
||||
|
||||
func JwtEncrypterMaker(secret string) TokenEncrypter {
|
||||
return &JwtEncrypter{
|
||||
alg: jwt.HmacSha256(secret),
|
||||
claims: jwt.NewClaim(),
|
||||
}
|
||||
}
|
||||
|
||||
type JwtEncrypter struct {
|
||||
alg jwt.Algorithm
|
||||
claims *jwt.Claims
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) Add(key string, value string) bool {
|
||||
encrypter.claims.Set(key, value)
|
||||
return true
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) FromStr(token string) bool {
|
||||
claims, err := encrypter.alg.Decode(token)
|
||||
// TODO: should return error or error info will lost
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
encrypter.claims = claims
|
||||
return true
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) Get(key string) (string, bool) {
|
||||
iValue, err := encrypter.claims.Get(key)
|
||||
// TODO: should return error or error info will lost
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return iValue.(string), true
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) ToStr() (string, bool) {
|
||||
token, err := encrypter.alg.Encode(encrypter.claims)
|
||||
|
||||
// TODO: should return error or error info will lost
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return token, true
|
||||
}
|
11
server/libs/encrypt/token_encrypter.go
Normal file
11
server/libs/encrypt/token_encrypter.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package encrypt
|
||||
|
||||
type EncrypterMaker func(string) TokenEncrypter
|
||||
|
||||
// TODO: name should be Encrypter?
|
||||
type TokenEncrypter interface {
|
||||
Add(string, string) bool
|
||||
FromStr(string) bool
|
||||
Get(string) (string, bool)
|
||||
ToStr() (string, bool)
|
||||
}
|
59
server/libs/errutil/ettutil.go
Normal file
59
server/libs/errutil/ettutil.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package errutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/logutil"
|
||||
)
|
||||
|
||||
type ErrUtil interface {
|
||||
IsErr(err error) bool
|
||||
IsFatalErr(err error) bool
|
||||
RecoverPanic()
|
||||
}
|
||||
|
||||
func NewErrChecker(logStack bool, logger logutil.LogUtil) ErrUtil {
|
||||
return &ErrChecker{logStack: logStack, log: logger}
|
||||
}
|
||||
|
||||
type ErrChecker struct {
|
||||
log logutil.LogUtil
|
||||
logStack bool
|
||||
}
|
||||
|
||||
// IsErr checks if error occurs
|
||||
func (e *ErrChecker) IsErr(err error) bool {
|
||||
if err != nil {
|
||||
e.log.Printf("Error:%q\n", err)
|
||||
if e.logStack {
|
||||
e.log.Println(debug.Stack())
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsFatalPanic should be used with defer
|
||||
func (e *ErrChecker) IsFatalErr(fe error) bool {
|
||||
if fe != nil {
|
||||
e.log.Printf("Panic:%q", fe)
|
||||
if e.logStack {
|
||||
e.log.Println(debug.Stack())
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RecoverPanic catchs the panic and logs panic information
|
||||
func (e *ErrChecker) RecoverPanic() {
|
||||
if r := recover(); r != nil {
|
||||
e.log.Printf("Recovered:%v", r)
|
||||
if e.logStack {
|
||||
e.log.Println(debug.Stack())
|
||||
}
|
||||
}
|
||||
}
|
177
server/libs/fileidx/file_idx.go
Normal file
177
server/libs/fileidx/file_idx.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package fileidx
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// StateStarted = after startUpload before upload
|
||||
StateStarted = "started"
|
||||
// StateUploading =after upload before finishUpload
|
||||
StateUploading = "uploading"
|
||||
// StateDone = after finishedUpload
|
||||
StateDone = "done"
|
||||
)
|
||||
|
||||
type FileInfo struct {
|
||||
Id string
|
||||
DownLimit int
|
||||
ModTime int64
|
||||
PathLocal string
|
||||
State string
|
||||
Uploaded int64
|
||||
}
|
||||
|
||||
type FileIndex interface {
|
||||
Add(fileInfo *FileInfo) int
|
||||
Del(id string)
|
||||
SetId(id string, newId string) bool
|
||||
SetDownLimit(id string, downLimit int) bool
|
||||
DecrDownLimit(id string) (int, bool)
|
||||
SetState(id string, state string) bool
|
||||
IncrUploaded(id string, uploaded int64) int64
|
||||
Get(id string) (*FileInfo, bool)
|
||||
List() map[string]*FileInfo
|
||||
}
|
||||
|
||||
func NewMemFileIndex(cap int) *MemFileIndex {
|
||||
return &MemFileIndex{
|
||||
cap: cap,
|
||||
infos: make(map[string]*FileInfo, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func NewMemFileIndexWithMap(cap int, infos map[string]*FileInfo) *MemFileIndex {
|
||||
return &MemFileIndex{
|
||||
cap: cap,
|
||||
infos: infos,
|
||||
}
|
||||
}
|
||||
|
||||
type MemFileIndex struct {
|
||||
cap int
|
||||
infos map[string]*FileInfo
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) Add(fileInfo *FileInfo) int {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
if len(idx.infos) >= idx.cap {
|
||||
return 1
|
||||
}
|
||||
|
||||
if _, found := idx.infos[fileInfo.Id]; found {
|
||||
return -1
|
||||
}
|
||||
|
||||
idx.infos[fileInfo.Id] = fileInfo
|
||||
return 0
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) Del(id string) {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
delete(idx.infos, id)
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) SetId(id string, newId string) bool {
|
||||
if id == newId {
|
||||
return true
|
||||
}
|
||||
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, foundNewId := idx.infos[newId]; foundNewId {
|
||||
return false
|
||||
}
|
||||
|
||||
idx.infos[newId] = info
|
||||
idx.infos[newId].Id = newId
|
||||
delete(idx.infos, id)
|
||||
return true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) SetDownLimit(id string, downLimit int) bool {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
info.DownLimit = downLimit
|
||||
return true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) DecrDownLimit(id string) (int, bool) {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found || info.State != StateDone {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if info.DownLimit == 0 {
|
||||
return 1, false
|
||||
}
|
||||
|
||||
if info.DownLimit > 0 {
|
||||
// info.DownLimit means unlimited
|
||||
info.DownLimit = info.DownLimit - 1
|
||||
}
|
||||
return 1, true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) SetState(id string, state string) bool {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
info.State = state
|
||||
return true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) IncrUploaded(id string, uploaded int64) int64 {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return 0
|
||||
}
|
||||
|
||||
info.Uploaded = info.Uploaded + uploaded
|
||||
return info.Uploaded
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) Get(id string) (*FileInfo, bool) {
|
||||
idx.mux.RLock()
|
||||
defer idx.mux.RUnlock()
|
||||
|
||||
infos, found := idx.infos[id]
|
||||
return infos, found
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) List() map[string]*FileInfo {
|
||||
idx.mux.RLock()
|
||||
defer idx.mux.RUnlock()
|
||||
|
||||
return idx.infos
|
||||
}
|
||||
|
||||
// TODO: add unit tests
|
118
server/libs/fsutil/fsutil.go
Normal file
118
server/libs/fsutil/fsutil.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package fsutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/errutil"
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/qtube"
|
||||
)
|
||||
|
||||
type FsUtil interface {
|
||||
CreateFile(fullPath string) error
|
||||
CopyChunkN(fullPath string, chunk io.Reader, start int64, length int64) bool
|
||||
DelFile(fullPath string) bool
|
||||
Open(fullPath string) (qtube.ReadSeekCloser, error)
|
||||
MkdirAll(path string, mode os.FileMode) bool
|
||||
Readdir(dirName string, n int) ([]*fileidx.FileInfo, error)
|
||||
}
|
||||
|
||||
func NewSimpleFs(errUtil errutil.ErrUtil) FsUtil {
|
||||
return &SimpleFs{
|
||||
Err: errUtil,
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleFs struct {
|
||||
Err errutil.ErrUtil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrExists = errors.New("file exists")
|
||||
ErrUnknown = errors.New("unknown error")
|
||||
)
|
||||
|
||||
func (sfs *SimpleFs) CreateFile(fullPath string) error {
|
||||
flag := os.O_CREATE | os.O_EXCL | os.O_RDONLY
|
||||
perm := os.FileMode(0644)
|
||||
newFile, err := os.OpenFile(fullPath, flag, perm)
|
||||
defer newFile.Close()
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if os.IsExist(err) {
|
||||
return ErrExists
|
||||
} else {
|
||||
return ErrUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (sfs *SimpleFs) CopyChunkN(fullPath string, chunk io.Reader, start int64, length int64) bool {
|
||||
flag := os.O_WRONLY
|
||||
perm := os.FileMode(0644)
|
||||
file, openErr := os.OpenFile(fullPath, flag, perm)
|
||||
|
||||
defer file.Close()
|
||||
if sfs.Err.IsErr(openErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := file.Seek(start, io.SeekStart); sfs.Err.IsErr(err) {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := io.CopyN(file, chunk, length); sfs.Err.IsErr(err) && err != io.EOF {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (sfs *SimpleFs) DelFile(fullPath string) bool {
|
||||
return !sfs.Err.IsErr(os.Remove(fullPath))
|
||||
}
|
||||
|
||||
func (sfs *SimpleFs) MkdirAll(path string, mode os.FileMode) bool {
|
||||
err := os.MkdirAll(path, mode)
|
||||
return !sfs.Err.IsErr(err)
|
||||
}
|
||||
|
||||
// TODO: not support read from last seek position
|
||||
func (sfs *SimpleFs) Readdir(dirName string, n int) ([]*fileidx.FileInfo, error) {
|
||||
dir, openErr := os.Open(dirName)
|
||||
defer dir.Close()
|
||||
|
||||
if sfs.Err.IsErr(openErr) {
|
||||
return []*fileidx.FileInfo{}, openErr
|
||||
}
|
||||
|
||||
osFileInfos, readErr := dir.Readdir(n)
|
||||
if sfs.Err.IsErr(readErr) && readErr != io.EOF {
|
||||
return []*fileidx.FileInfo{}, readErr
|
||||
}
|
||||
|
||||
fileInfos := make([]*fileidx.FileInfo, 0)
|
||||
for _, osFileInfo := range osFileInfos {
|
||||
if osFileInfo.Mode().IsRegular() {
|
||||
fileInfos = append(
|
||||
fileInfos,
|
||||
&fileidx.FileInfo{
|
||||
ModTime: osFileInfo.ModTime().UnixNano(),
|
||||
PathLocal: osFileInfo.Name(),
|
||||
Uploaded: osFileInfo.Size(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return fileInfos, readErr
|
||||
}
|
||||
|
||||
// the associated file descriptor has mode O_RDONLY as using os.Open
|
||||
func (sfs *SimpleFs) Open(fullPath string) (qtube.ReadSeekCloser, error) {
|
||||
return os.Open(fullPath)
|
||||
}
|
84
server/libs/httputil/httputil.go
Normal file
84
server/libs/httputil/httputil.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/errutil"
|
||||
)
|
||||
|
||||
type MsgRes struct {
|
||||
Code int
|
||||
Msg string
|
||||
}
|
||||
|
||||
var (
|
||||
Err400 = MsgRes{Code: http.StatusBadRequest, Msg: "Bad Request"}
|
||||
Err401 = MsgRes{Code: http.StatusUnauthorized, Msg: "Unauthorized"}
|
||||
Err404 = MsgRes{Code: http.StatusNotFound, Msg: "Not Found"}
|
||||
Err412 = MsgRes{Code: http.StatusPreconditionFailed, Msg: "Precondition Failed"}
|
||||
Err429 = MsgRes{Code: http.StatusTooManyRequests, Msg: "Too Many Requests"}
|
||||
Err500 = MsgRes{Code: http.StatusInternalServerError, Msg: "Internal Server Error"}
|
||||
Err503 = MsgRes{Code: http.StatusServiceUnavailable, Msg: "Service Unavailable"}
|
||||
Err504 = MsgRes{Code: http.StatusGatewayTimeout, Msg: "Gateway Timeout"}
|
||||
Ok200 = MsgRes{Code: http.StatusOK, Msg: "OK"}
|
||||
)
|
||||
|
||||
type HttpUtil interface {
|
||||
GetCookie(cookies []*http.Cookie, key string) string
|
||||
SetCookie(res http.ResponseWriter, key string, val string)
|
||||
Fill(msg interface{}, res http.ResponseWriter) int
|
||||
}
|
||||
|
||||
type QHttpUtil struct {
|
||||
CookieDomain string
|
||||
CookieHttpOnly bool
|
||||
CookieMaxAge int
|
||||
CookiePath string
|
||||
CookieSecure bool
|
||||
Err errutil.ErrUtil
|
||||
}
|
||||
|
||||
func (q *QHttpUtil) GetCookie(cookies []*http.Cookie, key string) string {
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == key {
|
||||
return cookie.Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (q *QHttpUtil) SetCookie(res http.ResponseWriter, key string, val string) {
|
||||
cookie := http.Cookie{
|
||||
Name: key,
|
||||
Value: val,
|
||||
Domain: q.CookieDomain,
|
||||
Expires: time.Now().Add(time.Duration(q.CookieMaxAge) * time.Second),
|
||||
HttpOnly: q.CookieHttpOnly,
|
||||
MaxAge: q.CookieMaxAge,
|
||||
Secure: q.CookieSecure,
|
||||
Path: q.CookiePath,
|
||||
}
|
||||
|
||||
res.Header().Set("Set-Cookie", cookie.String())
|
||||
}
|
||||
|
||||
func (q *QHttpUtil) Fill(msg interface{}, res http.ResponseWriter) int {
|
||||
if msg == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
msgBytes, marsErr := json.Marshal(msg)
|
||||
if q.Err.IsErr(marsErr) {
|
||||
return 0
|
||||
}
|
||||
|
||||
wrote, writeErr := res.Write(msgBytes)
|
||||
if q.Err.IsErr(writeErr) {
|
||||
return 0
|
||||
}
|
||||
return wrote
|
||||
}
|
130
server/libs/httpworker/worker.go
Normal file
130
server/libs/httpworker/worker.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package httpworker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/logutil"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWorkerNotFound = errors.New("worker not found")
|
||||
ErrTimeout = errors.New("timeout")
|
||||
)
|
||||
|
||||
type DoFunc func(http.ResponseWriter, *http.Request)
|
||||
|
||||
type Task struct {
|
||||
Ack chan error
|
||||
Do DoFunc
|
||||
Res http.ResponseWriter
|
||||
Req *http.Request
|
||||
}
|
||||
|
||||
type Workers interface {
|
||||
Put(*Task) bool
|
||||
IsInTime(ack chan error, msec time.Duration) error
|
||||
}
|
||||
|
||||
type WorkerPool struct {
|
||||
queue chan *Task
|
||||
size int
|
||||
workers []*Worker
|
||||
log logutil.LogUtil // TODO: should not pass log here
|
||||
}
|
||||
|
||||
func NewWorkerPool(poolSize int, queueSize int, log logutil.LogUtil) Workers {
|
||||
queue := make(chan *Task, queueSize)
|
||||
workers := make([]*Worker, 0, poolSize)
|
||||
|
||||
for i := 0; i < poolSize; i++ {
|
||||
worker := &Worker{
|
||||
Id: uint64(i),
|
||||
queue: queue,
|
||||
log: log,
|
||||
}
|
||||
|
||||
go worker.Start()
|
||||
workers = append(workers, worker)
|
||||
}
|
||||
|
||||
return &WorkerPool{
|
||||
queue: queue,
|
||||
size: poolSize,
|
||||
workers: workers,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (pool *WorkerPool) Put(task *Task) bool {
|
||||
if len(pool.queue) >= pool.size {
|
||||
return false
|
||||
}
|
||||
|
||||
pool.queue <- task
|
||||
return true
|
||||
}
|
||||
|
||||
func (pool *WorkerPool) IsInTime(ack chan error, msec time.Duration) error {
|
||||
start := time.Now().UnixNano()
|
||||
timeout := make(chan error)
|
||||
|
||||
go func() {
|
||||
time.Sleep(msec)
|
||||
timeout <- ErrTimeout
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ack:
|
||||
if err == nil {
|
||||
pool.log.Printf(
|
||||
"finish cost: %d usec",
|
||||
(time.Now().UnixNano()-start)/1000,
|
||||
)
|
||||
} else {
|
||||
pool.log.Printf(
|
||||
"finish with error cost: %d usec",
|
||||
(time.Now().UnixNano()-start)/1000,
|
||||
)
|
||||
}
|
||||
return err
|
||||
case errTimeout := <-timeout:
|
||||
pool.log.Printf("timeout cost: %d usec", (time.Now().UnixNano()-start)/1000)
|
||||
return errTimeout
|
||||
}
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
Id uint64
|
||||
queue chan *Task
|
||||
log logutil.LogUtil
|
||||
}
|
||||
|
||||
func (worker *Worker) RecoverPanic() {
|
||||
if r := recover(); r != nil {
|
||||
worker.log.Printf("Recovered:%v stack: %v", r, debug.Stack())
|
||||
// restart worker and IsInTime will return timeout error for last task
|
||||
worker.Start()
|
||||
}
|
||||
}
|
||||
|
||||
func (worker *Worker) Start() {
|
||||
defer worker.RecoverPanic()
|
||||
|
||||
for {
|
||||
task := <-worker.queue
|
||||
if task.Do != nil {
|
||||
task.Do(task.Res, task.Req)
|
||||
task.Ack <- nil
|
||||
} else {
|
||||
task.Ack <- ErrWorkerNotFound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceFunc lets you return struct directly
|
||||
type ServiceFunc func(http.ResponseWriter, *http.Request) interface{}
|
5
server/libs/limiter/limiter.go
Normal file
5
server/libs/limiter/limiter.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package limiter
|
||||
|
||||
type Limiter interface {
|
||||
Access(string, int16) bool
|
||||
}
|
220
server/libs/limiter/rate_limiter.go
Normal file
220
server/libs/limiter/rate_limiter.go
Normal file
|
@ -0,0 +1,220 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func now() int32 {
|
||||
return int32(time.Now().Unix())
|
||||
}
|
||||
|
||||
func afterCyc(cyc int32) int32 {
|
||||
return int32(time.Now().Unix()) + cyc
|
||||
}
|
||||
|
||||
func afterTtl(ttl int32) int32 {
|
||||
return int32(time.Now().Unix()) + ttl
|
||||
}
|
||||
|
||||
type Bucket struct {
|
||||
Refresh int32
|
||||
Tokens int16
|
||||
}
|
||||
|
||||
func NewBucket(cyc int32, cap int16) *Bucket {
|
||||
return &Bucket{
|
||||
Refresh: afterCyc(cyc),
|
||||
Tokens: cap,
|
||||
}
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Expired int32
|
||||
Buckets map[int16]*Bucket
|
||||
}
|
||||
|
||||
func NewItem(ttl int32) *Item {
|
||||
return &Item{
|
||||
Expired: afterTtl(ttl),
|
||||
Buckets: make(map[int16]*Bucket),
|
||||
}
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
items map[string]*Item
|
||||
bucketCap int16
|
||||
customCaps map[int16]int16
|
||||
cap int64
|
||||
cyc int32 // how much time, item autoclean will be executed, bucket will be refreshed
|
||||
ttl int32 // how much time, item will be expired(but not cleaned)
|
||||
mux sync.RWMutex
|
||||
snapshot map[string]map[int16]*Bucket
|
||||
}
|
||||
|
||||
func NewRateLimiter(cap int64, ttl int32, cyc int32, bucketCap int16, customCaps map[int16]int16) Limiter {
|
||||
if cap < 1 || ttl < 1 || cyc < 1 || bucketCap < 1 {
|
||||
panic("cap | bucketCap | ttl | cycle cant be less than 1")
|
||||
}
|
||||
|
||||
limiter := &RateLimiter{
|
||||
items: make(map[string]*Item, cap),
|
||||
bucketCap: bucketCap,
|
||||
customCaps: customCaps,
|
||||
cap: cap,
|
||||
ttl: ttl,
|
||||
cyc: cyc,
|
||||
}
|
||||
|
||||
go limiter.autoClean()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) getBucketCap(opId int16) int16 {
|
||||
bucketCap, existed := limiter.customCaps[opId]
|
||||
if !existed {
|
||||
return limiter.bucketCap
|
||||
}
|
||||
return bucketCap
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) Access(itemId string, opId int16) bool {
|
||||
limiter.mux.Lock()
|
||||
defer limiter.mux.Unlock()
|
||||
|
||||
item, itemExisted := limiter.items[itemId]
|
||||
if !itemExisted {
|
||||
if int64(len(limiter.items)) >= limiter.cap {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.items[itemId] = NewItem(limiter.ttl)
|
||||
limiter.items[itemId].Buckets[opId] = NewBucket(limiter.cyc, limiter.getBucketCap(opId)-1)
|
||||
return true
|
||||
}
|
||||
|
||||
bucket, bucketExisted := item.Buckets[opId]
|
||||
if !bucketExisted {
|
||||
item.Buckets[opId] = NewBucket(limiter.cyc, limiter.getBucketCap(opId)-1)
|
||||
return true
|
||||
}
|
||||
|
||||
if bucket.Refresh > now() {
|
||||
if bucket.Tokens > 0 {
|
||||
bucket.Tokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
bucket.Refresh = afterCyc(limiter.cyc)
|
||||
bucket.Tokens = limiter.getBucketCap(opId) - 1
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetCap() int64 {
|
||||
return limiter.cap
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetSize() int64 {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
return int64(len(limiter.items))
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) ExpandCap(cap int64) bool {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
if cap <= int64(len(limiter.items)) {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.cap = cap
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetTTL() int32 {
|
||||
return limiter.ttl
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) UpdateTTL(ttl int32) bool {
|
||||
if ttl < 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.ttl = ttl
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetCyc() int32 {
|
||||
return limiter.cyc
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) UpdateCyc(cyc int32) bool {
|
||||
if limiter.cyc < 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.cyc = cyc
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) Snapshot() map[string]map[int16]*Bucket {
|
||||
return limiter.snapshot
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) autoClean() {
|
||||
for {
|
||||
if limiter.cyc == 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(int64(limiter.cyc) * 1000000000))
|
||||
limiter.clean()
|
||||
}
|
||||
}
|
||||
|
||||
// clean may add affect other operations, do frequently?
|
||||
func (limiter *RateLimiter) clean() {
|
||||
limiter.snapshot = make(map[string]map[int16]*Bucket)
|
||||
now := now()
|
||||
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
for key, item := range limiter.items {
|
||||
if item.Expired <= now {
|
||||
delete(limiter.items, key)
|
||||
} else {
|
||||
limiter.snapshot[key] = item.Buckets
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only for test
|
||||
func (limiter *RateLimiter) exist(id string) bool {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
_, existed := limiter.items[id]
|
||||
return existed
|
||||
}
|
||||
|
||||
// Only for test
|
||||
func (limiter *RateLimiter) truncate() {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
for key, _ := range limiter.items {
|
||||
delete(limiter.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Only for test
|
||||
func (limiter *RateLimiter) get(id string) (*Item, bool) {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
item, existed := limiter.items[id]
|
||||
return item, existed
|
||||
}
|
161
server/libs/limiter/rate_limiter_test.go
Normal file
161
server/libs/limiter/rate_limiter_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
const rndCap = 10000
|
||||
const addCap = 1
|
||||
|
||||
// how to set time
|
||||
// extend: wait can be greater than ttl/2
|
||||
// cyc is smaller than ttl and wait, then it can be clean in time
|
||||
const cap = 40
|
||||
const ttl = 3
|
||||
const cyc = 1
|
||||
const bucketCap = 2
|
||||
const id1 = "id1"
|
||||
const id2 = "id2"
|
||||
const op1 int16 = 0
|
||||
const op2 int16 = 1
|
||||
|
||||
var customCaps = map[int16]int16{
|
||||
op2: 1000,
|
||||
}
|
||||
|
||||
const wait = 1
|
||||
|
||||
var limiter = NewRateLimiter(cap, ttl, cyc, bucketCap, customCaps).(*RateLimiter)
|
||||
|
||||
func printItem(id string) {
|
||||
item, existed := limiter.get(id1)
|
||||
if existed {
|
||||
fmt.Println("expired, now, existed", item.Expired, now(), existed)
|
||||
for id, bucket := range item.Buckets {
|
||||
fmt.Println("\tid, bucket", id, bucket)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("not existed")
|
||||
}
|
||||
}
|
||||
|
||||
var idSeed = 0
|
||||
|
||||
func randId() string {
|
||||
idSeed++
|
||||
return fmt.Sprintf("%d", idSeed)
|
||||
}
|
||||
|
||||
func TestAccess(t *testing.T) {
|
||||
func(t *testing.T) {
|
||||
canAccess := limiter.Access(id1, op1)
|
||||
if !canAccess {
|
||||
t.Fatal("access: fail")
|
||||
}
|
||||
|
||||
for i := 0; i < bucketCap; i++ {
|
||||
canAccess = limiter.Access(id1, op1)
|
||||
}
|
||||
|
||||
if canAccess {
|
||||
t.Fatal("access: fail to deny access")
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(limiter.GetCyc()) * time.Second)
|
||||
|
||||
canAccess = limiter.Access(id1, op1)
|
||||
if !canAccess {
|
||||
t.Fatal("access: fail to refresh tokens")
|
||||
}
|
||||
}(t)
|
||||
}
|
||||
|
||||
func TestCap(t *testing.T) {
|
||||
originalCap := limiter.GetCap()
|
||||
fmt.Printf("cap:info: %d\n", originalCap)
|
||||
|
||||
ok := limiter.ExpandCap(originalCap + addCap)
|
||||
|
||||
if !ok || limiter.GetCap() != originalCap+addCap {
|
||||
t.Fatal("cap: fail to expand")
|
||||
}
|
||||
|
||||
ok = limiter.ExpandCap(limiter.GetSize() - addCap)
|
||||
if ok {
|
||||
t.Fatal("cap: shrink cap")
|
||||
}
|
||||
|
||||
ids := []string{}
|
||||
for limiter.GetSize() < limiter.GetCap() {
|
||||
id := randId()
|
||||
ids = append(ids, id)
|
||||
|
||||
ok := limiter.Access(id, 0)
|
||||
if !ok {
|
||||
t.Fatal("cap: not full")
|
||||
}
|
||||
}
|
||||
|
||||
if limiter.GetSize() != limiter.GetCap() {
|
||||
t.Fatal("cap: incorrect size")
|
||||
}
|
||||
|
||||
if limiter.Access(randId(), 0) {
|
||||
t.Fatal("cap: more than cap")
|
||||
}
|
||||
|
||||
limiter.truncate()
|
||||
}
|
||||
|
||||
func TestTtl(t *testing.T) {
|
||||
var addTtl int32 = 1
|
||||
originalTTL := limiter.GetTTL()
|
||||
fmt.Printf("ttl:info: %d\n", originalTTL)
|
||||
|
||||
limiter.UpdateTTL(originalTTL + addTtl)
|
||||
if limiter.GetTTL() != originalTTL+addTtl {
|
||||
t.Fatal("ttl: update fail")
|
||||
}
|
||||
}
|
||||
|
||||
func cycTest(t *testing.T) {
|
||||
var addCyc int32 = 1
|
||||
originalCyc := limiter.GetCyc()
|
||||
fmt.Printf("cyc:info: %d\n", originalCyc)
|
||||
|
||||
limiter.UpdateCyc(originalCyc + addCyc)
|
||||
if limiter.GetCyc() != originalCyc+addCyc {
|
||||
t.Fatal("cyc: update fail")
|
||||
}
|
||||
}
|
||||
|
||||
func autoCleanTest(t *testing.T) {
|
||||
ids := []string{
|
||||
randId(),
|
||||
randId(),
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
ok := limiter.Access(id, 0)
|
||||
if ok {
|
||||
t.Fatal("autoClean: warning: add fail")
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(limiter.GetTTL()+wait) * time.Second)
|
||||
|
||||
for _, id := range ids {
|
||||
_, exist := limiter.get(id)
|
||||
if exist {
|
||||
t.Fatal("autoClean: item still exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// func snapshotTest(t *testing.T) {
|
||||
// }
|
7
server/libs/logutil/logutil.go
Normal file
7
server/libs/logutil/logutil.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package logutil
|
||||
|
||||
type LogUtil interface {
|
||||
Print(v ...interface{})
|
||||
Printf(format string, v ...interface{})
|
||||
Println(v ...interface{})
|
||||
}
|
12
server/libs/logutil/slogger.go
Normal file
12
server/libs/logutil/slogger.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package logutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
func NewSlog(out io.Writer, prefix string) LogUtil {
|
||||
return log.New(out, prefix, log.Ldate|log.Ltime|log.Lshortfile)
|
||||
}
|
||||
|
||||
type Slog *log.Logger
|
13
server/libs/qtube/downloader.go
Normal file
13
server/libs/qtube/downloader.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package qtube
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
)
|
||||
|
||||
type Downloader interface {
|
||||
ServeFile(res http.ResponseWriter, req *http.Request, fileInfo *fileidx.FileInfo) error
|
||||
}
|
280
server/libs/qtube/qtube.go
Normal file
280
server/libs/qtube/qtube.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
package qtube
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCopy = errors.New("ServeFile: copy error")
|
||||
ErrUnknown = errors.New("ServeFile: unknown error")
|
||||
)
|
||||
|
||||
type httpRange struct {
|
||||
start, length int64
|
||||
}
|
||||
|
||||
func (ra *httpRange) GetStart() int64 {
|
||||
return ra.start
|
||||
}
|
||||
func (ra *httpRange) GetLength() int64 {
|
||||
return ra.length
|
||||
}
|
||||
func (ra *httpRange) SetStart(start int64) {
|
||||
ra.start = start
|
||||
}
|
||||
func (ra *httpRange) SetLength(length int64) {
|
||||
ra.length = length
|
||||
}
|
||||
|
||||
func NewQTube(root string, copySpeed, maxRangeLen int64, filer FileReadSeekCloser) Downloader {
|
||||
return &QTube{
|
||||
Root: root,
|
||||
BytesPerSec: copySpeed,
|
||||
MaxRangeLen: maxRangeLen,
|
||||
Filer: filer,
|
||||
}
|
||||
}
|
||||
|
||||
type QTube struct {
|
||||
Root string
|
||||
BytesPerSec int64
|
||||
MaxRangeLen int64
|
||||
Filer FileReadSeekCloser
|
||||
}
|
||||
|
||||
type FileReadSeekCloser interface {
|
||||
Open(filePath string) (ReadSeekCloser, error)
|
||||
}
|
||||
|
||||
type ReadSeekCloser interface {
|
||||
io.Reader
|
||||
io.Seeker
|
||||
io.Closer
|
||||
}
|
||||
|
||||
const (
|
||||
ErrorInvalidRange = "ServeFile: invalid Range"
|
||||
ErrorInvalidSize = "ServeFile: invalid Range total size"
|
||||
)
|
||||
|
||||
func (tb *QTube) ServeFile(res http.ResponseWriter, req *http.Request, fileInfo *fileidx.FileInfo) error {
|
||||
headerRange := req.Header.Get("Range")
|
||||
|
||||
switch {
|
||||
case req.Method == http.MethodHead:
|
||||
res.Header().Set("Accept-Ranges", "bytes")
|
||||
res.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Uploaded))
|
||||
res.Header().Set("Content-Type", "application/octet-stream")
|
||||
res.WriteHeader(http.StatusOK)
|
||||
|
||||
return nil
|
||||
case headerRange == "":
|
||||
return tb.serveAll(res, fileInfo)
|
||||
default:
|
||||
return tb.serveRanges(res, headerRange, fileInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (tb *QTube) serveAll(res http.ResponseWriter, fileInfo *fileidx.FileInfo) error {
|
||||
res.Header().Set("Accept-Ranges", "bytes")
|
||||
res.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filepath.Base(fileInfo.PathLocal)))
|
||||
res.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Uploaded))
|
||||
res.Header().Set("Content-Type", "application/octet-stream")
|
||||
res.Header().Set("Last-Modified", time.Unix(fileInfo.ModTime, 0).UTC().Format(http.TimeFormat))
|
||||
res.WriteHeader(http.StatusOK)
|
||||
|
||||
// TODO: need verify path
|
||||
file, openErr := tb.Filer.Open(filepath.Join(tb.Root, fileInfo.PathLocal))
|
||||
defer file.Close()
|
||||
if openErr != nil {
|
||||
return openErr
|
||||
}
|
||||
|
||||
copyErr := tb.throttledCopyN(res, file, fileInfo.Uploaded)
|
||||
if copyErr != nil && copyErr != io.EOF {
|
||||
return copyErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tb *QTube) serveRanges(res http.ResponseWriter, headerRange string, fileInfo *fileidx.FileInfo) error {
|
||||
ranges, rangeErr := getRanges(headerRange, fileInfo.Uploaded)
|
||||
if rangeErr != nil {
|
||||
http.Error(res, rangeErr.Error(), http.StatusRequestedRangeNotSatisfiable)
|
||||
return errors.New(rangeErr.Error())
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(ranges) == 1 || len(ranges) > 1:
|
||||
if tb.copyRange(res, ranges[0], fileInfo) != nil {
|
||||
return ErrCopy
|
||||
}
|
||||
default:
|
||||
// TODO: add support for multiple ranges
|
||||
return ErrUnknown
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRanges(headerRange string, size int64) ([]httpRange, error) {
|
||||
ranges, raParseErr := parseRange(headerRange, size)
|
||||
// TODO: check max number of ranges, range start end
|
||||
if len(ranges) <= 0 || raParseErr != nil {
|
||||
return nil, errors.New(ErrorInvalidRange)
|
||||
}
|
||||
if sumRangesSize(ranges) > size {
|
||||
return nil, errors.New(ErrorInvalidSize)
|
||||
}
|
||||
|
||||
return ranges, nil
|
||||
}
|
||||
|
||||
func (tb *QTube) copyRange(res http.ResponseWriter, ra httpRange, fileInfo *fileidx.FileInfo) error {
|
||||
// TODO: comfirm this wont cause problem
|
||||
if ra.GetLength() > tb.MaxRangeLen {
|
||||
ra.SetLength(tb.MaxRangeLen)
|
||||
}
|
||||
|
||||
// TODO: add headers(ETag): https://tools.ietf.org/html/rfc7233#section-4.1 p11 2nd paragraph
|
||||
res.Header().Set("Accept-Ranges", "bytes")
|
||||
res.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filepath.Base(fileInfo.PathLocal)))
|
||||
res.Header().Set("Content-Type", "application/octet-stream")
|
||||
res.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, fileInfo.Uploaded))
|
||||
res.Header().Set("Content-Length", strconv.FormatInt(ra.GetLength(), 10))
|
||||
res.Header().Set("Last-Modified", time.Unix(fileInfo.ModTime, 0).UTC().Format(http.TimeFormat))
|
||||
res.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
// TODO: need verify path
|
||||
file, openErr := tb.Filer.Open(filepath.Join(tb.Root, fileInfo.PathLocal))
|
||||
defer file.Close()
|
||||
if openErr != nil {
|
||||
return openErr
|
||||
}
|
||||
|
||||
if _, seekErr := file.Seek(ra.start, io.SeekStart); seekErr != nil {
|
||||
return seekErr
|
||||
}
|
||||
|
||||
copyErr := tb.throttledCopyN(res, file, ra.length)
|
||||
if copyErr != nil && copyErr != io.EOF {
|
||||
return copyErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tb *QTube) throttledCopyN(dst io.Writer, src io.Reader, length int64) error {
|
||||
sum := int64(0)
|
||||
timeSlot := time.Duration(1 * time.Second)
|
||||
|
||||
for sum < length {
|
||||
start := time.Now()
|
||||
chunkSize := length - sum
|
||||
if length-sum > tb.BytesPerSec {
|
||||
chunkSize = tb.BytesPerSec
|
||||
}
|
||||
|
||||
copied, err := io.CopyN(dst, src, chunkSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sum += copied
|
||||
end := time.Now()
|
||||
if end.Before(start.Add(timeSlot)) {
|
||||
time.Sleep(start.Add(timeSlot).Sub(end))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRange(headerRange string, size int64) ([]httpRange, error) {
|
||||
if headerRange == "" {
|
||||
return nil, nil // header not present
|
||||
}
|
||||
|
||||
const keyByte = "bytes="
|
||||
if !strings.HasPrefix(headerRange, keyByte) {
|
||||
return nil, errors.New("byte= not found")
|
||||
}
|
||||
|
||||
var ranges []httpRange
|
||||
noOverlap := false
|
||||
for _, ra := range strings.Split(headerRange[len(keyByte):], ",") {
|
||||
ra = strings.TrimSpace(ra)
|
||||
if ra == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
i := strings.Index(ra, "-")
|
||||
if i < 0 {
|
||||
return nil, errors.New("- not found")
|
||||
}
|
||||
|
||||
start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
|
||||
var r httpRange
|
||||
if start == "" {
|
||||
i, err := strconv.ParseInt(end, 10, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i > size {
|
||||
i = size
|
||||
}
|
||||
r.start = size - i
|
||||
r.length = size - r.start
|
||||
} else {
|
||||
i, err := strconv.ParseInt(start, 10, 64)
|
||||
if err != nil || i < 0 {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i >= size {
|
||||
// If the range begins after the size of the content,
|
||||
// then it does not overlap.
|
||||
noOverlap = true
|
||||
continue
|
||||
}
|
||||
r.start = i
|
||||
if end == "" {
|
||||
// If no end is specified, range extends to end of the file.
|
||||
r.length = size - r.start
|
||||
} else {
|
||||
i, err := strconv.ParseInt(end, 10, 64)
|
||||
if err != nil || r.start > i {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i >= size {
|
||||
i = size - 1
|
||||
}
|
||||
r.length = i - r.start + 1
|
||||
}
|
||||
}
|
||||
ranges = append(ranges, r)
|
||||
}
|
||||
if noOverlap && len(ranges) == 0 {
|
||||
// The specified ranges did not overlap with the content.
|
||||
return nil, errors.New("parseRanges: no overlap")
|
||||
}
|
||||
return ranges, nil
|
||||
}
|
||||
|
||||
func sumRangesSize(ranges []httpRange) (size int64) {
|
||||
for _, ra := range ranges {
|
||||
size += ra.length
|
||||
}
|
||||
return
|
||||
}
|
354
server/libs/qtube/qtube_test.go
Normal file
354
server/libs/qtube/qtube_test.go
Normal file
|
@ -0,0 +1,354 @@
|
|||
package qtube
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
)
|
||||
|
||||
// Range format examples:
|
||||
// Range: <unit>=<range-start>-
|
||||
// Range: <unit>=<range-start>-<range-end>
|
||||
// Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>
|
||||
// Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>, <range-start>-<range-end>
|
||||
func TestGetRanges(t *testing.T) {
|
||||
type Input struct {
|
||||
HeaderRange string
|
||||
Size int64
|
||||
}
|
||||
type Output struct {
|
||||
Ranges []httpRange
|
||||
ErrorMsg string
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "invalid range",
|
||||
Input: Input{
|
||||
HeaderRange: "bytes=start-invalid end",
|
||||
Size: 0,
|
||||
},
|
||||
Output: Output{
|
||||
ErrorMsg: ErrorInvalidRange,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "invalid range total size",
|
||||
Input: Input{
|
||||
HeaderRange: "bytes=0-1, 2-3, 0-1, 0-2",
|
||||
Size: 3,
|
||||
},
|
||||
Output: Output{
|
||||
ErrorMsg: ErrorInvalidSize,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "range ok",
|
||||
Input: Input{
|
||||
HeaderRange: "bytes=0-1, 2-3",
|
||||
Size: 4,
|
||||
},
|
||||
Output: Output{
|
||||
Ranges: []httpRange{
|
||||
httpRange{start: 0, length: 2},
|
||||
httpRange{start: 2, length: 2},
|
||||
},
|
||||
ErrorMsg: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
ranges, err := getRanges(tCase.HeaderRange, tCase.Size)
|
||||
if err != nil {
|
||||
if err.Error() != tCase.ErrorMsg || len(tCase.Ranges) != 0 {
|
||||
t.Fatalf("getRanges: incorrect errorMsg want: %v got: %v", tCase.ErrorMsg, err.Error())
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
for id, ra := range ranges {
|
||||
if ra.GetStart() != tCase.Ranges[id].GetStart() {
|
||||
t.Fatalf("getRanges: incorrect range start, got: %v want: %v", ra.GetStart(), tCase.Ranges[id])
|
||||
}
|
||||
if ra.GetLength() != tCase.Ranges[id].GetLength() {
|
||||
t.Fatalf("getRanges: incorrect range length, got: %v want: %v", ra.GetLength(), tCase.Ranges[id])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThrottledCopyN(t *testing.T) {
|
||||
type Init struct {
|
||||
BytesPerSec int64
|
||||
MaxRangeLen int64
|
||||
}
|
||||
type Input struct {
|
||||
Src string
|
||||
Length int64
|
||||
}
|
||||
// after starting throttledCopyN by DstAtTime.AtMs millisecond,
|
||||
// copied valueshould equal to DstAtTime.Dst.
|
||||
type DstAtTime struct {
|
||||
AtMS int
|
||||
Dst string
|
||||
}
|
||||
type Output struct {
|
||||
ExpectDsts []DstAtTime
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
verifyDsts := func(dst *bytes.Buffer, expectDsts []DstAtTime) {
|
||||
for _, expectDst := range expectDsts {
|
||||
// fmt.Printf("sleep: %d\n", time.Now().UnixNano())
|
||||
time.Sleep(time.Duration(expectDst.AtMS) * time.Millisecond)
|
||||
dstStr := string(dst.Bytes())
|
||||
// fmt.Printf("check: %d\n", time.Now().UnixNano())
|
||||
if dstStr != expectDst.Dst {
|
||||
panic(
|
||||
fmt.Sprintf(
|
||||
"throttledCopyN want: <%s> | got: <%s> | at: %d",
|
||||
expectDst.Dst,
|
||||
dstStr,
|
||||
expectDst.AtMS,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "4 byte per sec",
|
||||
Init: Init{
|
||||
BytesPerSec: 5,
|
||||
MaxRangeLen: 10,
|
||||
},
|
||||
Input: Input{
|
||||
Src: "aaaa_aaaa_",
|
||||
Length: 10,
|
||||
},
|
||||
Output: Output{
|
||||
ExpectDsts: []DstAtTime{
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_"},
|
||||
DstAtTime{AtMS: 600, Dst: "aaaa_aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_aaaa_"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
tb := NewQTube("", tCase.BytesPerSec, tCase.MaxRangeLen, &stubFiler{}).(*QTube)
|
||||
dst := bytes.NewBuffer(make([]byte, len(tCase.Src)))
|
||||
dst.Reset()
|
||||
|
||||
go verifyDsts(dst, tCase.ExpectDsts)
|
||||
tb.throttledCopyN(dst, strings.NewReader(tCase.Src), tCase.Length)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: using same stub with testhelper
|
||||
type stubWriter struct {
|
||||
Headers http.Header
|
||||
Response []byte
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (w *stubWriter) Header() http.Header {
|
||||
return w.Headers
|
||||
}
|
||||
|
||||
func (w *stubWriter) Write(body []byte) (int, error) {
|
||||
w.Response = append(w.Response, body...)
|
||||
return len(body), nil
|
||||
}
|
||||
|
||||
func (w *stubWriter) WriteHeader(statusCode int) {
|
||||
w.StatusCode = statusCode
|
||||
}
|
||||
|
||||
func TestCopyRange(t *testing.T) {
|
||||
type Init struct {
|
||||
Content string
|
||||
}
|
||||
type Input struct {
|
||||
Range httpRange
|
||||
Info fileidx.FileInfo
|
||||
}
|
||||
type Output struct {
|
||||
StatusCode int
|
||||
Headers map[string][]string
|
||||
Body string
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "copy ok",
|
||||
Init: Init{
|
||||
Content: "abcd_abcd_",
|
||||
},
|
||||
Input: Input{
|
||||
Range: httpRange{
|
||||
start: 6,
|
||||
length: 3,
|
||||
},
|
||||
Info: fileidx.FileInfo{
|
||||
ModTime: 0,
|
||||
Uploaded: 10,
|
||||
PathLocal: "filename.jpg",
|
||||
},
|
||||
},
|
||||
Output: Output{
|
||||
StatusCode: 206,
|
||||
Headers: map[string][]string{
|
||||
"Accept-Ranges": []string{"bytes"},
|
||||
"Content-Disposition": []string{`attachment; filename="filename.jpg"`},
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
"Content-Range": []string{"bytes 6-8/10"},
|
||||
"Content-Length": []string{"3"},
|
||||
"Last-Modified": []string{time.Unix(0, 0).UTC().Format(http.TimeFormat)},
|
||||
},
|
||||
Body: "abc",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
filer := &stubFiler{
|
||||
&StubFile{
|
||||
Content: tCase.Content,
|
||||
Offset: 0,
|
||||
},
|
||||
}
|
||||
tb := NewQTube("", 100, 100, filer).(*QTube)
|
||||
res := &stubWriter{
|
||||
Headers: make(map[string][]string),
|
||||
Response: make([]byte, 0),
|
||||
}
|
||||
err := tb.copyRange(res, tCase.Range, &tCase.Info)
|
||||
if err != nil {
|
||||
t.Fatalf("copyRange: %v", err)
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("copyRange: statusCode not match got: %v want: %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
if string(res.Response) != tCase.Output.Body {
|
||||
t.Fatalf("copyRange: body not match \ngot: %v \nwant: %v", string(res.Response), tCase.Output.Body)
|
||||
}
|
||||
for key, vals := range tCase.Output.Headers {
|
||||
if res.Header().Get(key) != vals[0] {
|
||||
t.Fatalf("copyRange: header not match %v got: %v want: %v", key, res.Header().Get(key), vals[0])
|
||||
}
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("copyRange: statusCodes are not match %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeAll(t *testing.T) {
|
||||
type Init struct {
|
||||
Content string
|
||||
}
|
||||
type Input struct {
|
||||
Info fileidx.FileInfo
|
||||
}
|
||||
type Output struct {
|
||||
StatusCode int
|
||||
Headers map[string][]string
|
||||
Body string
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "copy ok",
|
||||
Init: Init{
|
||||
Content: "abcd_abcd_",
|
||||
},
|
||||
Input: Input{
|
||||
Info: fileidx.FileInfo{
|
||||
ModTime: 0,
|
||||
Uploaded: 10,
|
||||
PathLocal: "filename.jpg",
|
||||
},
|
||||
},
|
||||
Output: Output{
|
||||
StatusCode: 200,
|
||||
Headers: map[string][]string{
|
||||
"Accept-Ranges": []string{"bytes"},
|
||||
"Content-Disposition": []string{`attachment; filename="filename.jpg"`},
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
"Content-Length": []string{"10"},
|
||||
"Last-Modified": []string{time.Unix(0, 0).UTC().Format(http.TimeFormat)},
|
||||
},
|
||||
Body: "abcd_abcd_",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
filer := &stubFiler{
|
||||
&StubFile{
|
||||
Content: tCase.Content,
|
||||
Offset: 0,
|
||||
},
|
||||
}
|
||||
tb := NewQTube("", 100, 100, filer).(*QTube)
|
||||
res := &stubWriter{
|
||||
Headers: make(map[string][]string),
|
||||
Response: make([]byte, 0),
|
||||
}
|
||||
err := tb.serveAll(res, &tCase.Info)
|
||||
if err != nil {
|
||||
t.Fatalf("serveAll: %v", err)
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("serveAll: statusCode not match got: %v want: %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
if string(res.Response) != tCase.Output.Body {
|
||||
t.Fatalf("serveAll: body not match \ngot: %v \nwant: %v", string(res.Response), tCase.Output.Body)
|
||||
}
|
||||
for key, vals := range tCase.Output.Headers {
|
||||
if res.Header().Get(key) != vals[0] {
|
||||
t.Fatalf("serveAll: header not match %v got: %v want: %v", key, res.Header().Get(key), vals[0])
|
||||
}
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("serveAll: statusCodes are not match %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
28
server/libs/qtube/test_helper.go
Normal file
28
server/libs/qtube/test_helper.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package qtube
|
||||
|
||||
type StubFile struct {
|
||||
Content string
|
||||
Offset int64
|
||||
}
|
||||
|
||||
func (file *StubFile) Read(p []byte) (int, error) {
|
||||
copied := copy(p[:], []byte(file.Content)[:len(p)])
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
func (file *StubFile) Seek(offset int64, whence int) (int64, error) {
|
||||
file.Offset = offset
|
||||
return offset, nil
|
||||
}
|
||||
|
||||
func (file *StubFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubFiler struct {
|
||||
file *StubFile
|
||||
}
|
||||
|
||||
func (filer *stubFiler) Open(filePath string) (ReadSeekCloser, error) {
|
||||
return filer.file, nil
|
||||
}
|
102
server/libs/walls/access_walls.go
Normal file
102
server/libs/walls/access_walls.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package walls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/limiter"
|
||||
)
|
||||
|
||||
type AccessWalls struct {
|
||||
cf *cfg.Config
|
||||
IpLimiter limiter.Limiter
|
||||
OpLimiter limiter.Limiter
|
||||
EncrypterMaker encrypt.EncrypterMaker
|
||||
}
|
||||
|
||||
func NewAccessWalls(
|
||||
cf *cfg.Config,
|
||||
ipLimiter limiter.Limiter,
|
||||
opLimiter limiter.Limiter,
|
||||
encrypterMaker encrypt.EncrypterMaker,
|
||||
) Walls {
|
||||
return &AccessWalls{
|
||||
cf: cf,
|
||||
IpLimiter: ipLimiter,
|
||||
OpLimiter: opLimiter,
|
||||
EncrypterMaker: encrypterMaker,
|
||||
}
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) PassIpLimit(remoteAddr string) bool {
|
||||
if !walls.cf.Production {
|
||||
return true
|
||||
}
|
||||
return walls.IpLimiter.Access(remoteAddr, walls.cf.OpIdIpVisit)
|
||||
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) PassOpLimit(resourceId string, opId int16) bool {
|
||||
if !walls.cf.Production {
|
||||
return true
|
||||
}
|
||||
return walls.OpLimiter.Access(resourceId, opId)
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) PassLoginCheck(tokenStr string, req *http.Request) bool {
|
||||
if !walls.cf.Production {
|
||||
return true
|
||||
}
|
||||
|
||||
return walls.passLoginCheck(tokenStr)
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) passLoginCheck(tokenStr string) bool {
|
||||
token, getLoginTokenOk := walls.GetLoginToken(tokenStr)
|
||||
return getLoginTokenOk && token.AdminId == walls.cf.AdminId
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) GetLoginToken(tokenStr string) (*LoginToken, bool) {
|
||||
tokenMaker := walls.EncrypterMaker(string(walls.cf.SecretKeyByte))
|
||||
if !tokenMaker.FromStr(tokenStr) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
adminIdFromToken, adminIdOk := tokenMaker.Get(walls.cf.KeyAdminId)
|
||||
expiresStr, expiresStrOk := tokenMaker.Get(walls.cf.KeyExpires)
|
||||
if !adminIdOk || !expiresStrOk {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
expires, expiresParseErr := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if expiresParseErr != nil ||
|
||||
adminIdFromToken != walls.cf.AdminId ||
|
||||
expires <= time.Now().Unix() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &LoginToken{
|
||||
AdminId: adminIdFromToken,
|
||||
Expires: expires,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) MakeLoginToken(userId string) string {
|
||||
expires := time.Now().Add(time.Duration(walls.cf.CookieMaxAge) * time.Second).Unix()
|
||||
|
||||
tokenMaker := walls.EncrypterMaker(string(walls.cf.SecretKeyByte))
|
||||
tokenMaker.Add(walls.cf.KeyAdminId, userId)
|
||||
tokenMaker.Add(walls.cf.KeyExpires, fmt.Sprintf("%d", expires))
|
||||
|
||||
tokenStr, ok := tokenMaker.ToStr()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return tokenStr
|
||||
}
|
145
server/libs/walls/access_walls_test.go
Normal file
145
server/libs/walls/access_walls_test.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package walls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/limiter"
|
||||
)
|
||||
|
||||
func newAccessWalls(limiterCap int64, limiterTtl int32, limiterCyc int32, bucketCap int16) *AccessWalls {
|
||||
config := cfg.NewConfig()
|
||||
config.Production = true
|
||||
config.LimiterCap = limiterCap
|
||||
config.LimiterTtl = limiterTtl
|
||||
config.LimiterCyc = limiterCyc
|
||||
config.BucketCap = bucketCap
|
||||
encrypterMaker := encrypt.JwtEncrypterMaker
|
||||
ipLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
|
||||
opLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
|
||||
|
||||
return NewAccessWalls(config, ipLimiter, opLimiter, encrypterMaker).(*AccessWalls)
|
||||
}
|
||||
func TestIpLimit(t *testing.T) {
|
||||
ip := "0.0.0.0"
|
||||
limit := int16(10)
|
||||
ttl := int32(60)
|
||||
cyc := int32(5)
|
||||
walls := newAccessWalls(1000, ttl, cyc, limit)
|
||||
|
||||
testIpLimit(t, walls, ip, limit)
|
||||
// wait for tokens are re-fullfilled
|
||||
time.Sleep(time.Duration(cyc) * time.Second)
|
||||
testIpLimit(t, walls, ip, limit)
|
||||
|
||||
fmt.Println("ip limit: passed")
|
||||
}
|
||||
|
||||
func testIpLimit(t *testing.T, walls Walls, ip string, limit int16) {
|
||||
for i := int16(0); i < limit; i++ {
|
||||
if !walls.PassIpLimit(ip) {
|
||||
t.Fatalf("ipLimiter: should be passed", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
if walls.PassIpLimit(ip) {
|
||||
t.Fatalf("ipLimiter: should not be passed", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpLimit(t *testing.T) {
|
||||
resourceId := "id"
|
||||
op1 := int16(1)
|
||||
op2 := int16(2)
|
||||
limit := int16(10)
|
||||
ttl := int32(1)
|
||||
walls := newAccessWalls(1000, 5, ttl, limit)
|
||||
|
||||
testOpLimit(t, walls, resourceId, op1, limit)
|
||||
testOpLimit(t, walls, resourceId, op2, limit)
|
||||
// wait for tokens are re-fullfilled
|
||||
time.Sleep(time.Duration(ttl) * time.Second)
|
||||
testOpLimit(t, walls, resourceId, op1, limit)
|
||||
testOpLimit(t, walls, resourceId, op2, limit)
|
||||
|
||||
fmt.Println("op limit: passed")
|
||||
}
|
||||
|
||||
func testOpLimit(t *testing.T, walls Walls, resourceId string, op int16, limit int16) {
|
||||
for i := int16(0); i < limit; i++ {
|
||||
if !walls.PassOpLimit(resourceId, op) {
|
||||
t.Fatalf("opLimiter: should be passed")
|
||||
}
|
||||
}
|
||||
|
||||
if walls.PassOpLimit(resourceId, op) {
|
||||
t.Fatalf("opLimiter: should not be passed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginCheck(t *testing.T) {
|
||||
walls := newAccessWalls(1000, 5, 1, 10)
|
||||
|
||||
testValidToken(t, walls)
|
||||
testInvalidAdminIdToken(t, walls)
|
||||
testExpiredToken(t, walls)
|
||||
}
|
||||
|
||||
func testValidToken(t *testing.T, walls *AccessWalls) {
|
||||
config := cfg.NewConfig()
|
||||
|
||||
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
|
||||
tokenMaker.Add(config.KeyAdminId, config.AdminId)
|
||||
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()+int64(10)))
|
||||
tokenStr, getTokenOk := tokenMaker.ToStr()
|
||||
if !getTokenOk {
|
||||
t.Fatalf("passLoginCheck: fail to generate token")
|
||||
}
|
||||
|
||||
if !walls.passLoginCheck(tokenStr) {
|
||||
t.Fatalf("loginCheck: should be passed")
|
||||
}
|
||||
|
||||
fmt.Println("loginCheck: valid token passed")
|
||||
}
|
||||
|
||||
func testInvalidAdminIdToken(t *testing.T, walls *AccessWalls) {
|
||||
config := cfg.NewConfig()
|
||||
|
||||
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
|
||||
tokenMaker.Add(config.KeyAdminId, "invalid admin id")
|
||||
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()+int64(10)))
|
||||
tokenStr, getTokenOk := tokenMaker.ToStr()
|
||||
if !getTokenOk {
|
||||
t.Fatalf("passLoginCheck: fail to generate token")
|
||||
}
|
||||
|
||||
if walls.passLoginCheck(tokenStr) {
|
||||
t.Fatalf("loginCheck: should not be passed")
|
||||
}
|
||||
|
||||
fmt.Println("loginCheck: invalid admin id passed")
|
||||
}
|
||||
|
||||
func testExpiredToken(t *testing.T, walls *AccessWalls) {
|
||||
config := cfg.NewConfig()
|
||||
|
||||
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
|
||||
tokenMaker.Add(config.KeyAdminId, config.AdminId)
|
||||
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()-int64(1)))
|
||||
tokenStr, getTokenOk := tokenMaker.ToStr()
|
||||
if !getTokenOk {
|
||||
t.Fatalf("passLoginCheck: fail to generate token")
|
||||
}
|
||||
|
||||
if walls.passLoginCheck(tokenStr) {
|
||||
t.Fatalf("loginCheck: should not be passed")
|
||||
}
|
||||
|
||||
fmt.Println("loginCheck: expired token passed")
|
||||
}
|
17
server/libs/walls/walls.go
Normal file
17
server/libs/walls/walls.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package walls
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Walls interface {
|
||||
PassIpLimit(remoteAddr string) bool
|
||||
PassOpLimit(resourceId string, opId int16) bool
|
||||
PassLoginCheck(tokenStr string, req *http.Request) bool
|
||||
MakeLoginToken(uid string) string
|
||||
}
|
||||
|
||||
type LoginToken struct {
|
||||
AdminId string
|
||||
Expires int64
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue