enhance pr_reviewer.py code

This commit is contained in:
mrT23
2023-07-26 17:24:03 +03:00
parent a60a58794c
commit 1bd47b0d53

View File

@ -2,6 +2,7 @@ import copy
import json import json
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -16,7 +17,19 @@ from pr_agent.servers.help import actions_help_text, bot_help_text
class PRReviewer: class PRReviewer:
def __init__(self, pr_url: str, cli_mode=False, is_answer: bool = False, args=None): """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Args:
pr_url (str): The URL of the pull request to be reviewed.
cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False.
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
"""
self.parse_args(args) self.parse_args(args)
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
@ -25,13 +38,15 @@ class PRReviewer:
) )
self.pr_url = pr_url self.pr_url = pr_url
self.is_answer = is_answer self.is_answer = is_answer
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now") raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now")
answer_str, question_str = self._get_user_answers()
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode
answer_str, question_str = self._get_user_answers()
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),
@ -43,16 +58,27 @@ class PRReviewer:
"require_security": settings.pr_reviewer.require_security_review, "require_security": settings.pr_reviewer.require_security_review,
"require_focused": settings.pr_reviewer.require_focused_review, "require_focused": settings.pr_reviewer.require_focused_review,
'num_code_suggestions': settings.pr_reviewer.num_code_suggestions, 'num_code_suggestions': settings.pr_reviewer.num_code_suggestions,
#
'question_str': question_str, 'question_str': question_str,
'answer_str': answer_str, 'answer_str': answer_str,
} }
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_review_prompt.system,
settings.pr_review_prompt.user)
def parse_args(self, args): self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
settings.pr_review_prompt.system,
settings.pr_review_prompt.user
)
def parse_args(self, args: List[str]) -> None:
"""
Parse the arguments passed to the PRReviewer class and set the 'incremental' attribute accordingly.
Args:
args: A list of arguments passed to the PRReviewer class.
Returns:
None
"""
is_incremental = False is_incremental = False
if args and len(args) >= 1: if args and len(args) >= 1:
arg = args[0] arg = args[0]
@ -60,60 +86,93 @@ class PRReviewer:
is_incremental = True is_incremental = True
self.incremental = IncrementalPR(is_incremental) self.incremental = IncrementalPR(is_incremental)
async def review(self): async def review(self) -> None:
"""
Review the pull request and generate feedback.
"""
logging.info('Reviewing PR...') logging.info('Reviewing PR...')
if settings.config.publish_output: if settings.config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
pr_comment = self._prepare_pr_review() pr_comment = self._prepare_pr_review()
if settings.config.publish_output: if settings.config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
if settings.pr_reviewer.inline_code_comments: if settings.pr_reviewer.inline_code_comments:
logging.info('Pushing inline code comments...') logging.info('Pushing inline code comments...')
self._publish_inline_code_comments() self._publish_inline_code_comments()
return ""
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str) -> None:
"""
Prepare the AI prediction for the pull request review.
Args:
model: A string representing the AI model to be used for the prediction.
Returns:
None
"""
logging.info('Getting PR diff...') logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
logging.info('Getting AI prediction...') logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model) self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str): async def _get_prediction(self, model: str) -> str:
"""
Generate an AI prediction for the pull request review.
Args:
model: A string representing the AI model to be used for the prediction.
Returns:
A string representing the AI prediction for the pull request review.
"""
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables) system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables) user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt) response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=0.2,
system=system_prompt,
user=user_prompt
)
return response return response
def _prepare_pr_review(self) -> str: def _prepare_pr_review(self) -> str:
"""
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback.
"""
review = self.prediction.strip() review = self.prediction.strip()
try: try:
data = json.loads(review) data = json.loads(review)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
data = try_fix_json(review) data = try_fix_json(review)
# reordering for nicer display # Move 'Security concerns' key to 'PR Analysis' section for better display
if 'PR Feedback' in data: if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']:
if 'Security concerns' in data['PR Feedback']: val = data['PR Feedback']['Security concerns']
val = data['PR Feedback']['Security concerns'] del data['PR Feedback']['Security concerns']
del data['PR Feedback']['Security concerns'] data['PR Analysis']['Security concerns'] = val
data['PR Analysis']['Security concerns'] = val
if settings.config.git_provider != 'bitbucket' and \ # Filter out code suggestions that can be submitted as inline comments
settings.pr_reviewer.inline_code_comments and \ if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']:
'Code suggestions' in data['PR Feedback']:
# keeping only code suggestions that can't be submitted as inline comments
data['PR Feedback']['Code suggestions'] = [ data['PR Feedback']['Code suggestions'] = [
d for d in data['PR Feedback']['Code suggestions'] d for d in data['PR Feedback']['Code suggestions']
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content')) if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
@ -121,8 +180,8 @@ class PRReviewer:
if not data['PR Feedback']['Code suggestions']: if not data['PR Feedback']['Code suggestions']:
del data['PR Feedback']['Code suggestions'] del data['PR Feedback']['Code suggestions']
# Add incremental review section
if self.incremental.is_incremental: if self.incremental.is_incremental:
# Rename title when incremental review - Add to the beginning of the dict
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}"
data = OrderedDict(data) data = OrderedDict(data)
data.update({'Incremental PR Review': { data.update({'Incremental PR Review': {
@ -132,6 +191,7 @@ class PRReviewer:
markdown_text = convert_to_markdown(data) markdown_text = convert_to_markdown(data)
user = self.git_provider.get_user_id() user = self.git_provider.get_user_id()
# Add help text if not in CLI mode
if not self.cli_mode: if not self.cli_mode:
markdown_text += "\n### How to use\n" markdown_text += "\n### How to use\n"
if user and '[bot]' not in user: if user and '[bot]' not in user:
@ -139,11 +199,16 @@ class PRReviewer:
else: else:
markdown_text += actions_help_text markdown_text += actions_help_text
# Log markdown response if verbosity level is high
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}") logging.info(f"Markdown response:\n{markdown_text}")
return markdown_text return markdown_text
def _publish_inline_code_comments(self): def _publish_inline_code_comments(self) -> None:
"""
Publishes inline comments on a pull request with code suggestions generated by the AI model.
"""
if settings.pr_reviewer.num_code_suggestions == 0: if settings.pr_reviewer.num_code_suggestions == 0:
return return
@ -153,11 +218,11 @@ class PRReviewer:
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
data = try_fix_json(review) data = try_fix_json(review)
comments = [] comments: List[str] = []
for d in data['PR Feedback']['Code suggestions']: for suggestion in data.get('PR Feedback', {}).get('Code suggestions', []):
relevant_file = d.get('relevant file', '').strip() relevant_file = suggestion.get('relevant file', '').strip()
relevant_line_in_file = d.get('relevant line in file', '').strip() relevant_line_in_file = suggestion.get('relevant line in file', '').strip()
content = d.get('suggestion content', '') content = suggestion.get('suggestion content', '')
if not relevant_file or not relevant_line_in_file or not content: if not relevant_file or not relevant_line_in_file or not content:
logging.info("Skipping inline comment with missing file/line/content") logging.info("Skipping inline comment with missing file/line/content")
continue continue
@ -172,15 +237,26 @@ class PRReviewer:
if comments: if comments:
self.git_provider.publish_inline_comments(comments) self.git_provider.publish_inline_comments(comments)
def _get_user_answers(self): def _get_user_answers(self) -> Tuple[str, str]:
answer_str = question_str = "" """
Retrieves the question and answer strings from the discussion messages related to a pull request.
Returns:
A tuple containing the question and answer strings.
"""
question_str = ""
answer_str = ""
if self.is_answer: if self.is_answer:
discussion_messages = self.git_provider.get_issue_comments() discussion_messages = self.git_provider.get_issue_comments()
for message in discussion_messages.reversed:
for message in reversed(discussion_messages):
if "Questions to better understand the PR:" in message.body: if "Questions to better understand the PR:" in message.body:
question_str = message.body question_str = message.body
elif '/answer' in message.body: elif '/answer' in message.body:
answer_str = message.body answer_str = message.body
if answer_str and question_str: if answer_str and question_str:
break break
return question_str, answer_str return question_str, answer_str