diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py
index 798fc6c5..82a2af40 100644
--- a/pr_agent/algo/__init__.py
+++ b/pr_agent/algo/__init__.py
@@ -1,4 +1,5 @@
MAX_TOKENS = {
+ 'text-embedding-ada-002': 8000,
'gpt-3.5-turbo': 4000,
'gpt-3.5-turbo-0613': 4000,
'gpt-3.5-turbo-0301': 4000,
diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py
index f018a92b..d7eff9d7 100644
--- a/pr_agent/algo/token_handler.py
+++ b/pr_agent/algo/token_handler.py
@@ -21,7 +21,7 @@ class TokenHandler:
method.
"""
- def __init__(self, pr, vars: dict, system, user):
+ def __init__(self, pr=None, vars: dict = {}, system="", user=""):
"""
Initializes the TokenHandler object.
@@ -32,7 +32,8 @@ class TokenHandler:
- user: The user string.
"""
self.encoder = get_token_encoder()
- self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
+ if pr is not None:
+ self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
"""
diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py
index 0124c3d6..7ac4b468 100644
--- a/pr_agent/algo/utils.py
+++ b/pr_agent/algo/utils.py
@@ -32,33 +32,37 @@ def convert_to_markdown(output_data: dict) -> str:
emojis = {
"Main theme": "๐ฏ",
+ "PR summary": "๐",
"Type of PR": "๐",
"Score": "๐
",
"Relevant tests added": "๐งช",
"Unrelated changes": "โ ๏ธ",
"Focused PR": "โจ",
"Security concerns": "๐",
- "General PR suggestions": "๐ก",
+ "General suggestions": "๐ก",
"Insights from user's answers": "๐",
"Code feedback": "๐ค",
}
for key, value in output_data.items():
- if not value:
+ if value is None or value == '' or value == {}:
continue
if isinstance(value, dict):
markdown_text += f"## {key}\n\n"
markdown_text += convert_to_markdown(value)
elif isinstance(value, list):
- if key.lower() == 'code feedback':
- markdown_text += "\n" # just looks nicer with additional line breaks
emoji = emojis.get(key, "")
- markdown_text += f"- {emoji} **{key}:**\n\n"
+ if key.lower() == 'code feedback':
+ markdown_text += f"\n\n- ** { emoji } Code feedback:**
\n\n"
+ else:
+ markdown_text += f"- {emoji} **{key}:**\n\n"
for item in value:
if isinstance(item, dict) and key.lower() == 'code feedback':
markdown_text += parse_code_suggestion(item)
elif item:
markdown_text += f" - {item}\n"
+ if key.lower() == 'code feedback':
+ markdown_text += " \n\n"
elif value != 'n/a':
emoji = emojis.get(key, "")
markdown_text += f"- {emoji} **{key}:** {value}\n"
@@ -164,7 +168,7 @@ def fix_json_escape_char(json_message=None):
Raises:
None
- """
+ """
try:
result = json.loads(json_message)
except Exception as e:
@@ -191,7 +195,7 @@ def convert_str_to_datetime(date_str):
Example:
>>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC')
datetime.datetime(2022, 1, 1, 12, 0, 0)
- """
+ """
datetime_format = '%a, %d %b %Y %H:%M:%S %Z'
return datetime.strptime(date_str, datetime_format)
@@ -245,27 +249,34 @@ def update_settings_from_args(args: List[str]) -> List[str]:
arg = arg.strip()
if arg.startswith('--'):
arg = arg.strip('-').strip()
- vals = arg.split('=')
+ vals = arg.split('=', 1)
if len(vals) != 2:
- logging.error(f'Invalid argument format: {arg}')
+ if len(vals) > 2: # --extended is a valid argument
+ logging.error(f'Invalid argument format: {arg}')
other_args.append(arg)
continue
key, value = _fix_key_value(*vals)
- if key in get_settings():
- get_settings().set(key, value)
- logging.info(f'Updated setting {key} to: "{value}"')
- else:
- logging.info(f'No argument: {key}')
- other_args.append(arg)
+ get_settings().set(key, value)
+ logging.info(f'Updated setting {key} to: "{value}"')
else:
other_args.append(arg)
return other_args
+def _fix_key_value(key: str, value: str):
+ key = key.strip().upper()
+ value = value.strip()
+ try:
+ value = yaml.safe_load(value)
+ except Exception as e:
+ logging.error(f"Failed to parse YAML for config override {key}={value}", exc_info=e)
+ return key, value
+
+
def load_yaml(review_text: str) -> dict:
review_text = review_text.removeprefix('```yaml').rstrip('`')
try:
- data = yaml.load(review_text, Loader=yaml.SafeLoader)
+ data = yaml.safe_load(review_text)
except Exception as e:
logging.error(f"Failed to parse AI prediction: {e}")
data = try_fix_yaml(review_text)
diff --git a/pr_agent/cli.py b/pr_agent/cli.py
index 01c1a7ec..7c4508d9 100644
--- a/pr_agent/cli.py
+++ b/pr_agent/cli.py
@@ -5,6 +5,7 @@ import os
from pr_agent.agent.pr_agent import PRAgent, commands
from pr_agent.config_loader import get_settings
+from pr_agent.tools.pr_similar_issue import PRSimilarIssue
def run(inargs=None):
@@ -37,14 +38,19 @@ Configuration:
To edit any configuration parameter from 'configuration.toml', just add -config_path=.
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""")
- parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
+ parser.add_argument('--pr_url', type=str, help='The URL of the PR to review')
+ parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None)
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
args = parser.parse_args(inargs)
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
command = args.command.lower()
get_settings().set("CONFIG.CLI_MODE", True)
- result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest)))
+ if args.issue_url:
+ result = asyncio.run(PRAgent().handle_request(args.issue_url, command + " " + " ".join(args.rest)))
+ # result = asyncio.run(PRSimilarIssue(args.issue_url, cli_mode=True, args=command + " " + " ".join(args.rest)).run())
+ else:
+ result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest)))
if not result:
parser.print_help()
diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py
index 7e93d18c..0521716b 100644
--- a/pr_agent/git_providers/github_provider.py
+++ b/pr_agent/git_providers/github_provider.py
@@ -32,7 +32,7 @@ class GithubProvider(GitProvider):
self.diff_files = None
self.git_files = None
self.incremental = incremental
- if pr_url:
+ if pr_url and 'pull' in pr_url:
self.set_pr(pr_url)
self.last_commit_id = list(self.pr.get_commits())[-1]
diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml
index f8abd555..9bfdf3a3 100644
--- a/pr_agent/settings/configuration.toml
+++ b/pr_agent/settings/configuration.toml
@@ -84,4 +84,14 @@ polling_interval_seconds = 30
[local]
# LocalGitProvider settings - uncomment to use paths other than default
# description_path= "path/to/description.md"
-# review_path= "path/to/review.md"
\ No newline at end of file
+# review_path= "path/to/review.md"
+
+[pr_similar_issue]
+skip_comments = false
+force_update_dataset = false
+max_issues_to_scan = 1000
+
+[pinecone]
+# fill and place in .secrets.toml
+#api_key = ...
+# environment = "gcp-starter"
\ No newline at end of file
diff --git a/pr_agent/tools/pr_similar_issue.py b/pr_agent/tools/pr_similar_issue.py
index 497f2f5d..94dc10d3 100644
--- a/pr_agent/tools/pr_similar_issue.py
+++ b/pr_agent/tools/pr_similar_issue.py
@@ -1,77 +1,250 @@
import copy
import json
import logging
+from enum import Enum
from typing import List, Tuple
+import pinecone
+import openai
+import pandas as pd
+from pydantic import BaseModel, Field
-from jinja2 import Environment, StrictUndefined
-
-from pr_agent.algo.ai_handler import AiHandler
-from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
+from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import TokenHandler
-from pr_agent.algo.utils import load_yaml
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
-from pr_agent.git_providers.git_provider import get_main_pr_language
+from pinecone_datasets import Dataset, DatasetMetadata
+
+MODEL = "text-embedding-ada-002"
class PRSimilarIssue:
- def __init__(self, pr_url: str, issue_url: str, args: list = None):
- load_data_from_local = True
- if not load_data_from_local:
- self.git_provider = get_git_provider()()
- repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1])
- self.git_provider.repo = repo_name
- self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
- repo_obj = self.git_provider.repo_obj
+ def __init__(self, issue_url: str, args: list = None):
+ if get_settings().config.git_provider != "github":
+ raise Exception("Only github is supported for similar issue tool")
- def _process_issue(issue):
- header = body = issue_str = comments_str = ""
- if issue.pull_request:
- return header, body, issue_str, comments_str
- header = issue.title
- body = issue.body
- comments_obj = list(issue.get_comments())
- comments_str = ""
- for i, comment in enumerate(comments_obj):
- comments_str += f"comment {i}:\n{comment.body}\n\n\n"
- issue_str = f"Issue Header: \"{header}\"\n\nIssue Body:\n{body}"
- return header, body, issue_str, comments_str
+ self.cli_mode = get_settings().CONFIG.CLI_MODE
+ self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan
+ self.issue_url = issue_url
+ self.git_provider = get_git_provider()()
+ repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1])
+ self.git_provider.repo = repo_name
+ self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
+ self.token_handler = TokenHandler()
+ repo_obj = self.git_provider.repo_obj
+ repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
+ index_name = self.index_name = "codium-ai-pr-agent-issues"
- main_issue = repo_obj.get_issue(issue_number)
- assert not main_issue.pull_request
- _, _, main_issue_str, main_comments_str = _process_issue(main_issue)
+ # assuming pinecone api key and environment are set in secrets file
+ try:
+ api_key = get_settings().pinecone.api_key
+ environment = get_settings().pinecone.environment
+ except Exception:
+ if not self.cli_mode:
+ repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
+ issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
+ issue_main.create_comment("Please set pinecone api key and environment in secrets file")
+ raise Exception("Please set pinecone api key and environment in secrets file")
- issues_str_list = []
- comments_str_list = []
- issues = list(repo_obj.get_issues(state='all')) # 'open', 'closed', 'all'
- for i, issue in enumerate(issues):
- if issue.url == main_issue.url:
- continue
- if issue.pull_request:
- continue
- _, _, issue_str, comments_str = _process_issue(issue)
- issues_str_list.append(issue_str)
- comments_str_list.append(comments_str)
-
- json_output = {}
- json_output['main_issue'] = {}
- json_output['main_issue']['issue'] = main_issue_str
- json_output['main_issue']['comment'] = main_comments_str
- json_output['issues'] = {}
- for i in range(len(issues_str_list)):
- json_output['issues'][f'issue_{i}'] = {}
- json_output['issues'][f'issue_{i}']['issue'] = issues_str_list[i]
- json_output['issues'][f'issue_{i}'][f'comments'] = comments_str_list[i]
-
- jsonFile = open("/Users/talrid/Desktop/issues_data.json", "w")
- jsonFile.write(json.dumps(json_output))
- jsonFile.close()
+ # check if index exists, and if repo is already indexed
+ run_from_scratch = False
+ upsert = True
+ pinecone.init(api_key=api_key, environment=environment)
+ if not index_name in pinecone.list_indexes():
+ run_from_scratch = True
+ upsert = False
else:
- jsonFile = open("/Users/talrid/Desktop/issues_data.json", "r")
- json_output=json.loads(jsonFile.read())
+ if get_settings().pr_similar_issue.force_update_dataset:
+ upsert = True
+ else:
+ pinecone_index = pinecone.Index(index_name=index_name)
+ res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict()
+ if res["vectors"]:
+ upsert = False
- from langchain.document_loaders import TextLoader
- from langchain.text_splitter import CharacterTextSplitter
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
+ if run_from_scratch or upsert: # index the entire repo
+ logging.info('Indexing the entire repo...')
- aaa=3
+ logging.info('Getting issues...')
+ issues = list(repo_obj.get_issues(state='all'))
+ logging.info('Done')
+ self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert)
+ else: # update index if needed
+ pinecone_index = pinecone.Index(index_name=index_name)
+ issues_to_update = []
+ issues_paginated_list = repo_obj.get_issues(state='all')
+ counter = 1
+ for issue in issues_paginated_list:
+ if issue.pull_request:
+ continue
+ issue_str, comments, number = self._process_issue(issue)
+ issue_key = f"issue_{number}"
+ id = issue_key + "." + "issue"
+ res = pinecone_index.fetch([id]).to_dict()
+ is_new_issue = True
+ for vector in res["vectors"].values():
+ if vector['metadata']['repo'] == repo_name_for_index:
+ is_new_issue = False
+ break
+ if is_new_issue:
+ counter += 1
+ issues_to_update.append(issue)
+ else:
+ break
+
+ if issues_to_update:
+ logging.info(f'Updating index with {counter} new issues...')
+ self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True)
+ else:
+ logging.info('No new issues to update')
+
+ async def run(self):
+ repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
+ issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
+ issue_str, comments, number = self._process_issue(issue_main)
+ openai.api_key = get_settings().openai.key
+
+ res = openai.Embedding.create(input=[issue_str], engine=MODEL)
+ embeds = [record['embedding'] for record in res['data']]
+ pinecone_index = pinecone.Index(index_name=self.index_name)
+ res = pinecone_index.query(embeds[0],
+ top_k=5,
+ filter={"repo": self.repo_name_for_index},
+ include_metadata=True).to_dict()
+ relevant_issues_number_list = []
+ for r in res['matches']:
+ issue_number = int(r["id"].split('.')[0].split('_')[-1])
+ if original_issue_number == issue_number:
+ continue
+ if issue_number not in relevant_issues_number_list:
+ relevant_issues_number_list.append(issue_number)
+
+ similar_issues_str = "Similar Issues:\n\n"
+ for i, issue_number_similar in enumerate(relevant_issues_number_list):
+ issue = self.git_provider.repo_obj.get_issue(issue_number_similar)
+ title = issue.title
+ url = issue.html_url
+ similar_issues_str += f"{i + 1}. [{title}]({url})\n\n"
+ if get_settings().config.publish_output:
+ response = issue_main.create_comment(similar_issues_str)
+ logging.info(similar_issues_str)
+
+ def _process_issue(self, issue):
+ header = issue.title
+ body = issue.body
+ number = issue.number
+ if get_settings().pinecone.skip_comments:
+ comments = []
+ else:
+ comments = list(issue.get_comments())
+ issue_str = f"Issue Header: \"{header}\"\n\nIssue Body:\n{body}"
+ return issue_str, comments, number
+
+ def _update_index_with_issues(self, issues_list, repo_name_for_index, upsert=False):
+ logging.info('Processing issues...')
+ corpus = Corpus()
+ example_issue_record = Record(
+ id=f"example_issue_{repo_name_for_index}",
+ text="example_issue",
+ metadata=Metadata(repo=repo_name_for_index)
+ )
+ corpus.append(example_issue_record)
+
+ counter = 0
+ for issue in issues_list:
+
+ if issue.pull_request:
+ continue
+
+ counter += 1
+ if counter >= self.max_issues_to_scan:
+ logging.info(f"Scanned {self.max_issues_to_scan} issues, stopping")
+ break
+
+ issue_str, comments, number = self._process_issue(issue)
+ issue_key = f"issue_{number}"
+ username = issue.user.login
+ created_at = str(issue.created_at)
+ if len(issue_str) < 8000 or \
+ self.token_handler.count_tokens(issue_str) < MAX_TOKENS[MODEL]: # fast reject first
+ issue_record = Record(
+ id=issue_key + "." + "issue",
+ text=issue_str,
+ metadata=Metadata(repo=repo_name_for_index,
+ username=username,
+ created_at=created_at,
+ level=IssueLevel.ISSUE)
+ )
+ corpus.append(issue_record)
+ if comments:
+ for j, comment in enumerate(comments):
+ comment_body = comment.body
+ num_words_comment = len(comment_body.split())
+ if num_words_comment < 10:
+ continue
+
+ if len(issue_str) < 8000 or \
+ self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
+ comment_record = Record(
+ id=issue_key + ".comment_" + str(j + 1),
+ text=comment_body,
+ metadata=Metadata(repo=repo_name_for_index,
+ username=username, # use issue username for all comments
+ created_at=created_at,
+ level=IssueLevel.COMMENT)
+ )
+ corpus.append(comment_record)
+ df = pd.DataFrame(corpus.dict()["documents"])
+ logging.info('Done')
+
+ logging.info('Embedding...')
+ openai.api_key = get_settings().openai.key
+ list_to_encode = list(df["text"].values)
+ res = openai.Embedding.create(input=list_to_encode, engine=MODEL)
+ embeds = [record['embedding'] for record in res['data']]
+ df["values"] = embeds
+ meta = DatasetMetadata.empty()
+ meta.dense_model.dimension = len(embeds[0])
+ ds = Dataset.from_pandas(df, meta)
+ logging.info('Done')
+
+ api_key = get_settings().pinecone.api_key
+ environment = get_settings().pinecone.environment
+ if not upsert:
+ logging.info('Creating index...')
+ ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment)
+ else:
+ logging.info('Upserting index...')
+ namespace = ""
+ batch_size: int = 100
+ concurrency: int = 10
+ pinecone.init(api_key=api_key, environment=environment)
+ ds._upsert_to_index(self.index_name, namespace, batch_size, concurrency)
+ logging.info('Done')
+
+
+class IssueLevel(str, Enum):
+ ISSUE = "issue"
+ COMMENT = "comment"
+
+
+class Metadata(BaseModel):
+ repo: str
+ username: str = Field(default="@codium")
+ created_at: str = Field(default="01-01-1970 00:00:00.00000")
+ level: IssueLevel = Field(default=IssueLevel.ISSUE)
+
+ class Config:
+ use_enum_values = True
+
+
+class Record(BaseModel):
+ id: str
+ text: str
+ metadata: Metadata
+
+
+class Corpus(BaseModel):
+ documents: List[Record] = Field(default=[])
+
+ def append(self, r: Record):
+ self.documents.append(r)
diff --git a/requirements.txt b/requirements.txt
index 99efa846..5d4caaa6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -18,4 +18,6 @@ boto3~=1.28.25
google-cloud-storage==2.10.0
ujson==5.8.0
azure-devops==7.1.0b3
-msrest==0.7.1
\ No newline at end of file
+msrest==0.7.1
+pinecone-client==2.2.2
+pinecone_datasets==0.6.1
\ No newline at end of file