Files
AI-Support/backend/internal/ai/ai.go
2025-09-13 06:48:55 +03:00

395 lines
11 KiB
Go

package ai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"customer-support-system/internal/models"
"customer-support-system/pkg/config"
"customer-support-system/pkg/logger"
)
// OpenAIRequest represents a request to the OpenAI API
type OpenAIRequest struct {
Model string `json:"model"`
Messages []OpenAIMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// OpenAIMessage represents a message in the OpenAI API
type OpenAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// OpenAIResponse represents a response from the OpenAI API
type OpenAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChoice `json:"choices"`
Usage OpenAIUsage `json:"usage"`
}
// OpenAIChoice represents a choice in the OpenAI API response
type OpenAIChoice struct {
Index int `json:"index"`
Message OpenAIMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
// OpenAIUsage represents usage statistics in the OpenAI API response
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// OllamaRequest represents a request to the Ollama API
type OllamaRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream,omitempty"`
Options map[string]any `json:"options,omitempty"`
}
// OllamaResponse represents a response from the Ollama API
type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}
// AIService handles AI operations
type AIService struct {
openAIConfig config.APIConfig
localConfig config.APIConfig
httpClient *http.Client
}
// NewAIService creates a new AI service
func NewAIService() *AIService {
return &AIService{
openAIConfig: config.AppConfig.AI.OpenAI,
localConfig: config.AppConfig.AI.Local,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// QueryOpenAI sends a query to the OpenAI API
func (s *AIService) QueryOpenAI(ctx context.Context, prompt string, conversationHistory []models.Message) (string, error) {
startTime := time.Now()
// Convert conversation history to OpenAI messages
messages := s.convertToOpenAIMessages(prompt, conversationHistory)
// Create request
req := OpenAIRequest{
Model: s.openAIConfig.Model,
Messages: messages,
MaxTokens: s.openAIConfig.MaxTokens,
Temperature: s.openAIConfig.Temperature,
TopP: s.openAIConfig.TopP,
}
// Marshal request to JSON
reqBody, err := json.Marshal(req)
if err != nil {
logger.WithError(err).Error("Failed to marshal OpenAI request")
return "", fmt.Errorf("failed to marshal request: %w", err)
}
// Create HTTP request
httpReq, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(reqBody))
if err != nil {
logger.WithError(err).Error("Failed to create OpenAI HTTP request")
return "", fmt.Errorf("failed to create HTTP request: %w", err)
}
// Set headers
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.openAIConfig.APIKey))
// Send request
resp, err := s.httpClient.Do(httpReq)
if err != nil {
logger.WithError(err).Error("Failed to send OpenAI request")
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// Read response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
logger.WithError(err).Error("Failed to read OpenAI response")
return "", fmt.Errorf("failed to read response: %w", err)
}
// Check status code
if resp.StatusCode != http.StatusOK {
logger.WithField("status_code", resp.StatusCode).
WithField("response", string(respBody)).
Error("OpenAI API returned non-200 status code")
return "", fmt.Errorf("OpenAI API returned status code %d: %s", resp.StatusCode, string(respBody))
}
// Parse response
var openAIResp OpenAIResponse
if err := json.Unmarshal(respBody, &openAIResp); err != nil {
logger.WithError(err).Error("Failed to parse OpenAI response")
return "", fmt.Errorf("failed to parse response: %w", err)
}
// Extract response text
if len(openAIResp.Choices) == 0 {
logger.Error("OpenAI API returned no choices")
return "", fmt.Errorf("OpenAI API returned no choices")
}
responseText := openAIResp.Choices[0].Message.Content
// Log AI interaction
duration := time.Since(startTime)
logger.LogAIInteraction(
s.openAIConfig.Model,
len(prompt),
len(responseText),
duration,
true,
nil,
)
return responseText, nil
}
// QueryOllama sends a query to the Ollama API
func (s *AIService) QueryOllama(ctx context.Context, prompt string, conversationHistory []models.Message) (string, error) {
startTime := time.Now()
// Create request
req := OllamaRequest{
Model: s.localConfig.Model,
Prompt: s.buildOllamaPrompt(prompt, conversationHistory),
Stream: false,
Options: map[string]any{
"temperature": s.localConfig.Temperature,
"top_p": s.localConfig.TopP,
"num_predict": s.localConfig.MaxTokens,
},
}
// Marshal request to JSON
reqBody, err := json.Marshal(req)
if err != nil {
logger.WithError(err).Error("Failed to marshal Ollama request")
return "", fmt.Errorf("failed to marshal request: %w", err)
}
// Create HTTP request
httpReq, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/api/generate", s.localConfig.Endpoint), bytes.NewBuffer(reqBody))
if err != nil {
logger.WithError(err).Error("Failed to create Ollama HTTP request")
return "", fmt.Errorf("failed to create HTTP request: %w", err)
}
// Set headers
httpReq.Header.Set("Content-Type", "application/json")
// Send request
resp, err := s.httpClient.Do(httpReq)
if err != nil {
logger.WithError(err).Error("Failed to send Ollama request")
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// Read response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
logger.WithError(err).Error("Failed to read Ollama response")
return "", fmt.Errorf("failed to read response: %w", err)
}
// Check status code
if resp.StatusCode != http.StatusOK {
logger.WithField("status_code", resp.StatusCode).
WithField("response", string(respBody)).
Error("Ollama API returned non-200 status code")
return "", fmt.Errorf("Ollama API returned status code %d: %s", resp.StatusCode, string(respBody))
}
// Parse response
var ollamaResp OllamaResponse
if err := json.Unmarshal(respBody, &ollamaResp); err != nil {
logger.WithError(err).Error("Failed to parse Ollama response")
return "", fmt.Errorf("failed to parse response: %w", err)
}
// Extract response text
responseText := ollamaResp.Response
// Log AI interaction
duration := time.Since(startTime)
logger.LogAIInteraction(
s.localConfig.Model,
len(prompt),
len(responseText),
duration,
true,
nil,
)
return responseText, nil
}
// Query sends a query to the appropriate AI model based on complexity
func (s *AIService) Query(ctx context.Context, prompt string, conversationHistory []models.Message, complexity int) (string, error) {
// Determine which AI model to use based on complexity
if complexity >= 7 { // High complexity, use OpenAI
return s.QueryOpenAI(ctx, prompt, conversationHistory)
} else { // Low to medium complexity, use local LLM
return s.QueryOllama(ctx, prompt, conversationHistory)
}
}
// AnalyzeComplexity analyzes the complexity of a prompt
func (s *AIService) AnalyzeComplexity(prompt string) int {
// Simple heuristic for complexity analysis
// In a real implementation, this would be more sophisticated
complexity := 0
// Length factor
if len(prompt) > 100 {
complexity += 2
}
if len(prompt) > 200 {
complexity += 2
}
// Question type factor
if strings.Contains(prompt, "?") {
complexity += 1
}
// Technical terms factor
technicalTerms := []string{"API", "database", "server", "code", "programming", "software", "algorithm"}
for _, term := range technicalTerms {
if strings.Contains(strings.ToLower(prompt), strings.ToLower(term)) {
complexity += 1
}
}
// Multiple questions factor
questionCount := strings.Count(prompt, "?")
if questionCount > 1 {
complexity += questionCount - 1
}
// Cap complexity at 10
if complexity > 10 {
complexity = 10
}
return complexity
}
// convertToOpenAIMessages converts conversation history to OpenAI messages
func (s *AIService) convertToOpenAIMessages(prompt string, conversationHistory []models.Message) []OpenAIMessage {
messages := []OpenAIMessage{}
// Add system message
messages = append(messages, OpenAIMessage{
Role: "system",
Content: "You are a helpful customer support assistant. Provide clear, concise, and accurate answers to customer questions.",
})
// Add conversation history
for _, msg := range conversationHistory {
role := "user"
if msg.IsAI {
role = "assistant"
}
messages = append(messages, OpenAIMessage{
Role: role,
Content: msg.Content,
})
}
// Add current prompt
messages = append(messages, OpenAIMessage{
Role: "user",
Content: prompt,
})
return messages
}
// buildOllamaPrompt builds a prompt for Ollama from conversation history
func (s *AIService) buildOllamaPrompt(prompt string, conversationHistory []models.Message) string {
var builder strings.Builder
// Add system instruction
builder.WriteString("You are a helpful customer support assistant. Provide clear, concise, and accurate answers to customer questions.\n\n")
// Add conversation history
for _, msg := range conversationHistory {
if msg.IsAI {
builder.WriteString("Assistant: ")
} else {
builder.WriteString("User: ")
}
builder.WriteString(msg.Content)
builder.WriteString("\n\n")
}
// Add current prompt
builder.WriteString("User: ")
builder.WriteString(prompt)
builder.WriteString("\n\nAssistant: ")
return builder.String()
}
// GetAvailableModels returns the available AI models
func (s *AIService) GetAvailableModels() []models.AIModel {
return []models.AIModel{
{
Name: "OpenAI GPT-4",
Type: "openai",
Model: s.openAIConfig.Model,
MaxTokens: s.openAIConfig.MaxTokens,
Temperature: s.openAIConfig.Temperature,
TopP: s.openAIConfig.TopP,
Active: true,
Priority: 2,
Description: "OpenAI's GPT-4 model for complex queries",
},
{
Name: "Local LLaMA",
Type: "local",
Model: s.localConfig.Model,
Endpoint: s.localConfig.Endpoint,
MaxTokens: s.localConfig.MaxTokens,
Temperature: s.localConfig.Temperature,
TopP: s.localConfig.TopP,
Active: true,
Priority: 1,
Description: "Local LLaMA model for simple queries",
},
}
}