令牌桶的简单实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package utils

import (
"sync"
"time"
)

// TokenBucket 令牌桶算法实现
type TokenBucket struct {
capacity int64 // 桶容量(最大令牌数)
tokens int64 // 当前令牌数
refillRate int64 // 令牌补充速率(每秒补充多少个令牌)
refillPeriod time.Duration // 补充周期
lastRefill time.Time // 上次补充时间
mutex sync.Mutex // 互斥锁
stopCh chan struct{} // 停止信号
isRunning bool // 是否正在运行
}

// NewTokenBucket 创建新的令牌桶
// capacity: 桶容量
// refillRate: 每秒补充的令牌数
func NewTokenBucket(capacity int64, refillRate int64) *TokenBucket {
bucket := &TokenBucket{
capacity: capacity,
tokens: capacity, // 初始时桶是满的
refillRate: refillRate,
refillPeriod: time.Second / time.Duration(refillRate), // 计算每个令牌的补充间隔
lastRefill: time.Now(),
stopCh: make(chan struct{}),
isRunning: false,
}

// 启动令牌补充协程
bucket.start()
return bucket
}

// start 启动令牌补充协程
func (tb *TokenBucket) start() {
tb.mutex.Lock()
if tb.isRunning {
tb.mutex.Unlock()
return
}
tb.isRunning = true
tb.mutex.Unlock()

go func() {
ticker := time.NewTicker(tb.refillPeriod)
defer ticker.Stop()

for {
select {
case <-ticker.C:
tb.refill()
case <-tb.stopCh:
return
}
}
}()
}

// refill 补充令牌
func (tb *TokenBucket) refill() {
tb.mutex.Lock()
defer tb.mutex.Unlock()

if tb.tokens < tb.capacity {
tb.tokens++
tb.lastRefill = time.Now()
}
}

// Allow 尝试获取一个令牌
func (tb *TokenBucket) Allow() bool {
return tb.AllowN(1)
}

// AllowN 尝试获取 n 个令牌
func (tb *TokenBucket) AllowN(n int64) bool {
tb.mutex.Lock()
defer tb.mutex.Unlock()

if tb.tokens >= n {
tb.tokens -= n
return true
}
return false
}

// WaitN 等待 n 个令牌,或者超时
// TODO
func (tb *TokenBucket) WaitN(n int64, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)

// 简单的轮询等待,实际生产中可以使用条件变量或 channel 优化
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
if tb.AllowN(n) {
return true
}
if time.Now().After(deadline) {
return false
}
<-ticker.C
}
}

// GetStatus 获取当前桶的状态
func (tb *TokenBucket) GetStatus() (current int64, capacity int64) {
tb.mutex.Lock()
defer tb.mutex.Unlock()

return tb.tokens, tb.capacity
}

// Stop 停止令牌桶
func (tb *TokenBucket) Stop() {
tb.mutex.Lock()
defer tb.mutex.Unlock()

if tb.isRunning {
close(tb.stopCh)
tb.isRunning = false
}
}

Reference

图解各类限流算法|固定窗口/计数器、滑动窗口、漏桶算法、令牌桶算法

token_bucket.go