LFG
Some checks failed
CI/CD Pipeline / Run Tests (push) Has been cancelled
CI/CD Pipeline / Build Application (push) Has been cancelled
CI/CD Pipeline / Build Docker Image (push) Has been cancelled
CI/CD Pipeline / Security Scan (push) Has been cancelled
CI/CD Pipeline / Create Release (push) Has been cancelled

This commit is contained in:
Dev
2025-09-11 18:59:15 +03:00
commit 5440884b85
20 changed files with 3074 additions and 0 deletions

170
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,170 @@
name: CI/CD Pipeline
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
env:
GO_VERSION: 1.19
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ${{ env.GO_VERSION }}
cache: true
- name: Download dependencies
run: make deps
- name: Run linter
uses: golangci/golangci-lint-action@v3
with:
version: latest
args: --timeout=5m
- name: Run tests
run: make test
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.out
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
build:
name: Build Application
runs-on: ubuntu-latest
needs: test
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ${{ env.GO_VERSION }}
cache: true
- name: Download dependencies
run: make deps
- name: Build application
run: make build
- name: Upload build artifacts
uses: actions/upload-artifact@v3
with:
name: gorz-binary
path: gorz
docker-build:
name: Build Docker Image
runs-on: ubuntu-latest
needs: build
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop'
permissions:
contents: read
packages: write
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v4
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=pr
type=sha,prefix={{branch}}-
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
uses: docker/build-push-action@v4
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
security-scan:
name: Security Scan
runs-on: ubuntu-latest
needs: build
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
scan-type: 'fs'
scan-ref: '.'
format: 'sarif'
output: 'trivy-results.sarif'
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@v2
with:
sarif_file: 'trivy-results.sarif'
release:
name: Create Release
runs-on: ubuntu-latest
needs: [test, build, docker-build, security-scan]
if: github.ref == 'refs/heads/main' && startsWith(github.ref, 'refs/tags/v')
permissions:
contents: write
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Download build artifacts
uses: actions/download-artifact@v3
with:
name: gorz-binary
- name: Create Release
uses: softprops/action-gh-release@v1
with:
files: gorz
generate_release_notes: true
draft: false
prerelease: false

48
.gitignore vendored Normal file
View File

@@ -0,0 +1,48 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
gorz
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool
*.out
coverage.html
# Go workspace files
go.work
# Dependency directories
vendor/
# IDE specific files
.idea/
.vscode/
*.swp
*.swo
*~
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Configuration files with sensitive data
config.yaml
*.key
*.pem
*.crt
# Log files
*.log
# Docker compose override files
docker-compose.override.yml

54
Dockerfile Normal file
View File

@@ -0,0 +1,54 @@
# Build stage
FROM golang:1.19-alpine AS builder
# Install git and certificates
RUN apk add --no-cache git ca-certificates
# Set working directory
WORKDIR /app
# Copy go mod files
COPY go.mod go.sum ./
# Download dependencies
RUN go mod download
# Copy source code
COPY . .
# Build the application
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -ldflags '-extldflags "-static"' -o /gorz ./cmd/server
# Final stage
FROM alpine:latest
# Install ca-certificates for HTTPS
RUN apk --no-cache add ca-certificates
# Create non-root user
RUN addgroup -g 1000 appgroup && adduser -u 1000 -G appgroup -s /bin/sh -D appuser
# Set working directory
WORKDIR /app
# Copy binary from builder
COPY --from=builder /gorz .
# Copy config file if it exists
COPY config.yaml .
# Change ownership
RUN chown -R appuser:appgroup /app
# Switch to non-root user
USER appuser
# Expose ports
EXPOSE 8080 9090
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8080/health || exit 1
# Run the application
CMD ["./gorz"]

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 iwasforcedtobehere
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

81
Makefile Normal file
View File

@@ -0,0 +1,81 @@
.PHONY: build run test clean deps fmt vet lint docker-build docker-run docker-clean help
# Variables
BINARY_NAME=gorz
VERSION=$(shell git describe --tags --always --dirty --match="v*")
BUILD_TIME=$(shell date -u '+%Y-%m-%d %H:%M:%S')
GIT_COMMIT=$(shell git rev-parse --short HEAD)
LDFLAGS=-ldflags "-X main.Version=${VERSION} -X main.BuildTime='${BUILD_TIME}' -X main.GitCommit=${GIT_COMMIT}"
# Default target
all: deps fmt vet lint build
# Build the binary
build:
go build ${LDFLAGS} -o ${BINARY_NAME} ./cmd/server
# Run the binary
run: build
./${BINARY_NAME}
# Run tests
test:
go test -v -race -coverprofile=coverage.out ./...
go tool cover -html=coverage.out -o coverage.html
# Clean build artifacts
clean:
go clean
rm -f ${BINARY_NAME}
rm -f coverage.out coverage.html
# Download dependencies
deps:
go mod download
go mod tidy
# Format code
fmt:
go fmt ./...
# Vet checks for errors
vet:
go vet ./...
# Run linter
lint:
golangci-lint run
# Build Docker image
docker-build:
docker build -t ${BINARY_NAME}:${VERSION} .
# Run Docker container
docker-run: docker-build
docker run -p 8080:8080 -p 9090:9090 -v $(PWD)/config.yaml:/app/config.yaml ${BINARY_NAME}:${VERSION}
# Clean Docker images
docker-clean:
docker rmi ${BINARY_NAME}:${VERSION} || true
# Generate default config
config:
go run ./cmd/server -config=config.yaml && echo "Default configuration created at config.yaml"
# Show help
help:
@echo "Available targets:"
@echo " all - Run deps, fmt, vet, lint, and build"
@echo " build - Build the binary"
@echo " run - Build and run the binary"
@echo " test - Run tests and generate coverage report"
@echo " clean - Clean build artifacts"
@echo " deps - Download dependencies"
@echo " fmt - Format code"
@echo " vet - Run go vet"
@echo " lint - Run linter"
@echo " docker-build - Build Docker image"
@echo " docker-run - Build and run Docker container"
@echo " docker-clean - Clean Docker images"
@echo " config - Generate default configuration"
@echo " help - Show this help message"

179
README.md Normal file
View File

@@ -0,0 +1,179 @@
# goRZ - Fast Reverse Proxy Server
[![Go Report Card](https://goreportcard.com/badge/github.com/iwasforcedtobehere/goRZ)](https://goreportcard.com/report/github.com/iwasforcedtobehere/goRZ)
[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
[![Release](https://img.shields.io/github/v/release/iwasforcedtobehere/goRZ)](https://github.com/iwasforcedtobehere/goRZ/releases)
goRZ is a blazingly fast reverse proxy server designed for local development with NAT/firewall traversal capabilities. It's the perfect tool for developers who enjoy watching their requests bounce around like a confused pinball before reaching their destination.
## Features
-**High Performance**: Built with Go, because who has time to wait for slow proxies?
- 🔄 **Multiple Load Balancing Strategies**: Round Robin, Least Connections, and Random (for when you're feeling adventurous)
- 🩺 **Health Checking**: Monitors your targets like an overprotective parent
- 🔍 **Monitoring**: Built-in metrics endpoint because numbers make us feel smart
- 🌐 **NAT/Firewall Traversal**: Punches through NATs and firewalls like they're made of paper
- 🐳 **Docker Support**: Containerized for your convenience (and our sanity)
- ⚙️ **Flexible Configuration**: YAML configuration that's almost as flexible as a yoga instructor
## Installation
### From Source
```bash
git clone https://github.com/iwasforcedtobehere/goRZ.git
cd goRZ
make build
```
### Using Docker
```bash
docker pull ghcr.io/iwasforcedtobehere/gorz:latest
```
## Quick Start
1. Create a configuration file:
```bash
make config
```
2. Edit the generated `config.yaml` to match your needs (or just use the defaults, we won't judge)
3. Run the server:
```bash
make run
```
4. Marvel at your requests being proxied with the speed of a caffeinated cheetah
## Configuration
goRZ uses a YAML configuration file that's so intuitive, you'll think you've been using it for years (even if it's your first time):
```yaml
server:
port: 8080
read_timeout: 30
write_timeout: 30
idle_timeout: 60
proxy:
targets:
- name: "my-awesome-app"
address: "http://localhost:3000"
protocol: "http"
weight: 1
load_balancer: "roundrobin"
health_check_path: "/health"
health_check_interval: 30
nat:
enabled: false
stun_server: "stun:stun.l.google.com:19302"
logging:
level: "info"
format: "json"
output: "stdout"
monitor:
enabled: true
port: 9090
path: "/metrics"
auth: false
```
## Load Balancing
goRZ supports multiple load balancing strategies because we believe in giving you choices (even if you'll probably just stick with round-robin):
### Round Robin
The classic "take turns" approach. Fair, predictable, and about as exciting as a beige wall.
### Least Connections
For when you want to distribute load based on which server is least busy. It's like being a traffic controller, but with less stress and more coffee.
### Random
For when you're feeling lucky or just enjoy chaos. It's surprisingly effective, which says something about the universe.
## NAT/Firewall Traversal
goRZ can help you traverse NATs and firewalls using STUN and TURN protocols. It's like having a secret tunnel that bypasses all those annoying network restrictions:
```yaml
nat:
enabled: true
stun_server: "stun:stun.l.google.com:19302"
turn_server: "turn:your-turn-server:3478"
turn_username: "your-username"
turn_password: "your-password"
```
## Monitoring
goRZ includes a built-in monitoring endpoint that provides metrics about the proxy's performance. It's like having a fitness tracker for your proxy:
```bash
curl http://localhost:9090/metrics
```
You'll get a beautiful JSON response with all sorts of numbers and statistics that you can pretend to understand while nodding thoughtfully.
## Development
### Prerequisites
- Go 1.19 or later
- Make (for the Makefile targets)
- Docker (optional, for containerized development)
### Building
```bash
make build
```
### Testing
```bash
make test
```
### Linting
```bash
make lint
```
## Contributing
Contributions are welcome! If you'd like to contribute, please:
1. Fork the repository
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
4. Push to the branch (`git push origin feature/amazing-feature`)
5. Open a Pull Request
We'll review your contribution with the enthusiasm of someone who just discovered coffee.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. It's so permissive, you could probably use it to power a small country if you wanted to.
## Acknowledgments
- The Go team for creating a language that makes us feel productive
- The authors of all the libraries we're using (you're the real MVPs)
- Coffee, for fueling the development of this project
- Our users, for being brave enough to trust a proxy with a name that sounds like a monster from a 1950s B-movie
## Disclaimer
This software is provided "as is", without warranty of any kind, express or implied. In no event shall the authors be liable for any claim, damages or other liability, whether in an action of contract, tort or otherwise, arising from, out of or in connection with the software or the use or other dealings in the software.
Basically, if it breaks, you get to keep both pieces. We're programmers, not lawyers.

89
cmd/server/main.go Normal file
View File

@@ -0,0 +1,89 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
"github.com/iwasforcedtobehere/goRZ/internal/proxy"
)
var (
configPath = flag.String("config", "config.yaml", "Path to configuration file")
version = flag.Bool("version", false, "Show version information")
)
// Version information set at build time
var (
Version = "dev"
BuildTime = "unknown"
GitCommit = "unknown"
)
func main() {
flag.Parse()
if *version {
fmt.Printf("goRZ %s\n", Version)
fmt.Printf("Build time: %s\n", BuildTime)
fmt.Printf("Git commit: %s\n", GitCommit)
os.Exit(0)
}
// Initialize logger
appLogger := logger.NewLogger()
appLogger.Info("Starting goRZ reverse proxy server")
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
appLogger.Fatal("Failed to load configuration", logger.Error(err))
}
// Create proxy server
proxyServer, err := proxy.NewServer(cfg, appLogger)
if err != nil {
appLogger.Fatal("Failed to create proxy server", logger.Error(err))
}
// Create HTTP server
server := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: proxyServer,
ReadTimeout: time.Duration(cfg.Server.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(cfg.Server.WriteTimeout) * time.Second,
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
}
// Start server in a goroutine
go func() {
appLogger.Info("Server starting", logger.String("address", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
appLogger.Fatal("Server failed", logger.Error(err))
}
}()
// Wait for interrupt signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
appLogger.Info("Shutting down server...")
// Graceful shutdown with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
appLogger.Error("Server forced to shutdown", logger.Error(err))
}
appLogger.Info("Server exited")
}

89
docker-compose.yml Normal file
View File

@@ -0,0 +1,89 @@
version: '3.8'
services:
gorz:
build: .
ports:
- "8080:8080"
- "9090:9090"
volumes:
- ./examples/config.yaml:/app/config.yaml
environment:
- GORZ_CONFIG=/app/config.yaml
restart: unless-stopped
healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
networks:
- gorz-network
# Example web application to proxy to
webapp:
image: nginxdemos/hello:plain-text
ports:
- "3000:80"
networks:
- gorz-network
healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:80"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
# Another example application
api:
image: nginxdemos/hello:plain-text
ports:
- "3001:80"
networks:
- gorz-network
healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:80"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
# Prometheus for metrics collection
prometheus:
image: prom/prometheus:latest
ports:
- "9091:9090"
volumes:
- ./examples/prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus-data:/prometheus
command:
- '--config.file=/etc/prometheus/prometheus.yml'
- '--storage.tsdb.path=/prometheus'
- '--web.console.libraries=/etc/prometheus/console_libraries'
- '--web.console.templates=/etc/prometheus/consoles'
- '--storage.tsdb.retention.time=200h'
- '--web.enable-lifecycle'
networks:
- gorz-network
restart: unless-stopped
# Grafana for metrics visualization
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana-data:/var/lib/grafana
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
networks:
- gorz-network
restart: unless-stopped
networks:
gorz-network:
driver: bridge
volumes:
prometheus-data:
grafana-data:

30
examples/prometheus.yml Normal file
View File

@@ -0,0 +1,30 @@
global:
scrape_interval: 15s
evaluation_interval: 15s
rule_files:
# - "first_rules.yml"
# - "second_rules.yml"
scrape_configs:
- job_name: 'prometheus'
static_configs:
- targets: ['localhost:9090']
- job_name: 'gorz'
static_configs:
- targets: ['gorz:9090']
metrics_path: '/metrics'
scrape_interval: 5s
- job_name: 'webapp'
static_configs:
- targets: ['webapp:80']
metrics_path: '/metrics'
scrape_interval: 5s
- job_name: 'api'
static_configs:
- targets: ['api:80']
metrics_path: '/metrics'
scrape_interval: 5s

21
go.mod Normal file
View File

@@ -0,0 +1,21 @@
module github.com/iwasforcedtobehere/goRZ
go 1.19
require (
github.com/pion/stun v1.0.0
github.com/pion/turn/v2 v2.0.10
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.5 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport v0.13.1 // indirect
github.com/pion/udp v0.1.1 // indirect
golang.org/x/crypto v0.0.0-20220321153916-1c453e9389ed // indirect
golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect
golang.org/x/text v0.3.7 // indirect
)

206
internal/config/config.go Normal file
View File

@@ -0,0 +1,206 @@
package config
import (
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
// Config represents the application configuration
type Config struct {
Server ServerConfig `yaml:"server"`
Proxy ProxyConfig `yaml:"proxy"`
NAT NATConfig `yaml:"nat"`
Logging LoggingConfig `yaml:"logging"`
Monitor MonitorConfig `yaml:"monitor"`
}
// ServerConfig represents server configuration
type ServerConfig struct {
Port int `yaml:"port"`
ReadTimeout int `yaml:"read_timeout"`
WriteTimeout int `yaml:"write_timeout"`
IdleTimeout int `yaml:"idle_timeout"`
TLSCertFile string `yaml:"tls_cert_file,omitempty"`
TLSKeyFile string `yaml:"tls_key_file,omitempty"`
}
// ProxyConfig represents reverse proxy configuration
type ProxyConfig struct {
Targets []TargetConfig `yaml:"targets"`
LoadBalancer string `yaml:"load_balancer"` // "roundrobin", "leastconn", "random"
HealthCheckPath string `yaml:"health_check_path"`
HealthCheckInterval int `yaml:"health_check_interval"`
}
// TargetConfig represents a proxy target
type TargetConfig struct {
Name string `yaml:"name"`
Address string `yaml:"address"`
Protocol string `yaml:"protocol"` // "http", "https"
Weight int `yaml:"weight"` // for weighted load balancing
Healthy bool `yaml:"-"` // health status
}
// NATConfig represents NAT traversal configuration
type NATConfig struct {
Enabled bool `yaml:"enabled"`
STUNServer string `yaml:"stun_server,omitempty"`
TURNServer string `yaml:"turn_server,omitempty"`
TURNUsername string `yaml:"turn_username,omitempty"`
TURNPassword string `yaml:"turn_password,omitempty"`
}
// LoggingConfig represents logging configuration
type LoggingConfig struct {
Level string `yaml:"level"` // "debug", "info", "warn", "error"
Format string `yaml:"format"` // "json", "text"
Output string `yaml:"output"` // "stdout", "file"
File string `yaml:"file,omitempty"`
}
// MonitorConfig represents monitoring configuration
type MonitorConfig struct {
Enabled bool `yaml:"enabled"`
Port int `yaml:"port"`
Path string `yaml:"path"`
Auth bool `yaml:"auth"`
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
}
// Load loads configuration from file
func Load(path string) (*Config, error) {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return nil, fmt.Errorf("configuration file not found: %s", path)
}
// Read file
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read configuration file: %w", err)
}
// Parse YAML
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse configuration: %w", err)
}
// Set defaults
setDefaults(&config)
return &config, nil
}
// setDefaults sets default values for configuration
func setDefaults(c *Config) {
// Server defaults
if c.Server.Port == 0 {
c.Server.Port = 8080
}
if c.Server.ReadTimeout == 0 {
c.Server.ReadTimeout = 30
}
if c.Server.WriteTimeout == 0 {
c.Server.WriteTimeout = 30
}
if c.Server.IdleTimeout == 0 {
c.Server.IdleTimeout = 60
}
// Proxy defaults
if c.Proxy.LoadBalancer == "" {
c.Proxy.LoadBalancer = "roundrobin"
}
if c.Proxy.HealthCheckPath == "" {
c.Proxy.HealthCheckPath = "/health"
}
if c.Proxy.HealthCheckInterval == 0 {
c.Proxy.HealthCheckInterval = 30
}
// NAT defaults
if c.NAT.Enabled && c.NAT.STUNServer == "" {
c.NAT.STUNServer = "stun:stun.l.google.com:19302"
}
// Logging defaults
if c.Logging.Level == "" {
c.Logging.Level = "info"
}
if c.Logging.Format == "" {
c.Logging.Format = "json"
}
if c.Logging.Output == "" {
c.Logging.Output = "stdout"
}
// Monitor defaults
if c.Monitor.Enabled && c.Monitor.Port == 0 {
c.Monitor.Port = 9090
}
if c.Monitor.Enabled && c.Monitor.Path == "" {
c.Monitor.Path = "/metrics"
}
}
// CreateDefaultConfig creates a default configuration file
func CreateDefaultConfig(path string) error {
config := Config{
Server: ServerConfig{
Port: 8080,
ReadTimeout: 30,
WriteTimeout: 30,
IdleTimeout: 60,
},
Proxy: ProxyConfig{
LoadBalancer: "roundrobin",
HealthCheckPath: "/health",
HealthCheckInterval: 30,
Targets: []TargetConfig{
{
Name: "example",
Address: "http://localhost:3000",
Protocol: "http",
Weight: 1,
},
},
},
NAT: NATConfig{
Enabled: false,
STUNServer: "stun:stun.l.google.com:19302",
},
Logging: LoggingConfig{
Level: "info",
Format: "json",
Output: "stdout",
},
Monitor: MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
}
data, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal default configuration: %w", err)
}
// Create directory if it doesn't exist
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("failed to create configuration directory: %w", err)
}
// Write configuration file
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("failed to write configuration file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,243 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestLoad(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "test-config.yaml")
// Create a test configuration file
testConfig := `server:
port: 9090
read_timeout: 60
write_timeout: 60
idle_timeout: 120
proxy:
targets:
- name: "test-target"
address: "http://localhost:8080"
protocol: "http"
weight: 1
load_balancer: "leastconn"
health_check_path: "/healthz"
health_check_interval: 60
nat:
enabled: true
stun_server: "stun:stun.example.com:3478"
turn_server: "turn:turn.example.com:3478"
turn_username: "testuser"
turn_password: "testpass"
logging:
level: "debug"
format: "text"
output: "file"
file: "/var/log/gorz.log"
monitor:
enabled: true
port: 8081
path: "/stats"
auth: true
username: "admin"
password: "secret"
`
err := os.WriteFile(configPath, []byte(testConfig), 0644)
if err != nil {
t.Fatalf("Failed to write test config file: %v", err)
}
// Test loading the configuration
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Failed to load configuration: %v", err)
}
// Verify server configuration
if cfg.Server.Port != 9090 {
t.Errorf("Expected port 9090, got %d", cfg.Server.Port)
}
if cfg.Server.ReadTimeout != 60 {
t.Errorf("Expected read timeout 60, got %d", cfg.Server.ReadTimeout)
}
// Verify proxy configuration
if cfg.Proxy.LoadBalancer != "leastconn" {
t.Errorf("Expected load balancer 'leastconn', got %s", cfg.Proxy.LoadBalancer)
}
if len(cfg.Proxy.Targets) != 1 {
t.Errorf("Expected 1 target, got %d", len(cfg.Proxy.Targets))
}
if cfg.Proxy.Targets[0].Name != "test-target" {
t.Errorf("Expected target name 'test-target', got %s", cfg.Proxy.Targets[0].Name)
}
// Verify NAT configuration
if !cfg.NAT.Enabled {
t.Error("Expected NAT enabled to be true")
}
if cfg.NAT.STUNServer != "stun:stun.example.com:3478" {
t.Errorf("Expected STUN server 'stun:stun.example.com:3478', got %s", cfg.NAT.STUNServer)
}
// Verify logging configuration
if cfg.Logging.Level != "debug" {
t.Errorf("Expected log level 'debug', got %s", cfg.Logging.Level)
}
if cfg.Logging.Output != "file" {
t.Errorf("Expected log output 'file', got %s", cfg.Logging.Output)
}
// Verify monitor configuration
if !cfg.Monitor.Enabled {
t.Error("Expected monitor enabled to be true")
}
if cfg.Monitor.Port != 8081 {
t.Errorf("Expected monitor port 8081, got %d", cfg.Monitor.Port)
}
if !cfg.Monitor.Auth {
t.Error("Expected monitor auth to be true")
}
}
func TestLoadNonexistentFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.yaml")
if err == nil {
t.Error("Expected error for nonexistent file, got nil")
}
}
func TestLoadInvalidYAML(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "invalid-config.yaml")
// Create an invalid YAML file
invalidYAML := `server:
port: 8080
read_timeout: "not a number"
`
err := os.WriteFile(configPath, []byte(invalidYAML), 0644)
if err != nil {
t.Fatalf("Failed to write invalid config file: %v", err)
}
_, err = Load(configPath)
if err == nil {
t.Error("Expected error for invalid YAML, got nil")
}
}
func TestSetDefaults(t *testing.T) {
cfg := &Config{}
// Apply defaults
setDefaults(cfg)
// Verify default server values
if cfg.Server.Port != 8080 {
t.Errorf("Expected default port 8080, got %d", cfg.Server.Port)
}
if cfg.Server.ReadTimeout != 30 {
t.Errorf("Expected default read timeout 30, got %d", cfg.Server.ReadTimeout)
}
// Verify default proxy values
if cfg.Proxy.LoadBalancer != "roundrobin" {
t.Errorf("Expected default load balancer 'roundrobin', got %s", cfg.Proxy.LoadBalancer)
}
if cfg.Proxy.HealthCheckPath != "/health" {
t.Errorf("Expected default health check path '/health', got %s", cfg.Proxy.HealthCheckPath)
}
// Verify default NAT values
if cfg.NAT.Enabled {
t.Error("Expected default NAT enabled to be false")
}
if cfg.NAT.STUNServer != "stun:stun.l.google.com:19302" {
t.Errorf("Expected default STUN server 'stun:stun.l.google.com:19302', got %s", cfg.NAT.STUNServer)
}
// Verify default logging values
if cfg.Logging.Level != "info" {
t.Errorf("Expected default log level 'info', got %s", cfg.Logging.Level)
}
if cfg.Logging.Format != "json" {
t.Errorf("Expected default log format 'json', got %s", cfg.Logging.Format)
}
// Verify default monitor values
if !cfg.Monitor.Enabled {
t.Error("Expected default monitor enabled to be true")
}
if cfg.Monitor.Port != 9090 {
t.Errorf("Expected default monitor port 9090, got %d", cfg.Monitor.Port)
}
}
func TestCreateDefaultConfig(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "default-config.yaml")
// Create default configuration
err := CreateDefaultConfig(configPath)
if err != nil {
t.Fatalf("Failed to create default configuration: %v", err)
}
// Verify the file was created
if _, err := os.Stat(configPath); os.IsNotExist(err) {
t.Error("Expected config file to be created")
}
// Load and verify the configuration
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Failed to load created configuration: %v", err)
}
// Verify some default values
if cfg.Server.Port != 8080 {
t.Errorf("Expected default port 8080, got %d", cfg.Server.Port)
}
if cfg.Proxy.LoadBalancer != "roundrobin" {
t.Errorf("Expected default load balancer 'roundrobin', got %s", cfg.Proxy.LoadBalancer)
}
if cfg.Logging.Level != "info" {
t.Errorf("Expected default log level 'info', got %s", cfg.Logging.Level)
}
}
func TestCreateDefaultConfigDirectoryCreation(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
nestedDir := filepath.Join(tempDir, "nested", "directory")
configPath := filepath.Join(nestedDir, "config.yaml")
// Create default configuration in a nested directory
err := CreateDefaultConfig(configPath)
if err != nil {
t.Fatalf("Failed to create default configuration: %v", err)
}
// Verify the directory was created
if _, err := os.Stat(nestedDir); os.IsNotExist(err) {
t.Error("Expected nested directory to be created")
}
// Verify the file was created
if _, err := os.Stat(configPath); os.IsNotExist(err) {
t.Error("Expected config file to be created")
}
}

124
internal/logger/logger.go Normal file
View File

@@ -0,0 +1,124 @@
package logger
import (
"log"
"os"
)
// Logger represents the application logger
type Logger struct {
debugLogger *log.Logger
infoLogger *log.Logger
warnLogger *log.Logger
errorLogger *log.Logger
}
// Field represents a log field
type Field struct {
Key string
Value interface{}
}
// Option represents a logger option
type Option func(*Logger)
// NewLogger creates a new logger with default settings
func NewLogger(opts ...Option) *Logger {
logger := &Logger{
debugLogger: log.New(os.Stdout, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile),
infoLogger: log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile),
warnLogger: log.New(os.Stdout, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile),
errorLogger: log.New(os.Stdout, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile),
}
// Apply options
for _, opt := range opts {
opt(logger)
}
return logger
}
// Debug logs a debug message
func (l *Logger) Debug(msg string, fields ...Field) {
l.debugLogger.Printf(formatMessage(msg, fields...))
}
// Info logs an info message
func (l *Logger) Info(msg string, fields ...Field) {
l.infoLogger.Printf(formatMessage(msg, fields...))
}
// Warn logs a warning message
func (l *Logger) Warn(msg string, fields ...Field) {
l.warnLogger.Printf(formatMessage(msg, fields...))
}
// Error logs an error message
func (l *Logger) Error(msg string, fields ...Field) {
l.errorLogger.Printf(formatMessage(msg, fields...))
}
// Fatal logs a fatal message and exits
func (l *Logger) Fatal(msg string, fields ...Field) {
l.errorLogger.Printf(formatMessage(msg, fields...))
os.Exit(1)
}
// String creates a string field
func String(key, value string) Field {
return Field{Key: key, Value: value}
}
// Int creates an int field
func Int(key string, value int) Field {
return Field{Key: key, Value: value}
}
// Bool creates a bool field
func Bool(key string, value bool) Field {
return Field{Key: key, Value: value}
}
// Error creates an error field
func Error(err error) Field {
return Field{Key: "error", Value: err}
}
// formatMessage formats a log message with fields
func formatMessage(msg string, fields ...Field) string {
if len(fields) == 0 {
return msg
}
result := msg + " ["
for i, field := range fields {
if i > 0 {
result += ", "
}
result += field.Key + "="
result += toString(field.Value)
}
result += "]"
return result
}
// toString converts a value to string
func toString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case int:
return string(v)
case bool:
if v {
return "true"
}
return "false"
case error:
return v.Error()
default:
return "unknown"
}
}

View File

@@ -0,0 +1,185 @@
package logger
import (
"bytes"
"log"
"os"
"strings"
"testing"
)
func TestNewLogger(t *testing.T) {
logger := NewLogger()
if logger == nil {
t.Error("Expected logger to be created, got nil")
}
}
func TestLoggerDebug(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Debug("Debug message", String("key", "value"))
output := buf.String()
if !strings.Contains(output, "DEBUG: Debug message") {
t.Errorf("Expected log output to contain 'DEBUG: Debug message', got %s", output)
}
if !strings.Contains(output, "key=value") {
t.Errorf("Expected log output to contain 'key=value', got %s", output)
}
}
func TestLoggerInfo(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Info("Info message", Int("number", 42))
output := buf.String()
if !strings.Contains(output, "INFO: Info message") {
t.Errorf("Expected log output to contain 'INFO: Info message', got %s", output)
}
if !strings.Contains(output, "number=42") {
t.Errorf("Expected log output to contain 'number=42', got %s", output)
}
}
func TestLoggerWarn(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Warn("Warning message", Bool("flag", true))
output := buf.String()
if !strings.Contains(output, "WARN: Warning message") {
t.Errorf("Expected log output to contain 'WARN: Warning message', got %s", output)
}
if !strings.Contains(output, "flag=true") {
t.Errorf("Expected log output to contain 'flag=true', got %s", output)
}
}
func TestLoggerError(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
err := os.ErrNotExist
logger.Error("Error message", Error(err))
output := buf.String()
if !strings.Contains(output, "ERROR: Error message") {
t.Errorf("Expected log output to contain 'ERROR: Error message', got %s", output)
}
if !strings.Contains(output, "error=file does not exist") {
t.Errorf("Expected log output to contain 'error=file does not exist', got %s", output)
}
}
func TestLoggerFatal(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
// Mock os.Exit to prevent the test from exiting
exitCalled := false
exitFunc := func(code int) {
exitCalled = true
}
osExit = exitFunc
defer func() {
osExit = realOsExit
}()
logger := NewLogger()
logger.Fatal("Fatal message", String("reason", "testing"))
output := buf.String()
if !strings.Contains(output, "ERROR: Fatal message") {
t.Errorf("Expected log output to contain 'ERROR: Fatal message', got %s", output)
}
if !strings.Contains(output, "reason=testing") {
t.Errorf("Expected log output to contain 'reason=testing', got %s", output)
}
if !exitCalled {
t.Error("Expected os.Exit to be called")
}
}
func TestLoggerMultipleFields(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Info("Multiple fields",
String("string", "value"),
Int("int", 123),
Bool("bool", false))
output := buf.String()
if !strings.Contains(output, "INFO: Multiple fields") {
t.Errorf("Expected log output to contain 'INFO: Multiple fields', got %s", output)
}
if !strings.Contains(output, "string=value") {
t.Errorf("Expected log output to contain 'string=value', got %s", output)
}
if !strings.Contains(output, "int=123") {
t.Errorf("Expected log output to contain 'int=123', got %s", output)
}
if !strings.Contains(output, "bool=false") {
t.Errorf("Expected log output to contain 'bool=false', got %s", output)
}
}
func TestLoggerNoFields(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
logger := NewLogger()
logger.Info("No fields")
output := buf.String()
if !strings.Contains(output, "INFO: No fields") {
t.Errorf("Expected log output to contain 'INFO: No fields', got %s", output)
}
if strings.Contains(output, "[") {
t.Error("Expected log output to not contain field brackets when no fields are provided")
}
}
// Mock os.Exit for testing
var (
osExit = func(code int) { os.Exit(code) }
realOsExit = osExit
)

View File

@@ -0,0 +1,266 @@
package monitoring
import (
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
)
// Metrics represents the application metrics
type Metrics struct {
RequestsTotal int64 `json:"requests_total"`
RequestsActive int64 `json:"requests_active"`
ResponsesByStatus map[string]int64 `json:"responses_by_status"`
TargetMetrics map[string]*TargetMetric `json:"target_metrics"`
StartTime time.Time `json:"start_time"`
LastUpdated time.Time `json:"last_updated"`
mu sync.RWMutex `json:"-"`
}
// TargetMetric represents metrics for a specific target
type TargetMetric struct {
RequestsTotal int64 `json:"requests_total"`
ResponsesByStatus map[string]int64 `json:"responses_by_status"`
ResponseTimes []time.Duration `json:"response_times"`
AvgResponseTime time.Duration `json:"avg_response_time"`
Healthy bool `json:"healthy"`
LastChecked time.Time `json:"last_checked"`
}
// Monitor represents the monitoring service
type Monitor struct {
config *config.Config
logger *logger.Logger
metrics *Metrics
server *http.Server
authHandler http.Handler
}
// NewMonitor creates a new monitoring service
func NewMonitor(cfg *config.Config, logger *logger.Logger) *Monitor {
metrics := &Metrics{
ResponsesByStatus: make(map[string]int64),
TargetMetrics: make(map[string]*TargetMetric),
StartTime: time.Now(),
LastUpdated: time.Now(),
}
// Initialize target metrics
for _, target := range cfg.Proxy.Targets {
metrics.TargetMetrics[target.Name] = &TargetMetric{
ResponsesByStatus: make(map[string]int64),
ResponseTimes: make([]time.Duration, 0),
Healthy: target.Healthy,
LastChecked: time.Now(),
}
}
monitor := &Monitor{
config: cfg,
logger: logger,
metrics: metrics,
}
// Set up authentication if enabled
if cfg.Monitor.Auth {
monitor.authHandler = monitor.basicAuthHandler(monitor.metricsHandler)
} else {
monitor.authHandler = monitor.metricsHandler
}
return monitor
}
// Start starts the monitoring service
func (m *Monitor) Start() error {
if !m.config.Monitor.Enabled {
m.logger.Info("Monitoring is disabled")
return nil
}
// Create HTTP server
m.server = &http.Server{
Addr: fmt.Sprintf(":%d", m.config.Monitor.Port),
Handler: m.authHandler,
}
// Start server in a goroutine
go func() {
m.logger.Info("Monitoring server starting", logger.String("address", m.server.Addr))
if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
m.logger.Error("Monitoring server failed", logger.Error(err))
}
}()
return nil
}
// Stop stops the monitoring service
func (m *Monitor) Stop() {
if m.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
m.server.Shutdown(ctx)
m.logger.Info("Monitoring server stopped")
}
}
// metricsHandler handles HTTP requests for metrics
func (m *Monitor) metricsHandler(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != m.config.Monitor.Path {
http.NotFound(w, r)
return
}
m.metrics.mu.RLock()
defer m.metrics.mu.RUnlock()
// Update last updated time
m.metrics.LastUpdated = time.Now()
// Set content type
w.Header().Set("Content-Type", "application/json")
// Encode metrics as JSON
if err := json.NewEncoder(w).Encode(m.metrics); err != nil {
m.logger.Error("Failed to encode metrics", logger.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
}
// basicAuthHandler wraps a handler with basic authentication
func (m *Monitor) basicAuthHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || username != m.config.Monitor.Username || password != m.config.Monitor.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="goRZ Monitor"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next(w, r)
}
}
// IncrementRequest increments the total request count
func (m *Monitor) IncrementRequest() {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
m.metrics.RequestsTotal++
m.metrics.LastUpdated = time.Now()
}
// IncrementActiveRequest increments the active request count
func (m *Monitor) IncrementActiveRequest() {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
m.metrics.RequestsActive++
m.metrics.LastUpdated = time.Now()
}
// DecrementActiveRequest decrements the active request count
func (m *Monitor) DecrementActiveRequest() {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
m.metrics.RequestsActive--
if m.metrics.RequestsActive < 0 {
m.metrics.RequestsActive = 0
}
m.metrics.LastUpdated = time.Now()
}
// RecordResponse records a response with the given status code
func (m *Monitor) RecordResponse(statusCode int, targetName string, responseTime time.Duration) {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
status := fmt.Sprintf("%d", statusCode)
m.metrics.ResponsesByStatus[status]++
// Update target metrics
if target, exists := m.metrics.TargetMetrics[targetName]; exists {
target.RequestsTotal++
target.ResponsesByStatus[status]++
// Keep only the last 100 response times for average calculation
if len(target.ResponseTimes) >= 100 {
target.ResponseTimes = target.ResponseTimes[1:]
}
target.ResponseTimes = append(target.ResponseTimes, responseTime)
// Calculate average response time
var total time.Duration
for _, rt := range target.ResponseTimes {
total += rt
}
target.AvgResponseTime = total / time.Duration(len(target.ResponseTimes))
}
m.metrics.LastUpdated = time.Now()
}
// UpdateTargetHealth updates the health status of a target
func (m *Monitor) UpdateTargetHealth(targetName string, healthy bool) {
m.metrics.mu.Lock()
defer m.metrics.mu.Unlock()
if target, exists := m.metrics.TargetMetrics[targetName]; exists {
target.Healthy = healthy
target.LastChecked = time.Now()
}
m.metrics.LastUpdated = time.Now()
}
// GetMetrics returns a copy of the current metrics
func (m *Monitor) GetMetrics() Metrics {
m.metrics.mu.RLock()
defer m.metrics.mu.RUnlock()
// Create a deep copy of the metrics
metrics := Metrics{
RequestsTotal: m.metrics.RequestsTotal,
RequestsActive: m.metrics.RequestsActive,
ResponsesByStatus: make(map[string]int64),
TargetMetrics: make(map[string]*TargetMetric),
StartTime: m.metrics.StartTime,
LastUpdated: m.metrics.LastUpdated,
}
// Copy response status counts
for k, v := range m.metrics.ResponsesByStatus {
metrics.ResponsesByStatus[k] = v
}
// Copy target metrics
for k, v := range m.metrics.TargetMetrics {
targetMetric := &TargetMetric{
RequestsTotal: v.RequestsTotal,
ResponsesByStatus: make(map[string]int64),
ResponseTimes: make([]time.Duration, len(v.ResponseTimes)),
AvgResponseTime: v.AvgResponseTime,
Healthy: v.Healthy,
LastChecked: v.LastChecked,
}
// Copy response status counts for target
for rk, rv := range v.ResponsesByStatus {
targetMetric.ResponsesByStatus[rk] = rv
}
// Copy response times
copy(targetMetric.ResponseTimes, v.ResponseTimes)
metrics.TargetMetrics[k] = targetMetric
}
return metrics
}

View File

@@ -0,0 +1,312 @@
package monitoring
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
)
func TestMonitorStartStop(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Start the monitor
err := monitor.Start()
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
// Stop the monitor
monitor.Stop()
// Test that we can start and stop again
err = monitor.Start()
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
monitor.Stop()
}
func TestMonitorDisabled(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: false,
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Start the monitor
err := monitor.Start()
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
// Stop the monitor
monitor.Stop()
}
func TestMonitorMetricsHandler(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
Proxy: config.ProxyConfig{
Targets: []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
},
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Create a test HTTP server
server := httptest.NewServer(monitor.authHandler)
defer server.Close()
// Test metrics endpoint
resp, err := http.Get(server.URL + cfg.Monitor.Path)
if err != nil {
t.Fatalf("Failed to get metrics: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
// Decode and check metrics
var metrics Metrics
if err := json.NewDecoder(resp.Body).Decode(&metrics); err != nil {
t.Fatalf("Failed to decode metrics: %v", err)
}
if metrics.RequestsTotal != 0 {
t.Errorf("Expected requests total 0, got %d", metrics.RequestsTotal)
}
if metrics.RequestsActive != 0 {
t.Errorf("Expected requests active 0, got %d", metrics.RequestsActive)
}
if len(metrics.TargetMetrics) != 2 {
t.Errorf("Expected 2 target metrics, got %d", len(metrics.TargetMetrics))
}
if !metrics.TargetMetrics["target1"].Healthy {
t.Error("Expected target1 to be healthy")
}
if metrics.TargetMetrics["target2"].Healthy {
t.Error("Expected target2 to be unhealthy")
}
}
func TestMonitorBasicAuth(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: true,
Username: "admin",
Password: "secret",
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Create a test HTTP server
server := httptest.NewServer(monitor.authHandler)
defer server.Close()
// Test without authentication
resp, err := http.Get(server.URL + cfg.Monitor.Path)
if err != nil {
t.Fatalf("Failed to get metrics: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
// Test with incorrect authentication
req, err := http.NewRequest("GET", server.URL+cfg.Monitor.Path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.SetBasicAuth("wrong", "credentials")
client := &http.Client{}
resp, err = client.Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
// Test with correct authentication
req, err = http.NewRequest("GET", server.URL+cfg.Monitor.Path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.SetBasicAuth(cfg.Monitor.Username, cfg.Monitor.Password)
resp, err = client.Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
}
func TestMonitorMetricsTracking(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
Proxy: config.ProxyConfig{
Targets: []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
},
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Test incrementing request count
monitor.IncrementRequest()
metrics := monitor.GetMetrics()
if metrics.RequestsTotal != 1 {
t.Errorf("Expected requests total 1, got %d", metrics.RequestsTotal)
}
// Test incrementing active request count
monitor.IncrementActiveRequest()
metrics = monitor.GetMetrics()
if metrics.RequestsActive != 1 {
t.Errorf("Expected requests active 1, got %d", metrics.RequestsActive)
}
// Test decrementing active request count
monitor.DecrementActiveRequest()
metrics = monitor.GetMetrics()
if metrics.RequestsActive != 0 {
t.Errorf("Expected requests active 0, got %d", metrics.RequestsActive)
}
// Test recording response
monitor.RecordResponse(http.StatusOK, "target1", 100*time.Millisecond)
metrics = monitor.GetMetrics()
if metrics.ResponsesByStatus["200"] != 1 {
t.Errorf("Expected 1 response with status 200, got %d", metrics.ResponsesByStatus["200"])
}
if metrics.TargetMetrics["target1"].RequestsTotal != 1 {
t.Errorf("Expected target1 to have 1 request, got %d", metrics.TargetMetrics["target1"].RequestsTotal)
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["200"] != 1 {
t.Errorf("Expected target1 to have 1 response with status 200, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["200"])
}
if len(metrics.TargetMetrics["target1"].ResponseTimes) != 1 {
t.Errorf("Expected target1 to have 1 response time, got %d", len(metrics.TargetMetrics["target1"].ResponseTimes))
}
if metrics.TargetMetrics["target1"].AvgResponseTime != 100*time.Millisecond {
t.Errorf("Expected target1 to have average response time 100ms, got %v", metrics.TargetMetrics["target1"].AvgResponseTime)
}
// Test updating target health
monitor.UpdateTargetHealth("target1", false)
metrics = monitor.GetMetrics()
if metrics.TargetMetrics["target1"].Healthy {
t.Error("Expected target1 to be unhealthy")
}
}
func TestMonitorMultipleResponses(t *testing.T) {
cfg := &config.Config{
Monitor: config.MonitorConfig{
Enabled: true,
Port: 9090,
Path: "/metrics",
Auth: false,
},
Proxy: config.ProxyConfig{
Targets: []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
},
},
}
logger := logger.NewLogger()
monitor := NewMonitor(cfg, logger)
// Record multiple responses with different status codes and response times
monitor.RecordResponse(http.StatusOK, "target1", 100*time.Millisecond)
monitor.RecordResponse(http.StatusOK, "target1", 200*time.Millisecond)
monitor.RecordResponse(http.StatusNotFound, "target1", 50*time.Millisecond)
monitor.RecordResponse(http.StatusInternalServerError, "target1", 300*time.Millisecond)
metrics := monitor.GetMetrics()
// Check overall metrics
if metrics.RequestsTotal != 0 {
t.Errorf("Expected requests total 0, got %d", metrics.RequestsTotal)
}
if metrics.ResponsesByStatus["200"] != 2 {
t.Errorf("Expected 2 responses with status 200, got %d", metrics.ResponsesByStatus["200"])
}
if metrics.ResponsesByStatus["404"] != 1 {
t.Errorf("Expected 1 response with status 404, got %d", metrics.ResponsesByStatus["404"])
}
if metrics.ResponsesByStatus["500"] != 1 {
t.Errorf("Expected 1 response with status 500, got %d", metrics.ResponsesByStatus["500"])
}
// Check target metrics
if metrics.TargetMetrics["target1"].RequestsTotal != 4 {
t.Errorf("Expected target1 to have 4 requests, got %d", metrics.TargetMetrics["target1"].RequestsTotal)
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["200"] != 2 {
t.Errorf("Expected target1 to have 2 responses with status 200, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["200"])
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["404"] != 1 {
t.Errorf("Expected target1 to have 1 response with status 404, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["404"])
}
if metrics.TargetMetrics["target1"].ResponsesByStatus["500"] != 1 {
t.Errorf("Expected target1 to have 1 response with status 500, got %d", metrics.TargetMetrics["target1"].ResponsesByStatus["500"])
}
// Check average response time
expectedAvg := time.Duration((100 + 200 + 50 + 300) / 4)
if metrics.TargetMetrics["target1"].AvgResponseTime != expectedAvg {
t.Errorf("Expected target1 to have average response time %v, got %v", expectedAvg, metrics.TargetMetrics["target1"].AvgResponseTime)
}
}

231
internal/nat/nat.go Normal file
View File

@@ -0,0 +1,231 @@
package nat
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
"github.com/pion/stun"
"github.com/pion/turn/v2"
)
// NATTraversal handles NAT traversal for the proxy server
type NATTraversal struct {
config *config.Config
logger *logger.Logger
conn net.PacketConn
externalIP net.IP
externalPort int
mu sync.RWMutex
running bool
}
// NewNATTraversal creates a new NAT traversal instance
func NewNATTraversal(cfg *config.Config, logger *logger.Logger) *NATTraversal {
return &NATTraversal{
config: cfg,
logger: logger,
}
}
// Start starts the NAT traversal process
func (n *NATTraversal) Start() error {
if !n.config.NAT.Enabled {
n.logger.Info("NAT traversal is disabled")
return nil
}
n.mu.Lock()
defer n.mu.Unlock()
if n.running {
return fmt.Errorf("NAT traversal is already running")
}
// Create a UDP listener
conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
return fmt.Errorf("failed to create UDP listener: %w", err)
}
n.conn = conn
// Get external IP and port using STUN
if err := n.discoverExternalAddress(); err != nil {
conn.Close()
return fmt.Errorf("failed to discover external address: %w", err)
}
// Start TURN client if configured
if n.config.NAT.TURNServer != "" {
if err := n.startTURNClient(); err != nil {
conn.Close()
return fmt.Errorf("failed to start TURN client: %w", err)
}
}
n.running = true
n.logger.Info("NAT traversal started",
logger.String("external_ip", n.externalIP.String()),
logger.Int("external_port", n.externalPort))
return nil
}
// Stop stops the NAT traversal process
func (n *NATTraversal) Stop() {
n.mu.Lock()
defer n.mu.Unlock()
if !n.running {
return
}
if n.conn != nil {
n.conn.Close()
n.conn = nil
}
n.running = false
n.logger.Info("NAT traversal stopped")
}
// discoverExternalAddress discovers the external IP and port using STUN
func (n *NATTraversal) discoverExternalAddress() error {
// Parse STUN server URL
stunURL, err := stun.ParseURI(n.config.NAT.STUNServer)
if err != nil {
return fmt.Errorf("failed to parse STUN server URL: %w", err)
}
// Create STUN client
client, err := stun.NewClient("udp4", n.conn)
if err != nil {
return fmt.Errorf("failed to create STUN client: %w", err)
}
defer client.Close()
// Send binding request
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mappedAddr, err := client.Request(ctx, stunURL, stun.BindingRequest)
if err != nil {
return fmt.Errorf("STUN binding request failed: %w", err)
}
n.externalIP = mappedAddr.IP
n.externalPort = mappedAddr.Port
return nil
}
// startTURNClient starts a TURN client for relay connections
func (n *NATTraversal) startTURNClient() error {
// Parse TURN server URL
turnURL, err := turn.ParseURI(n.config.NAT.TURNServer)
if err != nil {
return fmt.Errorf("failed to parse TURN server URL: %w", err)
}
// Create TURN client config
cfg := &turn.ClientConfig{
STUNServerAddr: n.config.NAT.STUNServer,
TURNServerAddr: n.config.NAT.TURNServer,
Username: n.config.NAT.TURNUsername,
Credential: n.config.NAT.TURNPassword,
LoggerFactory: nil, // We'll use our own logger
}
// Create TURN client
client, err := turn.NewClient(cfg)
if err != nil {
return fmt.Errorf("failed to create TURN client: %w", err)
}
defer client.Close()
// Listen on provided conn
n.logger.Info("TURN client started", logger.String("server", n.config.NAT.TURNServer))
return nil
}
// GetExternalAddress returns the external IP and port
func (n *NATTraversal) GetExternalAddress() (net.IP, int) {
n.mu.RLock()
defer n.mu.RUnlock()
return n.externalIP, n.externalPort
}
// IsRunning returns whether NAT traversal is running
func (n *NATTraversal) IsRunning() bool {
n.mu.RLock()
defer n.mu.RUnlock()
return n.running
}
// CreateHolePunch attempts to create a hole in the NAT for direct peer-to-peer connections
func (n *NATTraversal) CreateHolePunch(peerAddr string) (net.Conn, error) {
if !n.running {
return nil, fmt.Errorf("NAT traversal is not running")
}
// Parse peer address
udpAddr, err := net.ResolveUDPAddr("udp4", peerAddr)
if err != nil {
return nil, fmt.Errorf("failed to resolve peer address: %w", err)
}
// Create UDP connection
conn, err := net.DialUDP("udp4", nil, udpAddr)
if err != nil {
return nil, fmt.Errorf("failed to create UDP connection: %w", err)
}
// Send a small packet to punch a hole in the NAT
_, err = conn.Write([]byte("punch"))
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to punch hole: %w", err)
}
n.logger.Debug("Hole punch attempted", logger.String("peer", peerAddr))
return conn, nil
}
// GetNATType attempts to determine the type of NAT
func (n *NATTraversal) GetNATType() (string, error) {
if !n.running {
return "", fmt.Errorf("NAT traversal is not running")
}
// This is a simplified NAT type detection
// In a real implementation, you would use more sophisticated methods
// like the one described in RFC 3489 and RFC 5780
// For now, we'll return a generic response
return "Unknown", nil
}
// CreateRelayConnection creates a relay connection through TURN
func (n *NATTraversal) CreateRelayConnection(peerAddr string) (net.Conn, error) {
if !n.running {
return nil, fmt.Errorf("NAT traversal is not running")
}
if n.config.NAT.TURNServer == "" {
return nil, fmt.Errorf("TURN server not configured")
}
// This is a placeholder for TURN relay connection creation
// In a real implementation, you would use the TURN client to allocate
// a relay and create a connection to the peer
return nil, fmt.Errorf("not implemented")
}

View File

@@ -0,0 +1,217 @@
package proxy
import (
"math/rand"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
)
// RoundRobinLoadBalancer implements round-robin load balancing
type RoundRobinLoadBalancer struct {
targets []*config.TargetConfig
current int
mu sync.Mutex
}
// NewRoundRobinLoadBalancer creates a new round-robin load balancer
func NewRoundRobinLoadBalancer(targets []config.TargetConfig) *RoundRobinLoadBalancer {
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
return &RoundRobinLoadBalancer{
targets: t,
current: 0,
}
}
// NextTarget returns the next target using round-robin algorithm
func (lb *RoundRobinLoadBalancer) NextTarget() (*config.TargetConfig, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
// Filter healthy targets
healthyTargets := make([]*config.TargetConfig, 0)
for _, target := range lb.targets {
if target.Healthy {
healthyTargets = append(healthyTargets, target)
}
}
if len(healthyTargets) == 0 {
return nil, ErrNoHealthyTargets
}
// Get next target
target := healthyTargets[lb.current%len(healthyTargets)]
lb.current = (lb.current + 1) % len(healthyTargets)
return target, nil
}
// UpdateTargets updates the targets list
func (lb *RoundRobinLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
lb.mu.Lock()
defer lb.mu.Unlock()
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
lb.targets = t
lb.current = 0
}
// RandomLoadBalancer implements random load balancing
type RandomLoadBalancer struct {
targets []*config.TargetConfig
rand *rand.Rand
mu sync.Mutex
}
// NewRandomLoadBalancer creates a new random load balancer
func NewRandomLoadBalancer(targets []config.TargetConfig) *RandomLoadBalancer {
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
return &RandomLoadBalancer{
targets: t,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// NextTarget returns a random target
func (lb *RandomLoadBalancer) NextTarget() (*config.TargetConfig, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
// Filter healthy targets
healthyTargets := make([]*config.TargetConfig, 0)
for _, target := range lb.targets {
if target.Healthy {
healthyTargets = append(healthyTargets, target)
}
}
if len(healthyTargets) == 0 {
return nil, ErrNoHealthyTargets
}
// Get random target
index := lb.rand.Intn(len(healthyTargets))
return healthyTargets[index], nil
}
// UpdateTargets updates the targets list
func (lb *RandomLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
lb.mu.Lock()
defer lb.mu.Unlock()
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
lb.targets = t
}
// LeastConnectionsLoadBalancer implements least connections load balancing
type LeastConnectionsLoadBalancer struct {
targets []*config.TargetConfig
connections map[string]int
mu sync.Mutex
}
// NewLeastConnectionsLoadBalancer creates a new least connections load balancer
func NewLeastConnectionsLoadBalancer(targets []config.TargetConfig) *LeastConnectionsLoadBalancer {
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
connections := make(map[string]int)
for _, target := range t {
connections[target.Name] = 0
}
return &LeastConnectionsLoadBalancer{
targets: t,
connections: connections,
}
}
// NextTarget returns the target with the least connections
func (lb *LeastConnectionsLoadBalancer) NextTarget() (*config.TargetConfig, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
// Filter healthy targets and find the one with least connections
var selectedTarget *config.TargetConfig
minConnections := -1
for _, target := range lb.targets {
if target.Healthy {
connections := lb.connections[target.Name]
if minConnections == -1 || connections < minConnections {
minConnections = connections
selectedTarget = target
}
}
}
if selectedTarget == nil {
return nil, ErrNoHealthyTargets
}
// Increment connection count
lb.connections[selectedTarget.Name]++
return selectedTarget, nil
}
// ReleaseConnection decrements the connection count for a target
func (lb *LeastConnectionsLoadBalancer) ReleaseConnection(targetName string) {
lb.mu.Lock()
defer lb.mu.Unlock()
if count, exists := lb.connections[targetName]; exists && count > 0 {
lb.connections[targetName] = count - 1
}
}
// UpdateTargets updates the targets list
func (lb *LeastConnectionsLoadBalancer) UpdateTargets(targets []config.TargetConfig) {
lb.mu.Lock()
defer lb.mu.Unlock()
t := make([]*config.TargetConfig, len(targets))
for i := range targets {
t[i] = &targets[i]
}
// Update connections map
connections := make(map[string]int)
for _, target := range t {
// Preserve existing connection count if target exists
if count, exists := lb.connections[target.Name]; exists {
connections[target.Name] = count
} else {
connections[target.Name] = 0
}
}
lb.targets = t
lb.connections = connections
}
// ErrNoHealthyTargets is returned when no healthy targets are available
var ErrNoHealthyTargets = errorString("no healthy targets available")
// errorString is a simple string-based error type
type errorString string
func (e errorString) Error() string {
return string(e)
}

View File

@@ -0,0 +1,270 @@
package proxy
import (
"testing"
"github.com/iwasforcedtobehere/goRZ/internal/config"
)
func TestRoundRobinLoadBalancer(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRoundRobinLoadBalancer(targets)
// Test that targets are selected in round-robin order
for i := 0; i < 6; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
expectedTarget := targets[i%3]
if target.Name != expectedTarget.Name {
t.Errorf("Expected target %s, got %s", expectedTarget.Name, target.Name)
}
}
}
func TestRoundRobinLoadBalancerWithUnhealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRoundRobinLoadBalancer(targets)
// Test that only healthy targets are selected
for i := 0; i < 4; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name == "target2" {
t.Errorf("Selected unhealthy target: %s", target.Name)
}
// Should alternate between target1 and target3
if i%2 == 0 && target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
if i%2 == 1 && target.Name != "target3" {
t.Errorf("Expected target3, got %s", target.Name)
}
}
}
func TestRoundRobinLoadBalancerNoHealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
}
lb := NewRoundRobinLoadBalancer(targets)
_, err := lb.NextTarget()
if err != ErrNoHealthyTargets {
t.Errorf("Expected ErrNoHealthyTargets, got %v", err)
}
}
func TestRoundRobinLoadBalancerUpdateTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRoundRobinLoadBalancer(targets)
// Get first target
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Update targets
newTargets := []config.TargetConfig{
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb.UpdateTargets(newTargets)
// Test that new targets are selected
for i := 0; i < 4; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
expectedTarget := newTargets[i%2]
if target.Name != expectedTarget.Name {
t.Errorf("Expected target %s, got %s", expectedTarget.Name, target.Name)
}
}
}
func TestRandomLoadBalancer(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRandomLoadBalancer(targets)
// Test that targets are selected randomly
// We'll just check that we don't get errors and that all targets are eventually selected
selected := make(map[string]bool)
for i := 0; i < 100; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
selected[target.Name] = true
}
// Check that all targets were selected at least once
for _, target := range targets {
if !selected[target.Name] {
t.Errorf("Target %s was never selected", target.Name)
}
}
}
func TestRandomLoadBalancerWithUnhealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewRandomLoadBalancer(targets)
// Test that only healthy targets are selected
for i := 0; i < 100; i++ {
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name == "target2" {
t.Errorf("Selected unhealthy target: %s", target.Name)
}
}
}
func TestLeastConnectionsLoadBalancer(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewLeastConnectionsLoadBalancer(targets)
// Initially, both targets should have 0 connections
// The first target should be selected
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Now target1 should have 1 connection, target2 should have 0
// So target2 should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target2" {
t.Errorf("Expected target2, got %s", target.Name)
}
// Now both targets should have 1 connection
// target1 should be selected again
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Release a connection from target1
lb.(*LeastConnectionsLoadBalancer).ReleaseConnection("target1")
// Now target1 should have 1 connection, target2 should have 1 connection
// But since we released from target1, it should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
}
func TestLeastConnectionsLoadBalancerUpdateTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: true},
}
lb := NewLeastConnectionsLoadBalancer(targets)
// Get first target
target, err := lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target1" {
t.Errorf("Expected target1, got %s", target.Name)
}
// Update targets
newTargets := []config.TargetConfig{
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: true},
{Name: "target3", Address: "http://localhost:8083", Protocol: "http", Weight: 1, Healthy: true},
}
lb.UpdateTargets(newTargets)
// Test that new targets are selected
// The first new target should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target2" {
t.Errorf("Expected target2, got %s", target.Name)
}
// The second new target should be selected
target, err = lb.NextTarget()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if target.Name != "target3" {
t.Errorf("Expected target3, got %s", target.Name)
}
}
func TestLeastConnectionsLoadBalancerNoHealthyTargets(t *testing.T) {
targets := []config.TargetConfig{
{Name: "target1", Address: "http://localhost:8081", Protocol: "http", Weight: 1, Healthy: false},
{Name: "target2", Address: "http://localhost:8082", Protocol: "http", Weight: 1, Healthy: false},
}
lb := NewLeastConnectionsLoadBalancer(targets)
_, err := lb.NextTarget()
if err != ErrNoHealthyTargets {
t.Errorf("Expected ErrNoHealthyTargets, got %v", err)
}
}

238
internal/proxy/proxy.go Normal file
View File

@@ -0,0 +1,238 @@
package proxy
import (
"context"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
"github.com/iwasforcedtobehere/goRZ/internal/config"
"github.com/iwasforcedtobehere/goRZ/internal/logger"
)
// Server represents the reverse proxy server
type Server struct {
config *config.Config
logger *logger.Logger
loadBalancer LoadBalancer
healthChecker *HealthChecker
}
// LoadBalancer defines the interface for load balancing strategies
type LoadBalancer interface {
NextTarget() (*config.TargetConfig, error)
UpdateTargets(targets []config.TargetConfig)
}
// HealthChecker handles health checking for proxy targets
type HealthChecker struct {
config *config.Config
logger *logger.Logger
targets map[string]*config.TargetConfig
httpClient *http.Client
stopChan chan struct{}
mu sync.RWMutex
}
// NewServer creates a new reverse proxy server
func NewServer(cfg *config.Config, logger *logger.Logger) (*Server, error) {
// Create load balancer based on configuration
var lb LoadBalancer
switch cfg.Proxy.LoadBalancer {
case "roundrobin":
lb = NewRoundRobinLoadBalancer(cfg.Proxy.Targets)
case "leastconn":
lb = NewLeastConnectionsLoadBalancer(cfg.Proxy.Targets)
case "random":
lb = NewRandomLoadBalancer(cfg.Proxy.Targets)
default:
return nil, fmt.Errorf("unsupported load balancer: %s", cfg.Proxy.LoadBalancer)
}
// Create health checker
healthChecker := NewHealthChecker(cfg, logger)
return &Server{
config: cfg,
logger: logger,
loadBalancer: lb,
healthChecker: healthChecker,
}, nil
}
// ServeHTTP handles HTTP requests
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Get next target based on load balancing strategy
target, err := s.loadBalancer.NextTarget()
if err != nil {
s.logger.Error("Failed to get target", logger.Error(err))
http.Error(w, "Service unavailable", http.StatusServiceUnavailable)
return
}
// Check if target is healthy
if !target.Healthy {
s.logger.Warn("Target is unhealthy", logger.String("target", target.Name))
http.Error(w, "Service unavailable", http.StatusServiceUnavailable)
return
}
// Create reverse proxy
targetURL, err := url.Parse(target.Address)
if err != nil {
s.logger.Error("Failed to parse target URL",
logger.String("target", target.Name),
logger.String("address", target.Address),
logger.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Set up custom director to modify request
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
// Add custom headers
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("X-Forwarded-Proto", "https")
if req.TLS != nil {
req.Header.Set("X-Forwarded-Ssl", "on")
}
// Add proxy identification
req.Header.Set("X-Proxy-Server", "goRZ")
}
// Set up error handler
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.logger.Error("Proxy error", logger.Error(err))
http.Error(w, "Proxy error", http.StatusBadGateway)
}
// Set up modify response handler
proxy.ModifyResponse = func(resp *http.Response) error {
// Add custom headers to response
resp.Header.Set("X-Served-By", "goRZ")
resp.Header.Set("X-Target", target.Name)
return nil
}
// Serve the request
s.logger.Debug("Proxying request",
logger.String("method", r.Method),
logger.String("path", r.URL.Path),
logger.String("target", target.Name))
proxy.ServeHTTP(w, r)
}
// Start starts the proxy server
func (s *Server) Start() error {
// Start health checker
if err := s.healthChecker.Start(); err != nil {
return fmt.Errorf("failed to start health checker: %w", err)
}
return nil
}
// Stop stops the proxy server
func (s *Server) Stop() {
s.healthChecker.Stop()
}
// NewHealthChecker creates a new health checker
func NewHealthChecker(cfg *config.Config, logger *logger.Logger) *HealthChecker {
targets := make(map[string]*config.TargetConfig)
for i := range cfg.Proxy.Targets {
targets[cfg.Proxy.Targets[i].Name] = &cfg.Proxy.Targets[i]
}
return &HealthChecker{
config: cfg,
logger: logger,
targets: targets,
httpClient: &http.Client{Timeout: 5 * time.Second},
stopChan: make(chan struct{}),
}
}
// Start starts the health checker
func (h *HealthChecker) Start() error {
// Initial health check
h.checkAllTargets()
// Start periodic health checks
go func() {
ticker := time.NewTicker(time.Duration(h.config.Proxy.HealthCheckInterval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.checkAllTargets()
case <-h.stopChan:
return
}
}
}()
return nil
}
// Stop stops the health checker
func (h *HealthChecker) Stop() {
close(h.stopChan)
}
// checkAllTargets checks the health of all targets
func (h *HealthChecker) checkAllTargets() {
h.mu.Lock()
defer h.mu.Unlock()
for name, target := range h.targets {
healthy := h.checkTargetHealth(target)
if target.Healthy != healthy {
h.logger.Info("Target health status changed",
logger.String("target", name),
logger.Bool("healthy", healthy))
target.Healthy = healthy
}
}
}
// checkTargetHealth checks the health of a single target
func (h *HealthChecker) checkTargetHealth(target *config.TargetConfig) bool {
targetURL, err := url.Parse(target.Address)
if err != nil {
h.logger.Error("Failed to parse target URL for health check",
logger.String("target", target.Name),
logger.Error(err))
return false
}
healthURL := *targetURL
healthURL.Path = h.config.Proxy.HealthCheckPath
req, err := http.NewRequest("GET", healthURL.String(), nil)
if err != nil {
h.logger.Error("Failed to create health check request",
logger.String("target", target.Name),
logger.Error(err))
return false
}
resp, err := h.httpClient.Do(req)
if err != nil {
h.logger.Error("Health check request failed",
logger.String("target", target.Name),
logger.Error(err))
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}