Files
WiseTLP/internal/security/keystore.go
2025-09-16 14:27:34 +03:00

268 lines
6.7 KiB
Go

package security
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"os"
"path/filepath"
"syscall"
"golang.org/x/term"
)
// KeyStore handles secure storage and retrieval of API keys
type KeyStore struct {
configDir string
}
// NewKeyStore creates a new KeyStore instance
func NewKeyStore() (*KeyStore, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get user home directory: %w", err)
}
configDir := filepath.Join(homeDir, ".config", "autotlp")
if err := os.MkdirAll(configDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create config directory: %w", err)
}
return &KeyStore{
configDir: configDir,
}, nil
}
// StoreAPIKey securely stores an API key
func (ks *KeyStore) StoreAPIKey(provider, apiKey string) error {
if apiKey == "" {
return fmt.Errorf("API key cannot be empty")
}
// Get or create master password
masterPassword, err := ks.getMasterPassword()
if err != nil {
return fmt.Errorf("failed to get master password: %w", err)
}
// Encrypt the API key
encryptedKey, err := ks.encrypt(apiKey, masterPassword)
if err != nil {
return fmt.Errorf("failed to encrypt API key: %w", err)
}
// Store encrypted key
keyFile := filepath.Join(ks.configDir, provider+".key")
if err := os.WriteFile(keyFile, []byte(encryptedKey), 0600); err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// RetrieveAPIKey retrieves and decrypts an API key
func (ks *KeyStore) RetrieveAPIKey(provider string) (string, error) {
keyFile := filepath.Join(ks.configDir, provider+".key")
// Check if key file exists
if _, err := os.Stat(keyFile); os.IsNotExist(err) {
return "", fmt.Errorf("API key not found for provider: %s", provider)
}
// Read encrypted key
encryptedKey, err := os.ReadFile(keyFile)
if err != nil {
return "", fmt.Errorf("failed to read key file: %w", err)
}
// Get master password
masterPassword, err := ks.getMasterPassword()
if err != nil {
return "", fmt.Errorf("failed to get master password: %w", err)
}
// Decrypt the API key
apiKey, err := ks.decrypt(string(encryptedKey), masterPassword)
if err != nil {
return "", fmt.Errorf("failed to decrypt API key: %w", err)
}
return apiKey, nil
}
// DeleteAPIKey removes a stored API key
func (ks *KeyStore) DeleteAPIKey(provider string) error {
keyFile := filepath.Join(ks.configDir, provider+".key")
if err := os.Remove(keyFile); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to delete key file: %w", err)
}
return nil
}
// ListProviders returns a list of providers with stored keys
func (ks *KeyStore) ListProviders() ([]string, error) {
files, err := os.ReadDir(ks.configDir)
if err != nil {
return nil, fmt.Errorf("failed to read config directory: %w", err)
}
var providers []string
for _, file := range files {
if !file.IsDir() && filepath.Ext(file.Name()) == ".key" {
provider := file.Name()[:len(file.Name())-4] // Remove .key extension
providers = append(providers, provider)
}
}
return providers, nil
}
// getMasterPassword gets or creates a master password for encryption
func (ks *KeyStore) getMasterPassword() (string, error) {
passwordFile := filepath.Join(ks.configDir, ".master")
// Check if master password file exists
if _, err := os.Stat(passwordFile); os.IsNotExist(err) {
// Create new master password
fmt.Print("Create a master password to secure your API keys: ")
password, err := ks.readPassword()
if err != nil {
return "", fmt.Errorf("failed to read password: %w", err)
}
fmt.Print("Confirm master password: ")
confirmPassword, err := ks.readPassword()
if err != nil {
return "", fmt.Errorf("failed to read confirmation password: %w", err)
}
if password != confirmPassword {
return "", fmt.Errorf("passwords do not match")
}
// Hash and store the password
hashedPassword := ks.hashPassword(password)
if err := os.WriteFile(passwordFile, []byte(hashedPassword), 0600); err != nil {
return "", fmt.Errorf("failed to store master password: %w", err)
}
return password, nil
}
// Master password exists, prompt for it
fmt.Print("Enter master password: ")
password, err := ks.readPassword()
if err != nil {
return "", fmt.Errorf("failed to read password: %w", err)
}
// Verify password
storedHash, err := os.ReadFile(passwordFile)
if err != nil {
return "", fmt.Errorf("failed to read stored password hash: %w", err)
}
if ks.hashPassword(password) != string(storedHash) {
return "", fmt.Errorf("incorrect master password")
}
return password, nil
}
// readPassword reads a password from stdin without echoing
func (ks *KeyStore) readPassword() (string, error) {
password, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return "", err
}
fmt.Println() // Add newline after password input
return string(password), nil
}
// hashPassword creates a hash of the password for verification
func (ks *KeyStore) hashPassword(password string) string {
hash := sha256.Sum256([]byte(password))
return base64.StdEncoding.EncodeToString(hash[:])
}
// encrypt encrypts data using AES-GCM
func (ks *KeyStore) encrypt(plaintext, password string) (string, error) {
// Create cipher
key := sha256.Sum256([]byte(password))
block, err := aes.NewCipher(key[:])
if err != nil {
return "", err
}
// Create GCM
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// Generate nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
// Encrypt
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// decrypt decrypts data using AES-GCM
func (ks *KeyStore) decrypt(ciphertext, password string) (string, error) {
// Decode base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", err
}
// Create cipher
key := sha256.Sum256([]byte(password))
block, err := aes.NewCipher(key[:])
if err != nil {
return "", err
}
// Create GCM
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// Extract nonce
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertext_bytes := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonce, ciphertext_bytes, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}
// SecureInput prompts for secure input without echoing
func SecureInput(prompt string) (string, error) {
fmt.Print(prompt)
input, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return "", err
}
fmt.Println() // Add newline
return string(input), nil
}