committtttt
This commit is contained in:
362
backend/internal/auth/auth.go
Normal file
362
backend/internal/auth/auth.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"customer-support-system/internal/database"
|
||||
"customer-support-system/internal/models"
|
||||
"customer-support-system/pkg/config"
|
||||
"customer-support-system/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserInactive = errors.New("user is inactive")
|
||||
ErrAccountLocked = errors.New("account is locked")
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
// AuthService handles authentication operations
|
||||
type AuthService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAuthService creates a new authentication service
|
||||
func NewAuthService(db *gorm.DB) *AuthService {
|
||||
return &AuthService{db: db}
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT token
|
||||
func (s *AuthService) Login(username, password, clientIP string) (*models.LoginResponse, error) {
|
||||
// Find user by username
|
||||
var user models.User
|
||||
if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
logger.WithField("username", username).Warn("Login attempt with non-existent username")
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.Active {
|
||||
logger.WithField("user_id", user.ID).Warn("Login attempt by inactive user")
|
||||
return nil, ErrUserInactive
|
||||
}
|
||||
|
||||
// Check password
|
||||
if !user.ComparePassword(password) {
|
||||
logger.WithField("user_id", user.ID).Warn("Login attempt with incorrect password")
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := s.GenerateJWTToken(user.ID)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("user_id", user.ID).Error("Failed to generate JWT token")
|
||||
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
// Log successful login
|
||||
logger.LogAuthEvent("login", fmt.Sprintf("%d", user.ID), clientIP, true, nil)
|
||||
|
||||
return &models.LoginResponse{
|
||||
Token: token,
|
||||
User: user.ToSafeUser(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Register creates a new user
|
||||
func (s *AuthService) Register(req *models.CreateUserRequest) (*models.SafeUser, error) {
|
||||
// Check if username already exists
|
||||
var existingUser models.User
|
||||
if err := s.db.Where("username = ?", req.Username).First(&existingUser).Error; err == nil {
|
||||
return nil, fmt.Errorf("username already exists")
|
||||
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("failed to check username: %w", err)
|
||||
}
|
||||
|
||||
// Check if email already exists
|
||||
if err := s.db.Where("email = ?", req.Email).First(&existingUser).Error; err == nil {
|
||||
return nil, fmt.Errorf("email already exists")
|
||||
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("failed to check email: %w", err)
|
||||
}
|
||||
|
||||
// Create new user
|
||||
user := models.User{
|
||||
Username: req.Username,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
FirstName: req.FirstName,
|
||||
LastName: req.LastName,
|
||||
Role: req.Role,
|
||||
Active: true,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&user).Error; err != nil {
|
||||
logger.WithError(err).WithField("username", req.Username).Error("Failed to create user")
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Log user registration
|
||||
logger.WithField("user_id", user.ID).Info("User registered successfully")
|
||||
|
||||
safeUser := user.ToSafeUser()
|
||||
return &safeUser, nil
|
||||
}
|
||||
|
||||
// GenerateJWTToken generates a JWT token for a user
|
||||
func (s *AuthService) GenerateJWTToken(userID uint) (string, error) {
|
||||
// Get JWT configuration
|
||||
jwtConfig := config.AppConfig.JWT
|
||||
|
||||
// Create claims
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": userID,
|
||||
"exp": time.Now().Add(time.Hour * time.Duration(jwtConfig.ExpirationHours)).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"iss": jwtConfig.Issuer,
|
||||
"aud": jwtConfig.Audience,
|
||||
}
|
||||
|
||||
// Create token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
// Sign token
|
||||
tokenString, err := token.SignedString(jwtConfig.GetJWTSigningKey())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ValidateJWTToken validates a JWT token and returns the user ID
|
||||
func (s *AuthService) ValidateJWTToken(tokenString string) (uint, error) {
|
||||
// Get JWT configuration
|
||||
jwtConfig := config.AppConfig.JWT
|
||||
|
||||
// Parse token
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtConfig.GetJWTSigningKey(), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
// Validate claims
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
// Check expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return 0, ErrTokenExpired
|
||||
}
|
||||
}
|
||||
|
||||
// Get user ID
|
||||
userID, ok := claims["user_id"].(float64)
|
||||
if !ok {
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
|
||||
return uint(userID), nil
|
||||
}
|
||||
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
|
||||
// GetUserByID returns a user by ID
|
||||
func (s *AuthService) GetUserByID(userID uint) (*models.User, error) {
|
||||
var user models.User
|
||||
if err := s.db.First(&user, userID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user
|
||||
func (s *AuthService) UpdateUser(userID uint, req *models.UpdateUserRequest) (*models.SafeUser, error) {
|
||||
var user models.User
|
||||
if err := s.db.First(&user, userID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if req.FirstName != "" {
|
||||
user.FirstName = req.FirstName
|
||||
}
|
||||
if req.LastName != "" {
|
||||
user.LastName = req.LastName
|
||||
}
|
||||
if req.Email != "" {
|
||||
user.Email = req.Email
|
||||
}
|
||||
if req.Active != nil {
|
||||
user.Active = *req.Active
|
||||
}
|
||||
if req.Role != "" {
|
||||
user.Role = req.Role
|
||||
}
|
||||
|
||||
if err := s.db.Save(&user).Error; err != nil {
|
||||
logger.WithError(err).WithField("user_id", userID).Error("Failed to update user")
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
// Log user update
|
||||
logger.WithField("user_id", userID).Info("User updated successfully")
|
||||
|
||||
safeUser := user.ToSafeUser()
|
||||
return &safeUser, nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password
|
||||
func (s *AuthService) ChangePassword(userID uint, req *models.ChangePasswordRequest) error {
|
||||
var user models.User
|
||||
if err := s.db.First(&user, userID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Check current password
|
||||
if !user.ComparePassword(req.CurrentPassword) {
|
||||
logger.WithField("user_id", userID).Warn("Password change attempt with incorrect current password")
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Update password
|
||||
user.Password = req.NewPassword
|
||||
if err := s.db.Save(&user).Error; err != nil {
|
||||
logger.WithError(err).WithField("user_id", userID).Error("Failed to change password")
|
||||
return fmt.Errorf("failed to change password: %w", err)
|
||||
}
|
||||
|
||||
// Log password change
|
||||
logger.WithField("user_id", userID).Info("Password changed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HashPassword hashes a password using bcrypt
|
||||
func HashPassword(password string) (string, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
return string(hashedPassword), nil
|
||||
}
|
||||
|
||||
// CheckPassword checks if a password matches a hashed password
|
||||
func CheckPassword(password, hashedPassword string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a gin middleware for authentication
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get token from Authorization header
|
||||
tokenString := c.GetHeader("Authorization")
|
||||
if tokenString == "" {
|
||||
c.JSON(401, gin.H{"error": "Authorization header is required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Remove "Bearer " prefix
|
||||
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
|
||||
tokenString = tokenString[7:]
|
||||
}
|
||||
|
||||
// Validate token
|
||||
authService := NewAuthService(database.GetDB())
|
||||
userID, err := authService.ValidateJWTToken(tokenString)
|
||||
if err != nil {
|
||||
c.JSON(401, gin.H{"error": "Invalid or expired token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := authService.GetUserByID(userID)
|
||||
if err != nil {
|
||||
c.JSON(401, gin.H{"error": "User not found"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.Active {
|
||||
c.JSON(401, gin.H{"error": "User account is inactive"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Set user ID in context
|
||||
c.Set("userID", userID)
|
||||
c.Set("userRole", user.Role)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RoleMiddleware returns a gin middleware for role-based authorization
|
||||
func RoleMiddleware(roles ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get user role from context
|
||||
userRole, exists := c.Get("userRole")
|
||||
if !exists {
|
||||
c.JSON(401, gin.H{"error": "User not authenticated"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user has required role
|
||||
roleStr, ok := userRole.(string)
|
||||
if !ok {
|
||||
c.JSON(500, gin.H{"error": "Invalid user role"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user role is in allowed roles
|
||||
allowed := false
|
||||
for _, role := range roles {
|
||||
if roleStr == role {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
c.JSON(403, gin.H{"error": "Insufficient permissions"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user