LFG
Some checks failed
CI/CD Pipeline / Run Tests (push) Has been cancelled
CI/CD Pipeline / Build Application (push) Has been cancelled
CI/CD Pipeline / Build Docker Image (push) Has been cancelled
CI/CD Pipeline / Security Scan (push) Has been cancelled
CI/CD Pipeline / Create Release (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Run Tests (push) Has been cancelled
CI/CD Pipeline / Build Application (push) Has been cancelled
CI/CD Pipeline / Build Docker Image (push) Has been cancelled
CI/CD Pipeline / Security Scan (push) Has been cancelled
CI/CD Pipeline / Create Release (push) Has been cancelled
This commit is contained in:
206
internal/config/config.go
Normal file
206
internal/config/config.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config represents the application configuration
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Proxy ProxyConfig `yaml:"proxy"`
|
||||
NAT NATConfig `yaml:"nat"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
Monitor MonitorConfig `yaml:"monitor"`
|
||||
}
|
||||
|
||||
// ServerConfig represents server configuration
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
ReadTimeout int `yaml:"read_timeout"`
|
||||
WriteTimeout int `yaml:"write_timeout"`
|
||||
IdleTimeout int `yaml:"idle_timeout"`
|
||||
TLSCertFile string `yaml:"tls_cert_file,omitempty"`
|
||||
TLSKeyFile string `yaml:"tls_key_file,omitempty"`
|
||||
}
|
||||
|
||||
// ProxyConfig represents reverse proxy configuration
|
||||
type ProxyConfig struct {
|
||||
Targets []TargetConfig `yaml:"targets"`
|
||||
LoadBalancer string `yaml:"load_balancer"` // "roundrobin", "leastconn", "random"
|
||||
HealthCheckPath string `yaml:"health_check_path"`
|
||||
HealthCheckInterval int `yaml:"health_check_interval"`
|
||||
}
|
||||
|
||||
// TargetConfig represents a proxy target
|
||||
type TargetConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Address string `yaml:"address"`
|
||||
Protocol string `yaml:"protocol"` // "http", "https"
|
||||
Weight int `yaml:"weight"` // for weighted load balancing
|
||||
Healthy bool `yaml:"-"` // health status
|
||||
}
|
||||
|
||||
// NATConfig represents NAT traversal configuration
|
||||
type NATConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
STUNServer string `yaml:"stun_server,omitempty"`
|
||||
TURNServer string `yaml:"turn_server,omitempty"`
|
||||
TURNUsername string `yaml:"turn_username,omitempty"`
|
||||
TURNPassword string `yaml:"turn_password,omitempty"`
|
||||
}
|
||||
|
||||
// LoggingConfig represents logging configuration
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"` // "debug", "info", "warn", "error"
|
||||
Format string `yaml:"format"` // "json", "text"
|
||||
Output string `yaml:"output"` // "stdout", "file"
|
||||
File string `yaml:"file,omitempty"`
|
||||
}
|
||||
|
||||
// MonitorConfig represents monitoring configuration
|
||||
type MonitorConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
Path string `yaml:"path"`
|
||||
Auth bool `yaml:"auth"`
|
||||
Username string `yaml:"username,omitempty"`
|
||||
Password string `yaml:"password,omitempty"`
|
||||
}
|
||||
|
||||
// Load loads configuration from file
|
||||
func Load(path string) (*Config, error) {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("configuration file not found: %s", path)
|
||||
}
|
||||
|
||||
// Read file
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read configuration file: %w", err)
|
||||
}
|
||||
|
||||
// Parse YAML
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse configuration: %w", err)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
setDefaults(&config)
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// setDefaults sets default values for configuration
|
||||
func setDefaults(c *Config) {
|
||||
// Server defaults
|
||||
if c.Server.Port == 0 {
|
||||
c.Server.Port = 8080
|
||||
}
|
||||
if c.Server.ReadTimeout == 0 {
|
||||
c.Server.ReadTimeout = 30
|
||||
}
|
||||
if c.Server.WriteTimeout == 0 {
|
||||
c.Server.WriteTimeout = 30
|
||||
}
|
||||
if c.Server.IdleTimeout == 0 {
|
||||
c.Server.IdleTimeout = 60
|
||||
}
|
||||
|
||||
// Proxy defaults
|
||||
if c.Proxy.LoadBalancer == "" {
|
||||
c.Proxy.LoadBalancer = "roundrobin"
|
||||
}
|
||||
if c.Proxy.HealthCheckPath == "" {
|
||||
c.Proxy.HealthCheckPath = "/health"
|
||||
}
|
||||
if c.Proxy.HealthCheckInterval == 0 {
|
||||
c.Proxy.HealthCheckInterval = 30
|
||||
}
|
||||
|
||||
// NAT defaults
|
||||
if c.NAT.Enabled && c.NAT.STUNServer == "" {
|
||||
c.NAT.STUNServer = "stun:stun.l.google.com:19302"
|
||||
}
|
||||
|
||||
// Logging defaults
|
||||
if c.Logging.Level == "" {
|
||||
c.Logging.Level = "info"
|
||||
}
|
||||
if c.Logging.Format == "" {
|
||||
c.Logging.Format = "json"
|
||||
}
|
||||
if c.Logging.Output == "" {
|
||||
c.Logging.Output = "stdout"
|
||||
}
|
||||
|
||||
// Monitor defaults
|
||||
if c.Monitor.Enabled && c.Monitor.Port == 0 {
|
||||
c.Monitor.Port = 9090
|
||||
}
|
||||
if c.Monitor.Enabled && c.Monitor.Path == "" {
|
||||
c.Monitor.Path = "/metrics"
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDefaultConfig creates a default configuration file
|
||||
func CreateDefaultConfig(path string) error {
|
||||
config := Config{
|
||||
Server: ServerConfig{
|
||||
Port: 8080,
|
||||
ReadTimeout: 30,
|
||||
WriteTimeout: 30,
|
||||
IdleTimeout: 60,
|
||||
},
|
||||
Proxy: ProxyConfig{
|
||||
LoadBalancer: "roundrobin",
|
||||
HealthCheckPath: "/health",
|
||||
HealthCheckInterval: 30,
|
||||
Targets: []TargetConfig{
|
||||
{
|
||||
Name: "example",
|
||||
Address: "http://localhost:3000",
|
||||
Protocol: "http",
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
NAT: NATConfig{
|
||||
Enabled: false,
|
||||
STUNServer: "stun:stun.l.google.com:19302",
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
},
|
||||
Monitor: MonitorConfig{
|
||||
Enabled: true,
|
||||
Port: 9090,
|
||||
Path: "/metrics",
|
||||
Auth: false,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal default configuration: %w", err)
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create configuration directory: %w", err)
|
||||
}
|
||||
|
||||
// Write configuration file
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write configuration file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
243
internal/config/config_test.go
Normal file
243
internal/config/config_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "test-config.yaml")
|
||||
|
||||
// Create a test configuration file
|
||||
testConfig := `server:
|
||||
port: 9090
|
||||
read_timeout: 60
|
||||
write_timeout: 60
|
||||
idle_timeout: 120
|
||||
|
||||
proxy:
|
||||
targets:
|
||||
- name: "test-target"
|
||||
address: "http://localhost:8080"
|
||||
protocol: "http"
|
||||
weight: 1
|
||||
load_balancer: "leastconn"
|
||||
health_check_path: "/healthz"
|
||||
health_check_interval: 60
|
||||
|
||||
nat:
|
||||
enabled: true
|
||||
stun_server: "stun:stun.example.com:3478"
|
||||
turn_server: "turn:turn.example.com:3478"
|
||||
turn_username: "testuser"
|
||||
turn_password: "testpass"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
format: "text"
|
||||
output: "file"
|
||||
file: "/var/log/gorz.log"
|
||||
|
||||
monitor:
|
||||
enabled: true
|
||||
port: 8081
|
||||
path: "/stats"
|
||||
auth: true
|
||||
username: "admin"
|
||||
password: "secret"
|
||||
`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
// Test loading the configuration
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load configuration: %v", err)
|
||||
}
|
||||
|
||||
// Verify server configuration
|
||||
if cfg.Server.Port != 9090 {
|
||||
t.Errorf("Expected port 9090, got %d", cfg.Server.Port)
|
||||
}
|
||||
if cfg.Server.ReadTimeout != 60 {
|
||||
t.Errorf("Expected read timeout 60, got %d", cfg.Server.ReadTimeout)
|
||||
}
|
||||
|
||||
// Verify proxy configuration
|
||||
if cfg.Proxy.LoadBalancer != "leastconn" {
|
||||
t.Errorf("Expected load balancer 'leastconn', got %s", cfg.Proxy.LoadBalancer)
|
||||
}
|
||||
if len(cfg.Proxy.Targets) != 1 {
|
||||
t.Errorf("Expected 1 target, got %d", len(cfg.Proxy.Targets))
|
||||
}
|
||||
if cfg.Proxy.Targets[0].Name != "test-target" {
|
||||
t.Errorf("Expected target name 'test-target', got %s", cfg.Proxy.Targets[0].Name)
|
||||
}
|
||||
|
||||
// Verify NAT configuration
|
||||
if !cfg.NAT.Enabled {
|
||||
t.Error("Expected NAT enabled to be true")
|
||||
}
|
||||
if cfg.NAT.STUNServer != "stun:stun.example.com:3478" {
|
||||
t.Errorf("Expected STUN server 'stun:stun.example.com:3478', got %s", cfg.NAT.STUNServer)
|
||||
}
|
||||
|
||||
// Verify logging configuration
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("Expected log level 'debug', got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Logging.Output != "file" {
|
||||
t.Errorf("Expected log output 'file', got %s", cfg.Logging.Output)
|
||||
}
|
||||
|
||||
// Verify monitor configuration
|
||||
if !cfg.Monitor.Enabled {
|
||||
t.Error("Expected monitor enabled to be true")
|
||||
}
|
||||
if cfg.Monitor.Port != 8081 {
|
||||
t.Errorf("Expected monitor port 8081, got %d", cfg.Monitor.Port)
|
||||
}
|
||||
if !cfg.Monitor.Auth {
|
||||
t.Error("Expected monitor auth to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadNonexistentFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/config.yaml")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadInvalidYAML(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "invalid-config.yaml")
|
||||
|
||||
// Create an invalid YAML file
|
||||
invalidYAML := `server:
|
||||
port: 8080
|
||||
read_timeout: "not a number"
|
||||
`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(invalidYAML), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write invalid config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = Load(configPath)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid YAML, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDefaults(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
|
||||
// Apply defaults
|
||||
setDefaults(cfg)
|
||||
|
||||
// Verify default server values
|
||||
if cfg.Server.Port != 8080 {
|
||||
t.Errorf("Expected default port 8080, got %d", cfg.Server.Port)
|
||||
}
|
||||
if cfg.Server.ReadTimeout != 30 {
|
||||
t.Errorf("Expected default read timeout 30, got %d", cfg.Server.ReadTimeout)
|
||||
}
|
||||
|
||||
// Verify default proxy values
|
||||
if cfg.Proxy.LoadBalancer != "roundrobin" {
|
||||
t.Errorf("Expected default load balancer 'roundrobin', got %s", cfg.Proxy.LoadBalancer)
|
||||
}
|
||||
if cfg.Proxy.HealthCheckPath != "/health" {
|
||||
t.Errorf("Expected default health check path '/health', got %s", cfg.Proxy.HealthCheckPath)
|
||||
}
|
||||
|
||||
// Verify default NAT values
|
||||
if cfg.NAT.Enabled {
|
||||
t.Error("Expected default NAT enabled to be false")
|
||||
}
|
||||
if cfg.NAT.STUNServer != "stun:stun.l.google.com:19302" {
|
||||
t.Errorf("Expected default STUN server 'stun:stun.l.google.com:19302', got %s", cfg.NAT.STUNServer)
|
||||
}
|
||||
|
||||
// Verify default logging values
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("Expected default log level 'info', got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Logging.Format != "json" {
|
||||
t.Errorf("Expected default log format 'json', got %s", cfg.Logging.Format)
|
||||
}
|
||||
|
||||
// Verify default monitor values
|
||||
if !cfg.Monitor.Enabled {
|
||||
t.Error("Expected default monitor enabled to be true")
|
||||
}
|
||||
if cfg.Monitor.Port != 9090 {
|
||||
t.Errorf("Expected default monitor port 9090, got %d", cfg.Monitor.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDefaultConfig(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "default-config.yaml")
|
||||
|
||||
// Create default configuration
|
||||
err := CreateDefaultConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create default configuration: %v", err)
|
||||
}
|
||||
|
||||
// Verify the file was created
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
t.Error("Expected config file to be created")
|
||||
}
|
||||
|
||||
// Load and verify the configuration
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load created configuration: %v", err)
|
||||
}
|
||||
|
||||
// Verify some default values
|
||||
if cfg.Server.Port != 8080 {
|
||||
t.Errorf("Expected default port 8080, got %d", cfg.Server.Port)
|
||||
}
|
||||
if cfg.Proxy.LoadBalancer != "roundrobin" {
|
||||
t.Errorf("Expected default load balancer 'roundrobin', got %s", cfg.Proxy.LoadBalancer)
|
||||
}
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("Expected default log level 'info', got %s", cfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDefaultConfigDirectoryCreation(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
nestedDir := filepath.Join(tempDir, "nested", "directory")
|
||||
configPath := filepath.Join(nestedDir, "config.yaml")
|
||||
|
||||
// Create default configuration in a nested directory
|
||||
err := CreateDefaultConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create default configuration: %v", err)
|
||||
}
|
||||
|
||||
// Verify the directory was created
|
||||
if _, err := os.Stat(nestedDir); os.IsNotExist(err) {
|
||||
t.Error("Expected nested directory to be created")
|
||||
}
|
||||
|
||||
// Verify the file was created
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
t.Error("Expected config file to be created")
|
||||
}
|
||||
}
|
124
internal/logger/logger.go
Normal file
124
internal/logger/logger.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Logger represents the application logger
|
||||
type Logger struct {
|
||||
debugLogger *log.Logger
|
||||
infoLogger *log.Logger
|
||||
warnLogger *log.Logger
|
||||
errorLogger *log.Logger
|
||||
}
|
||||
|
||||
// Field represents a log field
|
||||
type Field struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Option represents a logger option
|
||||
type Option func(*Logger)
|
||||
|
||||
// NewLogger creates a new logger with default settings
|
||||
func NewLogger(opts ...Option) *Logger {
|
||||
logger := &Logger{
|
||||
debugLogger: log.New(os.Stdout, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile),
|
||||
infoLogger: log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile),
|
||||
warnLogger: log.New(os.Stdout, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile),
|
||||
errorLogger: log.New(os.Stdout, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(logger)
|
||||
}
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *Logger) Debug(msg string, fields ...Field) {
|
||||
l.debugLogger.Printf(formatMessage(msg, fields...))
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *Logger) Info(msg string, fields ...Field) {
|
||||
l.infoLogger.Printf(formatMessage(msg, fields...))
|
||||
}
|
||||
|
||||
// Warn logs a warning message
|
||||
func (l *Logger) Warn(msg string, fields ...Field) {
|
||||
l.warnLogger.Printf(formatMessage(msg, fields...))
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *Logger) Error(msg string, fields ...Field) {
|
||||
l.errorLogger.Printf(formatMessage(msg, fields...))
|
||||
}
|
||||
|
||||
// Fatal logs a fatal message and exits
|
||||
func (l *Logger) Fatal(msg string, fields ...Field) {
|
||||
l.errorLogger.Printf(formatMessage(msg, fields...))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// String creates a string field
|
||||
func String(key, value string) Field {
|
||||
return Field{Key: key, Value: value}
|
||||
}
|
||||
|
||||
// Int creates an int field
|
||||
func Int(key string, value int) Field {
|
||||
return Field{Key: key, Value: value}
|
||||
}
|
||||
|
||||
// Bool creates a bool field
|
||||
func Bool(key string, value bool) Field {
|
||||
return Field{Key: key, Value: value}
|
||||
}
|
||||
|
||||
// Error creates an error field
|
||||
func Error(err error) Field {
|
||||
return Field{Key: "error", Value: err}
|
||||
}
|
||||
|
||||
// formatMessage formats a log message with fields
|
||||
func formatMessage(msg string, fields ...Field) string {
|
||||
if len(fields) == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
result := msg + " ["
|
||||
for i, field := range fields {
|
||||
if i > 0 {
|
||||
result += ", "
|
||||
}
|
||||
result += field.Key + "="
|
||||
result += toString(field.Value)
|
||||
}
|
||||
result += "]"
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// toString converts a value to string
|
||||
func toString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case int:
|
||||
return string(v)
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case error:
|
||||
return v.Error()
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
185
internal/logger/logger_test.go
Normal file
185
internal/logger/logger_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewLogger(t *testing.T) {
|
||||
logger := NewLogger()
|
||||
if logger == nil {
|
||||
t.Error("Expected logger to be created, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerDebug(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
|
||||
logger := NewLogger()
|
||||
logger.Debug("Debug message", String("key", "value"))
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "DEBUG: Debug message") {
|
||||
t.Errorf("Expected log output to contain 'DEBUG: Debug message', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "key=value") {
|
||||
t.Errorf("Expected log output to contain 'key=value', got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerInfo(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
|
||||
logger := NewLogger()
|
||||
logger.Info("Info message", Int("number", 42))
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "INFO: Info message") {
|
||||
t.Errorf("Expected log output to contain 'INFO: Info message', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "number=42") {
|
||||
t.Errorf("Expected log output to contain 'number=42', got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerWarn(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
|
||||
logger := NewLogger()
|
||||
logger.Warn("Warning message", Bool("flag", true))
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "WARN: Warning message") {
|
||||
t.Errorf("Expected log output to contain 'WARN: Warning message', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "flag=true") {
|
||||
t.Errorf("Expected log output to contain 'flag=true', got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerError(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
|
||||
logger := NewLogger()
|
||||
err := os.ErrNotExist
|
||||
logger.Error("Error message", Error(err))
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "ERROR: Error message") {
|
||||
t.Errorf("Expected log output to contain 'ERROR: Error message', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "error=file does not exist") {
|
||||
t.Errorf("Expected log output to contain 'error=file does not exist', got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerFatal(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
|
||||
// Mock os.Exit to prevent the test from exiting
|
||||
exitCalled := false
|
||||
exitFunc := func(code int) {
|
||||
exitCalled = true
|
||||
}
|
||||
osExit = exitFunc
|
||||
defer func() {
|
||||
osExit = realOsExit
|
||||
}()
|
||||
|
||||
logger := NewLogger()
|
||||
logger.Fatal("Fatal message", String("reason", "testing"))
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "ERROR: Fatal message") {
|
||||
t.Errorf("Expected log output to contain 'ERROR: Fatal message', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "reason=testing") {
|
||||
t.Errorf("Expected log output to contain 'reason=testing', got %s", output)
|
||||
}
|
||||
if !exitCalled {
|
||||
t.Error("Expected os.Exit to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerMultipleFields(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
|
||||
logger := NewLogger()
|
||||
logger.Info("Multiple fields",
|
||||
String("string", "value"),
|
||||
Int("int", 123),
|
||||
Bool("bool", false))
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "INFO: Multiple fields") {
|
||||
t.Errorf("Expected log output to contain 'INFO: Multiple fields', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "string=value") {
|
||||
t.Errorf("Expected log output to contain 'string=value', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "int=123") {
|
||||
t.Errorf("Expected log output to contain 'int=123', got %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "bool=false") {
|
||||
t.Errorf("Expected log output to contain 'bool=false', got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerNoFields(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
|
||||
logger := NewLogger()
|
||||
logger.Info("No fields")
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "INFO: No fields") {
|
||||
t.Errorf("Expected log output to contain 'INFO: No fields', got %s", output)
|
||||
}
|
||||
if strings.Contains(output, "[") {
|
||||
t.Error("Expected log output to not contain field brackets when no fields are provided")
|
||||
}
|
||||
}
|
||||
|
||||
// Mock os.Exit for testing
|
||||
var (
|
||||
osExit = func(code int) { os.Exit(code) }
|
||||
realOsExit = osExit
|
||||
)
|
266
internal/monitoring/monitoring.go
Normal file
266
internal/monitoring/monitoring.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/config"
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/logger"
|
||||
)
|
||||
|
||||
// Metrics represents the application metrics
|
||||
type Metrics struct {
|
||||
RequestsTotal int64 `json:"requests_total"`
|
||||
RequestsActive int64 `json:"requests_active"`
|
||||
ResponsesByStatus map[string]int64 `json:"responses_by_status"`
|
||||
TargetMetrics map[string]*TargetMetric `json:"target_metrics"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
mu sync.RWMutex `json:"-"`
|
||||
}
|
||||
|
||||
// TargetMetric represents metrics for a specific target
|
||||
type TargetMetric struct {
|
||||
RequestsTotal int64 `json:"requests_total"`
|
||||
ResponsesByStatus map[string]int64 `json:"responses_by_status"`
|
||||
ResponseTimes []time.Duration `json:"response_times"`
|
||||
AvgResponseTime time.Duration `json:"avg_response_time"`
|
||||
Healthy bool `json:"healthy"`
|
||||
LastChecked time.Time `json:"last_checked"`
|
||||
}
|
||||
|
||||
// Monitor represents the monitoring service
|
||||
type Monitor struct {
|
||||
config *config.Config
|
||||
logger *logger.Logger
|
||||
metrics *Metrics
|
||||
server *http.Server
|
||||
authHandler http.Handler
|
||||
}
|
||||
|
||||
// NewMonitor creates a new monitoring service
|
||||
func NewMonitor(cfg *config.Config, logger *logger.Logger) *Monitor {
|
||||
metrics := &Metrics{
|
||||
ResponsesByStatus: make(map[string]int64),
|
||||
TargetMetrics: make(map[string]*TargetMetric),
|
||||
StartTime: time.Now(),
|
||||
LastUpdated: time.Now(),
|
||||
}
|
||||
|
||||
// Initialize target metrics
|
||||
for _, target := range cfg.Proxy.Targets {
|
||||
metrics.TargetMetrics[target.Name] = &TargetMetric{
|
||||
ResponsesByStatus: make(map[string]int64),
|
||||
ResponseTimes: make([]time.Duration, 0),
|
||||
Healthy: target.Healthy,
|
||||
LastChecked: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
monitor := &Monitor{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
metrics: metrics,
|
||||
}
|
||||
|
||||
// Set up authentication if enabled
|
||||
if cfg.Monitor.Auth {
|
||||
monitor.authHandler = monitor.basicAuthHandler(monitor.metricsHandler)
|
||||
} else {
|
||||
monitor.authHandler = monitor.metricsHandler
|
||||
}
|
||||
|
||||
return monitor
|
||||
}
|
||||
|
||||
// Start starts the monitoring service
|
||||
func (m *Monitor) Start() error {
|
||||
if !m.config.Monitor.Enabled {
|
||||
m.logger.Info("Monitoring is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create HTTP server
|
||||
m.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", m.config.Monitor.Port),
|
||||
Handler: m.authHandler,
|
||||
}
|
||||
|
||||
// Start server in a goroutine
|
||||
go func() {
|
||||
m.logger.Info("Monitoring server starting", logger.String("address", m.server.Addr))
|
||||
if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
m.logger.Error("Monitoring server failed", logger.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the monitoring service
|
||||
func (m *Monitor) Stop() {
|
||||
if m.server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
m.server.Shutdown(ctx)
|
||||
m.logger.Info("Monitoring server stopped")
|
||||
}
|
||||
}
|
||||
|
||||
// metricsHandler handles HTTP requests for metrics
|
||||
func (m *Monitor) metricsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != m.config.Monitor.Path {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
m.metrics.mu.RLock()
|
||||
defer m.metrics.mu.RUnlock()
|
||||
|
||||
// Update last updated time
|
||||
m.metrics.LastUpdated = time.Now()
|
||||
|
||||
// Set content type
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Encode metrics as JSON
|
||||
if err := json.NewEncoder(w).Encode(m.metrics); err != nil {
|
||||
m.logger.Error("Failed to encode metrics", logger.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// basicAuthHandler wraps a handler with basic authentication
|
||||
func (m *Monitor) basicAuthHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok || username != m.config.Monitor.Username || password != m.config.Monitor.Password {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="goRZ Monitor"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementRequest increments the total request count
|
||||
func (m *Monitor) IncrementRequest() {
|
||||
m.metrics.mu.Lock()
|
||||
defer m.metrics.mu.Unlock()
|
||||
|
||||
m.metrics.RequestsTotal++
|
||||
m.metrics.LastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// IncrementActiveRequest increments the active request count
|
||||
func (m *Monitor) IncrementActiveRequest() {
|
||||
m.metrics.mu.Lock()
|
||||
defer m.metrics.mu.Unlock()
|
||||
|
||||
m.metrics.RequestsActive++
|
||||
m.metrics.LastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// DecrementActiveRequest decrements the active request count
|
||||
func (m *Monitor) DecrementActiveRequest() {
|
||||
m.metrics.mu.Lock()
|
||||
defer m.metrics.mu.Unlock()
|
||||
|
||||
m.metrics.RequestsActive--
|
||||
if m.metrics.RequestsActive < 0 {
|
||||
m.metrics.RequestsActive = 0
|
||||
}
|
||||
m.metrics.LastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// RecordResponse records a response with the given status code
|
||||
func (m *Monitor) RecordResponse(statusCode int, targetName string, responseTime time.Duration) {
|
||||
m.metrics.mu.Lock()
|
||||
defer m.metrics.mu.Unlock()
|
||||
|
||||
status := fmt.Sprintf("%d", statusCode)
|
||||
m.metrics.ResponsesByStatus[status]++
|
||||
|
||||
// Update target metrics
|
||||
if target, exists := m.metrics.TargetMetrics[targetName]; exists {
|
||||
target.RequestsTotal++
|
||||
target.ResponsesByStatus[status]++
|
||||
|
||||
// Keep only the last 100 response times for average calculation
|
||||
if len(target.ResponseTimes) >= 100 {
|
||||
target.ResponseTimes = target.ResponseTimes[1:]
|
||||
}
|
||||
target.ResponseTimes = append(target.ResponseTimes, responseTime)
|
||||
|
||||
// Calculate average response time
|
||||
var total time.Duration
|
||||
for _, rt := range target.ResponseTimes {
|
||||
total += rt
|
||||
}
|
||||
target.AvgResponseTime = total / time.Duration(len(target.ResponseTimes))
|
||||
}
|
||||
|
||||
m.metrics.LastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// UpdateTargetHealth updates the health status of a target
|
||||
func (m *Monitor) UpdateTargetHealth(targetName string, healthy bool) {
|
||||
m.metrics.mu.Lock()
|
||||
defer m.metrics.mu.Unlock()
|
||||
|
||||
if target, exists := m.metrics.TargetMetrics[targetName]; exists {
|
||||
target.Healthy = healthy
|
||||
target.LastChecked = time.Now()
|
||||
}
|
||||
|
||||
m.metrics.LastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the current metrics
|
||||
func (m *Monitor) GetMetrics() Metrics {
|
||||
m.metrics.mu.RLock()
|
||||
defer m.metrics.mu.RUnlock()
|
||||
|
||||
// Create a deep copy of the metrics
|
||||
metrics := Metrics{
|
||||
RequestsTotal: m.metrics.RequestsTotal,
|
||||
RequestsActive: m.metrics.RequestsActive,
|
||||
ResponsesByStatus: make(map[string]int64),
|
||||
TargetMetrics: make(map[string]*TargetMetric),
|
||||
StartTime: m.metrics.StartTime,
|
||||
LastUpdated: m.metrics.LastUpdated,
|
||||
}
|
||||
|
||||
// Copy response status counts
|
||||
for k, v := range m.metrics.ResponsesByStatus {
|
||||
metrics.ResponsesByStatus[k] = v
|
||||
}
|
||||
|
||||
// Copy target metrics
|
||||
for k, v := range m.metrics.TargetMetrics {
|
||||
targetMetric := &TargetMetric{
|
||||
RequestsTotal: v.RequestsTotal,
|
||||
ResponsesByStatus: make(map[string]int64),
|
||||
ResponseTimes: make([]time.Duration, len(v.ResponseTimes)),
|
||||
AvgResponseTime: v.AvgResponseTime,
|
||||
Healthy: v.Healthy,
|
||||
LastChecked: v.LastChecked,
|
||||
}
|
||||
|
||||
// Copy response status counts for target
|
||||
for rk, rv := range v.ResponsesByStatus {
|
||||
targetMetric.ResponsesByStatus[rk] = rv
|
||||
}
|
||||
|
||||
// Copy response times
|
||||
copy(targetMetric.ResponseTimes, v.ResponseTimes)
|
||||
|
||||
metrics.TargetMetrics[k] = targetMetric
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
312
internal/monitoring/monitoring_test.go
Normal file
312
internal/monitoring/monitoring_test.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/config"
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/logger"
|
||||
)
|
||||
|
||||
func TestMonitorStartStop(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Monitor: config.MonitorConfig{
|
||||
Enabled: true,
|
||||
Port: 9090,
|
||||
Path: "/metrics",
|
||||
Auth: false,
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger()
|
||||
monitor := NewMonitor(cfg, logger)
|
||||
|
||||
// Start the monitor
|
||||
err := monitor.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start monitor: %v", err)
|
||||
}
|
||||
|
||||
// Stop the monitor
|
||||
monitor.Stop()
|
||||
|
||||
// Test that we can start and stop again
|
||||
err = monitor.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start monitor: %v", err)
|
||||
}
|
||||
monitor.Stop()
|
||||
}
|
||||
|
||||
func TestMonitorDisabled(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Monitor: config.MonitorConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger()
|
||||
monitor := NewMonitor(cfg, logger)
|
||||
|
||||
// Start the monitor
|
||||
err := monitor.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start monitor: %v", err)
|
||||
}
|
||||
|
||||
// Stop the monitor
|
||||
monitor.Stop()
|
||||
}
|
||||
|
||||
func TestMonitorMetricsHandler(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Monitor: config.MonitorConfig{
|
||||
Enabled: true,
|
||||
Port: 9090,
|
||||
Path: "/metrics",
|
||||
Auth: false,
|
||||
},
|
||||
Proxy: config.ProxyConfig{
|
||||
Targets: []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger()
|
||||
monitor := NewMonitor(cfg, logger)
|
||||
|
||||
// Create a test HTTP server
|
||||
server := httptest.NewServer(monitor.authHandler)
|
||||
defer server.Close()
|
||||
|
||||
// Test metrics endpoint
|
||||
resp, err := http.Get(server.URL + cfg.Monitor.Path)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get metrics: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Decode and check metrics
|
||||
var metrics Metrics
|
||||
if err := json.NewDecoder(resp.Body).Decode(&metrics); err != nil {
|
||||
t.Fatalf("Failed to decode metrics: %v", err)
|
||||
}
|
||||
|
||||
if metrics.RequestsTotal != 0 {
|
||||
t.Errorf("Expected requests total 0, got %d", metrics.RequestsTotal)
|
||||
}
|
||||
if metrics.RequestsActive != 0 {
|
||||
t.Errorf("Expected requests active 0, got %d", metrics.RequestsActive)
|
||||
}
|
||||
if len(metrics.TargetMetrics) != 2 {
|
||||
t.Errorf("Expected 2 target metrics, got %d", len(metrics.TargetMetrics))
|
||||
}
|
||||
if !metrics.TargetMetrics["target1"].Healthy {
|
||||
t.Error("Expected target1 to be healthy")
|
||||
}
|
||||
if metrics.TargetMetrics["target2"].Healthy {
|
||||
t.Error("Expected target2 to be unhealthy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitorBasicAuth(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Monitor: config.MonitorConfig{
|
||||
Enabled: true,
|
||||
Port: 9090,
|
||||
Path: "/metrics",
|
||||
Auth: true,
|
||||
Username: "admin",
|
||||
Password: "secret",
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger()
|
||||
monitor := NewMonitor(cfg, logger)
|
||||
|
||||
// Create a test HTTP server
|
||||
server := httptest.NewServer(monitor.authHandler)
|
||||
defer server.Close()
|
||||
|
||||
// Test without authentication
|
||||
resp, err := http.Get(server.URL + cfg.Monitor.Path)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get metrics: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Test with incorrect authentication
|
||||
req, err := http.NewRequest("GET", server.URL+cfg.Monitor.Path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.SetBasicAuth("wrong", "credentials")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Test with correct authentication
|
||||
req, err = http.NewRequest("GET", server.URL+cfg.Monitor.Path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.SetBasicAuth(cfg.Monitor.Username, cfg.Monitor.Password)
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitorMetricsTracking(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Monitor: config.MonitorConfig{
|
||||
Enabled: true,
|
||||
Port: 9090,
|
||||
Path: "/metrics",
|
||||
Auth: false,
|
||||
},
|
||||
Proxy: config.ProxyConfig{
|
||||
Targets: []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger()
|
||||
monitor := NewMonitor(cfg, logger)
|
||||
|
||||
// Test incrementing request count
|
||||
monitor.IncrementRequest()
|
||||
metrics := monitor.GetMetrics()
|
||||
if metrics.RequestsTotal != 1 {
|
||||
t.Errorf("Expected requests total 1, got %d", metrics.RequestsTotal)
|
||||
}
|
||||
|
||||
// Test incrementing active request count
|
||||
monitor.IncrementActiveRequest()
|
||||
metrics = monitor.GetMetrics()
|
||||
if metrics.RequestsActive != 1 {
|
||||
t.Errorf("Expected requests active 1, got %d", metrics.RequestsActive)
|
||||
}
|
||||
|
||||
// Test decrementing active request count
|
||||
monitor.DecrementActiveRequest()
|
||||
metrics = monitor.GetMetrics()
|
||||
if metrics.RequestsActive != 0 {
|
||||
t.Errorf("Expected requests active 0, got %d", metrics.RequestsActive)
|
||||
}
|
||||
|
||||
// Test recording response
|
||||
monitor.RecordResponse(http.StatusOK, "target1", 100*time.Millisecond)
|
||||
metrics = monitor.GetMetrics()
|
||||
if metrics.ResponsesByStatus["200"] != 1 {
|
||||
t.Errorf("Expected 1 response with status 200, got %d", metrics.ResponsesByStatus["200"])
|
||||
}
|
||||
if metrics.TargetMetrics["target1"].RequestsTotal != 1 {
|
||||
t.Errorf("Expected target1 to have 1 request, got %d", metrics.TargetMetrics["target1"].RequestsTotal)
|
||||
}
|
||||
if metrics.TargetMetrics["target1"].ResponsesByStatus["200"] != 1 {
|
||||
t.Errorf("Expected target1 to have 1 response with status 200, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["200"])
|
||||
}
|
||||
if len(metrics.TargetMetrics["target1"].ResponseTimes) != 1 {
|
||||
t.Errorf("Expected target1 to have 1 response time, got %d", len(metrics.TargetMetrics["target1"].ResponseTimes))
|
||||
}
|
||||
if metrics.TargetMetrics["target1"].AvgResponseTime != 100*time.Millisecond {
|
||||
t.Errorf("Expected target1 to have average response time 100ms, got %v", metrics.TargetMetrics["target1"].AvgResponseTime)
|
||||
}
|
||||
|
||||
// Test updating target health
|
||||
monitor.UpdateTargetHealth("target1", false)
|
||||
metrics = monitor.GetMetrics()
|
||||
if metrics.TargetMetrics["target1"].Healthy {
|
||||
t.Error("Expected target1 to be unhealthy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitorMultipleResponses(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Monitor: config.MonitorConfig{
|
||||
Enabled: true,
|
||||
Port: 9090,
|
||||
Path: "/metrics",
|
||||
Auth: false,
|
||||
},
|
||||
Proxy: config.ProxyConfig{
|
||||
Targets: []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger()
|
||||
monitor := NewMonitor(cfg, logger)
|
||||
|
||||
// Record multiple responses with different status codes and response times
|
||||
monitor.RecordResponse(http.StatusOK, "target1", 100*time.Millisecond)
|
||||
monitor.RecordResponse(http.StatusOK, "target1", 200*time.Millisecond)
|
||||
monitor.RecordResponse(http.StatusNotFound, "target1", 50*time.Millisecond)
|
||||
monitor.RecordResponse(http.StatusInternalServerError, "target1", 300*time.Millisecond)
|
||||
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
// Check overall metrics
|
||||
if metrics.RequestsTotal != 0 {
|
||||
t.Errorf("Expected requests total 0, got %d", metrics.RequestsTotal)
|
||||
}
|
||||
if metrics.ResponsesByStatus["200"] != 2 {
|
||||
t.Errorf("Expected 2 responses with status 200, got %d", metrics.ResponsesByStatus["200"])
|
||||
}
|
||||
if metrics.ResponsesByStatus["404"] != 1 {
|
||||
t.Errorf("Expected 1 response with status 404, got %d", metrics.ResponsesByStatus["404"])
|
||||
}
|
||||
if metrics.ResponsesByStatus["500"] != 1 {
|
||||
t.Errorf("Expected 1 response with status 500, got %d", metrics.ResponsesByStatus["500"])
|
||||
}
|
||||
|
||||
// Check target metrics
|
||||
if metrics.TargetMetrics["target1"].RequestsTotal != 4 {
|
||||
t.Errorf("Expected target1 to have 4 requests, got %d", metrics.TargetMetrics["target1"].RequestsTotal)
|
||||
}
|
||||
if metrics.TargetMetrics["target1"].ResponsesByStatus["200"] != 2 {
|
||||
t.Errorf("Expected target1 to have 2 responses with status 200, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["200"])
|
||||
}
|
||||
if metrics.TargetMetrics["target1"].ResponsesByStatus["404"] != 1 {
|
||||
t.Errorf("Expected target1 to have 1 response with status 404, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["404"])
|
||||
}
|
||||
if metrics.TargetMetrics["target1"].ResponsesByStatus["500"] != 1 {
|
||||
t.Errorf("Expected target1 to have 1 response with status 500, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["500"])
|
||||
}
|
||||
|
||||
// Check average response time
|
||||
expectedAvg := time.Duration((100 + 200 + 50 + 300) / 4)
|
||||
if metrics.TargetMetrics["target1"].AvgResponseTime != expectedAvg {
|
||||
t.Errorf("Expected target1 to have average response time %v, got %v", expectedAvg, metrics.TargetMetrics["target1"].AvgResponseTime)
|
||||
}
|
||||
}
|
231
internal/nat/nat.go
Normal file
231
internal/nat/nat.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package nat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/config"
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/logger"
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/turn/v2"
|
||||
)
|
||||
|
||||
// NATTraversal handles NAT traversal for the proxy server
|
||||
type NATTraversal struct {
|
||||
config *config.Config
|
||||
logger *logger.Logger
|
||||
conn net.PacketConn
|
||||
externalIP net.IP
|
||||
externalPort int
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// NewNATTraversal creates a new NAT traversal instance
|
||||
func NewNATTraversal(cfg *config.Config, logger *logger.Logger) *NATTraversal {
|
||||
return &NATTraversal{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the NAT traversal process
|
||||
func (n *NATTraversal) Start() error {
|
||||
if !n.config.NAT.Enabled {
|
||||
n.logger.Info("NAT traversal is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
|
||||
if n.running {
|
||||
return fmt.Errorf("NAT traversal is already running")
|
||||
}
|
||||
|
||||
// Create a UDP listener
|
||||
conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP listener: %w", err)
|
||||
}
|
||||
n.conn = conn
|
||||
|
||||
// Get external IP and port using STUN
|
||||
if err := n.discoverExternalAddress(); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to discover external address: %w", err)
|
||||
}
|
||||
|
||||
// Start TURN client if configured
|
||||
if n.config.NAT.TURNServer != "" {
|
||||
if err := n.startTURNClient(); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to start TURN client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n.running = true
|
||||
n.logger.Info("NAT traversal started",
|
||||
logger.String("external_ip", n.externalIP.String()),
|
||||
logger.Int("external_port", n.externalPort))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the NAT traversal process
|
||||
func (n *NATTraversal) Stop() {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
|
||||
if !n.running {
|
||||
return
|
||||
}
|
||||
|
||||
if n.conn != nil {
|
||||
n.conn.Close()
|
||||
n.conn = nil
|
||||
}
|
||||
|
||||
n.running = false
|
||||
n.logger.Info("NAT traversal stopped")
|
||||
}
|
||||
|
||||
// discoverExternalAddress discovers the external IP and port using STUN
|
||||
func (n *NATTraversal) discoverExternalAddress() error {
|
||||
// Parse STUN server URL
|
||||
stunURL, err := stun.ParseURI(n.config.NAT.STUNServer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse STUN server URL: %w", err)
|
||||
}
|
||||
|
||||
// Create STUN client
|
||||
client, err := stun.NewClient("udp4", n.conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create STUN client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Send binding request
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mappedAddr, err := client.Request(ctx, stunURL, stun.BindingRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("STUN binding request failed: %w", err)
|
||||
}
|
||||
|
||||
n.externalIP = mappedAddr.IP
|
||||
n.externalPort = mappedAddr.Port
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startTURNClient starts a TURN client for relay connections
|
||||
func (n *NATTraversal) startTURNClient() error {
|
||||
// Parse TURN server URL
|
||||
turnURL, err := turn.ParseURI(n.config.NAT.TURNServer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TURN server URL: %w", err)
|
||||
}
|
||||
|
||||
// Create TURN client config
|
||||
cfg := &turn.ClientConfig{
|
||||
STUNServerAddr: n.config.NAT.STUNServer,
|
||||
TURNServerAddr: n.config.NAT.TURNServer,
|
||||
Username: n.config.NAT.TURNUsername,
|
||||
Credential: n.config.NAT.TURNPassword,
|
||||
LoggerFactory: nil, // We'll use our own logger
|
||||
}
|
||||
|
||||
// Create TURN client
|
||||
client, err := turn.NewClient(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TURN client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Listen on provided conn
|
||||
n.logger.Info("TURN client started", logger.String("server", n.config.NAT.TURNServer))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExternalAddress returns the external IP and port
|
||||
func (n *NATTraversal) GetExternalAddress() (net.IP, int) {
|
||||
n.mu.RLock()
|
||||
defer n.mu.RUnlock()
|
||||
|
||||
return n.externalIP, n.externalPort
|
||||
}
|
||||
|
||||
// IsRunning returns whether NAT traversal is running
|
||||
func (n *NATTraversal) IsRunning() bool {
|
||||
n.mu.RLock()
|
||||
defer n.mu.RUnlock()
|
||||
|
||||
return n.running
|
||||
}
|
||||
|
||||
// CreateHolePunch attempts to create a hole in the NAT for direct peer-to-peer connections
|
||||
func (n *NATTraversal) CreateHolePunch(peerAddr string) (net.Conn, error) {
|
||||
if !n.running {
|
||||
return nil, fmt.Errorf("NAT traversal is not running")
|
||||
}
|
||||
|
||||
// Parse peer address
|
||||
udpAddr, err := net.ResolveUDPAddr("udp4", peerAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve peer address: %w", err)
|
||||
}
|
||||
|
||||
// Create UDP connection
|
||||
conn, err := net.DialUDP("udp4", nil, udpAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create UDP connection: %w", err)
|
||||
}
|
||||
|
||||
// Send a small packet to punch a hole in the NAT
|
||||
_, err = conn.Write([]byte("punch"))
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to punch hole: %w", err)
|
||||
}
|
||||
|
||||
n.logger.Debug("Hole punch attempted", logger.String("peer", peerAddr))
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// GetNATType attempts to determine the type of NAT
|
||||
func (n *NATTraversal) GetNATType() (string, error) {
|
||||
if !n.running {
|
||||
return "", fmt.Errorf("NAT traversal is not running")
|
||||
}
|
||||
|
||||
// This is a simplified NAT type detection
|
||||
// In a real implementation, you would use more sophisticated methods
|
||||
// like the one described in RFC 3489 and RFC 5780
|
||||
|
||||
// For now, we'll return a generic response
|
||||
return "Unknown", nil
|
||||
}
|
||||
|
||||
// CreateRelayConnection creates a relay connection through TURN
|
||||
func (n *NATTraversal) CreateRelayConnection(peerAddr string) (net.Conn, error) {
|
||||
if !n.running {
|
||||
return nil, fmt.Errorf("NAT traversal is not running")
|
||||
}
|
||||
|
||||
if n.config.NAT.TURNServer == "" {
|
||||
return nil, fmt.Errorf("TURN server not configured")
|
||||
}
|
||||
|
||||
// This is a placeholder for TURN relay connection creation
|
||||
// In a real implementation, you would use the TURN client to allocate
|
||||
// a relay and create a connection to the peer
|
||||
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
217
internal/proxy/loadbalancer.go
Normal file
217
internal/proxy/loadbalancer.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/config"
|
||||
)
|
||||
|
||||
// RoundRobinLoadBalancer implements round-robin load balancing
|
||||
type RoundRobinLoadBalancer struct {
|
||||
targets []*config.TargetConfig
|
||||
current int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRoundRobinLoadBalancer creates a new round-robin load balancer
|
||||
func NewRoundRobinLoadBalancer(targets []config.TargetConfig) *RoundRobinLoadBalancer {
|
||||
t := make([]*config.TargetConfig, len(targets))
|
||||
for i := range targets {
|
||||
t[i] = &targets[i]
|
||||
}
|
||||
return &RoundRobinLoadBalancer{
|
||||
targets: t,
|
||||
current: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// NextTarget returns the next target using round-robin algorithm
|
||||
func (lb *RoundRobinLoadBalancer) NextTarget() (*config.TargetConfig, error) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
// Filter healthy targets
|
||||
healthyTargets := make([]*config.TargetConfig, 0)
|
||||
for _, target := range lb.targets {
|
||||
if target.Healthy {
|
||||
healthyTargets = append(healthyTargets, target)
|
||||
}
|
||||
}
|
||||
|
||||
if len(healthyTargets) == 0 {
|
||||
return nil, ErrNoHealthyTargets
|
||||
}
|
||||
|
||||
// Get next target
|
||||
target := healthyTargets[lb.current%len(healthyTargets)]
|
||||
lb.current = (lb.current + 1) % len(healthyTargets)
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// UpdateTargets updates the targets list
|
||||
func (lb *RoundRobinLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
t := make([]*config.TargetConfig, len(targets))
|
||||
for i := range targets {
|
||||
t[i] = &targets[i]
|
||||
}
|
||||
lb.targets = t
|
||||
lb.current = 0
|
||||
}
|
||||
|
||||
// RandomLoadBalancer implements random load balancing
|
||||
type RandomLoadBalancer struct {
|
||||
targets []*config.TargetConfig
|
||||
rand *rand.Rand
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRandomLoadBalancer creates a new random load balancer
|
||||
func NewRandomLoadBalancer(targets []config.TargetConfig) *RandomLoadBalancer {
|
||||
t := make([]*config.TargetConfig, len(targets))
|
||||
for i := range targets {
|
||||
t[i] = &targets[i]
|
||||
}
|
||||
return &RandomLoadBalancer{
|
||||
targets: t,
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// NextTarget returns a random target
|
||||
func (lb *RandomLoadBalancer) NextTarget() (*config.TargetConfig, error) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
// Filter healthy targets
|
||||
healthyTargets := make([]*config.TargetConfig, 0)
|
||||
for _, target := range lb.targets {
|
||||
if target.Healthy {
|
||||
healthyTargets = append(healthyTargets, target)
|
||||
}
|
||||
}
|
||||
|
||||
if len(healthyTargets) == 0 {
|
||||
return nil, ErrNoHealthyTargets
|
||||
}
|
||||
|
||||
// Get random target
|
||||
index := lb.rand.Intn(len(healthyTargets))
|
||||
return healthyTargets[index], nil
|
||||
}
|
||||
|
||||
// UpdateTargets updates the targets list
|
||||
func (lb *RandomLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
t := make([]*config.TargetConfig, len(targets))
|
||||
for i := range targets {
|
||||
t[i] = &targets[i]
|
||||
}
|
||||
lb.targets = t
|
||||
}
|
||||
|
||||
// LeastConnectionsLoadBalancer implements least connections load balancing
|
||||
type LeastConnectionsLoadBalancer struct {
|
||||
targets []*config.TargetConfig
|
||||
connections map[string]int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewLeastConnectionsLoadBalancer creates a new least connections load balancer
|
||||
func NewLeastConnectionsLoadBalancer(targets []config.TargetConfig) *LeastConnectionsLoadBalancer {
|
||||
t := make([]*config.TargetConfig, len(targets))
|
||||
for i := range targets {
|
||||
t[i] = &targets[i]
|
||||
}
|
||||
|
||||
connections := make(map[string]int)
|
||||
for _, target := range t {
|
||||
connections[target.Name] = 0
|
||||
}
|
||||
|
||||
return &LeastConnectionsLoadBalancer{
|
||||
targets: t,
|
||||
connections: connections,
|
||||
}
|
||||
}
|
||||
|
||||
// NextTarget returns the target with the least connections
|
||||
func (lb *LeastConnectionsLoadBalancer) NextTarget() (*config.TargetConfig, error) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
// Filter healthy targets and find the one with least connections
|
||||
var selectedTarget *config.TargetConfig
|
||||
minConnections := -1
|
||||
|
||||
for _, target := range lb.targets {
|
||||
if target.Healthy {
|
||||
connections := lb.connections[target.Name]
|
||||
if minConnections == -1 || connections < minConnections {
|
||||
minConnections = connections
|
||||
selectedTarget = target
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selectedTarget == nil {
|
||||
return nil, ErrNoHealthyTargets
|
||||
}
|
||||
|
||||
// Increment connection count
|
||||
lb.connections[selectedTarget.Name]++
|
||||
|
||||
return selectedTarget, nil
|
||||
}
|
||||
|
||||
// ReleaseConnection decrements the connection count for a target
|
||||
func (lb *LeastConnectionsLoadBalancer) ReleaseConnection(targetName string) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
if count, exists := lb.connections[targetName]; exists && count > 0 {
|
||||
lb.connections[targetName] = count - 1
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTargets updates the targets list
|
||||
func (lb *LeastConnectionsLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
t := make([]*config.TargetConfig, len(targets))
|
||||
for i := range targets {
|
||||
t[i] = &targets[i]
|
||||
}
|
||||
|
||||
// Update connections map
|
||||
connections := make(map[string]int)
|
||||
for _, target := range t {
|
||||
// Preserve existing connection count if target exists
|
||||
if count, exists := lb.connections[target.Name]; exists {
|
||||
connections[target.Name] = count
|
||||
} else {
|
||||
connections[target.Name] = 0
|
||||
}
|
||||
}
|
||||
|
||||
lb.targets = t
|
||||
lb.connections = connections
|
||||
}
|
||||
|
||||
// ErrNoHealthyTargets is returned when no healthy targets are available
|
||||
var ErrNoHealthyTargets = errorString("no healthy targets available")
|
||||
|
||||
// errorString is a simple string-based error type
|
||||
type errorString string
|
||||
|
||||
func (e errorString) Error() string {
|
||||
return string(e)
|
||||
}
|
270
internal/proxy/loadbalancer_test.go
Normal file
270
internal/proxy/loadbalancer_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/config"
|
||||
)
|
||||
|
||||
func TestRoundRobinLoadBalancer(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
|
||||
lb := NewRoundRobinLoadBalancer(targets)
|
||||
|
||||
// Test that targets are selected in round-robin order
|
||||
for i := 0; i < 6; i++ {
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedTarget := targets[i%3]
|
||||
if target.Name != expectedTarget.Name {
|
||||
t.Errorf("Expected target %s, got %s", expectedTarget.Name, target.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinLoadBalancerWithUnhealthyTargets(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
|
||||
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
|
||||
lb := NewRoundRobinLoadBalancer(targets)
|
||||
|
||||
// Test that only healthy targets are selected
|
||||
for i := 0; i < 4; i++ {
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if target.Name == "target2" {
|
||||
t.Errorf("Selected unhealthy target: %s", target.Name)
|
||||
}
|
||||
|
||||
// Should alternate between target1 and target3
|
||||
if i%2 == 0 && target.Name != "target1" {
|
||||
t.Errorf("Expected target1, got %s", target.Name)
|
||||
}
|
||||
if i%2 == 1 && target.Name != "target3" {
|
||||
t.Errorf("Expected target3, got %s", target.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinLoadBalancerNoHealthyTargets(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: false},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
|
||||
}
|
||||
|
||||
lb := NewRoundRobinLoadBalancer(targets)
|
||||
|
||||
_, err := lb.NextTarget()
|
||||
if err != ErrNoHealthyTargets {
|
||||
t.Errorf("Expected ErrNoHealthyTargets, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinLoadBalancerUpdateTargets(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
|
||||
lb := NewRoundRobinLoadBalancer(targets)
|
||||
|
||||
// Get first target
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target1" {
|
||||
t.Errorf("Expected target1, got %s", target.Name)
|
||||
}
|
||||
|
||||
// Update targets
|
||||
newTargets := []config.TargetConfig{
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
lb.UpdateTargets(newTargets)
|
||||
|
||||
// Test that new targets are selected
|
||||
for i := 0; i < 4; i++ {
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedTarget := newTargets[i%2]
|
||||
if target.Name != expectedTarget.Name {
|
||||
t.Errorf("Expected target %s, got %s", expectedTarget.Name, target.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomLoadBalancer(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
|
||||
lb := NewRandomLoadBalancer(targets)
|
||||
|
||||
// Test that targets are selected randomly
|
||||
// We'll just check that we don't get errors and that all targets are eventually selected
|
||||
selected := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
selected[target.Name] = true
|
||||
}
|
||||
|
||||
// Check that all targets were selected at least once
|
||||
for _, target := range targets {
|
||||
if !selected[target.Name] {
|
||||
t.Errorf("Target %s was never selected", target.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomLoadBalancerWithUnhealthyTargets(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
|
||||
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
|
||||
lb := NewRandomLoadBalancer(targets)
|
||||
|
||||
// Test that only healthy targets are selected
|
||||
for i := 0; i < 100; i++ {
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if target.Name == "target2" {
|
||||
t.Errorf("Selected unhealthy target: %s", target.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnectionsLoadBalancer(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
|
||||
lb := NewLeastConnectionsLoadBalancer(targets)
|
||||
|
||||
// Initially, both targets should have 0 connections
|
||||
// The first target should be selected
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target1" {
|
||||
t.Errorf("Expected target1, got %s", target.Name)
|
||||
}
|
||||
|
||||
// Now target1 should have 1 connection, target2 should have 0
|
||||
// So target2 should be selected
|
||||
target, err = lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target2" {
|
||||
t.Errorf("Expected target2, got %s", target.Name)
|
||||
}
|
||||
|
||||
// Now both targets should have 1 connection
|
||||
// target1 should be selected again
|
||||
target, err = lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target1" {
|
||||
t.Errorf("Expected target1, got %s", target.Name)
|
||||
}
|
||||
|
||||
// Release a connection from target1
|
||||
lb.(*LeastConnectionsLoadBalancer).ReleaseConnection("target1")
|
||||
|
||||
// Now target1 should have 1 connection, target2 should have 1 connection
|
||||
// But since we released from target1, it should be selected
|
||||
target, err = lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target1" {
|
||||
t.Errorf("Expected target1, got %s", target.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnectionsLoadBalancerUpdateTargets(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
|
||||
lb := NewLeastConnectionsLoadBalancer(targets)
|
||||
|
||||
// Get first target
|
||||
target, err := lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target1" {
|
||||
t.Errorf("Expected target1, got %s", target.Name)
|
||||
}
|
||||
|
||||
// Update targets
|
||||
newTargets := []config.TargetConfig{
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
|
||||
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
|
||||
}
|
||||
lb.UpdateTargets(newTargets)
|
||||
|
||||
// Test that new targets are selected
|
||||
// The first new target should be selected
|
||||
target, err = lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target2" {
|
||||
t.Errorf("Expected target2, got %s", target.Name)
|
||||
}
|
||||
|
||||
// The second new target should be selected
|
||||
target, err = lb.NextTarget()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if target.Name != "target3" {
|
||||
t.Errorf("Expected target3, got %s", target.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnectionsLoadBalancerNoHealthyTargets(t *testing.T) {
|
||||
targets := []config.TargetConfig{
|
||||
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: false},
|
||||
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
|
||||
}
|
||||
|
||||
lb := NewLeastConnectionsLoadBalancer(targets)
|
||||
|
||||
_, err := lb.NextTarget()
|
||||
if err != ErrNoHealthyTargets {
|
||||
t.Errorf("Expected ErrNoHealthyTargets, got %v", err)
|
||||
}
|
||||
}
|
238
internal/proxy/proxy.go
Normal file
238
internal/proxy/proxy.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/config"
|
||||
"github.com/iwasforcedtobehere/goRZ/internal/logger"
|
||||
)
|
||||
|
||||
// Server represents the reverse proxy server
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
logger *logger.Logger
|
||||
loadBalancer LoadBalancer
|
||||
healthChecker *HealthChecker
|
||||
}
|
||||
|
||||
// LoadBalancer defines the interface for load balancing strategies
|
||||
type LoadBalancer interface {
|
||||
NextTarget() (*config.TargetConfig, error)
|
||||
UpdateTargets(targets []config.TargetConfig)
|
||||
}
|
||||
|
||||
// HealthChecker handles health checking for proxy targets
|
||||
type HealthChecker struct {
|
||||
config *config.Config
|
||||
logger *logger.Logger
|
||||
targets map[string]*config.TargetConfig
|
||||
httpClient *http.Client
|
||||
stopChan chan struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewServer creates a new reverse proxy server
|
||||
func NewServer(cfg *config.Config, logger *logger.Logger) (*Server, error) {
|
||||
// Create load balancer based on configuration
|
||||
var lb LoadBalancer
|
||||
switch cfg.Proxy.LoadBalancer {
|
||||
case "roundrobin":
|
||||
lb = NewRoundRobinLoadBalancer(cfg.Proxy.Targets)
|
||||
case "leastconn":
|
||||
lb = NewLeastConnectionsLoadBalancer(cfg.Proxy.Targets)
|
||||
case "random":
|
||||
lb = NewRandomLoadBalancer(cfg.Proxy.Targets)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported load balancer: %s", cfg.Proxy.LoadBalancer)
|
||||
}
|
||||
|
||||
// Create health checker
|
||||
healthChecker := NewHealthChecker(cfg, logger)
|
||||
|
||||
return &Server{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
loadBalancer: lb,
|
||||
healthChecker: healthChecker,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ServeHTTP handles HTTP requests
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get next target based on load balancing strategy
|
||||
target, err := s.loadBalancer.NextTarget()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get target", logger.Error(err))
|
||||
http.Error(w, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if target is healthy
|
||||
if !target.Healthy {
|
||||
s.logger.Warn("Target is unhealthy", logger.String("target", target.Name))
|
||||
http.Error(w, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Create reverse proxy
|
||||
targetURL, err := url.Parse(target.Address)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to parse target URL",
|
||||
logger.String("target", target.Name),
|
||||
logger.String("address", target.Address),
|
||||
logger.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
// Set up custom director to modify request
|
||||
originalDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Add custom headers
|
||||
req.Header.Set("X-Forwarded-Host", req.Host)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
if req.TLS != nil {
|
||||
req.Header.Set("X-Forwarded-Ssl", "on")
|
||||
}
|
||||
// Add proxy identification
|
||||
req.Header.Set("X-Proxy-Server", "goRZ")
|
||||
}
|
||||
|
||||
// Set up error handler
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
s.logger.Error("Proxy error", logger.Error(err))
|
||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
||||
}
|
||||
|
||||
// Set up modify response handler
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Add custom headers to response
|
||||
resp.Header.Set("X-Served-By", "goRZ")
|
||||
resp.Header.Set("X-Target", target.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve the request
|
||||
s.logger.Debug("Proxying request",
|
||||
logger.String("method", r.Method),
|
||||
logger.String("path", r.URL.Path),
|
||||
logger.String("target", target.Name))
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Start starts the proxy server
|
||||
func (s *Server) Start() error {
|
||||
// Start health checker
|
||||
if err := s.healthChecker.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start health checker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the proxy server
|
||||
func (s *Server) Stop() {
|
||||
s.healthChecker.Stop()
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
func NewHealthChecker(cfg *config.Config, logger *logger.Logger) *HealthChecker {
|
||||
targets := make(map[string]*config.TargetConfig)
|
||||
for i := range cfg.Proxy.Targets {
|
||||
targets[cfg.Proxy.Targets[i].Name] = &cfg.Proxy.Targets[i]
|
||||
}
|
||||
|
||||
return &HealthChecker{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
targets: targets,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the health checker
|
||||
func (h *HealthChecker) Start() error {
|
||||
// Initial health check
|
||||
h.checkAllTargets()
|
||||
|
||||
// Start periodic health checks
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Duration(h.config.Proxy.HealthCheckInterval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
h.checkAllTargets()
|
||||
case <-h.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the health checker
|
||||
func (h *HealthChecker) Stop() {
|
||||
close(h.stopChan)
|
||||
}
|
||||
|
||||
// checkAllTargets checks the health of all targets
|
||||
func (h *HealthChecker) checkAllTargets() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for name, target := range h.targets {
|
||||
healthy := h.checkTargetHealth(target)
|
||||
if target.Healthy != healthy {
|
||||
h.logger.Info("Target health status changed",
|
||||
logger.String("target", name),
|
||||
logger.Bool("healthy", healthy))
|
||||
target.Healthy = healthy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkTargetHealth checks the health of a single target
|
||||
func (h *HealthChecker) checkTargetHealth(target *config.TargetConfig) bool {
|
||||
targetURL, err := url.Parse(target.Address)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to parse target URL for health check",
|
||||
logger.String("target", target.Name),
|
||||
logger.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
healthURL := *targetURL
|
||||
healthURL.Path = h.config.Proxy.HealthCheckPath
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL.String(), nil)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create health check request",
|
||||
logger.String("target", target.Name),
|
||||
logger.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
h.logger.Error("Health check request failed",
|
||||
logger.String("target", target.Name),
|
||||
logger.Error(err))
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
Reference in New Issue
Block a user