feat(files): enable limiter for download

This commit is contained in:
hexxa 2021-08-08 16:29:22 +08:00 committed by Hexxa
parent fd5da3db37
commit cdd15be4aa
2 changed files with 92 additions and 18 deletions

View file

@ -4,7 +4,6 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/ihexxa/quickshare/src/userstore"
"io"
"net/http"
"os"
@ -20,6 +19,7 @@ import (
"github.com/ihexxa/quickshare/src/depidx"
q "github.com/ihexxa/quickshare/src/handlers"
"github.com/ihexxa/quickshare/src/userstore"
)
var (
@ -296,7 +296,7 @@ func (h *FileHandlers) UploadChunk(c *gin.Context) {
return
}
ok, err := h.deps.Limiter().CanUpload(userIDInt, len([]byte(req.Content)))
ok, err := h.deps.Limiter().CanWrite(userIDInt, len([]byte(req.Content)))
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
@ -453,6 +453,11 @@ func (h *FileHandlers) Download(c *gin.Context) {
c.JSON(q.ErrResp(c, 403, q.ErrAccessDenied))
return
}
userIDInt, err := strconv.ParseUint(userID, 10, 64)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
// TODO: when sharing is introduced, move following logics to a separeted method
// concurrently file accessing is managed by os
@ -479,12 +484,11 @@ func (h *FileHandlers) Download(c *gin.Context) {
}
contentType := http.DetectContentType(fileHeadBuf[:read])
r, err := h.deps.FS().GetFileReader(filePath)
fd, err := h.deps.FS().GetFileReader(filePath)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
// reader will be closed by multipart response writer
extraHeaders := map[string]string{
"Content-Disposition": fmt.Sprintf(`attachment; filename="%s"`, info.Name()),
@ -492,7 +496,13 @@ func (h *FileHandlers) Download(c *gin.Context) {
// respond to normal requests
if ifRangeVal != "" || rangeVal == "" {
c.DataFromReader(200, info.Size(), contentType, r, extraHeaders)
limitedReader, err := h.GetStreamReader(userIDInt, fd)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
c.DataFromReader(200, info.Size(), contentType, limitedReader, extraHeaders)
return
}
@ -503,16 +513,23 @@ func (h *FileHandlers) Download(c *gin.Context) {
return
}
mw, contentLength, err := multipart.NewResponseWriter(r, parts, false)
mw, contentLength, err := multipart.NewResponseWriter(fd, parts, false)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
// TODO: reader will be closed by multipart response writer
go mw.Write()
limitedReader, err := h.GetStreamReader(userIDInt, mw)
if err != nil {
c.JSON(q.ErrResp(c, 500, err))
return
}
// it takes the \r\n before body into account, so contentLength+2
c.DataFromReader(206, contentLength+2, contentType, mw, extraHeaders)
c.DataFromReader(206, contentLength+2, contentType, limitedReader, extraHeaders)
}
type ListResp struct {
@ -631,3 +648,33 @@ func (h *FileHandlers) DelUploading(c *gin.Context) {
c.JSON(q.Resp(200))
})
}
func (h *FileHandlers) GetStreamReader(userID uint64, fd io.Reader) (io.Reader, error) {
pr, pw := io.Pipe()
chunkSize := 100 * 1024 // notice: it can not be greater than limiter's token count
go func() {
defer pw.Close()
for {
ok, err := h.deps.Limiter().CanRead(userID, chunkSize)
if err != nil {
pw.CloseWithError(err)
break
} else if !ok {
time.Sleep(time.Duration(1) * time.Second)
continue
}
_, err = io.CopyN(pw, fd, int64(chunkSize))
if err != nil {
if err != io.EOF {
pw.CloseWithError(err)
}
break
}
}
}()
return pr, nil
}

View file

@ -11,26 +11,29 @@ import (
const cacheSizeLimit = 1024
type ILimiter interface {
CanUpload(id uint64, chunkSize int) (bool, error)
CanWrite(userID uint64, chunkSize int) (bool, error)
CanRead(userID uint64, chunkSize int) (bool, error)
}
type IOLimiter struct {
mtx *sync.Mutex
UploadLimiter *golimiter.Limiter
DownloadLimiter *golimiter.Limiter
users userstore.IUserStore
quotaCache map[uint64]*userstore.Quota
}
func NewIOLimiter(upCap, upCyc int, users userstore.IUserStore) *IOLimiter {
func NewIOLimiter(cap, cyc int, users userstore.IUserStore) *IOLimiter {
return &IOLimiter{
mtx: &sync.Mutex{},
UploadLimiter: golimiter.New(upCap, upCyc),
UploadLimiter: golimiter.New(cap, cyc),
DownloadLimiter: golimiter.New(cap, cyc),
users: users,
quotaCache: map[uint64]*userstore.Quota{},
}
}
func (lm *IOLimiter) CanUpload(id uint64, chunkSize int) (bool, error) {
func (lm *IOLimiter) CanWrite(id uint64, chunkSize int) (bool, error) {
lm.mtx.Lock()
defer lm.mtx.Unlock()
@ -54,6 +57,30 @@ func (lm *IOLimiter) CanUpload(id uint64, chunkSize int) (bool, error) {
), nil
}
func (lm *IOLimiter) CanRead(id uint64, chunkSize int) (bool, error) {
lm.mtx.Lock()
defer lm.mtx.Unlock()
quota, ok := lm.quotaCache[id]
if !ok {
user, err := lm.users.GetUser(id)
if err != nil {
return false, err
}
quota = user.Quota
lm.quotaCache[id] = quota
}
if len(lm.quotaCache) > cacheSizeLimit {
lm.clean()
}
return lm.DownloadLimiter.Access(
fmt.Sprint(id),
quota.DownloadSpeedLimit,
chunkSize,
), nil
}
func (lm *IOLimiter) clean() {
count := 0
for key := range lm.quotaCache {