1
This commit is contained in:
259
internal/ratelimit/ratelimit.go
Normal file
259
internal/ratelimit/ratelimit.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBucket implements a token bucket rate limiter
|
||||
type TokenBucket struct {
|
||||
capacity int64 // Maximum number of tokens
|
||||
tokens int64 // Current number of tokens
|
||||
refillRate int64 // Tokens added per second
|
||||
lastRefill time.Time // Last refill time
|
||||
mu sync.Mutex // Protects token count
|
||||
}
|
||||
|
||||
// NewTokenBucket creates a new token bucket rate limiter
|
||||
func NewTokenBucket(capacity, refillRate int64) *TokenBucket {
|
||||
return &TokenBucket{
|
||||
capacity: capacity,
|
||||
tokens: capacity,
|
||||
refillRate: refillRate,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed and consumes a token if available
|
||||
func (tb *TokenBucket) Allow() bool {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
|
||||
tb.refill()
|
||||
|
||||
if tb.tokens > 0 {
|
||||
tb.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Wait blocks until a token becomes available or context is cancelled
|
||||
func (tb *TokenBucket) Wait(ctx context.Context) error {
|
||||
for {
|
||||
if tb.Allow() {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Millisecond * 10):
|
||||
// Small delay before retrying
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refill adds tokens based on elapsed time (caller must hold mutex)
|
||||
func (tb *TokenBucket) refill() {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(tb.lastRefill)
|
||||
|
||||
tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate
|
||||
if tokensToAdd > 0 {
|
||||
tb.tokens = min(tb.capacity, tb.tokens+tokensToAdd)
|
||||
tb.lastRefill = now
|
||||
}
|
||||
}
|
||||
|
||||
// min returns the minimum of two int64 values
|
||||
func min(a, b int64) int64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Limiter interface for different rate limiting strategies
|
||||
type Limiter interface {
|
||||
Allow() bool
|
||||
Wait(ctx context.Context) error
|
||||
}
|
||||
|
||||
// FixedWindowLimiter implements fixed window rate limiting
|
||||
type FixedWindowLimiter struct {
|
||||
limit int64
|
||||
window time.Duration
|
||||
requests int64
|
||||
windowStart time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewFixedWindowLimiter creates a new fixed window rate limiter
|
||||
func NewFixedWindowLimiter(limit int64, window time.Duration) *FixedWindowLimiter {
|
||||
return &FixedWindowLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
windowStart: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed within the current window
|
||||
func (fwl *FixedWindowLimiter) Allow() bool {
|
||||
fwl.mu.Lock()
|
||||
defer fwl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Check if we need to start a new window
|
||||
if now.Sub(fwl.windowStart) >= fwl.window {
|
||||
fwl.requests = 0
|
||||
fwl.windowStart = now
|
||||
}
|
||||
|
||||
if fwl.requests < fwl.limit {
|
||||
fwl.requests++
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Wait blocks until a request is allowed or context is cancelled
|
||||
func (fwl *FixedWindowLimiter) Wait(ctx context.Context) error {
|
||||
for {
|
||||
if fwl.Allow() {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Millisecond * 10):
|
||||
// Small delay before retrying
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AdaptiveLimiter adjusts rate based on response times and error rates
|
||||
type AdaptiveLimiter struct {
|
||||
baseLimiter Limiter
|
||||
targetLatency time.Duration
|
||||
maxErrorRate float64
|
||||
currentRate int64
|
||||
measurements []measurement
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type measurement struct {
|
||||
timestamp time.Time
|
||||
responseTime time.Duration
|
||||
success bool
|
||||
}
|
||||
|
||||
// NewAdaptiveLimiter creates a new adaptive rate limiter
|
||||
func NewAdaptiveLimiter(baseRate int64, targetLatency time.Duration, maxErrorRate float64) *AdaptiveLimiter {
|
||||
return &AdaptiveLimiter{
|
||||
baseLimiter: NewTokenBucket(baseRate, baseRate),
|
||||
targetLatency: targetLatency,
|
||||
maxErrorRate: maxErrorRate,
|
||||
currentRate: baseRate,
|
||||
measurements: make([]measurement, 0, 100),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed
|
||||
func (al *AdaptiveLimiter) Allow() bool {
|
||||
return al.baseLimiter.Allow()
|
||||
}
|
||||
|
||||
// Wait blocks until a request is allowed
|
||||
func (al *AdaptiveLimiter) Wait(ctx context.Context) error {
|
||||
return al.baseLimiter.Wait(ctx)
|
||||
}
|
||||
|
||||
// RecordResponse records a response for adaptive adjustment
|
||||
func (al *AdaptiveLimiter) RecordResponse(responseTime time.Duration, success bool) {
|
||||
al.mu.Lock()
|
||||
defer al.mu.Unlock()
|
||||
|
||||
// Add measurement
|
||||
m := measurement{
|
||||
timestamp: time.Now(),
|
||||
responseTime: responseTime,
|
||||
success: success,
|
||||
}
|
||||
al.measurements = append(al.measurements, m)
|
||||
|
||||
// Keep only recent measurements (last 100 or last minute)
|
||||
cutoff := time.Now().Add(-time.Minute)
|
||||
for i, measurement := range al.measurements {
|
||||
if measurement.timestamp.After(cutoff) {
|
||||
al.measurements = al.measurements[i:]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust rate if we have enough data
|
||||
if len(al.measurements) >= 10 {
|
||||
al.adjustRate()
|
||||
}
|
||||
}
|
||||
|
||||
// adjustRate adjusts the rate based on recent measurements
|
||||
func (al *AdaptiveLimiter) adjustRate() {
|
||||
if len(al.measurements) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate average response time and error rate
|
||||
var totalResponseTime time.Duration
|
||||
var successCount int64
|
||||
|
||||
for _, m := range al.measurements {
|
||||
totalResponseTime += m.responseTime
|
||||
if m.success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
avgResponseTime := totalResponseTime / time.Duration(len(al.measurements))
|
||||
errorRate := 1.0 - float64(successCount)/float64(len(al.measurements))
|
||||
|
||||
// Adjust rate based on metrics
|
||||
adjustmentFactor := 1.0
|
||||
|
||||
if avgResponseTime > al.targetLatency {
|
||||
// Response time too high, decrease rate
|
||||
adjustmentFactor = 0.9
|
||||
} else if avgResponseTime < al.targetLatency/2 {
|
||||
// Response time very good, increase rate
|
||||
adjustmentFactor = 1.1
|
||||
}
|
||||
|
||||
if errorRate > al.maxErrorRate {
|
||||
// Error rate too high, decrease rate more aggressively
|
||||
adjustmentFactor *= 0.8
|
||||
}
|
||||
|
||||
// Apply adjustment
|
||||
newRate := int64(float64(al.currentRate) * adjustmentFactor)
|
||||
if newRate < 1 {
|
||||
newRate = 1
|
||||
}
|
||||
|
||||
if newRate != al.currentRate {
|
||||
al.currentRate = newRate
|
||||
// Update the base limiter with new rate
|
||||
al.baseLimiter = NewTokenBucket(newRate, newRate)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentRate returns the current rate limit
|
||||
func (al *AdaptiveLimiter) GetCurrentRate() int64 {
|
||||
al.mu.RLock()
|
||||
defer al.mu.RUnlock()
|
||||
return al.currentRate
|
||||
}
|
200
internal/ratelimit/ratelimit_test.go
Normal file
200
internal/ratelimit/ratelimit_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBucket_Allow(t *testing.T) {
|
||||
bucket := NewTokenBucket(5, 1) // 5 capacity, 1 token per second
|
||||
|
||||
// Should allow 5 requests initially
|
||||
for i := 0; i < 5; i++ {
|
||||
if !bucket.Allow() {
|
||||
t.Errorf("Request %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 6th request should be denied
|
||||
if bucket.Allow() {
|
||||
t.Error("6th request should be denied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBucket_Refill(t *testing.T) {
|
||||
bucket := NewTokenBucket(2, 2) // 2 capacity, 2 tokens per second
|
||||
|
||||
// Consume all tokens
|
||||
bucket.Allow()
|
||||
bucket.Allow()
|
||||
|
||||
// Should be empty now
|
||||
if bucket.Allow() {
|
||||
t.Error("Bucket should be empty")
|
||||
}
|
||||
|
||||
// Wait for refill
|
||||
time.Sleep(1100 * time.Millisecond) // Wait a bit more than 1 second
|
||||
|
||||
// Should have tokens again
|
||||
if !bucket.Allow() {
|
||||
t.Error("Should have tokens after refill")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBucket_Wait(t *testing.T) {
|
||||
bucket := NewTokenBucket(1, 1) // 1 capacity, 1 token per second
|
||||
|
||||
// Consume the token
|
||||
if !bucket.Allow() {
|
||||
t.Error("First request should be allowed")
|
||||
}
|
||||
|
||||
// Test wait with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
err := bucket.Wait(ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("Expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
|
||||
if duration < 90*time.Millisecond || duration > 150*time.Millisecond {
|
||||
t.Errorf("Wait duration should be around 100ms, got %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedWindowLimiter_Allow(t *testing.T) {
|
||||
limiter := NewFixedWindowLimiter(3, 1*time.Second) // 3 requests per second
|
||||
|
||||
// Should allow 3 requests
|
||||
for i := 0; i < 3; i++ {
|
||||
if !limiter.Allow() {
|
||||
t.Errorf("Request %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 4th request should be denied
|
||||
if limiter.Allow() {
|
||||
t.Error("4th request should be denied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedWindowLimiter_WindowReset(t *testing.T) {
|
||||
limiter := NewFixedWindowLimiter(2, 500*time.Millisecond) // 2 requests per 500ms
|
||||
|
||||
// Consume all requests
|
||||
limiter.Allow()
|
||||
limiter.Allow()
|
||||
|
||||
// Should be at limit
|
||||
if limiter.Allow() {
|
||||
t.Error("Should be at limit")
|
||||
}
|
||||
|
||||
// Wait for window reset
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// Should allow requests again
|
||||
if !limiter.Allow() {
|
||||
t.Error("Should allow requests after window reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdaptiveLimiter_Basic(t *testing.T) {
|
||||
limiter := NewAdaptiveLimiter(10, 100*time.Millisecond, 0.05) // 10 RPS, 100ms target, 5% max error
|
||||
|
||||
// Should allow initial requests
|
||||
for i := 0; i < 5; i++ {
|
||||
if !limiter.Allow() {
|
||||
t.Errorf("Request %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdaptiveLimiter_RecordResponse(t *testing.T) {
|
||||
limiter := NewAdaptiveLimiter(10, 100*time.Millisecond, 0.05)
|
||||
|
||||
// Record some fast responses
|
||||
for i := 0; i < 15; i++ {
|
||||
limiter.RecordResponse(50*time.Millisecond, true)
|
||||
}
|
||||
|
||||
// Check that rate might have increased (or at least not decreased significantly)
|
||||
initialRate := limiter.GetCurrentRate()
|
||||
if initialRate < 8 { // Should be at least close to original rate
|
||||
t.Errorf("Rate should not decrease significantly with good responses, got %d", initialRate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdaptiveLimiter_SlowResponses(t *testing.T) {
|
||||
limiter := NewAdaptiveLimiter(10, 100*time.Millisecond, 0.05)
|
||||
|
||||
// Record some slow responses
|
||||
for i := 0; i < 15; i++ {
|
||||
limiter.RecordResponse(500*time.Millisecond, true) // 5x target latency
|
||||
}
|
||||
|
||||
// Rate should decrease
|
||||
finalRate := limiter.GetCurrentRate()
|
||||
if finalRate >= 10 {
|
||||
t.Errorf("Rate should decrease with slow responses, got %d", finalRate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdaptiveLimiter_HighErrorRate(t *testing.T) {
|
||||
limiter := NewAdaptiveLimiter(10, 100*time.Millisecond, 0.05)
|
||||
|
||||
// Record responses with high error rate
|
||||
for i := 0; i < 15; i++ {
|
||||
success := i < 5 // Only first 5 are successful (33% success rate)
|
||||
limiter.RecordResponse(50*time.Millisecond, success)
|
||||
}
|
||||
|
||||
// Rate should decrease due to high error rate
|
||||
finalRate := limiter.GetCurrentRate()
|
||||
if finalRate >= 10 {
|
||||
t.Errorf("Rate should decrease with high error rate, got %d", finalRate)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkTokenBucket_Allow(b *testing.B) {
|
||||
bucket := NewTokenBucket(1000, 1000)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bucket.Allow()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFixedWindowLimiter_Allow(b *testing.B) {
|
||||
limiter := NewFixedWindowLimiter(1000, 1*time.Second)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
limiter.Allow()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAdaptiveLimiter_Allow(b *testing.B) {
|
||||
limiter := NewAdaptiveLimiter(1000, 100*time.Millisecond, 0.05)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
limiter.Allow()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAdaptiveLimiter_RecordResponse(b *testing.B) {
|
||||
limiter := NewAdaptiveLimiter(1000, 100*time.Millisecond, 0.05)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
limiter.RecordResponse(50*time.Millisecond, true)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user