!1 Merge back to master

Merge pull request !1 from dev branch
This commit is contained in:
hekk 2018-05-27 21:32:55 +08:00
parent 30c963a5f0
commit 61a1c93f0f
89 changed files with 15859 additions and 2 deletions

105
server/apis/auth.go Normal file
View file

@ -0,0 +1,105 @@
package apis
import (
"net/http"
"time"
)
import (
"quickshare/server/libs/httputil"
"quickshare/server/libs/httpworker"
)
func (srv *SrvShare) LoginHandler(res http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
srv.Http.Fill(httputil.Err404, res)
return
}
act := req.FormValue(srv.Conf.KeyAct)
todo := func(res http.ResponseWriter, req *http.Request) interface{} { return httputil.Err404 }
switch act {
case srv.Conf.ActLogin:
todo = srv.Login
case srv.Conf.ActLogout:
todo = srv.Logout
default:
srv.Http.Fill(httputil.Err404, res)
return
}
ack := make(chan error, 1)
ok := srv.WorkerPool.Put(&httpworker.Task{
Ack: ack,
Do: srv.Wrap(todo),
Res: res,
Req: req,
})
if !ok {
srv.Http.Fill(httputil.Err503, res)
return
}
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
if srv.Err.IsErr(execErr) {
srv.Http.Fill(httputil.Err500, res)
}
}
func (srv *SrvShare) Login(res http.ResponseWriter, req *http.Request) interface{} {
// all users need to pass same wall to login
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) ||
!srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdLogin) {
return httputil.Err504
}
return srv.login(
req.FormValue(srv.Conf.KeyAdminId),
req.FormValue(srv.Conf.KeyAdminPwd),
res,
)
}
func (srv *SrvShare) login(adminId string, adminPwd string, res http.ResponseWriter) interface{} {
if adminId != srv.Conf.AdminId ||
adminPwd != srv.Conf.AdminPwd {
return httputil.Err401
}
token := srv.Walls.MakeLoginToken(srv.Conf.AdminId)
if token == "" {
return httputil.Err500
}
srv.Http.SetCookie(res, srv.Conf.KeyToken, token)
return httputil.Ok200
}
func (srv *SrvShare) Logout(res http.ResponseWriter, req *http.Request) interface{} {
srv.Http.SetCookie(res, srv.Conf.KeyToken, "-")
return httputil.Ok200
}
func (srv *SrvShare) IsValidLength(length int64) bool {
return length > 0 && length <= srv.Conf.MaxUpBytesPerSec
}
func (srv *SrvShare) IsValidStart(start, expectStart int64) bool {
return start == expectStart
}
func (srv *SrvShare) IsValidShareId(shareId string) bool {
// id could be 0 for dev environment
if srv.Conf.Production {
return len(shareId) == 64
}
return true
}
func (srv *SrvShare) IsValidDownLimit(limit int) bool {
return limit >= -1
}
func IsValidFileName(fileName string) bool {
return fileName != "" && len(fileName) < 240
}

78
server/apis/auth_test.go Normal file
View file

@ -0,0 +1,78 @@
package apis
import (
"fmt"
"strings"
"testing"
)
import (
"quickshare/server/libs/cfg"
"quickshare/server/libs/encrypt"
"quickshare/server/libs/httputil"
)
func TestLogin(t *testing.T) {
conf := cfg.NewConfig()
type testCase struct {
Desc string
AdminId string
AdminPwd string
Result interface{}
VerifyToken bool
}
testCases := []testCase{
testCase{
Desc: "invalid input",
AdminId: "",
AdminPwd: "",
Result: httputil.Err401,
VerifyToken: false,
},
testCase{
Desc: "account not match",
AdminId: "unknown",
AdminPwd: "unknown",
Result: httputil.Err401,
VerifyToken: false,
},
testCase{
Desc: "succeed to login",
AdminId: conf.AdminId,
AdminPwd: conf.AdminPwd,
Result: httputil.Ok200,
VerifyToken: true,
},
}
for _, testCase := range testCases {
srv := NewSrvShare(conf)
res := &stubWriter{Headers: map[string][]string{}}
ret := srv.login(testCase.AdminId, testCase.AdminPwd, res)
if ret != testCase.Result {
t.Fatalf("login: reponse=%v testCase=%v", ret, testCase.Result)
}
// verify cookie (only token.adminid part))
if testCase.VerifyToken {
cookieVal := strings.Replace(
res.Header().Get("Set-Cookie"),
fmt.Sprintf("%s=", conf.KeyToken),
"",
1,
)
gotTokenStr := strings.Split(cookieVal, ";")[0]
token := encrypt.JwtEncrypterMaker(conf.SecretKey)
token.FromStr(gotTokenStr)
gotToken, found := token.Get(conf.KeyAdminId)
if !found || conf.AdminId != gotToken {
t.Fatalf("login: token admin id unmatch got=%v expect=%v", gotToken, conf.AdminId)
}
}
}
}

66
server/apis/client.go Normal file
View file

@ -0,0 +1,66 @@
package apis
import (
"net/http"
"path/filepath"
"strings"
"time"
)
import (
"quickshare/server/libs/httputil"
"quickshare/server/libs/httpworker"
)
func (srv *SrvShare) ClientHandler(res http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
srv.Http.Fill(httputil.Err404, res)
return
}
ack := make(chan error, 1)
ok := srv.WorkerPool.Put(&httpworker.Task{
Ack: ack,
Do: srv.Wrap(srv.GetClient),
Res: res,
Req: req,
})
if !ok {
srv.Http.Fill(httputil.Err503, res)
return
}
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
if srv.Err.IsErr(execErr) {
srv.Http.Fill(httputil.Err500, res)
}
}
func (srv *SrvShare) GetClient(res http.ResponseWriter, req *http.Request) interface{} {
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) {
return httputil.Err504
}
return srv.getClient(res, req, req.URL.EscapedPath())
}
func (srv *SrvShare) getClient(res http.ResponseWriter, req *http.Request, relPath string) interface{} {
if strings.HasSuffix(relPath, "/") {
relPath = relPath + "index.html"
}
if !IsValidClientPath(relPath) {
return httputil.Err400
}
fullPath := filepath.Clean(filepath.Join("./public", relPath))
http.ServeFile(res, req, fullPath)
return 0
}
func IsValidClientPath(fullPath string) bool {
if strings.Contains(fullPath, "..") {
return false
}
return true
}

69
server/apis/download.go Normal file
View file

@ -0,0 +1,69 @@
package apis
import (
"net/http"
"time"
)
import (
"quickshare/server/libs/fileidx"
"quickshare/server/libs/httputil"
"quickshare/server/libs/httpworker"
)
func (srv *SrvShare) DownloadHandler(res http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
srv.Http.Fill(httputil.Err404, res)
}
ack := make(chan error, 1)
ok := srv.WorkerPool.Put(&httpworker.Task{
Ack: ack,
Do: srv.Wrap(srv.Download),
Res: res,
Req: req,
})
if !ok {
srv.Http.Fill(httputil.Err503, res)
}
// using WriteTimeout instead of Timeout
// After timeout, connection will be lost, and worker will fail to write and return
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.WriteTimeout)*time.Millisecond)
if srv.Err.IsErr(execErr) {
srv.Http.Fill(httputil.Err500, res)
}
}
func (srv *SrvShare) Download(res http.ResponseWriter, req *http.Request) interface{} {
shareId := req.FormValue(srv.Conf.KeyShareId)
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) ||
!srv.Walls.PassOpLimit(shareId, srv.Conf.OpIdDownload) {
return httputil.Err429
}
return srv.download(shareId, res, req)
}
func (srv *SrvShare) download(shareId string, res http.ResponseWriter, req *http.Request) interface{} {
if !srv.IsValidShareId(shareId) {
return httputil.Err400
}
fileInfo, found := srv.Index.Get(shareId)
switch {
case !found || fileInfo.State != fileidx.StateDone:
return httputil.Err404
case fileInfo.DownLimit == 0:
return httputil.Err412
default:
updated, _ := srv.Index.DecrDownLimit(shareId)
if updated != 1 {
return httputil.Err500
}
}
err := srv.Downloader.ServeFile(res, req, fileInfo)
srv.Err.IsErr(err)
return 0
}

View file

@ -0,0 +1,271 @@
package apis
import (
"net/http"
"os"
"testing"
"time"
)
import (
"quickshare/server/libs/cfg"
"quickshare/server/libs/errutil"
"quickshare/server/libs/fileidx"
"quickshare/server/libs/httputil"
"quickshare/server/libs/logutil"
"quickshare/server/libs/qtube"
)
func initServiceForDownloadTest(config *cfg.Config, indexMap map[string]*fileidx.FileInfo, content string) *SrvShare {
setDownloader := func(srv *SrvShare) {
srv.Downloader = stubDownloader{Content: content}
}
setIndex := func(srv *SrvShare) {
srv.Index = fileidx.NewMemFileIndexWithMap(len(indexMap), indexMap)
}
setFs := func(srv *SrvShare) {
srv.Fs = &stubFsUtil{
MockFile: &qtube.StubFile{
Content: content,
Offset: 0,
},
}
}
logger := logutil.NewSlog(os.Stdout, config.AppName)
setLog := func(srv *SrvShare) {
srv.Log = logger
}
setErr := func(srv *SrvShare) {
srv.Err = errutil.NewErrChecker(!config.Production, logger)
}
return InitSrvShare(config, setDownloader, setIndex, setFs, setLog, setErr)
}
func TestDownload(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
Content string
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
ShareId string
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response interface{}
Body string
}
type testCase struct {
Desc string
Init
Input
Output
}
testCases := []testCase{
testCase{
Desc: "empty file index",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{},
Response: httputil.Err404,
},
},
testCase{
Desc: "file info not found",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"1": &fileidx.FileInfo{},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"1": &fileidx.FileInfo{},
},
Response: httputil.Err404,
},
},
testCase{
Desc: "file not found because of state=uploading",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: 1,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateUploading,
Uploaded: 1,
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: 1,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateUploading,
Uploaded: 1,
},
},
Response: httputil.Err404,
},
},
testCase{
Desc: "download failed because download limit = 0",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: 0,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateDone,
Uploaded: 1,
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: 0,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateDone,
Uploaded: 1,
},
},
Response: httputil.Err412,
},
},
testCase{
Desc: "succeed to download",
Init: Init{
Content: "content",
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: 1,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateDone,
Uploaded: 1,
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: 0,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateDone,
Uploaded: 1,
},
},
Response: 0,
Body: "content",
},
},
testCase{
Desc: "succeed to download DownLimit == -1",
Init: Init{
Content: "content",
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: -1,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateDone,
Uploaded: 1,
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: -1,
ModTime: time.Now().UnixNano(),
PathLocal: "path",
State: fileidx.StateDone,
Uploaded: 1,
},
},
Response: 0,
Body: "content",
},
},
}
for _, testCase := range testCases {
srv := initServiceForDownloadTest(conf, testCase.Init.IndexMap, testCase.Content)
writer := &stubWriter{Headers: map[string][]string{}}
response := srv.download(
testCase.ShareId,
writer,
&http.Request{},
)
// verify downlimit
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
info, _ := srv.Index.Get(testCase.ShareId)
t.Fatalf(
"download: index incorrect got=%v want=%v",
info,
testCase.Output.IndexMap[testCase.ShareId],
)
}
// verify response
if response != testCase.Output.Response {
t.Fatalf(
"download: response incorrect response=%v testCase=%v",
response,
testCase.Output.Response,
)
}
// verify writerContent
if string(writer.Response) != testCase.Output.Body {
t.Fatalf(
"download: body incorrect got=%v want=%v",
string(writer.Response),
testCase.Output.Body,
)
}
}
}

234
server/apis/file_info.go Normal file
View file

@ -0,0 +1,234 @@
package apis
import (
"fmt"
"math/rand"
"net/http"
"path/filepath"
"strconv"
"time"
)
import (
"quickshare/server/libs/fileidx"
"quickshare/server/libs/httputil"
"quickshare/server/libs/httpworker"
)
func (srv *SrvShare) FileInfoHandler(res http.ResponseWriter, req *http.Request) {
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
if !srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr)) ||
!srv.Walls.PassLoginCheck(tokenStr, req) {
srv.Http.Fill(httputil.Err429, res)
return
}
todo := func(res http.ResponseWriter, req *http.Request) interface{} { return httputil.Err404 }
switch req.Method {
case http.MethodGet:
todo = srv.List
case http.MethodDelete:
todo = srv.Del
case http.MethodPatch:
act := req.FormValue(srv.Conf.KeyAct)
switch act {
case srv.Conf.ActShadowId:
todo = srv.ShadowId
case srv.Conf.ActPublishId:
todo = srv.PublishId
case srv.Conf.ActSetDownLimit:
todo = srv.SetDownLimit
case srv.Conf.ActAddLocalFiles:
todo = srv.AddLocalFiles
default:
srv.Http.Fill(httputil.Err404, res)
return
}
default:
srv.Http.Fill(httputil.Err404, res)
return
}
ack := make(chan error, 1)
ok := srv.WorkerPool.Put(&httpworker.Task{
Ack: ack,
Do: srv.Wrap(todo),
Res: res,
Req: req,
})
if !ok {
srv.Http.Fill(httputil.Err503, res)
}
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
if srv.Err.IsErr(execErr) {
srv.Http.Fill(httputil.Err500, res)
}
}
type ResInfos struct {
List []*fileidx.FileInfo
}
func (srv *SrvShare) List(res http.ResponseWriter, req *http.Request) interface{} {
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdGetFInfo) {
return httputil.Err429
}
return srv.list()
}
func (srv *SrvShare) list() interface{} {
infos := make([]*fileidx.FileInfo, 0)
for _, info := range srv.Index.List() {
infos = append(infos, info)
}
return &ResInfos{List: infos}
}
func (srv *SrvShare) Del(res http.ResponseWriter, req *http.Request) interface{} {
shareId := req.FormValue(srv.Conf.KeyShareId)
if !srv.Walls.PassOpLimit(shareId, srv.Conf.OpIdDelFInfo) {
return httputil.Err504
}
return srv.del(shareId)
}
func (srv *SrvShare) del(shareId string) interface{} {
if !srv.IsValidShareId(shareId) {
return httputil.Err400
}
fileInfo, found := srv.Index.Get(shareId)
if !found {
return httputil.Err404
}
srv.Index.Del(shareId)
fullPath := filepath.Join(srv.Conf.PathLocal, fileInfo.PathLocal)
if !srv.Fs.DelFile(fullPath) {
// TODO: may log file name because file not exist or delete is not authenticated
return httputil.Err500
}
return httputil.Ok200
}
func (srv *SrvShare) ShadowId(res http.ResponseWriter, req *http.Request) interface{} {
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdOpFInfo) {
return httputil.Err429
}
shareId := req.FormValue(srv.Conf.KeyShareId)
return srv.shadowId(shareId)
}
func (srv *SrvShare) shadowId(shareId string) interface{} {
if !srv.IsValidShareId(shareId) {
return httputil.Err400
}
info, found := srv.Index.Get(shareId)
if !found {
return httputil.Err404
}
secretId := srv.Encryptor.Encrypt(
[]byte(fmt.Sprintf("%s%s", info.PathLocal, genPwd())),
)
if !srv.Index.SetId(info.Id, secretId) {
return httputil.Err412
}
return &ShareInfo{ShareId: secretId}
}
func (srv *SrvShare) PublishId(res http.ResponseWriter, req *http.Request) interface{} {
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdOpFInfo) {
return httputil.Err429
}
shareId := req.FormValue(srv.Conf.KeyShareId)
return srv.publishId(shareId)
}
func (srv *SrvShare) publishId(shareId string) interface{} {
if !srv.IsValidShareId(shareId) {
return httputil.Err400
}
info, found := srv.Index.Get(shareId)
if !found {
return httputil.Err404
}
publicId := srv.Encryptor.Encrypt([]byte(info.PathLocal))
if !srv.Index.SetId(info.Id, publicId) {
return httputil.Err412
}
return &ShareInfo{ShareId: publicId}
}
func (srv *SrvShare) SetDownLimit(res http.ResponseWriter, req *http.Request) interface{} {
if !srv.Walls.PassOpLimit(srv.Conf.AllUsers, srv.Conf.OpIdOpFInfo) {
return httputil.Err429
}
shareId := req.FormValue(srv.Conf.KeyShareId)
downLimit64, downLimitParseErr := strconv.ParseInt(req.FormValue(srv.Conf.KeyDownLimit), 10, 32)
downLimit := int(downLimit64)
if srv.Err.IsErr(downLimitParseErr) {
return httputil.Err400
}
return srv.setDownLimit(shareId, downLimit)
}
func (srv *SrvShare) setDownLimit(shareId string, downLimit int) interface{} {
if !srv.IsValidShareId(shareId) || !srv.IsValidDownLimit(downLimit) {
return httputil.Err400
}
if !srv.Index.SetDownLimit(shareId, downLimit) {
return httputil.Err404
}
return httputil.Ok200
}
func (srv *SrvShare) AddLocalFiles(res http.ResponseWriter, req *http.Request) interface{} {
return srv.AddLocalFilesImp()
}
func (srv *SrvShare) AddLocalFilesImp() interface{} {
infos, err := srv.Fs.Readdir(srv.Conf.PathLocal, srv.Conf.LocalFileLimit)
if srv.Err.IsErr(err) {
panic(fmt.Sprintf("fail to readdir: %v", err))
}
for _, info := range infos {
info.DownLimit = srv.Conf.DownLimit
info.State = fileidx.StateDone
info.PathLocal = info.PathLocal
info.Id = srv.Encryptor.Encrypt([]byte(info.PathLocal))
addRet := srv.Index.Add(info)
switch {
case addRet == 0 || addRet == -1:
// TODO: return files not added
continue
case addRet == 1:
break
default:
return httputil.Err500
}
}
return httputil.Ok200
}
func genPwd() string {
return fmt.Sprintf("%d%d%d%d", rand.Intn(10), rand.Intn(10), rand.Intn(10), rand.Intn(10))
}

View file

@ -0,0 +1,584 @@
package apis
import (
"os"
"path/filepath"
"testing"
)
import (
"quickshare/server/libs/cfg"
"quickshare/server/libs/errutil"
"quickshare/server/libs/fileidx"
"quickshare/server/libs/httputil"
"quickshare/server/libs/logutil"
)
const mockShadowId = "shadowId"
const mockPublicId = "publicId"
func initServiceForFileInfoTest(
config *cfg.Config,
indexMap map[string]*fileidx.FileInfo,
useShadowEnc bool,
localFileInfos []*fileidx.FileInfo,
) *SrvShare {
setIndex := func(srv *SrvShare) {
srv.Index = fileidx.NewMemFileIndexWithMap(len(indexMap), indexMap)
}
setFs := func(srv *SrvShare) {
srv.Fs = &stubFsUtil{MockLocalFileInfos: localFileInfos}
}
logger := logutil.NewSlog(os.Stdout, config.AppName)
setLog := func(srv *SrvShare) {
srv.Log = logger
}
errChecker := errutil.NewErrChecker(!config.Production, logger)
setErr := func(srv *SrvShare) {
srv.Err = errChecker
}
var setEncryptor AddDep
if useShadowEnc {
setEncryptor = func(srv *SrvShare) {
srv.Encryptor = &stubEncryptor{MockResult: mockShadowId}
}
} else {
setEncryptor = func(srv *SrvShare) {
srv.Encryptor = &stubEncryptor{MockResult: mockPublicId}
}
}
return InitSrvShare(config, setIndex, setFs, setEncryptor, setLog, setErr)
}
func TestList(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Output struct {
IndexMap map[string]*fileidx.FileInfo
}
type TestCase struct {
Desc string
Output
}
testCases := []TestCase{
TestCase{
Desc: "success",
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
},
"1": &fileidx.FileInfo{
Id: "1",
},
},
},
},
}
for _, testCase := range testCases {
srv := initServiceForFileInfoTest(conf, testCase.Output.IndexMap, true, []*fileidx.FileInfo{})
response := srv.list()
resInfos := response.(*ResInfos)
for _, info := range resInfos.List {
infoFromSrv, found := srv.Index.Get(info.Id)
if !found || infoFromSrv.Id != info.Id {
t.Fatalf("list: file infos are not identical")
}
}
if len(resInfos.List) != len(srv.Index.List()) {
t.Fatalf("list: file infos are not identical")
}
}
}
func TestDel(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
ShareId string
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response httputil.MsgRes
}
type TestCase struct {
Desc string
Init
Input
Output
}
testCases := []TestCase{
TestCase{
Desc: "success",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
},
"1": &fileidx.FileInfo{
Id: "1",
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"1": &fileidx.FileInfo{
Id: "1",
},
},
Response: httputil.Ok200,
},
},
TestCase{
Desc: "not found",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"1": &fileidx.FileInfo{
Id: "1",
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"1": &fileidx.FileInfo{
Id: "1",
},
},
Response: httputil.Err404,
},
},
}
for _, testCase := range testCases {
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, true, []*fileidx.FileInfo{})
response := srv.del(testCase.ShareId)
res := response.(httputil.MsgRes)
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
t.Fatalf("del: index incorrect")
}
if res != testCase.Output.Response {
t.Fatalf("del: response incorrect got: %v, want: %v", res, testCase.Output.Response)
}
}
}
func TestShadowId(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
ShareId string
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response interface{}
}
type TestCase struct {
Desc string
Init
Input
Output
}
testCases := []TestCase{
TestCase{
Desc: "success",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
mockShadowId: &fileidx.FileInfo{
Id: mockShadowId,
},
},
Response: &ShareInfo{
ShareId: mockShadowId,
},
},
},
TestCase{
Desc: "original id not exists",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{},
Response: httputil.Err404,
},
},
TestCase{
Desc: "dest id exists",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
},
mockShadowId: &fileidx.FileInfo{
Id: mockShadowId,
},
},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
},
mockShadowId: &fileidx.FileInfo{
Id: mockShadowId,
},
},
Response: httputil.Err412,
},
},
}
for _, testCase := range testCases {
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, true, []*fileidx.FileInfo{})
response := srv.shadowId(testCase.ShareId)
switch response.(type) {
case *ShareInfo:
res := response.(*ShareInfo)
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
info, found := srv.Index.Get(mockShadowId)
t.Fatalf(
"shadowId: index incorrect got %v found: %v want %v",
info,
found,
testCase.Output.IndexMap[mockShadowId],
)
}
if res.ShareId != mockShadowId {
t.Fatalf("shadowId: mockId incorrect")
}
case httputil.MsgRes:
res := response.(httputil.MsgRes)
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
t.Fatalf("shadowId: map not identical")
}
if res != testCase.Output.Response {
t.Fatalf("shadowId: response incorrect")
}
default:
t.Fatalf("shadowId: return type not found")
}
}
}
func TestPublishId(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
ShareId string
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response interface{}
}
type TestCase struct {
Desc string
Init
Input
Output
}
testCases := []TestCase{
TestCase{
Desc: "success",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
mockShadowId: &fileidx.FileInfo{
Id: mockShadowId,
},
},
},
Input: Input{
ShareId: mockShadowId,
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
mockPublicId: &fileidx.FileInfo{
Id: mockPublicId,
},
},
Response: &ShareInfo{
ShareId: mockPublicId,
},
},
},
TestCase{
Desc: "original id not exists",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{},
Response: httputil.Err404,
},
},
TestCase{
Desc: "dest id exists",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
mockShadowId: &fileidx.FileInfo{
Id: mockShadowId,
},
mockPublicId: &fileidx.FileInfo{
Id: mockPublicId,
},
},
},
Input: Input{
ShareId: mockShadowId,
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
mockShadowId: &fileidx.FileInfo{
Id: mockShadowId,
},
mockPublicId: &fileidx.FileInfo{
Id: mockPublicId,
},
},
Response: httputil.Err412,
},
},
}
for _, testCase := range testCases {
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, false, []*fileidx.FileInfo{})
response := srv.publishId(testCase.ShareId)
switch response.(type) {
case *ShareInfo:
res := response.(*ShareInfo)
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
info, found := srv.Index.Get(mockPublicId)
t.Fatalf(
"shadowId: index incorrect got %v found: %v want %v",
info,
found,
testCase.Output.IndexMap[mockPublicId],
)
}
if res.ShareId != mockPublicId {
t.Fatalf("shadowId: mockId incorrect", res.ShareId, mockPublicId)
}
case httputil.MsgRes:
res := response.(httputil.MsgRes)
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
t.Fatalf("shadowId: map not identical")
}
if res != testCase.Output.Response {
t.Fatalf("shadowId: response incorrect got: %v want: %v", res, testCase.Output.Response)
}
default:
t.Fatalf("shadowId: return type not found")
}
}
}
func TestSetDownLimit(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
mockDownLimit := 100
type Init struct {
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
ShareId string
DownLimit int
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response httputil.MsgRes
}
type TestCase struct {
Desc string
Init
Input
Output
}
testCases := []TestCase{
TestCase{
Desc: "success",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
},
},
},
Input: Input{
ShareId: "0",
DownLimit: mockDownLimit,
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
"0": &fileidx.FileInfo{
Id: "0",
DownLimit: mockDownLimit,
},
},
Response: httputil.Ok200,
},
},
TestCase{
Desc: "not found",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
ShareId: "0",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{},
Response: httputil.Err404,
},
},
}
for _, testCase := range testCases {
srv := initServiceForFileInfoTest(conf, testCase.Init.IndexMap, true, []*fileidx.FileInfo{})
response := srv.setDownLimit(testCase.ShareId, mockDownLimit)
res := response.(httputil.MsgRes)
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
info, _ := srv.Index.Get(testCase.ShareId)
t.Fatalf(
"setDownLimit: index incorrect got: %v want: %v",
info,
testCase.Output.IndexMap[testCase.ShareId],
)
}
if res != testCase.Output.Response {
t.Fatalf("setDownLimit: response incorrect got: %v, want: %v", res, testCase.Output.Response)
}
}
}
func TestAddLocalFiles(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
Infos []*fileidx.FileInfo
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response httputil.MsgRes
}
type TestCase struct {
Desc string
Init
Output
}
testCases := []TestCase{
TestCase{
Desc: "success",
Init: Init{
Infos: []*fileidx.FileInfo{
&fileidx.FileInfo{
Id: "",
DownLimit: 0,
ModTime: 13,
PathLocal: "filename1",
State: "",
Uploaded: 13,
},
},
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
mockPublicId: &fileidx.FileInfo{
Id: mockPublicId,
DownLimit: conf.DownLimit,
ModTime: 13,
PathLocal: filepath.Join(conf.PathLocal, "filename1"),
State: fileidx.StateDone,
Uploaded: 13,
},
},
},
},
}
for _, testCase := range testCases {
srv := initServiceForFileInfoTest(conf, testCase.Output.IndexMap, false, testCase.Init.Infos)
response := srv.AddLocalFilesImp()
res := response.(httputil.MsgRes)
if res.Code != 200 {
t.Fatalf("addLocalFiles: code not correct")
}
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
t.Fatalf(
"addLocalFiles: indexes not identical got: %v want: %v",
srv.Index.List(),
testCase.Output.IndexMap,
)
}
}
}

145
server/apis/service.go Normal file
View file

@ -0,0 +1,145 @@
package apis
import (
"log"
"net/http"
"os"
"strings"
)
import (
"quickshare/server/libs/cfg"
"quickshare/server/libs/encrypt"
"quickshare/server/libs/errutil"
"quickshare/server/libs/fileidx"
"quickshare/server/libs/fsutil"
"quickshare/server/libs/httputil"
"quickshare/server/libs/httpworker"
"quickshare/server/libs/limiter"
"quickshare/server/libs/logutil"
"quickshare/server/libs/qtube"
"quickshare/server/libs/walls"
)
type AddDep func(*SrvShare)
func NewSrvShare(config *cfg.Config) *SrvShare {
logger := logutil.NewSlog(os.Stdout, config.AppName)
setLog := func(srv *SrvShare) {
srv.Log = logger
}
errChecker := errutil.NewErrChecker(!config.Production, logger)
setErr := func(srv *SrvShare) {
srv.Err = errChecker
}
setWorkerPool := func(srv *SrvShare) {
workerPoolSize := config.WorkerPoolSize
taskQueueSize := config.TaskQueueSize
srv.WorkerPool = httpworker.NewWorkerPool(workerPoolSize, taskQueueSize, logger)
}
setWalls := func(srv *SrvShare) {
encrypterMaker := encrypt.JwtEncrypterMaker
ipLimiter := limiter.NewRateLimiter(
config.LimiterCap,
config.LimiterTtl,
config.LimiterCyc,
config.BucketCap,
config.SpecialCaps,
)
opLimiter := limiter.NewRateLimiter(
config.LimiterCap,
config.LimiterTtl,
config.LimiterCyc,
config.BucketCap,
config.SpecialCaps,
)
srv.Walls = walls.NewAccessWalls(config, ipLimiter, opLimiter, encrypterMaker)
}
setIndex := func(srv *SrvShare) {
srv.Index = fileidx.NewMemFileIndex(config.MaxShares)
}
fs := fsutil.NewSimpleFs(errChecker)
setFs := func(srv *SrvShare) {
srv.Fs = fs
}
setDownloader := func(srv *SrvShare) {
srv.Downloader = qtube.NewQTube(
config.PathLocal,
config.MaxDownBytesPerSec,
config.MaxRangeLength,
fs,
)
}
setEncryptor := func(srv *SrvShare) {
srv.Encryptor = &encrypt.HmacEncryptor{Key: config.SecretKeyByte}
}
setHttp := func(srv *SrvShare) {
srv.Http = &httputil.QHttpUtil{
CookieDomain: config.CookieDomain,
CookieHttpOnly: config.CookieHttpOnly,
CookieMaxAge: config.CookieMaxAge,
CookiePath: config.CookiePath,
CookieSecure: config.CookieSecure,
Err: errChecker,
}
}
return InitSrvShare(config, setIndex, setWalls, setWorkerPool, setFs, setDownloader, setEncryptor, setLog, setErr, setHttp)
}
func InitSrvShare(config *cfg.Config, addDeps ...AddDep) *SrvShare {
srv := &SrvShare{}
srv.Conf = config
for _, addDep := range addDeps {
addDep(srv)
}
if !srv.Fs.MkdirAll(srv.Conf.PathLocal, os.FileMode(0775)) {
panic("fail to make ./files/ folder")
}
if res := srv.AddLocalFilesImp(); res != httputil.Ok200 {
panic("fail to add local files")
}
return srv
}
type SrvShare struct {
Conf *cfg.Config
Encryptor encrypt.Encryptor
Err errutil.ErrUtil
Downloader qtube.Downloader
Http httputil.HttpUtil
Index fileidx.FileIndex
Fs fsutil.FsUtil
Log logutil.LogUtil
Walls walls.Walls
WorkerPool httpworker.Workers
}
func (srv *SrvShare) Wrap(serviceFunc httpworker.ServiceFunc) httpworker.DoFunc {
return func(res http.ResponseWriter, req *http.Request) {
body := serviceFunc(res, req)
if body != nil && body != 0 && srv.Http.Fill(body, res) <= 0 {
log.Println("Wrap: fail to fill body", body, res)
}
}
}
func GetRemoteIp(addr string) string {
addrParts := strings.Split(addr, ":")
if len(addrParts) > 0 {
return addrParts[0]
}
return "unknown ip"
}

117
server/apis/test_helper.go Normal file
View file

@ -0,0 +1,117 @@
package apis
import (
"fmt"
"io"
"net/http"
"os"
)
import (
"quickshare/server/libs/fileidx"
"quickshare/server/libs/qtube"
)
type stubFsUtil struct {
MockLocalFileInfos []*fileidx.FileInfo
MockFile *qtube.StubFile
}
var expectCreateFileName = ""
func (fs *stubFsUtil) CreateFile(fileName string) error {
if fileName != expectCreateFileName {
panic(
fmt.Sprintf("CreateFile: got: %s expect: %s", fileName, expectCreateFileName),
)
}
return nil
}
func (fs *stubFsUtil) CopyChunkN(fullPath string, chunk io.Reader, start int64, len int64) bool {
return true
}
func (fs *stubFsUtil) ServeFile(res http.ResponseWriter, req *http.Request, fileName string) {
return
}
func (fs *stubFsUtil) DelFile(fullPath string) bool {
return true
}
func (fs *stubFsUtil) MkdirAll(path string, mode os.FileMode) bool {
return true
}
func (fs *stubFsUtil) Readdir(dirname string, n int) ([]*fileidx.FileInfo, error) {
return fs.MockLocalFileInfos, nil
}
func (fs *stubFsUtil) Open(filePath string) (qtube.ReadSeekCloser, error) {
return fs.MockFile, nil
}
type stubWriter struct {
Headers http.Header
Response []byte
StatusCode int
}
func (w *stubWriter) Header() http.Header {
return w.Headers
}
func (w *stubWriter) Write(body []byte) (int, error) {
w.Response = append(w.Response, body...)
return len(body), nil
}
func (w *stubWriter) WriteHeader(statusCode int) {
w.StatusCode = statusCode
}
type stubDownloader struct {
Content string
}
func (d stubDownloader) ServeFile(w http.ResponseWriter, r *http.Request, fileInfo *fileidx.FileInfo) error {
_, err := w.Write([]byte(d.Content))
return err
}
func sameInfoWithoutTime(info1, info2 *fileidx.FileInfo) bool {
return info1.Id == info2.Id &&
info1.DownLimit == info2.DownLimit &&
info1.PathLocal == info2.PathLocal &&
info1.State == info2.State &&
info1.Uploaded == info2.Uploaded
}
func sameMap(map1, map2 map[string]*fileidx.FileInfo) bool {
for key, info1 := range map1 {
info2, found := map2[key]
if !found || !sameInfoWithoutTime(info1, info2) {
fmt.Printf("infos are not same: \n%v \n%v", info1, info2)
return false
}
}
for key, info2 := range map2 {
info1, found := map1[key]
if !found || !sameInfoWithoutTime(info1, info2) {
fmt.Printf("infos are not same: \n%v \n%v", info1, info2)
return false
}
}
return true
}
type stubEncryptor struct {
MockResult string
}
func (enc *stubEncryptor) Encrypt(content []byte) string {
return enc.MockResult
}

253
server/apis/upload.go Normal file
View file

@ -0,0 +1,253 @@
package apis
import (
"io"
"net/http"
"path/filepath"
"strconv"
"time"
)
import (
"quickshare/server/libs/encrypt"
"quickshare/server/libs/fileidx"
"quickshare/server/libs/fsutil"
httpUtil "quickshare/server/libs/httputil"
worker "quickshare/server/libs/httpworker"
)
const DefaultId = "0"
type ByteRange struct {
ShareId string
Start int64
Length int64
}
type ShareInfo struct {
ShareId string
}
func (srv *SrvShare) StartUploadHandler(res http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
srv.Http.Fill(httpUtil.Err404, res)
return
}
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
ipPass := srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr))
loginPass := srv.Walls.PassLoginCheck(tokenStr, req)
opPass := srv.Walls.PassOpLimit(GetRemoteIp(req.RemoteAddr), srv.Conf.OpIdUpload)
if !ipPass || !loginPass || !opPass {
srv.Http.Fill(httpUtil.Err429, res)
return
}
ack := make(chan error, 1)
ok := srv.WorkerPool.Put(&worker.Task{
Ack: ack,
Do: srv.Wrap(srv.StartUpload),
Res: res,
Req: req,
})
if !ok {
srv.Http.Fill(httpUtil.Err503, res)
}
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
if srv.Err.IsErr(execErr) {
srv.Http.Fill(httpUtil.Err500, res)
}
}
func (srv *SrvShare) UploadHandler(res http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
srv.Http.Fill(httpUtil.Err404, res)
return
}
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
ipPass := srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr))
loginPass := srv.Walls.PassLoginCheck(tokenStr, req)
opPass := srv.Walls.PassOpLimit(GetRemoteIp(req.RemoteAddr), srv.Conf.OpIdUpload)
if !ipPass || !loginPass || !opPass {
srv.Http.Fill(httpUtil.Err429, res)
return
}
multiFormErr := req.ParseMultipartForm(srv.Conf.ParseFormBufSize)
if srv.Err.IsErr(multiFormErr) {
srv.Http.Fill(httpUtil.Err400, res)
return
}
srv.Log.Println("form", req.Form)
srv.Log.Println("pform", req.PostForm)
srv.Log.Println("mform", req.MultipartForm)
ack := make(chan error, 1)
ok := srv.WorkerPool.Put(&worker.Task{
Ack: ack,
Do: srv.Wrap(srv.Upload),
Res: res,
Req: req,
})
if !ok {
srv.Http.Fill(httpUtil.Err503, res)
}
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
if srv.Err.IsErr(execErr) {
srv.Http.Fill(httpUtil.Err500, res)
}
}
func (srv *SrvShare) FinishUploadHandler(res http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
srv.Http.Fill(httpUtil.Err404, res)
return
}
tokenStr := srv.Http.GetCookie(req.Cookies(), srv.Conf.KeyToken)
ipPass := srv.Walls.PassIpLimit(GetRemoteIp(req.RemoteAddr))
loginPass := srv.Walls.PassLoginCheck(tokenStr, req)
opPass := srv.Walls.PassOpLimit(GetRemoteIp(req.RemoteAddr), srv.Conf.OpIdUpload)
if !ipPass || !loginPass || !opPass {
srv.Http.Fill(httpUtil.Err429, res)
return
}
ack := make(chan error, 1)
ok := srv.WorkerPool.Put(&worker.Task{
Ack: ack,
Do: srv.Wrap(srv.FinishUpload),
Res: res,
Req: req,
})
if !ok {
srv.Http.Fill(httpUtil.Err503, res)
}
execErr := srv.WorkerPool.IsInTime(ack, time.Duration(srv.Conf.Timeout)*time.Millisecond)
if srv.Err.IsErr(execErr) {
srv.Http.Fill(httpUtil.Err500, res)
}
}
func (srv *SrvShare) StartUpload(res http.ResponseWriter, req *http.Request) interface{} {
return srv.startUpload(req.FormValue(srv.Conf.KeyFileName))
}
func (srv *SrvShare) startUpload(fileName string) interface{} {
if !IsValidFileName(fileName) {
return httpUtil.Err400
}
id := DefaultId
if srv.Conf.Production {
id = genInfoId(fileName, srv.Conf.SecretKeyByte)
}
info := &fileidx.FileInfo{
Id: id,
DownLimit: srv.Conf.DownLimit,
ModTime: time.Now().UnixNano(),
PathLocal: fileName,
Uploaded: 0,
State: fileidx.StateStarted,
}
switch srv.Index.Add(info) {
case 0:
// go on
case -1:
return httpUtil.Err412
case 1:
return httpUtil.Err500 // TODO: use correct status code
default:
srv.Index.Del(id)
return httpUtil.Err500
}
fullPath := filepath.Join(srv.Conf.PathLocal, info.PathLocal)
createFileErr := srv.Fs.CreateFile(fullPath)
switch {
case createFileErr == fsutil.ErrExists:
srv.Index.Del(id)
return httpUtil.Err412
case createFileErr == fsutil.ErrUnknown:
srv.Index.Del(id)
return httpUtil.Err500
default:
srv.Index.SetState(id, fileidx.StateUploading)
return &ByteRange{
ShareId: id,
Start: 0,
Length: srv.Conf.MaxUpBytesPerSec,
}
}
}
func (srv *SrvShare) Upload(res http.ResponseWriter, req *http.Request) interface{} {
shareId := req.FormValue(srv.Conf.KeyShareId)
start, startErr := strconv.ParseInt(req.FormValue(srv.Conf.KeyStart), 10, 64)
length, lengthErr := strconv.ParseInt(req.FormValue(srv.Conf.KeyLen), 10, 64)
chunk, _, chunkErr := req.FormFile(srv.Conf.KeyChunk)
if srv.Err.IsErr(startErr) ||
srv.Err.IsErr(lengthErr) ||
srv.Err.IsErr(chunkErr) {
return httpUtil.Err400
}
return srv.upload(shareId, start, length, chunk)
}
func (srv *SrvShare) upload(shareId string, start int64, length int64, chunk io.Reader) interface{} {
if !srv.IsValidShareId(shareId) {
return httpUtil.Err400
}
fileInfo, found := srv.Index.Get(shareId)
if !found {
return httpUtil.Err404
}
if !srv.IsValidStart(start, fileInfo.Uploaded) || !srv.IsValidLength(length) {
return httpUtil.Err400
}
fullPath := filepath.Join(srv.Conf.PathLocal, fileInfo.PathLocal)
if !srv.Fs.CopyChunkN(fullPath, chunk, start, length) {
return httpUtil.Err500
}
if srv.Index.IncrUploaded(shareId, length) == 0 {
return httpUtil.Err404
}
return &ByteRange{
ShareId: shareId,
Start: start + length,
Length: srv.Conf.MaxUpBytesPerSec,
}
}
func (srv *SrvShare) FinishUpload(res http.ResponseWriter, req *http.Request) interface{} {
shareId := req.FormValue(srv.Conf.KeyShareId)
return srv.finishUpload(shareId)
}
func (srv *SrvShare) finishUpload(shareId string) interface{} {
if !srv.Index.SetState(shareId, fileidx.StateDone) {
return httpUtil.Err404
}
return &ShareInfo{
ShareId: shareId,
}
}
func genInfoId(content string, key []byte) string {
encrypter := encrypt.HmacEncryptor{Key: key}
return encrypter.Encrypt([]byte(content))
}

368
server/apis/upload_test.go Normal file
View file

@ -0,0 +1,368 @@
package apis
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
import (
"quickshare/server/libs/cfg"
"quickshare/server/libs/encrypt"
"quickshare/server/libs/errutil"
"quickshare/server/libs/fileidx"
"quickshare/server/libs/httputil"
"quickshare/server/libs/httpworker"
"quickshare/server/libs/limiter"
"quickshare/server/libs/logutil"
"quickshare/server/libs/walls"
)
const testCap = 3
func initServiceForUploadTest(config *cfg.Config, indexMap map[string]*fileidx.FileInfo) *SrvShare {
logger := logutil.NewSlog(os.Stdout, config.AppName)
setLog := func(srv *SrvShare) {
srv.Log = logger
}
setWorkerPool := func(srv *SrvShare) {
workerPoolSize := config.WorkerPoolSize
taskQueueSize := config.TaskQueueSize
srv.WorkerPool = httpworker.NewWorkerPool(workerPoolSize, taskQueueSize, logger)
}
setWalls := func(srv *SrvShare) {
encrypterMaker := encrypt.JwtEncrypterMaker
ipLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
opLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
srv.Walls = walls.NewAccessWalls(config, ipLimiter, opLimiter, encrypterMaker)
}
setIndex := func(srv *SrvShare) {
srv.Index = fileidx.NewMemFileIndexWithMap(len(indexMap)+testCap, indexMap)
}
setFs := func(srv *SrvShare) {
srv.Fs = &stubFsUtil{}
}
setEncryptor := func(srv *SrvShare) {
srv.Encryptor = &encrypt.HmacEncryptor{Key: config.SecretKeyByte}
}
errChecker := errutil.NewErrChecker(!config.Production, logger)
setErr := func(srv *SrvShare) {
srv.Err = errChecker
}
return InitSrvShare(config, setIndex, setWalls, setWorkerPool, setFs, setEncryptor, setLog, setErr)
}
func TestStartUpload(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
FileName string
}
type Output struct {
Response interface{}
IndexMap map[string]*fileidx.FileInfo
}
type testCase struct {
Desc string
Init
Input
Output
}
testCases := []testCase{
testCase{
Desc: "invalid file name",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
FileName: "",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{},
Response: httputil.Err400,
},
},
testCase{
Desc: "succeed to start uploading",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
FileName: "filename",
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
DefaultId: &fileidx.FileInfo{
Id: DefaultId,
DownLimit: conf.DownLimit,
ModTime: time.Now().UnixNano(),
PathLocal: "filename",
Uploaded: 0,
State: fileidx.StateUploading,
},
},
Response: &ByteRange{
ShareId: DefaultId,
Start: 0,
Length: conf.MaxUpBytesPerSec,
},
},
},
}
for _, testCase := range testCases {
srv := initServiceForUploadTest(conf, testCase.Init.IndexMap)
// verify CreateFile
expectCreateFileName = filepath.Join(conf.PathLocal, testCase.FileName)
response := srv.startUpload(testCase.FileName)
// verify index
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
t.Fatalf("startUpload: index not equal got: %v, %v, expect: %v", srv.Index.List(), response, testCase.Output.IndexMap)
}
// verify response
switch expectRes := testCase.Output.Response.(type) {
case *ByteRange:
res := response.(*ByteRange)
if res.ShareId != expectRes.ShareId ||
res.Start != expectRes.Start ||
res.Length != expectRes.Length {
t.Fatalf(fmt.Sprintf("startUpload: res=%v expect=%v", res, expectRes))
}
case httputil.MsgRes:
if response != expectRes {
t.Fatalf(fmt.Sprintf("startUpload: reponse=%v expectRes=%v", response, expectRes))
}
default:
t.Fatalf(fmt.Sprintf("startUpload: type not found: %T %T", testCase.Output.Response, httputil.Err400))
}
}
}
func TestUpload(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
ShareId string
Start int64
Len int64
Chunk io.Reader
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response interface{}
}
type testCase struct {
Desc string
Init
Input
Output
}
testCases := []testCase{
testCase{
Desc: "shareid does not exist",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
ShareId: DefaultId,
Start: 0,
Len: 1,
Chunk: strings.NewReader(""),
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{},
Response: httputil.Err404,
},
},
testCase{
Desc: "succeed",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
DefaultId: &fileidx.FileInfo{
Id: DefaultId,
DownLimit: conf.MaxShares,
PathLocal: "path/filename",
State: fileidx.StateUploading,
Uploaded: 0,
},
},
},
Input: Input{
ShareId: DefaultId,
Start: 0,
Len: 1,
Chunk: strings.NewReader("a"),
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
DefaultId: &fileidx.FileInfo{
Id: DefaultId,
DownLimit: conf.MaxShares,
PathLocal: "path/filename",
State: fileidx.StateUploading,
Uploaded: 1,
},
},
Response: &ByteRange{
ShareId: DefaultId,
Start: 1,
Length: conf.MaxUpBytesPerSec,
},
},
},
}
for _, testCase := range testCases {
srv := initServiceForUploadTest(conf, testCase.Init.IndexMap)
response := srv.upload(
testCase.Input.ShareId,
testCase.Input.Start,
testCase.Input.Len,
testCase.Input.Chunk,
)
// TODO: not verified copyChunk
// verify index
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
t.Fatalf("upload: index not identical got: %v want: %v", srv.Index.List(), testCase.Output.IndexMap)
}
// verify response
switch response.(type) {
case *ByteRange:
br := testCase.Output.Response.(*ByteRange)
res := response.(*ByteRange)
if res.ShareId != br.ShareId || res.Start != br.Start || res.Length != br.Length {
t.Fatalf(fmt.Sprintf("upload: response=%v expectRes=%v", res, br))
}
default:
if response != testCase.Output.Response {
t.Fatalf(fmt.Sprintf("upload: response=%v expectRes=%v", response, testCase.Output.Response))
}
}
}
}
func TestFinishUpload(t *testing.T) {
conf := cfg.NewConfig()
conf.Production = false
type Init struct {
IndexMap map[string]*fileidx.FileInfo
}
type Input struct {
ShareId string
Start int64
Len int64
Chunk io.Reader
}
type Output struct {
IndexMap map[string]*fileidx.FileInfo
Response interface{}
}
type testCase struct {
Desc string
Init
Input
Output
}
testCases := []testCase{
testCase{
Desc: "success",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{
DefaultId: &fileidx.FileInfo{
Id: DefaultId,
DownLimit: conf.MaxShares,
PathLocal: "path/filename",
State: fileidx.StateUploading,
Uploaded: 1,
},
},
},
Input: Input{
ShareId: DefaultId,
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{
DefaultId: &fileidx.FileInfo{
Id: DefaultId,
DownLimit: conf.MaxShares,
PathLocal: "path/filename",
State: fileidx.StateDone,
Uploaded: 1,
},
},
Response: &ShareInfo{
ShareId: DefaultId,
},
},
},
testCase{
Desc: "shareId exists",
Init: Init{
IndexMap: map[string]*fileidx.FileInfo{},
},
Input: Input{
ShareId: DefaultId,
},
Output: Output{
IndexMap: map[string]*fileidx.FileInfo{},
Response: httputil.Err404,
},
},
}
for _, testCase := range testCases {
srv := initServiceForUploadTest(conf, testCase.Init.IndexMap)
response := srv.finishUpload(testCase.ShareId)
if !sameMap(srv.Index.List(), testCase.Output.IndexMap) {
t.Fatalf("finishUpload: index not identical got: %v, want: %v", srv.Index.List(), testCase.Output.IndexMap)
}
switch res := response.(type) {
case httputil.MsgRes:
expectRes := testCase.Output.Response.(httputil.MsgRes)
if res != expectRes {
t.Fatalf(fmt.Sprintf("finishUpload: reponse=%v expectRes=%v", res, expectRes))
}
case *ShareInfo:
info, found := testCase.Output.IndexMap[res.ShareId]
if !found || info.State != fileidx.StateDone {
// TODO: should use isValidUrl or better to verify result
t.Fatalf(fmt.Sprintf("finishUpload: share info is not correct: received: %v expect: %v", res.ShareId, testCase.ShareId))
}
default:
t.Fatalf(fmt.Sprintf("finishUpload: type not found: %T %T", response, testCase.Output.Response))
}
}
}

251
server/libs/cfg/cfg.go Normal file
View file

@ -0,0 +1,251 @@
package cfg
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"strconv"
"strings"
)
type Config struct {
AppName string
AdminId string
AdminPwd string
SecretKey string
SecretKeyByte []byte `json:",omitempty"`
// server
Production bool
HostName string
Port int
// performance
MaxUpBytesPerSec int64
MaxDownBytesPerSec int64
MaxRangeLength int64
Timeout int // millisecond
ReadTimeout int
WriteTimeout int
IdleTimeout int
WorkerPoolSize int
TaskQueueSize int
QueueSize int
ParseFormBufSize int64
MaxHeaderBytes int
DownLimit int
MaxShares int
LocalFileLimit int
// Cookie
CookieDomain string
CookieHttpOnly bool
CookieMaxAge int
CookiePath string
CookieSecure bool
// keys
KeyAdminId string
KeyAdminPwd string
KeyToken string
KeyFileName string
KeyFileSize string
KeyShareId string
KeyStart string
KeyLen string
KeyChunk string
KeyAct string
KeyExpires string
KeyDownLimit string
ActStartUpload string
ActUpload string
ActFinishUpload string
ActLogin string
ActLogout string
ActShadowId string
ActPublishId string
ActSetDownLimit string
ActAddLocalFiles string
// resource id
AllUsers string
// opIds
OpIdIpVisit int16
OpIdUpload int16
OpIdDownload int16
OpIdLogin int16
OpIdGetFInfo int16
OpIdDelFInfo int16
OpIdOpFInfo int16
// local
PathLocal string
PathLogin string
PathDownloadLogin string
PathDownload string
PathUpload string
PathStartUpload string
PathFinishUpload string
PathFileInfo string
PathClient string
// rate Limiter
LimiterCap int64
LimiterTtl int32
LimiterCyc int32
BucketCap int16
SpecialCapsStr map[string]int16
SpecialCaps map[int16]int16
}
func NewConfig() *Config {
config := &Config{
// secrets
AppName: "qs",
AdminId: "admin",
AdminPwd: "qs",
SecretKey: "qs",
SecretKeyByte: []byte("qs"),
// server
Production: true,
HostName: "localhost",
Port: 8888,
// performance
MaxUpBytesPerSec: 500 * 1000,
MaxDownBytesPerSec: 500 * 1000,
MaxRangeLength: 10 * 1024 * 1024,
Timeout: 500, // millisecond,
ReadTimeout: 500,
WriteTimeout: 43200000,
IdleTimeout: 10000,
WorkerPoolSize: 2,
TaskQueueSize: 2,
QueueSize: 2,
ParseFormBufSize: 600,
MaxHeaderBytes: 1 << 15, // 32KB
DownLimit: -1,
MaxShares: 1 << 31,
LocalFileLimit: -1,
// Cookie
CookieDomain: "",
CookieHttpOnly: false,
CookieMaxAge: 3600 * 24 * 30, // one week,
CookiePath: "/",
CookieSecure: false,
// keys
KeyAdminId: "adminid",
KeyAdminPwd: "adminpwd",
KeyToken: "token",
KeyFileName: "fname",
KeyFileSize: "size",
KeyShareId: "shareid",
KeyStart: "start",
KeyLen: "len",
KeyChunk: "chunk",
KeyAct: "act",
KeyExpires: "expires",
KeyDownLimit: "downlimit",
ActStartUpload: "startupload",
ActUpload: "upload",
ActFinishUpload: "finishupload",
ActLogin: "login",
ActLogout: "logout",
ActShadowId: "shadowid",
ActPublishId: "publishid",
ActSetDownLimit: "setdownlimit",
ActAddLocalFiles: "addlocalfiles",
AllUsers: "allusers",
// opIds
OpIdIpVisit: 0,
OpIdUpload: 1,
OpIdDownload: 2,
OpIdLogin: 3,
OpIdGetFInfo: 4,
OpIdDelFInfo: 5,
OpIdOpFInfo: 6,
// local
PathLocal: "files",
PathLogin: "/login",
PathDownloadLogin: "/download-login",
PathDownload: "/download",
PathUpload: "/upload",
PathStartUpload: "/startupload",
PathFinishUpload: "/finishupload",
PathFileInfo: "/fileinfo",
PathClient: "/",
// rate Limiter
LimiterCap: 256, // how many op supported for each user
LimiterTtl: 3600, // second
LimiterCyc: 1, // second
BucketCap: 3, // how many op can do per LimiterCyc sec
SpecialCaps: map[int16]int16{
0: 5, // ip
1: 1, // upload
2: 1, // download
3: 1, // login
},
}
return config
}
func NewConfigFrom(path string) *Config {
configBytes, readErr := ioutil.ReadFile(path)
if readErr != nil {
panic(fmt.Sprintf("config file not found: %s", path))
}
config := &Config{}
marshalErr := json.Unmarshal(configBytes, config)
// TODO: look for a better solution
config.SpecialCaps = make(map[int16]int16)
for strKey, value := range config.SpecialCapsStr {
key, parseKeyErr := strconv.ParseInt(strKey, 10, 16)
if parseKeyErr != nil {
panic("fail to parse SpecialCapsStr, its type should be map[int16]int16")
}
config.SpecialCaps[int16(key)] = value
}
if marshalErr != nil {
panic("config file format is incorrect")
}
config.SecretKeyByte = []byte(config.SecretKey)
if config.HostName == "" {
hostName, err := GetLocalAddr()
if err != nil {
panic(err)
}
config.HostName = hostName.String()
}
return config
}
func GetLocalAddr() (net.IP, error) {
fmt.Println(`config.HostName is empty(""), choose one IP for listening automatically.`)
infs, err := net.Interfaces()
if err != nil {
panic("fail to get net interfaces")
}
for _, inf := range infs {
if inf.Flags&4 != 4 && !strings.Contains(inf.Name, "docker") {
addrs, err := inf.Addrs()
if err != nil {
panic("fail to get addrs of interface")
}
for _, addr := range addrs {
switch v := addr.(type) {
case *net.IPAddr:
if !strings.Contains(v.IP.String(), ":") {
return v.IP, nil
}
case *net.IPNet:
if !strings.Contains(v.IP.String(), ":") {
return v.IP, nil
}
}
}
}
}
return nil, errors.New("no addr found")
}

View file

@ -0,0 +1,17 @@
package encrypt
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
)
type HmacEncryptor struct {
Key []byte
}
func (encryptor *HmacEncryptor) Encrypt(content []byte) string {
mac := hmac.New(sha256.New, encryptor.Key)
mac.Write(content)
return hex.EncodeToString(mac.Sum(nil))
}

View file

@ -0,0 +1,5 @@
package encrypt
type Encryptor interface {
Encrypt(content []byte) string
}

View file

@ -0,0 +1,53 @@
package encrypt
import (
"github.com/robbert229/jwt"
)
func JwtEncrypterMaker(secret string) TokenEncrypter {
return &JwtEncrypter{
alg: jwt.HmacSha256(secret),
claims: jwt.NewClaim(),
}
}
type JwtEncrypter struct {
alg jwt.Algorithm
claims *jwt.Claims
}
func (encrypter *JwtEncrypter) Add(key string, value string) bool {
encrypter.claims.Set(key, value)
return true
}
func (encrypter *JwtEncrypter) FromStr(token string) bool {
claims, err := encrypter.alg.Decode(token)
// TODO: should return error or error info will lost
if err != nil {
return false
}
encrypter.claims = claims
return true
}
func (encrypter *JwtEncrypter) Get(key string) (string, bool) {
iValue, err := encrypter.claims.Get(key)
// TODO: should return error or error info will lost
if err != nil {
return "", false
}
return iValue.(string), true
}
func (encrypter *JwtEncrypter) ToStr() (string, bool) {
token, err := encrypter.alg.Encode(encrypter.claims)
// TODO: should return error or error info will lost
if err != nil {
return "", false
}
return token, true
}

View file

@ -0,0 +1,11 @@
package encrypt
type EncrypterMaker func(string) TokenEncrypter
// TODO: name should be Encrypter?
type TokenEncrypter interface {
Add(string, string) bool
FromStr(string) bool
Get(string) (string, bool)
ToStr() (string, bool)
}

View file

@ -0,0 +1,59 @@
package errutil
import (
"os"
"runtime/debug"
)
import (
"quickshare/server/libs/logutil"
)
type ErrUtil interface {
IsErr(err error) bool
IsFatalErr(err error) bool
RecoverPanic()
}
func NewErrChecker(logStack bool, logger logutil.LogUtil) ErrUtil {
return &ErrChecker{logStack: logStack, log: logger}
}
type ErrChecker struct {
log logutil.LogUtil
logStack bool
}
// IsErr checks if error occurs
func (e *ErrChecker) IsErr(err error) bool {
if err != nil {
e.log.Printf("Error:%q\n", err)
if e.logStack {
e.log.Println(debug.Stack())
}
return true
}
return false
}
// IsFatalPanic should be used with defer
func (e *ErrChecker) IsFatalErr(fe error) bool {
if fe != nil {
e.log.Printf("Panic:%q", fe)
if e.logStack {
e.log.Println(debug.Stack())
}
os.Exit(1)
}
return false
}
// RecoverPanic catchs the panic and logs panic information
func (e *ErrChecker) RecoverPanic() {
if r := recover(); r != nil {
e.log.Printf("Recovered:%v", r)
if e.logStack {
e.log.Println(debug.Stack())
}
}
}

View file

@ -0,0 +1,177 @@
package fileidx
import (
"sync"
)
const (
// StateStarted = after startUpload before upload
StateStarted = "started"
// StateUploading =after upload before finishUpload
StateUploading = "uploading"
// StateDone = after finishedUpload
StateDone = "done"
)
type FileInfo struct {
Id string
DownLimit int
ModTime int64
PathLocal string
State string
Uploaded int64
}
type FileIndex interface {
Add(fileInfo *FileInfo) int
Del(id string)
SetId(id string, newId string) bool
SetDownLimit(id string, downLimit int) bool
DecrDownLimit(id string) (int, bool)
SetState(id string, state string) bool
IncrUploaded(id string, uploaded int64) int64
Get(id string) (*FileInfo, bool)
List() map[string]*FileInfo
}
func NewMemFileIndex(cap int) *MemFileIndex {
return &MemFileIndex{
cap: cap,
infos: make(map[string]*FileInfo, 0),
}
}
func NewMemFileIndexWithMap(cap int, infos map[string]*FileInfo) *MemFileIndex {
return &MemFileIndex{
cap: cap,
infos: infos,
}
}
type MemFileIndex struct {
cap int
infos map[string]*FileInfo
mux sync.RWMutex
}
func (idx *MemFileIndex) Add(fileInfo *FileInfo) int {
idx.mux.Lock()
defer idx.mux.Unlock()
if len(idx.infos) >= idx.cap {
return 1
}
if _, found := idx.infos[fileInfo.Id]; found {
return -1
}
idx.infos[fileInfo.Id] = fileInfo
return 0
}
func (idx *MemFileIndex) Del(id string) {
idx.mux.Lock()
defer idx.mux.Unlock()
delete(idx.infos, id)
}
func (idx *MemFileIndex) SetId(id string, newId string) bool {
if id == newId {
return true
}
idx.mux.Lock()
defer idx.mux.Unlock()
info, found := idx.infos[id]
if !found {
return false
}
if _, foundNewId := idx.infos[newId]; foundNewId {
return false
}
idx.infos[newId] = info
idx.infos[newId].Id = newId
delete(idx.infos, id)
return true
}
func (idx *MemFileIndex) SetDownLimit(id string, downLimit int) bool {
idx.mux.Lock()
defer idx.mux.Unlock()
info, found := idx.infos[id]
if !found {
return false
}
info.DownLimit = downLimit
return true
}
func (idx *MemFileIndex) DecrDownLimit(id string) (int, bool) {
idx.mux.Lock()
defer idx.mux.Unlock()
info, found := idx.infos[id]
if !found || info.State != StateDone {
return 0, false
}
if info.DownLimit == 0 {
return 1, false
}
if info.DownLimit > 0 {
// info.DownLimit means unlimited
info.DownLimit = info.DownLimit - 1
}
return 1, true
}
func (idx *MemFileIndex) SetState(id string, state string) bool {
idx.mux.Lock()
defer idx.mux.Unlock()
info, found := idx.infos[id]
if !found {
return false
}
info.State = state
return true
}
func (idx *MemFileIndex) IncrUploaded(id string, uploaded int64) int64 {
idx.mux.Lock()
defer idx.mux.Unlock()
info, found := idx.infos[id]
if !found {
return 0
}
info.Uploaded = info.Uploaded + uploaded
return info.Uploaded
}
func (idx *MemFileIndex) Get(id string) (*FileInfo, bool) {
idx.mux.RLock()
defer idx.mux.RUnlock()
infos, found := idx.infos[id]
return infos, found
}
func (idx *MemFileIndex) List() map[string]*FileInfo {
idx.mux.RLock()
defer idx.mux.RUnlock()
return idx.infos
}
// TODO: add unit tests

View file

@ -0,0 +1,118 @@
package fsutil
import (
"errors"
"io"
"os"
)
import (
"quickshare/server/libs/errutil"
"quickshare/server/libs/fileidx"
"quickshare/server/libs/qtube"
)
type FsUtil interface {
CreateFile(fullPath string) error
CopyChunkN(fullPath string, chunk io.Reader, start int64, length int64) bool
DelFile(fullPath string) bool
Open(fullPath string) (qtube.ReadSeekCloser, error)
MkdirAll(path string, mode os.FileMode) bool
Readdir(dirName string, n int) ([]*fileidx.FileInfo, error)
}
func NewSimpleFs(errUtil errutil.ErrUtil) FsUtil {
return &SimpleFs{
Err: errUtil,
}
}
type SimpleFs struct {
Err errutil.ErrUtil
}
var (
ErrExists = errors.New("file exists")
ErrUnknown = errors.New("unknown error")
)
func (sfs *SimpleFs) CreateFile(fullPath string) error {
flag := os.O_CREATE | os.O_EXCL | os.O_RDONLY
perm := os.FileMode(0644)
newFile, err := os.OpenFile(fullPath, flag, perm)
defer newFile.Close()
if err == nil {
return nil
} else if os.IsExist(err) {
return ErrExists
} else {
return ErrUnknown
}
}
func (sfs *SimpleFs) CopyChunkN(fullPath string, chunk io.Reader, start int64, length int64) bool {
flag := os.O_WRONLY
perm := os.FileMode(0644)
file, openErr := os.OpenFile(fullPath, flag, perm)
defer file.Close()
if sfs.Err.IsErr(openErr) {
return false
}
if _, err := file.Seek(start, io.SeekStart); sfs.Err.IsErr(err) {
return false
}
if _, err := io.CopyN(file, chunk, length); sfs.Err.IsErr(err) && err != io.EOF {
return false
}
return true
}
func (sfs *SimpleFs) DelFile(fullPath string) bool {
return !sfs.Err.IsErr(os.Remove(fullPath))
}
func (sfs *SimpleFs) MkdirAll(path string, mode os.FileMode) bool {
err := os.MkdirAll(path, mode)
return !sfs.Err.IsErr(err)
}
// TODO: not support read from last seek position
func (sfs *SimpleFs) Readdir(dirName string, n int) ([]*fileidx.FileInfo, error) {
dir, openErr := os.Open(dirName)
defer dir.Close()
if sfs.Err.IsErr(openErr) {
return []*fileidx.FileInfo{}, openErr
}
osFileInfos, readErr := dir.Readdir(n)
if sfs.Err.IsErr(readErr) && readErr != io.EOF {
return []*fileidx.FileInfo{}, readErr
}
fileInfos := make([]*fileidx.FileInfo, 0)
for _, osFileInfo := range osFileInfos {
if osFileInfo.Mode().IsRegular() {
fileInfos = append(
fileInfos,
&fileidx.FileInfo{
ModTime: osFileInfo.ModTime().UnixNano(),
PathLocal: osFileInfo.Name(),
Uploaded: osFileInfo.Size(),
},
)
}
}
return fileInfos, readErr
}
// the associated file descriptor has mode O_RDONLY as using os.Open
func (sfs *SimpleFs) Open(fullPath string) (qtube.ReadSeekCloser, error) {
return os.Open(fullPath)
}

View file

@ -0,0 +1,84 @@
package httputil
import (
"encoding/json"
"net/http"
"time"
)
import (
"quickshare/server/libs/errutil"
)
type MsgRes struct {
Code int
Msg string
}
var (
Err400 = MsgRes{Code: http.StatusBadRequest, Msg: "Bad Request"}
Err401 = MsgRes{Code: http.StatusUnauthorized, Msg: "Unauthorized"}
Err404 = MsgRes{Code: http.StatusNotFound, Msg: "Not Found"}
Err412 = MsgRes{Code: http.StatusPreconditionFailed, Msg: "Precondition Failed"}
Err429 = MsgRes{Code: http.StatusTooManyRequests, Msg: "Too Many Requests"}
Err500 = MsgRes{Code: http.StatusInternalServerError, Msg: "Internal Server Error"}
Err503 = MsgRes{Code: http.StatusServiceUnavailable, Msg: "Service Unavailable"}
Err504 = MsgRes{Code: http.StatusGatewayTimeout, Msg: "Gateway Timeout"}
Ok200 = MsgRes{Code: http.StatusOK, Msg: "OK"}
)
type HttpUtil interface {
GetCookie(cookies []*http.Cookie, key string) string
SetCookie(res http.ResponseWriter, key string, val string)
Fill(msg interface{}, res http.ResponseWriter) int
}
type QHttpUtil struct {
CookieDomain string
CookieHttpOnly bool
CookieMaxAge int
CookiePath string
CookieSecure bool
Err errutil.ErrUtil
}
func (q *QHttpUtil) GetCookie(cookies []*http.Cookie, key string) string {
for _, cookie := range cookies {
if cookie.Name == key {
return cookie.Value
}
}
return ""
}
func (q *QHttpUtil) SetCookie(res http.ResponseWriter, key string, val string) {
cookie := http.Cookie{
Name: key,
Value: val,
Domain: q.CookieDomain,
Expires: time.Now().Add(time.Duration(q.CookieMaxAge) * time.Second),
HttpOnly: q.CookieHttpOnly,
MaxAge: q.CookieMaxAge,
Secure: q.CookieSecure,
Path: q.CookiePath,
}
res.Header().Set("Set-Cookie", cookie.String())
}
func (q *QHttpUtil) Fill(msg interface{}, res http.ResponseWriter) int {
if msg == nil {
return 0
}
msgBytes, marsErr := json.Marshal(msg)
if q.Err.IsErr(marsErr) {
return 0
}
wrote, writeErr := res.Write(msgBytes)
if q.Err.IsErr(writeErr) {
return 0
}
return wrote
}

View file

@ -0,0 +1,130 @@
package httpworker
import (
"errors"
"net/http"
"runtime/debug"
"time"
)
import (
"quickshare/server/libs/logutil"
)
var (
ErrWorkerNotFound = errors.New("worker not found")
ErrTimeout = errors.New("timeout")
)
type DoFunc func(http.ResponseWriter, *http.Request)
type Task struct {
Ack chan error
Do DoFunc
Res http.ResponseWriter
Req *http.Request
}
type Workers interface {
Put(*Task) bool
IsInTime(ack chan error, msec time.Duration) error
}
type WorkerPool struct {
queue chan *Task
size int
workers []*Worker
log logutil.LogUtil // TODO: should not pass log here
}
func NewWorkerPool(poolSize int, queueSize int, log logutil.LogUtil) Workers {
queue := make(chan *Task, queueSize)
workers := make([]*Worker, 0, poolSize)
for i := 0; i < poolSize; i++ {
worker := &Worker{
Id: uint64(i),
queue: queue,
log: log,
}
go worker.Start()
workers = append(workers, worker)
}
return &WorkerPool{
queue: queue,
size: poolSize,
workers: workers,
log: log,
}
}
func (pool *WorkerPool) Put(task *Task) bool {
if len(pool.queue) >= pool.size {
return false
}
pool.queue <- task
return true
}
func (pool *WorkerPool) IsInTime(ack chan error, msec time.Duration) error {
start := time.Now().UnixNano()
timeout := make(chan error)
go func() {
time.Sleep(msec)
timeout <- ErrTimeout
}()
select {
case err := <-ack:
if err == nil {
pool.log.Printf(
"finish cost: %d usec",
(time.Now().UnixNano()-start)/1000,
)
} else {
pool.log.Printf(
"finish with error cost: %d usec",
(time.Now().UnixNano()-start)/1000,
)
}
return err
case errTimeout := <-timeout:
pool.log.Printf("timeout cost: %d usec", (time.Now().UnixNano()-start)/1000)
return errTimeout
}
}
type Worker struct {
Id uint64
queue chan *Task
log logutil.LogUtil
}
func (worker *Worker) RecoverPanic() {
if r := recover(); r != nil {
worker.log.Printf("Recovered:%v stack: %v", r, debug.Stack())
// restart worker and IsInTime will return timeout error for last task
worker.Start()
}
}
func (worker *Worker) Start() {
defer worker.RecoverPanic()
for {
task := <-worker.queue
if task.Do != nil {
task.Do(task.Res, task.Req)
task.Ack <- nil
} else {
task.Ack <- ErrWorkerNotFound
}
}
}
// ServiceFunc lets you return struct directly
type ServiceFunc func(http.ResponseWriter, *http.Request) interface{}

View file

@ -0,0 +1,5 @@
package limiter
type Limiter interface {
Access(string, int16) bool
}

View file

@ -0,0 +1,220 @@
package limiter
import (
"sync"
"time"
)
func now() int32 {
return int32(time.Now().Unix())
}
func afterCyc(cyc int32) int32 {
return int32(time.Now().Unix()) + cyc
}
func afterTtl(ttl int32) int32 {
return int32(time.Now().Unix()) + ttl
}
type Bucket struct {
Refresh int32
Tokens int16
}
func NewBucket(cyc int32, cap int16) *Bucket {
return &Bucket{
Refresh: afterCyc(cyc),
Tokens: cap,
}
}
type Item struct {
Expired int32
Buckets map[int16]*Bucket
}
func NewItem(ttl int32) *Item {
return &Item{
Expired: afterTtl(ttl),
Buckets: make(map[int16]*Bucket),
}
}
type RateLimiter struct {
items map[string]*Item
bucketCap int16
customCaps map[int16]int16
cap int64
cyc int32 // how much time, item autoclean will be executed, bucket will be refreshed
ttl int32 // how much time, item will be expired(but not cleaned)
mux sync.RWMutex
snapshot map[string]map[int16]*Bucket
}
func NewRateLimiter(cap int64, ttl int32, cyc int32, bucketCap int16, customCaps map[int16]int16) Limiter {
if cap < 1 || ttl < 1 || cyc < 1 || bucketCap < 1 {
panic("cap | bucketCap | ttl | cycle cant be less than 1")
}
limiter := &RateLimiter{
items: make(map[string]*Item, cap),
bucketCap: bucketCap,
customCaps: customCaps,
cap: cap,
ttl: ttl,
cyc: cyc,
}
go limiter.autoClean()
return limiter
}
func (limiter *RateLimiter) getBucketCap(opId int16) int16 {
bucketCap, existed := limiter.customCaps[opId]
if !existed {
return limiter.bucketCap
}
return bucketCap
}
func (limiter *RateLimiter) Access(itemId string, opId int16) bool {
limiter.mux.Lock()
defer limiter.mux.Unlock()
item, itemExisted := limiter.items[itemId]
if !itemExisted {
if int64(len(limiter.items)) >= limiter.cap {
return false
}
limiter.items[itemId] = NewItem(limiter.ttl)
limiter.items[itemId].Buckets[opId] = NewBucket(limiter.cyc, limiter.getBucketCap(opId)-1)
return true
}
bucket, bucketExisted := item.Buckets[opId]
if !bucketExisted {
item.Buckets[opId] = NewBucket(limiter.cyc, limiter.getBucketCap(opId)-1)
return true
}
if bucket.Refresh > now() {
if bucket.Tokens > 0 {
bucket.Tokens--
return true
}
return false
}
bucket.Refresh = afterCyc(limiter.cyc)
bucket.Tokens = limiter.getBucketCap(opId) - 1
return true
}
func (limiter *RateLimiter) GetCap() int64 {
return limiter.cap
}
func (limiter *RateLimiter) GetSize() int64 {
limiter.mux.RLock()
defer limiter.mux.RUnlock()
return int64(len(limiter.items))
}
func (limiter *RateLimiter) ExpandCap(cap int64) bool {
limiter.mux.RLock()
defer limiter.mux.RUnlock()
if cap <= int64(len(limiter.items)) {
return false
}
limiter.cap = cap
return true
}
func (limiter *RateLimiter) GetTTL() int32 {
return limiter.ttl
}
func (limiter *RateLimiter) UpdateTTL(ttl int32) bool {
if ttl < 1 {
return false
}
limiter.ttl = ttl
return true
}
func (limiter *RateLimiter) GetCyc() int32 {
return limiter.cyc
}
func (limiter *RateLimiter) UpdateCyc(cyc int32) bool {
if limiter.cyc < 1 {
return false
}
limiter.cyc = cyc
return true
}
func (limiter *RateLimiter) Snapshot() map[string]map[int16]*Bucket {
return limiter.snapshot
}
func (limiter *RateLimiter) autoClean() {
for {
if limiter.cyc == 0 {
break
}
time.Sleep(time.Duration(int64(limiter.cyc) * 1000000000))
limiter.clean()
}
}
// clean may add affect other operations, do frequently?
func (limiter *RateLimiter) clean() {
limiter.snapshot = make(map[string]map[int16]*Bucket)
now := now()
limiter.mux.RLock()
defer limiter.mux.RUnlock()
for key, item := range limiter.items {
if item.Expired <= now {
delete(limiter.items, key)
} else {
limiter.snapshot[key] = item.Buckets
}
}
}
// Only for test
func (limiter *RateLimiter) exist(id string) bool {
limiter.mux.RLock()
defer limiter.mux.RUnlock()
_, existed := limiter.items[id]
return existed
}
// Only for test
func (limiter *RateLimiter) truncate() {
limiter.mux.RLock()
defer limiter.mux.RUnlock()
for key, _ := range limiter.items {
delete(limiter.items, key)
}
}
// Only for test
func (limiter *RateLimiter) get(id string) (*Item, bool) {
limiter.mux.RLock()
defer limiter.mux.RUnlock()
item, existed := limiter.items[id]
return item, existed
}

View file

@ -0,0 +1,161 @@
package limiter
import (
"fmt"
"math/rand"
"testing"
"time"
)
var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
const rndCap = 10000
const addCap = 1
// how to set time
// extend: wait can be greater than ttl/2
// cyc is smaller than ttl and wait, then it can be clean in time
const cap = 40
const ttl = 3
const cyc = 1
const bucketCap = 2
const id1 = "id1"
const id2 = "id2"
const op1 int16 = 0
const op2 int16 = 1
var customCaps = map[int16]int16{
op2: 1000,
}
const wait = 1
var limiter = NewRateLimiter(cap, ttl, cyc, bucketCap, customCaps).(*RateLimiter)
func printItem(id string) {
item, existed := limiter.get(id1)
if existed {
fmt.Println("expired, now, existed", item.Expired, now(), existed)
for id, bucket := range item.Buckets {
fmt.Println("\tid, bucket", id, bucket)
}
} else {
fmt.Println("not existed")
}
}
var idSeed = 0
func randId() string {
idSeed++
return fmt.Sprintf("%d", idSeed)
}
func TestAccess(t *testing.T) {
func(t *testing.T) {
canAccess := limiter.Access(id1, op1)
if !canAccess {
t.Fatal("access: fail")
}
for i := 0; i < bucketCap; i++ {
canAccess = limiter.Access(id1, op1)
}
if canAccess {
t.Fatal("access: fail to deny access")
}
time.Sleep(time.Duration(limiter.GetCyc()) * time.Second)
canAccess = limiter.Access(id1, op1)
if !canAccess {
t.Fatal("access: fail to refresh tokens")
}
}(t)
}
func TestCap(t *testing.T) {
originalCap := limiter.GetCap()
fmt.Printf("cap:info: %d\n", originalCap)
ok := limiter.ExpandCap(originalCap + addCap)
if !ok || limiter.GetCap() != originalCap+addCap {
t.Fatal("cap: fail to expand")
}
ok = limiter.ExpandCap(limiter.GetSize() - addCap)
if ok {
t.Fatal("cap: shrink cap")
}
ids := []string{}
for limiter.GetSize() < limiter.GetCap() {
id := randId()
ids = append(ids, id)
ok := limiter.Access(id, 0)
if !ok {
t.Fatal("cap: not full")
}
}
if limiter.GetSize() != limiter.GetCap() {
t.Fatal("cap: incorrect size")
}
if limiter.Access(randId(), 0) {
t.Fatal("cap: more than cap")
}
limiter.truncate()
}
func TestTtl(t *testing.T) {
var addTtl int32 = 1
originalTTL := limiter.GetTTL()
fmt.Printf("ttl:info: %d\n", originalTTL)
limiter.UpdateTTL(originalTTL + addTtl)
if limiter.GetTTL() != originalTTL+addTtl {
t.Fatal("ttl: update fail")
}
}
func cycTest(t *testing.T) {
var addCyc int32 = 1
originalCyc := limiter.GetCyc()
fmt.Printf("cyc:info: %d\n", originalCyc)
limiter.UpdateCyc(originalCyc + addCyc)
if limiter.GetCyc() != originalCyc+addCyc {
t.Fatal("cyc: update fail")
}
}
func autoCleanTest(t *testing.T) {
ids := []string{
randId(),
randId(),
}
for _, id := range ids {
ok := limiter.Access(id, 0)
if ok {
t.Fatal("autoClean: warning: add fail")
}
}
time.Sleep(time.Duration(limiter.GetTTL()+wait) * time.Second)
for _, id := range ids {
_, exist := limiter.get(id)
if exist {
t.Fatal("autoClean: item still exist")
}
}
}
// func snapshotTest(t *testing.T) {
// }

View file

@ -0,0 +1,7 @@
package logutil
type LogUtil interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
Println(v ...interface{})
}

View file

@ -0,0 +1,12 @@
package logutil
import (
"io"
"log"
)
func NewSlog(out io.Writer, prefix string) LogUtil {
return log.New(out, prefix, log.Ldate|log.Ltime|log.Lshortfile)
}
type Slog *log.Logger

View file

@ -0,0 +1,13 @@
package qtube
import (
"net/http"
)
import (
"quickshare/server/libs/fileidx"
)
type Downloader interface {
ServeFile(res http.ResponseWriter, req *http.Request, fileInfo *fileidx.FileInfo) error
}

280
server/libs/qtube/qtube.go Normal file
View file

@ -0,0 +1,280 @@
package qtube
import (
"errors"
"fmt"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
)
import (
"quickshare/server/libs/fileidx"
)
var (
ErrCopy = errors.New("ServeFile: copy error")
ErrUnknown = errors.New("ServeFile: unknown error")
)
type httpRange struct {
start, length int64
}
func (ra *httpRange) GetStart() int64 {
return ra.start
}
func (ra *httpRange) GetLength() int64 {
return ra.length
}
func (ra *httpRange) SetStart(start int64) {
ra.start = start
}
func (ra *httpRange) SetLength(length int64) {
ra.length = length
}
func NewQTube(root string, copySpeed, maxRangeLen int64, filer FileReadSeekCloser) Downloader {
return &QTube{
Root: root,
BytesPerSec: copySpeed,
MaxRangeLen: maxRangeLen,
Filer: filer,
}
}
type QTube struct {
Root string
BytesPerSec int64
MaxRangeLen int64
Filer FileReadSeekCloser
}
type FileReadSeekCloser interface {
Open(filePath string) (ReadSeekCloser, error)
}
type ReadSeekCloser interface {
io.Reader
io.Seeker
io.Closer
}
const (
ErrorInvalidRange = "ServeFile: invalid Range"
ErrorInvalidSize = "ServeFile: invalid Range total size"
)
func (tb *QTube) ServeFile(res http.ResponseWriter, req *http.Request, fileInfo *fileidx.FileInfo) error {
headerRange := req.Header.Get("Range")
switch {
case req.Method == http.MethodHead:
res.Header().Set("Accept-Ranges", "bytes")
res.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Uploaded))
res.Header().Set("Content-Type", "application/octet-stream")
res.WriteHeader(http.StatusOK)
return nil
case headerRange == "":
return tb.serveAll(res, fileInfo)
default:
return tb.serveRanges(res, headerRange, fileInfo)
}
}
func (tb *QTube) serveAll(res http.ResponseWriter, fileInfo *fileidx.FileInfo) error {
res.Header().Set("Accept-Ranges", "bytes")
res.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filepath.Base(fileInfo.PathLocal)))
res.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Uploaded))
res.Header().Set("Content-Type", "application/octet-stream")
res.Header().Set("Last-Modified", time.Unix(fileInfo.ModTime, 0).UTC().Format(http.TimeFormat))
res.WriteHeader(http.StatusOK)
// TODO: need verify path
file, openErr := tb.Filer.Open(filepath.Join(tb.Root, fileInfo.PathLocal))
defer file.Close()
if openErr != nil {
return openErr
}
copyErr := tb.throttledCopyN(res, file, fileInfo.Uploaded)
if copyErr != nil && copyErr != io.EOF {
return copyErr
}
return nil
}
func (tb *QTube) serveRanges(res http.ResponseWriter, headerRange string, fileInfo *fileidx.FileInfo) error {
ranges, rangeErr := getRanges(headerRange, fileInfo.Uploaded)
if rangeErr != nil {
http.Error(res, rangeErr.Error(), http.StatusRequestedRangeNotSatisfiable)
return errors.New(rangeErr.Error())
}
switch {
case len(ranges) == 1 || len(ranges) > 1:
if tb.copyRange(res, ranges[0], fileInfo) != nil {
return ErrCopy
}
default:
// TODO: add support for multiple ranges
return ErrUnknown
}
return nil
}
func getRanges(headerRange string, size int64) ([]httpRange, error) {
ranges, raParseErr := parseRange(headerRange, size)
// TODO: check max number of ranges, range start end
if len(ranges) <= 0 || raParseErr != nil {
return nil, errors.New(ErrorInvalidRange)
}
if sumRangesSize(ranges) > size {
return nil, errors.New(ErrorInvalidSize)
}
return ranges, nil
}
func (tb *QTube) copyRange(res http.ResponseWriter, ra httpRange, fileInfo *fileidx.FileInfo) error {
// TODO: comfirm this wont cause problem
if ra.GetLength() > tb.MaxRangeLen {
ra.SetLength(tb.MaxRangeLen)
}
// TODO: add headers(ETag): https://tools.ietf.org/html/rfc7233#section-4.1 p11 2nd paragraph
res.Header().Set("Accept-Ranges", "bytes")
res.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filepath.Base(fileInfo.PathLocal)))
res.Header().Set("Content-Type", "application/octet-stream")
res.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, fileInfo.Uploaded))
res.Header().Set("Content-Length", strconv.FormatInt(ra.GetLength(), 10))
res.Header().Set("Last-Modified", time.Unix(fileInfo.ModTime, 0).UTC().Format(http.TimeFormat))
res.WriteHeader(http.StatusPartialContent)
// TODO: need verify path
file, openErr := tb.Filer.Open(filepath.Join(tb.Root, fileInfo.PathLocal))
defer file.Close()
if openErr != nil {
return openErr
}
if _, seekErr := file.Seek(ra.start, io.SeekStart); seekErr != nil {
return seekErr
}
copyErr := tb.throttledCopyN(res, file, ra.length)
if copyErr != nil && copyErr != io.EOF {
return copyErr
}
return nil
}
func (tb *QTube) throttledCopyN(dst io.Writer, src io.Reader, length int64) error {
sum := int64(0)
timeSlot := time.Duration(1 * time.Second)
for sum < length {
start := time.Now()
chunkSize := length - sum
if length-sum > tb.BytesPerSec {
chunkSize = tb.BytesPerSec
}
copied, err := io.CopyN(dst, src, chunkSize)
if err != nil {
return err
}
sum += copied
end := time.Now()
if end.Before(start.Add(timeSlot)) {
time.Sleep(start.Add(timeSlot).Sub(end))
}
}
return nil
}
func parseRange(headerRange string, size int64) ([]httpRange, error) {
if headerRange == "" {
return nil, nil // header not present
}
const keyByte = "bytes="
if !strings.HasPrefix(headerRange, keyByte) {
return nil, errors.New("byte= not found")
}
var ranges []httpRange
noOverlap := false
for _, ra := range strings.Split(headerRange[len(keyByte):], ",") {
ra = strings.TrimSpace(ra)
if ra == "" {
continue
}
i := strings.Index(ra, "-")
if i < 0 {
return nil, errors.New("- not found")
}
start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
var r httpRange
if start == "" {
i, err := strconv.ParseInt(end, 10, 64)
if err != nil {
return nil, errors.New("invalid range")
}
if i > size {
i = size
}
r.start = size - i
r.length = size - r.start
} else {
i, err := strconv.ParseInt(start, 10, 64)
if err != nil || i < 0 {
return nil, errors.New("invalid range")
}
if i >= size {
// If the range begins after the size of the content,
// then it does not overlap.
noOverlap = true
continue
}
r.start = i
if end == "" {
// If no end is specified, range extends to end of the file.
r.length = size - r.start
} else {
i, err := strconv.ParseInt(end, 10, 64)
if err != nil || r.start > i {
return nil, errors.New("invalid range")
}
if i >= size {
i = size - 1
}
r.length = i - r.start + 1
}
}
ranges = append(ranges, r)
}
if noOverlap && len(ranges) == 0 {
// The specified ranges did not overlap with the content.
return nil, errors.New("parseRanges: no overlap")
}
return ranges, nil
}
func sumRangesSize(ranges []httpRange) (size int64) {
for _, ra := range ranges {
size += ra.length
}
return
}

View file

@ -0,0 +1,354 @@
package qtube
import (
"bytes"
"fmt"
"net/http"
"strings"
"testing"
"time"
)
import (
"quickshare/server/libs/fileidx"
)
// Range format examples:
// Range: <unit>=<range-start>-
// Range: <unit>=<range-start>-<range-end>
// Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>
// Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>, <range-start>-<range-end>
func TestGetRanges(t *testing.T) {
type Input struct {
HeaderRange string
Size int64
}
type Output struct {
Ranges []httpRange
ErrorMsg string
}
type testCase struct {
Desc string
Input
Output
}
testCases := []testCase{
testCase{
Desc: "invalid range",
Input: Input{
HeaderRange: "bytes=start-invalid end",
Size: 0,
},
Output: Output{
ErrorMsg: ErrorInvalidRange,
},
},
testCase{
Desc: "invalid range total size",
Input: Input{
HeaderRange: "bytes=0-1, 2-3, 0-1, 0-2",
Size: 3,
},
Output: Output{
ErrorMsg: ErrorInvalidSize,
},
},
testCase{
Desc: "range ok",
Input: Input{
HeaderRange: "bytes=0-1, 2-3",
Size: 4,
},
Output: Output{
Ranges: []httpRange{
httpRange{start: 0, length: 2},
httpRange{start: 2, length: 2},
},
ErrorMsg: "",
},
},
}
for _, tCase := range testCases {
ranges, err := getRanges(tCase.HeaderRange, tCase.Size)
if err != nil {
if err.Error() != tCase.ErrorMsg || len(tCase.Ranges) != 0 {
t.Fatalf("getRanges: incorrect errorMsg want: %v got: %v", tCase.ErrorMsg, err.Error())
} else {
continue
}
} else {
for id, ra := range ranges {
if ra.GetStart() != tCase.Ranges[id].GetStart() {
t.Fatalf("getRanges: incorrect range start, got: %v want: %v", ra.GetStart(), tCase.Ranges[id])
}
if ra.GetLength() != tCase.Ranges[id].GetLength() {
t.Fatalf("getRanges: incorrect range length, got: %v want: %v", ra.GetLength(), tCase.Ranges[id])
}
}
}
}
}
func TestThrottledCopyN(t *testing.T) {
type Init struct {
BytesPerSec int64
MaxRangeLen int64
}
type Input struct {
Src string
Length int64
}
// after starting throttledCopyN by DstAtTime.AtMs millisecond,
// copied valueshould equal to DstAtTime.Dst.
type DstAtTime struct {
AtMS int
Dst string
}
type Output struct {
ExpectDsts []DstAtTime
}
type testCase struct {
Desc string
Init
Input
Output
}
verifyDsts := func(dst *bytes.Buffer, expectDsts []DstAtTime) {
for _, expectDst := range expectDsts {
// fmt.Printf("sleep: %d\n", time.Now().UnixNano())
time.Sleep(time.Duration(expectDst.AtMS) * time.Millisecond)
dstStr := string(dst.Bytes())
// fmt.Printf("check: %d\n", time.Now().UnixNano())
if dstStr != expectDst.Dst {
panic(
fmt.Sprintf(
"throttledCopyN want: <%s> | got: <%s> | at: %d",
expectDst.Dst,
dstStr,
expectDst.AtMS,
),
)
}
}
}
testCases := []testCase{
testCase{
Desc: "4 byte per sec",
Init: Init{
BytesPerSec: 5,
MaxRangeLen: 10,
},
Input: Input{
Src: "aaaa_aaaa_",
Length: 10,
},
Output: Output{
ExpectDsts: []DstAtTime{
DstAtTime{AtMS: 200, Dst: "aaaa_"},
DstAtTime{AtMS: 200, Dst: "aaaa_"},
DstAtTime{AtMS: 200, Dst: "aaaa_"},
DstAtTime{AtMS: 600, Dst: "aaaa_aaaa_"},
DstAtTime{AtMS: 200, Dst: "aaaa_aaaa_"},
DstAtTime{AtMS: 200, Dst: "aaaa_aaaa_"},
},
},
},
}
for _, tCase := range testCases {
tb := NewQTube("", tCase.BytesPerSec, tCase.MaxRangeLen, &stubFiler{}).(*QTube)
dst := bytes.NewBuffer(make([]byte, len(tCase.Src)))
dst.Reset()
go verifyDsts(dst, tCase.ExpectDsts)
tb.throttledCopyN(dst, strings.NewReader(tCase.Src), tCase.Length)
}
}
// TODO: using same stub with testhelper
type stubWriter struct {
Headers http.Header
Response []byte
StatusCode int
}
func (w *stubWriter) Header() http.Header {
return w.Headers
}
func (w *stubWriter) Write(body []byte) (int, error) {
w.Response = append(w.Response, body...)
return len(body), nil
}
func (w *stubWriter) WriteHeader(statusCode int) {
w.StatusCode = statusCode
}
func TestCopyRange(t *testing.T) {
type Init struct {
Content string
}
type Input struct {
Range httpRange
Info fileidx.FileInfo
}
type Output struct {
StatusCode int
Headers map[string][]string
Body string
}
type testCase struct {
Desc string
Init
Input
Output
}
testCases := []testCase{
testCase{
Desc: "copy ok",
Init: Init{
Content: "abcd_abcd_",
},
Input: Input{
Range: httpRange{
start: 6,
length: 3,
},
Info: fileidx.FileInfo{
ModTime: 0,
Uploaded: 10,
PathLocal: "filename.jpg",
},
},
Output: Output{
StatusCode: 206,
Headers: map[string][]string{
"Accept-Ranges": []string{"bytes"},
"Content-Disposition": []string{`attachment; filename="filename.jpg"`},
"Content-Type": []string{"application/octet-stream"},
"Content-Range": []string{"bytes 6-8/10"},
"Content-Length": []string{"3"},
"Last-Modified": []string{time.Unix(0, 0).UTC().Format(http.TimeFormat)},
},
Body: "abc",
},
},
}
for _, tCase := range testCases {
filer := &stubFiler{
&StubFile{
Content: tCase.Content,
Offset: 0,
},
}
tb := NewQTube("", 100, 100, filer).(*QTube)
res := &stubWriter{
Headers: make(map[string][]string),
Response: make([]byte, 0),
}
err := tb.copyRange(res, tCase.Range, &tCase.Info)
if err != nil {
t.Fatalf("copyRange: %v", err)
}
if res.StatusCode != tCase.Output.StatusCode {
t.Fatalf("copyRange: statusCode not match got: %v want: %v", res.StatusCode, tCase.Output.StatusCode)
}
if string(res.Response) != tCase.Output.Body {
t.Fatalf("copyRange: body not match \ngot: %v \nwant: %v", string(res.Response), tCase.Output.Body)
}
for key, vals := range tCase.Output.Headers {
if res.Header().Get(key) != vals[0] {
t.Fatalf("copyRange: header not match %v got: %v want: %v", key, res.Header().Get(key), vals[0])
}
}
if res.StatusCode != tCase.Output.StatusCode {
t.Fatalf("copyRange: statusCodes are not match %v", res.StatusCode, tCase.Output.StatusCode)
}
}
}
func TestServeAll(t *testing.T) {
type Init struct {
Content string
}
type Input struct {
Info fileidx.FileInfo
}
type Output struct {
StatusCode int
Headers map[string][]string
Body string
}
type testCase struct {
Desc string
Init
Input
Output
}
testCases := []testCase{
testCase{
Desc: "copy ok",
Init: Init{
Content: "abcd_abcd_",
},
Input: Input{
Info: fileidx.FileInfo{
ModTime: 0,
Uploaded: 10,
PathLocal: "filename.jpg",
},
},
Output: Output{
StatusCode: 200,
Headers: map[string][]string{
"Accept-Ranges": []string{"bytes"},
"Content-Disposition": []string{`attachment; filename="filename.jpg"`},
"Content-Type": []string{"application/octet-stream"},
"Content-Length": []string{"10"},
"Last-Modified": []string{time.Unix(0, 0).UTC().Format(http.TimeFormat)},
},
Body: "abcd_abcd_",
},
},
}
for _, tCase := range testCases {
filer := &stubFiler{
&StubFile{
Content: tCase.Content,
Offset: 0,
},
}
tb := NewQTube("", 100, 100, filer).(*QTube)
res := &stubWriter{
Headers: make(map[string][]string),
Response: make([]byte, 0),
}
err := tb.serveAll(res, &tCase.Info)
if err != nil {
t.Fatalf("serveAll: %v", err)
}
if res.StatusCode != tCase.Output.StatusCode {
t.Fatalf("serveAll: statusCode not match got: %v want: %v", res.StatusCode, tCase.Output.StatusCode)
}
if string(res.Response) != tCase.Output.Body {
t.Fatalf("serveAll: body not match \ngot: %v \nwant: %v", string(res.Response), tCase.Output.Body)
}
for key, vals := range tCase.Output.Headers {
if res.Header().Get(key) != vals[0] {
t.Fatalf("serveAll: header not match %v got: %v want: %v", key, res.Header().Get(key), vals[0])
}
}
if res.StatusCode != tCase.Output.StatusCode {
t.Fatalf("serveAll: statusCodes are not match %v", res.StatusCode, tCase.Output.StatusCode)
}
}
}

View file

@ -0,0 +1,28 @@
package qtube
type StubFile struct {
Content string
Offset int64
}
func (file *StubFile) Read(p []byte) (int, error) {
copied := copy(p[:], []byte(file.Content)[:len(p)])
return copied, nil
}
func (file *StubFile) Seek(offset int64, whence int) (int64, error) {
file.Offset = offset
return offset, nil
}
func (file *StubFile) Close() error {
return nil
}
type stubFiler struct {
file *StubFile
}
func (filer *stubFiler) Open(filePath string) (ReadSeekCloser, error) {
return filer.file, nil
}

View file

@ -0,0 +1,102 @@
package walls
import (
"fmt"
"net/http"
"strconv"
"time"
)
import (
"quickshare/server/libs/cfg"
"quickshare/server/libs/encrypt"
"quickshare/server/libs/limiter"
)
type AccessWalls struct {
cf *cfg.Config
IpLimiter limiter.Limiter
OpLimiter limiter.Limiter
EncrypterMaker encrypt.EncrypterMaker
}
func NewAccessWalls(
cf *cfg.Config,
ipLimiter limiter.Limiter,
opLimiter limiter.Limiter,
encrypterMaker encrypt.EncrypterMaker,
) Walls {
return &AccessWalls{
cf: cf,
IpLimiter: ipLimiter,
OpLimiter: opLimiter,
EncrypterMaker: encrypterMaker,
}
}
func (walls *AccessWalls) PassIpLimit(remoteAddr string) bool {
if !walls.cf.Production {
return true
}
return walls.IpLimiter.Access(remoteAddr, walls.cf.OpIdIpVisit)
}
func (walls *AccessWalls) PassOpLimit(resourceId string, opId int16) bool {
if !walls.cf.Production {
return true
}
return walls.OpLimiter.Access(resourceId, opId)
}
func (walls *AccessWalls) PassLoginCheck(tokenStr string, req *http.Request) bool {
if !walls.cf.Production {
return true
}
return walls.passLoginCheck(tokenStr)
}
func (walls *AccessWalls) passLoginCheck(tokenStr string) bool {
token, getLoginTokenOk := walls.GetLoginToken(tokenStr)
return getLoginTokenOk && token.AdminId == walls.cf.AdminId
}
func (walls *AccessWalls) GetLoginToken(tokenStr string) (*LoginToken, bool) {
tokenMaker := walls.EncrypterMaker(string(walls.cf.SecretKeyByte))
if !tokenMaker.FromStr(tokenStr) {
return nil, false
}
adminIdFromToken, adminIdOk := tokenMaker.Get(walls.cf.KeyAdminId)
expiresStr, expiresStrOk := tokenMaker.Get(walls.cf.KeyExpires)
if !adminIdOk || !expiresStrOk {
return nil, false
}
expires, expiresParseErr := strconv.ParseInt(expiresStr, 10, 64)
if expiresParseErr != nil ||
adminIdFromToken != walls.cf.AdminId ||
expires <= time.Now().Unix() {
return nil, false
}
return &LoginToken{
AdminId: adminIdFromToken,
Expires: expires,
}, true
}
func (walls *AccessWalls) MakeLoginToken(userId string) string {
expires := time.Now().Add(time.Duration(walls.cf.CookieMaxAge) * time.Second).Unix()
tokenMaker := walls.EncrypterMaker(string(walls.cf.SecretKeyByte))
tokenMaker.Add(walls.cf.KeyAdminId, userId)
tokenMaker.Add(walls.cf.KeyExpires, fmt.Sprintf("%d", expires))
tokenStr, ok := tokenMaker.ToStr()
if !ok {
return ""
}
return tokenStr
}

View file

@ -0,0 +1,145 @@
package walls
import (
"fmt"
"testing"
"time"
)
import (
"quickshare/server/libs/cfg"
"quickshare/server/libs/encrypt"
"quickshare/server/libs/limiter"
)
func newAccessWalls(limiterCap int64, limiterTtl int32, limiterCyc int32, bucketCap int16) *AccessWalls {
config := cfg.NewConfig()
config.Production = true
config.LimiterCap = limiterCap
config.LimiterTtl = limiterTtl
config.LimiterCyc = limiterCyc
config.BucketCap = bucketCap
encrypterMaker := encrypt.JwtEncrypterMaker
ipLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
opLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
return NewAccessWalls(config, ipLimiter, opLimiter, encrypterMaker).(*AccessWalls)
}
func TestIpLimit(t *testing.T) {
ip := "0.0.0.0"
limit := int16(10)
ttl := int32(60)
cyc := int32(5)
walls := newAccessWalls(1000, ttl, cyc, limit)
testIpLimit(t, walls, ip, limit)
// wait for tokens are re-fullfilled
time.Sleep(time.Duration(cyc) * time.Second)
testIpLimit(t, walls, ip, limit)
fmt.Println("ip limit: passed")
}
func testIpLimit(t *testing.T, walls Walls, ip string, limit int16) {
for i := int16(0); i < limit; i++ {
if !walls.PassIpLimit(ip) {
t.Fatalf("ipLimiter: should be passed", time.Now().Unix())
}
}
if walls.PassIpLimit(ip) {
t.Fatalf("ipLimiter: should not be passed", time.Now().Unix())
}
}
func TestOpLimit(t *testing.T) {
resourceId := "id"
op1 := int16(1)
op2 := int16(2)
limit := int16(10)
ttl := int32(1)
walls := newAccessWalls(1000, 5, ttl, limit)
testOpLimit(t, walls, resourceId, op1, limit)
testOpLimit(t, walls, resourceId, op2, limit)
// wait for tokens are re-fullfilled
time.Sleep(time.Duration(ttl) * time.Second)
testOpLimit(t, walls, resourceId, op1, limit)
testOpLimit(t, walls, resourceId, op2, limit)
fmt.Println("op limit: passed")
}
func testOpLimit(t *testing.T, walls Walls, resourceId string, op int16, limit int16) {
for i := int16(0); i < limit; i++ {
if !walls.PassOpLimit(resourceId, op) {
t.Fatalf("opLimiter: should be passed")
}
}
if walls.PassOpLimit(resourceId, op) {
t.Fatalf("opLimiter: should not be passed")
}
}
func TestLoginCheck(t *testing.T) {
walls := newAccessWalls(1000, 5, 1, 10)
testValidToken(t, walls)
testInvalidAdminIdToken(t, walls)
testExpiredToken(t, walls)
}
func testValidToken(t *testing.T, walls *AccessWalls) {
config := cfg.NewConfig()
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
tokenMaker.Add(config.KeyAdminId, config.AdminId)
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()+int64(10)))
tokenStr, getTokenOk := tokenMaker.ToStr()
if !getTokenOk {
t.Fatalf("passLoginCheck: fail to generate token")
}
if !walls.passLoginCheck(tokenStr) {
t.Fatalf("loginCheck: should be passed")
}
fmt.Println("loginCheck: valid token passed")
}
func testInvalidAdminIdToken(t *testing.T, walls *AccessWalls) {
config := cfg.NewConfig()
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
tokenMaker.Add(config.KeyAdminId, "invalid admin id")
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()+int64(10)))
tokenStr, getTokenOk := tokenMaker.ToStr()
if !getTokenOk {
t.Fatalf("passLoginCheck: fail to generate token")
}
if walls.passLoginCheck(tokenStr) {
t.Fatalf("loginCheck: should not be passed")
}
fmt.Println("loginCheck: invalid admin id passed")
}
func testExpiredToken(t *testing.T, walls *AccessWalls) {
config := cfg.NewConfig()
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
tokenMaker.Add(config.KeyAdminId, config.AdminId)
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()-int64(1)))
tokenStr, getTokenOk := tokenMaker.ToStr()
if !getTokenOk {
t.Fatalf("passLoginCheck: fail to generate token")
}
if walls.passLoginCheck(tokenStr) {
t.Fatalf("loginCheck: should not be passed")
}
fmt.Println("loginCheck: expired token passed")
}

View file

@ -0,0 +1,17 @@
package walls
import (
"net/http"
)
type Walls interface {
PassIpLimit(remoteAddr string) bool
PassOpLimit(resourceId string, opId int16) bool
PassLoginCheck(tokenStr string, req *http.Request) bool
MakeLoginToken(uid string) string
}
type LoginToken struct {
AdminId string
Expires int64
}