Merge remote-tracking branch 'origin/main' into tr/issue_tool

# Conflicts:
#	pr_agent/algo/utils.py
This commit is contained in:
mrT23
2023-09-05 08:05:33 +03:00
43 changed files with 2842 additions and 470 deletions

View File

@ -1,14 +1,13 @@
import copy
import json
import logging
import textwrap
from typing import List, Dict
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, get_pr_multi_diffs
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json
from pr_agent.algo.utils import load_yaml
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import BitbucketProvider, get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -22,6 +21,13 @@ class PRCodeSuggestions:
self.git_provider.get_languages(), self.git_provider.get_files()
)
# extended mode
self.is_extended = any(["extended" in arg for arg in args])
if self.is_extended:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
@ -32,7 +38,7 @@ class PRCodeSuggestions:
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions,
"num_code_suggestions": num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
@ -42,18 +48,26 @@ class PRCodeSuggestions:
get_settings().pr_code_suggestions_prompt.user)
async def run(self):
assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
logging.info('Generating code suggestions for PR...')
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions()
if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction)
data = self._prepare_pr_code_suggestions()
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
if (not self.is_extended and get_settings().pr_code_suggestions.rank_suggestions) or \
(self.is_extended and get_settings().pr_code_suggestions.rank_extended_suggestions):
logging.info('Ranking Suggestions...')
data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions'])
if get_settings().config.publish_output:
logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...')
logging.info('Pushing inline code suggestions...')
self.push_inline_code_suggestions(data)
async def _prepare_prediction(self, model: str):
@ -81,14 +95,11 @@ class PRCodeSuggestions:
return response
def _prepare_pr_code_suggestions(self) -> str:
def _prepare_pr_code_suggestions(self) -> Dict:
review = self.prediction.strip()
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review, code_suggestions=True)
data = load_yaml(review)
if isinstance(data, list):
data = {'Code suggestions': data}
return data
def push_inline_code_suggestions(self, data):
@ -102,11 +113,8 @@ class PRCodeSuggestions:
if get_settings().config.verbosity_level >= 2:
logging.info(f"suggestion: {d}")
relevant_file = d['relevant file'].strip()
relevant_lines_str = d['relevant lines'].strip()
if ',' in relevant_lines_str: # handling 'relevant lines': '181, 190' or '178-184, 188-194'
relevant_lines_str = relevant_lines_str.split(',')[0]
relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position
relevant_lines_end = int(relevant_lines_str.split('-')[-1])
relevant_lines_start = int(d['relevant lines start']) # absolute position
relevant_lines_end = int(d['relevant lines end'])
content = d['suggestion content']
new_code_snippet = d['improved code']
@ -121,7 +129,11 @@ class PRCodeSuggestions:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse suggestion: {d}")
self.git_provider.publish_code_suggestions(code_suggestions)
is_successful = self.git_provider.publish_code_suggestions(code_suggestions)
if not is_successful:
logging.info("Failed to publish code suggestions, trying to publish each suggestion separately")
for code_suggestion in code_suggestions:
self.git_provider.publish_code_suggestions([code_suggestion])
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
try: # dedent code snippet
@ -145,3 +157,81 @@ class PRCodeSuggestions:
return new_code_snippet
async def _prepare_prediction_extended(self, model: str) -> dict:
logging.info('Getting PR diff...')
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
logging.info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
self.patches_diff = patches_diff
prediction = await self._get_prediction(model)
prediction_list.append(prediction)
self.prediction_list = prediction_list
data = {}
for prediction in prediction_list:
self.prediction = prediction
data_per_chunk = self._prepare_pr_code_suggestions()
if "Code suggestions" in data:
data["Code suggestions"].extend(data_per_chunk["Code suggestions"])
else:
data.update(data_per_chunk)
self.data = data
return data
async def rank_suggestions(self, data: List) -> List:
"""
Call a model to rank (sort) code suggestions based on their importance order.
Args:
data (List): A list of code suggestions to be ranked.
Returns:
List: The ranked list of code suggestions.
"""
suggestion_list = []
# remove invalid suggestions
for i, suggestion in enumerate(data):
if suggestion['existing code'] != suggestion['improved code']:
suggestion_list.append(suggestion)
data_sorted = [[]] * len(suggestion_list)
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(
variables)
user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt,
user=user_prompt)
sort_order = load_yaml(response)
for s in sort_order['Sort Order']:
suggestion_number = s['suggestion number']
importance_order = s['importance order']
data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1]
if get_settings().pr_code_suggestions.final_clip_factor != 1:
new_len = int(0.5 + len(data_sorted) * get_settings().pr_code_suggestions.final_clip_factor)
data_sorted = data_sorted[:new_len]
except Exception as e:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not sort suggestions, error: {e}")
data_sorted = suggestion_list
return data_sorted

View File

@ -36,12 +36,14 @@ class PRDescription:
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"description": self.git_provider.get_pr_description(full=False),
"language": self.main_pr_language,
"diff": "", # empty diff for initial calculation
"extra_instructions": get_settings().pr_description.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages()
}
self.user_description = self.git_provider.get_user_description()
# Initialize the token handler
self.token_handler = TokenHandler(
@ -145,15 +147,12 @@ class PRDescription:
# Load the AI prediction data into a dictionary
data = load_yaml(self.prediction.strip())
if get_settings().pr_description.add_original_user_description and self.user_description:
data["User Description"] = self.user_description
# Initialization
pr_types = []
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
markdown_text = ""
for key, value in data.items():
markdown_text += f"## {key}\n\n"
markdown_text += f"{value}\n\n"
# If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
if 'PR Type' in data:
if type(data['PR Type']) == list:
@ -161,13 +160,19 @@ class PRDescription:
elif type(data['PR Type']) == str:
pr_types = data['PR Type'].split(',')
# Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary
title = data.pop('PR Title')
# Remove the 'PR Title' key from the dictionary
ai_title = data.pop('PR Title')
if get_settings().pr_description.keep_original_user_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
# Assign the value of the 'PR Title' key to 'title' variable
title = ai_title
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
# except for the items containing the word 'walkthrough'
pr_body = ""
for key, value in data.items():
for idx, (key, value) in enumerate(data.items()):
pr_body += f"## {key}:\n"
if 'walkthrough' in key.lower():
# for filename, description in value.items():
@ -179,7 +184,11 @@ class PRDescription:
# if the value is a list, join its items by comma
if type(value) == list:
value = ', '.join(v for v in value)
pr_body += f"{value}\n\n___\n"
pr_body += f"{value}\n"
if idx < len(data) - 1:
pr_body += "\n___\n"
markdown_text = f"## Title\n\n{title}\n\n___\n{pr_body}"
if get_settings().config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}")

View File

@ -23,7 +23,7 @@ class PRReviewer:
"""
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, is_answer: bool = False, args: list = None):
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
@ -40,6 +40,7 @@ class PRReviewer:
)
self.pr_url = pr_url
self.is_answer = is_answer
self.is_auto = is_auto
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
@ -93,8 +94,12 @@ class PRReviewer:
"""
Review the pull request and generate feedback.
"""
logging.info('Reviewing PR...')
if self.is_auto and not get_settings().pr_reviewer.automatic_review:
logging.info(f'Automatic review is disabled {self.pr_url}')
return None
logging.info(f'Reviewing PR: {self.pr_url} ...')
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)