From cdd15be4aa7f131eab842f8cb8b97a3fa73346f1 Mon Sep 17 00:00:00 2001 From: hexxa Date: Sun, 8 Aug 2021 16:29:22 +0800 Subject: [PATCH] feat(files): enable limiter for download --- src/handlers/fileshdr/handlers.go | 61 +++++++++++++++++++++++++++---- src/iolimiter/iolimiter.go | 49 +++++++++++++++++++------ 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/src/handlers/fileshdr/handlers.go b/src/handlers/fileshdr/handlers.go index dadb650..fed1c68 100644 --- a/src/handlers/fileshdr/handlers.go +++ b/src/handlers/fileshdr/handlers.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/ihexxa/quickshare/src/userstore" "io" "net/http" "os" @@ -20,6 +19,7 @@ import ( "github.com/ihexxa/quickshare/src/depidx" q "github.com/ihexxa/quickshare/src/handlers" + "github.com/ihexxa/quickshare/src/userstore" ) var ( @@ -296,7 +296,7 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) { return } - ok, err := h.deps.Limiter().CanUpload(userIDInt, len([]byte(req.Content))) + ok, err := h.deps.Limiter().CanWrite(userIDInt, len([]byte(req.Content))) if err != nil { c.JSON(q.ErrResp(c, 500, err)) return @@ -453,6 +453,11 @@ func (h *FileHandlers) Download(c *gin.Context) { c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied)) return } + userIDInt, err := strconv.ParseUint(userID, 10, 64) + if err != nil { + c.JSON(q.ErrResp(c, 500, err)) + return + } // TODO: when sharing is introduced, move following logics to a separeted method // concurrently file accessing is managed by os @@ -479,12 +484,11 @@ func (h *FileHandlers) Download(c *gin.Context) { } contentType := http.DetectContentType(fileHeadBuf[:read]) - r, err := h.deps.FS().GetFileReader(filePath) + fd, err := h.deps.FS().GetFileReader(filePath) if err != nil { c.JSON(q.ErrResp(c, 500, err)) return } - // reader will be closed by multipart response writer extraHeaders := map[string]string{ "Content-Disposition": fmt.Sprintf(`attachment; filename="%s"`, info.Name()), @@ -492,7 +496,13 @@ func (h *FileHandlers) Download(c *gin.Context) { // respond to normal requests if ifRangeVal != "" || rangeVal == "" { - c.DataFromReader(200, info.Size(), contentType, r, extraHeaders) + limitedReader, err := h.GetStreamReader(userIDInt, fd) + if err != nil { + c.JSON(q.ErrResp(c, 500, err)) + return + } + + c.DataFromReader(200, info.Size(), contentType, limitedReader, extraHeaders) return } @@ -503,16 +513,23 @@ func (h *FileHandlers) Download(c *gin.Context) { return } - mw, contentLength, err := multipart.NewResponseWriter(r, parts, false) + mw, contentLength, err := multipart.NewResponseWriter(fd, parts, false) if err != nil { c.JSON(q.ErrResp(c, 500, err)) return } + // TODO: reader will be closed by multipart response writer? go mw.Write() + limitedReader, err := h.GetStreamReader(userIDInt, mw) + if err != nil { + c.JSON(q.ErrResp(c, 500, err)) + return + } + // it takes the \r\n before body into account, so contentLength+2 - c.DataFromReader(206, contentLength+2, contentType, mw, extraHeaders) + c.DataFromReader(206, contentLength+2, contentType, limitedReader, extraHeaders) } type ListResp struct { @@ -631,3 +648,33 @@ func (h *FileHandlers) DelUploading(c *gin.Context) { c.JSON(q.Resp(200)) }) } + +func (h *FileHandlers) GetStreamReader(userID uint64, fd io.Reader) (io.Reader, error) { + pr, pw := io.Pipe() + chunkSize := 100 * 1024 // notice: it can not be greater than limiter's token count + + go func() { + defer pw.Close() + + for { + ok, err := h.deps.Limiter().CanRead(userID, chunkSize) + if err != nil { + pw.CloseWithError(err) + break + } else if !ok { + time.Sleep(time.Duration(1) * time.Second) + continue + } + + _, err = io.CopyN(pw, fd, int64(chunkSize)) + if err != nil { + if err != io.EOF { + pw.CloseWithError(err) + } + break + } + } + }() + + return pr, nil +} diff --git a/src/iolimiter/iolimiter.go b/src/iolimiter/iolimiter.go index 3ca0988..a2c366c 100644 --- a/src/iolimiter/iolimiter.go +++ b/src/iolimiter/iolimiter.go @@ -11,26 +11,29 @@ import ( const cacheSizeLimit = 1024 type ILimiter interface { - CanUpload(id uint64, chunkSize int) (bool, error) + CanWrite(userID uint64, chunkSize int) (bool, error) + CanRead(userID uint64, chunkSize int) (bool, error) } type IOLimiter struct { - mtx *sync.Mutex - UploadLimiter *golimiter.Limiter - users userstore.IUserStore - quotaCache map[uint64]*userstore.Quota + mtx *sync.Mutex + UploadLimiter *golimiter.Limiter + DownloadLimiter *golimiter.Limiter + users userstore.IUserStore + quotaCache map[uint64]*userstore.Quota } -func NewIOLimiter(upCap, upCyc int, users userstore.IUserStore) *IOLimiter { +func NewIOLimiter(cap, cyc int, users userstore.IUserStore) *IOLimiter { return &IOLimiter{ - mtx: &sync.Mutex{}, - UploadLimiter: golimiter.New(upCap, upCyc), - users: users, - quotaCache: map[uint64]*userstore.Quota{}, + mtx: &sync.Mutex{}, + UploadLimiter: golimiter.New(cap, cyc), + DownloadLimiter: golimiter.New(cap, cyc), + users: users, + quotaCache: map[uint64]*userstore.Quota{}, } } -func (lm *IOLimiter) CanUpload(id uint64, chunkSize int) (bool, error) { +func (lm *IOLimiter) CanWrite(id uint64, chunkSize int) (bool, error) { lm.mtx.Lock() defer lm.mtx.Unlock() @@ -54,6 +57,30 @@ func (lm *IOLimiter) CanUpload(id uint64, chunkSize int) (bool, error) { ), nil } +func (lm *IOLimiter) CanRead(id uint64, chunkSize int) (bool, error) { + lm.mtx.Lock() + defer lm.mtx.Unlock() + + quota, ok := lm.quotaCache[id] + if !ok { + user, err := lm.users.GetUser(id) + if err != nil { + return false, err + } + quota = user.Quota + lm.quotaCache[id] = quota + } + if len(lm.quotaCache) > cacheSizeLimit { + lm.clean() + } + + return lm.DownloadLimiter.Access( + fmt.Sprint(id), + quota.DownloadSpeedLimit, + chunkSize, + ), nil +} + func (lm *IOLimiter) clean() { count := 0 for key := range lm.quotaCache {