LFG
Some checks failed
CI/CD Pipeline / Run Tests (push) Has been cancelled
CI/CD Pipeline / Build Application (push) Has been cancelled
CI/CD Pipeline / Build Docker Image (push) Has been cancelled
CI/CD Pipeline / Security Scan (push) Has been cancelled
CI/CD Pipeline / Create Release (push) Has been cancelled

This commit is contained in:
Dev
2025-09-11 18:59:15 +03:00
commit 5440884b85
20 changed files with 3074 additions and 0 deletions

206
internal/config/config.go Normal file
View File

@@ -0,0 +1,206 @@
package config
import (
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
// Config represents the application configuration
type Config struct {
Server ServerConfig `yaml:"server"`
Proxy ProxyConfig `yaml:"proxy"`
NAT NATConfig `yaml:"nat"`
Logging LoggingConfig `yaml:"logging"`
Monitor MonitorConfig `yaml:"monitor"`
}
// ServerConfig represents server configuration
type ServerConfig struct {
Port int `yaml:"port"`
ReadTimeout int `yaml:"read_timeout"`
WriteTimeout int `yaml:"write_timeout"`
IdleTimeout int `yaml:"idle_timeout"`
TLSCertFile string `yaml:"tls_cert_file,omitempty"`
TLSKeyFile string `yaml:"tls_key_file,omitempty"`
}
// ProxyConfig represents reverse proxy configuration
type ProxyConfig struct {
Targets []TargetConfig `yaml:"targets"`
LoadBalancer string `yaml:"load_balancer"` // "roundrobin", "leastconn", "random"
HealthCheckPath string `yaml:"health_check_path"`
HealthCheckInterval int `yaml:"health_check_interval"`
}
// TargetConfig represents a proxy target
type TargetConfig struct {
Name string `yaml:"name"`
Address string `yaml:"address"`
Protocol string `yaml:"protocol"` // "http", "https"
Weight int `yaml:"weight"` // for weighted load balancing
Healthy bool `yaml:"-"` // health status
}
// NATConfig represents NAT traversal configuration
type NATConfig struct {
Enabled bool `yaml:"enabled"`
STUNServer string `yaml:"stun_server,omitempty"`
TURNServer string `yaml:"turn_server,omitempty"`
TURNUsername string `yaml:"turn_username,omitempty"`
TURNPassword string `yaml:"turn_password,omitempty"`
}
// LoggingConfig represents logging configuration
type LoggingConfig struct {
Level string `yaml:"level"` // "debug", "info", "warn", "error"
Format string `yaml:"format"` // "json", "text"
Output string `yaml:"output"` // "stdout", "file"
File string `yaml:"file,omitempty"`
}
// MonitorConfig represents monitoring configuration
type MonitorConfig struct {
Enabled bool `yaml:"enabled"`
Port int `yaml:"port"`
Path string `yaml:"path"`
Auth bool `yaml:"auth"`
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
}
// Load loads configuration from file
func Load(path string) (*Config, error) {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return nil, fmt.Errorf("configuration file not found: %s", path)
}
// Read file
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read configuration file: %w", err)
}
// Parse YAML
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse configuration: %w", err)
}
// Set defaults
setDefaults(&config)
return &config, nil
}
// setDefaults sets default values for configuration
func setDefaults(c *Config) {
// Server defaults
if c.Server.Port == 0 {
c.Server.Port = 8080
}
if c.Server.ReadTimeout == 0 {
c.Server.ReadTimeout = 30
}
if c.Server.WriteTimeout == 0 {
c.Server.WriteTimeout = 30
}
if c.Server.IdleTimeout == 0 {
c.Server.IdleTimeout = 60
}
// Proxy defaults
if c.Proxy.LoadBalancer == "" {
c.Proxy.LoadBalancer = "roundrobin"
}
if c.Proxy.HealthCheckPath == "" {
c.Proxy.HealthCheckPath = "/health"
}
if c.Proxy.HealthCheckInterval == 0 {
c.Proxy.HealthCheckInterval = 30
}
// NAT defaults
if c.NAT.Enabled && c.NAT.STUNServer == "" {
c.NAT.STUNServer = "stun:stun.l.google.com:19302"
}
// Logging defaults
if c.Logging.Level == "" {
c.Logging.Level = "info"
}
if c.Logging.Format == "" {
c.Logging.Format = "json"
}
if c.Logging.Output == "" {
c.Logging.Output = "stdout"
}
// Monitor defaults
if c.Monitor.Enabled && c.Monitor.Port == 0 {
c.Monitor.Port = 9090
}
if c.Monitor.Enabled && c.Monitor.Path == "" {
c.Monitor.Path = "/metrics"
}
}
// CreateDefaultConfig creates a default configuration file
func CreateDefaultConfig(path string) error {
config := Config{
Server: ServerConfig{
Port: 8080,
ReadTimeout: 30,
WriteTimeout: 30,
IdleTimeout: 60,
},
Proxy: ProxyConfig{
LoadBalancer: "roundrobin",
HealthCheckPath: "/health",
HealthCheckInterval: 30,
Targets: []TargetConfig{
{
Name: "example",
Address: "http://localhost:3000",
Protocol: "http",
Weight: 1,
},
},
},
NAT: NATConfig{
Enabled: false,
STUNServer: "stun:stun.l.google.com:19302",
},
Logging: LoggingConfig{
Level: "info",
Format: "json",
Output: "stdout",
},
Monitor: MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
}
data, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal default configuration: %w", err)
}
// Create directory if it doesn't exist
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("failed to create configuration directory: %w", err)
}
// Write configuration file
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("failed to write configuration file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,243 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestLoad(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "test-config.yaml")
// Create a test configuration file
testConfig := `server:
port: 9090
read_timeout: 60
write_timeout: 60
idle_timeout: 120
proxy:
targets:
- name: "test-target"
address: "http://localhost:8080"
protocol: "http"
weight: 1
load_balancer: "leastconn"
health_check_path: "/healthz"
health_check_interval: 60
nat:
enabled: true
stun_server: "stun:stun.example.com:3478"
turn_server: "turn:turn.example.com:3478"
turn_username: "testuser"
turn_password: "testpass"
logging:
level: "debug"
format: "text"
output: "file"
file: "/var/log/gorz.log"
monitor:
enabled: true
port: 8081
path: "/stats"
auth: true
username: "admin"
password: "secret"
`
err := os.WriteFile(configPath, []byte(testConfig), 0644)
if err != nil {
t.Fatalf("Failed to write test config file: %v", err)
}
// Test loading the configuration
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Failed to load configuration: %v", err)
}
// Verify server configuration
if cfg.Server.Port != 9090 {
t.Errorf("Expected port 9090, got %d", cfg.Server.Port)
}
if cfg.Server.ReadTimeout != 60 {
t.Errorf("Expected read timeout 60, got %d", cfg.Server.ReadTimeout)
}
// Verify proxy configuration
if cfg.Proxy.LoadBalancer != "leastconn" {
t.Errorf("Expected load balancer 'leastconn', got %s", cfg.Proxy.LoadBalancer)
}
if len(cfg.Proxy.Targets) != 1 {
t.Errorf("Expected 1 target, got %d", len(cfg.Proxy.Targets))
}
if cfg.Proxy.Targets[0].Name != "test-target" {
t.Errorf("Expected target name 'test-target', got %s", cfg.Proxy.Targets[0].Name)
}
// Verify NAT configuration
if !cfg.NAT.Enabled {
t.Error("Expected NAT enabled to be true")
}
if cfg.NAT.STUNServer != "stun:stun.example.com:3478" {
t.Errorf("Expected STUN server 'stun:stun.example.com:3478', got %s", cfg.NAT.STUNServer)
}
// Verify logging configuration
if cfg.Logging.Level != "debug" {
t.Errorf("Expected log level 'debug', got %s", cfg.Logging.Level)
}
if cfg.Logging.Output != "file" {
t.Errorf("Expected log output 'file', got %s", cfg.Logging.Output)
}
// Verify monitor configuration
if !cfg.Monitor.Enabled {
t.Error("Expected monitor enabled to be true")
}
if cfg.Monitor.Port != 8081 {
t.Errorf("Expected monitor port 8081, got %d", cfg.Monitor.Port)
}
if !cfg.Monitor.Auth {
t.Error("Expected monitor auth to be true")
}
}
func TestLoadNonexistentFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.yaml")
if err == nil {
t.Error("Expected error for nonexistent file, got nil")
}
}
func TestLoadInvalidYAML(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "invalid-config.yaml")
// Create an invalid YAML file
invalidYAML := `server:
port: 8080
read_timeout: "not a number"
`
err := os.WriteFile(configPath, []byte(invalidYAML), 0644)
if err != nil {
t.Fatalf("Failed to write invalid config file: %v", err)
}
_, err = Load(configPath)
if err == nil {
t.Error("Expected error for invalid YAML, got nil")
}
}
func TestSetDefaults(t *testing.T) {
cfg := &Config{}
// Apply defaults
setDefaults(cfg)
// Verify default server values
if cfg.Server.Port != 8080 {
t.Errorf("Expected default port 8080, got %d", cfg.Server.Port)
}
if cfg.Server.ReadTimeout != 30 {
t.Errorf("Expected default read timeout 30, got %d", cfg.Server.ReadTimeout)
}
// Verify default proxy values
if cfg.Proxy.LoadBalancer != "roundrobin" {
t.Errorf("Expected default load balancer 'roundrobin', got %s", cfg.Proxy.LoadBalancer)
}
if cfg.Proxy.HealthCheckPath != "/health" {
t.Errorf("Expected default health check path '/health', got %s", cfg.Proxy.HealthCheckPath)
}
// Verify default NAT values
if cfg.NAT.Enabled {
t.Error("Expected default NAT enabled to be false")
}
if cfg.NAT.STUNServer != "stun:stun.l.google.com:19302" {
t.Errorf("Expected default STUN server 'stun:stun.l.google.com:19302', got %s", cfg.NAT.STUNServer)
}
// Verify default logging values
if cfg.Logging.Level != "info" {
t.Errorf("Expected default log level 'info', got %s", cfg.Logging.Level)
}
if cfg.Logging.Format != "json" {
t.Errorf("Expected default log format 'json', got %s", cfg.Logging.Format)
}
// Verify default monitor values
if !cfg.Monitor.Enabled {
t.Error("Expected default monitor enabled to be true")
}
if cfg.Monitor.Port != 9090 {
t.Errorf("Expected default monitor port 9090, got %d", cfg.Monitor.Port)
}
}
func TestCreateDefaultConfig(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "default-config.yaml")
// Create default configuration
err := CreateDefaultConfig(configPath)
if err != nil {
t.Fatalf("Failed to create default configuration: %v", err)
}
// Verify the file was created
if _, err := os.Stat(configPath); os.IsNotExist(err) {
t.Error("Expected config file to be created")
}
// Load and verify the configuration
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Failed to load created configuration: %v", err)
}
// Verify some default values
if cfg.Server.Port != 8080 {
t.Errorf("Expected default port 8080, got %d", cfg.Server.Port)
}
if cfg.Proxy.LoadBalancer != "roundrobin" {
t.Errorf("Expected default load balancer 'roundrobin', got %s", cfg.Proxy.LoadBalancer)
}
if cfg.Logging.Level != "info" {
t.Errorf("Expected default log level 'info', got %s", cfg.Logging.Level)
}
}
func TestCreateDefaultConfigDirectoryCreation(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
nestedDir := filepath.Join(tempDir, "nested", "directory")
configPath := filepath.Join(nestedDir, "config.yaml")
// Create default configuration in a nested directory
err := CreateDefaultConfig(configPath)
if err != nil {
t.Fatalf("Failed to create default configuration: %v", err)
}
// Verify the directory was created
if _, err := os.Stat(nestedDir); os.IsNotExist(err) {
t.Error("Expected nested directory to be created")
}
// Verify the file was created
if _, err := os.Stat(configPath); os.IsNotExist(err) {
t.Error("Expected config file to be created")
}
}

124
internal/logger/logger.go Normal file
View File

@@ -0,0 +1,124 @@
package logger
import (
"log"
"os"
)
// Logger represents the application logger
type Logger struct {
debugLogger *log.Logger
infoLogger *log.Logger
warnLogger *log.Logger
errorLogger *log.Logger
}
// Field represents a log field
type Field struct {
Key string
Value interface{}
}
// Option represents a logger option
type Option func(*Logger)
// NewLogger creates a new logger with default settings
func NewLogger(opts ...Option) *Logger {
logger := &Logger{
debugLogger: log.New(os.Stdout, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile),
infoLogger: log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile),
warnLogger: log.New(os.Stdout, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile),
errorLogger: log.New(os.Stdout, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile),
}
// Apply options
for _, opt := range opts {
opt(logger)
}
return logger
}
// Debug logs a debug message
func (l *Logger) Debug(msg string, fields ...Field) {
l.debugLogger.Printf(formatMessage(msg, fields...))
}
// Info logs an info message
func (l *Logger) Info(msg string, fields ...Field) {
l.infoLogger.Printf(formatMessage(msg, fields...))
}
// Warn logs a warning message
func (l *Logger) Warn(msg string, fields ...Field) {
l.warnLogger.Printf(formatMessage(msg, fields...))
}
// Error logs an error message
func (l *Logger) Error(msg string, fields ...Field) {
l.errorLogger.Printf(formatMessage(msg, fields...))
}
// Fatal logs a fatal message and exits
func (l *Logger) Fatal(msg string, fields ...Field) {
l.errorLogger.Printf(formatMessage(msg, fields...))
os.Exit(1)
}
// String creates a string field
func String(key, value string) Field {
return Field{Key: key, Value: value}
}
// Int creates an int field
func Int(key string, value int) Field {
return Field{Key: key, Value: value}
}
// Bool creates a bool field
func Bool(key string, value bool) Field {
return Field{Key: key, Value: value}
}
// Error creates an error field
func Error(err error) Field {
return Field{Key: "error", Value: err}
}
// formatMessage formats a log message with fields
func formatMessage(msg string, fields ...Field) string {
if len(fields) == 0 {
return msg
}
result := msg + " ["
for i, field := range fields {
if i > 0 {
result += ", "
}
result += field.Key + "="
result += toString(field.Value)
}
result += "]"
return result
}
// toString converts a value to string
func toString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case int:
return string(v)
case bool:
if v {
return "true"
}
return "false"
case error:
return v.Error()
default:
return "unknown"
}
}

View File

@@ -0,0 +1,185 @@
package logger
import (
"bytes"
"log"
"os"
"strings"
"testing"
)
func TestNewLogger(t *testing.T) {
logger := NewLogger()
if logger == nil {
t.Error("Expected logger to be created, got nil")
}
}
func TestLoggerDebug(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Debug("Debug message", String("key", "value"))
output := buf.String()
if !strings.Contains(output, "DEBUG: Debug message") {
t.Errorf("Expected log output to contain 'DEBUG: Debug message', got %s", output)
}
if !strings.Contains(output, "key=value") {
t.Errorf("Expected log output to contain 'key=value', got %s", output)
}
}
func TestLoggerInfo(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Info("Info message", Int("number", 42))
output := buf.String()
if !strings.Contains(output, "INFO: Info message") {
t.Errorf("Expected log output to contain 'INFO: Info message', got %s", output)
}
if !strings.Contains(output, "number=42") {
t.Errorf("Expected log output to contain 'number=42', got %s", output)
}
}
func TestLoggerWarn(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Warn("Warning message", Bool("flag", true))
output := buf.String()
if !strings.Contains(output, "WARN: Warning message") {
t.Errorf("Expected log output to contain 'WARN: Warning message', got %s", output)
}
if !strings.Contains(output, "flag=true") {
t.Errorf("Expected log output to contain 'flag=true', got %s", output)
}
}
func TestLoggerError(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
err := os.ErrNotExist
logger.Error("Error message", Error(err))
output := buf.String()
if !strings.Contains(output, "ERROR: Error message") {
t.Errorf("Expected log output to contain 'ERROR: Error message', got %s", output)
}
if !strings.Contains(output, "error=file does not exist") {
t.Errorf("Expected log output to contain 'error=file does not exist', got %s", output)
}
}
func TestLoggerFatal(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
// Mock os.Exit to prevent the test from exiting
exitCalled := false
exitFunc := func(code int) {
exitCalled = true
}
osExit = exitFunc
defer func() {
osExit = realOsExit
}()
logger := NewLogger()
logger.Fatal("Fatal message", String("reason", "testing"))
output := buf.String()
if !strings.Contains(output, "ERROR: Fatal message") {
t.Errorf("Expected log output to contain 'ERROR: Fatal message', got %s", output)
}
if !strings.Contains(output, "reason=testing") {
t.Errorf("Expected log output to contain 'reason=testing', got %s", output)
}
if !exitCalled {
t.Error("Expected os.Exit to be called")
}
}
func TestLoggerMultipleFields(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Info("Multiple fields",
String("string", "value"),
Int("int", 123),
Bool("bool", false))
output := buf.String()
if !strings.Contains(output, "INFO: Multiple fields") {
t.Errorf("Expected log output to contain 'INFO: Multiple fields', got %s", output)
}
if !strings.Contains(output, "string=value") {
t.Errorf("Expected log output to contain 'string=value', got %s", output)
}
if !strings.Contains(output, "int=123") {
t.Errorf("Expected log output to contain 'int=123', got %s", output)
}
if !strings.Contains(output, "bool=false") {
t.Errorf("Expected log output to contain 'bool=false', got %s", output)
}
}
func TestLoggerNoFields(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Info("No fields")
output := buf.String()
if !strings.Contains(output, "INFO: No fields") {
t.Errorf("Expected log output to contain 'INFO: No fields', got %s", output)
}
if strings.Contains(output, "[") {
t.Error("Expected log output to not contain field brackets when no fields are provided")
}
}
// Mock os.Exit for testing
var (
osExit = func(code int) { os.Exit(code) }
realOsExit = osExit
)

View File

@@ -0,0 +1,266 @@
package monitoring
import (
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
)
// Metrics represents the application metrics
type Metrics struct {
RequestsTotal int64 `json:"requests_total"`
RequestsActive int64 `json:"requests_active"`
ResponsesByStatus map[string]int64 `json:"responses_by_status"`
TargetMetrics map[string]*TargetMetric `json:"target_metrics"`
StartTime time.Time `json:"start_time"`
LastUpdated time.Time `json:"last_updated"`
mu sync.RWMutex `json:"-"`
}
// TargetMetric represents metrics for a specific target
type TargetMetric struct {
RequestsTotal int64 `json:"requests_total"`
ResponsesByStatus map[string]int64 `json:"responses_by_status"`
ResponseTimes []time.Duration `json:"response_times"`
AvgResponseTime time.Duration `json:"avg_response_time"`
Healthy bool `json:"healthy"`
LastChecked time.Time `json:"last_checked"`
}
// Monitor represents the monitoring service
type Monitor struct {
config *config.Config
logger *logger.Logger
metrics *Metrics
server *http.Server
authHandler http.Handler
}
// NewMonitor creates a new monitoring service
func NewMonitor(cfg *config.Config, logger *logger.Logger) *Monitor {
metrics := &Metrics{
ResponsesByStatus: make(map[string]int64),
TargetMetrics: make(map[string]*TargetMetric),
StartTime: time.Now(),
LastUpdated: time.Now(),
}
// Initialize target metrics
for _, target := range cfg.Proxy.Targets {
metrics.TargetMetrics[target.Name] = &TargetMetric{
ResponsesByStatus: make(map[string]int64),
ResponseTimes: make([]time.Duration, 0),
Healthy: target.Healthy,
LastChecked: time.Now(),
}
}
monitor := &Monitor{
config: cfg,
logger: logger,
metrics: metrics,
}
// Set up authentication if enabled
if cfg.Monitor.Auth {
monitor.authHandler = monitor.basicAuthHandler(monitor.metricsHandler)
} else {
monitor.authHandler = monitor.metricsHandler
}
return monitor
}
// Start starts the monitoring service
func (m *Monitor) Start() error {
if !m.config.Monitor.Enabled {
m.logger.Info("Monitoring is disabled")
return nil
}
// Create HTTP server
m.server = &http.Server{
Addr: fmt.Sprintf(":%d", m.config.Monitor.Port),
Handler: m.authHandler,
}
// Start server in a goroutine
go func() {
m.logger.Info("Monitoring server starting", logger.String("address", m.server.Addr))
if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
m.logger.Error("Monitoring server failed", logger.Error(err))
}
}()
return nil
}
// Stop stops the monitoring service
func (m *Monitor) Stop() {
if m.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
m.server.Shutdown(ctx)
m.logger.Info("Monitoring server stopped")
}
}
// metricsHandler handles HTTP requests for metrics
func (m *Monitor) metricsHandler(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != m.config.Monitor.Path {
http.NotFound(w, r)
return
}
m.metrics.mu.RLock()
defer m.metrics.mu.RUnlock()
// Update last updated time
m.metrics.LastUpdated = time.Now()
// Set content type
w.Header().Set("Content-Type", "application/json")
// Encode metrics as JSON
if err := json.NewEncoder(w).Encode(m.metrics); err != nil {
m.logger.Error("Failed to encode metrics", logger.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
}
// basicAuthHandler wraps a handler with basic authentication
func (m *Monitor) basicAuthHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || username != m.config.Monitor.Username || password != m.config.Monitor.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="goRZ Monitor"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next(w, r)
}
}
// IncrementRequest increments the total request count
func (m *Monitor) IncrementRequest() {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
m.metrics.RequestsTotal++
m.metrics.LastUpdated = time.Now()
}
// IncrementActiveRequest increments the active request count
func (m *Monitor) IncrementActiveRequest() {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
m.metrics.RequestsActive++
m.metrics.LastUpdated = time.Now()
}
// DecrementActiveRequest decrements the active request count
func (m *Monitor) DecrementActiveRequest() {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
m.metrics.RequestsActive--
if m.metrics.RequestsActive < 0 {
m.metrics.RequestsActive = 0
}
m.metrics.LastUpdated = time.Now()
}
// RecordResponse records a response with the given status code
func (m *Monitor) RecordResponse(statusCode int, targetName string, responseTime time.Duration) {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
status := fmt.Sprintf("%d", statusCode)
m.metrics.ResponsesByStatus[status]++
// Update target metrics
if target, exists := m.metrics.TargetMetrics[targetName]; exists {
target.RequestsTotal++
target.ResponsesByStatus[status]++
// Keep only the last 100 response times for average calculation
if len(target.ResponseTimes) >= 100 {
target.ResponseTimes = target.ResponseTimes[1:]
}
target.ResponseTimes = append(target.ResponseTimes, responseTime)
// Calculate average response time
var total time.Duration
for _, rt := range target.ResponseTimes {
total += rt
}
target.AvgResponseTime = total / time.Duration(len(target.ResponseTimes))
}
m.metrics.LastUpdated = time.Now()
}
// UpdateTargetHealth updates the health status of a target
func (m *Monitor) UpdateTargetHealth(targetName string, healthy bool) {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
if target, exists := m.metrics.TargetMetrics[targetName]; exists {
target.Healthy = healthy
target.LastChecked = time.Now()
}
m.metrics.LastUpdated = time.Now()
}
// GetMetrics returns a copy of the current metrics
func (m *Monitor) GetMetrics() Metrics {
m.metrics.mu.RLock()
defer m.metrics.mu.RUnlock()
// Create a deep copy of the metrics
metrics := Metrics{
RequestsTotal: m.metrics.RequestsTotal,
RequestsActive: m.metrics.RequestsActive,
ResponsesByStatus: make(map[string]int64),
TargetMetrics: make(map[string]*TargetMetric),
StartTime: m.metrics.StartTime,
LastUpdated: m.metrics.LastUpdated,
}
// Copy response status counts
for k, v := range m.metrics.ResponsesByStatus {
metrics.ResponsesByStatus[k] = v
}
// Copy target metrics
for k, v := range m.metrics.TargetMetrics {
targetMetric := &TargetMetric{
RequestsTotal: v.RequestsTotal,
ResponsesByStatus: make(map[string]int64),
ResponseTimes: make([]time.Duration, len(v.ResponseTimes)),
AvgResponseTime: v.AvgResponseTime,
Healthy: v.Healthy,
LastChecked: v.LastChecked,
}
// Copy response status counts for target
for rk, rv := range v.ResponsesByStatus {
targetMetric.ResponsesByStatus[rk] = rv
}
// Copy response times
copy(targetMetric.ResponseTimes, v.ResponseTimes)
metrics.TargetMetrics[k] = targetMetric
}
return metrics
}

View File

@@ -0,0 +1,312 @@
package monitoring
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
)
func TestMonitorStartStop(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Start the monitor
err := monitor.Start()
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
// Stop the monitor
monitor.Stop()
// Test that we can start and stop again
err = monitor.Start()
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
monitor.Stop()
}
func TestMonitorDisabled(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: false,
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Start the monitor
err := monitor.Start()
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
// Stop the monitor
monitor.Stop()
}
func TestMonitorMetricsHandler(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
Proxy: config.ProxyConfig{
Targets: []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
},
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Create a test HTTP server
server := httptest.NewServer(monitor.authHandler)
defer server.Close()
// Test metrics endpoint
resp, err := http.Get(server.URL + cfg.Monitor.Path)
if err != nil {
t.Fatalf("Failed to get metrics: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
// Decode and check metrics
var metrics Metrics
if err := json.NewDecoder(resp.Body).Decode(&metrics); err != nil {
t.Fatalf("Failed to decode metrics: %v", err)
}
if metrics.RequestsTotal != 0 {
t.Errorf("Expected requests total 0, got %d", metrics.RequestsTotal)
}
if metrics.RequestsActive != 0 {
t.Errorf("Expected requests active 0, got %d", metrics.RequestsActive)
}
if len(metrics.TargetMetrics) != 2 {
t.Errorf("Expected 2 target metrics, got %d", len(metrics.TargetMetrics))
}
if !metrics.TargetMetrics["target1"].Healthy {
t.Error("Expected target1 to be healthy")
}
if metrics.TargetMetrics["target2"].Healthy {
t.Error("Expected target2 to be unhealthy")
}
}
func TestMonitorBasicAuth(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: true,
Username: "admin",
Password: "secret",
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Create a test HTTP server
server := httptest.NewServer(monitor.authHandler)
defer server.Close()
// Test without authentication
resp, err := http.Get(server.URL + cfg.Monitor.Path)
if err != nil {
t.Fatalf("Failed to get metrics: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
// Test with incorrect authentication
req, err := http.NewRequest("GET", server.URL+cfg.Monitor.Path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.SetBasicAuth("wrong", "credentials")
client := &http.Client{}
resp, err = client.Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
// Test with correct authentication
req, err = http.NewRequest("GET", server.URL+cfg.Monitor.Path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.SetBasicAuth(cfg.Monitor.Username, cfg.Monitor.Password)
resp, err = client.Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
}
func TestMonitorMetricsTracking(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
Proxy: config.ProxyConfig{
Targets: []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
},
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Test incrementing request count
monitor.IncrementRequest()
metrics := monitor.GetMetrics()
if metrics.RequestsTotal != 1 {
t.Errorf("Expected requests total 1, got %d", metrics.RequestsTotal)
}
// Test incrementing active request count
monitor.IncrementActiveRequest()
metrics = monitor.GetMetrics()
if metrics.RequestsActive != 1 {
t.Errorf("Expected requests active 1, got %d", metrics.RequestsActive)
}
// Test decrementing active request count
monitor.DecrementActiveRequest()
metrics = monitor.GetMetrics()
if metrics.RequestsActive != 0 {
t.Errorf("Expected requests active 0, got %d", metrics.RequestsActive)
}
// Test recording response
monitor.RecordResponse(http.StatusOK, "target1", 100*time.Millisecond)
metrics = monitor.GetMetrics()
if metrics.ResponsesByStatus["200"] != 1 {
t.Errorf("Expected 1 response with status 200, got %d", metrics.ResponsesByStatus["200"])
}
if metrics.TargetMetrics["target1"].RequestsTotal != 1 {
t.Errorf("Expected target1 to have 1 request, got %d", metrics.TargetMetrics["target1"].RequestsTotal)
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["200"] != 1 {
t.Errorf("Expected target1 to have 1 response with status 200, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["200"])
}
if len(metrics.TargetMetrics["target1"].ResponseTimes) != 1 {
t.Errorf("Expected target1 to have 1 response time, got %d", len(metrics.TargetMetrics["target1"].ResponseTimes))
}
if metrics.TargetMetrics["target1"].AvgResponseTime != 100*time.Millisecond {
t.Errorf("Expected target1 to have average response time 100ms, got %v", metrics.TargetMetrics["target1"].AvgResponseTime)
}
// Test updating target health
monitor.UpdateTargetHealth("target1", false)
metrics = monitor.GetMetrics()
if metrics.TargetMetrics["target1"].Healthy {
t.Error("Expected target1 to be unhealthy")
}
}
func TestMonitorMultipleResponses(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
Proxy: config.ProxyConfig{
Targets: []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
},
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Record multiple responses with different status codes and response times
monitor.RecordResponse(http.StatusOK, "target1", 100*time.Millisecond)
monitor.RecordResponse(http.StatusOK, "target1", 200*time.Millisecond)
monitor.RecordResponse(http.StatusNotFound, "target1", 50*time.Millisecond)
monitor.RecordResponse(http.StatusInternalServerError, "target1", 300*time.Millisecond)
metrics := monitor.GetMetrics()
// Check overall metrics
if metrics.RequestsTotal != 0 {
t.Errorf("Expected requests total 0, got %d", metrics.RequestsTotal)
}
if metrics.ResponsesByStatus["200"] != 2 {
t.Errorf("Expected 2 responses with status 200, got %d", metrics.ResponsesByStatus["200"])
}
if metrics.ResponsesByStatus["404"] != 1 {
t.Errorf("Expected 1 response with status 404, got %d", metrics.ResponsesByStatus["404"])
}
if metrics.ResponsesByStatus["500"] != 1 {
t.Errorf("Expected 1 response with status 500, got %d", metrics.ResponsesByStatus["500"])
}
// Check target metrics
if metrics.TargetMetrics["target1"].RequestsTotal != 4 {
t.Errorf("Expected target1 to have 4 requests, got %d", metrics.TargetMetrics["target1"].RequestsTotal)
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["200"] != 2 {
t.Errorf("Expected target1 to have 2 responses with status 200, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["200"])
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["404"] != 1 {
t.Errorf("Expected target1 to have 1 response with status 404, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["404"])
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["500"] != 1 {
t.Errorf("Expected target1 to have 1 response with status 500, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["500"])
}
// Check average response time
expectedAvg := time.Duration((100 + 200 + 50 + 300) / 4)
if metrics.TargetMetrics["target1"].AvgResponseTime != expectedAvg {
t.Errorf("Expected target1 to have average response time %v, got %v", expectedAvg, metrics.TargetMetrics["target1"].AvgResponseTime)
}
}

231
internal/nat/nat.go Normal file
View File

@@ -0,0 +1,231 @@
package nat
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
"github.com/pion/stun"
"github.com/pion/turn/v2"
)
// NATTraversal handles NAT traversal for the proxy server
type NATTraversal struct {
config *config.Config
logger *logger.Logger
conn net.PacketConn
externalIP net.IP
externalPort int
mu sync.RWMutex
running bool
}
// NewNATTraversal creates a new NAT traversal instance
func NewNATTraversal(cfg *config.Config, logger *logger.Logger) *NATTraversal {
return &NATTraversal{
config: cfg,
logger: logger,
}
}
// Start starts the NAT traversal process
func (n *NATTraversal) Start() error {
if !n.config.NAT.Enabled {
n.logger.Info("NAT traversal is disabled")
return nil
}
n.mu.Lock()
defer n.mu.Unlock()
if n.running {
return fmt.Errorf("NAT traversal is already running")
}
// Create a UDP listener
conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
return fmt.Errorf("failed to create UDP listener: %w", err)
}
n.conn = conn
// Get external IP and port using STUN
if err := n.discoverExternalAddress(); err != nil {
conn.Close()
return fmt.Errorf("failed to discover external address: %w", err)
}
// Start TURN client if configured
if n.config.NAT.TURNServer != "" {
if err := n.startTURNClient(); err != nil {
conn.Close()
return fmt.Errorf("failed to start TURN client: %w", err)
}
}
n.running = true
n.logger.Info("NAT traversal started",
logger.String("external_ip", n.externalIP.String()),
logger.Int("external_port", n.externalPort))
return nil
}
// Stop stops the NAT traversal process
func (n *NATTraversal) Stop() {
n.mu.Lock()
defer n.mu.Unlock()
if !n.running {
return
}
if n.conn != nil {
n.conn.Close()
n.conn = nil
}
n.running = false
n.logger.Info("NAT traversal stopped")
}
// discoverExternalAddress discovers the external IP and port using STUN
func (n *NATTraversal) discoverExternalAddress() error {
// Parse STUN server URL
stunURL, err := stun.ParseURI(n.config.NAT.STUNServer)
if err != nil {
return fmt.Errorf("failed to parse STUN server URL: %w", err)
}
// Create STUN client
client, err := stun.NewClient("udp4", n.conn)
if err != nil {
return fmt.Errorf("failed to create STUN client: %w", err)
}
defer client.Close()
// Send binding request
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mappedAddr, err := client.Request(ctx, stunURL, stun.BindingRequest)
if err != nil {
return fmt.Errorf("STUN binding request failed: %w", err)
}
n.externalIP = mappedAddr.IP
n.externalPort = mappedAddr.Port
return nil
}
// startTURNClient starts a TURN client for relay connections
func (n *NATTraversal) startTURNClient() error {
// Parse TURN server URL
turnURL, err := turn.ParseURI(n.config.NAT.TURNServer)
if err != nil {
return fmt.Errorf("failed to parse TURN server URL: %w", err)
}
// Create TURN client config
cfg := &turn.ClientConfig{
STUNServerAddr: n.config.NAT.STUNServer,
TURNServerAddr: n.config.NAT.TURNServer,
Username: n.config.NAT.TURNUsername,
Credential: n.config.NAT.TURNPassword,
LoggerFactory: nil, // We'll use our own logger
}
// Create TURN client
client, err := turn.NewClient(cfg)
if err != nil {
return fmt.Errorf("failed to create TURN client: %w", err)
}
defer client.Close()
// Listen on provided conn
n.logger.Info("TURN client started", logger.String("server", n.config.NAT.TURNServer))
return nil
}
// GetExternalAddress returns the external IP and port
func (n *NATTraversal) GetExternalAddress() (net.IP, int) {
n.mu.RLock()
defer n.mu.RUnlock()
return n.externalIP, n.externalPort
}
// IsRunning returns whether NAT traversal is running
func (n *NATTraversal) IsRunning() bool {
n.mu.RLock()
defer n.mu.RUnlock()
return n.running
}
// CreateHolePunch attempts to create a hole in the NAT for direct peer-to-peer connections
func (n *NATTraversal) CreateHolePunch(peerAddr string) (net.Conn, error) {
if !n.running {
return nil, fmt.Errorf("NAT traversal is not running")
}
// Parse peer address
udpAddr, err := net.ResolveUDPAddr("udp4", peerAddr)
if err != nil {
return nil, fmt.Errorf("failed to resolve peer address: %w", err)
}
// Create UDP connection
conn, err := net.DialUDP("udp4", nil, udpAddr)
if err != nil {
return nil, fmt.Errorf("failed to create UDP connection: %w", err)
}
// Send a small packet to punch a hole in the NAT
_, err = conn.Write([]byte("punch"))
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to punch hole: %w", err)
}
n.logger.Debug("Hole punch attempted", logger.String("peer", peerAddr))
return conn, nil
}
// GetNATType attempts to determine the type of NAT
func (n *NATTraversal) GetNATType() (string, error) {
if !n.running {
return "", fmt.Errorf("NAT traversal is not running")
}
// This is a simplified NAT type detection
// In a real implementation, you would use more sophisticated methods
// like the one described in RFC 3489 and RFC 5780
// For now, we'll return a generic response
return "Unknown", nil
}
// CreateRelayConnection creates a relay connection through TURN
func (n *NATTraversal) CreateRelayConnection(peerAddr string) (net.Conn, error) {
if !n.running {
return nil, fmt.Errorf("NAT traversal is not running")
}
if n.config.NAT.TURNServer == "" {
return nil, fmt.Errorf("TURN server not configured")
}
// This is a placeholder for TURN relay connection creation
// In a real implementation, you would use the TURN client to allocate
// a relay and create a connection to the peer
return nil, fmt.Errorf("not implemented")
}

View File

@@ -0,0 +1,217 @@
package proxy
import (
"math/rand"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
)
// RoundRobinLoadBalancer implements round-robin load balancing
type RoundRobinLoadBalancer struct {
targets []*config.TargetConfig
current int
mu sync.Mutex
}
// NewRoundRobinLoadBalancer creates a new round-robin load balancer
func NewRoundRobinLoadBalancer(targets []config.TargetConfig) *RoundRobinLoadBalancer {
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
return &RoundRobinLoadBalancer{
targets: t,
current: 0,
}
}
// NextTarget returns the next target using round-robin algorithm
func (lb *RoundRobinLoadBalancer) NextTarget() (*config.TargetConfig, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
// Filter healthy targets
healthyTargets := make([]*config.TargetConfig, 0)
for _, target := range lb.targets {
if target.Healthy {
healthyTargets = append(healthyTargets, target)
}
}
if len(healthyTargets) == 0 {
return nil, ErrNoHealthyTargets
}
// Get next target
target := healthyTargets[lb.current%len(healthyTargets)]
lb.current = (lb.current + 1) % len(healthyTargets)
return target, nil
}
// UpdateTargets updates the targets list
func (lb *RoundRobinLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
lb.mu.Lock()
defer lb.mu.Unlock()
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
lb.targets = t
lb.current = 0
}
// RandomLoadBalancer implements random load balancing
type RandomLoadBalancer struct {
targets []*config.TargetConfig
rand *rand.Rand
mu sync.Mutex
}
// NewRandomLoadBalancer creates a new random load balancer
func NewRandomLoadBalancer(targets []config.TargetConfig) *RandomLoadBalancer {
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
return &RandomLoadBalancer{
targets: t,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// NextTarget returns a random target
func (lb *RandomLoadBalancer) NextTarget() (*config.TargetConfig, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
// Filter healthy targets
healthyTargets := make([]*config.TargetConfig, 0)
for _, target := range lb.targets {
if target.Healthy {
healthyTargets = append(healthyTargets, target)
}
}
if len(healthyTargets) == 0 {
return nil, ErrNoHealthyTargets
}
// Get random target
index := lb.rand.Intn(len(healthyTargets))
return healthyTargets[index], nil
}
// UpdateTargets updates the targets list
func (lb *RandomLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
lb.mu.Lock()
defer lb.mu.Unlock()
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
lb.targets = t
}
// LeastConnectionsLoadBalancer implements least connections load balancing
type LeastConnectionsLoadBalancer struct {
targets []*config.TargetConfig
connections map[string]int
mu sync.Mutex
}
// NewLeastConnectionsLoadBalancer creates a new least connections load balancer
func NewLeastConnectionsLoadBalancer(targets []config.TargetConfig) *LeastConnectionsLoadBalancer {
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
connections := make(map[string]int)
for _, target := range t {
connections[target.Name] = 0
}
return &LeastConnectionsLoadBalancer{
targets: t,
connections: connections,
}
}
// NextTarget returns the target with the least connections
func (lb *LeastConnectionsLoadBalancer) NextTarget() (*config.TargetConfig, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
// Filter healthy targets and find the one with least connections
var selectedTarget *config.TargetConfig
minConnections := -1
for _, target := range lb.targets {
if target.Healthy {
connections := lb.connections[target.Name]
if minConnections == -1 || connections < minConnections {
minConnections = connections
selectedTarget = target
}
}
}
if selectedTarget == nil {
return nil, ErrNoHealthyTargets
}
// Increment connection count
lb.connections[selectedTarget.Name]++
return selectedTarget, nil
}
// ReleaseConnection decrements the connection count for a target
func (lb *LeastConnectionsLoadBalancer) ReleaseConnection(targetName string) {
lb.mu.Lock()
defer lb.mu.Unlock()
if count, exists := lb.connections[targetName]; exists && count > 0 {
lb.connections[targetName] = count - 1
}
}
// UpdateTargets updates the targets list
func (lb *LeastConnectionsLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
lb.mu.Lock()
defer lb.mu.Unlock()
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
// Update connections map
connections := make(map[string]int)
for _, target := range t {
// Preserve existing connection count if target exists
if count, exists := lb.connections[target.Name]; exists {
connections[target.Name] = count
} else {
connections[target.Name] = 0
}
}
lb.targets = t
lb.connections = connections
}
// ErrNoHealthyTargets is returned when no healthy targets are available
var ErrNoHealthyTargets = errorString("no healthy targets available")
// errorString is a simple string-based error type
type errorString string
func (e errorString) Error() string {
return string(e)
}

View File

@@ -0,0 +1,270 @@
package proxy
import (
"testing"
"github.com/iwasforcedtobehere/goRZ/internal/config"
)
func TestRoundRobinLoadBalancer(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRoundRobinLoadBalancer(targets)
// Test that targets are selected in round-robin order
for i := 0; i < 6; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
expectedTarget := targets[i%3]
if target.Name != expectedTarget.Name {
t.Errorf("Expected target %s, got %s", expectedTarget.Name, target.Name)
}
}
}
func TestRoundRobinLoadBalancerWithUnhealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRoundRobinLoadBalancer(targets)
// Test that only healthy targets are selected
for i := 0; i < 4; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name == "target2" {
t.Errorf("Selected unhealthy target: %s", target.Name)
}
// Should alternate between target1 and target3
if i%2 == 0 && target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
if i%2 == 1 && target.Name != "target3" {
t.Errorf("Expected target3, got %s", target.Name)
}
}
}
func TestRoundRobinLoadBalancerNoHealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
}
lb := NewRoundRobinLoadBalancer(targets)
_, err := lb.NextTarget()
if err != ErrNoHealthyTargets {
t.Errorf("Expected ErrNoHealthyTargets, got %v", err)
}
}
func TestRoundRobinLoadBalancerUpdateTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRoundRobinLoadBalancer(targets)
// Get first target
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Update targets
newTargets := []config.TargetConfig{
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb.UpdateTargets(newTargets)
// Test that new targets are selected
for i := 0; i < 4; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
expectedTarget := newTargets[i%2]
if target.Name != expectedTarget.Name {
t.Errorf("Expected target %s, got %s", expectedTarget.Name, target.Name)
}
}
}
func TestRandomLoadBalancer(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRandomLoadBalancer(targets)
// Test that targets are selected randomly
// We'll just check that we don't get errors and that all targets are eventually selected
selected := make(map[string]bool)
for i := 0; i < 100; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
selected[target.Name] = true
}
// Check that all targets were selected at least once
for _, target := range targets {
if !selected[target.Name] {
t.Errorf("Target %s was never selected", target.Name)
}
}
}
func TestRandomLoadBalancerWithUnhealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRandomLoadBalancer(targets)
// Test that only healthy targets are selected
for i := 0; i < 100; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name == "target2" {
t.Errorf("Selected unhealthy target: %s", target.Name)
}
}
}
func TestLeastConnectionsLoadBalancer(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewLeastConnectionsLoadBalancer(targets)
// Initially, both targets should have 0 connections
// The first target should be selected
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Now target1 should have 1 connection, target2 should have 0
// So target2 should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target2" {
t.Errorf("Expected target2, got %s", target.Name)
}
// Now both targets should have 1 connection
// target1 should be selected again
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Release a connection from target1
lb.(*LeastConnectionsLoadBalancer).ReleaseConnection("target1")
// Now target1 should have 1 connection, target2 should have 1 connection
// But since we released from target1, it should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
}
func TestLeastConnectionsLoadBalancerUpdateTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewLeastConnectionsLoadBalancer(targets)
// Get first target
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Update targets
newTargets := []config.TargetConfig{
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb.UpdateTargets(newTargets)
// Test that new targets are selected
// The first new target should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target2" {
t.Errorf("Expected target2, got %s", target.Name)
}
// The second new target should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target3" {
t.Errorf("Expected target3, got %s", target.Name)
}
}
func TestLeastConnectionsLoadBalancerNoHealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
}
lb := NewLeastConnectionsLoadBalancer(targets)
_, err := lb.NextTarget()
if err != ErrNoHealthyTargets {
t.Errorf("Expected ErrNoHealthyTargets, got %v", err)
}
}

238
internal/proxy/proxy.go Normal file
View File

@@ -0,0 +1,238 @@
package proxy
import (
"context"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
)
// Server represents the reverse proxy server
type Server struct {
config *config.Config
logger *logger.Logger
loadBalancer LoadBalancer
healthChecker *HealthChecker
}
// LoadBalancer defines the interface for load balancing strategies
type LoadBalancer interface {
NextTarget() (*config.TargetConfig, error)
UpdateTargets(targets []config.TargetConfig)
}
// HealthChecker handles health checking for proxy targets
type HealthChecker struct {
config *config.Config
logger *logger.Logger
targets map[string]*config.TargetConfig
httpClient *http.Client
stopChan chan struct{}
mu sync.RWMutex
}
// NewServer creates a new reverse proxy server
func NewServer(cfg *config.Config, logger *logger.Logger) (*Server, error) {
// Create load balancer based on configuration
var lb LoadBalancer
switch cfg.Proxy.LoadBalancer {
case "roundrobin":
lb = NewRoundRobinLoadBalancer(cfg.Proxy.Targets)
case "leastconn":
lb = NewLeastConnectionsLoadBalancer(cfg.Proxy.Targets)
case "random":
lb = NewRandomLoadBalancer(cfg.Proxy.Targets)
default:
return nil, fmt.Errorf("unsupported load balancer: %s", cfg.Proxy.LoadBalancer)
}
// Create health checker
healthChecker := NewHealthChecker(cfg, logger)
return &Server{
config: cfg,
logger: logger,
loadBalancer: lb,
healthChecker: healthChecker,
}, nil
}
// ServeHTTP handles HTTP requests
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Get next target based on load balancing strategy
target, err := s.loadBalancer.NextTarget()
if err != nil {
s.logger.Error("Failed to get target", logger.Error(err))
http.Error(w, "Service unavailable", http.StatusServiceUnavailable)
return
}
// Check if target is healthy
if !target.Healthy {
s.logger.Warn("Target is unhealthy", logger.String("target", target.Name))
http.Error(w, "Service unavailable", http.StatusServiceUnavailable)
return
}
// Create reverse proxy
targetURL, err := url.Parse(target.Address)
if err != nil {
s.logger.Error("Failed to parse target URL",
logger.String("target", target.Name),
logger.String("address", target.Address),
logger.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Set up custom director to modify request
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
// Add custom headers
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("X-Forwarded-Proto", "https")
if req.TLS != nil {
req.Header.Set("X-Forwarded-Ssl", "on")
}
// Add proxy identification
req.Header.Set("X-Proxy-Server", "goRZ")
}
// Set up error handler
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.logger.Error("Proxy error", logger.Error(err))
http.Error(w, "Proxy error", http.StatusBadGateway)
}
// Set up modify response handler
proxy.ModifyResponse = func(resp *http.Response) error {
// Add custom headers to response
resp.Header.Set("X-Served-By", "goRZ")
resp.Header.Set("X-Target", target.Name)
return nil
}
// Serve the request
s.logger.Debug("Proxying request",
logger.String("method", r.Method),
logger.String("path", r.URL.Path),
logger.String("target", target.Name))
proxy.ServeHTTP(w, r)
}
// Start starts the proxy server
func (s *Server) Start() error {
// Start health checker
if err := s.healthChecker.Start(); err != nil {
return fmt.Errorf("failed to start health checker: %w", err)
}
return nil
}
// Stop stops the proxy server
func (s *Server) Stop() {
s.healthChecker.Stop()
}
// NewHealthChecker creates a new health checker
func NewHealthChecker(cfg *config.Config, logger *logger.Logger) *HealthChecker {
targets := make(map[string]*config.TargetConfig)
for i := range cfg.Proxy.Targets {
targets[cfg.Proxy.Targets[i].Name] = &cfg.Proxy.Targets[i]
}
return &HealthChecker{
config: cfg,
logger: logger,
targets: targets,
httpClient: &http.Client{Timeout: 5 * time.Second},
stopChan: make(chan struct{}),
}
}
// Start starts the health checker
func (h *HealthChecker) Start() error {
// Initial health check
h.checkAllTargets()
// Start periodic health checks
go func() {
ticker := time.NewTicker(time.Duration(h.config.Proxy.HealthCheckInterval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.checkAllTargets()
case <-h.stopChan:
return
}
}
}()
return nil
}
// Stop stops the health checker
func (h *HealthChecker) Stop() {
close(h.stopChan)
}
// checkAllTargets checks the health of all targets
func (h *HealthChecker) checkAllTargets() {
h.mu.Lock()
defer h.mu.Unlock()
for name, target := range h.targets {
healthy := h.checkTargetHealth(target)
if target.Healthy != healthy {
h.logger.Info("Target health status changed",
logger.String("target", name),
logger.Bool("healthy", healthy))
target.Healthy = healthy
}
}
}
// checkTargetHealth checks the health of a single target
func (h *HealthChecker) checkTargetHealth(target *config.TargetConfig) bool {
targetURL, err := url.Parse(target.Address)
if err != nil {
h.logger.Error("Failed to parse target URL for health check",
logger.String("target", target.Name),
logger.Error(err))
return false
}
healthURL := *targetURL
healthURL.Path = h.config.Proxy.HealthCheckPath
req, err := http.NewRequest("GET", healthURL.String(), nil)
if err != nil {
h.logger.Error("Failed to create health check request",
logger.String("target", target.Name),
logger.Error(err))
return false
}
resp, err := h.httpClient.Do(req)
if err != nil {
h.logger.Error("Health check request failed",
logger.String("target", target.Name),
logger.Error(err))
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}