Initial commit - PR-Agent OSS release

This commit is contained in:
Ori Kotek
2023-07-06 00:21:08 +03:00
commit 4b4d91dfe9
44 changed files with 2426 additions and 0 deletions

10
pr_agent/algo/__init__.py Normal file
View File

@ -0,0 +1,10 @@
MAX_TOKENS = {
'gpt-3.5-turbo': 4000,
'gpt-3.5-turbo-0613': 4000,
'gpt-3.5-turbo-0301': 4000,
'gpt-3.5-turbo-16k': 16000,
'gpt-3.5-turbo-16k-0613': 16000,
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
}

View File

@ -0,0 +1,37 @@
import logging
import openai
from openai.error import APIError, Timeout, TryAgain
from retry import retry
from pr_agent.config_loader import settings
OPENAI_RETRIES=2
class AiHandler:
def __init__(self):
try:
openai.api_key = settings.openai.key
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
try:
response = await openai.ChatCompletion.acreate(
model=model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}
],
temperature=temperature,
)
except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e)
raise
if response is None or len(response.choices) == 0:
raise TryAgain
resp = response.choices[0]['message']['content']
finish_reason = response.choices[0].finish_reason
return resp, finish_reason

View File

@ -0,0 +1,107 @@
from __future__ import annotations
import logging
import re
from pr_agent.config_loader import settings
def extend_patch(original_file_str, patch_str, num_lines) -> str:
"""
Extends the patch to include 'num_lines' more surrounding lines
"""
if not patch_str or num_lines == 0:
return patch_str
original_lines = original_file_str.splitlines()
patch_lines = patch_str.splitlines()
extended_patch_lines = []
start1, size1, start2, size2 = -1, -1, -1, -1
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
try:
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match:
# finish previous hunk
if start1 != -1:
extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
start1, size1, start2, size2 = map(int, match.groups()[:4])
section_header = match.groups()[4]
extended_start1 = max(1, start1 - num_lines)
extended_size1 = size1 + (start1 - extended_start1) + num_lines
extended_start2 = max(1, start2 - num_lines)
extended_size2 = size2 + (start2 - extended_start2) + num_lines
extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}')
extended_patch_lines.extend(
original_lines[extended_start1 - 1:start1 - 1]) # one to zero based
continue
extended_patch_lines.append(line)
except Exception as e:
if settings.config.verbosity_level >= 2:
logging.error(f"Failed to extend patch: {e}")
return patch_str
# finish previous hunk
if start1 != -1:
extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
extended_patch_str = '\n'.join(extended_patch_lines)
return extended_patch_str
def omit_deletion_hunks(patch_lines) -> str:
temp_hunk = []
added_patched = []
add_hunk = False
inside_hunk = False
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match:
# finish previous hunk
if inside_hunk and add_hunk:
added_patched.extend(temp_hunk)
temp_hunk = []
add_hunk = False
temp_hunk.append(line)
inside_hunk = True
else:
temp_hunk.append(line)
edit_type = line[0]
if edit_type == '+':
add_hunk = True
if inside_hunk and add_hunk:
added_patched.extend(temp_hunk)
return '\n'.join(added_patched)
def handle_patch_deletions(patch: str, original_file_content_str: str,
new_file_content_str: str, file_name: str) -> str:
"""
Handle entire file or deletion patches
"""
if not new_file_content_str:
# logic for handling deleted files - don't show patch, just show that the file was deleted
if settings.config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, minimizing deletion file")
patch = "File was deleted\n"
else:
patch_lines = patch.splitlines()
patch_new = omit_deletion_hunks(patch_lines)
if patch != patch_new:
if settings.config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new
return patch

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import difflib
import logging
from typing import Any, Dict, Tuple
from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.git_providers import GithubProvider
OUTPUT_BUFFER_TOKENS = 800
PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: [GithubProvider, Any], token_handler: TokenHandler) -> str:
"""
Returns a string with the diff of the PR.
If needed, apply diff minimization techniques to reduce the number of tokens
"""
files = list(git_provider.get_diff_files())
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), files)
# generate a standard diff string, with patch extension
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS < token_handler.limit:
return "\n".join(patches_extended)
# if we are over the limit, start pruning
patches_compressed = pr_generate_compressed_diff(pr_languages, token_handler)
return "\n".join(patches_compressed)
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -> \
Tuple[list, int]:
"""
Generate a standard diff string, with patch extension
"""
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
# handle the case of large patch, that initially was not loaded
patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch)
if not patch:
continue
# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n"
patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
total_tokens += patch_tokens
patches_extended.append(full_extended_patch)
return patches_extended, total_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) -> list:
# Apply Diff Minimization techniques to reduce the number of tokens:
# 0. Start from the largest diff patch to smaller ones
# 1. Don't use extend context lines around diff
# 2. Minimize deleted files
# 3. Minimize deleted hunks
# 4. Minimize all remaining files when you reach token limit
patches = []
# sort each one of the languages in top_langs by the number of tokens in the diff
sorted_files = []
for lang in top_langs:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
total_tokens = token_handler.prompt_tokens
for file in sorted_files:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch)
if not patch:
continue
# removing delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str,
new_file_content_str, file.filename)
new_patch_tokens = token_handler.count_tokens(patch)
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS // 2:
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue # Hard Stop, no more tokens
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
if settings.config.verbosity_level >= 2:
logging.warning(f"Patch too large, minimizing it, {file.filename}")
patch = "File was modified"
if patch:
patch_final = f"## {file.filename}\n\n{patch}\n"
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
if settings.config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
return patches
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
if not patch: # to Do - also add condition for file extension
try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True))
if settings.config.verbosity_level >= 2:
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
patch = ''.join(diff)
except Exception:
pass
return patch

View File

@ -0,0 +1,24 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model
from pr_agent.algo import MAX_TOKENS
from pr_agent.config_loader import settings
class TokenHandler:
def __init__(self, pr, vars: dict, system, user):
self.encoder = encoding_for_model(settings.config.model)
self.limit = MAX_TOKENS[settings.config.model]
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens
def count_tokens(self, patch: str) -> int:
return len(self.encoder.encode(patch))

59
pr_agent/algo/utils.py Normal file
View File

@ -0,0 +1,59 @@
from __future__ import annotations
import textwrap
def convert_to_markdown(output_data: dict) -> str:
markdown_text = ""
emojis = {
"Main theme": "🎯",
"Description and title": "🔍",
"Type of PR": "📌",
"Relevant tests added": "🧪",
"Unrelated changes": "⚠️",
"Minimal and focused": "",
"Security concerns": "🔒",
"General PR suggestions": "💡",
"Code suggestions": "🤖"
}
for key, value in output_data.items():
if not value:
continue
if isinstance(value, dict):
markdown_text += f"## {key}\n\n"
markdown_text += convert_to_markdown(value)
elif isinstance(value, list):
if key.lower() == 'code suggestions':
markdown_text += "\n" # just looks nicer with additional line breaks
emoji = emojis.get(key, "") # Use a dash if no emoji is found for the key
markdown_text += f"- {emoji} **{key}:**\n\n"
for item in value:
if isinstance(item, dict) and key.lower() == 'code suggestions':
markdown_text += parse_code_suggestion(item)
elif item:
markdown_text += f" - {item}\n"
elif value != 'n/a':
emoji = emojis.get(key, "") # Use a dash if no emoji is found for the key
markdown_text += f"- {emoji} **{key}:** {value}\n"
return markdown_text
def parse_code_suggestion(code_suggestions: dict) -> str:
markdown_text = ""
for sub_key, sub_value in code_suggestions.items():
if isinstance(sub_value, dict): # "code example"
markdown_text += f" - **{sub_key}:**\n"
for code_key, code_value in sub_value.items(): # 'before' and 'after' code
code_str = f"```\n{code_value}\n```"
code_str_indented = textwrap.indent(code_str, ' ')
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
else:
if "suggestion number" in sub_key.lower():
markdown_text += f"- **suggestion {sub_value}:**\n" # prettier formatting
else:
markdown_text += f" - **{sub_key}:** {sub_value}\n"
markdown_text += "\n"
return markdown_text