diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 798fc6c5..82a2af40 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -1,4 +1,5 @@ MAX_TOKENS = { + 'text-embedding-ada-002': 8000, 'gpt-3.5-turbo': 4000, 'gpt-3.5-turbo-0613': 4000, 'gpt-3.5-turbo-0301': 4000, diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index f018a92b..d7eff9d7 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -21,7 +21,7 @@ class TokenHandler: method. """ - def __init__(self, pr, vars: dict, system, user): + def __init__(self, pr=None, vars: dict = {}, system="", user=""): """ Initializes the TokenHandler object. @@ -32,7 +32,8 @@ class TokenHandler: - user: The user string. """ self.encoder = get_token_encoder() - self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) + if pr is not None: + self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): """ diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 0124c3d6..7ac4b468 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -32,33 +32,37 @@ def convert_to_markdown(output_data: dict) -> str: emojis = { "Main theme": "๐ŸŽฏ", + "PR summary": "๐Ÿ“", "Type of PR": "๐Ÿ“Œ", "Score": "๐Ÿ…", "Relevant tests added": "๐Ÿงช", "Unrelated changes": "โš ๏ธ", "Focused PR": "โœจ", "Security concerns": "๐Ÿ”’", - "General PR suggestions": "๐Ÿ’ก", + "General suggestions": "๐Ÿ’ก", "Insights from user's answers": "๐Ÿ“", "Code feedback": "๐Ÿค–", } for key, value in output_data.items(): - if not value: + if value is None or value == '' or value == {}: continue if isinstance(value, dict): markdown_text += f"## {key}\n\n" markdown_text += convert_to_markdown(value) elif isinstance(value, list): - if key.lower() == 'code feedback': - markdown_text += "\n" # just looks nicer with additional line breaks emoji = emojis.get(key, "") - markdown_text += f"- {emoji} **{key}:**\n\n" + if key.lower() == 'code feedback': + markdown_text += f"\n\n- **
{ emoji } Code feedback:**\n\n" + else: + markdown_text += f"- {emoji} **{key}:**\n\n" for item in value: if isinstance(item, dict) and key.lower() == 'code feedback': markdown_text += parse_code_suggestion(item) elif item: markdown_text += f" - {item}\n" + if key.lower() == 'code feedback': + markdown_text += "
\n\n" elif value != 'n/a': emoji = emojis.get(key, "") markdown_text += f"- {emoji} **{key}:** {value}\n" @@ -164,7 +168,7 @@ def fix_json_escape_char(json_message=None): Raises: None - """ + """ try: result = json.loads(json_message) except Exception as e: @@ -191,7 +195,7 @@ def convert_str_to_datetime(date_str): Example: >>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC') datetime.datetime(2022, 1, 1, 12, 0, 0) - """ + """ datetime_format = '%a, %d %b %Y %H:%M:%S %Z' return datetime.strptime(date_str, datetime_format) @@ -245,27 +249,34 @@ def update_settings_from_args(args: List[str]) -> List[str]: arg = arg.strip() if arg.startswith('--'): arg = arg.strip('-').strip() - vals = arg.split('=') + vals = arg.split('=', 1) if len(vals) != 2: - logging.error(f'Invalid argument format: {arg}') + if len(vals) > 2: # --extended is a valid argument + logging.error(f'Invalid argument format: {arg}') other_args.append(arg) continue key, value = _fix_key_value(*vals) - if key in get_settings(): - get_settings().set(key, value) - logging.info(f'Updated setting {key} to: "{value}"') - else: - logging.info(f'No argument: {key}') - other_args.append(arg) + get_settings().set(key, value) + logging.info(f'Updated setting {key} to: "{value}"') else: other_args.append(arg) return other_args +def _fix_key_value(key: str, value: str): + key = key.strip().upper() + value = value.strip() + try: + value = yaml.safe_load(value) + except Exception as e: + logging.error(f"Failed to parse YAML for config override {key}={value}", exc_info=e) + return key, value + + def load_yaml(review_text: str) -> dict: review_text = review_text.removeprefix('```yaml').rstrip('`') try: - data = yaml.load(review_text, Loader=yaml.SafeLoader) + data = yaml.safe_load(review_text) except Exception as e: logging.error(f"Failed to parse AI prediction: {e}") data = try_fix_yaml(review_text) diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 01c1a7ec..7c4508d9 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -5,6 +5,7 @@ import os from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.config_loader import get_settings +from pr_agent.tools.pr_similar_issue import PRSimilarIssue def run(inargs=None): @@ -37,14 +38,19 @@ Configuration: To edit any configuration parameter from 'configuration.toml', just add -config_path=. For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' """) - parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) + parser.add_argument('--pr_url', type=str, help='The URL of the PR to review') + parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None) parser.add_argument('command', type=str, help='The', choices=commands, default='review') parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) args = parser.parse_args(inargs) logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) command = args.command.lower() get_settings().set("CONFIG.CLI_MODE", True) - result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest))) + if args.issue_url: + result = asyncio.run(PRAgent().handle_request(args.issue_url, command + " " + " ".join(args.rest))) + # result = asyncio.run(PRSimilarIssue(args.issue_url, cli_mode=True, args=command + " " + " ".join(args.rest)).run()) + else: + result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest))) if not result: parser.print_help() diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 7e93d18c..0521716b 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -32,7 +32,7 @@ class GithubProvider(GitProvider): self.diff_files = None self.git_files = None self.incremental = incremental - if pr_url: + if pr_url and 'pull' in pr_url: self.set_pr(pr_url) self.last_commit_id = list(self.pr.get_commits())[-1] diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index f8abd555..9bfdf3a3 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -84,4 +84,14 @@ polling_interval_seconds = 30 [local] # LocalGitProvider settings - uncomment to use paths other than default # description_path= "path/to/description.md" -# review_path= "path/to/review.md" \ No newline at end of file +# review_path= "path/to/review.md" + +[pr_similar_issue] +skip_comments = false +force_update_dataset = false +max_issues_to_scan = 1000 + +[pinecone] +# fill and place in .secrets.toml +#api_key = ... +# environment = "gcp-starter" \ No newline at end of file diff --git a/pr_agent/tools/pr_similar_issue.py b/pr_agent/tools/pr_similar_issue.py index 497f2f5d..94dc10d3 100644 --- a/pr_agent/tools/pr_similar_issue.py +++ b/pr_agent/tools/pr_similar_issue.py @@ -1,77 +1,250 @@ import copy import json import logging +from enum import Enum from typing import List, Tuple +import pinecone +import openai +import pandas as pd +from pydantic import BaseModel, Field -from jinja2 import Environment, StrictUndefined - -from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models +from pr_agent.algo import MAX_TOKENS from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider -from pr_agent.git_providers.git_provider import get_main_pr_language +from pinecone_datasets import Dataset, DatasetMetadata + +MODEL = "text-embedding-ada-002" class PRSimilarIssue: - def __init__(self, pr_url: str, issue_url: str, args: list = None): - load_data_from_local = True - if not load_data_from_local: - self.git_provider = get_git_provider()() - repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1]) - self.git_provider.repo = repo_name - self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name) - repo_obj = self.git_provider.repo_obj + def __init__(self, issue_url: str, args: list = None): + if get_settings().config.git_provider != "github": + raise Exception("Only github is supported for similar issue tool") - def _process_issue(issue): - header = body = issue_str = comments_str = "" - if issue.pull_request: - return header, body, issue_str, comments_str - header = issue.title - body = issue.body - comments_obj = list(issue.get_comments()) - comments_str = "" - for i, comment in enumerate(comments_obj): - comments_str += f"comment {i}:\n{comment.body}\n\n\n" - issue_str = f"Issue Header: \"{header}\"\n\nIssue Body:\n{body}" - return header, body, issue_str, comments_str + self.cli_mode = get_settings().CONFIG.CLI_MODE + self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan + self.issue_url = issue_url + self.git_provider = get_git_provider()() + repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1]) + self.git_provider.repo = repo_name + self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name) + self.token_handler = TokenHandler() + repo_obj = self.git_provider.repo_obj + repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-') + index_name = self.index_name = "codium-ai-pr-agent-issues" - main_issue = repo_obj.get_issue(issue_number) - assert not main_issue.pull_request - _, _, main_issue_str, main_comments_str = _process_issue(main_issue) + # assuming pinecone api key and environment are set in secrets file + try: + api_key = get_settings().pinecone.api_key + environment = get_settings().pinecone.environment + except Exception: + if not self.cli_mode: + repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1]) + issue_main = self.git_provider.repo_obj.get_issue(original_issue_number) + issue_main.create_comment("Please set pinecone api key and environment in secrets file") + raise Exception("Please set pinecone api key and environment in secrets file") - issues_str_list = [] - comments_str_list = [] - issues = list(repo_obj.get_issues(state='all')) # 'open', 'closed', 'all' - for i, issue in enumerate(issues): - if issue.url == main_issue.url: - continue - if issue.pull_request: - continue - _, _, issue_str, comments_str = _process_issue(issue) - issues_str_list.append(issue_str) - comments_str_list.append(comments_str) - - json_output = {} - json_output['main_issue'] = {} - json_output['main_issue']['issue'] = main_issue_str - json_output['main_issue']['comment'] = main_comments_str - json_output['issues'] = {} - for i in range(len(issues_str_list)): - json_output['issues'][f'issue_{i}'] = {} - json_output['issues'][f'issue_{i}']['issue'] = issues_str_list[i] - json_output['issues'][f'issue_{i}'][f'comments'] = comments_str_list[i] - - jsonFile = open("/Users/talrid/Desktop/issues_data.json", "w") - jsonFile.write(json.dumps(json_output)) - jsonFile.close() + # check if index exists, and if repo is already indexed + run_from_scratch = False + upsert = True + pinecone.init(api_key=api_key, environment=environment) + if not index_name in pinecone.list_indexes(): + run_from_scratch = True + upsert = False else: - jsonFile = open("/Users/talrid/Desktop/issues_data.json", "r") - json_output=json.loads(jsonFile.read()) + if get_settings().pr_similar_issue.force_update_dataset: + upsert = True + else: + pinecone_index = pinecone.Index(index_name=index_name) + res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict() + if res["vectors"]: + upsert = False - from langchain.document_loaders import TextLoader - from langchain.text_splitter import CharacterTextSplitter - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + if run_from_scratch or upsert: # index the entire repo + logging.info('Indexing the entire repo...') - aaa=3 + logging.info('Getting issues...') + issues = list(repo_obj.get_issues(state='all')) + logging.info('Done') + self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert) + else: # update index if needed + pinecone_index = pinecone.Index(index_name=index_name) + issues_to_update = [] + issues_paginated_list = repo_obj.get_issues(state='all') + counter = 1 + for issue in issues_paginated_list: + if issue.pull_request: + continue + issue_str, comments, number = self._process_issue(issue) + issue_key = f"issue_{number}" + id = issue_key + "." + "issue" + res = pinecone_index.fetch([id]).to_dict() + is_new_issue = True + for vector in res["vectors"].values(): + if vector['metadata']['repo'] == repo_name_for_index: + is_new_issue = False + break + if is_new_issue: + counter += 1 + issues_to_update.append(issue) + else: + break + + if issues_to_update: + logging.info(f'Updating index with {counter} new issues...') + self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True) + else: + logging.info('No new issues to update') + + async def run(self): + repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1]) + issue_main = self.git_provider.repo_obj.get_issue(original_issue_number) + issue_str, comments, number = self._process_issue(issue_main) + openai.api_key = get_settings().openai.key + + res = openai.Embedding.create(input=[issue_str], engine=MODEL) + embeds = [record['embedding'] for record in res['data']] + pinecone_index = pinecone.Index(index_name=self.index_name) + res = pinecone_index.query(embeds[0], + top_k=5, + filter={"repo": self.repo_name_for_index}, + include_metadata=True).to_dict() + relevant_issues_number_list = [] + for r in res['matches']: + issue_number = int(r["id"].split('.')[0].split('_')[-1]) + if original_issue_number == issue_number: + continue + if issue_number not in relevant_issues_number_list: + relevant_issues_number_list.append(issue_number) + + similar_issues_str = "Similar Issues:\n\n" + for i, issue_number_similar in enumerate(relevant_issues_number_list): + issue = self.git_provider.repo_obj.get_issue(issue_number_similar) + title = issue.title + url = issue.html_url + similar_issues_str += f"{i + 1}. [{title}]({url})\n\n" + if get_settings().config.publish_output: + response = issue_main.create_comment(similar_issues_str) + logging.info(similar_issues_str) + + def _process_issue(self, issue): + header = issue.title + body = issue.body + number = issue.number + if get_settings().pinecone.skip_comments: + comments = [] + else: + comments = list(issue.get_comments()) + issue_str = f"Issue Header: \"{header}\"\n\nIssue Body:\n{body}" + return issue_str, comments, number + + def _update_index_with_issues(self, issues_list, repo_name_for_index, upsert=False): + logging.info('Processing issues...') + corpus = Corpus() + example_issue_record = Record( + id=f"example_issue_{repo_name_for_index}", + text="example_issue", + metadata=Metadata(repo=repo_name_for_index) + ) + corpus.append(example_issue_record) + + counter = 0 + for issue in issues_list: + + if issue.pull_request: + continue + + counter += 1 + if counter >= self.max_issues_to_scan: + logging.info(f"Scanned {self.max_issues_to_scan} issues, stopping") + break + + issue_str, comments, number = self._process_issue(issue) + issue_key = f"issue_{number}" + username = issue.user.login + created_at = str(issue.created_at) + if len(issue_str) < 8000 or \ + self.token_handler.count_tokens(issue_str) < MAX_TOKENS[MODEL]: # fast reject first + issue_record = Record( + id=issue_key + "." + "issue", + text=issue_str, + metadata=Metadata(repo=repo_name_for_index, + username=username, + created_at=created_at, + level=IssueLevel.ISSUE) + ) + corpus.append(issue_record) + if comments: + for j, comment in enumerate(comments): + comment_body = comment.body + num_words_comment = len(comment_body.split()) + if num_words_comment < 10: + continue + + if len(issue_str) < 8000 or \ + self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]: + comment_record = Record( + id=issue_key + ".comment_" + str(j + 1), + text=comment_body, + metadata=Metadata(repo=repo_name_for_index, + username=username, # use issue username for all comments + created_at=created_at, + level=IssueLevel.COMMENT) + ) + corpus.append(comment_record) + df = pd.DataFrame(corpus.dict()["documents"]) + logging.info('Done') + + logging.info('Embedding...') + openai.api_key = get_settings().openai.key + list_to_encode = list(df["text"].values) + res = openai.Embedding.create(input=list_to_encode, engine=MODEL) + embeds = [record['embedding'] for record in res['data']] + df["values"] = embeds + meta = DatasetMetadata.empty() + meta.dense_model.dimension = len(embeds[0]) + ds = Dataset.from_pandas(df, meta) + logging.info('Done') + + api_key = get_settings().pinecone.api_key + environment = get_settings().pinecone.environment + if not upsert: + logging.info('Creating index...') + ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment) + else: + logging.info('Upserting index...') + namespace = "" + batch_size: int = 100 + concurrency: int = 10 + pinecone.init(api_key=api_key, environment=environment) + ds._upsert_to_index(self.index_name, namespace, batch_size, concurrency) + logging.info('Done') + + +class IssueLevel(str, Enum): + ISSUE = "issue" + COMMENT = "comment" + + +class Metadata(BaseModel): + repo: str + username: str = Field(default="@codium") + created_at: str = Field(default="01-01-1970 00:00:00.00000") + level: IssueLevel = Field(default=IssueLevel.ISSUE) + + class Config: + use_enum_values = True + + +class Record(BaseModel): + id: str + text: str + metadata: Metadata + + +class Corpus(BaseModel): + documents: List[Record] = Field(default=[]) + + def append(self, r: Record): + self.documents.append(r) diff --git a/requirements.txt b/requirements.txt index 99efa846..5d4caaa6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,6 @@ boto3~=1.28.25 google-cloud-storage==2.10.0 ujson==5.8.0 azure-devops==7.1.0b3 -msrest==0.7.1 \ No newline at end of file +msrest==0.7.1 +pinecone-client==2.2.2 +pinecone_datasets==0.6.1 \ No newline at end of file