parent
30c963a5f0
commit
61a1c93f0f
89 changed files with 15859 additions and 2 deletions
251
server/libs/cfg/cfg.go
Normal file
251
server/libs/cfg/cfg.go
Normal file
|
@ -0,0 +1,251 @@
|
|||
package cfg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
AppName string
|
||||
AdminId string
|
||||
AdminPwd string
|
||||
SecretKey string
|
||||
SecretKeyByte []byte `json:",omitempty"`
|
||||
// server
|
||||
Production bool
|
||||
HostName string
|
||||
Port int
|
||||
// performance
|
||||
MaxUpBytesPerSec int64
|
||||
MaxDownBytesPerSec int64
|
||||
MaxRangeLength int64
|
||||
Timeout int // millisecond
|
||||
ReadTimeout int
|
||||
WriteTimeout int
|
||||
IdleTimeout int
|
||||
WorkerPoolSize int
|
||||
TaskQueueSize int
|
||||
QueueSize int
|
||||
ParseFormBufSize int64
|
||||
MaxHeaderBytes int
|
||||
DownLimit int
|
||||
MaxShares int
|
||||
LocalFileLimit int
|
||||
// Cookie
|
||||
CookieDomain string
|
||||
CookieHttpOnly bool
|
||||
CookieMaxAge int
|
||||
CookiePath string
|
||||
CookieSecure bool
|
||||
// keys
|
||||
KeyAdminId string
|
||||
KeyAdminPwd string
|
||||
KeyToken string
|
||||
KeyFileName string
|
||||
KeyFileSize string
|
||||
KeyShareId string
|
||||
KeyStart string
|
||||
KeyLen string
|
||||
KeyChunk string
|
||||
KeyAct string
|
||||
KeyExpires string
|
||||
KeyDownLimit string
|
||||
ActStartUpload string
|
||||
ActUpload string
|
||||
ActFinishUpload string
|
||||
ActLogin string
|
||||
ActLogout string
|
||||
ActShadowId string
|
||||
ActPublishId string
|
||||
ActSetDownLimit string
|
||||
ActAddLocalFiles string
|
||||
// resource id
|
||||
AllUsers string
|
||||
// opIds
|
||||
OpIdIpVisit int16
|
||||
OpIdUpload int16
|
||||
OpIdDownload int16
|
||||
OpIdLogin int16
|
||||
OpIdGetFInfo int16
|
||||
OpIdDelFInfo int16
|
||||
OpIdOpFInfo int16
|
||||
// local
|
||||
PathLocal string
|
||||
PathLogin string
|
||||
PathDownloadLogin string
|
||||
PathDownload string
|
||||
PathUpload string
|
||||
PathStartUpload string
|
||||
PathFinishUpload string
|
||||
PathFileInfo string
|
||||
PathClient string
|
||||
// rate Limiter
|
||||
LimiterCap int64
|
||||
LimiterTtl int32
|
||||
LimiterCyc int32
|
||||
BucketCap int16
|
||||
SpecialCapsStr map[string]int16
|
||||
SpecialCaps map[int16]int16
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
config := &Config{
|
||||
// secrets
|
||||
AppName: "qs",
|
||||
AdminId: "admin",
|
||||
AdminPwd: "qs",
|
||||
SecretKey: "qs",
|
||||
SecretKeyByte: []byte("qs"),
|
||||
// server
|
||||
Production: true,
|
||||
HostName: "localhost",
|
||||
Port: 8888,
|
||||
// performance
|
||||
MaxUpBytesPerSec: 500 * 1000,
|
||||
MaxDownBytesPerSec: 500 * 1000,
|
||||
MaxRangeLength: 10 * 1024 * 1024,
|
||||
Timeout: 500, // millisecond,
|
||||
ReadTimeout: 500,
|
||||
WriteTimeout: 43200000,
|
||||
IdleTimeout: 10000,
|
||||
WorkerPoolSize: 2,
|
||||
TaskQueueSize: 2,
|
||||
QueueSize: 2,
|
||||
ParseFormBufSize: 600,
|
||||
MaxHeaderBytes: 1 << 15, // 32KB
|
||||
DownLimit: -1,
|
||||
MaxShares: 1 << 31,
|
||||
LocalFileLimit: -1,
|
||||
// Cookie
|
||||
CookieDomain: "",
|
||||
CookieHttpOnly: false,
|
||||
CookieMaxAge: 3600 * 24 * 30, // one week,
|
||||
CookiePath: "/",
|
||||
CookieSecure: false,
|
||||
// keys
|
||||
KeyAdminId: "adminid",
|
||||
KeyAdminPwd: "adminpwd",
|
||||
KeyToken: "token",
|
||||
KeyFileName: "fname",
|
||||
KeyFileSize: "size",
|
||||
KeyShareId: "shareid",
|
||||
KeyStart: "start",
|
||||
KeyLen: "len",
|
||||
KeyChunk: "chunk",
|
||||
KeyAct: "act",
|
||||
KeyExpires: "expires",
|
||||
KeyDownLimit: "downlimit",
|
||||
ActStartUpload: "startupload",
|
||||
ActUpload: "upload",
|
||||
ActFinishUpload: "finishupload",
|
||||
ActLogin: "login",
|
||||
ActLogout: "logout",
|
||||
ActShadowId: "shadowid",
|
||||
ActPublishId: "publishid",
|
||||
ActSetDownLimit: "setdownlimit",
|
||||
ActAddLocalFiles: "addlocalfiles",
|
||||
AllUsers: "allusers",
|
||||
// opIds
|
||||
OpIdIpVisit: 0,
|
||||
OpIdUpload: 1,
|
||||
OpIdDownload: 2,
|
||||
OpIdLogin: 3,
|
||||
OpIdGetFInfo: 4,
|
||||
OpIdDelFInfo: 5,
|
||||
OpIdOpFInfo: 6,
|
||||
// local
|
||||
PathLocal: "files",
|
||||
PathLogin: "/login",
|
||||
PathDownloadLogin: "/download-login",
|
||||
PathDownload: "/download",
|
||||
PathUpload: "/upload",
|
||||
PathStartUpload: "/startupload",
|
||||
PathFinishUpload: "/finishupload",
|
||||
PathFileInfo: "/fileinfo",
|
||||
PathClient: "/",
|
||||
// rate Limiter
|
||||
LimiterCap: 256, // how many op supported for each user
|
||||
LimiterTtl: 3600, // second
|
||||
LimiterCyc: 1, // second
|
||||
BucketCap: 3, // how many op can do per LimiterCyc sec
|
||||
SpecialCaps: map[int16]int16{
|
||||
0: 5, // ip
|
||||
1: 1, // upload
|
||||
2: 1, // download
|
||||
3: 1, // login
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func NewConfigFrom(path string) *Config {
|
||||
configBytes, readErr := ioutil.ReadFile(path)
|
||||
if readErr != nil {
|
||||
panic(fmt.Sprintf("config file not found: %s", path))
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
marshalErr := json.Unmarshal(configBytes, config)
|
||||
|
||||
// TODO: look for a better solution
|
||||
config.SpecialCaps = make(map[int16]int16)
|
||||
for strKey, value := range config.SpecialCapsStr {
|
||||
key, parseKeyErr := strconv.ParseInt(strKey, 10, 16)
|
||||
if parseKeyErr != nil {
|
||||
panic("fail to parse SpecialCapsStr, its type should be map[int16]int16")
|
||||
}
|
||||
config.SpecialCaps[int16(key)] = value
|
||||
}
|
||||
|
||||
if marshalErr != nil {
|
||||
panic("config file format is incorrect")
|
||||
}
|
||||
|
||||
config.SecretKeyByte = []byte(config.SecretKey)
|
||||
if config.HostName == "" {
|
||||
hostName, err := GetLocalAddr()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
config.HostName = hostName.String()
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func GetLocalAddr() (net.IP, error) {
|
||||
fmt.Println(`config.HostName is empty(""), choose one IP for listening automatically.`)
|
||||
infs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
panic("fail to get net interfaces")
|
||||
}
|
||||
|
||||
for _, inf := range infs {
|
||||
if inf.Flags&4 != 4 && !strings.Contains(inf.Name, "docker") {
|
||||
addrs, err := inf.Addrs()
|
||||
if err != nil {
|
||||
panic("fail to get addrs of interface")
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
switch v := addr.(type) {
|
||||
case *net.IPAddr:
|
||||
if !strings.Contains(v.IP.String(), ":") {
|
||||
return v.IP, nil
|
||||
}
|
||||
case *net.IPNet:
|
||||
if !strings.Contains(v.IP.String(), ":") {
|
||||
return v.IP, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("no addr found")
|
||||
}
|
17
server/libs/encrypt/encrypter_hmac.go
Normal file
17
server/libs/encrypt/encrypter_hmac.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package encrypt
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
type HmacEncryptor struct {
|
||||
Key []byte
|
||||
}
|
||||
|
||||
func (encryptor *HmacEncryptor) Encrypt(content []byte) string {
|
||||
mac := hmac.New(sha256.New, encryptor.Key)
|
||||
mac.Write(content)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
5
server/libs/encrypt/encryptor.go
Normal file
5
server/libs/encrypt/encryptor.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package encrypt
|
||||
|
||||
type Encryptor interface {
|
||||
Encrypt(content []byte) string
|
||||
}
|
53
server/libs/encrypt/jwt.go
Normal file
53
server/libs/encrypt/jwt.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package encrypt
|
||||
|
||||
import (
|
||||
"github.com/robbert229/jwt"
|
||||
)
|
||||
|
||||
func JwtEncrypterMaker(secret string) TokenEncrypter {
|
||||
return &JwtEncrypter{
|
||||
alg: jwt.HmacSha256(secret),
|
||||
claims: jwt.NewClaim(),
|
||||
}
|
||||
}
|
||||
|
||||
type JwtEncrypter struct {
|
||||
alg jwt.Algorithm
|
||||
claims *jwt.Claims
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) Add(key string, value string) bool {
|
||||
encrypter.claims.Set(key, value)
|
||||
return true
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) FromStr(token string) bool {
|
||||
claims, err := encrypter.alg.Decode(token)
|
||||
// TODO: should return error or error info will lost
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
encrypter.claims = claims
|
||||
return true
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) Get(key string) (string, bool) {
|
||||
iValue, err := encrypter.claims.Get(key)
|
||||
// TODO: should return error or error info will lost
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return iValue.(string), true
|
||||
}
|
||||
|
||||
func (encrypter *JwtEncrypter) ToStr() (string, bool) {
|
||||
token, err := encrypter.alg.Encode(encrypter.claims)
|
||||
|
||||
// TODO: should return error or error info will lost
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return token, true
|
||||
}
|
11
server/libs/encrypt/token_encrypter.go
Normal file
11
server/libs/encrypt/token_encrypter.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package encrypt
|
||||
|
||||
type EncrypterMaker func(string) TokenEncrypter
|
||||
|
||||
// TODO: name should be Encrypter?
|
||||
type TokenEncrypter interface {
|
||||
Add(string, string) bool
|
||||
FromStr(string) bool
|
||||
Get(string) (string, bool)
|
||||
ToStr() (string, bool)
|
||||
}
|
59
server/libs/errutil/ettutil.go
Normal file
59
server/libs/errutil/ettutil.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package errutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/logutil"
|
||||
)
|
||||
|
||||
type ErrUtil interface {
|
||||
IsErr(err error) bool
|
||||
IsFatalErr(err error) bool
|
||||
RecoverPanic()
|
||||
}
|
||||
|
||||
func NewErrChecker(logStack bool, logger logutil.LogUtil) ErrUtil {
|
||||
return &ErrChecker{logStack: logStack, log: logger}
|
||||
}
|
||||
|
||||
type ErrChecker struct {
|
||||
log logutil.LogUtil
|
||||
logStack bool
|
||||
}
|
||||
|
||||
// IsErr checks if error occurs
|
||||
func (e *ErrChecker) IsErr(err error) bool {
|
||||
if err != nil {
|
||||
e.log.Printf("Error:%q\n", err)
|
||||
if e.logStack {
|
||||
e.log.Println(debug.Stack())
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsFatalPanic should be used with defer
|
||||
func (e *ErrChecker) IsFatalErr(fe error) bool {
|
||||
if fe != nil {
|
||||
e.log.Printf("Panic:%q", fe)
|
||||
if e.logStack {
|
||||
e.log.Println(debug.Stack())
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RecoverPanic catchs the panic and logs panic information
|
||||
func (e *ErrChecker) RecoverPanic() {
|
||||
if r := recover(); r != nil {
|
||||
e.log.Printf("Recovered:%v", r)
|
||||
if e.logStack {
|
||||
e.log.Println(debug.Stack())
|
||||
}
|
||||
}
|
||||
}
|
177
server/libs/fileidx/file_idx.go
Normal file
177
server/libs/fileidx/file_idx.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package fileidx
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// StateStarted = after startUpload before upload
|
||||
StateStarted = "started"
|
||||
// StateUploading =after upload before finishUpload
|
||||
StateUploading = "uploading"
|
||||
// StateDone = after finishedUpload
|
||||
StateDone = "done"
|
||||
)
|
||||
|
||||
type FileInfo struct {
|
||||
Id string
|
||||
DownLimit int
|
||||
ModTime int64
|
||||
PathLocal string
|
||||
State string
|
||||
Uploaded int64
|
||||
}
|
||||
|
||||
type FileIndex interface {
|
||||
Add(fileInfo *FileInfo) int
|
||||
Del(id string)
|
||||
SetId(id string, newId string) bool
|
||||
SetDownLimit(id string, downLimit int) bool
|
||||
DecrDownLimit(id string) (int, bool)
|
||||
SetState(id string, state string) bool
|
||||
IncrUploaded(id string, uploaded int64) int64
|
||||
Get(id string) (*FileInfo, bool)
|
||||
List() map[string]*FileInfo
|
||||
}
|
||||
|
||||
func NewMemFileIndex(cap int) *MemFileIndex {
|
||||
return &MemFileIndex{
|
||||
cap: cap,
|
||||
infos: make(map[string]*FileInfo, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func NewMemFileIndexWithMap(cap int, infos map[string]*FileInfo) *MemFileIndex {
|
||||
return &MemFileIndex{
|
||||
cap: cap,
|
||||
infos: infos,
|
||||
}
|
||||
}
|
||||
|
||||
type MemFileIndex struct {
|
||||
cap int
|
||||
infos map[string]*FileInfo
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) Add(fileInfo *FileInfo) int {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
if len(idx.infos) >= idx.cap {
|
||||
return 1
|
||||
}
|
||||
|
||||
if _, found := idx.infos[fileInfo.Id]; found {
|
||||
return -1
|
||||
}
|
||||
|
||||
idx.infos[fileInfo.Id] = fileInfo
|
||||
return 0
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) Del(id string) {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
delete(idx.infos, id)
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) SetId(id string, newId string) bool {
|
||||
if id == newId {
|
||||
return true
|
||||
}
|
||||
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, foundNewId := idx.infos[newId]; foundNewId {
|
||||
return false
|
||||
}
|
||||
|
||||
idx.infos[newId] = info
|
||||
idx.infos[newId].Id = newId
|
||||
delete(idx.infos, id)
|
||||
return true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) SetDownLimit(id string, downLimit int) bool {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
info.DownLimit = downLimit
|
||||
return true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) DecrDownLimit(id string) (int, bool) {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found || info.State != StateDone {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if info.DownLimit == 0 {
|
||||
return 1, false
|
||||
}
|
||||
|
||||
if info.DownLimit > 0 {
|
||||
// info.DownLimit means unlimited
|
||||
info.DownLimit = info.DownLimit - 1
|
||||
}
|
||||
return 1, true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) SetState(id string, state string) bool {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
info.State = state
|
||||
return true
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) IncrUploaded(id string, uploaded int64) int64 {
|
||||
idx.mux.Lock()
|
||||
defer idx.mux.Unlock()
|
||||
|
||||
info, found := idx.infos[id]
|
||||
if !found {
|
||||
return 0
|
||||
}
|
||||
|
||||
info.Uploaded = info.Uploaded + uploaded
|
||||
return info.Uploaded
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) Get(id string) (*FileInfo, bool) {
|
||||
idx.mux.RLock()
|
||||
defer idx.mux.RUnlock()
|
||||
|
||||
infos, found := idx.infos[id]
|
||||
return infos, found
|
||||
}
|
||||
|
||||
func (idx *MemFileIndex) List() map[string]*FileInfo {
|
||||
idx.mux.RLock()
|
||||
defer idx.mux.RUnlock()
|
||||
|
||||
return idx.infos
|
||||
}
|
||||
|
||||
// TODO: add unit tests
|
118
server/libs/fsutil/fsutil.go
Normal file
118
server/libs/fsutil/fsutil.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package fsutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/errutil"
|
||||
"quickshare/server/libs/fileidx"
|
||||
"quickshare/server/libs/qtube"
|
||||
)
|
||||
|
||||
type FsUtil interface {
|
||||
CreateFile(fullPath string) error
|
||||
CopyChunkN(fullPath string, chunk io.Reader, start int64, length int64) bool
|
||||
DelFile(fullPath string) bool
|
||||
Open(fullPath string) (qtube.ReadSeekCloser, error)
|
||||
MkdirAll(path string, mode os.FileMode) bool
|
||||
Readdir(dirName string, n int) ([]*fileidx.FileInfo, error)
|
||||
}
|
||||
|
||||
func NewSimpleFs(errUtil errutil.ErrUtil) FsUtil {
|
||||
return &SimpleFs{
|
||||
Err: errUtil,
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleFs struct {
|
||||
Err errutil.ErrUtil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrExists = errors.New("file exists")
|
||||
ErrUnknown = errors.New("unknown error")
|
||||
)
|
||||
|
||||
func (sfs *SimpleFs) CreateFile(fullPath string) error {
|
||||
flag := os.O_CREATE | os.O_EXCL | os.O_RDONLY
|
||||
perm := os.FileMode(0644)
|
||||
newFile, err := os.OpenFile(fullPath, flag, perm)
|
||||
defer newFile.Close()
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if os.IsExist(err) {
|
||||
return ErrExists
|
||||
} else {
|
||||
return ErrUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (sfs *SimpleFs) CopyChunkN(fullPath string, chunk io.Reader, start int64, length int64) bool {
|
||||
flag := os.O_WRONLY
|
||||
perm := os.FileMode(0644)
|
||||
file, openErr := os.OpenFile(fullPath, flag, perm)
|
||||
|
||||
defer file.Close()
|
||||
if sfs.Err.IsErr(openErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := file.Seek(start, io.SeekStart); sfs.Err.IsErr(err) {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := io.CopyN(file, chunk, length); sfs.Err.IsErr(err) && err != io.EOF {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (sfs *SimpleFs) DelFile(fullPath string) bool {
|
||||
return !sfs.Err.IsErr(os.Remove(fullPath))
|
||||
}
|
||||
|
||||
func (sfs *SimpleFs) MkdirAll(path string, mode os.FileMode) bool {
|
||||
err := os.MkdirAll(path, mode)
|
||||
return !sfs.Err.IsErr(err)
|
||||
}
|
||||
|
||||
// TODO: not support read from last seek position
|
||||
func (sfs *SimpleFs) Readdir(dirName string, n int) ([]*fileidx.FileInfo, error) {
|
||||
dir, openErr := os.Open(dirName)
|
||||
defer dir.Close()
|
||||
|
||||
if sfs.Err.IsErr(openErr) {
|
||||
return []*fileidx.FileInfo{}, openErr
|
||||
}
|
||||
|
||||
osFileInfos, readErr := dir.Readdir(n)
|
||||
if sfs.Err.IsErr(readErr) && readErr != io.EOF {
|
||||
return []*fileidx.FileInfo{}, readErr
|
||||
}
|
||||
|
||||
fileInfos := make([]*fileidx.FileInfo, 0)
|
||||
for _, osFileInfo := range osFileInfos {
|
||||
if osFileInfo.Mode().IsRegular() {
|
||||
fileInfos = append(
|
||||
fileInfos,
|
||||
&fileidx.FileInfo{
|
||||
ModTime: osFileInfo.ModTime().UnixNano(),
|
||||
PathLocal: osFileInfo.Name(),
|
||||
Uploaded: osFileInfo.Size(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return fileInfos, readErr
|
||||
}
|
||||
|
||||
// the associated file descriptor has mode O_RDONLY as using os.Open
|
||||
func (sfs *SimpleFs) Open(fullPath string) (qtube.ReadSeekCloser, error) {
|
||||
return os.Open(fullPath)
|
||||
}
|
84
server/libs/httputil/httputil.go
Normal file
84
server/libs/httputil/httputil.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/errutil"
|
||||
)
|
||||
|
||||
type MsgRes struct {
|
||||
Code int
|
||||
Msg string
|
||||
}
|
||||
|
||||
var (
|
||||
Err400 = MsgRes{Code: http.StatusBadRequest, Msg: "Bad Request"}
|
||||
Err401 = MsgRes{Code: http.StatusUnauthorized, Msg: "Unauthorized"}
|
||||
Err404 = MsgRes{Code: http.StatusNotFound, Msg: "Not Found"}
|
||||
Err412 = MsgRes{Code: http.StatusPreconditionFailed, Msg: "Precondition Failed"}
|
||||
Err429 = MsgRes{Code: http.StatusTooManyRequests, Msg: "Too Many Requests"}
|
||||
Err500 = MsgRes{Code: http.StatusInternalServerError, Msg: "Internal Server Error"}
|
||||
Err503 = MsgRes{Code: http.StatusServiceUnavailable, Msg: "Service Unavailable"}
|
||||
Err504 = MsgRes{Code: http.StatusGatewayTimeout, Msg: "Gateway Timeout"}
|
||||
Ok200 = MsgRes{Code: http.StatusOK, Msg: "OK"}
|
||||
)
|
||||
|
||||
type HttpUtil interface {
|
||||
GetCookie(cookies []*http.Cookie, key string) string
|
||||
SetCookie(res http.ResponseWriter, key string, val string)
|
||||
Fill(msg interface{}, res http.ResponseWriter) int
|
||||
}
|
||||
|
||||
type QHttpUtil struct {
|
||||
CookieDomain string
|
||||
CookieHttpOnly bool
|
||||
CookieMaxAge int
|
||||
CookiePath string
|
||||
CookieSecure bool
|
||||
Err errutil.ErrUtil
|
||||
}
|
||||
|
||||
func (q *QHttpUtil) GetCookie(cookies []*http.Cookie, key string) string {
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == key {
|
||||
return cookie.Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (q *QHttpUtil) SetCookie(res http.ResponseWriter, key string, val string) {
|
||||
cookie := http.Cookie{
|
||||
Name: key,
|
||||
Value: val,
|
||||
Domain: q.CookieDomain,
|
||||
Expires: time.Now().Add(time.Duration(q.CookieMaxAge) * time.Second),
|
||||
HttpOnly: q.CookieHttpOnly,
|
||||
MaxAge: q.CookieMaxAge,
|
||||
Secure: q.CookieSecure,
|
||||
Path: q.CookiePath,
|
||||
}
|
||||
|
||||
res.Header().Set("Set-Cookie", cookie.String())
|
||||
}
|
||||
|
||||
func (q *QHttpUtil) Fill(msg interface{}, res http.ResponseWriter) int {
|
||||
if msg == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
msgBytes, marsErr := json.Marshal(msg)
|
||||
if q.Err.IsErr(marsErr) {
|
||||
return 0
|
||||
}
|
||||
|
||||
wrote, writeErr := res.Write(msgBytes)
|
||||
if q.Err.IsErr(writeErr) {
|
||||
return 0
|
||||
}
|
||||
return wrote
|
||||
}
|
130
server/libs/httpworker/worker.go
Normal file
130
server/libs/httpworker/worker.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package httpworker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/logutil"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWorkerNotFound = errors.New("worker not found")
|
||||
ErrTimeout = errors.New("timeout")
|
||||
)
|
||||
|
||||
type DoFunc func(http.ResponseWriter, *http.Request)
|
||||
|
||||
type Task struct {
|
||||
Ack chan error
|
||||
Do DoFunc
|
||||
Res http.ResponseWriter
|
||||
Req *http.Request
|
||||
}
|
||||
|
||||
type Workers interface {
|
||||
Put(*Task) bool
|
||||
IsInTime(ack chan error, msec time.Duration) error
|
||||
}
|
||||
|
||||
type WorkerPool struct {
|
||||
queue chan *Task
|
||||
size int
|
||||
workers []*Worker
|
||||
log logutil.LogUtil // TODO: should not pass log here
|
||||
}
|
||||
|
||||
func NewWorkerPool(poolSize int, queueSize int, log logutil.LogUtil) Workers {
|
||||
queue := make(chan *Task, queueSize)
|
||||
workers := make([]*Worker, 0, poolSize)
|
||||
|
||||
for i := 0; i < poolSize; i++ {
|
||||
worker := &Worker{
|
||||
Id: uint64(i),
|
||||
queue: queue,
|
||||
log: log,
|
||||
}
|
||||
|
||||
go worker.Start()
|
||||
workers = append(workers, worker)
|
||||
}
|
||||
|
||||
return &WorkerPool{
|
||||
queue: queue,
|
||||
size: poolSize,
|
||||
workers: workers,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (pool *WorkerPool) Put(task *Task) bool {
|
||||
if len(pool.queue) >= pool.size {
|
||||
return false
|
||||
}
|
||||
|
||||
pool.queue <- task
|
||||
return true
|
||||
}
|
||||
|
||||
func (pool *WorkerPool) IsInTime(ack chan error, msec time.Duration) error {
|
||||
start := time.Now().UnixNano()
|
||||
timeout := make(chan error)
|
||||
|
||||
go func() {
|
||||
time.Sleep(msec)
|
||||
timeout <- ErrTimeout
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ack:
|
||||
if err == nil {
|
||||
pool.log.Printf(
|
||||
"finish cost: %d usec",
|
||||
(time.Now().UnixNano()-start)/1000,
|
||||
)
|
||||
} else {
|
||||
pool.log.Printf(
|
||||
"finish with error cost: %d usec",
|
||||
(time.Now().UnixNano()-start)/1000,
|
||||
)
|
||||
}
|
||||
return err
|
||||
case errTimeout := <-timeout:
|
||||
pool.log.Printf("timeout cost: %d usec", (time.Now().UnixNano()-start)/1000)
|
||||
return errTimeout
|
||||
}
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
Id uint64
|
||||
queue chan *Task
|
||||
log logutil.LogUtil
|
||||
}
|
||||
|
||||
func (worker *Worker) RecoverPanic() {
|
||||
if r := recover(); r != nil {
|
||||
worker.log.Printf("Recovered:%v stack: %v", r, debug.Stack())
|
||||
// restart worker and IsInTime will return timeout error for last task
|
||||
worker.Start()
|
||||
}
|
||||
}
|
||||
|
||||
func (worker *Worker) Start() {
|
||||
defer worker.RecoverPanic()
|
||||
|
||||
for {
|
||||
task := <-worker.queue
|
||||
if task.Do != nil {
|
||||
task.Do(task.Res, task.Req)
|
||||
task.Ack <- nil
|
||||
} else {
|
||||
task.Ack <- ErrWorkerNotFound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceFunc lets you return struct directly
|
||||
type ServiceFunc func(http.ResponseWriter, *http.Request) interface{}
|
5
server/libs/limiter/limiter.go
Normal file
5
server/libs/limiter/limiter.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package limiter
|
||||
|
||||
type Limiter interface {
|
||||
Access(string, int16) bool
|
||||
}
|
220
server/libs/limiter/rate_limiter.go
Normal file
220
server/libs/limiter/rate_limiter.go
Normal file
|
@ -0,0 +1,220 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func now() int32 {
|
||||
return int32(time.Now().Unix())
|
||||
}
|
||||
|
||||
func afterCyc(cyc int32) int32 {
|
||||
return int32(time.Now().Unix()) + cyc
|
||||
}
|
||||
|
||||
func afterTtl(ttl int32) int32 {
|
||||
return int32(time.Now().Unix()) + ttl
|
||||
}
|
||||
|
||||
type Bucket struct {
|
||||
Refresh int32
|
||||
Tokens int16
|
||||
}
|
||||
|
||||
func NewBucket(cyc int32, cap int16) *Bucket {
|
||||
return &Bucket{
|
||||
Refresh: afterCyc(cyc),
|
||||
Tokens: cap,
|
||||
}
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Expired int32
|
||||
Buckets map[int16]*Bucket
|
||||
}
|
||||
|
||||
func NewItem(ttl int32) *Item {
|
||||
return &Item{
|
||||
Expired: afterTtl(ttl),
|
||||
Buckets: make(map[int16]*Bucket),
|
||||
}
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
items map[string]*Item
|
||||
bucketCap int16
|
||||
customCaps map[int16]int16
|
||||
cap int64
|
||||
cyc int32 // how much time, item autoclean will be executed, bucket will be refreshed
|
||||
ttl int32 // how much time, item will be expired(but not cleaned)
|
||||
mux sync.RWMutex
|
||||
snapshot map[string]map[int16]*Bucket
|
||||
}
|
||||
|
||||
func NewRateLimiter(cap int64, ttl int32, cyc int32, bucketCap int16, customCaps map[int16]int16) Limiter {
|
||||
if cap < 1 || ttl < 1 || cyc < 1 || bucketCap < 1 {
|
||||
panic("cap | bucketCap | ttl | cycle cant be less than 1")
|
||||
}
|
||||
|
||||
limiter := &RateLimiter{
|
||||
items: make(map[string]*Item, cap),
|
||||
bucketCap: bucketCap,
|
||||
customCaps: customCaps,
|
||||
cap: cap,
|
||||
ttl: ttl,
|
||||
cyc: cyc,
|
||||
}
|
||||
|
||||
go limiter.autoClean()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) getBucketCap(opId int16) int16 {
|
||||
bucketCap, existed := limiter.customCaps[opId]
|
||||
if !existed {
|
||||
return limiter.bucketCap
|
||||
}
|
||||
return bucketCap
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) Access(itemId string, opId int16) bool {
|
||||
limiter.mux.Lock()
|
||||
defer limiter.mux.Unlock()
|
||||
|
||||
item, itemExisted := limiter.items[itemId]
|
||||
if !itemExisted {
|
||||
if int64(len(limiter.items)) >= limiter.cap {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.items[itemId] = NewItem(limiter.ttl)
|
||||
limiter.items[itemId].Buckets[opId] = NewBucket(limiter.cyc, limiter.getBucketCap(opId)-1)
|
||||
return true
|
||||
}
|
||||
|
||||
bucket, bucketExisted := item.Buckets[opId]
|
||||
if !bucketExisted {
|
||||
item.Buckets[opId] = NewBucket(limiter.cyc, limiter.getBucketCap(opId)-1)
|
||||
return true
|
||||
}
|
||||
|
||||
if bucket.Refresh > now() {
|
||||
if bucket.Tokens > 0 {
|
||||
bucket.Tokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
bucket.Refresh = afterCyc(limiter.cyc)
|
||||
bucket.Tokens = limiter.getBucketCap(opId) - 1
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetCap() int64 {
|
||||
return limiter.cap
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetSize() int64 {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
return int64(len(limiter.items))
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) ExpandCap(cap int64) bool {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
if cap <= int64(len(limiter.items)) {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.cap = cap
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetTTL() int32 {
|
||||
return limiter.ttl
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) UpdateTTL(ttl int32) bool {
|
||||
if ttl < 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.ttl = ttl
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) GetCyc() int32 {
|
||||
return limiter.cyc
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) UpdateCyc(cyc int32) bool {
|
||||
if limiter.cyc < 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
limiter.cyc = cyc
|
||||
return true
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) Snapshot() map[string]map[int16]*Bucket {
|
||||
return limiter.snapshot
|
||||
}
|
||||
|
||||
func (limiter *RateLimiter) autoClean() {
|
||||
for {
|
||||
if limiter.cyc == 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(int64(limiter.cyc) * 1000000000))
|
||||
limiter.clean()
|
||||
}
|
||||
}
|
||||
|
||||
// clean may add affect other operations, do frequently?
|
||||
func (limiter *RateLimiter) clean() {
|
||||
limiter.snapshot = make(map[string]map[int16]*Bucket)
|
||||
now := now()
|
||||
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
for key, item := range limiter.items {
|
||||
if item.Expired <= now {
|
||||
delete(limiter.items, key)
|
||||
} else {
|
||||
limiter.snapshot[key] = item.Buckets
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only for test
|
||||
func (limiter *RateLimiter) exist(id string) bool {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
_, existed := limiter.items[id]
|
||||
return existed
|
||||
}
|
||||
|
||||
// Only for test
|
||||
func (limiter *RateLimiter) truncate() {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
for key, _ := range limiter.items {
|
||||
delete(limiter.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Only for test
|
||||
func (limiter *RateLimiter) get(id string) (*Item, bool) {
|
||||
limiter.mux.RLock()
|
||||
defer limiter.mux.RUnlock()
|
||||
|
||||
item, existed := limiter.items[id]
|
||||
return item, existed
|
||||
}
|
161
server/libs/limiter/rate_limiter_test.go
Normal file
161
server/libs/limiter/rate_limiter_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
const rndCap = 10000
|
||||
const addCap = 1
|
||||
|
||||
// how to set time
|
||||
// extend: wait can be greater than ttl/2
|
||||
// cyc is smaller than ttl and wait, then it can be clean in time
|
||||
const cap = 40
|
||||
const ttl = 3
|
||||
const cyc = 1
|
||||
const bucketCap = 2
|
||||
const id1 = "id1"
|
||||
const id2 = "id2"
|
||||
const op1 int16 = 0
|
||||
const op2 int16 = 1
|
||||
|
||||
var customCaps = map[int16]int16{
|
||||
op2: 1000,
|
||||
}
|
||||
|
||||
const wait = 1
|
||||
|
||||
var limiter = NewRateLimiter(cap, ttl, cyc, bucketCap, customCaps).(*RateLimiter)
|
||||
|
||||
func printItem(id string) {
|
||||
item, existed := limiter.get(id1)
|
||||
if existed {
|
||||
fmt.Println("expired, now, existed", item.Expired, now(), existed)
|
||||
for id, bucket := range item.Buckets {
|
||||
fmt.Println("\tid, bucket", id, bucket)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("not existed")
|
||||
}
|
||||
}
|
||||
|
||||
var idSeed = 0
|
||||
|
||||
func randId() string {
|
||||
idSeed++
|
||||
return fmt.Sprintf("%d", idSeed)
|
||||
}
|
||||
|
||||
func TestAccess(t *testing.T) {
|
||||
func(t *testing.T) {
|
||||
canAccess := limiter.Access(id1, op1)
|
||||
if !canAccess {
|
||||
t.Fatal("access: fail")
|
||||
}
|
||||
|
||||
for i := 0; i < bucketCap; i++ {
|
||||
canAccess = limiter.Access(id1, op1)
|
||||
}
|
||||
|
||||
if canAccess {
|
||||
t.Fatal("access: fail to deny access")
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(limiter.GetCyc()) * time.Second)
|
||||
|
||||
canAccess = limiter.Access(id1, op1)
|
||||
if !canAccess {
|
||||
t.Fatal("access: fail to refresh tokens")
|
||||
}
|
||||
}(t)
|
||||
}
|
||||
|
||||
func TestCap(t *testing.T) {
|
||||
originalCap := limiter.GetCap()
|
||||
fmt.Printf("cap:info: %d\n", originalCap)
|
||||
|
||||
ok := limiter.ExpandCap(originalCap + addCap)
|
||||
|
||||
if !ok || limiter.GetCap() != originalCap+addCap {
|
||||
t.Fatal("cap: fail to expand")
|
||||
}
|
||||
|
||||
ok = limiter.ExpandCap(limiter.GetSize() - addCap)
|
||||
if ok {
|
||||
t.Fatal("cap: shrink cap")
|
||||
}
|
||||
|
||||
ids := []string{}
|
||||
for limiter.GetSize() < limiter.GetCap() {
|
||||
id := randId()
|
||||
ids = append(ids, id)
|
||||
|
||||
ok := limiter.Access(id, 0)
|
||||
if !ok {
|
||||
t.Fatal("cap: not full")
|
||||
}
|
||||
}
|
||||
|
||||
if limiter.GetSize() != limiter.GetCap() {
|
||||
t.Fatal("cap: incorrect size")
|
||||
}
|
||||
|
||||
if limiter.Access(randId(), 0) {
|
||||
t.Fatal("cap: more than cap")
|
||||
}
|
||||
|
||||
limiter.truncate()
|
||||
}
|
||||
|
||||
func TestTtl(t *testing.T) {
|
||||
var addTtl int32 = 1
|
||||
originalTTL := limiter.GetTTL()
|
||||
fmt.Printf("ttl:info: %d\n", originalTTL)
|
||||
|
||||
limiter.UpdateTTL(originalTTL + addTtl)
|
||||
if limiter.GetTTL() != originalTTL+addTtl {
|
||||
t.Fatal("ttl: update fail")
|
||||
}
|
||||
}
|
||||
|
||||
func cycTest(t *testing.T) {
|
||||
var addCyc int32 = 1
|
||||
originalCyc := limiter.GetCyc()
|
||||
fmt.Printf("cyc:info: %d\n", originalCyc)
|
||||
|
||||
limiter.UpdateCyc(originalCyc + addCyc)
|
||||
if limiter.GetCyc() != originalCyc+addCyc {
|
||||
t.Fatal("cyc: update fail")
|
||||
}
|
||||
}
|
||||
|
||||
func autoCleanTest(t *testing.T) {
|
||||
ids := []string{
|
||||
randId(),
|
||||
randId(),
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
ok := limiter.Access(id, 0)
|
||||
if ok {
|
||||
t.Fatal("autoClean: warning: add fail")
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(limiter.GetTTL()+wait) * time.Second)
|
||||
|
||||
for _, id := range ids {
|
||||
_, exist := limiter.get(id)
|
||||
if exist {
|
||||
t.Fatal("autoClean: item still exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// func snapshotTest(t *testing.T) {
|
||||
// }
|
7
server/libs/logutil/logutil.go
Normal file
7
server/libs/logutil/logutil.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package logutil
|
||||
|
||||
type LogUtil interface {
|
||||
Print(v ...interface{})
|
||||
Printf(format string, v ...interface{})
|
||||
Println(v ...interface{})
|
||||
}
|
12
server/libs/logutil/slogger.go
Normal file
12
server/libs/logutil/slogger.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package logutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
func NewSlog(out io.Writer, prefix string) LogUtil {
|
||||
return log.New(out, prefix, log.Ldate|log.Ltime|log.Lshortfile)
|
||||
}
|
||||
|
||||
type Slog *log.Logger
|
13
server/libs/qtube/downloader.go
Normal file
13
server/libs/qtube/downloader.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package qtube
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
)
|
||||
|
||||
type Downloader interface {
|
||||
ServeFile(res http.ResponseWriter, req *http.Request, fileInfo *fileidx.FileInfo) error
|
||||
}
|
280
server/libs/qtube/qtube.go
Normal file
280
server/libs/qtube/qtube.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
package qtube
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCopy = errors.New("ServeFile: copy error")
|
||||
ErrUnknown = errors.New("ServeFile: unknown error")
|
||||
)
|
||||
|
||||
type httpRange struct {
|
||||
start, length int64
|
||||
}
|
||||
|
||||
func (ra *httpRange) GetStart() int64 {
|
||||
return ra.start
|
||||
}
|
||||
func (ra *httpRange) GetLength() int64 {
|
||||
return ra.length
|
||||
}
|
||||
func (ra *httpRange) SetStart(start int64) {
|
||||
ra.start = start
|
||||
}
|
||||
func (ra *httpRange) SetLength(length int64) {
|
||||
ra.length = length
|
||||
}
|
||||
|
||||
func NewQTube(root string, copySpeed, maxRangeLen int64, filer FileReadSeekCloser) Downloader {
|
||||
return &QTube{
|
||||
Root: root,
|
||||
BytesPerSec: copySpeed,
|
||||
MaxRangeLen: maxRangeLen,
|
||||
Filer: filer,
|
||||
}
|
||||
}
|
||||
|
||||
type QTube struct {
|
||||
Root string
|
||||
BytesPerSec int64
|
||||
MaxRangeLen int64
|
||||
Filer FileReadSeekCloser
|
||||
}
|
||||
|
||||
type FileReadSeekCloser interface {
|
||||
Open(filePath string) (ReadSeekCloser, error)
|
||||
}
|
||||
|
||||
type ReadSeekCloser interface {
|
||||
io.Reader
|
||||
io.Seeker
|
||||
io.Closer
|
||||
}
|
||||
|
||||
const (
|
||||
ErrorInvalidRange = "ServeFile: invalid Range"
|
||||
ErrorInvalidSize = "ServeFile: invalid Range total size"
|
||||
)
|
||||
|
||||
func (tb *QTube) ServeFile(res http.ResponseWriter, req *http.Request, fileInfo *fileidx.FileInfo) error {
|
||||
headerRange := req.Header.Get("Range")
|
||||
|
||||
switch {
|
||||
case req.Method == http.MethodHead:
|
||||
res.Header().Set("Accept-Ranges", "bytes")
|
||||
res.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Uploaded))
|
||||
res.Header().Set("Content-Type", "application/octet-stream")
|
||||
res.WriteHeader(http.StatusOK)
|
||||
|
||||
return nil
|
||||
case headerRange == "":
|
||||
return tb.serveAll(res, fileInfo)
|
||||
default:
|
||||
return tb.serveRanges(res, headerRange, fileInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (tb *QTube) serveAll(res http.ResponseWriter, fileInfo *fileidx.FileInfo) error {
|
||||
res.Header().Set("Accept-Ranges", "bytes")
|
||||
res.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filepath.Base(fileInfo.PathLocal)))
|
||||
res.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Uploaded))
|
||||
res.Header().Set("Content-Type", "application/octet-stream")
|
||||
res.Header().Set("Last-Modified", time.Unix(fileInfo.ModTime, 0).UTC().Format(http.TimeFormat))
|
||||
res.WriteHeader(http.StatusOK)
|
||||
|
||||
// TODO: need verify path
|
||||
file, openErr := tb.Filer.Open(filepath.Join(tb.Root, fileInfo.PathLocal))
|
||||
defer file.Close()
|
||||
if openErr != nil {
|
||||
return openErr
|
||||
}
|
||||
|
||||
copyErr := tb.throttledCopyN(res, file, fileInfo.Uploaded)
|
||||
if copyErr != nil && copyErr != io.EOF {
|
||||
return copyErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tb *QTube) serveRanges(res http.ResponseWriter, headerRange string, fileInfo *fileidx.FileInfo) error {
|
||||
ranges, rangeErr := getRanges(headerRange, fileInfo.Uploaded)
|
||||
if rangeErr != nil {
|
||||
http.Error(res, rangeErr.Error(), http.StatusRequestedRangeNotSatisfiable)
|
||||
return errors.New(rangeErr.Error())
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(ranges) == 1 || len(ranges) > 1:
|
||||
if tb.copyRange(res, ranges[0], fileInfo) != nil {
|
||||
return ErrCopy
|
||||
}
|
||||
default:
|
||||
// TODO: add support for multiple ranges
|
||||
return ErrUnknown
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRanges(headerRange string, size int64) ([]httpRange, error) {
|
||||
ranges, raParseErr := parseRange(headerRange, size)
|
||||
// TODO: check max number of ranges, range start end
|
||||
if len(ranges) <= 0 || raParseErr != nil {
|
||||
return nil, errors.New(ErrorInvalidRange)
|
||||
}
|
||||
if sumRangesSize(ranges) > size {
|
||||
return nil, errors.New(ErrorInvalidSize)
|
||||
}
|
||||
|
||||
return ranges, nil
|
||||
}
|
||||
|
||||
func (tb *QTube) copyRange(res http.ResponseWriter, ra httpRange, fileInfo *fileidx.FileInfo) error {
|
||||
// TODO: comfirm this wont cause problem
|
||||
if ra.GetLength() > tb.MaxRangeLen {
|
||||
ra.SetLength(tb.MaxRangeLen)
|
||||
}
|
||||
|
||||
// TODO: add headers(ETag): https://tools.ietf.org/html/rfc7233#section-4.1 p11 2nd paragraph
|
||||
res.Header().Set("Accept-Ranges", "bytes")
|
||||
res.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filepath.Base(fileInfo.PathLocal)))
|
||||
res.Header().Set("Content-Type", "application/octet-stream")
|
||||
res.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, fileInfo.Uploaded))
|
||||
res.Header().Set("Content-Length", strconv.FormatInt(ra.GetLength(), 10))
|
||||
res.Header().Set("Last-Modified", time.Unix(fileInfo.ModTime, 0).UTC().Format(http.TimeFormat))
|
||||
res.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
// TODO: need verify path
|
||||
file, openErr := tb.Filer.Open(filepath.Join(tb.Root, fileInfo.PathLocal))
|
||||
defer file.Close()
|
||||
if openErr != nil {
|
||||
return openErr
|
||||
}
|
||||
|
||||
if _, seekErr := file.Seek(ra.start, io.SeekStart); seekErr != nil {
|
||||
return seekErr
|
||||
}
|
||||
|
||||
copyErr := tb.throttledCopyN(res, file, ra.length)
|
||||
if copyErr != nil && copyErr != io.EOF {
|
||||
return copyErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tb *QTube) throttledCopyN(dst io.Writer, src io.Reader, length int64) error {
|
||||
sum := int64(0)
|
||||
timeSlot := time.Duration(1 * time.Second)
|
||||
|
||||
for sum < length {
|
||||
start := time.Now()
|
||||
chunkSize := length - sum
|
||||
if length-sum > tb.BytesPerSec {
|
||||
chunkSize = tb.BytesPerSec
|
||||
}
|
||||
|
||||
copied, err := io.CopyN(dst, src, chunkSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sum += copied
|
||||
end := time.Now()
|
||||
if end.Before(start.Add(timeSlot)) {
|
||||
time.Sleep(start.Add(timeSlot).Sub(end))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRange(headerRange string, size int64) ([]httpRange, error) {
|
||||
if headerRange == "" {
|
||||
return nil, nil // header not present
|
||||
}
|
||||
|
||||
const keyByte = "bytes="
|
||||
if !strings.HasPrefix(headerRange, keyByte) {
|
||||
return nil, errors.New("byte= not found")
|
||||
}
|
||||
|
||||
var ranges []httpRange
|
||||
noOverlap := false
|
||||
for _, ra := range strings.Split(headerRange[len(keyByte):], ",") {
|
||||
ra = strings.TrimSpace(ra)
|
||||
if ra == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
i := strings.Index(ra, "-")
|
||||
if i < 0 {
|
||||
return nil, errors.New("- not found")
|
||||
}
|
||||
|
||||
start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
|
||||
var r httpRange
|
||||
if start == "" {
|
||||
i, err := strconv.ParseInt(end, 10, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i > size {
|
||||
i = size
|
||||
}
|
||||
r.start = size - i
|
||||
r.length = size - r.start
|
||||
} else {
|
||||
i, err := strconv.ParseInt(start, 10, 64)
|
||||
if err != nil || i < 0 {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i >= size {
|
||||
// If the range begins after the size of the content,
|
||||
// then it does not overlap.
|
||||
noOverlap = true
|
||||
continue
|
||||
}
|
||||
r.start = i
|
||||
if end == "" {
|
||||
// If no end is specified, range extends to end of the file.
|
||||
r.length = size - r.start
|
||||
} else {
|
||||
i, err := strconv.ParseInt(end, 10, 64)
|
||||
if err != nil || r.start > i {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i >= size {
|
||||
i = size - 1
|
||||
}
|
||||
r.length = i - r.start + 1
|
||||
}
|
||||
}
|
||||
ranges = append(ranges, r)
|
||||
}
|
||||
if noOverlap && len(ranges) == 0 {
|
||||
// The specified ranges did not overlap with the content.
|
||||
return nil, errors.New("parseRanges: no overlap")
|
||||
}
|
||||
return ranges, nil
|
||||
}
|
||||
|
||||
func sumRangesSize(ranges []httpRange) (size int64) {
|
||||
for _, ra := range ranges {
|
||||
size += ra.length
|
||||
}
|
||||
return
|
||||
}
|
354
server/libs/qtube/qtube_test.go
Normal file
354
server/libs/qtube/qtube_test.go
Normal file
|
@ -0,0 +1,354 @@
|
|||
package qtube
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/fileidx"
|
||||
)
|
||||
|
||||
// Range format examples:
|
||||
// Range: <unit>=<range-start>-
|
||||
// Range: <unit>=<range-start>-<range-end>
|
||||
// Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>
|
||||
// Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>, <range-start>-<range-end>
|
||||
func TestGetRanges(t *testing.T) {
|
||||
type Input struct {
|
||||
HeaderRange string
|
||||
Size int64
|
||||
}
|
||||
type Output struct {
|
||||
Ranges []httpRange
|
||||
ErrorMsg string
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "invalid range",
|
||||
Input: Input{
|
||||
HeaderRange: "bytes=start-invalid end",
|
||||
Size: 0,
|
||||
},
|
||||
Output: Output{
|
||||
ErrorMsg: ErrorInvalidRange,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "invalid range total size",
|
||||
Input: Input{
|
||||
HeaderRange: "bytes=0-1, 2-3, 0-1, 0-2",
|
||||
Size: 3,
|
||||
},
|
||||
Output: Output{
|
||||
ErrorMsg: ErrorInvalidSize,
|
||||
},
|
||||
},
|
||||
testCase{
|
||||
Desc: "range ok",
|
||||
Input: Input{
|
||||
HeaderRange: "bytes=0-1, 2-3",
|
||||
Size: 4,
|
||||
},
|
||||
Output: Output{
|
||||
Ranges: []httpRange{
|
||||
httpRange{start: 0, length: 2},
|
||||
httpRange{start: 2, length: 2},
|
||||
},
|
||||
ErrorMsg: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
ranges, err := getRanges(tCase.HeaderRange, tCase.Size)
|
||||
if err != nil {
|
||||
if err.Error() != tCase.ErrorMsg || len(tCase.Ranges) != 0 {
|
||||
t.Fatalf("getRanges: incorrect errorMsg want: %v got: %v", tCase.ErrorMsg, err.Error())
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
for id, ra := range ranges {
|
||||
if ra.GetStart() != tCase.Ranges[id].GetStart() {
|
||||
t.Fatalf("getRanges: incorrect range start, got: %v want: %v", ra.GetStart(), tCase.Ranges[id])
|
||||
}
|
||||
if ra.GetLength() != tCase.Ranges[id].GetLength() {
|
||||
t.Fatalf("getRanges: incorrect range length, got: %v want: %v", ra.GetLength(), tCase.Ranges[id])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThrottledCopyN(t *testing.T) {
|
||||
type Init struct {
|
||||
BytesPerSec int64
|
||||
MaxRangeLen int64
|
||||
}
|
||||
type Input struct {
|
||||
Src string
|
||||
Length int64
|
||||
}
|
||||
// after starting throttledCopyN by DstAtTime.AtMs millisecond,
|
||||
// copied valueshould equal to DstAtTime.Dst.
|
||||
type DstAtTime struct {
|
||||
AtMS int
|
||||
Dst string
|
||||
}
|
||||
type Output struct {
|
||||
ExpectDsts []DstAtTime
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
verifyDsts := func(dst *bytes.Buffer, expectDsts []DstAtTime) {
|
||||
for _, expectDst := range expectDsts {
|
||||
// fmt.Printf("sleep: %d\n", time.Now().UnixNano())
|
||||
time.Sleep(time.Duration(expectDst.AtMS) * time.Millisecond)
|
||||
dstStr := string(dst.Bytes())
|
||||
// fmt.Printf("check: %d\n", time.Now().UnixNano())
|
||||
if dstStr != expectDst.Dst {
|
||||
panic(
|
||||
fmt.Sprintf(
|
||||
"throttledCopyN want: <%s> | got: <%s> | at: %d",
|
||||
expectDst.Dst,
|
||||
dstStr,
|
||||
expectDst.AtMS,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "4 byte per sec",
|
||||
Init: Init{
|
||||
BytesPerSec: 5,
|
||||
MaxRangeLen: 10,
|
||||
},
|
||||
Input: Input{
|
||||
Src: "aaaa_aaaa_",
|
||||
Length: 10,
|
||||
},
|
||||
Output: Output{
|
||||
ExpectDsts: []DstAtTime{
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_"},
|
||||
DstAtTime{AtMS: 600, Dst: "aaaa_aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_aaaa_"},
|
||||
DstAtTime{AtMS: 200, Dst: "aaaa_aaaa_"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
tb := NewQTube("", tCase.BytesPerSec, tCase.MaxRangeLen, &stubFiler{}).(*QTube)
|
||||
dst := bytes.NewBuffer(make([]byte, len(tCase.Src)))
|
||||
dst.Reset()
|
||||
|
||||
go verifyDsts(dst, tCase.ExpectDsts)
|
||||
tb.throttledCopyN(dst, strings.NewReader(tCase.Src), tCase.Length)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: using same stub with testhelper
|
||||
type stubWriter struct {
|
||||
Headers http.Header
|
||||
Response []byte
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (w *stubWriter) Header() http.Header {
|
||||
return w.Headers
|
||||
}
|
||||
|
||||
func (w *stubWriter) Write(body []byte) (int, error) {
|
||||
w.Response = append(w.Response, body...)
|
||||
return len(body), nil
|
||||
}
|
||||
|
||||
func (w *stubWriter) WriteHeader(statusCode int) {
|
||||
w.StatusCode = statusCode
|
||||
}
|
||||
|
||||
func TestCopyRange(t *testing.T) {
|
||||
type Init struct {
|
||||
Content string
|
||||
}
|
||||
type Input struct {
|
||||
Range httpRange
|
||||
Info fileidx.FileInfo
|
||||
}
|
||||
type Output struct {
|
||||
StatusCode int
|
||||
Headers map[string][]string
|
||||
Body string
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "copy ok",
|
||||
Init: Init{
|
||||
Content: "abcd_abcd_",
|
||||
},
|
||||
Input: Input{
|
||||
Range: httpRange{
|
||||
start: 6,
|
||||
length: 3,
|
||||
},
|
||||
Info: fileidx.FileInfo{
|
||||
ModTime: 0,
|
||||
Uploaded: 10,
|
||||
PathLocal: "filename.jpg",
|
||||
},
|
||||
},
|
||||
Output: Output{
|
||||
StatusCode: 206,
|
||||
Headers: map[string][]string{
|
||||
"Accept-Ranges": []string{"bytes"},
|
||||
"Content-Disposition": []string{`attachment; filename="filename.jpg"`},
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
"Content-Range": []string{"bytes 6-8/10"},
|
||||
"Content-Length": []string{"3"},
|
||||
"Last-Modified": []string{time.Unix(0, 0).UTC().Format(http.TimeFormat)},
|
||||
},
|
||||
Body: "abc",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
filer := &stubFiler{
|
||||
&StubFile{
|
||||
Content: tCase.Content,
|
||||
Offset: 0,
|
||||
},
|
||||
}
|
||||
tb := NewQTube("", 100, 100, filer).(*QTube)
|
||||
res := &stubWriter{
|
||||
Headers: make(map[string][]string),
|
||||
Response: make([]byte, 0),
|
||||
}
|
||||
err := tb.copyRange(res, tCase.Range, &tCase.Info)
|
||||
if err != nil {
|
||||
t.Fatalf("copyRange: %v", err)
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("copyRange: statusCode not match got: %v want: %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
if string(res.Response) != tCase.Output.Body {
|
||||
t.Fatalf("copyRange: body not match \ngot: %v \nwant: %v", string(res.Response), tCase.Output.Body)
|
||||
}
|
||||
for key, vals := range tCase.Output.Headers {
|
||||
if res.Header().Get(key) != vals[0] {
|
||||
t.Fatalf("copyRange: header not match %v got: %v want: %v", key, res.Header().Get(key), vals[0])
|
||||
}
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("copyRange: statusCodes are not match %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeAll(t *testing.T) {
|
||||
type Init struct {
|
||||
Content string
|
||||
}
|
||||
type Input struct {
|
||||
Info fileidx.FileInfo
|
||||
}
|
||||
type Output struct {
|
||||
StatusCode int
|
||||
Headers map[string][]string
|
||||
Body string
|
||||
}
|
||||
type testCase struct {
|
||||
Desc string
|
||||
Init
|
||||
Input
|
||||
Output
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
testCase{
|
||||
Desc: "copy ok",
|
||||
Init: Init{
|
||||
Content: "abcd_abcd_",
|
||||
},
|
||||
Input: Input{
|
||||
Info: fileidx.FileInfo{
|
||||
ModTime: 0,
|
||||
Uploaded: 10,
|
||||
PathLocal: "filename.jpg",
|
||||
},
|
||||
},
|
||||
Output: Output{
|
||||
StatusCode: 200,
|
||||
Headers: map[string][]string{
|
||||
"Accept-Ranges": []string{"bytes"},
|
||||
"Content-Disposition": []string{`attachment; filename="filename.jpg"`},
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
"Content-Length": []string{"10"},
|
||||
"Last-Modified": []string{time.Unix(0, 0).UTC().Format(http.TimeFormat)},
|
||||
},
|
||||
Body: "abcd_abcd_",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range testCases {
|
||||
filer := &stubFiler{
|
||||
&StubFile{
|
||||
Content: tCase.Content,
|
||||
Offset: 0,
|
||||
},
|
||||
}
|
||||
tb := NewQTube("", 100, 100, filer).(*QTube)
|
||||
res := &stubWriter{
|
||||
Headers: make(map[string][]string),
|
||||
Response: make([]byte, 0),
|
||||
}
|
||||
err := tb.serveAll(res, &tCase.Info)
|
||||
if err != nil {
|
||||
t.Fatalf("serveAll: %v", err)
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("serveAll: statusCode not match got: %v want: %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
if string(res.Response) != tCase.Output.Body {
|
||||
t.Fatalf("serveAll: body not match \ngot: %v \nwant: %v", string(res.Response), tCase.Output.Body)
|
||||
}
|
||||
for key, vals := range tCase.Output.Headers {
|
||||
if res.Header().Get(key) != vals[0] {
|
||||
t.Fatalf("serveAll: header not match %v got: %v want: %v", key, res.Header().Get(key), vals[0])
|
||||
}
|
||||
}
|
||||
if res.StatusCode != tCase.Output.StatusCode {
|
||||
t.Fatalf("serveAll: statusCodes are not match %v", res.StatusCode, tCase.Output.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
28
server/libs/qtube/test_helper.go
Normal file
28
server/libs/qtube/test_helper.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package qtube
|
||||
|
||||
type StubFile struct {
|
||||
Content string
|
||||
Offset int64
|
||||
}
|
||||
|
||||
func (file *StubFile) Read(p []byte) (int, error) {
|
||||
copied := copy(p[:], []byte(file.Content)[:len(p)])
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
func (file *StubFile) Seek(offset int64, whence int) (int64, error) {
|
||||
file.Offset = offset
|
||||
return offset, nil
|
||||
}
|
||||
|
||||
func (file *StubFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubFiler struct {
|
||||
file *StubFile
|
||||
}
|
||||
|
||||
func (filer *stubFiler) Open(filePath string) (ReadSeekCloser, error) {
|
||||
return filer.file, nil
|
||||
}
|
102
server/libs/walls/access_walls.go
Normal file
102
server/libs/walls/access_walls.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package walls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/limiter"
|
||||
)
|
||||
|
||||
type AccessWalls struct {
|
||||
cf *cfg.Config
|
||||
IpLimiter limiter.Limiter
|
||||
OpLimiter limiter.Limiter
|
||||
EncrypterMaker encrypt.EncrypterMaker
|
||||
}
|
||||
|
||||
func NewAccessWalls(
|
||||
cf *cfg.Config,
|
||||
ipLimiter limiter.Limiter,
|
||||
opLimiter limiter.Limiter,
|
||||
encrypterMaker encrypt.EncrypterMaker,
|
||||
) Walls {
|
||||
return &AccessWalls{
|
||||
cf: cf,
|
||||
IpLimiter: ipLimiter,
|
||||
OpLimiter: opLimiter,
|
||||
EncrypterMaker: encrypterMaker,
|
||||
}
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) PassIpLimit(remoteAddr string) bool {
|
||||
if !walls.cf.Production {
|
||||
return true
|
||||
}
|
||||
return walls.IpLimiter.Access(remoteAddr, walls.cf.OpIdIpVisit)
|
||||
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) PassOpLimit(resourceId string, opId int16) bool {
|
||||
if !walls.cf.Production {
|
||||
return true
|
||||
}
|
||||
return walls.OpLimiter.Access(resourceId, opId)
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) PassLoginCheck(tokenStr string, req *http.Request) bool {
|
||||
if !walls.cf.Production {
|
||||
return true
|
||||
}
|
||||
|
||||
return walls.passLoginCheck(tokenStr)
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) passLoginCheck(tokenStr string) bool {
|
||||
token, getLoginTokenOk := walls.GetLoginToken(tokenStr)
|
||||
return getLoginTokenOk && token.AdminId == walls.cf.AdminId
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) GetLoginToken(tokenStr string) (*LoginToken, bool) {
|
||||
tokenMaker := walls.EncrypterMaker(string(walls.cf.SecretKeyByte))
|
||||
if !tokenMaker.FromStr(tokenStr) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
adminIdFromToken, adminIdOk := tokenMaker.Get(walls.cf.KeyAdminId)
|
||||
expiresStr, expiresStrOk := tokenMaker.Get(walls.cf.KeyExpires)
|
||||
if !adminIdOk || !expiresStrOk {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
expires, expiresParseErr := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if expiresParseErr != nil ||
|
||||
adminIdFromToken != walls.cf.AdminId ||
|
||||
expires <= time.Now().Unix() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &LoginToken{
|
||||
AdminId: adminIdFromToken,
|
||||
Expires: expires,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (walls *AccessWalls) MakeLoginToken(userId string) string {
|
||||
expires := time.Now().Add(time.Duration(walls.cf.CookieMaxAge) * time.Second).Unix()
|
||||
|
||||
tokenMaker := walls.EncrypterMaker(string(walls.cf.SecretKeyByte))
|
||||
tokenMaker.Add(walls.cf.KeyAdminId, userId)
|
||||
tokenMaker.Add(walls.cf.KeyExpires, fmt.Sprintf("%d", expires))
|
||||
|
||||
tokenStr, ok := tokenMaker.ToStr()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return tokenStr
|
||||
}
|
145
server/libs/walls/access_walls_test.go
Normal file
145
server/libs/walls/access_walls_test.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package walls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"quickshare/server/libs/cfg"
|
||||
"quickshare/server/libs/encrypt"
|
||||
"quickshare/server/libs/limiter"
|
||||
)
|
||||
|
||||
func newAccessWalls(limiterCap int64, limiterTtl int32, limiterCyc int32, bucketCap int16) *AccessWalls {
|
||||
config := cfg.NewConfig()
|
||||
config.Production = true
|
||||
config.LimiterCap = limiterCap
|
||||
config.LimiterTtl = limiterTtl
|
||||
config.LimiterCyc = limiterCyc
|
||||
config.BucketCap = bucketCap
|
||||
encrypterMaker := encrypt.JwtEncrypterMaker
|
||||
ipLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
|
||||
opLimiter := limiter.NewRateLimiter(config.LimiterCap, config.LimiterTtl, config.LimiterCyc, config.BucketCap, map[int16]int16{})
|
||||
|
||||
return NewAccessWalls(config, ipLimiter, opLimiter, encrypterMaker).(*AccessWalls)
|
||||
}
|
||||
func TestIpLimit(t *testing.T) {
|
||||
ip := "0.0.0.0"
|
||||
limit := int16(10)
|
||||
ttl := int32(60)
|
||||
cyc := int32(5)
|
||||
walls := newAccessWalls(1000, ttl, cyc, limit)
|
||||
|
||||
testIpLimit(t, walls, ip, limit)
|
||||
// wait for tokens are re-fullfilled
|
||||
time.Sleep(time.Duration(cyc) * time.Second)
|
||||
testIpLimit(t, walls, ip, limit)
|
||||
|
||||
fmt.Println("ip limit: passed")
|
||||
}
|
||||
|
||||
func testIpLimit(t *testing.T, walls Walls, ip string, limit int16) {
|
||||
for i := int16(0); i < limit; i++ {
|
||||
if !walls.PassIpLimit(ip) {
|
||||
t.Fatalf("ipLimiter: should be passed", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
if walls.PassIpLimit(ip) {
|
||||
t.Fatalf("ipLimiter: should not be passed", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpLimit(t *testing.T) {
|
||||
resourceId := "id"
|
||||
op1 := int16(1)
|
||||
op2 := int16(2)
|
||||
limit := int16(10)
|
||||
ttl := int32(1)
|
||||
walls := newAccessWalls(1000, 5, ttl, limit)
|
||||
|
||||
testOpLimit(t, walls, resourceId, op1, limit)
|
||||
testOpLimit(t, walls, resourceId, op2, limit)
|
||||
// wait for tokens are re-fullfilled
|
||||
time.Sleep(time.Duration(ttl) * time.Second)
|
||||
testOpLimit(t, walls, resourceId, op1, limit)
|
||||
testOpLimit(t, walls, resourceId, op2, limit)
|
||||
|
||||
fmt.Println("op limit: passed")
|
||||
}
|
||||
|
||||
func testOpLimit(t *testing.T, walls Walls, resourceId string, op int16, limit int16) {
|
||||
for i := int16(0); i < limit; i++ {
|
||||
if !walls.PassOpLimit(resourceId, op) {
|
||||
t.Fatalf("opLimiter: should be passed")
|
||||
}
|
||||
}
|
||||
|
||||
if walls.PassOpLimit(resourceId, op) {
|
||||
t.Fatalf("opLimiter: should not be passed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginCheck(t *testing.T) {
|
||||
walls := newAccessWalls(1000, 5, 1, 10)
|
||||
|
||||
testValidToken(t, walls)
|
||||
testInvalidAdminIdToken(t, walls)
|
||||
testExpiredToken(t, walls)
|
||||
}
|
||||
|
||||
func testValidToken(t *testing.T, walls *AccessWalls) {
|
||||
config := cfg.NewConfig()
|
||||
|
||||
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
|
||||
tokenMaker.Add(config.KeyAdminId, config.AdminId)
|
||||
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()+int64(10)))
|
||||
tokenStr, getTokenOk := tokenMaker.ToStr()
|
||||
if !getTokenOk {
|
||||
t.Fatalf("passLoginCheck: fail to generate token")
|
||||
}
|
||||
|
||||
if !walls.passLoginCheck(tokenStr) {
|
||||
t.Fatalf("loginCheck: should be passed")
|
||||
}
|
||||
|
||||
fmt.Println("loginCheck: valid token passed")
|
||||
}
|
||||
|
||||
func testInvalidAdminIdToken(t *testing.T, walls *AccessWalls) {
|
||||
config := cfg.NewConfig()
|
||||
|
||||
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
|
||||
tokenMaker.Add(config.KeyAdminId, "invalid admin id")
|
||||
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()+int64(10)))
|
||||
tokenStr, getTokenOk := tokenMaker.ToStr()
|
||||
if !getTokenOk {
|
||||
t.Fatalf("passLoginCheck: fail to generate token")
|
||||
}
|
||||
|
||||
if walls.passLoginCheck(tokenStr) {
|
||||
t.Fatalf("loginCheck: should not be passed")
|
||||
}
|
||||
|
||||
fmt.Println("loginCheck: invalid admin id passed")
|
||||
}
|
||||
|
||||
func testExpiredToken(t *testing.T, walls *AccessWalls) {
|
||||
config := cfg.NewConfig()
|
||||
|
||||
tokenMaker := encrypt.JwtEncrypterMaker(string(config.SecretKeyByte))
|
||||
tokenMaker.Add(config.KeyAdminId, config.AdminId)
|
||||
tokenMaker.Add(config.KeyExpires, fmt.Sprintf("%d", time.Now().Unix()-int64(1)))
|
||||
tokenStr, getTokenOk := tokenMaker.ToStr()
|
||||
if !getTokenOk {
|
||||
t.Fatalf("passLoginCheck: fail to generate token")
|
||||
}
|
||||
|
||||
if walls.passLoginCheck(tokenStr) {
|
||||
t.Fatalf("loginCheck: should not be passed")
|
||||
}
|
||||
|
||||
fmt.Println("loginCheck: expired token passed")
|
||||
}
|
17
server/libs/walls/walls.go
Normal file
17
server/libs/walls/walls.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package walls
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Walls interface {
|
||||
PassIpLimit(remoteAddr string) bool
|
||||
PassOpLimit(resourceId string, opId int16) bool
|
||||
PassLoginCheck(tokenStr string, req *http.Request) bool
|
||||
MakeLoginToken(uid string) string
|
||||
}
|
||||
|
||||
type LoginToken struct {
|
||||
AdminId string
|
||||
Expires int64
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue