package nat import ( "context" "fmt" "net" "sync" "time" "github.com/iwasforcedtobehere/goRZ/internal/config" "github.com/iwasforcedtobehere/goRZ/internal/logger" "github.com/pion/stun" "github.com/pion/turn/v2" ) // NATTraversal handles NAT traversal for the proxy server type NATTraversal struct { config *config.Config logger *logger.Logger conn net.PacketConn externalIP net.IP externalPort int mu sync.RWMutex running bool } // NewNATTraversal creates a new NAT traversal instance func NewNATTraversal(cfg *config.Config, logger *logger.Logger) *NATTraversal { return &NATTraversal{ config: cfg, logger: logger, } } // Start starts the NAT traversal process func (n *NATTraversal) Start() error { if !n.config.NAT.Enabled { n.logger.Info("NAT traversal is disabled") return nil } n.mu.Lock() defer n.mu.Unlock() if n.running { return fmt.Errorf("NAT traversal is already running") } // Create a UDP listener conn, err := net.ListenPacket("udp4", "0.0.0.0:0") if err != nil { return fmt.Errorf("failed to create UDP listener: %w", err) } n.conn = conn // Get external IP and port using STUN if err := n.discoverExternalAddress(); err != nil { conn.Close() return fmt.Errorf("failed to discover external address: %w", err) } // Start TURN client if configured if n.config.NAT.TURNServer != "" { if err := n.startTURNClient(); err != nil { conn.Close() return fmt.Errorf("failed to start TURN client: %w", err) } } n.running = true n.logger.Info("NAT traversal started", logger.String("external_ip", n.externalIP.String()), logger.Int("external_port", n.externalPort)) return nil } // Stop stops the NAT traversal process func (n *NATTraversal) Stop() { n.mu.Lock() defer n.mu.Unlock() if !n.running { return } if n.conn != nil { n.conn.Close() n.conn = nil } n.running = false n.logger.Info("NAT traversal stopped") } // discoverExternalAddress discovers the external IP and port using STUN func (n *NATTraversal) discoverExternalAddress() error { // Parse STUN server URL stunURL, err := stun.ParseURI(n.config.NAT.STUNServer) if err != nil { return fmt.Errorf("failed to parse STUN server URL: %w", err) } // Create STUN client client, err := stun.NewClient("udp4", n.conn) if err != nil { return fmt.Errorf("failed to create STUN client: %w", err) } defer client.Close() // Send binding request ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() mappedAddr, err := client.Request(ctx, stunURL, stun.BindingRequest) if err != nil { return fmt.Errorf("STUN binding request failed: %w", err) } n.externalIP = mappedAddr.IP n.externalPort = mappedAddr.Port return nil } // startTURNClient starts a TURN client for relay connections func (n *NATTraversal) startTURNClient() error { // Parse TURN server URL turnURL, err := turn.ParseURI(n.config.NAT.TURNServer) if err != nil { return fmt.Errorf("failed to parse TURN server URL: %w", err) } // Create TURN client config cfg := &turn.ClientConfig{ STUNServerAddr: n.config.NAT.STUNServer, TURNServerAddr: n.config.NAT.TURNServer, Username: n.config.NAT.TURNUsername, Credential: n.config.NAT.TURNPassword, LoggerFactory: nil, // We'll use our own logger } // Create TURN client client, err := turn.NewClient(cfg) if err != nil { return fmt.Errorf("failed to create TURN client: %w", err) } defer client.Close() // Listen on provided conn n.logger.Info("TURN client started", logger.String("server", n.config.NAT.TURNServer)) return nil } // GetExternalAddress returns the external IP and port func (n *NATTraversal) GetExternalAddress() (net.IP, int) { n.mu.RLock() defer n.mu.RUnlock() return n.externalIP, n.externalPort } // IsRunning returns whether NAT traversal is running func (n *NATTraversal) IsRunning() bool { n.mu.RLock() defer n.mu.RUnlock() return n.running } // CreateHolePunch attempts to create a hole in the NAT for direct peer-to-peer connections func (n *NATTraversal) CreateHolePunch(peerAddr string) (net.Conn, error) { if !n.running { return nil, fmt.Errorf("NAT traversal is not running") } // Parse peer address udpAddr, err := net.ResolveUDPAddr("udp4", peerAddr) if err != nil { return nil, fmt.Errorf("failed to resolve peer address: %w", err) } // Create UDP connection conn, err := net.DialUDP("udp4", nil, udpAddr) if err != nil { return nil, fmt.Errorf("failed to create UDP connection: %w", err) } // Send a small packet to punch a hole in the NAT _, err = conn.Write([]byte("punch")) if err != nil { conn.Close() return nil, fmt.Errorf("failed to punch hole: %w", err) } n.logger.Debug("Hole punch attempted", logger.String("peer", peerAddr)) return conn, nil } // GetNATType attempts to determine the type of NAT func (n *NATTraversal) GetNATType() (string, error) { if !n.running { return "", fmt.Errorf("NAT traversal is not running") } // This is a simplified NAT type detection // In a real implementation, you would use more sophisticated methods // like the one described in RFC 3489 and RFC 5780 // For now, we'll return a generic response return "Unknown", nil } // CreateRelayConnection creates a relay connection through TURN func (n *NATTraversal) CreateRelayConnection(peerAddr string) (net.Conn, error) { if !n.running { return nil, fmt.Errorf("NAT traversal is not running") } if n.config.NAT.TURNServer == "" { return nil, fmt.Errorf("TURN server not configured") } // This is a placeholder for TURN relay connection creation // In a real implementation, you would use the TURN client to allocate // a relay and create a connection to the peer return nil, fmt.Errorf("not implemented") }