256 lines
5.5 KiB
Go
256 lines
5.5 KiB
Go
package engine
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"git.gostacks.org/iwasforcedtobehere/stroke/pkg/config"
|
||
"git.gostacks.org/iwasforcedtobehere/stroke/pkg/metrics"
|
||
)
|
||
|
||
// Engine represents the main stress testing engine
|
||
type Engine struct {
|
||
config *config.Config
|
||
client *http.Client
|
||
metrics *metrics.Collector
|
||
ctx context.Context
|
||
cancel context.CancelFunc
|
||
}
|
||
|
||
// Result holds the execution results
|
||
type Result struct {
|
||
TotalRequests int64
|
||
SuccessRequests int64
|
||
FailedRequests int64
|
||
TotalDuration time.Duration
|
||
RequestsPerSec float64
|
||
Metrics *metrics.Results
|
||
}
|
||
|
||
// Worker represents a single worker goroutine
|
||
type Worker struct {
|
||
id int
|
||
engine *Engine
|
||
wg *sync.WaitGroup
|
||
}
|
||
|
||
// New creates a new stress testing engine
|
||
func New(cfg *config.Config) *Engine {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
|
||
// Configure HTTP client with reasonable defaults
|
||
client := &http.Client{
|
||
Timeout: time.Duration(cfg.Target.Timeout) * time.Second,
|
||
Transport: &http.Transport{
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 100,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
DisableCompression: false,
|
||
},
|
||
}
|
||
|
||
return &Engine{
|
||
config: cfg,
|
||
client: client,
|
||
metrics: metrics.NewCollector(),
|
||
ctx: ctx,
|
||
cancel: cancel,
|
||
}
|
||
}
|
||
|
||
// Run executes the stress test
|
||
func (e *Engine) Run() (*Result, error) {
|
||
fmt.Printf("🚀 Starting stress test against %s\n", e.config.Target.URL)
|
||
fmt.Printf("Workers: %d | Requests: %d | Duration: %v\n",
|
||
e.config.Load.Concurrency, e.config.Load.Requests, e.config.Load.Duration)
|
||
|
||
startTime := time.Now()
|
||
|
||
// Create worker pool
|
||
var wg sync.WaitGroup
|
||
requestChan := make(chan struct{}, e.config.Load.Requests)
|
||
|
||
// Start workers
|
||
for i := 0; i < e.config.Load.Concurrency; i++ {
|
||
wg.Add(1)
|
||
worker := &Worker{
|
||
id: i,
|
||
engine: e,
|
||
wg: &wg,
|
||
}
|
||
go worker.run(requestChan)
|
||
}
|
||
|
||
// Feed requests to workers
|
||
go e.feedRequests(requestChan)
|
||
|
||
// Wait for completion or timeout
|
||
done := make(chan struct{})
|
||
go func() {
|
||
wg.Wait()
|
||
close(done)
|
||
}()
|
||
|
||
select {
|
||
case <-done:
|
||
// All workers finished
|
||
case <-time.After(e.config.Load.Duration):
|
||
// Timeout reached
|
||
e.cancel()
|
||
wg.Wait()
|
||
case <-e.ctx.Done():
|
||
// Cancelled
|
||
wg.Wait()
|
||
}
|
||
|
||
endTime := time.Now()
|
||
duration := endTime.Sub(startTime)
|
||
|
||
// Collect results
|
||
metricsResults := e.metrics.GetResults()
|
||
|
||
result := &Result{
|
||
TotalRequests: metricsResults.TotalRequests,
|
||
SuccessRequests: metricsResults.SuccessRequests,
|
||
FailedRequests: metricsResults.FailedRequests,
|
||
TotalDuration: duration,
|
||
RequestsPerSec: float64(metricsResults.TotalRequests) / duration.Seconds(),
|
||
Metrics: metricsResults,
|
||
}
|
||
|
||
e.printResults(result)
|
||
return result, nil
|
||
}
|
||
|
||
// feedRequests sends requests to the worker pool
|
||
func (e *Engine) feedRequests(requestChan chan<- struct{}) {
|
||
defer close(requestChan)
|
||
|
||
if e.config.Load.Requests > 0 {
|
||
// Fixed number of requests
|
||
for i := 0; i < e.config.Load.Requests; i++ {
|
||
select {
|
||
case requestChan <- struct{}{}:
|
||
case <-e.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
} else {
|
||
// Duration-based requests
|
||
ticker := time.NewTicker(time.Duration(1000/e.config.Load.RequestsPerSecond) * time.Millisecond)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
select {
|
||
case requestChan <- struct{}{}:
|
||
case <-e.ctx.Done():
|
||
return
|
||
}
|
||
case <-e.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// run executes the worker loop
|
||
func (w *Worker) run(requestChan <-chan struct{}) {
|
||
defer w.wg.Done()
|
||
|
||
for {
|
||
select {
|
||
case <-requestChan:
|
||
w.executeRequest()
|
||
case <-w.engine.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// executeRequest performs a single HTTP request
|
||
func (w *Worker) executeRequest() {
|
||
startTime := time.Now()
|
||
|
||
// Create request
|
||
var body io.Reader
|
||
if w.engine.config.Target.Body != "" {
|
||
body = strings.NewReader(w.engine.config.Target.Body)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(
|
||
w.engine.ctx,
|
||
w.engine.config.Target.Method,
|
||
w.engine.config.Target.URL,
|
||
body,
|
||
)
|
||
if err != nil {
|
||
w.engine.metrics.RecordRequest(time.Since(startTime), 0, err)
|
||
return
|
||
}
|
||
|
||
// Add headers
|
||
for key, value := range w.engine.config.Target.Headers {
|
||
req.Header.Set(key, value)
|
||
}
|
||
|
||
// Execute request
|
||
resp, err := w.engine.client.Do(req)
|
||
duration := time.Since(startTime)
|
||
|
||
if err != nil {
|
||
w.engine.metrics.RecordRequest(duration, 0, err)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// Read response body (to ensure proper connection handling)
|
||
io.Copy(io.Discard, resp.Body)
|
||
|
||
// Record metrics
|
||
w.engine.metrics.RecordRequest(duration, resp.StatusCode, nil)
|
||
}
|
||
|
||
// Stop gracefully stops the engine
|
||
func (e *Engine) Stop() {
|
||
e.cancel()
|
||
}
|
||
|
||
// printResults displays the test results
|
||
func (e *Engine) printResults(result *Result) {
|
||
fmt.Printf("\n📊 Test Results:\n")
|
||
fmt.Printf("Duration: %.2fs | RPS: %.2f | Total: %d | Success: %d | Failed: %d\n",
|
||
result.TotalDuration.Seconds(),
|
||
result.RequestsPerSec,
|
||
result.TotalRequests,
|
||
result.SuccessRequests,
|
||
result.FailedRequests,
|
||
)
|
||
|
||
fmt.Printf("\nResponse Times:\n")
|
||
fmt.Printf(" Min: %v | Max: %v | Avg: %v\n",
|
||
result.Metrics.MinResponseTime,
|
||
result.Metrics.MaxResponseTime,
|
||
result.Metrics.AvgResponseTime,
|
||
)
|
||
|
||
fmt.Printf(" p50: %v | p90: %v | p95: %v | p99: %v\n",
|
||
result.Metrics.P50,
|
||
result.Metrics.P90,
|
||
result.Metrics.P95,
|
||
result.Metrics.P99,
|
||
)
|
||
|
||
if result.FailedRequests == 0 {
|
||
fmt.Printf("\n🎉 Fuck yeah! Your API handled it like a champ! 💪\n")
|
||
} else {
|
||
fmt.Printf("\n⚠️ Your API had some hiccups. Time to optimize! 🔧\n")
|
||
}
|
||
}
|