package cmd import ( "bufio" "fmt" "os" "strings" "github.com/spf13/cobra" "github.com/iwasforcedtobehere/git-automation-cli/internal/config" "github.com/iwasforcedtobehere/git-automation-cli/internal/git" "github.com/iwasforcedtobehere/git-automation-cli/internal/github" "github.com/iwasforcedtobehere/git-automation-cli/internal/validation" ) var prCmd = &cobra.Command{ Use: "pr", Short: "Pull request utilities", Long: `Create, list, and manage GitHub pull requests. Supports creating pull requests from the current branch and managing existing PRs.`, } var prCreateCmd = &cobra.Command{ Use: "create [title]", Short: "Create a new pull request", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() title := strings.Join(args, " ") // Get body from flag or prompt body, _ := cmd.Flags().GetString("body") if body == "" { body = promptForBody() } // Get draft flag draft, _ := cmd.Flags().GetBool("draft") // Get base branch base, _ := cmd.Flags().GetString("base") if base == "" { base = config.GlobalConfig.DefaultBranch if base == "" { base = "main" } } // Validate Git repository if validationResult := validation.ValidateGitRepository(ctx); !validationResult.IsValid { return fmt.Errorf(validationResult.GetErrors()) } // Get current branch currentBranch, err := git.CurrentBranch(ctx) if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } // Check if current branch is the same as base branch if currentBranch == base { return fmt.Errorf("cannot create pull request from %s branch to itself", base) } // Create GitHub client client, err := github.NewClient(ctx) if err != nil { return fmt.Errorf("failed to create GitHub client: %w", err) } // Get repository information remote, _ := cmd.Flags().GetString("remote") if remote == "" { remote = config.GlobalConfig.DefaultRemote if remote == "" { remote = "origin" } } // Get remote URL remoteURL, err := git.GetRemoteURL(ctx, remote) if err != nil { return fmt.Errorf("failed to get remote URL: %w", err) } // Parse GitHub URL if !github.IsGitHubURL(remoteURL) { return fmt.Errorf("remote URL is not a GitHub URL: %s", remoteURL) } owner, repo, err := github.ParseGitHubURL(remoteURL) if err != nil { return fmt.Errorf("failed to parse GitHub URL: %w", err) } // Set repository on client client.SetRepository(owner, repo) // Create pull request pr := &github.PullRequestRequest{ Title: title, Body: body, Head: currentBranch, Base: base, Draft: draft, } pullRequest, err := client.CreatePullRequest(ctx, pr) if err != nil { return fmt.Errorf("failed to create pull request: %w", err) } fmt.Printf("Created pull request: %s\n", pullRequest.HTMLURL) return nil }, } var prListCmd = &cobra.Command{ Use: "list", Short: "List pull requests", RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() // Get state flag state, _ := cmd.Flags().GetString("state") if state == "" { state = "open" } // Validate Git repository if validationResult := validation.ValidateGitRepository(ctx); !validationResult.IsValid { return fmt.Errorf(validationResult.GetErrors()) } // Create GitHub client client, err := github.NewClient(ctx) if err != nil { return fmt.Errorf("failed to create GitHub client: %w", err) } // Get repository information remote, _ := cmd.Flags().GetString("remote") if remote == "" { remote = config.GlobalConfig.DefaultRemote if remote == "" { remote = "origin" } } // Get remote URL remoteURL, err := git.GetRemoteURL(ctx, remote) if err != nil { return fmt.Errorf("failed to get remote URL: %w", err) } // Parse GitHub URL if !github.IsGitHubURL(remoteURL) { return fmt.Errorf("remote URL is not a GitHub URL: %s", remoteURL) } owner, repo, err := github.ParseGitHubURL(remoteURL) if err != nil { return fmt.Errorf("failed to parse GitHub URL: %w", err) } // Set repository on client client.SetRepository(owner, repo) // Get pull requests pullRequests, err := client.GetPullRequests(ctx, state) if err != nil { return fmt.Errorf("failed to get pull requests: %w", err) } // Print pull requests if len(pullRequests) == 0 { fmt.Printf("No %s pull requests found\n", state) return nil } fmt.Printf("%s pull requests:\n", strings.Title(state)) for _, pr := range pullRequests { fmt.Printf("#%d: %s (%s)\n", pr.Number, pr.Title, pr.User.Login) } return nil }, } var prMergeCmd = &cobra.Command{ Use: "merge [number]", Short: "Merge a pull request", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() number := args[0] // Validate Git repository if validationResult := validation.ValidateGitRepository(ctx); !validationResult.IsValid { return fmt.Errorf(validationResult.GetErrors()) } // Create GitHub client client, err := github.NewClient(ctx) if err != nil { return fmt.Errorf("failed to create GitHub client: %w", err) } // Get repository information remote, _ := cmd.Flags().GetString("remote") if remote == "" { remote = config.GlobalConfig.DefaultRemote if remote == "" { remote = "origin" } } // Get remote URL remoteURL, err := git.GetRemoteURL(ctx, remote) if err != nil { return fmt.Errorf("failed to get remote URL: %w", err) } // Parse GitHub URL if !github.IsGitHubURL(remoteURL) { return fmt.Errorf("remote URL is not a GitHub URL: %s", remoteURL) } owner, repo, err := github.ParseGitHubURL(remoteURL) if err != nil { return fmt.Errorf("failed to parse GitHub URL: %w", err) } // Set repository on client client.SetRepository(owner, repo) // Get merge method method, _ := cmd.Flags().GetString("method") if method == "" { method = "merge" } // Merge pull request mergeRequest := &github.MergePullRequestRequest{ MergeMethod: method, } mergeResponse, err := client.MergePullRequest(ctx, parseInt(number), mergeRequest) if err != nil { return fmt.Errorf("failed to merge pull request: %w", err) } fmt.Printf("Merged pull request: %s\n", mergeResponse.SHA) return nil }, } func promptForBody() string { fmt.Print("Enter pull request body (press Enter twice to finish):\n") reader := bufio.NewReader(os.Stdin) var body string var consecutiveEmptyLines int for { fmt.Print("> ") line, err := reader.ReadString('\n') if err != nil { break } line = strings.TrimSpace(line) if line == "" { consecutiveEmptyLines++ if consecutiveEmptyLines >= 2 { break } } else { consecutiveEmptyLines = 0 if body != "" { body += "\n" } body += line } } return body } func readLine() string { reader := bufio.NewReader(os.Stdin) line, err := reader.ReadString('\n') if err != nil { return "" } return strings.TrimSpace(line) } func parseInt(s string) int { var result int _, err := fmt.Sscanf(s, "%d", &result) if err != nil { return 0 } return result } func init() { rootCmd.AddCommand(prCmd) prCmd.AddCommand(prCreateCmd) prCmd.AddCommand(prListCmd) prCmd.AddCommand(prMergeCmd) // Add flags to pr create command prCreateCmd.Flags().String("body", "", "Pull request body") prCreateCmd.Flags().Bool("draft", false, "Create a draft pull request") prCreateCmd.Flags().String("base", "", "Base branch for the pull request") prCreateCmd.Flags().String("remote", "", "Remote to use for the pull request") // Add flags to pr list command prListCmd.Flags().String("state", "open", "State of pull requests to list (open, closed, all)") prListCmd.Flags().String("remote", "", "Remote to use for the pull request") // Add flags to pr merge command prMergeCmd.Flags().String("method", "merge", "Merge method (merge, squash, rebase)") prMergeCmd.Flags().String("remote", "", "Remote to use for the pull request") }