package proxy import ( "context" "fmt" "net/http" "net/http/httputil" "net/url" "sync" "time" "github.com/iwasforcedtobehere/goRZ/internal/config" "github.com/iwasforcedtobehere/goRZ/internal/logger" ) // Server represents the reverse proxy server type Server struct { config *config.Config logger *logger.Logger loadBalancer LoadBalancer healthChecker *HealthChecker } // LoadBalancer defines the interface for load balancing strategies type LoadBalancer interface { NextTarget() (*config.TargetConfig, error) UpdateTargets(targets []config.TargetConfig) } // HealthChecker handles health checking for proxy targets type HealthChecker struct { config *config.Config logger *logger.Logger targets map[string]*config.TargetConfig httpClient *http.Client stopChan chan struct{} mu sync.RWMutex } // NewServer creates a new reverse proxy server func NewServer(cfg *config.Config, logger *logger.Logger) (*Server, error) { // Create load balancer based on configuration var lb LoadBalancer switch cfg.Proxy.LoadBalancer { case "roundrobin": lb = NewRoundRobinLoadBalancer(cfg.Proxy.Targets) case "leastconn": lb = NewLeastConnectionsLoadBalancer(cfg.Proxy.Targets) case "random": lb = NewRandomLoadBalancer(cfg.Proxy.Targets) default: return nil, fmt.Errorf("unsupported load balancer: %s", cfg.Proxy.LoadBalancer) } // Create health checker healthChecker := NewHealthChecker(cfg, logger) return &Server{ config: cfg, logger: logger, loadBalancer: lb, healthChecker: healthChecker, }, nil } // ServeHTTP handles HTTP requests func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Get next target based on load balancing strategy target, err := s.loadBalancer.NextTarget() if err != nil { s.logger.Error("Failed to get target", logger.Error(err)) http.Error(w, "Service unavailable", http.StatusServiceUnavailable) return } // Check if target is healthy if !target.Healthy { s.logger.Warn("Target is unhealthy", logger.String("target", target.Name)) http.Error(w, "Service unavailable", http.StatusServiceUnavailable) return } // Create reverse proxy targetURL, err := url.Parse(target.Address) if err != nil { s.logger.Error("Failed to parse target URL", logger.String("target", target.Name), logger.String("address", target.Address), logger.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } proxy := httputil.NewSingleHostReverseProxy(targetURL) // Set up custom director to modify request originalDirector := proxy.Director proxy.Director = func(req *http.Request) { originalDirector(req) // Add custom headers req.Header.Set("X-Forwarded-Host", req.Host) req.Header.Set("X-Forwarded-Proto", "https") if req.TLS != nil { req.Header.Set("X-Forwarded-Ssl", "on") } // Add proxy identification req.Header.Set("X-Proxy-Server", "goRZ") } // Set up error handler proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { s.logger.Error("Proxy error", logger.Error(err)) http.Error(w, "Proxy error", http.StatusBadGateway) } // Set up modify response handler proxy.ModifyResponse = func(resp *http.Response) error { // Add custom headers to response resp.Header.Set("X-Served-By", "goRZ") resp.Header.Set("X-Target", target.Name) return nil } // Serve the request s.logger.Debug("Proxying request", logger.String("method", r.Method), logger.String("path", r.URL.Path), logger.String("target", target.Name)) proxy.ServeHTTP(w, r) } // Start starts the proxy server func (s *Server) Start() error { // Start health checker if err := s.healthChecker.Start(); err != nil { return fmt.Errorf("failed to start health checker: %w", err) } return nil } // Stop stops the proxy server func (s *Server) Stop() { s.healthChecker.Stop() } // NewHealthChecker creates a new health checker func NewHealthChecker(cfg *config.Config, logger *logger.Logger) *HealthChecker { targets := make(map[string]*config.TargetConfig) for i := range cfg.Proxy.Targets { targets[cfg.Proxy.Targets[i].Name] = &cfg.Proxy.Targets[i] } return &HealthChecker{ config: cfg, logger: logger, targets: targets, httpClient: &http.Client{Timeout: 5 * time.Second}, stopChan: make(chan struct{}), } } // Start starts the health checker func (h *HealthChecker) Start() error { // Initial health check h.checkAllTargets() // Start periodic health checks go func() { ticker := time.NewTicker(time.Duration(h.config.Proxy.HealthCheckInterval) * time.Second) defer ticker.Stop() for { select { case <-ticker.C: h.checkAllTargets() case <-h.stopChan: return } } }() return nil } // Stop stops the health checker func (h *HealthChecker) Stop() { close(h.stopChan) } // checkAllTargets checks the health of all targets func (h *HealthChecker) checkAllTargets() { h.mu.Lock() defer h.mu.Unlock() for name, target := range h.targets { healthy := h.checkTargetHealth(target) if target.Healthy != healthy { h.logger.Info("Target health status changed", logger.String("target", name), logger.Bool("healthy", healthy)) target.Healthy = healthy } } } // checkTargetHealth checks the health of a single target func (h *HealthChecker) checkTargetHealth(target *config.TargetConfig) bool { targetURL, err := url.Parse(target.Address) if err != nil { h.logger.Error("Failed to parse target URL for health check", logger.String("target", target.Name), logger.Error(err)) return false } healthURL := *targetURL healthURL.Path = h.config.Proxy.HealthCheckPath req, err := http.NewRequest("GET", healthURL.String(), nil) if err != nil { h.logger.Error("Failed to create health check request", logger.String("target", target.Name), logger.Error(err)) return false } resp, err := h.httpClient.Do(req) if err != nil { h.logger.Error("Health check request failed", logger.String("target", target.Name), logger.Error(err)) return false } defer resp.Body.Close() return resp.StatusCode == http.StatusOK }