package ratelimit import ( "context" "sync" "time" ) // TokenBucket implements a token bucket rate limiter type TokenBucket struct { capacity int64 // Maximum number of tokens tokens int64 // Current number of tokens refillRate int64 // Tokens added per second lastRefill time.Time // Last refill time mu sync.Mutex // Protects token count } // NewTokenBucket creates a new token bucket rate limiter func NewTokenBucket(capacity, refillRate int64) *TokenBucket { return &TokenBucket{ capacity: capacity, tokens: capacity, refillRate: refillRate, lastRefill: time.Now(), } } // Allow checks if a request is allowed and consumes a token if available func (tb *TokenBucket) Allow() bool { tb.mu.Lock() defer tb.mu.Unlock() tb.refill() if tb.tokens > 0 { tb.tokens-- return true } return false } // Wait blocks until a token becomes available or context is cancelled func (tb *TokenBucket) Wait(ctx context.Context) error { for { if tb.Allow() { return nil } select { case <-ctx.Done(): return ctx.Err() case <-time.After(time.Millisecond * 10): // Small delay before retrying } } } // refill adds tokens based on elapsed time (caller must hold mutex) func (tb *TokenBucket) refill() { now := time.Now() elapsed := now.Sub(tb.lastRefill) tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate if tokensToAdd > 0 { tb.tokens = min(tb.capacity, tb.tokens+tokensToAdd) tb.lastRefill = now } } // min returns the minimum of two int64 values func min(a, b int64) int64 { if a < b { return a } return b } // Limiter interface for different rate limiting strategies type Limiter interface { Allow() bool Wait(ctx context.Context) error } // FixedWindowLimiter implements fixed window rate limiting type FixedWindowLimiter struct { limit int64 window time.Duration requests int64 windowStart time.Time mu sync.Mutex } // NewFixedWindowLimiter creates a new fixed window rate limiter func NewFixedWindowLimiter(limit int64, window time.Duration) *FixedWindowLimiter { return &FixedWindowLimiter{ limit: limit, window: window, windowStart: time.Now(), } } // Allow checks if a request is allowed within the current window func (fwl *FixedWindowLimiter) Allow() bool { fwl.mu.Lock() defer fwl.mu.Unlock() now := time.Now() // Check if we need to start a new window if now.Sub(fwl.windowStart) >= fwl.window { fwl.requests = 0 fwl.windowStart = now } if fwl.requests < fwl.limit { fwl.requests++ return true } return false } // Wait blocks until a request is allowed or context is cancelled func (fwl *FixedWindowLimiter) Wait(ctx context.Context) error { for { if fwl.Allow() { return nil } select { case <-ctx.Done(): return ctx.Err() case <-time.After(time.Millisecond * 10): // Small delay before retrying } } } // AdaptiveLimiter adjusts rate based on response times and error rates type AdaptiveLimiter struct { baseLimiter Limiter targetLatency time.Duration maxErrorRate float64 currentRate int64 measurements []measurement mu sync.RWMutex } type measurement struct { timestamp time.Time responseTime time.Duration success bool } // NewAdaptiveLimiter creates a new adaptive rate limiter func NewAdaptiveLimiter(baseRate int64, targetLatency time.Duration, maxErrorRate float64) *AdaptiveLimiter { return &AdaptiveLimiter{ baseLimiter: NewTokenBucket(baseRate, baseRate), targetLatency: targetLatency, maxErrorRate: maxErrorRate, currentRate: baseRate, measurements: make([]measurement, 0, 100), } } // Allow checks if a request is allowed func (al *AdaptiveLimiter) Allow() bool { return al.baseLimiter.Allow() } // Wait blocks until a request is allowed func (al *AdaptiveLimiter) Wait(ctx context.Context) error { return al.baseLimiter.Wait(ctx) } // RecordResponse records a response for adaptive adjustment func (al *AdaptiveLimiter) RecordResponse(responseTime time.Duration, success bool) { al.mu.Lock() defer al.mu.Unlock() // Add measurement m := measurement{ timestamp: time.Now(), responseTime: responseTime, success: success, } al.measurements = append(al.measurements, m) // Keep only recent measurements (last 100 or last minute) cutoff := time.Now().Add(-time.Minute) for i, measurement := range al.measurements { if measurement.timestamp.After(cutoff) { al.measurements = al.measurements[i:] break } } // Adjust rate if we have enough data if len(al.measurements) >= 10 { al.adjustRate() } } // adjustRate adjusts the rate based on recent measurements func (al *AdaptiveLimiter) adjustRate() { if len(al.measurements) == 0 { return } // Calculate average response time and error rate var totalResponseTime time.Duration var successCount int64 for _, m := range al.measurements { totalResponseTime += m.responseTime if m.success { successCount++ } } avgResponseTime := totalResponseTime / time.Duration(len(al.measurements)) errorRate := 1.0 - float64(successCount)/float64(len(al.measurements)) // Adjust rate based on metrics adjustmentFactor := 1.0 if avgResponseTime > al.targetLatency { // Response time too high, decrease rate adjustmentFactor = 0.9 } else if avgResponseTime < al.targetLatency/2 { // Response time very good, increase rate adjustmentFactor = 1.1 } if errorRate > al.maxErrorRate { // Error rate too high, decrease rate more aggressively adjustmentFactor *= 0.8 } // Apply adjustment newRate := int64(float64(al.currentRate) * adjustmentFactor) if newRate < 1 { newRate = 1 } if newRate != al.currentRate { al.currentRate = newRate // Update the base limiter with new rate al.baseLimiter = NewTokenBucket(newRate, newRate) } } // GetCurrentRate returns the current rate limit func (al *AdaptiveLimiter) GetCurrentRate() int64 { al.mu.RLock() defer al.mu.RUnlock() return al.currentRate }