395 lines
11 KiB
Go
395 lines
11 KiB
Go
package ai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"customer-support-system/internal/models"
|
|
"customer-support-system/pkg/config"
|
|
"customer-support-system/pkg/logger"
|
|
)
|
|
|
|
// OpenAIRequest represents a request to the OpenAI API
|
|
type OpenAIRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []OpenAIMessage `json:"messages"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
TopP float64 `json:"top_p,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
}
|
|
|
|
// OpenAIMessage represents a message in the OpenAI API
|
|
type OpenAIMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// OpenAIResponse represents a response from the OpenAI API
|
|
type OpenAIResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []OpenAIChoice `json:"choices"`
|
|
Usage OpenAIUsage `json:"usage"`
|
|
}
|
|
|
|
// OpenAIChoice represents a choice in the OpenAI API response
|
|
type OpenAIChoice struct {
|
|
Index int `json:"index"`
|
|
Message OpenAIMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
// OpenAIUsage represents usage statistics in the OpenAI API response
|
|
type OpenAIUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
// OllamaRequest represents a request to the Ollama API
|
|
type OllamaRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
Options map[string]any `json:"options,omitempty"`
|
|
}
|
|
|
|
// OllamaResponse represents a response from the Ollama API
|
|
type OllamaResponse struct {
|
|
Model string `json:"model"`
|
|
CreatedAt string `json:"created_at"`
|
|
Response string `json:"response"`
|
|
Done bool `json:"done"`
|
|
}
|
|
|
|
// AIService handles AI operations
|
|
type AIService struct {
|
|
openAIConfig config.APIConfig
|
|
localConfig config.APIConfig
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewAIService creates a new AI service
|
|
func NewAIService() *AIService {
|
|
return &AIService{
|
|
openAIConfig: config.AppConfig.AI.OpenAI,
|
|
localConfig: config.AppConfig.AI.Local,
|
|
httpClient: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
// QueryOpenAI sends a query to the OpenAI API
|
|
func (s *AIService) QueryOpenAI(ctx context.Context, prompt string, conversationHistory []models.Message) (string, error) {
|
|
startTime := time.Now()
|
|
|
|
// Convert conversation history to OpenAI messages
|
|
messages := s.convertToOpenAIMessages(prompt, conversationHistory)
|
|
|
|
// Create request
|
|
req := OpenAIRequest{
|
|
Model: s.openAIConfig.Model,
|
|
Messages: messages,
|
|
MaxTokens: s.openAIConfig.MaxTokens,
|
|
Temperature: s.openAIConfig.Temperature,
|
|
TopP: s.openAIConfig.TopP,
|
|
}
|
|
|
|
// Marshal request to JSON
|
|
reqBody, err := json.Marshal(req)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to marshal OpenAI request")
|
|
return "", fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
// Create HTTP request
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(reqBody))
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to create OpenAI HTTP request")
|
|
return "", fmt.Errorf("failed to create HTTP request: %w", err)
|
|
}
|
|
|
|
// Set headers
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.openAIConfig.APIKey))
|
|
|
|
// Send request
|
|
resp, err := s.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to send OpenAI request")
|
|
return "", fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read response body
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to read OpenAI response")
|
|
return "", fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
|
|
// Check status code
|
|
if resp.StatusCode != http.StatusOK {
|
|
logger.WithField("status_code", resp.StatusCode).
|
|
WithField("response", string(respBody)).
|
|
Error("OpenAI API returned non-200 status code")
|
|
return "", fmt.Errorf("OpenAI API returned status code %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Parse response
|
|
var openAIResp OpenAIResponse
|
|
if err := json.Unmarshal(respBody, &openAIResp); err != nil {
|
|
logger.WithError(err).Error("Failed to parse OpenAI response")
|
|
return "", fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
// Extract response text
|
|
if len(openAIResp.Choices) == 0 {
|
|
logger.Error("OpenAI API returned no choices")
|
|
return "", fmt.Errorf("OpenAI API returned no choices")
|
|
}
|
|
|
|
responseText := openAIResp.Choices[0].Message.Content
|
|
|
|
// Log AI interaction
|
|
duration := time.Since(startTime)
|
|
logger.LogAIInteraction(
|
|
s.openAIConfig.Model,
|
|
len(prompt),
|
|
len(responseText),
|
|
duration,
|
|
true,
|
|
nil,
|
|
)
|
|
|
|
return responseText, nil
|
|
}
|
|
|
|
// QueryOllama sends a query to the Ollama API
|
|
func (s *AIService) QueryOllama(ctx context.Context, prompt string, conversationHistory []models.Message) (string, error) {
|
|
startTime := time.Now()
|
|
|
|
// Create request
|
|
req := OllamaRequest{
|
|
Model: s.localConfig.Model,
|
|
Prompt: s.buildOllamaPrompt(prompt, conversationHistory),
|
|
Stream: false,
|
|
Options: map[string]any{
|
|
"temperature": s.localConfig.Temperature,
|
|
"top_p": s.localConfig.TopP,
|
|
"num_predict": s.localConfig.MaxTokens,
|
|
},
|
|
}
|
|
|
|
// Marshal request to JSON
|
|
reqBody, err := json.Marshal(req)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to marshal Ollama request")
|
|
return "", fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
// Create HTTP request
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/api/generate", s.localConfig.Endpoint), bytes.NewBuffer(reqBody))
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to create Ollama HTTP request")
|
|
return "", fmt.Errorf("failed to create HTTP request: %w", err)
|
|
}
|
|
|
|
// Set headers
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
// Send request
|
|
resp, err := s.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to send Ollama request")
|
|
return "", fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read response body
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to read Ollama response")
|
|
return "", fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
|
|
// Check status code
|
|
if resp.StatusCode != http.StatusOK {
|
|
logger.WithField("status_code", resp.StatusCode).
|
|
WithField("response", string(respBody)).
|
|
Error("Ollama API returned non-200 status code")
|
|
return "", fmt.Errorf("Ollama API returned status code %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Parse response
|
|
var ollamaResp OllamaResponse
|
|
if err := json.Unmarshal(respBody, &ollamaResp); err != nil {
|
|
logger.WithError(err).Error("Failed to parse Ollama response")
|
|
return "", fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
// Extract response text
|
|
responseText := ollamaResp.Response
|
|
|
|
// Log AI interaction
|
|
duration := time.Since(startTime)
|
|
logger.LogAIInteraction(
|
|
s.localConfig.Model,
|
|
len(prompt),
|
|
len(responseText),
|
|
duration,
|
|
true,
|
|
nil,
|
|
)
|
|
|
|
return responseText, nil
|
|
}
|
|
|
|
// Query sends a query to the appropriate AI model based on complexity
|
|
func (s *AIService) Query(ctx context.Context, prompt string, conversationHistory []models.Message, complexity int) (string, error) {
|
|
// Determine which AI model to use based on complexity
|
|
if complexity >= 7 { // High complexity, use OpenAI
|
|
return s.QueryOpenAI(ctx, prompt, conversationHistory)
|
|
} else { // Low to medium complexity, use local LLM
|
|
return s.QueryOllama(ctx, prompt, conversationHistory)
|
|
}
|
|
}
|
|
|
|
// AnalyzeComplexity analyzes the complexity of a prompt
|
|
func (s *AIService) AnalyzeComplexity(prompt string) int {
|
|
// Simple heuristic for complexity analysis
|
|
// In a real implementation, this would be more sophisticated
|
|
|
|
complexity := 0
|
|
|
|
// Length factor
|
|
if len(prompt) > 100 {
|
|
complexity += 2
|
|
}
|
|
if len(prompt) > 200 {
|
|
complexity += 2
|
|
}
|
|
|
|
// Question type factor
|
|
if strings.Contains(prompt, "?") {
|
|
complexity += 1
|
|
}
|
|
|
|
// Technical terms factor
|
|
technicalTerms := []string{"API", "database", "server", "code", "programming", "software", "algorithm"}
|
|
for _, term := range technicalTerms {
|
|
if strings.Contains(strings.ToLower(prompt), strings.ToLower(term)) {
|
|
complexity += 1
|
|
}
|
|
}
|
|
|
|
// Multiple questions factor
|
|
questionCount := strings.Count(prompt, "?")
|
|
if questionCount > 1 {
|
|
complexity += questionCount - 1
|
|
}
|
|
|
|
// Cap complexity at 10
|
|
if complexity > 10 {
|
|
complexity = 10
|
|
}
|
|
|
|
return complexity
|
|
}
|
|
|
|
// convertToOpenAIMessages converts conversation history to OpenAI messages
|
|
func (s *AIService) convertToOpenAIMessages(prompt string, conversationHistory []models.Message) []OpenAIMessage {
|
|
messages := []OpenAIMessage{}
|
|
|
|
// Add system message
|
|
messages = append(messages, OpenAIMessage{
|
|
Role: "system",
|
|
Content: "You are a helpful customer support assistant. Provide clear, concise, and accurate answers to customer questions.",
|
|
})
|
|
|
|
// Add conversation history
|
|
for _, msg := range conversationHistory {
|
|
role := "user"
|
|
if msg.IsAI {
|
|
role = "assistant"
|
|
}
|
|
messages = append(messages, OpenAIMessage{
|
|
Role: role,
|
|
Content: msg.Content,
|
|
})
|
|
}
|
|
|
|
// Add current prompt
|
|
messages = append(messages, OpenAIMessage{
|
|
Role: "user",
|
|
Content: prompt,
|
|
})
|
|
|
|
return messages
|
|
}
|
|
|
|
// buildOllamaPrompt builds a prompt for Ollama from conversation history
|
|
func (s *AIService) buildOllamaPrompt(prompt string, conversationHistory []models.Message) string {
|
|
var builder strings.Builder
|
|
|
|
// Add system instruction
|
|
builder.WriteString("You are a helpful customer support assistant. Provide clear, concise, and accurate answers to customer questions.\n\n")
|
|
|
|
// Add conversation history
|
|
for _, msg := range conversationHistory {
|
|
if msg.IsAI {
|
|
builder.WriteString("Assistant: ")
|
|
} else {
|
|
builder.WriteString("User: ")
|
|
}
|
|
builder.WriteString(msg.Content)
|
|
builder.WriteString("\n\n")
|
|
}
|
|
|
|
// Add current prompt
|
|
builder.WriteString("User: ")
|
|
builder.WriteString(prompt)
|
|
builder.WriteString("\n\nAssistant: ")
|
|
|
|
return builder.String()
|
|
}
|
|
|
|
// GetAvailableModels returns the available AI models
|
|
func (s *AIService) GetAvailableModels() []models.AIModel {
|
|
return []models.AIModel{
|
|
{
|
|
Name: "OpenAI GPT-4",
|
|
Type: "openai",
|
|
Model: s.openAIConfig.Model,
|
|
MaxTokens: s.openAIConfig.MaxTokens,
|
|
Temperature: s.openAIConfig.Temperature,
|
|
TopP: s.openAIConfig.TopP,
|
|
Active: true,
|
|
Priority: 2,
|
|
Description: "OpenAI's GPT-4 model for complex queries",
|
|
},
|
|
{
|
|
Name: "Local LLaMA",
|
|
Type: "local",
|
|
Model: s.localConfig.Model,
|
|
Endpoint: s.localConfig.Endpoint,
|
|
MaxTokens: s.localConfig.MaxTokens,
|
|
Temperature: s.localConfig.Temperature,
|
|
TopP: s.localConfig.TopP,
|
|
Active: true,
|
|
Priority: 1,
|
|
Description: "Local LLaMA model for simple queries",
|
|
},
|
|
}
|
|
}
|