363 lines
9.6 KiB
Go
363 lines
9.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
|
|
"customer-support-system/internal/database"
|
|
"customer-support-system/internal/models"
|
|
"customer-support-system/pkg/config"
|
|
"customer-support-system/pkg/logger"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
ErrUserNotFound = errors.New("user not found")
|
|
ErrUserInactive = errors.New("user is inactive")
|
|
ErrAccountLocked = errors.New("account is locked")
|
|
ErrTokenExpired = errors.New("token has expired")
|
|
ErrInvalidToken = errors.New("invalid token")
|
|
)
|
|
|
|
// AuthService handles authentication operations
|
|
type AuthService struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewAuthService creates a new authentication service
|
|
func NewAuthService(db *gorm.DB) *AuthService {
|
|
return &AuthService{db: db}
|
|
}
|
|
|
|
// Login authenticates a user and returns a JWT token
|
|
func (s *AuthService) Login(username, password, clientIP string) (*models.LoginResponse, error) {
|
|
// Find user by username
|
|
var user models.User
|
|
if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
logger.WithField("username", username).Warn("Login attempt with non-existent username")
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
// Check if user is active
|
|
if !user.Active {
|
|
logger.WithField("user_id", user.ID).Warn("Login attempt by inactive user")
|
|
return nil, ErrUserInactive
|
|
}
|
|
|
|
// Check password
|
|
if !user.ComparePassword(password) {
|
|
logger.WithField("user_id", user.ID).Warn("Login attempt with incorrect password")
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
// Generate JWT token
|
|
token, err := s.GenerateJWTToken(user.ID)
|
|
if err != nil {
|
|
logger.WithError(err).WithField("user_id", user.ID).Error("Failed to generate JWT token")
|
|
return nil, fmt.Errorf("failed to generate token: %w", err)
|
|
}
|
|
|
|
// Log successful login
|
|
logger.LogAuthEvent("login", fmt.Sprintf("%d", user.ID), clientIP, true, nil)
|
|
|
|
return &models.LoginResponse{
|
|
Token: token,
|
|
User: user.ToSafeUser(),
|
|
}, nil
|
|
}
|
|
|
|
// Register creates a new user
|
|
func (s *AuthService) Register(req *models.CreateUserRequest) (*models.SafeUser, error) {
|
|
// Check if username already exists
|
|
var existingUser models.User
|
|
if err := s.db.Where("username = ?", req.Username).First(&existingUser).Error; err == nil {
|
|
return nil, fmt.Errorf("username already exists")
|
|
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, fmt.Errorf("failed to check username: %w", err)
|
|
}
|
|
|
|
// Check if email already exists
|
|
if err := s.db.Where("email = ?", req.Email).First(&existingUser).Error; err == nil {
|
|
return nil, fmt.Errorf("email already exists")
|
|
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, fmt.Errorf("failed to check email: %w", err)
|
|
}
|
|
|
|
// Create new user
|
|
user := models.User{
|
|
Username: req.Username,
|
|
Email: req.Email,
|
|
Password: req.Password,
|
|
FirstName: req.FirstName,
|
|
LastName: req.LastName,
|
|
Role: req.Role,
|
|
Active: true,
|
|
}
|
|
|
|
if err := s.db.Create(&user).Error; err != nil {
|
|
logger.WithError(err).WithField("username", req.Username).Error("Failed to create user")
|
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
// Log user registration
|
|
logger.WithField("user_id", user.ID).Info("User registered successfully")
|
|
|
|
safeUser := user.ToSafeUser()
|
|
return &safeUser, nil
|
|
}
|
|
|
|
// GenerateJWTToken generates a JWT token for a user
|
|
func (s *AuthService) GenerateJWTToken(userID uint) (string, error) {
|
|
// Get JWT configuration
|
|
jwtConfig := config.AppConfig.JWT
|
|
|
|
// Create claims
|
|
claims := jwt.MapClaims{
|
|
"user_id": userID,
|
|
"exp": time.Now().Add(time.Hour * time.Duration(jwtConfig.ExpirationHours)).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"iss": jwtConfig.Issuer,
|
|
"aud": jwtConfig.Audience,
|
|
}
|
|
|
|
// Create token
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
// Sign token
|
|
tokenString, err := token.SignedString(jwtConfig.GetJWTSigningKey())
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
|
}
|
|
|
|
return tokenString, nil
|
|
}
|
|
|
|
// ValidateJWTToken validates a JWT token and returns the user ID
|
|
func (s *AuthService) ValidateJWTToken(tokenString string) (uint, error) {
|
|
// Get JWT configuration
|
|
jwtConfig := config.AppConfig.JWT
|
|
|
|
// Parse token
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
// Validate signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return jwtConfig.GetJWTSigningKey(), nil
|
|
})
|
|
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to parse token: %w", err)
|
|
}
|
|
|
|
// Validate claims
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
// Check expiration
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
if time.Now().Unix() > int64(exp) {
|
|
return 0, ErrTokenExpired
|
|
}
|
|
}
|
|
|
|
// Get user ID
|
|
userID, ok := claims["user_id"].(float64)
|
|
if !ok {
|
|
return 0, ErrInvalidToken
|
|
}
|
|
|
|
return uint(userID), nil
|
|
}
|
|
|
|
return 0, ErrInvalidToken
|
|
}
|
|
|
|
// GetUserByID returns a user by ID
|
|
func (s *AuthService) GetUserByID(userID uint) (*models.User, error) {
|
|
var user models.User
|
|
if err := s.db.First(&user, userID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// UpdateUser updates a user
|
|
func (s *AuthService) UpdateUser(userID uint, req *models.UpdateUserRequest) (*models.SafeUser, error) {
|
|
var user models.User
|
|
if err := s.db.First(&user, userID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
// Update fields
|
|
if req.FirstName != "" {
|
|
user.FirstName = req.FirstName
|
|
}
|
|
if req.LastName != "" {
|
|
user.LastName = req.LastName
|
|
}
|
|
if req.Email != "" {
|
|
user.Email = req.Email
|
|
}
|
|
if req.Active != nil {
|
|
user.Active = *req.Active
|
|
}
|
|
if req.Role != "" {
|
|
user.Role = req.Role
|
|
}
|
|
|
|
if err := s.db.Save(&user).Error; err != nil {
|
|
logger.WithError(err).WithField("user_id", userID).Error("Failed to update user")
|
|
return nil, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
// Log user update
|
|
logger.WithField("user_id", userID).Info("User updated successfully")
|
|
|
|
safeUser := user.ToSafeUser()
|
|
return &safeUser, nil
|
|
}
|
|
|
|
// ChangePassword changes a user's password
|
|
func (s *AuthService) ChangePassword(userID uint, req *models.ChangePasswordRequest) error {
|
|
var user models.User
|
|
if err := s.db.First(&user, userID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrUserNotFound
|
|
}
|
|
return fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
// Check current password
|
|
if !user.ComparePassword(req.CurrentPassword) {
|
|
logger.WithField("user_id", userID).Warn("Password change attempt with incorrect current password")
|
|
return ErrInvalidCredentials
|
|
}
|
|
|
|
// Update password
|
|
user.Password = req.NewPassword
|
|
if err := s.db.Save(&user).Error; err != nil {
|
|
logger.WithError(err).WithField("user_id", userID).Error("Failed to change password")
|
|
return fmt.Errorf("failed to change password: %w", err)
|
|
}
|
|
|
|
// Log password change
|
|
logger.WithField("user_id", userID).Info("Password changed successfully")
|
|
|
|
return nil
|
|
}
|
|
|
|
// HashPassword hashes a password using bcrypt
|
|
func HashPassword(password string) (string, error) {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
return string(hashedPassword), nil
|
|
}
|
|
|
|
// CheckPassword checks if a password matches a hashed password
|
|
func CheckPassword(password, hashedPassword string) bool {
|
|
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
|
return err == nil
|
|
}
|
|
|
|
// AuthMiddleware returns a gin middleware for authentication
|
|
func AuthMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Get token from Authorization header
|
|
tokenString := c.GetHeader("Authorization")
|
|
if tokenString == "" {
|
|
c.JSON(401, gin.H{"error": "Authorization header is required"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Remove "Bearer " prefix
|
|
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
|
|
tokenString = tokenString[7:]
|
|
}
|
|
|
|
// Validate token
|
|
authService := NewAuthService(database.GetDB())
|
|
userID, err := authService.ValidateJWTToken(tokenString)
|
|
if err != nil {
|
|
c.JSON(401, gin.H{"error": "Invalid or expired token"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Get user
|
|
user, err := authService.GetUserByID(userID)
|
|
if err != nil {
|
|
c.JSON(401, gin.H{"error": "User not found"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Check if user is active
|
|
if !user.Active {
|
|
c.JSON(401, gin.H{"error": "User account is inactive"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Set user ID in context
|
|
c.Set("userID", userID)
|
|
c.Set("userRole", user.Role)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RoleMiddleware returns a gin middleware for role-based authorization
|
|
func RoleMiddleware(roles ...string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Get user role from context
|
|
userRole, exists := c.Get("userRole")
|
|
if !exists {
|
|
c.JSON(401, gin.H{"error": "User not authenticated"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Check if user has required role
|
|
roleStr, ok := userRole.(string)
|
|
if !ok {
|
|
c.JSON(500, gin.H{"error": "Invalid user role"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Check if user role is in allowed roles
|
|
allowed := false
|
|
for _, role := range roles {
|
|
if roleStr == role {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !allowed {
|
|
c.JSON(403, gin.H{"error": "Insufficient permissions"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|