Support context aware settings (for each incoming request), support override of settings, refactor CLI to use pr_agent.py

This commit is contained in:
Ori Kotek
2023-08-01 14:43:26 +03:00
parent 6605f9c444
commit d7b77764c3
26 changed files with 305 additions and 384 deletions

View File

@ -1,6 +1,7 @@
import re import shlex
from pr_agent.config_loader import settings from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_information_from_user import PRInformationFromUser
@ -8,29 +9,39 @@ from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
command2class = {
"answer": PRReviewer,
"review": PRReviewer,
"review_pr": PRReviewer,
"reflect": PRInformationFromUser,
"reflect_and_review": PRInformationFromUser,
"describe": PRDescription,
"describe_pr": PRDescription,
"improve": PRCodeSuggestions,
"improve_code": PRCodeSuggestions,
"ask": PRQuestions,
"ask_question": PRQuestions,
"update_changelog": PRUpdateChangelog,
}
commands = list(command2class.keys())
class PRAgent: class PRAgent:
def __init__(self): def __init__(self):
pass pass
async def handle_request(self, pr_url, request) -> bool: async def handle_request(self, pr_url, request) -> bool:
action, *args = request.strip().split() lexer = shlex.shlex(request, posix=True)
if any(cmd == action for cmd in ["/answer"]): lexer.whitespace_split = True
await PRReviewer(pr_url, is_answer=True, args=args).review() action, *args = list(lexer)
elif any(cmd == action for cmd in ["/review", "/review_pr", "/reflect_and_review"]): args = update_settings_from_args(args)
if settings.pr_reviewer.ask_and_reflect or "/reflect_and_review" in request: action = action.lstrip("/").lower()
await PRInformationFromUser(pr_url, args=args).generate_questions() if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect:
else: action = "review"
await PRReviewer(pr_url, args=args).review() if action == "answer":
elif any(cmd == action for cmd in ["/describe", "/describe_pr"]): await PRReviewer(pr_url, is_answer=True, args=args).run()
await PRDescription(pr_url, args=args).describe() elif action in command2class:
elif any(cmd == action for cmd in ["/improve", "/improve_code"]): await command2class[action](pr_url, args=args).run()
await PRCodeSuggestions(pr_url, args=args).suggest()
elif any(cmd == action for cmd in ["/ask", "/ask_question"]):
await PRQuestions(pr_url, args=args).answer()
elif any(cmd == action for cmd in ["/update_changelog"]):
await PRUpdateChangelog(pr_url, args=args).update_changelog()
else: else:
return False return False
return True return True

View File

@ -1,10 +1,10 @@
import logging import logging
import openai import openai
from openai.error import APIError, Timeout, TryAgain, RateLimitError from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry from retry import retry
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
OPENAI_RETRIES=5 OPENAI_RETRIES=5
@ -21,16 +21,16 @@ class AiHandler:
Raises a ValueError if the OpenAI key is missing. Raises a ValueError if the OpenAI key is missing.
""" """
try: try:
openai.api_key = settings.openai.key openai.api_key = get_settings().openai.key
if settings.get("OPENAI.ORG", None): if get_settings().get("OPENAI.ORG", None):
openai.organization = settings.openai.org openai.organization = get_settings().openai.org
self.deployment_id = settings.get("OPENAI.DEPLOYMENT_ID", None) self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if settings.get("OPENAI.API_TYPE", None): if get_settings().get("OPENAI.API_TYPE", None):
openai.api_type = settings.openai.api_type openai.api_type = get_settings().openai.api_type
if settings.get("OPENAI.API_VERSION", None): if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = settings.openai.api_version openai.api_version = get_settings().openai.api_version
if settings.get("OPENAI.API_BASE", None): if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = settings.openai.api_base openai.api_base = get_settings().openai.api_base
except AttributeError as e: except AttributeError as e:
raise ValueError("OpenAI key is required") from e raise ValueError("OpenAI key is required") from e

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import re import re
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
def extend_patch(original_file_str, patch_str, num_lines) -> str: def extend_patch(original_file_str, patch_str, num_lines) -> str:
@ -55,7 +55,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
continue continue
extended_patch_lines.append(line) extended_patch_lines.append(line)
except Exception as e: except Exception as e:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.error(f"Failed to extend patch: {e}") logging.error(f"Failed to extend patch: {e}")
return patch_str return patch_str
@ -126,14 +126,14 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
""" """
if not new_file_content_str: if not new_file_content_str:
# logic for handling deleted files - don't show patch, just show that the file was deleted # logic for handling deleted files - don't show patch, just show that the file was deleted
if settings.config.verbosity_level > 0: if get_settings().config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, minimizing deletion file") logging.info(f"Processing file: {file_name}, minimizing deletion file")
patch = None # file was deleted patch = None # file was deleted
else: else:
patch_lines = patch.splitlines() patch_lines = patch.splitlines()
patch_new = omit_deletion_hunks(patch_lines) patch_new = omit_deletion_hunks(patch_lines)
if patch != patch_new: if patch != patch_new:
if settings.config.verbosity_level > 0: if get_settings().config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, hunks were deleted") logging.info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new patch = patch_new
return patch return patch
@ -141,7 +141,8 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
""" """
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file. Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
the file.
Args: Args:
patch (str): The patch string to be converted. patch (str): The patch string to be converted.

View File

@ -1,15 +1,15 @@
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501 # Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
from typing import Dict from typing import Dict
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
language_extension_map_org = settings.language_extension_map_org language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501 # Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = settings.bad_extensions.default bad_extensions = get_settings().bad_extensions.default
if settings.config.use_extra_bad_extensions: if get_settings().config.use_extra_bad_extensions:
bad_extensions += settings.bad_extensions.extra bad_extensions += get_settings().bad_extensions.extra
def filter_bad_extensions(files): def filter_bad_extensions(files):

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Tuple, Union, Callable, List from typing import Callable, Tuple
from github import RateLimitExceededException from github import RateLimitExceededException
@ -10,7 +10,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_large_diff from pr_agent.algo.utils import load_large_diff
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider from pr_agent.git_providers.git_provider import GitProvider
DELETED_FILES_ = "Deleted files:\n" DELETED_FILES_ = "Deleted files:\n"
@ -27,11 +27,15 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
Returns a string with the diff of the pull request, applying diff minimization techniques if needed. Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
Args: Args:
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request. git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. request.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The name of the model used for tokenization. model (str): The name of the model used for tokenization.
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False. add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False. diff. Defaults to False.
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
extra lines of context. Defaults to False.
Returns: Returns:
str: A string with the diff of the pull request, applying diff minimization techniques if needed. str: A string with the diff of the pull request, applying diff minimization techniques if needed.
@ -76,10 +80,12 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \ add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]: Tuple[list, int]:
""" """
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed. Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.
Args: Args:
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files. - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request. - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff. - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
@ -119,10 +125,13 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]: convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
""" """
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used. Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
tokens used.
Args: Args:
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files. top_langs (list): A list of dictionaries representing the languages used in the pull request and their
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. corresponding files.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The model used for tokenization. model (str): The model used for tokenization.
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff. convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
Returns: Returns:
@ -181,7 +190,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
# Current logic is to skip the patch if it's too large # Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens # TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements # until we meet the requirements
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.warning(f"Patch too large, minimizing it, {file.filename}") logging.warning(f"Patch too large, minimizing it, {file.filename}")
if not modified_files_list: if not modified_files_list:
total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_) total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
@ -196,15 +205,15 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
patch_final = patch patch_final = patch
patches.append(patch_final) patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final) total_tokens += token_handler.count_tokens(patch_final)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}") logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
return patches, modified_files_list, deleted_files_list return patches, modified_files_list, deleted_files_list
async def retry_with_fallback_models(f: Callable): async def retry_with_fallback_models(f: Callable):
model = settings.config.model model = get_settings().config.model
fallback_models = settings.config.fallback_models fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list): if not isinstance(fallback_models, list):
fallback_models = [fallback_models] fallback_models = [fallback_models]
all_models = [model] + fallback_models all_models = [model] + fallback_models

View File

@ -1,8 +1,7 @@
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model from tiktoken import encoding_for_model
from pr_agent.algo import MAX_TOKENS from pr_agent.config_loader import get_settings
from pr_agent.config_loader import settings
class TokenHandler: class TokenHandler:
@ -10,9 +9,12 @@ class TokenHandler:
A class for handling tokens in the context of a pull request. A class for handling tokens in the context of a pull request.
Attributes: Attributes:
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them. - encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module. number of tokens in them.
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method. - limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the
pr_agent.algo module.
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens
method.
""" """
def __init__(self, pr, vars: dict, system, user): def __init__(self, pr, vars: dict, system, user):
@ -25,7 +27,7 @@ class TokenHandler:
- system: The system string. - system: The system string.
- user: The user string. - user: The user string.
""" """
self.encoder = encoding_for_model(settings.config.model) self.encoder = encoding_for_model(get_settings().config.model)
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):

View File

@ -1,15 +1,24 @@
from __future__ import annotations from __future__ import annotations
from typing import List
import difflib import difflib
from datetime import datetime
import json import json
import logging import logging
import re import re
import textwrap import textwrap
from datetime import datetime
from typing import Any, List
from pr_agent.config_loader import settings from starlette_context import context
from pr_agent.config_loader import get_settings, global_settings
def get_setting(key: str) -> Any:
try:
key = key.upper()
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
except Exception:
return global_settings.get(key, None)
def convert_to_markdown(output_data: dict) -> str: def convert_to_markdown(output_data: dict) -> str:
""" """
@ -97,12 +106,16 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
- data: A dictionary containing the parsed JSON data. - data: A dictionary containing the parsed JSON data.
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion. The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message. If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion. message.
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines. If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON
message by parsing until the last valid code suggestion.
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or
newlines.
It tries to parse the JSON message with the closing bracket and checks if it is valid. It tries to parse the JSON message with the closing bracket and checks if it is valid.
If the JSON message is valid, the parsed JSON data is returned. If the JSON message is valid, the parsed JSON data is returned.
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached. If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON
message is obtained or the maximum number of iterations is reached.
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned. If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
""" """
@ -184,7 +197,8 @@ def convert_str_to_datetime(date_str):
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str: def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
""" """
Generate a patch for a modified file by comparing the original content of the file with the new content provided as input. Generate a patch for a modified file by comparing the original content of the file with the new content provided as
input.
Args: Args:
file: The file object for which the patch needs to be generated. file: The file object for which the patch needs to be generated.
@ -199,14 +213,16 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
None. None.
Additional Information: Additional Information:
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output. - If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created. as output.
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating
that the file was modified but no patch was found, and a patch is manually created.
""" """
if not patch: # to Do - also add condition for file extension if not patch: # to Do - also add condition for file extension
try: try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True)) new_file_content_str.splitlines(keepends=True))
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.") logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
patch = ''.join(diff) patch = ''.join(diff)
except Exception: except Exception:
@ -214,7 +230,7 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
return patch return patch
def update_settings_from_args(args: List[str]) -> None: def update_settings_from_args(args: List[str]) -> List[str]:
""" """
Update the settings of the Dynaconf object based on the arguments passed to the function. Update the settings of the Dynaconf object based on the arguments passed to the function.
@ -230,28 +246,22 @@ def update_settings_from_args(args: List[str]) -> None:
ValueError: If the argument is not in the correct format. ValueError: If the argument is not in the correct format.
""" """
other_args = []
if args: if args:
for arg in args: for arg in args:
try: arg = arg.strip()
if arg.startswith('--'):
arg = arg.strip('-').strip() arg = arg.strip('-').strip()
vals = arg.split('=') vals = arg.split('=')
if len(vals) != 2: if len(vals) != 2:
raise ValueError(f'Invalid argument format: {arg}') logging.error(f'Invalid argument format: {arg}')
other_args.append(arg)
continue
key, value = vals key, value = vals
keys = key.split('.') key = key.strip().upper()
d = settings value = value.strip()
for i, k in enumerate(keys[:-1]): get_settings().set(key, value)
if k not in d:
raise ValueError(f'Invalid setting: {key}')
d = d[k]
if keys[-1] not in d:
raise ValueError(f'Invalid setting: {key}')
if isinstance(d[keys[-1]], bool):
d[keys[-1]] = value.lower() in ("yes", "true", "t", "1")
else:
d[keys[-1]] = type(d[keys[-1]])(value)
logging.info(f'Updated setting {key} to: "{value}"') logging.info(f'Updated setting {key} to: "{value}"')
except ValueError as e: else:
logging.error(str(e)) other_args.append(arg)
except Exception as e: return other_args
logging.error(f'Failed to parse argument {arg}: {e}')

View File

@ -3,15 +3,11 @@ import asyncio
import logging import logging
import os import os
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.agent.pr_agent import PRAgent, commands
from pr_agent.tools.pr_description import PRDescription from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
def run(args=None): def run(inargs=None):
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage= parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\ """\
Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>]. Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>].
@ -34,79 +30,16 @@ To edit any configuration parameter from 'configuration.toml', just add -config_
For example: '- cli.py --pr-url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' For example: '- cli.py --pr-url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""") """)
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr', parser.add_argument('command', type=str, help='The', choices=commands, default='review')
'ask', 'ask_question',
'describe', 'describe_pr',
'improve', 'improve_code',
'reflect', 'review_after_reflect',
'update_changelog'],
default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
args = parser.parse_args(args) args = parser.parse_args(inargs)
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
command = args.command.lower() command = args.command.lower()
commands = { get_settings().set("CONFIG.CLI_MODE", True)
'ask': _handle_ask_command, result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest)))
'ask_question': _handle_ask_command, if not result:
'describe': _handle_describe_command,
'describe_pr': _handle_describe_command,
'improve': _handle_improve_command,
'improve_code': _handle_improve_command,
'review': _handle_review_command,
'review_pr': _handle_review_command,
'reflect': _handle_reflect_command,
'review_after_reflect': _handle_review_after_reflect_command,
'update_changelog': _handle_update_changelog,
}
if command in commands:
commands[command](args.pr_url, args.rest)
else:
print(f"Unknown command: {command}")
parser.print_help() parser.print_help()
def _handle_ask_command(pr_url: str, rest: list):
if len(rest) == 0:
print("Please specify a question")
return
print(f"Question: {' '.join(rest)} about PR {pr_url}")
reviewer = PRQuestions(pr_url, rest)
asyncio.run(reviewer.answer())
def _handle_describe_command(pr_url: str, rest: list):
print(f"PR description: {pr_url}")
reviewer = PRDescription(pr_url, args=rest)
asyncio.run(reviewer.describe())
def _handle_improve_command(pr_url: str, rest: list):
print(f"PR code suggestions: {pr_url}")
reviewer = PRCodeSuggestions(pr_url, args=rest)
asyncio.run(reviewer.suggest())
def _handle_review_command(pr_url: str, rest: list):
print(f"Reviewing PR: {pr_url}")
reviewer = PRReviewer(pr_url, cli_mode=True, args=rest)
asyncio.run(reviewer.review())
def _handle_reflect_command(pr_url: str, rest: list):
print(f"Asking the PR author questions: {pr_url}")
reviewer = PRInformationFromUser(pr_url)
asyncio.run(reviewer.generate_questions())
def _handle_review_after_reflect_command(pr_url: str, rest: list):
print(f"Processing author's answers and sending review: {pr_url}")
reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True, args=rest)
asyncio.run(reviewer.review())
def _handle_update_changelog(pr_url: str, rest: list):
print(f"Updating changlog for: {pr_url}")
reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest)
asyncio.run(reviewer.update_changelog())
if __name__ == '__main__': if __name__ == '__main__':
run() run()

View File

@ -3,11 +3,12 @@ from pathlib import Path
from typing import Optional from typing import Optional
from dynaconf import Dynaconf from dynaconf import Dynaconf
from starlette_context import context
PR_AGENT_TOML_KEY = 'pr-agent' PR_AGENT_TOML_KEY = 'pr-agent'
current_dir = dirname(abspath(__file__)) current_dir = dirname(abspath(__file__))
settings = Dynaconf( global_settings = Dynaconf(
envvar_prefix=False, envvar_prefix=False,
merge_enabled=True, merge_enabled=True,
settings_files=[join(current_dir, f) for f in [ settings_files=[join(current_dir, f) for f in [
@ -25,6 +26,13 @@ settings = Dynaconf(
) )
def get_settings():
try:
return context["settings"]
except Exception:
return global_settings
# Add local configuration from pyproject.toml of the project being reviewed # Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Path: def _find_repository_root() -> Path:
""" """
@ -39,6 +47,7 @@ def _find_repository_root() -> Path:
cwd = cwd.parent cwd = cwd.parent
return None return None
def _find_pyproject() -> Optional[Path]: def _find_pyproject() -> Optional[Path]:
""" """
Search for file pyproject.toml in the repository root. Search for file pyproject.toml in the repository root.
@ -49,6 +58,7 @@ def _find_pyproject() -> Optional[Path]:
return pyproject if pyproject.is_file() else None return pyproject if pyproject.is_file() else None
return None return None
pyproject_path = _find_pyproject() pyproject_path = _find_pyproject()
if pyproject_path is not None: if pyproject_path is not None:
settings.load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}') get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')

View File

@ -1,4 +1,4 @@
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from pr_agent.git_providers.github_provider import GithubProvider from pr_agent.git_providers.github_provider import GithubProvider
from pr_agent.git_providers.gitlab_provider import GitLabProvider from pr_agent.git_providers.gitlab_provider import GitLabProvider
@ -13,7 +13,7 @@ _GIT_PROVIDERS = {
def get_git_provider(): def get_git_provider():
try: try:
provider_id = settings.config.git_provider provider_id = get_settings().config.git_provider
except AttributeError as e: except AttributeError as e:
raise ValueError("git_provider is a required attribute in the configuration file") from e raise ValueError("git_provider is a required attribute in the configuration file") from e
if provider_id not in _GIT_PROVIDERS: if provider_id not in _GIT_PROVIDERS:

View File

@ -5,15 +5,14 @@ from urllib.parse import urlparse
import requests import requests
from atlassian.bitbucket import Cloud from atlassian.bitbucket import Cloud
from pr_agent.config_loader import settings from ..config_loader import get_settings
from .git_provider import FilePatchInfo from .git_provider import FilePatchInfo
class BitbucketProvider: class BitbucketProvider:
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
s = requests.Session() s = requests.Session()
s.headers['Authorization'] = f'Bearer {settings.get("BITBUCKET.BEARER_TOKEN", None)}' s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
self.bitbucket_client = Cloud(session=s) self.bitbucket_client = Cloud(session=s)
self.workspace_slug = None self.workspace_slug = None

View File

@ -7,12 +7,11 @@ from github import AppAuthentication, Auth, Github, GithubException
from retry import retry from retry import retry
from starlette_context import context from starlette_context import context
from pr_agent.config_loader import settings
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded from ..servers.utils import RateLimitExceeded
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
class GithubProvider(GitProvider): class GithubProvider(GitProvider):
@ -85,7 +84,7 @@ class GithubProvider(GitProvider):
return self.pr.get_files() return self.pr.get_files()
@retry(exceptions=RateLimitExceeded, @retry(exceptions=RateLimitExceeded,
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
try: try:
files = self.get_files() files = self.get_files()
@ -118,7 +117,7 @@ class GithubProvider(GitProvider):
# self.pr.create_issue_comment(pr_comment) # self.pr.create_issue_comment(pr_comment)
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not settings.config.publish_output_progress: if is_temporary and not get_settings().config.publish_output_progress:
logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}") logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
return return
response = self.pr.create_issue_comment(pr_comment) response = self.pr.create_issue_comment(pr_comment)
@ -149,7 +148,7 @@ class GithubProvider(GitProvider):
position = i position = i
break break
if position == -1: if position == -1:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
subject_type = "FILE" subject_type = "FILE"
else: else:
@ -174,13 +173,13 @@ class GithubProvider(GitProvider):
relevant_lines_end = suggestion['relevant_lines_end'] relevant_lines_end = suggestion['relevant_lines_end']
if not relevant_lines_start or relevant_lines_start == -1: if not relevant_lines_start or relevant_lines_start == -1:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.exception( logging.exception(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
continue continue
if relevant_lines_end < relevant_lines_start: if relevant_lines_end < relevant_lines_start:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.exception(f"Failed to publish code suggestion, " logging.exception(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and " f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}") f"relevant_lines_start is {relevant_lines_start}")
@ -207,7 +206,7 @@ class GithubProvider(GitProvider):
self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list) self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list)
return True return True
except Exception as e: except Exception as e:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.error(f"Failed to publish code suggestion, error: {e}") logging.error(f"Failed to publish code suggestion, error: {e}")
return False return False
@ -241,7 +240,7 @@ class GithubProvider(GitProvider):
return self.github_user_id return self.github_user_id
def get_notifications(self, since: datetime): def get_notifications(self, since: datetime):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user") deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type != 'user': if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications") raise ValueError("Deployment mode must be set to 'user' to get notifications")
@ -282,12 +281,12 @@ class GithubProvider(GitProvider):
return repo_name, pr_number return repo_name, pr_number
def _get_github_client(self): def _get_github_client(self):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user") deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type == 'app': if deployment_type == 'app':
try: try:
private_key = settings.github.private_key private_key = get_settings().github.private_key
app_id = settings.github.app_id app_id = get_settings().github.app_id
except AttributeError as e: except AttributeError as e:
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
if not self.installation_id: if not self.installation_id:
@ -298,7 +297,7 @@ class GithubProvider(GitProvider):
if deployment_type == 'user': if deployment_type == 'user':
try: try:
token = settings.github.user_token token = get_settings().github.user_token
except AttributeError as e: except AttributeError as e:
raise ValueError( raise ValueError(
"GitHub token is required when using user deployment. See: " "GitHub token is required when using user deployment. See: "
@ -327,7 +326,9 @@ class GithubProvider(GitProvider):
def publish_labels(self, pr_types): def publish_labels(self, pr_types):
try: try:
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", "Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9", "Other": "d1bcf9"} label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5",
"Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9",
"Other": "d1bcf9"}
post_parameters = [] post_parameters = []
for p in pr_types: for p in pr_types:
color = label_color_map.get(p, "d1bcf9") # default to "Other" color color = label_color_map.get(p, "d1bcf9") # default to "Other" color

View File

@ -6,9 +6,8 @@ from urllib.parse import urlparse
import gitlab import gitlab
from gitlab import GitlabGetError from gitlab import GitlabGetError
from pr_agent.config_loader import settings
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
logger = logging.getLogger() logger = logging.getLogger()
@ -17,10 +16,10 @@ logger = logging.getLogger()
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None) gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
raise ValueError("GitLab URL is not set in the config file") raise ValueError("GitLab URL is not set in the config file")
gitlab_access_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None) gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_access_token: if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file") raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab( self.gl = gitlab.Gitlab(

View File

@ -5,7 +5,7 @@ from typing import List
from git import Repo from git import Repo
from pr_agent.config_loader import _find_repository_root, settings from pr_agent.config_loader import _find_repository_root, get_settings
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
@ -38,12 +38,12 @@ class LocalGitProvider(GitProvider):
self._prepare_repo() self._prepare_repo()
self.diff_files = None self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files()) self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
self.description_path = settings.get('local.description_path') \ self.description_path = get_settings().get('local.description_path') \
if settings.get('local.description_path') is not None else self.repo_path / 'description.md' if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
self.review_path = settings.get('local.review_path') \ self.review_path = get_settings().get('local.review_path') \
if settings.get('local.review_path') is not None else self.repo_path / 'review.md' if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
# inline code comments are not supported for local git repositories # inline code comments are not supported for local git repositories
settings.pr_reviewer.inline_code_comments = False get_settings().pr_reviewer.inline_code_comments = False
def _prepare_repo(self): def _prepare_repo(self):
""" """

View File

@ -3,7 +3,7 @@ import json
import os import os
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
@ -30,11 +30,11 @@ async def run_action():
return return
# Set the environment variables in the settings # Set the environment variables in the settings
settings.set("OPENAI.KEY", OPENAI_KEY) get_settings().set("OPENAI.KEY", OPENAI_KEY)
if OPENAI_ORG: if OPENAI_ORG:
settings.set("OPENAI.ORG", OPENAI_ORG) get_settings().set("OPENAI.ORG", OPENAI_ORG)
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN) get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
settings.set("GITHUB.DEPLOYMENT_TYPE", "user") get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user")
# Load the event payload # Load the event payload
try: try:
@ -50,7 +50,7 @@ async def run_action():
if action in ["opened", "reopened"]: if action in ["opened", "reopened"]:
pr_url = event_payload.get("pull_request", {}).get("url") pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url: if pr_url:
await PRReviewer(pr_url).review() await PRReviewer(pr_url).run()
# Handle issue comment event # Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment": elif GITHUB_EVENT_NAME == "issue_comment":

View File

@ -1,6 +1,7 @@
from typing import Dict, Any import copy
import logging import logging
import sys import sys
from typing import Any, Dict
import uvicorn import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
@ -9,7 +10,7 @@ from starlette_context import context
from starlette_context.middleware import RawContextMiddleware from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.servers.utils import verify_signature from pr_agent.servers.utils import verify_signature
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@ -20,7 +21,8 @@ router = APIRouter()
async def handle_github_webhooks(request: Request, response: Response): async def handle_github_webhooks(request: Request, response: Response):
""" """
Receives and processes incoming GitHub webhook requests. Receives and processes incoming GitHub webhook requests.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further processing. Verifies the request signature, parses the request body, and passes it to the handle_request function for further
processing.
""" """
logging.debug("Received a GitHub webhook") logging.debug("Received a GitHub webhook")
@ -29,6 +31,7 @@ async def handle_github_webhooks(request: Request, response: Response):
logging.debug(f'Request body:\n{body}') logging.debug(f'Request body:\n{body}')
installation_id = body.get("installation", {}).get("id") installation_id = body.get("installation", {}).get("id")
context["installation_id"] = installation_id context["installation_id"] = installation_id
context["settings"] = copy.deepcopy(global_settings)
return await handle_request(body) return await handle_request(body)
@ -46,7 +49,7 @@ async def get_body(request):
raise HTTPException(status_code=400, detail="Error parsing request body") from e raise HTTPException(status_code=400, detail="Error parsing request body") from e
body_bytes = await request.body() body_bytes = await request.body()
signature_header = request.headers.get('x-hub-signature-256', None) signature_header = request.headers.get('x-hub-signature-256', None)
webhook_secret = getattr(settings.github, 'webhook_secret', None) webhook_secret = getattr(get_settings().github, 'webhook_secret', None)
if webhook_secret: if webhook_secret:
verify_signature(body_bytes, webhook_secret, signature_header) verify_signature(body_bytes, webhook_secret, signature_header)
return body return body
@ -96,7 +99,7 @@ async def root():
def start(): def start():
# Override the deployment type to app # Override the deployment type to app
settings.set("GITHUB.DEPLOYMENT_TYPE", "app") get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
middleware = [Middleware(RawContextMiddleware)] middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware) app = FastAPI(middleware=middleware)
app.include_router(router) app.include_router(router)

View File

@ -6,7 +6,7 @@ from datetime import datetime, timezone
import aiohttp import aiohttp
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.servers.help import bot_help_text from pr_agent.servers.help import bot_help_text
@ -38,8 +38,8 @@ async def polling_loop():
agent = PRAgent() agent = PRAgent()
try: try:
deployment_type = settings.github.deployment_type deployment_type = get_settings().github.deployment_type
token = settings.github.user_token token = get_settings().github.user_token
except AttributeError: except AttributeError:
deployment_type = 'none' deployment_type = 'none'
token = None token = None

View File

@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse
from starlette.background import BackgroundTasks from starlette.background import BackgroundTasks
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
app = FastAPI() app = FastAPI()
router = APIRouter() router = APIRouter()
@ -29,13 +29,13 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
def start(): def start():
gitlab_url = settings.get("GITLAB.URL", None) gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
raise ValueError("GITLAB.URL is not set") raise ValueError("GITLAB.URL is not set")
gitlab_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None) gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token: if not gitlab_token:
raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set") raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set")
settings.config.git_provider = "gitlab" get_settings().config.git_provider = "gitlab"
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)

View File

@ -8,8 +8,8 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json, update_settings_from_args from pr_agent.algo.utils import try_fix_json
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import BitbucketProvider, get_git_provider from pr_agent.git_providers import BitbucketProvider, get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
@ -21,7 +21,6 @@ class PRCodeSuggestions:
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
update_settings_from_args(args)
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
@ -33,24 +32,24 @@ class PRCodeSuggestions:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_language, "language": self.main_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"num_code_suggestions": settings.pr_code_suggestions.num_code_suggestions, "num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions,
"extra_instructions": settings.pr_code_suggestions.extra_instructions, "extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_code_suggestions_prompt.system, get_settings().pr_code_suggestions_prompt.system,
settings.pr_code_suggestions_prompt.user) get_settings().pr_code_suggestions_prompt.user)
async def suggest(self): async def run(self):
assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now" assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
logging.info('Generating code suggestions for PR...') logging.info('Generating code suggestions for PR...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions() data = self._prepare_pr_code_suggestions()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...') logging.info('Pushing inline code comments...')
@ -71,9 +70,9 @@ class PRCodeSuggestions:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_code_suggestions_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_code_suggestions_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -86,7 +85,7 @@ class PRCodeSuggestions:
try: try:
data = json.loads(review) data = json.loads(review)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse json response: {review}") logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review, code_suggestions=True) data = try_fix_json(review, code_suggestions=True)
return data return data
@ -95,7 +94,7 @@ class PRCodeSuggestions:
code_suggestions = [] code_suggestions = []
for d in data['Code suggestions']: for d in data['Code suggestions']:
try: try:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"suggestion: {d}") logging.info(f"suggestion: {d}")
relevant_file = d['relevant file'].strip() relevant_file = d['relevant file'].strip()
relevant_lines_str = d['relevant lines'].strip() relevant_lines_str = d['relevant lines'].strip()
@ -113,8 +112,8 @@ class PRCodeSuggestions:
code_suggestions.append({'body': body, 'relevant_file': relevant_file, code_suggestions.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_lines_start, 'relevant_lines_start': relevant_lines_start,
'relevant_lines_end': relevant_lines_end}) 'relevant_lines_end': relevant_lines_end})
except: except Exception:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse suggestion: {d}") logging.info(f"Could not parse suggestion: {d}")
self.git_provider.publish_code_suggestions(code_suggestions) self.git_provider.publish_code_suggestions(code_suggestions)
@ -136,7 +135,7 @@ class PRCodeSuggestions:
if delta_spaces > 0: if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n') new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e: except Exception as e:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}") logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
return new_code_snippet return new_code_snippet

View File

@ -1,15 +1,14 @@
import copy import copy
import json import json
import logging import logging
from typing import Tuple, List from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
@ -17,13 +16,12 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
class PRDescription: class PRDescription:
def __init__(self, pr_url: str, args: list = None): def __init__(self, pr_url: str, args: list = None):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
Args: Args:
pr_url (str): The URL of the pull request. pr_url (str): The URL of the pull request.
args (list, optional): List of arguments passed to the PRDescription class. Defaults to None. args (list, optional): List of arguments passed to the PRDescription class. Defaults to None.
""" """
update_settings_from_args(args)
# Initialize the git provider and main PR language # Initialize the git provider and main PR language
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
@ -40,27 +38,27 @@ class PRDescription:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_pr_language, "language": self.main_pr_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"extra_instructions": settings.pr_description.extra_instructions, "extra_instructions": get_settings().pr_description.extra_instructions,
} }
# Initialize the token handler # Initialize the token handler
self.token_handler = TokenHandler( self.token_handler = TokenHandler(
self.git_provider.pr, self.git_provider.pr,
self.vars, self.vars,
settings.pr_description_prompt.system, get_settings().pr_description_prompt.system,
settings.pr_description_prompt.user, get_settings().pr_description_prompt.user,
) )
# Initialize patches_diff and prediction attributes # Initialize patches_diff and prediction attributes
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
async def describe(self): async def run(self):
""" """
Generates a PR description using an AI model and publishes it to the PR. Generates a PR description using an AI model and publishes it to the PR.
""" """
logging.info('Generating a PR description...') logging.info('Generating a PR description...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
@ -68,9 +66,9 @@ class PRDescription:
logging.info('Preparing answer...') logging.info('Preparing answer...')
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer() pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing answer...') logging.info('Pushing answer...')
if settings.pr_description.publish_description_as_comment: if get_settings().pr_description.publish_description_as_comment:
self.git_provider.publish_comment(markdown_text) self.git_provider.publish_comment(markdown_text)
else: else:
self.git_provider.publish_description(pr_title, pr_body) self.git_provider.publish_description(pr_title, pr_body)
@ -116,10 +114,10 @@ class PRDescription:
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
@ -170,7 +168,7 @@ class PRDescription:
else: else:
pr_body += f"{value}\n\n___\n" pr_body += f"{value}\n\n___\n"
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}") logging.info(f"title:\n{title}\n{pr_body}")
return title, pr_body, pr_types, markdown_text return title, pr_body, pr_types, markdown_text

View File

@ -6,13 +6,11 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
class PRInformationFromUser: class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None): def __init__(self, pr_url: str, args: list = None):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
@ -29,19 +27,19 @@ class PRInformationFromUser:
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_information_from_user_prompt.system, get_settings().pr_information_from_user_prompt.system,
settings.pr_information_from_user_prompt.user) get_settings().pr_information_from_user_prompt.user)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
async def generate_questions(self): async def generate_questions(self):
logging.info('Generating question to the user...') logging.info('Generating question to the user...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing questions...", is_temporary=True) self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing questions...') logging.info('Preparing questions...')
pr_comment = self._prepare_pr_answer() pr_comment = self._prepare_pr_answer()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing questions...') logging.info('Pushing questions...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
@ -57,9 +55,9 @@ class PRInformationFromUser:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_information_from_user_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_information_from_user_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -68,7 +66,7 @@ class PRInformationFromUser:
def _prepare_pr_answer(self) -> str: def _prepare_pr_answer(self) -> str:
model_output = self.prediction.strip() model_output = self.prediction.strip()
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"answer_str:\n{model_output}") logging.info(f"answer_str:\n{model_output}")
answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\ answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\
"\n>/answer\n>1) ...\n>2) ...\n>...\n" "\n>/answer\n>1) ...\n>2) ...\n>...\n"

View File

@ -6,7 +6,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
@ -30,8 +30,8 @@ class PRQuestions:
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_questions_prompt.system, get_settings().pr_questions_prompt.system,
settings.pr_questions_prompt.user) get_settings().pr_questions_prompt.user)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -42,14 +42,14 @@ class PRQuestions:
question_str = "" question_str = ""
return question_str return question_str
async def answer(self): async def run(self):
logging.info('Answering a PR question...') logging.info('Answering a PR question...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing answer...", is_temporary=True) self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing answer...') logging.info('Preparing answer...')
pr_comment = self._prepare_pr_answer() pr_comment = self._prepare_pr_answer()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing answer...') logging.info('Pushing answer...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
@ -65,9 +65,9 @@ class PRQuestions:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_questions_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_questions_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -77,6 +77,6 @@ class PRQuestions:
def _prepare_pr_answer(self) -> str: def _prepare_pr_answer(self) -> str:
answer_str = f"Question: {self.question_str}\n\n" answer_str = f"Question: {self.question_str}\n\n"
answer_str += f"Answer:\n{self.prediction.strip()}\n\n" answer_str += f"Answer:\n{self.prediction.strip()}\n\n"
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"answer_str:\n{answer_str}") logging.info(f"answer_str:\n{answer_str}")
return answer_str return answer_str

View File

@ -2,17 +2,17 @@ import copy
import json import json
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, update_settings_from_args from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language, IncrementalPR from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
from pr_agent.servers.help import actions_help_text, bot_help_text from pr_agent.servers.help import actions_help_text, bot_help_text
@ -20,17 +20,15 @@ class PRReviewer:
""" """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None): def __init__(self, pr_url: str, is_answer: bool = False, args: list = None):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Args: Args:
pr_url (str): The URL of the pull request to be reviewed. pr_url (str): The URL of the pull request to be reviewed.
cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False.
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
""" """
update_settings_from_args(args)
self.parse_args(args) # -i command self.parse_args(args) # -i command
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
@ -41,11 +39,10 @@ class PRReviewer:
self.is_answer = is_answer self.is_answer = is_answer
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now") raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode
answer_str, question_str = self._get_user_answers() answer_str, question_str = self._get_user_answers()
self.vars = { self.vars = {
@ -54,21 +51,21 @@ class PRReviewer:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_language, "language": self.main_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"require_score": settings.pr_reviewer.require_score_review, "require_score": get_settings().pr_reviewer.require_score_review,
"require_tests": settings.pr_reviewer.require_tests_review, "require_tests": get_settings().pr_reviewer.require_tests_review,
"require_security": settings.pr_reviewer.require_security_review, "require_security": get_settings().pr_reviewer.require_security_review,
"require_focused": settings.pr_reviewer.require_focused_review, "require_focused": get_settings().pr_reviewer.require_focused_review,
'num_code_suggestions': settings.pr_reviewer.num_code_suggestions, 'num_code_suggestions': get_settings().pr_reviewer.num_code_suggestions,
'question_str': question_str, 'question_str': question_str,
'answer_str': answer_str, 'answer_str': answer_str,
"extra_instructions": settings.pr_reviewer.extra_instructions, "extra_instructions": get_settings().pr_reviewer.extra_instructions,
} }
self.token_handler = TokenHandler( self.token_handler = TokenHandler(
self.git_provider.pr, self.git_provider.pr,
self.vars, self.vars,
settings.pr_review_prompt.system, get_settings().pr_review_prompt.system,
settings.pr_review_prompt.user get_settings().pr_review_prompt.user
) )
def parse_args(self, args: List[str]) -> None: def parse_args(self, args: List[str]) -> None:
@ -88,13 +85,13 @@ class PRReviewer:
is_incremental = True is_incremental = True
self.incremental = IncrementalPR(is_incremental) self.incremental = IncrementalPR(is_incremental)
async def review(self) -> None: async def run(self) -> None:
""" """
Review the pull request and generate feedback. Review the pull request and generate feedback.
""" """
logging.info('Reviewing PR...') logging.info('Reviewing PR...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
@ -102,12 +99,12 @@ class PRReviewer:
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
pr_comment = self._prepare_pr_review() pr_comment = self._prepare_pr_review()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
if settings.pr_reviewer.inline_code_comments: if get_settings().pr_reviewer.inline_code_comments:
logging.info('Pushing inline code comments...') logging.info('Pushing inline code comments...')
self._publish_inline_code_comments() self._publish_inline_code_comments()
@ -140,10 +137,10 @@ class PRReviewer:
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
@ -158,7 +155,8 @@ class PRReviewer:
def _prepare_pr_review(self) -> str: def _prepare_pr_review(self) -> str:
""" """
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback. Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes
the feedback.
""" """
review = self.prediction.strip() review = self.prediction.strip()
@ -174,7 +172,8 @@ class PRReviewer:
data['PR Analysis']['Security concerns'] = val data['PR Analysis']['Security concerns'] = val
# Filter out code suggestions that can be submitted as inline comments # Filter out code suggestions that can be submitted as inline comments
if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']: if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
and 'Code suggestions' in data['PR Feedback']:
data['PR Feedback']['Code suggestions'] = [ data['PR Feedback']['Code suggestions'] = [
d for d in data['PR Feedback']['Code suggestions'] d for d in data['PR Feedback']['Code suggestions']
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content')) if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
@ -184,7 +183,8 @@ class PRReviewer:
# Add incremental review section # Add incremental review section
if self.incremental.is_incremental: if self.incremental.is_incremental:
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
f"{self.git_provider.incremental.first_new_commit_sha}"
data = OrderedDict(data) data = OrderedDict(data)
data.update({'Incremental PR Review': { data.update({'Incremental PR Review': {
"⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}}) "⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}})
@ -194,7 +194,7 @@ class PRReviewer:
user = self.git_provider.get_user_id() user = self.git_provider.get_user_id()
# Add help text if not in CLI mode # Add help text if not in CLI mode
if not self.cli_mode: if get_settings().get("CONFIG.CLI_MODE", False):
markdown_text += "\n### How to use\n" markdown_text += "\n### How to use\n"
if user and '[bot]' not in user: if user and '[bot]' not in user:
markdown_text += bot_help_text(user) markdown_text += bot_help_text(user)
@ -202,7 +202,7 @@ class PRReviewer:
markdown_text += actions_help_text markdown_text += actions_help_text
# Log markdown response if verbosity level is high # Log markdown response if verbosity level is high
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}") logging.info(f"Markdown response:\n{markdown_text}")
return markdown_text return markdown_text
@ -211,7 +211,7 @@ class PRReviewer:
""" """
Publishes inline comments on a pull request with code suggestions generated by the AI model. Publishes inline comments on a pull request with code suggestions generated by the AI model.
""" """
if settings.pr_reviewer.num_code_suggestions == 0: if get_settings().pr_reviewer.num_code_suggestions == 0:
return return
review = self.prediction.strip() review = self.prediction.strip()

View File

@ -9,9 +9,8 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.algo.utils import update_settings_from_args from pr_agent.git_providers import GithubProvider, get_git_provider
from pr_agent.git_providers import get_git_provider, GithubProvider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
CHANGELOG_LINES = 50 CHANGELOG_LINES = 50
@ -24,8 +23,7 @@ class PRUpdateChangelog:
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
update_settings_from_args(args) self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self.commit_changelog = settings.pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str self._get_changlog_file() # self.changelog_file_str
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
@ -39,23 +37,23 @@ class PRUpdateChangelog:
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"changelog_file_str": self.changelog_file_str, "changelog_file_str": self.changelog_file_str,
"today": date.today(), "today": date.today(),
"extra_instructions": settings.pr_update_changelog.extra_instructions, "extra_instructions": get_settings().pr_update_changelog.extra_instructions,
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_update_changelog_prompt.system, get_settings().pr_update_changelog_prompt.system,
settings.pr_update_changelog_prompt.user) get_settings().pr_update_changelog_prompt.user)
async def update_changelog(self): async def run(self):
assert type(self.git_provider) == GithubProvider, "Currently only Github is supported" assert type(self.git_provider) == GithubProvider, "Currently only Github is supported"
logging.info('Updating the changelog...') logging.info('Updating the changelog...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True) self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR changelog updates...') logging.info('Preparing PR changelog updates...')
new_file_content, answer = self._prepare_changelog_update() new_file_content, answer = self._prepare_changelog_update()
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
logging.info('Publishing changelog updates...') logging.info('Publishing changelog updates...')
if self.commit_changelog: if self.commit_changelog:
@ -75,9 +73,9 @@ class PRUpdateChangelog:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_update_changelog_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_update_changelog_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -86,7 +84,7 @@ class PRUpdateChangelog:
return response return response
def _prepare_changelog_update(self) -> Tuple[str, str]: def _prepare_changelog_update(self) -> Tuple[str, str]:
answer = self.prediction.strip().strip("```").strip() answer = self.prediction.strip().strip("```").strip() # noqa B005
if hasattr(self, "changelog_file"): if hasattr(self, "changelog_file"):
existing_content = self.changelog_file.decoded_content.decode() existing_content = self.changelog_file.decoded_content.decode()
else: else:
@ -100,7 +98,7 @@ class PRUpdateChangelog:
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \ answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n" "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"answer:\n{answer}") logging.info(f"answer:\n{answer}")
return new_file_content, answer return new_file_content, answer
@ -120,7 +118,7 @@ class PRUpdateChangelog:
last_commit_id = list(self.git_provider.pr.get_commits())[-1] last_commit_id = list(self.git_provider.pr.get_commits())[-1]
try: try:
self.git_provider.pr.create_review(commit=last_commit_id, comments=[d]) self.git_provider.pr.create_review(commit=last_commit_id, comments=[d])
except: except Exception:
# we can't create a review for some reason, let's just publish a comment # we can't create a review for some reason, let's just publish a comment
self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}") self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}")
@ -147,7 +145,7 @@ Example:
changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines() changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines()
changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES] changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES]
self.changelog_file_str = "\n".join(changelog_file_lines) self.changelog_file_str = "\n".join(changelog_file_lines)
except: except Exception:
self.changelog_file_str = "" self.changelog_file_str = ""
if self.commit_changelog: if self.commit_changelog:
logging.info("No CHANGELOG.md file found in the repository. Creating one...") logging.info("No CHANGELOG.md file found in the repository. Creating one...")

View File

@ -2,7 +2,7 @@
import logging import logging
from pr_agent.algo.git_patch_processing import handle_patch_deletions from pr_agent.algo.git_patch_processing import handle_patch_deletions
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
""" """
Code Analysis Code Analysis
@ -49,7 +49,7 @@ class TestHandlePatchDeletions:
original_file_content_str = 'foo\nbar\n' original_file_content_str = 'foo\nbar\n'
new_file_content_str = '' new_file_content_str = ''
file_name = 'file.py' file_name = 'file.py'
settings.config.verbosity_level = 1 get_settings().config.verbosity_level = 1
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name) handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name)

View File

@ -1,50 +0,0 @@
# Generated by CodiumAI
from pr_agent.algo.utils import update_settings_from_args
import logging
from pr_agent.config_loader import settings
import pytest
class TestUpdateSettingsFromArgs:
# Tests that the function updates the setting when passed a single valid argument.
def test_single_valid_argument(self):
args = ['--pr_code_suggestions.extra_instructions="be funny"']
update_settings_from_args(args)
assert settings.pr_code_suggestions.extra_instructions == '"be funny"'
# Tests that the function updates the settings when passed multiple valid arguments.
def test_multiple_valid_arguments(self):
args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions=3']
update_settings_from_args(args)
assert settings.pr_code_suggestions.extra_instructions == '"be funny"'
assert settings.pr_code_suggestions.num_code_suggestions == 3
# Tests that the function updates the setting when passed a boolean value.
def test_boolean_values(self):
settings.pr_code_suggestions.enabled = False
args = ['--pr_code_suggestions.enabled=true']
update_settings_from_args(args)
assert 'pr_code_suggestions' in settings
assert 'enabled' in settings.pr_code_suggestions
assert settings.pr_code_suggestions.enabled == True
# Tests that the function updates the setting when passed an integer value.
def test_integer_values(self):
args = ['--pr_code_suggestions.num_code_suggestions=3']
update_settings_from_args(args)
assert settings.pr_code_suggestions.num_code_suggestions == 3
# Tests that the function does not update any settings when passed an empty argument list.
def test_empty_argument_list(self):
args = []
update_settings_from_args(args)
assert settings == settings
# Tests that the function logs an error when passed an invalid argument format.
def test_invalid_argument_format(self, caplog):
args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions']
with caplog.at_level(logging.ERROR):
update_settings_from_args(args)
assert 'Invalid argument format' in caplog.text