Initial commit - PR-Agent OSS release

This commit is contained in:
Ori Kotek
2023-07-06 00:21:08 +03:00
commit 4b4d91dfe9
44 changed files with 2426 additions and 0 deletions

1
pr_agent/__init__.py Normal file
View File

@ -0,0 +1 @@

View File

View File

@ -0,0 +1,20 @@
import re
from typing import Optional
from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
class PRAgent:
def __init__(self, installation_id: Optional[int] = None):
self.installation_id = installation_id
async def handle_request(self, pr_url, request):
if 'please review' in request.lower():
reviewer = PRReviewer(pr_url, self.installation_id)
await reviewer.review()
elif 'please answer' in request.lower():
question = re.split(r'(?i)please answer', request)[1].strip()
answerer = PRQuestions(pr_url, question, self.installation_id)
await answerer.answer()

10
pr_agent/algo/__init__.py Normal file
View File

@ -0,0 +1,10 @@
MAX_TOKENS = {
'gpt-3.5-turbo': 4000,
'gpt-3.5-turbo-0613': 4000,
'gpt-3.5-turbo-0301': 4000,
'gpt-3.5-turbo-16k': 16000,
'gpt-3.5-turbo-16k-0613': 16000,
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
}

View File

@ -0,0 +1,37 @@
import logging
import openai
from openai.error import APIError, Timeout, TryAgain
from retry import retry
from pr_agent.config_loader import settings
OPENAI_RETRIES=2
class AiHandler:
def __init__(self):
try:
openai.api_key = settings.openai.key
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
try:
response = await openai.ChatCompletion.acreate(
model=model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}
],
temperature=temperature,
)
except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e)
raise
if response is None or len(response.choices) == 0:
raise TryAgain
resp = response.choices[0]['message']['content']
finish_reason = response.choices[0].finish_reason
return resp, finish_reason

View File

@ -0,0 +1,107 @@
from __future__ import annotations
import logging
import re
from pr_agent.config_loader import settings
def extend_patch(original_file_str, patch_str, num_lines) -> str:
"""
Extends the patch to include 'num_lines' more surrounding lines
"""
if not patch_str or num_lines == 0:
return patch_str
original_lines = original_file_str.splitlines()
patch_lines = patch_str.splitlines()
extended_patch_lines = []
start1, size1, start2, size2 = -1, -1, -1, -1
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
try:
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match:
# finish previous hunk
if start1 != -1:
extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
start1, size1, start2, size2 = map(int, match.groups()[:4])
section_header = match.groups()[4]
extended_start1 = max(1, start1 - num_lines)
extended_size1 = size1 + (start1 - extended_start1) + num_lines
extended_start2 = max(1, start2 - num_lines)
extended_size2 = size2 + (start2 - extended_start2) + num_lines
extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}')
extended_patch_lines.extend(
original_lines[extended_start1 - 1:start1 - 1]) # one to zero based
continue
extended_patch_lines.append(line)
except Exception as e:
if settings.config.verbosity_level >= 2:
logging.error(f"Failed to extend patch: {e}")
return patch_str
# finish previous hunk
if start1 != -1:
extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
extended_patch_str = '\n'.join(extended_patch_lines)
return extended_patch_str
def omit_deletion_hunks(patch_lines) -> str:
temp_hunk = []
added_patched = []
add_hunk = False
inside_hunk = False
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match:
# finish previous hunk
if inside_hunk and add_hunk:
added_patched.extend(temp_hunk)
temp_hunk = []
add_hunk = False
temp_hunk.append(line)
inside_hunk = True
else:
temp_hunk.append(line)
edit_type = line[0]
if edit_type == '+':
add_hunk = True
if inside_hunk and add_hunk:
added_patched.extend(temp_hunk)
return '\n'.join(added_patched)
def handle_patch_deletions(patch: str, original_file_content_str: str,
new_file_content_str: str, file_name: str) -> str:
"""
Handle entire file or deletion patches
"""
if not new_file_content_str:
# logic for handling deleted files - don't show patch, just show that the file was deleted
if settings.config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, minimizing deletion file")
patch = "File was deleted\n"
else:
patch_lines = patch.splitlines()
patch_new = omit_deletion_hunks(patch_lines)
if patch != patch_new:
if settings.config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new
return patch

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import difflib
import logging
from typing import Any, Dict, Tuple
from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.git_providers import GithubProvider
OUTPUT_BUFFER_TOKENS = 800
PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: [GithubProvider, Any], token_handler: TokenHandler) -> str:
"""
Returns a string with the diff of the PR.
If needed, apply diff minimization techniques to reduce the number of tokens
"""
files = list(git_provider.get_diff_files())
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), files)
# generate a standard diff string, with patch extension
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS < token_handler.limit:
return "\n".join(patches_extended)
# if we are over the limit, start pruning
patches_compressed = pr_generate_compressed_diff(pr_languages, token_handler)
return "\n".join(patches_compressed)
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -> \
Tuple[list, int]:
"""
Generate a standard diff string, with patch extension
"""
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
# handle the case of large patch, that initially was not loaded
patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch)
if not patch:
continue
# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n"
patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
total_tokens += patch_tokens
patches_extended.append(full_extended_patch)
return patches_extended, total_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) -> list:
# Apply Diff Minimization techniques to reduce the number of tokens:
# 0. Start from the largest diff patch to smaller ones
# 1. Don't use extend context lines around diff
# 2. Minimize deleted files
# 3. Minimize deleted hunks
# 4. Minimize all remaining files when you reach token limit
patches = []
# sort each one of the languages in top_langs by the number of tokens in the diff
sorted_files = []
for lang in top_langs:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
total_tokens = token_handler.prompt_tokens
for file in sorted_files:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch)
if not patch:
continue
# removing delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str,
new_file_content_str, file.filename)
new_patch_tokens = token_handler.count_tokens(patch)
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS // 2:
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue # Hard Stop, no more tokens
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
if settings.config.verbosity_level >= 2:
logging.warning(f"Patch too large, minimizing it, {file.filename}")
patch = "File was modified"
if patch:
patch_final = f"## {file.filename}\n\n{patch}\n"
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
if settings.config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
return patches
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
if not patch: # to Do - also add condition for file extension
try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True))
if settings.config.verbosity_level >= 2:
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
patch = ''.join(diff)
except Exception:
pass
return patch

View File

@ -0,0 +1,24 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model
from pr_agent.algo import MAX_TOKENS
from pr_agent.config_loader import settings
class TokenHandler:
def __init__(self, pr, vars: dict, system, user):
self.encoder = encoding_for_model(settings.config.model)
self.limit = MAX_TOKENS[settings.config.model]
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens
def count_tokens(self, patch: str) -> int:
return len(self.encoder.encode(patch))

59
pr_agent/algo/utils.py Normal file
View File

@ -0,0 +1,59 @@
from __future__ import annotations
import textwrap
def convert_to_markdown(output_data: dict) -> str:
markdown_text = ""
emojis = {
"Main theme": "🎯",
"Description and title": "🔍",
"Type of PR": "📌",
"Relevant tests added": "🧪",
"Unrelated changes": "⚠️",
"Minimal and focused": "",
"Security concerns": "🔒",
"General PR suggestions": "💡",
"Code suggestions": "🤖"
}
for key, value in output_data.items():
if not value:
continue
if isinstance(value, dict):
markdown_text += f"## {key}\n\n"
markdown_text += convert_to_markdown(value)
elif isinstance(value, list):
if key.lower() == 'code suggestions':
markdown_text += "\n" # just looks nicer with additional line breaks
emoji = emojis.get(key, "") # Use a dash if no emoji is found for the key
markdown_text += f"- {emoji} **{key}:**\n\n"
for item in value:
if isinstance(item, dict) and key.lower() == 'code suggestions':
markdown_text += parse_code_suggestion(item)
elif item:
markdown_text += f" - {item}\n"
elif value != 'n/a':
emoji = emojis.get(key, "") # Use a dash if no emoji is found for the key
markdown_text += f"- {emoji} **{key}:** {value}\n"
return markdown_text
def parse_code_suggestion(code_suggestions: dict) -> str:
markdown_text = ""
for sub_key, sub_value in code_suggestions.items():
if isinstance(sub_value, dict): # "code example"
markdown_text += f" - **{sub_key}:**\n"
for code_key, code_value in sub_value.items(): # 'before' and 'after' code
code_str = f"```\n{code_value}\n```"
code_str_indented = textwrap.indent(code_str, ' ')
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
else:
if "suggestion number" in sub_key.lower():
markdown_text += f"- **suggestion {sub_value}:**\n" # prettier formatting
else:
markdown_text += f" - **{sub_key}:** {sub_value}\n"
markdown_text += "\n"
return markdown_text

14
pr_agent/config_loader.py Normal file
View File

@ -0,0 +1,14 @@
from os.path import abspath, dirname, join
from dynaconf import Dynaconf
current_dir = dirname(abspath(__file__))
settings = Dynaconf(
envvar_prefix=False,
settings_files=[join(current_dir, f) for f in [
"settings/.secrets.toml",
"settings/configuration.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml"
]]
)

View File

@ -0,0 +1,15 @@
from pr_agent.config_loader import settings
from pr_agent.git_providers.github_provider import GithubProvider
_GIT_PROVIDERS = {
'github': GithubProvider
}
def get_git_provider():
try:
provider_id = settings.config.git_provider
except AttributeError as e:
raise ValueError("github_provider is a required attribute in the configuration file") from e
if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}")
return _GIT_PROVIDERS[provider_id]

View File

@ -0,0 +1,170 @@
from collections import namedtuple
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Tuple
from urllib.parse import urlparse
from github import AppAuthentication, File, Github
from pr_agent.config_loader import settings
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
class GithubProvider:
def __init__(self, pr_url: Optional[str] = None, installation_id: Optional[int] = None):
self.installation_id = installation_id
self.github_client = self._get_github_client()
self.repo = None
self.pr_num = None
self.pr = None
if pr_url:
self.set_pr(pr_url)
def set_pr(self, pr_url: str):
self.repo, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr()
def get_diff_files(self) -> list[FilePatchInfo]:
files = self.pr.get_files()
diff_files = []
for file in files:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename))
return diff_files
def publish_comment(self, pr_comment: str):
self.pr.create_issue_comment(pr_comment)
def get_title(self):
return self.pr.title
def get_description(self):
return self.pr.body
def get_languages(self):
return self._get_repo().get_languages()
def get_main_pr_language(self) -> str:
"""
Get the main language of the commit. Return an empty string if cannot determine.
"""
main_language_str = ""
try:
languages = self.get_languages()
top_language = max(languages, key=languages.get).lower()
# validate that the specific commit uses the main language
extension_list = []
files = self.pr.get_files()
for file in files:
extension_list.append(file.filename.rsplit('.')[-1])
# get the most common extension
most_common_extension = max(set(extension_list), key=extension_list.count)
# look for a match. TBD: add more languages, do this systematically
if most_common_extension == 'py' and top_language == 'python' or \
most_common_extension == 'js' and top_language == 'javascript' or \
most_common_extension == 'ts' and top_language == 'typescript' or \
most_common_extension == 'go' and top_language == 'go' or \
most_common_extension == 'java' and top_language == 'java' or \
most_common_extension == 'c' and top_language == 'c' or \
most_common_extension == 'cpp' and top_language == 'c++' or \
most_common_extension == 'cs' and top_language == 'c#' or \
most_common_extension == 'swift' and top_language == 'swift' or \
most_common_extension == 'php' and top_language == 'php' or \
most_common_extension == 'rb' and top_language == 'ruby' or \
most_common_extension == 'rs' and top_language == 'rust' or \
most_common_extension == 'scala' and top_language == 'scala' or \
most_common_extension == 'kt' and top_language == 'kotlin' or \
most_common_extension == 'pl' and top_language == 'perl' or \
most_common_extension == 'swift' and top_language == 'swift':
main_language_str = top_language
except Exception:
pass
return main_language_str
def get_pr_branch(self):
return self.pr.head.ref
def get_notifications(self, since: datetime):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications")
notifications = self.github_client.get_user().get_notifications(since=since)
return notifications
@staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url)
if 'github.com' not in parsed_url.netloc:
raise ValueError("The provided URL is not a valid GitHub URL")
path_parts = parsed_url.path.strip('/').split('/')
if 'api.github.com' in parsed_url.netloc:
if len(path_parts) < 5 or path_parts[3] != 'pulls':
raise ValueError("The provided URL does not appear to be a GitHub PR URL")
repo_name = '/'.join(path_parts[1:3])
try:
pr_number = int(path_parts[4])
except ValueError as e:
raise ValueError("Unable to convert PR number to integer") from e
return repo_name, pr_number
if len(path_parts) < 4 or path_parts[2] != 'pull':
raise ValueError("The provided URL does not appear to be a GitHub PR URL")
repo_name = '/'.join(path_parts[:2])
try:
pr_number = int(path_parts[3])
except ValueError as e:
raise ValueError("Unable to convert PR number to integer") from e
return repo_name, pr_number
def _get_github_client(self):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type == 'app':
try:
private_key = settings.github.private_key
app_id = settings.github.app_id
except AttributeError as e:
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
if not self.installation_id:
raise ValueError("GitHub app installation ID is required when using GitHub app deployment")
auth = AppAuthentication(app_id=app_id, private_key=private_key,
installation_id=self.installation_id)
return Github(app_auth=auth)
if deployment_type == 'user':
try:
token = settings.github.user_token
except AttributeError as e:
raise ValueError("GitHub token is required when using user deployment") from e
return Github(token)
def _get_repo(self):
return self.github_client.get_repo(self.repo)
def _get_pr(self):
return self._get_repo().get_pull(self.pr_num)
def _get_pr_file_content(self, file: FilePatchInfo, sha: str):
try:
file_content_str = self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode()
except Exception:
file_content_str = ""
return file_content_str

View File

@ -0,0 +1,16 @@
import argparse
import asyncio
import logging
import os
from pr_agent.tools.pr_questions import PRQuestions
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Review a PR from a URL')
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
parser.add_argument('--question_str', type=str, help='The question to answer', required=True)
args = parser.parse_args()
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
reviewer = PRQuestions(args.pr_url, args.question_str, None)
asyncio.run(reviewer.answer())

View File

@ -0,0 +1,14 @@
import argparse
import asyncio
import logging
import os
from pr_agent.tools.pr_reviewer import PRReviewer
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Review a PR from a URL')
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
args = parser.parse_args()
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
reviewer = PRReviewer(args.pr_url, None)
asyncio.run(reviewer.review())

View File

@ -0,0 +1,78 @@
import logging
import sys
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
from pr_agent.servers.utils import verify_signature
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
router = APIRouter()
@router.post("/api/v1/github_webhooks")
async def handle_github_webhooks(request: Request, response: Response):
logging.debug("Received a github webhook")
try:
body = await request.json()
except Exception as e:
logging.error("Error parsing request body", e)
raise HTTPException(status_code=400, detail="Error parsing request body") from e
body_bytes = await request.body()
signature_header = request.headers.get('x-hub-signature-256', None)
try:
webhook_secret = settings.github.webhook_secret
except AttributeError:
webhook_secret = None
if webhook_secret:
verify_signature(body_bytes, webhook_secret, signature_header)
logging.debug(f'Request body:\n{body}')
return await handle_request(body)
async def handle_request(body):
action = body.get("action", None)
installation_id = body.get("installation", {}).get("id", None)
agent = PRAgent(installation_id)
if action == 'created':
if "comment" not in body:
return {}
comment_body = body.get("comment", {}).get("body", None)
if "says 'Please" in comment_body:
return {}
if "issue" not in body and "pull_request" not in body["issue"]:
return {}
pull_request = body["issue"]["pull_request"]
api_url = pull_request.get("url", None)
await agent.handle_request(api_url, comment_body)
elif action in ["opened"] or 'reopened' in action:
pull_request = body.get("pull_request", None)
if not pull_request:
return {}
api_url = pull_request.get("url", None)
if api_url is None:
return {}
await agent.handle_request(api_url, "please review")
else:
return {}
@router.get("/")
async def root():
return {"status": "ok"}
def start():
if settings.get("GITHUB.DEPLOYMENT_TYPE", "user") != "app":
raise Exception("Please set deployment type to app in .secrets.toml file")
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=3000)
if __name__ == '__main__':
start()

View File

@ -0,0 +1,73 @@
import asyncio
import logging
import sys
from datetime import datetime, timezone
import aiohttp
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
NOTIFICATION_URL = "https://api.github.com/notifications"
def now() -> str:
now_utc = datetime.now(timezone.utc).isoformat()
now_utc = now_utc.replace("+00:00", "Z")
return now_utc
async def polling_loop():
since = [now()]
last_modified = [None]
try:
deployment_type = settings.github.deployment_type
token = settings.github.user_token
except AttributeError:
deployment_type = 'none'
token = None
if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications")
if not token:
raise ValueError("User token must be set to get notifications")
async with aiohttp.ClientSession() as session:
while True:
headers = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"Bearer {token}"
}
params = {
"participating": "true"
}
if since[0]:
params["since"] = since[0]
if last_modified[0]:
headers["If-Modified-Since"] = last_modified[0]
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response:
if response.status == 200:
if 'Last-Modified' in response.headers:
last_modified[0] = response.headers['Last-Modified']
since[0] = None
notifications = await response.json()
for notification in notifications:
if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url']
async with session.get(latest_comment, headers=headers) as comment_response:
if comment_response.status == 200:
comment = await comment_response.json()
comment_body = comment['body'] if 'body' in comment else ''
commenter_github_user = comment['user']['login'] if 'user' in comment else ''
logging.info(f"Commenter: {commenter_github_user}\nComment: {comment_body}")
if comment_body.strip().startswith("@"):
agent = PRAgent()
await agent.handle_request(pr_url, comment_body)
elif response.status != 304:
print(f"Failed to fetch notifications. Status code: {response.status}")
await asyncio.sleep(5)
if __name__ == '__main__':
asyncio.run(polling_loop())

23
pr_agent/servers/utils.py Normal file
View File

@ -0,0 +1,23 @@
import hashlib
import hmac
from fastapi import HTTPException
def verify_signature(payload_body, secret_token, signature_header):
"""Verify that the payload was sent from GitHub by validating SHA256.
Raise and return 403 if not authorized.
Args:
payload_body: original request body to verify (request.body())
secret_token: GitHub app webhook token (WEBHOOK_SECRET)
signature_header: header received from GitHub (x-hub-signature-256)
"""
if not signature_header:
raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!")
hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256)
expected_signature = "sha256=" + hash_object.hexdigest()
if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!")

View File

@ -0,0 +1,26 @@
# QUICKSTART:
# Copy this file to .secrets in the same folder.
# The minimum workable settings - set openai.key to your API key.
# Set github.deployment_type to "user" and github.user_token to your GitHub personal access token.
# This will allow you to run the CLI scripts in the scripts/ folder and the github_polling server.
#
# See README for details about GitHub App deployment.
[openai]
key = "<API_KEY>"
[github]
# The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user"
# ---- Set the following only for deployment type == "user"
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope.
# ---- Set the following only for deployment type == "app", see README for details.
private_key = """\
-----BEGIN RSA PRIVATE KEY-----
<GITHUB PRIVATE KEY>
-----END RSA PRIVATE KEY-----
"""
app_id = 123456 # The GitHub App ID, replace with your own.
webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.

View File

@ -0,0 +1,15 @@
[config]
model="gpt-4-0613"
git_provider="github"
publish_review=true
verbosity_level=0 # 0,1,2
[pr_reviewer]
require_minimal_and_focused_review=true
require_tests_review=true
require_security_review=true
extended_code_suggestions=false
num_code_suggestions=4
[pr_questions]

View File

@ -0,0 +1,30 @@
[pr_questions_prompt]
system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests.
Your task is to answer questions about the new PR code (the '+' lines), and provide feedback.
Be informative, constructive, and give examples. Try to be as specific as possible, and don't avoid answering the questions.
Make sure not to repeat modifications already implemented in the new PR code (the '+' lines).
"""
user="""PR Info:
Title: '{{title}}'
Branch: '{{branch}}'
Description: '{{description}}'
{%- if language %}
Main language: {{language}}
{%- endif %}
The PR Git Diff:
```
{{diff}}
```
Note that lines in the diff body are prefixed with a symbol that represents the type of change: '-' for deletions, '+' for additions, and ' ' (a space) for unchanged lines
The PR Questions:
```
{{ questions }}
```
Response:
"""

View File

@ -0,0 +1,159 @@
[pr_review_prompt]
system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests.
Your task is to provide constructive and concise feedback for the PR, and also provide meaningfull code suggestions to improve the new PR code (the '+' lines).
- Provide up to {{ num_code_suggestions }} code suggestions.
- Try to focus on important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
{%- if extended_code_suggestions %}
- For each suggestion, provide a short and concise code snippet to illustrate the existing code, and the improved code.
{%- endif %}
- Make sure not to provide suggestion repeating modifications already implemented in the new PR code (the '+' lines).
You must use the following JSON schema to format your answer:
```json
{
"PR Analysis": {
"Main theme": {
"type": "string",
"description": "a short explanation of the PR"
},
"Description and title": {
"type": "string",
"description": "yes\\no question: does this PR have a relevant description and title"
},
"Type of PR": {
"type": "string",
"enum": ["Bug fix", "Tests", "Bug fix with tests", "Refactoring", "Enhancement", "Documentation", "Other"]
},
{%- if require_tests %}
"Relevant tests added": {
"type": "string",
"description": "yes\\no question: does this PR have relevant tests ?"
},
{%- endif %}
{%- if require_minimal_and_focused %}
"Minimal and focused": {
"type": "string",
"description": "is this PR as minimal and focused as possible, with all code changes centered around a single coherent theme, described in the PR description and title ?" explain your answer"
}
},
{%- endif %}
"PR Feedback": {
"General PR suggestions": {
"type": "string",
"description": "important suggestions for the contributors and maintainers of this PR, may include overall structure, primary purpose and best practices. consider using specific filenames, classes and functions names. explain yourself!"
},
"Code suggestions": {
"type": "array",
"maxItems": {{ num_code_suggestions }},
"uniqueItems": true,
"items": {
"suggestion number": {
"type": "int",
"description": "suggestion number, starting from 1"
},
"relevant file": {
"type": "string",
"description": "the relevant file name"
},
"suggestion content": {
"type": "string",
{%- if extended_code_suggestions %}
"description": "a concrete suggestion for meaningfully improving the new PR code. Don't repeat previous suggestions. Add tags with importance measure that matches each suggestion ('important' or 'medium'). Do not make suggestions for updating or adding docstrings, renaming PR title and description, or linter like.
{%- else %}
"description": "a concrete suggestion for meaningfully improving the new PR code. Also describe how, specifically, the suggestion can be applied to new PR code. Add tags with importance measure that matches each suggestion ('important' or 'medium'). Do not make suggestions for updating or adding docstrings, renaming PR title and description, or linter like.
{%- endif %}
},
{%- if extended_code_suggestions %}
"why": {
"type": "string",
"description": "shortly explain why this suggestion is important"
},
"code example": {
"type": "object",
"properties": {
"before code": {
"type": "string",
"description": "Short and concise code snippet, to illustrate the existing code"
},
"after code": {
"type": "string",
"description": "Short and concise code snippet, to illustrate the improved code"
}
}
}
{%- endif %}
}
},
{%- if require_security %}
"Security concerns": {
"type": "string",
"description": "yes\\no question: does this PR code introduce possible security concerns or issues, like SQL injection, XSS, CSRF, and others ? explain your answer"
? explain your answer"
}
{%- endif %}
}
}
```
Example output:
'
{
"PR Analysis":
{
"Main theme": "xxx",
"Description and title": "Yes",
"Type of PR": "Bug fix",
{%- if require_tests %}
"Relevant tests added": "No",
{%- endif %}
{%- if require_minimal_and_focused %}
"Minimal and focused": "No, because ..."
{%- endif %}
},
"PR Feedback":
{
"General PR suggestions": "..., `xxx`...",
"Code suggestions": [
{
"suggestion number": 1,
"relevant file": "xxx.py",
"suggestion content": "xxx [important]",
{%- if extended_code_suggestions %}
"why": "xxx",
"code example":
{
"before code": "xxx",
"after code": "xxx"
}
{%- endif %}
},
...
]
{%- if require_security %},
"Security concerns": "No, because ..."
{%- endif %}
}
}
'
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
"""
user="""PR Info:
Title: '{{title}}'
Branch: '{{branch}}'
Description: '{{description}}'
{%- if language %}
Main language: {{language}}
{%- endif %}
The PR Git Diff:
```
{{diff}}
```
Note that lines in the diff body are prefixed with a symbol that represents the type of change: '-' for deletions, '+' for additions, and ' ' (a space) for unchanged lines.
Response (should be a valid JSON, and nothing else):
```json
"""

View File

View File

@ -0,0 +1,67 @@
import copy
import logging
from typing import Optional
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider
class PRQuestions:
def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None):
self.git_provider = get_git_provider()(pr_url, installation_id)
self.main_pr_language = self.git_provider.get_main_pr_language()
self.installation_id = installation_id
self.ai_handler = AiHandler()
self.question_str = question_str
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.pr.body,
"language": self.git_provider.get_main_pr_language(),
"diff": "", # empty diff for initial calculation
"questions": self.question_str,
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_questions_prompt.system,
settings.pr_questions_prompt.user)
self.patches_diff = None
self.prediction = None
async def answer(self):
logging.info('Answering a PR question...')
self.git_provider.publish_comment("Preparing answer...")
logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction()
logging.info('Preparing answer...')
pr_comment = self._prepare_pr_answer()
if settings.config.publish_review:
logging.info('Pushing answer...')
self.git_provider.publish_comment(pr_comment)
return ""
async def _get_prediction(self):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_questions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_questions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
model = settings.config.model
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
def _prepare_pr_answer(self) -> str:
answer_str = f"Questions: {self.question_str}\n\n"
answer_str += f"Answer: {self.prediction.strip()}\n\n"
return answer_str

View File

@ -0,0 +1,88 @@
import copy
import json
import logging
from typing import Optional
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider
class PRReviewer:
def __init__(self, pr_url: str, installation_id: Optional[int] = None):
self.git_provider = get_git_provider()(pr_url, installation_id)
self.main_language = self.git_provider.get_main_pr_language()
self.installation_id = installation_id
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.pr.body,
"language": self.git_provider.get_main_pr_language(),
"diff": "", # empty diff for initial calculation
"require_tests": settings.pr_reviewer.require_tests_review,
"require_security": settings.pr_reviewer.require_security_review,
"require_minimal_and_focused": settings.pr_reviewer.require_minimal_and_focused_review,
'extended_code_suggestions': settings.pr_reviewer.extended_code_suggestions,
'num_code_suggestions': settings.pr_reviewer.num_code_suggestions,
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_review_prompt.system,
settings.pr_review_prompt.user)
async def review(self):
logging.info('Reviewing PR...')
if settings.config.publish_review:
self.git_provider.publish_comment("Preparing review...")
logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction()
logging.info('Preparing PR review...')
pr_comment = self._prepare_pr_review()
if settings.config.publish_review:
logging.info('Pushing PR review...')
self.git_provider.publish_comment(pr_comment)
return ""
async def _get_prediction(self):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
model = settings.config.model
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
try:
json.loads(response)
except json.decoder.JSONDecodeError:
logging.warning("Could not decode JSON")
response = {}
return response
def _prepare_pr_review(self) -> str:
review = self.prediction.strip()
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
logging.error("Unable to decode JSON response from AI")
data = {}
markdown_text = convert_to_markdown(data)
markdown_text += "\nAdd a comment that says 'Please review' to ask for a new review after you update the PR.\n"
markdown_text += "Add a comment that says 'Please answer <QUESTION...>' to ask a question about this PR.\n"
if settings.config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}")
return markdown_text