package security import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" "io" "os" "path/filepath" "syscall" "golang.org/x/term" ) // KeyStore handles secure storage and retrieval of API keys type KeyStore struct { configDir string } // NewKeyStore creates a new KeyStore instance func NewKeyStore() (*KeyStore, error) { homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("failed to get user home directory: %w", err) } configDir := filepath.Join(homeDir, ".config", "autotlp") if err := os.MkdirAll(configDir, 0700); err != nil { return nil, fmt.Errorf("failed to create config directory: %w", err) } return &KeyStore{ configDir: configDir, }, nil } // StoreAPIKey securely stores an API key func (ks *KeyStore) StoreAPIKey(provider, apiKey string) error { if apiKey == "" { return fmt.Errorf("API key cannot be empty") } // Get or create master password masterPassword, err := ks.getMasterPassword() if err != nil { return fmt.Errorf("failed to get master password: %w", err) } // Encrypt the API key encryptedKey, err := ks.encrypt(apiKey, masterPassword) if err != nil { return fmt.Errorf("failed to encrypt API key: %w", err) } // Store encrypted key keyFile := filepath.Join(ks.configDir, provider+".key") if err := os.WriteFile(keyFile, []byte(encryptedKey), 0600); err != nil { return fmt.Errorf("failed to write key file: %w", err) } return nil } // RetrieveAPIKey retrieves and decrypts an API key func (ks *KeyStore) RetrieveAPIKey(provider string) (string, error) { keyFile := filepath.Join(ks.configDir, provider+".key") // Check if key file exists if _, err := os.Stat(keyFile); os.IsNotExist(err) { return "", fmt.Errorf("API key not found for provider: %s", provider) } // Read encrypted key encryptedKey, err := os.ReadFile(keyFile) if err != nil { return "", fmt.Errorf("failed to read key file: %w", err) } // Get master password masterPassword, err := ks.getMasterPassword() if err != nil { return "", fmt.Errorf("failed to get master password: %w", err) } // Decrypt the API key apiKey, err := ks.decrypt(string(encryptedKey), masterPassword) if err != nil { return "", fmt.Errorf("failed to decrypt API key: %w", err) } return apiKey, nil } // DeleteAPIKey removes a stored API key func (ks *KeyStore) DeleteAPIKey(provider string) error { keyFile := filepath.Join(ks.configDir, provider+".key") if err := os.Remove(keyFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to delete key file: %w", err) } return nil } // ListProviders returns a list of providers with stored keys func (ks *KeyStore) ListProviders() ([]string, error) { files, err := os.ReadDir(ks.configDir) if err != nil { return nil, fmt.Errorf("failed to read config directory: %w", err) } var providers []string for _, file := range files { if !file.IsDir() && filepath.Ext(file.Name()) == ".key" { provider := file.Name()[:len(file.Name())-4] // Remove .key extension providers = append(providers, provider) } } return providers, nil } // getMasterPassword gets or creates a master password for encryption func (ks *KeyStore) getMasterPassword() (string, error) { passwordFile := filepath.Join(ks.configDir, ".master") // Check if master password file exists if _, err := os.Stat(passwordFile); os.IsNotExist(err) { // Create new master password fmt.Print("Create a master password to secure your API keys: ") password, err := ks.readPassword() if err != nil { return "", fmt.Errorf("failed to read password: %w", err) } fmt.Print("Confirm master password: ") confirmPassword, err := ks.readPassword() if err != nil { return "", fmt.Errorf("failed to read confirmation password: %w", err) } if password != confirmPassword { return "", fmt.Errorf("passwords do not match") } // Hash and store the password hashedPassword := ks.hashPassword(password) if err := os.WriteFile(passwordFile, []byte(hashedPassword), 0600); err != nil { return "", fmt.Errorf("failed to store master password: %w", err) } return password, nil } // Master password exists, prompt for it fmt.Print("Enter master password: ") password, err := ks.readPassword() if err != nil { return "", fmt.Errorf("failed to read password: %w", err) } // Verify password storedHash, err := os.ReadFile(passwordFile) if err != nil { return "", fmt.Errorf("failed to read stored password hash: %w", err) } if ks.hashPassword(password) != string(storedHash) { return "", fmt.Errorf("incorrect master password") } return password, nil } // readPassword reads a password from stdin without echoing func (ks *KeyStore) readPassword() (string, error) { password, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return "", err } fmt.Println() // Add newline after password input return string(password), nil } // hashPassword creates a hash of the password for verification func (ks *KeyStore) hashPassword(password string) string { hash := sha256.Sum256([]byte(password)) return base64.StdEncoding.EncodeToString(hash[:]) } // encrypt encrypts data using AES-GCM func (ks *KeyStore) encrypt(plaintext, password string) (string, error) { // Create cipher key := sha256.Sum256([]byte(password)) block, err := aes.NewCipher(key[:]) if err != nil { return "", err } // Create GCM gcm, err := cipher.NewGCM(block) if err != nil { return "", err } // Generate nonce nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return "", err } // Encrypt ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) return base64.StdEncoding.EncodeToString(ciphertext), nil } // decrypt decrypts data using AES-GCM func (ks *KeyStore) decrypt(ciphertext, password string) (string, error) { // Decode base64 data, err := base64.StdEncoding.DecodeString(ciphertext) if err != nil { return "", err } // Create cipher key := sha256.Sum256([]byte(password)) block, err := aes.NewCipher(key[:]) if err != nil { return "", err } // Create GCM gcm, err := cipher.NewGCM(block) if err != nil { return "", err } // Extract nonce nonceSize := gcm.NonceSize() if len(data) < nonceSize { return "", fmt.Errorf("ciphertext too short") } nonce, ciphertext_bytes := data[:nonceSize], data[nonceSize:] // Decrypt plaintext, err := gcm.Open(nil, nonce, ciphertext_bytes, nil) if err != nil { return "", err } return string(plaintext), nil } // SecureInput prompts for secure input without echoing func SecureInput(prompt string) (string, error) { fmt.Print(prompt) input, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return "", err } fmt.Println() // Add newline return string(input), nil }