260 lines
5.9 KiB
Go
260 lines
5.9 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// TokenBucket implements a token bucket rate limiter
|
|
type TokenBucket struct {
|
|
capacity int64 // Maximum number of tokens
|
|
tokens int64 // Current number of tokens
|
|
refillRate int64 // Tokens added per second
|
|
lastRefill time.Time // Last refill time
|
|
mu sync.Mutex // Protects token count
|
|
}
|
|
|
|
// NewTokenBucket creates a new token bucket rate limiter
|
|
func NewTokenBucket(capacity, refillRate int64) *TokenBucket {
|
|
return &TokenBucket{
|
|
capacity: capacity,
|
|
tokens: capacity,
|
|
refillRate: refillRate,
|
|
lastRefill: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request is allowed and consumes a token if available
|
|
func (tb *TokenBucket) Allow() bool {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
|
|
tb.refill()
|
|
|
|
if tb.tokens > 0 {
|
|
tb.tokens--
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// Wait blocks until a token becomes available or context is cancelled
|
|
func (tb *TokenBucket) Wait(ctx context.Context) error {
|
|
for {
|
|
if tb.Allow() {
|
|
return nil
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(time.Millisecond * 10):
|
|
// Small delay before retrying
|
|
}
|
|
}
|
|
}
|
|
|
|
// refill adds tokens based on elapsed time (caller must hold mutex)
|
|
func (tb *TokenBucket) refill() {
|
|
now := time.Now()
|
|
elapsed := now.Sub(tb.lastRefill)
|
|
|
|
tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate
|
|
if tokensToAdd > 0 {
|
|
tb.tokens = min(tb.capacity, tb.tokens+tokensToAdd)
|
|
tb.lastRefill = now
|
|
}
|
|
}
|
|
|
|
// min returns the minimum of two int64 values
|
|
func min(a, b int64) int64 {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Limiter interface for different rate limiting strategies
|
|
type Limiter interface {
|
|
Allow() bool
|
|
Wait(ctx context.Context) error
|
|
}
|
|
|
|
// FixedWindowLimiter implements fixed window rate limiting
|
|
type FixedWindowLimiter struct {
|
|
limit int64
|
|
window time.Duration
|
|
requests int64
|
|
windowStart time.Time
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewFixedWindowLimiter creates a new fixed window rate limiter
|
|
func NewFixedWindowLimiter(limit int64, window time.Duration) *FixedWindowLimiter {
|
|
return &FixedWindowLimiter{
|
|
limit: limit,
|
|
window: window,
|
|
windowStart: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request is allowed within the current window
|
|
func (fwl *FixedWindowLimiter) Allow() bool {
|
|
fwl.mu.Lock()
|
|
defer fwl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
|
|
// Check if we need to start a new window
|
|
if now.Sub(fwl.windowStart) >= fwl.window {
|
|
fwl.requests = 0
|
|
fwl.windowStart = now
|
|
}
|
|
|
|
if fwl.requests < fwl.limit {
|
|
fwl.requests++
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// Wait blocks until a request is allowed or context is cancelled
|
|
func (fwl *FixedWindowLimiter) Wait(ctx context.Context) error {
|
|
for {
|
|
if fwl.Allow() {
|
|
return nil
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(time.Millisecond * 10):
|
|
// Small delay before retrying
|
|
}
|
|
}
|
|
}
|
|
|
|
// AdaptiveLimiter adjusts rate based on response times and error rates
|
|
type AdaptiveLimiter struct {
|
|
baseLimiter Limiter
|
|
targetLatency time.Duration
|
|
maxErrorRate float64
|
|
currentRate int64
|
|
measurements []measurement
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
type measurement struct {
|
|
timestamp time.Time
|
|
responseTime time.Duration
|
|
success bool
|
|
}
|
|
|
|
// NewAdaptiveLimiter creates a new adaptive rate limiter
|
|
func NewAdaptiveLimiter(baseRate int64, targetLatency time.Duration, maxErrorRate float64) *AdaptiveLimiter {
|
|
return &AdaptiveLimiter{
|
|
baseLimiter: NewTokenBucket(baseRate, baseRate),
|
|
targetLatency: targetLatency,
|
|
maxErrorRate: maxErrorRate,
|
|
currentRate: baseRate,
|
|
measurements: make([]measurement, 0, 100),
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request is allowed
|
|
func (al *AdaptiveLimiter) Allow() bool {
|
|
return al.baseLimiter.Allow()
|
|
}
|
|
|
|
// Wait blocks until a request is allowed
|
|
func (al *AdaptiveLimiter) Wait(ctx context.Context) error {
|
|
return al.baseLimiter.Wait(ctx)
|
|
}
|
|
|
|
// RecordResponse records a response for adaptive adjustment
|
|
func (al *AdaptiveLimiter) RecordResponse(responseTime time.Duration, success bool) {
|
|
al.mu.Lock()
|
|
defer al.mu.Unlock()
|
|
|
|
// Add measurement
|
|
m := measurement{
|
|
timestamp: time.Now(),
|
|
responseTime: responseTime,
|
|
success: success,
|
|
}
|
|
al.measurements = append(al.measurements, m)
|
|
|
|
// Keep only recent measurements (last 100 or last minute)
|
|
cutoff := time.Now().Add(-time.Minute)
|
|
for i, measurement := range al.measurements {
|
|
if measurement.timestamp.After(cutoff) {
|
|
al.measurements = al.measurements[i:]
|
|
break
|
|
}
|
|
}
|
|
|
|
// Adjust rate if we have enough data
|
|
if len(al.measurements) >= 10 {
|
|
al.adjustRate()
|
|
}
|
|
}
|
|
|
|
// adjustRate adjusts the rate based on recent measurements
|
|
func (al *AdaptiveLimiter) adjustRate() {
|
|
if len(al.measurements) == 0 {
|
|
return
|
|
}
|
|
|
|
// Calculate average response time and error rate
|
|
var totalResponseTime time.Duration
|
|
var successCount int64
|
|
|
|
for _, m := range al.measurements {
|
|
totalResponseTime += m.responseTime
|
|
if m.success {
|
|
successCount++
|
|
}
|
|
}
|
|
|
|
avgResponseTime := totalResponseTime / time.Duration(len(al.measurements))
|
|
errorRate := 1.0 - float64(successCount)/float64(len(al.measurements))
|
|
|
|
// Adjust rate based on metrics
|
|
adjustmentFactor := 1.0
|
|
|
|
if avgResponseTime > al.targetLatency {
|
|
// Response time too high, decrease rate
|
|
adjustmentFactor = 0.9
|
|
} else if avgResponseTime < al.targetLatency/2 {
|
|
// Response time very good, increase rate
|
|
adjustmentFactor = 1.1
|
|
}
|
|
|
|
if errorRate > al.maxErrorRate {
|
|
// Error rate too high, decrease rate more aggressively
|
|
adjustmentFactor *= 0.8
|
|
}
|
|
|
|
// Apply adjustment
|
|
newRate := int64(float64(al.currentRate) * adjustmentFactor)
|
|
if newRate < 1 {
|
|
newRate = 1
|
|
}
|
|
|
|
if newRate != al.currentRate {
|
|
al.currentRate = newRate
|
|
// Update the base limiter with new rate
|
|
al.baseLimiter = NewTokenBucket(newRate, newRate)
|
|
}
|
|
}
|
|
|
|
// GetCurrentRate returns the current rate limit
|
|
func (al *AdaptiveLimiter) GetCurrentRate() int64 {
|
|
al.mu.RLock()
|
|
defer al.mu.RUnlock()
|
|
return al.currentRate
|
|
}
|