feat(pkg): add limiter
This commit is contained in:
parent
e73947de0d
commit
e01f5f8351
2 changed files with 307 additions and 0 deletions
122
src/golimiter/limiter.go
Normal file
122
src/golimiter/limiter.go
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
package golimiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
// "math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const expiredCycCount = 3
|
||||||
|
|
||||||
|
type Bucket struct {
|
||||||
|
refreshedAt time.Time
|
||||||
|
token int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBucket(token int) *Bucket {
|
||||||
|
return &Bucket{
|
||||||
|
refreshedAt: time.Now(),
|
||||||
|
token: token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bucket) Access(cyc, incr, decr int) bool {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if decr > incr {
|
||||||
|
return false
|
||||||
|
} else if b.token >= decr {
|
||||||
|
b.token -= decr
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.refreshedAt.
|
||||||
|
Add(time.Duration(cyc) * time.Millisecond).
|
||||||
|
After(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fmt.Println(4)
|
||||||
|
b.token = incr - decr
|
||||||
|
b.refreshedAt = now
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type Limiter struct {
|
||||||
|
buckets map[string]*Bucket
|
||||||
|
cap int
|
||||||
|
cyc int
|
||||||
|
cleanBatch int
|
||||||
|
mtx *sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cap, cyc int) *Limiter {
|
||||||
|
if cap <= 0 {
|
||||||
|
panic("limiter: invalid cap <= 0")
|
||||||
|
}
|
||||||
|
if cyc <= 0 {
|
||||||
|
panic(fmt.Sprintf("limiter: invalid cyc=%d", cyc))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Limiter{
|
||||||
|
buckets: make(map[string]*Bucket),
|
||||||
|
cap: cap,
|
||||||
|
cyc: cyc,
|
||||||
|
cleanBatch: 10,
|
||||||
|
mtx: &sync.RWMutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func NewWithcleanBatch(cap, cyc, cleanBatch int64, refill int) *Limiter {
|
||||||
|
// limiter := New(cap, cyc, refill)
|
||||||
|
// limiter.cleanBatch = cleanBatch
|
||||||
|
// return limiter
|
||||||
|
// }
|
||||||
|
|
||||||
|
func (l *Limiter) Access(id string, incr, decr int) bool {
|
||||||
|
l.mtx.Lock()
|
||||||
|
defer l.mtx.Unlock()
|
||||||
|
|
||||||
|
b, ok := l.buckets[id]
|
||||||
|
if !ok {
|
||||||
|
size := len(l.buckets)
|
||||||
|
if size > l.cap/2 {
|
||||||
|
l.clean()
|
||||||
|
}
|
||||||
|
|
||||||
|
size = len(l.buckets)
|
||||||
|
if size+1 > l.cap || incr < decr {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
l.buckets[id] = NewBucket(incr - decr)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return b.Access(l.cyc, incr, decr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) clean() {
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
for key, bucket := range l.buckets {
|
||||||
|
if bucket.refreshedAt.
|
||||||
|
Add(time.Duration(l.cyc*expiredCycCount) * time.Millisecond).
|
||||||
|
Before(time.Now()) {
|
||||||
|
delete(l.buckets, key)
|
||||||
|
}
|
||||||
|
if count++; count >= 10 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) GetCap() int {
|
||||||
|
l.mtx.RLock()
|
||||||
|
defer l.mtx.RUnlock()
|
||||||
|
return l.cap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) GetCyc() int {
|
||||||
|
l.mtx.RLock()
|
||||||
|
defer l.mtx.RUnlock()
|
||||||
|
return l.cyc
|
||||||
|
}
|
185
src/golimiter/limiter_test.go
Normal file
185
src/golimiter/limiter_test.go
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
package golimiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
// "math/rand"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
// "time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLimiter(t *testing.T) {
|
||||||
|
t.Run("access count is limited", func(t *testing.T) {
|
||||||
|
clientCount := 3
|
||||||
|
tokenMaxCount := 3
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
counts := make([]int, clientCount)
|
||||||
|
|
||||||
|
limiter := New(clientCount, 2000)
|
||||||
|
client := func(id int) {
|
||||||
|
lid := fmt.Sprint(id)
|
||||||
|
for i := 0; i < tokenMaxCount*2; i++ {
|
||||||
|
ok := limiter.Access(lid, tokenMaxCount, 1)
|
||||||
|
if ok {
|
||||||
|
counts[id]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < clientCount; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go client(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for id := range counts {
|
||||||
|
if counts[id] != tokenMaxCount {
|
||||||
|
t.Fatalf("id(%d): accessed(%d) tokenMaxCount(%d) don't match", id, counts[id], tokenMaxCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
|
||||||
|
// const rndCap = 10000
|
||||||
|
// const addCap = 1
|
||||||
|
|
||||||
|
// // how to set time
|
||||||
|
// // extend: wait can be greater than ttl/2
|
||||||
|
// // cyc is smaller than ttl and wait, then it can be clean in time
|
||||||
|
// const cap = 40
|
||||||
|
// const ttl = 3
|
||||||
|
// const cyc = 1
|
||||||
|
// const bucketCap = 2
|
||||||
|
// const id1 = "id1"
|
||||||
|
// const id2 = "id2"
|
||||||
|
// const op1 int16 = 0
|
||||||
|
// const op2 int16 = 1
|
||||||
|
|
||||||
|
// var customCaps = map[int16]int16{
|
||||||
|
// op2: 1000,
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const wait = 1
|
||||||
|
|
||||||
|
// var limiter = New(cap, cyc, bucketCap)
|
||||||
|
|
||||||
|
// var idSeed = 0
|
||||||
|
|
||||||
|
// func randId() string {
|
||||||
|
// idSeed++
|
||||||
|
// return fmt.Sprintf("%d", idSeed)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestAccess(t *testing.T) {
|
||||||
|
// func(t *testing.T) {
|
||||||
|
// canAccess := limiter.Access(id1)
|
||||||
|
// if !canAccess {
|
||||||
|
// t.Fatal("access: fail")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for i := 0; i < bucketCap; i++ {
|
||||||
|
// canAccess = limiter.Access(id1)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if canAccess {
|
||||||
|
// t.Fatal("access: fail to deny access")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// time.Sleep(time.Duration(limiter.GetCyc()) * time.Second)
|
||||||
|
|
||||||
|
// canAccess = limiter.Access(id1)
|
||||||
|
// if !canAccess {
|
||||||
|
// t.Fatal("access: fail to refresh tokens")
|
||||||
|
// }
|
||||||
|
// }(t)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestCap(t *testing.T) {
|
||||||
|
// originalCap := limiter.GetCap()
|
||||||
|
// fmt.Printf("cap:info: %d\n", originalCap)
|
||||||
|
|
||||||
|
// ok := limiter.ExpandCap(originalCap + addCap)
|
||||||
|
|
||||||
|
// if !ok || limiter.GetCap() != originalCap+addCap {
|
||||||
|
// t.Fatal("cap: fail to expand")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// ok = limiter.ExpandCap(limiter.GetSize() - addCap)
|
||||||
|
// if ok {
|
||||||
|
// t.Fatal("cap: shrink cap")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// ids := []string{}
|
||||||
|
// for limiter.GetSize() < limiter.GetCap() {
|
||||||
|
// id := randId()
|
||||||
|
// ids = append(ids, id)
|
||||||
|
|
||||||
|
// ok := limiter.Access(id, 0)
|
||||||
|
// if !ok {
|
||||||
|
// t.Fatal("cap: not full")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if limiter.GetSize() != limiter.GetCap() {
|
||||||
|
// t.Fatal("cap: incorrect size")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if limiter.Access(randId(), 0) {
|
||||||
|
// t.Fatal("cap: more than cap")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// limiter.truncate()
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestTtl(t *testing.T) {
|
||||||
|
// var addTtl int32 = 1
|
||||||
|
// originalTTL := limiter.GetTTL()
|
||||||
|
// fmt.Printf("ttl:info: %d\n", originalTTL)
|
||||||
|
|
||||||
|
// limiter.UpdateTTL(originalTTL + addTtl)
|
||||||
|
// if limiter.GetTTL() != originalTTL+addTtl {
|
||||||
|
// t.Fatal("ttl: update fail")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func cycTest(t *testing.T) {
|
||||||
|
// var addCyc int32 = 1
|
||||||
|
// originalCyc := limiter.GetCyc()
|
||||||
|
// fmt.Printf("cyc:info: %d\n", originalCyc)
|
||||||
|
|
||||||
|
// limiter.UpdateCyc(originalCyc + addCyc)
|
||||||
|
// if limiter.GetCyc() != originalCyc+addCyc {
|
||||||
|
// t.Fatal("cyc: update fail")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func autoCleanTest(t *testing.T) {
|
||||||
|
// ids := []string{
|
||||||
|
// randId(),
|
||||||
|
// randId(),
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, id := range ids {
|
||||||
|
// ok := limiter.Access(id, 0)
|
||||||
|
// if ok {
|
||||||
|
// t.Fatal("autoClean: warning: add fail")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// time.Sleep(time.Duration(limiter.GetTTL()+wait) * time.Second)
|
||||||
|
|
||||||
|
// for _, id := range ids {
|
||||||
|
// _, exist := limiter.get(id)
|
||||||
|
// if exist {
|
||||||
|
// t.Fatal("autoClean: item still exist")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func snapshotTest(t *testing.T) {
|
||||||
|
// }
|
Loading…
Add table
Add a link
Reference in a new issue