diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 2ab13d69..c1ab4803 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -9,6 +9,7 @@ from pr_agent.git_providers import get_git_provider from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_information_from_user import PRInformationFromUser +from pr_agent.tools.pr_similar_issue import PRSimilarIssue from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_update_changelog import PRUpdateChangelog @@ -29,6 +30,7 @@ command2class = { "update_changelog": PRUpdateChangelog, "config": PRConfig, "settings": PRConfig, + "similar_issue": PRSimilarIssue, } commands = list(command2class.keys()) @@ -73,7 +75,7 @@ class PRAgent: elif action in command2class: if notify: notify() - await command2class[action](pr_url, args=args).run() + await command2class[action](pr_url, *args).run() else: return False return True diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 725d75ec..14fdda59 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -253,8 +253,12 @@ def update_settings_from_args(args: List[str]) -> List[str]: key, value = vals key = key.strip().upper() value = value.strip() - get_settings().set(key, value) - logging.info(f'Updated setting {key} to: "{value}"') + if key in get_settings(): + get_settings().set(key, value) + logging.info(f'Updated setting {key} to: "{value}"') + else: + logging.info(f'No argument: {key}') + other_args.append(arg) else: other_args.append(arg) return other_args diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index be0fa645..c010158d 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -312,6 +312,35 @@ class GithubProvider(GitProvider): return repo_name, pr_number + @staticmethod + def _parse_issue_url(issue_url: str) -> Tuple[str, int]: + parsed_url = urlparse(issue_url) + + if 'github.com' not in parsed_url.netloc: + raise ValueError("The provided URL is not a valid GitHub URL") + + path_parts = parsed_url.path.strip('/').split('/') + if 'api.github.com' in parsed_url.netloc: + if len(path_parts) < 5 or path_parts[3] != 'issues': + raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL") + repo_name = '/'.join(path_parts[1:3]) + try: + issue_number = int(path_parts[4]) + except ValueError as e: + raise ValueError("Unable to convert issue number to integer") from e + return repo_name, issue_number + + if len(path_parts) < 4 or path_parts[2] != 'issues': + raise ValueError("The provided URL does not appear to be a GitHub PR issue") + + repo_name = '/'.join(path_parts[:2]) + try: + issue_number = int(path_parts[3]) + except ValueError as e: + raise ValueError("Unable to convert issue number to integer") from e + + return repo_name, issue_number + def _get_github_client(self): deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") diff --git a/pr_agent/tools/pr_similar_issue.py b/pr_agent/tools/pr_similar_issue.py new file mode 100644 index 00000000..497f2f5d --- /dev/null +++ b/pr_agent/tools/pr_similar_issue.py @@ -0,0 +1,77 @@ +import copy +import json +import logging +from typing import List, Tuple + +from jinja2 import Environment, StrictUndefined + +from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models +from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import load_yaml +from pr_agent.config_loader import get_settings +from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers.git_provider import get_main_pr_language + + +class PRSimilarIssue: + def __init__(self, pr_url: str, issue_url: str, args: list = None): + load_data_from_local = True + if not load_data_from_local: + self.git_provider = get_git_provider()() + repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1]) + self.git_provider.repo = repo_name + self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name) + repo_obj = self.git_provider.repo_obj + + def _process_issue(issue): + header = body = issue_str = comments_str = "" + if issue.pull_request: + return header, body, issue_str, comments_str + header = issue.title + body = issue.body + comments_obj = list(issue.get_comments()) + comments_str = "" + for i, comment in enumerate(comments_obj): + comments_str += f"comment {i}:\n{comment.body}\n\n\n" + issue_str = f"Issue Header: \"{header}\"\n\nIssue Body:\n{body}" + return header, body, issue_str, comments_str + + main_issue = repo_obj.get_issue(issue_number) + assert not main_issue.pull_request + _, _, main_issue_str, main_comments_str = _process_issue(main_issue) + + issues_str_list = [] + comments_str_list = [] + issues = list(repo_obj.get_issues(state='all')) # 'open', 'closed', 'all' + for i, issue in enumerate(issues): + if issue.url == main_issue.url: + continue + if issue.pull_request: + continue + _, _, issue_str, comments_str = _process_issue(issue) + issues_str_list.append(issue_str) + comments_str_list.append(comments_str) + + json_output = {} + json_output['main_issue'] = {} + json_output['main_issue']['issue'] = main_issue_str + json_output['main_issue']['comment'] = main_comments_str + json_output['issues'] = {} + for i in range(len(issues_str_list)): + json_output['issues'][f'issue_{i}'] = {} + json_output['issues'][f'issue_{i}']['issue'] = issues_str_list[i] + json_output['issues'][f'issue_{i}'][f'comments'] = comments_str_list[i] + + jsonFile = open("/Users/talrid/Desktop/issues_data.json", "w") + jsonFile.write(json.dumps(json_output)) + jsonFile.close() + else: + jsonFile = open("/Users/talrid/Desktop/issues_data.json", "r") + json_output=json.loads(jsonFile.read()) + + from langchain.document_loaders import TextLoader + from langchain.text_splitter import CharacterTextSplitter + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + + aaa=3