diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 1a00c977..d0037c95 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,6 +1,7 @@ -import re +import shlex -from pr_agent.config_loader import settings +from pr_agent.algo.utils import update_settings_from_args +from pr_agent.config_loader import get_settings from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_information_from_user import PRInformationFromUser @@ -8,29 +9,40 @@ from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_update_changelog import PRUpdateChangelog +command2class = { + "answer": PRReviewer, + "review": PRReviewer, + "review_pr": PRReviewer, + "reflect": PRInformationFromUser, + "reflect_and_review": PRInformationFromUser, + "describe": PRDescription, + "describe_pr": PRDescription, + "improve": PRCodeSuggestions, + "improve_code": PRCodeSuggestions, + "ask": PRQuestions, + "ask_question": PRQuestions, + "update_changelog": PRUpdateChangelog, +} + +commands = list(command2class.keys()) class PRAgent: def __init__(self): pass async def handle_request(self, pr_url, request) -> bool: - action, *args = request.strip().split() - if any(cmd == action for cmd in ["/answer"]): - await PRReviewer(pr_url, is_answer=True, args=args).review() - elif any(cmd == action for cmd in ["/review", "/review_pr", "/reflect_and_review"]): - if settings.pr_reviewer.ask_and_reflect or "/reflect_and_review" in request: - await PRInformationFromUser(pr_url, args=args).generate_questions() - else: - await PRReviewer(pr_url, args=args).review() - elif any(cmd == action for cmd in ["/describe", "/describe_pr"]): - await PRDescription(pr_url, args=args).describe() - elif any(cmd == action for cmd in ["/improve", "/improve_code"]): - await PRCodeSuggestions(pr_url, args=args).suggest() - elif any(cmd == action for cmd in ["/ask", "/ask_question"]): - await PRQuestions(pr_url, args=args).answer() - elif any(cmd == action for cmd in ["/update_changelog"]): - await PRUpdateChangelog(pr_url, args=args).update_changelog() + request = request.replace("'", "\\'") + lexer = shlex.shlex(request, posix=True) + lexer.whitespace_split = True + action, *args = list(lexer) + args = update_settings_from_args(args) + action = action.lstrip("/").lower() + if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect: + action = "review" + if action == "answer": + await PRReviewer(pr_url, is_answer=True, args=args).run() + elif action in command2class: + await command2class[action](pr_url, args=args).run() else: return False - return True diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 8ab22e05..aa92f5e5 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -1,10 +1,10 @@ import logging import openai -from openai.error import APIError, Timeout, TryAgain, RateLimitError +from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings OPENAI_RETRIES=5 @@ -21,16 +21,16 @@ class AiHandler: Raises a ValueError if the OpenAI key is missing. """ try: - openai.api_key = settings.openai.key - if settings.get("OPENAI.ORG", None): - openai.organization = settings.openai.org - self.deployment_id = settings.get("OPENAI.DEPLOYMENT_ID", None) - if settings.get("OPENAI.API_TYPE", None): - openai.api_type = settings.openai.api_type - if settings.get("OPENAI.API_VERSION", None): - openai.api_version = settings.openai.api_version - if settings.get("OPENAI.API_BASE", None): - openai.api_base = settings.openai.api_base + openai.api_key = get_settings().openai.key + if get_settings().get("OPENAI.ORG", None): + openai.organization = get_settings().openai.org + self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None) + if get_settings().get("OPENAI.API_TYPE", None): + openai.api_type = get_settings().openai.api_type + if get_settings().get("OPENAI.API_VERSION", None): + openai.api_version = get_settings().openai.api_version + if get_settings().get("OPENAI.API_BASE", None): + openai.api_base = get_settings().openai.api_base except AttributeError as e: raise ValueError("OpenAI key is required") from e diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index 8128da48..57e091b4 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging import re -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings def extend_patch(original_file_str, patch_str, num_lines) -> str: @@ -55,7 +55,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: continue extended_patch_lines.append(line) except Exception as e: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.error(f"Failed to extend patch: {e}") return patch_str @@ -126,14 +126,14 @@ def handle_patch_deletions(patch: str, original_file_content_str: str, """ if not new_file_content_str: # logic for handling deleted files - don't show patch, just show that the file was deleted - if settings.config.verbosity_level > 0: + if get_settings().config.verbosity_level > 0: logging.info(f"Processing file: {file_name}, minimizing deletion file") patch = None # file was deleted else: patch_lines = patch.splitlines() patch_new = omit_deletion_hunks(patch_lines) if patch != patch_new: - if settings.config.verbosity_level > 0: + if get_settings().config.verbosity_level > 0: logging.info(f"Processing file: {file_name}, hunks were deleted") patch = patch_new return patch @@ -141,7 +141,8 @@ def handle_patch_deletions(patch: str, original_file_content_str: str, def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: """ - Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file. + Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of + the file. Args: patch (str): The patch string to be converted. diff --git a/pr_agent/algo/language_handler.py b/pr_agent/algo/language_handler.py index c81d7fd6..586a3161 100644 --- a/pr_agent/algo/language_handler.py +++ b/pr_agent/algo/language_handler.py @@ -1,15 +1,15 @@ # Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501 from typing import Dict -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings -language_extension_map_org = settings.language_extension_map_org +language_extension_map_org = get_settings().language_extension_map_org language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} # Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501 -bad_extensions = settings.bad_extensions.default -if settings.config.use_extra_bad_extensions: - bad_extensions += settings.bad_extensions.extra +bad_extensions = get_settings().bad_extensions.default +if get_settings().config.use_extra_bad_extensions: + bad_extensions += get_settings().bad_extensions.extra def filter_bad_extensions(files): diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 45ef40b2..f29a24e9 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Tuple, Union, Callable, List +from typing import Callable, Tuple from github import RateLimitExceededException @@ -10,7 +10,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_large_diff -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import GitProvider DELETED_FILES_ = "Deleted files:\n" @@ -27,11 +27,15 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s Returns a string with the diff of the pull request, applying diff minimization techniques if needed. Args: - git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request. - token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. + git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull + request. + token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the + pull request. model (str): The name of the model used for tokenization. - add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False. - disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False. + add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the + diff. Defaults to False. + disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with + extra lines of context. Defaults to False. Returns: str: A string with the diff of the pull request, applying diff minimization techniques if needed. @@ -76,10 +80,12 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, add_line_numbers_to_hunks: bool) -> \ Tuple[list, int]: """ - Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed. + Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff + minimization techniques if needed. Args: - - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files. + - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding + files. - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request. - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff. @@ -119,10 +125,13 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]: """ - Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used. + Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of + tokens used. Args: - top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files. - token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. + top_langs (list): A list of dictionaries representing the languages used in the pull request and their + corresponding files. + token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the + pull request. model (str): The model used for tokenization. convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff. Returns: @@ -181,7 +190,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo # Current logic is to skip the patch if it's too large # TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens # until we meet the requirements - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.warning(f"Patch too large, minimizing it, {file.filename}") if not modified_files_list: total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_) @@ -196,15 +205,15 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo patch_final = patch patches.append(patch_final) total_tokens += token_handler.count_tokens(patch_final) - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}") return patches, modified_files_list, deleted_files_list async def retry_with_fallback_models(f: Callable): - model = settings.config.model - fallback_models = settings.config.fallback_models + model = get_settings().config.model + fallback_models = get_settings().config.fallback_models if not isinstance(fallback_models, list): fallback_models = [fallback_models] all_models = [model] + fallback_models diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 66659824..0888f8b8 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -1,8 +1,7 @@ from jinja2 import Environment, StrictUndefined from tiktoken import encoding_for_model -from pr_agent.algo import MAX_TOKENS -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings class TokenHandler: @@ -10,9 +9,12 @@ class TokenHandler: A class for handling tokens in the context of a pull request. Attributes: - - encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them. - - limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module. - - prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method. + - encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the + number of tokens in them. + - limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the + pr_agent.algo module. + - prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens + method. """ def __init__(self, pr, vars: dict, system, user): @@ -25,7 +27,7 @@ class TokenHandler: - system: The system string. - user: The user string. """ - self.encoder = encoding_for_model(settings.config.model) + self.encoder = encoding_for_model(get_settings().config.model) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index c31556f1..6d0a9206 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -1,15 +1,24 @@ from __future__ import annotations -from typing import List import difflib -from datetime import datetime import json import logging import re import textwrap +from datetime import datetime +from typing import Any, List -from pr_agent.config_loader import settings +from starlette_context import context +from pr_agent.config_loader import get_settings, global_settings + + +def get_setting(key: str) -> Any: + try: + key = key.upper() + return context.get("settings", global_settings).get(key, global_settings.get(key, None)) + except Exception: + return global_settings.get(key, None) def convert_to_markdown(output_data: dict) -> str: """ @@ -97,12 +106,16 @@ def try_fix_json(review, max_iter=10, code_suggestions=False): - data: A dictionary containing the parsed JSON data. The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion. - If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message. - If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion. - The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines. + If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the + message. + If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON + message by parsing until the last valid code suggestion. + The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or + newlines. It tries to parse the JSON message with the closing bracket and checks if it is valid. If the JSON message is valid, the parsed JSON data is returned. - If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached. + If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON + message is obtained or the maximum number of iterations is reached. If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned. """ @@ -184,7 +197,8 @@ def convert_str_to_datetime(date_str): def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str: """ - Generate a patch for a modified file by comparing the original content of the file with the new content provided as input. + Generate a patch for a modified file by comparing the original content of the file with the new content provided as + input. Args: file: The file object for which the patch needs to be generated. @@ -199,14 +213,16 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str: None. Additional Information: - - If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output. - - If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created. + - If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it + as output. + - If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating + that the file was modified but no patch was found, and a patch is manually created. """ if not patch: # to Do - also add condition for file extension try: diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), new_file_content_str.splitlines(keepends=True)) - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.") patch = ''.join(diff) except Exception: @@ -214,7 +230,7 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str: return patch -def update_settings_from_args(args: List[str]) -> None: +def update_settings_from_args(args: List[str]) -> List[str]: """ Update the settings of the Dynaconf object based on the arguments passed to the function. @@ -230,28 +246,22 @@ def update_settings_from_args(args: List[str]) -> None: ValueError: If the argument is not in the correct format. """ + other_args = [] if args: for arg in args: - try: + arg = arg.strip() + if arg.startswith('--'): arg = arg.strip('-').strip() vals = arg.split('=') if len(vals) != 2: - raise ValueError(f'Invalid argument format: {arg}') + logging.error(f'Invalid argument format: {arg}') + other_args.append(arg) + continue key, value = vals - keys = key.split('.') - d = settings - for i, k in enumerate(keys[:-1]): - if k not in d: - raise ValueError(f'Invalid setting: {key}') - d = d[k] - if keys[-1] not in d: - raise ValueError(f'Invalid setting: {key}') - if isinstance(d[keys[-1]], bool): - d[keys[-1]] = value.lower() in ("yes", "true", "t", "1") - else: - d[keys[-1]] = type(d[keys[-1]])(value) + key = key.strip().upper() + value = value.strip() + get_settings().set(key, value) logging.info(f'Updated setting {key} to: "{value}"') - except ValueError as e: - logging.error(str(e)) - except Exception as e: - logging.error(f'Failed to parse argument {arg}: {e}') \ No newline at end of file + else: + other_args.append(arg) + return other_args diff --git a/pr_agent/cli.py b/pr_agent/cli.py index b0173ace..8dd21b3f 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -3,15 +3,11 @@ import asyncio import logging import os -from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions -from pr_agent.tools.pr_description import PRDescription -from pr_agent.tools.pr_information_from_user import PRInformationFromUser -from pr_agent.tools.pr_questions import PRQuestions -from pr_agent.tools.pr_reviewer import PRReviewer -from pr_agent.tools.pr_update_changelog import PRUpdateChangelog +from pr_agent.agent.pr_agent import PRAgent, commands +from pr_agent.config_loader import get_settings -def run(args=None): +def run(inargs=None): parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage= """\ Usage: cli.py --pr-url []. @@ -34,79 +30,16 @@ To edit any configuration parameter from 'configuration.toml', just add -config_ For example: '- cli.py --pr-url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' """) parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) - parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr', - 'ask', 'ask_question', - 'describe', 'describe_pr', - 'improve', 'improve_code', - 'reflect', 'review_after_reflect', - 'update_changelog'], - default='review') + parser.add_argument('command', type=str, help='The', choices=commands, default='review') parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) - args = parser.parse_args(args) + args = parser.parse_args(inargs) logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) command = args.command.lower() - commands = { - 'ask': _handle_ask_command, - 'ask_question': _handle_ask_command, - 'describe': _handle_describe_command, - 'describe_pr': _handle_describe_command, - 'improve': _handle_improve_command, - 'improve_code': _handle_improve_command, - 'review': _handle_review_command, - 'review_pr': _handle_review_command, - 'reflect': _handle_reflect_command, - 'review_after_reflect': _handle_review_after_reflect_command, - 'update_changelog': _handle_update_changelog, - } - if command in commands: - commands[command](args.pr_url, args.rest) - else: - print(f"Unknown command: {command}") + get_settings().set("CONFIG.CLI_MODE", True) + result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest))) + if not result: parser.print_help() -def _handle_ask_command(pr_url: str, rest: list): - if len(rest) == 0: - print("Please specify a question") - return - print(f"Question: {' '.join(rest)} about PR {pr_url}") - reviewer = PRQuestions(pr_url, rest) - asyncio.run(reviewer.answer()) - - -def _handle_describe_command(pr_url: str, rest: list): - print(f"PR description: {pr_url}") - reviewer = PRDescription(pr_url, args=rest) - asyncio.run(reviewer.describe()) - - -def _handle_improve_command(pr_url: str, rest: list): - print(f"PR code suggestions: {pr_url}") - reviewer = PRCodeSuggestions(pr_url, args=rest) - asyncio.run(reviewer.suggest()) - - -def _handle_review_command(pr_url: str, rest: list): - print(f"Reviewing PR: {pr_url}") - reviewer = PRReviewer(pr_url, cli_mode=True, args=rest) - asyncio.run(reviewer.review()) - - -def _handle_reflect_command(pr_url: str, rest: list): - print(f"Asking the PR author questions: {pr_url}") - reviewer = PRInformationFromUser(pr_url) - asyncio.run(reviewer.generate_questions()) - - -def _handle_review_after_reflect_command(pr_url: str, rest: list): - print(f"Processing author's answers and sending review: {pr_url}") - reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True, args=rest) - asyncio.run(reviewer.review()) - -def _handle_update_changelog(pr_url: str, rest: list): - print(f"Updating changlog for: {pr_url}") - reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest) - asyncio.run(reviewer.update_changelog()) - if __name__ == '__main__': run() diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 5716fe8a..3075e8dc 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -3,28 +3,36 @@ from pathlib import Path from typing import Optional from dynaconf import Dynaconf +from starlette_context import context PR_AGENT_TOML_KEY = 'pr-agent' current_dir = dirname(abspath(__file__)) -settings = Dynaconf( +global_settings = Dynaconf( envvar_prefix=False, merge_enabled=True, settings_files=[join(current_dir, f) for f in [ - "settings/.secrets.toml", - "settings/configuration.toml", - "settings/language_extensions.toml", - "settings/pr_reviewer_prompts.toml", - "settings/pr_questions_prompts.toml", - "settings/pr_description_prompts.toml", - "settings/pr_code_suggestions_prompts.toml", - "settings/pr_information_from_user_prompts.toml", - "settings/pr_update_changelog_prompts.toml", - "settings_prod/.secrets.toml" - ]] + "settings/.secrets.toml", + "settings/configuration.toml", + "settings/language_extensions.toml", + "settings/pr_reviewer_prompts.toml", + "settings/pr_questions_prompts.toml", + "settings/pr_description_prompts.toml", + "settings/pr_code_suggestions_prompts.toml", + "settings/pr_information_from_user_prompts.toml", + "settings/pr_update_changelog_prompts.toml", + "settings_prod/.secrets.toml" + ]] ) +def get_settings(): + try: + return context["settings"] + except Exception: + return global_settings + + # Add local configuration from pyproject.toml of the project being reviewed def _find_repository_root() -> Path: """ @@ -39,6 +47,7 @@ def _find_repository_root() -> Path: cwd = cwd.parent return None + def _find_pyproject() -> Optional[Path]: """ Search for file pyproject.toml in the repository root. @@ -49,6 +58,7 @@ def _find_pyproject() -> Optional[Path]: return pyproject if pyproject.is_file() else None return None + pyproject_path = _find_pyproject() if pyproject_path is not None: - settings.load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}') + get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}') diff --git a/pr_agent/git_providers/__init__.py b/pr_agent/git_providers/__init__.py index 421fcd5d..e7c2aa0f 100644 --- a/pr_agent/git_providers/__init__.py +++ b/pr_agent/git_providers/__init__.py @@ -1,4 +1,4 @@ -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings from pr_agent.git_providers.bitbucket_provider import BitbucketProvider from pr_agent.git_providers.github_provider import GithubProvider from pr_agent.git_providers.gitlab_provider import GitLabProvider @@ -13,7 +13,7 @@ _GIT_PROVIDERS = { def get_git_provider(): try: - provider_id = settings.config.git_provider + provider_id = get_settings().config.git_provider except AttributeError as e: raise ValueError("git_provider is a required attribute in the configuration file") from e if provider_id not in _GIT_PROVIDERS: diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index 27694f8f..2f3ec2c2 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -5,15 +5,14 @@ from urllib.parse import urlparse import requests from atlassian.bitbucket import Cloud -from pr_agent.config_loader import settings - +from ..config_loader import get_settings from .git_provider import FilePatchInfo class BitbucketProvider: def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): s = requests.Session() - s.headers['Authorization'] = f'Bearer {settings.get("BITBUCKET.BEARER_TOKEN", None)}' + s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}' self.bitbucket_client = Cloud(session=s) self.workspace_slug = None diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 341335df..ae3eaeba 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -7,12 +7,11 @@ from github import AppAuthentication, Auth, Github, GithubException from retry import retry from starlette_context import context -from pr_agent.config_loader import settings - from ..algo.language_handler import is_valid_file from ..algo.utils import load_large_diff -from .git_provider import FilePatchInfo, GitProvider, IncrementalPR +from ..config_loader import get_settings from ..servers.utils import RateLimitExceeded +from .git_provider import FilePatchInfo, GitProvider, IncrementalPR class GithubProvider(GitProvider): @@ -85,7 +84,7 @@ class GithubProvider(GitProvider): return self.pr.get_files() @retry(exceptions=RateLimitExceeded, - tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) + tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) def get_diff_files(self) -> list[FilePatchInfo]: try: files = self.get_files() @@ -118,7 +117,7 @@ class GithubProvider(GitProvider): # self.pr.create_issue_comment(pr_comment) def publish_comment(self, pr_comment: str, is_temporary: bool = False): - if is_temporary and not settings.config.publish_output_progress: + if is_temporary and not get_settings().config.publish_output_progress: logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}") return response = self.pr.create_issue_comment(pr_comment) @@ -149,7 +148,7 @@ class GithubProvider(GitProvider): position = i break if position == -1: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") subject_type = "FILE" else: @@ -174,13 +173,13 @@ class GithubProvider(GitProvider): relevant_lines_end = suggestion['relevant_lines_end'] if not relevant_lines_start or relevant_lines_start == -1: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.exception( f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") continue if relevant_lines_end < relevant_lines_start: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.exception(f"Failed to publish code suggestion, " f"relevant_lines_end is {relevant_lines_end} and " f"relevant_lines_start is {relevant_lines_start}") @@ -207,7 +206,7 @@ class GithubProvider(GitProvider): self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list) return True except Exception as e: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.error(f"Failed to publish code suggestion, error: {e}") return False @@ -241,7 +240,7 @@ class GithubProvider(GitProvider): return self.github_user_id def get_notifications(self, since: datetime): - deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user") + deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") if deployment_type != 'user': raise ValueError("Deployment mode must be set to 'user' to get notifications") @@ -282,12 +281,12 @@ class GithubProvider(GitProvider): return repo_name, pr_number def _get_github_client(self): - deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user") + deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") if deployment_type == 'app': try: - private_key = settings.github.private_key - app_id = settings.github.app_id + private_key = get_settings().github.private_key + app_id = get_settings().github.app_id except AttributeError as e: raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e if not self.installation_id: @@ -298,7 +297,7 @@ class GithubProvider(GitProvider): if deployment_type == 'user': try: - token = settings.github.user_token + token = get_settings().github.user_token except AttributeError as e: raise ValueError( "GitHub token is required when using user deployment. See: " @@ -327,7 +326,9 @@ class GithubProvider(GitProvider): def publish_labels(self, pr_types): try: - label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", "Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9", "Other": "d1bcf9"} + label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", + "Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9", + "Other": "d1bcf9"} post_parameters = [] for p in pr_types: color = label_color_map.get(p, "d1bcf9") # default to "Other" color diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 279815d6..10363ec1 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -6,9 +6,8 @@ from urllib.parse import urlparse import gitlab from gitlab import GitlabGetError -from pr_agent.config_loader import settings - from ..algo.language_handler import is_valid_file +from ..config_loader import get_settings from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider logger = logging.getLogger() @@ -17,10 +16,10 @@ logger = logging.getLogger() class GitLabProvider(GitProvider): def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): - gitlab_url = settings.get("GITLAB.URL", None) + gitlab_url = get_settings().get("GITLAB.URL", None) if not gitlab_url: raise ValueError("GitLab URL is not set in the config file") - gitlab_access_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None) + gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None) if not gitlab_access_token: raise ValueError("GitLab personal access token is not set in the config file") self.gl = gitlab.Gitlab( diff --git a/pr_agent/git_providers/local_git_provider.py b/pr_agent/git_providers/local_git_provider.py index 304417ea..a4f21969 100644 --- a/pr_agent/git_providers/local_git_provider.py +++ b/pr_agent/git_providers/local_git_provider.py @@ -5,7 +5,7 @@ from typing import List from git import Repo -from pr_agent.config_loader import _find_repository_root, settings +from pr_agent.config_loader import _find_repository_root, get_settings from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider @@ -38,12 +38,12 @@ class LocalGitProvider(GitProvider): self._prepare_repo() self.diff_files = None self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files()) - self.description_path = settings.get('local.description_path') \ - if settings.get('local.description_path') is not None else self.repo_path / 'description.md' - self.review_path = settings.get('local.review_path') \ - if settings.get('local.review_path') is not None else self.repo_path / 'review.md' + self.description_path = get_settings().get('local.description_path') \ + if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md' + self.review_path = get_settings().get('local.review_path') \ + if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md' # inline code comments are not supported for local git repositories - settings.pr_reviewer.inline_code_comments = False + get_settings().pr_reviewer.inline_code_comments = False def _prepare_repo(self): """ diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py index 2aa424d7..9846e199 100644 --- a/pr_agent/servers/github_action_runner.py +++ b/pr_agent/servers/github_action_runner.py @@ -3,7 +3,7 @@ import json import os from pr_agent.agent.pr_agent import PRAgent -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings from pr_agent.tools.pr_reviewer import PRReviewer @@ -30,11 +30,11 @@ async def run_action(): return # Set the environment variables in the settings - settings.set("OPENAI.KEY", OPENAI_KEY) + get_settings().set("OPENAI.KEY", OPENAI_KEY) if OPENAI_ORG: - settings.set("OPENAI.ORG", OPENAI_ORG) - settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN) - settings.set("GITHUB.DEPLOYMENT_TYPE", "user") + get_settings().set("OPENAI.ORG", OPENAI_ORG) + get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN) + get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user") # Load the event payload try: @@ -50,7 +50,7 @@ async def run_action(): if action in ["opened", "reopened"]: pr_url = event_payload.get("pull_request", {}).get("url") if pr_url: - await PRReviewer(pr_url).review() + await PRReviewer(pr_url).run() # Handle issue comment event elif GITHUB_EVENT_NAME == "issue_comment": diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index a9ba1de5..263f5ba5 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -1,6 +1,7 @@ -from typing import Dict, Any +import copy import logging import sys +from typing import Any, Dict import uvicorn from fastapi import APIRouter, FastAPI, HTTPException, Request, Response @@ -9,7 +10,7 @@ from starlette_context import context from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings, global_settings from pr_agent.servers.utils import verify_signature logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) @@ -20,7 +21,8 @@ router = APIRouter() async def handle_github_webhooks(request: Request, response: Response): """ Receives and processes incoming GitHub webhook requests. - Verifies the request signature, parses the request body, and passes it to the handle_request function for further processing. + Verifies the request signature, parses the request body, and passes it to the handle_request function for further + processing. """ logging.debug("Received a GitHub webhook") @@ -29,6 +31,7 @@ async def handle_github_webhooks(request: Request, response: Response): logging.debug(f'Request body:\n{body}') installation_id = body.get("installation", {}).get("id") context["installation_id"] = installation_id + context["settings"] = copy.deepcopy(global_settings) return await handle_request(body) @@ -46,7 +49,7 @@ async def get_body(request): raise HTTPException(status_code=400, detail="Error parsing request body") from e body_bytes = await request.body() signature_header = request.headers.get('x-hub-signature-256', None) - webhook_secret = getattr(settings.github, 'webhook_secret', None) + webhook_secret = getattr(get_settings().github, 'webhook_secret', None) if webhook_secret: verify_signature(body_bytes, webhook_secret, signature_header) return body @@ -62,6 +65,8 @@ async def handle_request(body: Dict[str, Any]): body: The request body. """ action = body.get("action") + if not action: + return {} agent = PRAgent() if action == 'created': @@ -77,7 +82,7 @@ async def handle_request(body: Dict[str, Any]): api_url = pull_request.get("url") await agent.handle_request(api_url, comment_body) - elif action in ["opened"] or 'reopened' in action: + elif action == "opened" or 'reopened' in action: pull_request = body.get("pull_request") if not pull_request: return {} @@ -96,7 +101,7 @@ async def root(): def start(): # Override the deployment type to app - settings.set("GITHUB.DEPLOYMENT_TYPE", "app") + get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app") middleware = [Middleware(RawContextMiddleware)] app = FastAPI(middleware=middleware) app.include_router(router) diff --git a/pr_agent/servers/github_polling.py b/pr_agent/servers/github_polling.py index 06550418..18f71dd7 100644 --- a/pr_agent/servers/github_polling.py +++ b/pr_agent/servers/github_polling.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone import aiohttp from pr_agent.agent.pr_agent import PRAgent -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.servers.help import bot_help_text @@ -38,8 +38,8 @@ async def polling_loop(): agent = PRAgent() try: - deployment_type = settings.github.deployment_type - token = settings.github.user_token + deployment_type = get_settings().github.deployment_type + token = get_settings().github.user_token except AttributeError: deployment_type = 'none' token = None diff --git a/pr_agent/servers/gitlab_webhook.py b/pr_agent/servers/gitlab_webhook.py index 75bef3cc..c9b623f7 100644 --- a/pr_agent/servers/gitlab_webhook.py +++ b/pr_agent/servers/gitlab_webhook.py @@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse from starlette.background import BackgroundTasks from pr_agent.agent.pr_agent import PRAgent -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings app = FastAPI() router = APIRouter() @@ -29,13 +29,13 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request): return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) def start(): - gitlab_url = settings.get("GITLAB.URL", None) + gitlab_url = get_settings().get("GITLAB.URL", None) if not gitlab_url: raise ValueError("GITLAB.URL is not set") - gitlab_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None) + gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None) if not gitlab_token: raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set") - settings.config.git_provider = "gitlab" + get_settings().config.git_provider = "gitlab" app = FastAPI() app.include_router(router) diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 3c1477cc..71aecd7a 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -8,8 +8,8 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import try_fix_json, update_settings_from_args -from pr_agent.config_loader import settings +from pr_agent.algo.utils import try_fix_json +from pr_agent.config_loader import get_settings from pr_agent.git_providers import BitbucketProvider, get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -21,7 +21,6 @@ class PRCodeSuggestions: self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - update_settings_from_args(args) self.ai_handler = AiHandler() self.patches_diff = None @@ -33,24 +32,24 @@ class PRCodeSuggestions: "description": self.git_provider.get_pr_description(), "language": self.main_language, "diff": "", # empty diff for initial calculation - "num_code_suggestions": settings.pr_code_suggestions.num_code_suggestions, - "extra_instructions": settings.pr_code_suggestions.extra_instructions, + "num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions, + "extra_instructions": get_settings().pr_code_suggestions.extra_instructions, } self.token_handler = TokenHandler(self.git_provider.pr, self.vars, - settings.pr_code_suggestions_prompt.system, - settings.pr_code_suggestions_prompt.user) + get_settings().pr_code_suggestions_prompt.system, + get_settings().pr_code_suggestions_prompt.user) - async def suggest(self): + async def run(self): assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now" logging.info('Generating code suggestions for PR...') - if settings.config.publish_output: + if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing review...", is_temporary=True) await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing PR review...') data = self._prepare_pr_code_suggestions() - if settings.config.publish_output: + if get_settings().config.publish_output: logging.info('Pushing PR review...') self.git_provider.remove_initial_comment() logging.info('Pushing inline code comments...') @@ -71,9 +70,9 @@ class PRCodeSuggestions: variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(settings.pr_code_suggestions_prompt.system).render(variables) - user_prompt = environment.from_string(settings.pr_code_suggestions_prompt.user).render(variables) - if settings.config.verbosity_level >= 2: + system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables) + if get_settings().config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, @@ -86,7 +85,7 @@ class PRCodeSuggestions: try: data = json.loads(review) except json.decoder.JSONDecodeError: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"Could not parse json response: {review}") data = try_fix_json(review, code_suggestions=True) return data @@ -95,7 +94,7 @@ class PRCodeSuggestions: code_suggestions = [] for d in data['Code suggestions']: try: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"suggestion: {d}") relevant_file = d['relevant file'].strip() relevant_lines_str = d['relevant lines'].strip() @@ -113,8 +112,8 @@ class PRCodeSuggestions: code_suggestions.append({'body': body, 'relevant_file': relevant_file, 'relevant_lines_start': relevant_lines_start, 'relevant_lines_end': relevant_lines_end}) - except: - if settings.config.verbosity_level >= 2: + except Exception: + if get_settings().config.verbosity_level >= 2: logging.info(f"Could not parse suggestion: {d}") self.git_provider.publish_code_suggestions(code_suggestions) @@ -136,7 +135,7 @@ class PRCodeSuggestions: if delta_spaces > 0: new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n') except Exception as e: - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}") return new_code_snippet diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index ea0fb92f..aaf14227 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -1,15 +1,14 @@ import copy import json import logging -from typing import Tuple, List +from typing import List, Tuple from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import update_settings_from_args -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -17,13 +16,12 @@ from pr_agent.git_providers.git_provider import get_main_pr_language class PRDescription: def __init__(self, pr_url: str, args: list = None): """ - Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. + Initialize the PRDescription object with the necessary attributes and objects for generating a PR description + using an AI model. Args: pr_url (str): The URL of the pull request. args (list, optional): List of arguments passed to the PRDescription class. Defaults to None. """ - update_settings_from_args(args) - # Initialize the git provider and main PR language self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( @@ -41,28 +39,28 @@ class PRDescription: "description": self.git_provider.get_pr_description(), "language": self.main_pr_language, "diff": "", # empty diff for initial calculation - "extra_instructions": settings.pr_description.extra_instructions, - "commit_messages_str": commit_messages_str, + "extra_instructions": get_settings().pr_description.extra_instructions, + "commit_messages_str": commit_messages_str } # Initialize the token handler self.token_handler = TokenHandler( self.git_provider.pr, self.vars, - settings.pr_description_prompt.system, - settings.pr_description_prompt.user, + get_settings().pr_description_prompt.system, + get_settings().pr_description_prompt.user, ) # Initialize patches_diff and prediction attributes self.patches_diff = None self.prediction = None - async def describe(self): + async def run(self): """ Generates a PR description using an AI model and publishes it to the PR. """ logging.info('Generating a PR description...') - if settings.config.publish_output: + if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) await retry_with_fallback_models(self._prepare_prediction) @@ -70,9 +68,9 @@ class PRDescription: logging.info('Preparing answer...') pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer() - if settings.config.publish_output: + if get_settings().config.publish_output: logging.info('Pushing answer...') - if settings.pr_description.publish_description_as_comment: + if get_settings().pr_description.publish_description_as_comment: self.git_provider.publish_comment(markdown_text) else: self.git_provider.publish_description(pr_title, pr_body) @@ -118,10 +116,10 @@ class PRDescription: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables) - user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables) + system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") @@ -172,7 +170,7 @@ class PRDescription: else: pr_body += f"{value}\n\n___\n" - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"title:\n{title}\n{pr_body}") return title, pr_body, pr_types, markdown_text \ No newline at end of file diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index feeb0e31..10d32381 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -6,13 +6,11 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language - - class PRInformationFromUser: def __init__(self, pr_url: str, args: list = None): self.git_provider = get_git_provider()(pr_url) @@ -29,19 +27,19 @@ class PRInformationFromUser: } self.token_handler = TokenHandler(self.git_provider.pr, self.vars, - settings.pr_information_from_user_prompt.system, - settings.pr_information_from_user_prompt.user) + get_settings().pr_information_from_user_prompt.system, + get_settings().pr_information_from_user_prompt.user) self.patches_diff = None self.prediction = None - async def generate_questions(self): + async def run(self): logging.info('Generating question to the user...') - if settings.config.publish_output: + if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing questions...", is_temporary=True) await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing questions...') pr_comment = self._prepare_pr_answer() - if settings.config.publish_output: + if get_settings().config.publish_output: logging.info('Pushing questions...') self.git_provider.publish_comment(pr_comment) self.git_provider.remove_initial_comment() @@ -57,9 +55,9 @@ class PRInformationFromUser: variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(settings.pr_information_from_user_prompt.system).render(variables) - user_prompt = environment.from_string(settings.pr_information_from_user_prompt.user).render(variables) - if settings.config.verbosity_level >= 2: + system_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.user).render(variables) + if get_settings().config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, @@ -68,7 +66,7 @@ class PRInformationFromUser: def _prepare_pr_answer(self) -> str: model_output = self.prediction.strip() - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"answer_str:\n{model_output}") answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\ "\n>/answer\n>1) ...\n>2) ...\n>...\n" diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 589cf3e9..33923776 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -6,7 +6,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -30,8 +30,8 @@ class PRQuestions: } self.token_handler = TokenHandler(self.git_provider.pr, self.vars, - settings.pr_questions_prompt.system, - settings.pr_questions_prompt.user) + get_settings().pr_questions_prompt.system, + get_settings().pr_questions_prompt.user) self.patches_diff = None self.prediction = None @@ -42,14 +42,14 @@ class PRQuestions: question_str = "" return question_str - async def answer(self): + async def run(self): logging.info('Answering a PR question...') - if settings.config.publish_output: + if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing answer...", is_temporary=True) await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing answer...') pr_comment = self._prepare_pr_answer() - if settings.config.publish_output: + if get_settings().config.publish_output: logging.info('Pushing answer...') self.git_provider.publish_comment(pr_comment) self.git_provider.remove_initial_comment() @@ -65,9 +65,9 @@ class PRQuestions: variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(settings.pr_questions_prompt.system).render(variables) - user_prompt = environment.from_string(settings.pr_questions_prompt.user).render(variables) - if settings.config.verbosity_level >= 2: + system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables) + if get_settings().config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, @@ -77,6 +77,6 @@ class PRQuestions: def _prepare_pr_answer(self) -> str: answer_str = f"Question: {self.question_str}\n\n" answer_str += f"Answer:\n{self.prediction.strip()}\n\n" - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"answer_str:\n{answer_str}") return answer_str diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 8e2343d0..f8610863 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -2,17 +2,17 @@ import copy import json import logging from collections import OrderedDict -from typing import Tuple, List +from typing import List, Tuple from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import convert_to_markdown, try_fix_json, update_settings_from_args -from pr_agent.config_loader import settings +from pr_agent.algo.utils import convert_to_markdown, try_fix_json +from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider -from pr_agent.git_providers.git_provider import get_main_pr_language, IncrementalPR +from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language from pr_agent.servers.help import actions_help_text, bot_help_text @@ -20,17 +20,15 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None): + def __init__(self, pr_url: str, is_answer: bool = False, args: list = None): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Args: pr_url (str): The URL of the pull request to be reviewed. - cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. """ - update_settings_from_args(args) self.parse_args(args) # -i command self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) @@ -41,11 +39,10 @@ class PRReviewer: self.is_answer = is_answer if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): - raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now") + raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") self.ai_handler = AiHandler() self.patches_diff = None self.prediction = None - self.cli_mode = cli_mode answer_str, question_str = self._get_user_answers() self.vars = { @@ -54,21 +51,21 @@ class PRReviewer: "description": self.git_provider.get_pr_description(), "language": self.main_language, "diff": "", # empty diff for initial calculation - "require_score": settings.pr_reviewer.require_score_review, - "require_tests": settings.pr_reviewer.require_tests_review, - "require_security": settings.pr_reviewer.require_security_review, - "require_focused": settings.pr_reviewer.require_focused_review, - 'num_code_suggestions': settings.pr_reviewer.num_code_suggestions, + "require_score": get_settings().pr_reviewer.require_score_review, + "require_tests": get_settings().pr_reviewer.require_tests_review, + "require_security": get_settings().pr_reviewer.require_security_review, + "require_focused": get_settings().pr_reviewer.require_focused_review, + 'num_code_suggestions': get_settings().pr_reviewer.num_code_suggestions, 'question_str': question_str, 'answer_str': answer_str, - "extra_instructions": settings.pr_reviewer.extra_instructions, + "extra_instructions": get_settings().pr_reviewer.extra_instructions, } self.token_handler = TokenHandler( self.git_provider.pr, self.vars, - settings.pr_review_prompt.system, - settings.pr_review_prompt.user + get_settings().pr_review_prompt.system, + get_settings().pr_review_prompt.user ) def parse_args(self, args: List[str]) -> None: @@ -88,13 +85,13 @@ class PRReviewer: is_incremental = True self.incremental = IncrementalPR(is_incremental) - async def review(self) -> None: + async def run(self) -> None: """ Review the pull request and generate feedback. """ logging.info('Reviewing PR...') - if settings.config.publish_output: + if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing review...", is_temporary=True) await retry_with_fallback_models(self._prepare_prediction) @@ -102,12 +99,12 @@ class PRReviewer: logging.info('Preparing PR review...') pr_comment = self._prepare_pr_review() - if settings.config.publish_output: + if get_settings().config.publish_output: logging.info('Pushing PR review...') self.git_provider.publish_comment(pr_comment) self.git_provider.remove_initial_comment() - if settings.pr_reviewer.inline_code_comments: + if get_settings().pr_reviewer.inline_code_comments: logging.info('Pushing inline code comments...') self._publish_inline_code_comments() @@ -140,10 +137,10 @@ class PRReviewer: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables) - user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables) + system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables) - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") @@ -158,7 +155,8 @@ class PRReviewer: def _prepare_pr_review(self) -> str: """ - Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback. + Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes + the feedback. """ review = self.prediction.strip() @@ -174,7 +172,8 @@ class PRReviewer: data['PR Analysis']['Security concerns'] = val # Filter out code suggestions that can be submitted as inline comments - if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']: + if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \ + and 'Code suggestions' in data['PR Feedback']: data['PR Feedback']['Code suggestions'] = [ d for d in data['PR Feedback']['Code suggestions'] if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content')) @@ -184,7 +183,8 @@ class PRReviewer: # Add incremental review section if self.incremental.is_incremental: - last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" + last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \ + f"{self.git_provider.incremental.first_new_commit_sha}" data = OrderedDict(data) data.update({'Incremental PR Review': { "⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}}) @@ -194,7 +194,7 @@ class PRReviewer: user = self.git_provider.get_user_id() # Add help text if not in CLI mode - if not self.cli_mode: + if not get_settings().get("CONFIG.CLI_MODE", False): markdown_text += "\n### How to use\n" if user and '[bot]' not in user: markdown_text += bot_help_text(user) @@ -202,7 +202,7 @@ class PRReviewer: markdown_text += actions_help_text # Log markdown response if verbosity level is high - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"Markdown response:\n{markdown_text}") return markdown_text @@ -211,7 +211,7 @@ class PRReviewer: """ Publishes inline comments on a pull request with code suggestions generated by the AI model. """ - if settings.pr_reviewer.num_code_suggestions == 0: + if get_settings().pr_reviewer.num_code_suggestions == 0: return review = self.prediction.strip() diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index a4f93978..91b4da9e 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -9,9 +9,8 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.config_loader import settings -from pr_agent.algo.utils import update_settings_from_args -from pr_agent.git_providers import get_git_provider, GithubProvider +from pr_agent.config_loader import get_settings +from pr_agent.git_providers import GithubProvider, get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language CHANGELOG_LINES = 50 @@ -24,8 +23,7 @@ class PRUpdateChangelog: self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - update_settings_from_args(args) - self.commit_changelog = settings.pr_update_changelog.push_changelog_changes + self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes self._get_changlog_file() # self.changelog_file_str self.ai_handler = AiHandler() self.patches_diff = None @@ -39,23 +37,23 @@ class PRUpdateChangelog: "diff": "", # empty diff for initial calculation "changelog_file_str": self.changelog_file_str, "today": date.today(), - "extra_instructions": settings.pr_update_changelog.extra_instructions, + "extra_instructions": get_settings().pr_update_changelog.extra_instructions, } self.token_handler = TokenHandler(self.git_provider.pr, self.vars, - settings.pr_update_changelog_prompt.system, - settings.pr_update_changelog_prompt.user) + get_settings().pr_update_changelog_prompt.system, + get_settings().pr_update_changelog_prompt.user) - async def update_changelog(self): + async def run(self): assert type(self.git_provider) == GithubProvider, "Currently only Github is supported" logging.info('Updating the changelog...') - if settings.config.publish_output: + if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True) await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing PR changelog updates...') new_file_content, answer = self._prepare_changelog_update() - if settings.config.publish_output: + if get_settings().config.publish_output: self.git_provider.remove_initial_comment() logging.info('Publishing changelog updates...') if self.commit_changelog: @@ -75,9 +73,9 @@ class PRUpdateChangelog: variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(settings.pr_update_changelog_prompt.system).render(variables) - user_prompt = environment.from_string(settings.pr_update_changelog_prompt.user).render(variables) - if settings.config.verbosity_level >= 2: + system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables) + if get_settings().config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, @@ -86,7 +84,7 @@ class PRUpdateChangelog: return response def _prepare_changelog_update(self) -> Tuple[str, str]: - answer = self.prediction.strip().strip("```").strip() + answer = self.prediction.strip().strip("```").strip() # noqa B005 if hasattr(self, "changelog_file"): existing_content = self.changelog_file.decoded_content.decode() else: @@ -100,7 +98,7 @@ class PRUpdateChangelog: answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \ "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n" - if settings.config.verbosity_level >= 2: + if get_settings().config.verbosity_level >= 2: logging.info(f"answer:\n{answer}") return new_file_content, answer @@ -120,7 +118,7 @@ class PRUpdateChangelog: last_commit_id = list(self.git_provider.pr.get_commits())[-1] try: self.git_provider.pr.create_review(commit=last_commit_id, comments=[d]) - except: + except Exception: # we can't create a review for some reason, let's just publish a comment self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}") @@ -147,7 +145,7 @@ Example: changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines() changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES] self.changelog_file_str = "\n".join(changelog_file_lines) - except: + except Exception: self.changelog_file_str = "" if self.commit_changelog: logging.info("No CHANGELOG.md file found in the repository. Creating one...") diff --git a/tests/unittest/test_handle_patch_deletions.py b/tests/unittest/test_handle_patch_deletions.py index 95ab8674..152ea4b2 100644 --- a/tests/unittest/test_handle_patch_deletions.py +++ b/tests/unittest/test_handle_patch_deletions.py @@ -2,7 +2,7 @@ import logging from pr_agent.algo.git_patch_processing import handle_patch_deletions -from pr_agent.config_loader import settings +from pr_agent.config_loader import get_settings """ Code Analysis @@ -49,7 +49,7 @@ class TestHandlePatchDeletions: original_file_content_str = 'foo\nbar\n' new_file_content_str = '' file_name = 'file.py' - settings.config.verbosity_level = 1 + get_settings().config.verbosity_level = 1 with caplog.at_level(logging.INFO): handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name) diff --git a/tests/unittest/test_update_settings_from_args.py b/tests/unittest/test_update_settings_from_args.py deleted file mode 100644 index 5cfa2202..00000000 --- a/tests/unittest/test_update_settings_from_args.py +++ /dev/null @@ -1,50 +0,0 @@ - -# Generated by CodiumAI -from pr_agent.algo.utils import update_settings_from_args -import logging -from pr_agent.config_loader import settings - - -import pytest - -class TestUpdateSettingsFromArgs: - # Tests that the function updates the setting when passed a single valid argument. - def test_single_valid_argument(self): - args = ['--pr_code_suggestions.extra_instructions="be funny"'] - update_settings_from_args(args) - assert settings.pr_code_suggestions.extra_instructions == '"be funny"' - - # Tests that the function updates the settings when passed multiple valid arguments. - def test_multiple_valid_arguments(self): - args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions=3'] - update_settings_from_args(args) - assert settings.pr_code_suggestions.extra_instructions == '"be funny"' - assert settings.pr_code_suggestions.num_code_suggestions == 3 - - # Tests that the function updates the setting when passed a boolean value. - def test_boolean_values(self): - settings.pr_code_suggestions.enabled = False - args = ['--pr_code_suggestions.enabled=true'] - update_settings_from_args(args) - assert 'pr_code_suggestions' in settings - assert 'enabled' in settings.pr_code_suggestions - assert settings.pr_code_suggestions.enabled == True - - # Tests that the function updates the setting when passed an integer value. - def test_integer_values(self): - args = ['--pr_code_suggestions.num_code_suggestions=3'] - update_settings_from_args(args) - assert settings.pr_code_suggestions.num_code_suggestions == 3 - - # Tests that the function does not update any settings when passed an empty argument list. - def test_empty_argument_list(self): - args = [] - update_settings_from_args(args) - assert settings == settings - - # Tests that the function logs an error when passed an invalid argument format. - def test_invalid_argument_format(self, caplog): - args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions'] - with caplog.at_level(logging.ERROR): - update_settings_from_args(args) - assert 'Invalid argument format' in caplog.text \ No newline at end of file