Support context aware settings (for each incoming request), support override of settings, refactor CLI to use pr_agent.py

This commit is contained in:
Ori Kotek
2023-08-01 14:43:26 +03:00
parent 6605f9c444
commit d7b77764c3
26 changed files with 305 additions and 384 deletions

View File

@ -1,6 +1,7 @@
import re
import shlex
from pr_agent.config_loader import settings
from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
@ -8,29 +9,39 @@ from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
command2class = {
"answer": PRReviewer,
"review": PRReviewer,
"review_pr": PRReviewer,
"reflect": PRInformationFromUser,
"reflect_and_review": PRInformationFromUser,
"describe": PRDescription,
"describe_pr": PRDescription,
"improve": PRCodeSuggestions,
"improve_code": PRCodeSuggestions,
"ask": PRQuestions,
"ask_question": PRQuestions,
"update_changelog": PRUpdateChangelog,
}
commands = list(command2class.keys())
class PRAgent:
def __init__(self):
pass
async def handle_request(self, pr_url, request) -> bool:
action, *args = request.strip().split()
if any(cmd == action for cmd in ["/answer"]):
await PRReviewer(pr_url, is_answer=True, args=args).review()
elif any(cmd == action for cmd in ["/review", "/review_pr", "/reflect_and_review"]):
if settings.pr_reviewer.ask_and_reflect or "/reflect_and_review" in request:
await PRInformationFromUser(pr_url, args=args).generate_questions()
else:
await PRReviewer(pr_url, args=args).review()
elif any(cmd == action for cmd in ["/describe", "/describe_pr"]):
await PRDescription(pr_url, args=args).describe()
elif any(cmd == action for cmd in ["/improve", "/improve_code"]):
await PRCodeSuggestions(pr_url, args=args).suggest()
elif any(cmd == action for cmd in ["/ask", "/ask_question"]):
await PRQuestions(pr_url, args=args).answer()
elif any(cmd == action for cmd in ["/update_changelog"]):
await PRUpdateChangelog(pr_url, args=args).update_changelog()
lexer = shlex.shlex(request, posix=True)
lexer.whitespace_split = True
action, *args = list(lexer)
args = update_settings_from_args(args)
action = action.lstrip("/").lower()
if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect:
action = "review"
if action == "answer":
await PRReviewer(pr_url, is_answer=True, args=args).run()
elif action in command2class:
await command2class[action](pr_url, args=args).run()
else:
return False
return True

View File

@ -1,10 +1,10 @@
import logging
import openai
from openai.error import APIError, Timeout, TryAgain, RateLimitError
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
OPENAI_RETRIES=5
@ -21,16 +21,16 @@ class AiHandler:
Raises a ValueError if the OpenAI key is missing.
"""
try:
openai.api_key = settings.openai.key
if settings.get("OPENAI.ORG", None):
openai.organization = settings.openai.org
self.deployment_id = settings.get("OPENAI.DEPLOYMENT_ID", None)
if settings.get("OPENAI.API_TYPE", None):
openai.api_type = settings.openai.api_type
if settings.get("OPENAI.API_VERSION", None):
openai.api_version = settings.openai.api_version
if settings.get("OPENAI.API_BASE", None):
openai.api_base = settings.openai.api_base
openai.api_key = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if get_settings().get("OPENAI.API_TYPE", None):
openai.api_type = get_settings().openai.api_type
if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base
except AttributeError as e:
raise ValueError("OpenAI key is required") from e

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import re
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
def extend_patch(original_file_str, patch_str, num_lines) -> str:
@ -55,7 +55,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
continue
extended_patch_lines.append(line)
except Exception as e:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.error(f"Failed to extend patch: {e}")
return patch_str
@ -126,14 +126,14 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
"""
if not new_file_content_str:
# logic for handling deleted files - don't show patch, just show that the file was deleted
if settings.config.verbosity_level > 0:
if get_settings().config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, minimizing deletion file")
patch = None # file was deleted
else:
patch_lines = patch.splitlines()
patch_new = omit_deletion_hunks(patch_lines)
if patch != patch_new:
if settings.config.verbosity_level > 0:
if get_settings().config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new
return patch
@ -141,7 +141,8 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
"""
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file.
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
the file.
Args:
patch (str): The patch string to be converted.

View File

@ -1,15 +1,15 @@
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
from typing import Dict
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
language_extension_map_org = settings.language_extension_map_org
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = settings.bad_extensions.default
if settings.config.use_extra_bad_extensions:
bad_extensions += settings.bad_extensions.extra
bad_extensions = get_settings().bad_extensions.default
if get_settings().config.use_extra_bad_extensions:
bad_extensions += get_settings().bad_extensions.extra
def filter_bad_extensions(files):

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Tuple, Union, Callable, List
from typing import Callable, Tuple
from github import RateLimitExceededException
@ -10,7 +10,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_large_diff
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider
DELETED_FILES_ = "Deleted files:\n"
@ -27,11 +27,15 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
Args:
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
request.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The name of the model used for tokenization.
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
diff. Defaults to False.
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
extra lines of context. Defaults to False.
Returns:
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
@ -76,10 +80,12 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]:
"""
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed.
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.
Args:
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files.
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
@ -119,10 +125,13 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
"""
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
tokens used.
Args:
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
top_langs (list): A list of dictionaries representing the languages used in the pull request and their
corresponding files.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The model used for tokenization.
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
Returns:
@ -181,7 +190,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.warning(f"Patch too large, minimizing it, {file.filename}")
if not modified_files_list:
total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
@ -196,15 +205,15 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
patch_final = patch
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
return patches, modified_files_list, deleted_files_list
async def retry_with_fallback_models(f: Callable):
model = settings.config.model
fallback_models = settings.config.fallback_models
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [fallback_models]
all_models = [model] + fallback_models

View File

@ -1,8 +1,7 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model
from pr_agent.algo import MAX_TOKENS
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
class TokenHandler:
@ -10,9 +9,12 @@ class TokenHandler:
A class for handling tokens in the context of a pull request.
Attributes:
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them.
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module.
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method.
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the
number of tokens in them.
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the
pr_agent.algo module.
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens
method.
"""
def __init__(self, pr, vars: dict, system, user):
@ -25,7 +27,7 @@ class TokenHandler:
- system: The system string.
- user: The user string.
"""
self.encoder = encoding_for_model(settings.config.model)
self.encoder = encoding_for_model(get_settings().config.model)
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):

View File

@ -1,15 +1,24 @@
from __future__ import annotations
from typing import List
import difflib
from datetime import datetime
import json
import logging
import re
import textwrap
from datetime import datetime
from typing import Any, List
from pr_agent.config_loader import settings
from starlette_context import context
from pr_agent.config_loader import get_settings, global_settings
def get_setting(key: str) -> Any:
try:
key = key.upper()
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
except Exception:
return global_settings.get(key, None)
def convert_to_markdown(output_data: dict) -> str:
"""
@ -97,12 +106,16 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
- data: A dictionary containing the parsed JSON data.
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message.
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion.
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines.
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the
message.
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON
message by parsing until the last valid code suggestion.
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or
newlines.
It tries to parse the JSON message with the closing bracket and checks if it is valid.
If the JSON message is valid, the parsed JSON data is returned.
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached.
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON
message is obtained or the maximum number of iterations is reached.
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
"""
@ -184,7 +197,8 @@ def convert_str_to_datetime(date_str):
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
"""
Generate a patch for a modified file by comparing the original content of the file with the new content provided as input.
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
input.
Args:
file: The file object for which the patch needs to be generated.
@ -199,14 +213,16 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
None.
Additional Information:
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output.
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created.
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it
as output.
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating
that the file was modified but no patch was found, and a patch is manually created.
"""
if not patch: # to Do - also add condition for file extension
try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True))
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
patch = ''.join(diff)
except Exception:
@ -214,7 +230,7 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
return patch
def update_settings_from_args(args: List[str]) -> None:
def update_settings_from_args(args: List[str]) -> List[str]:
"""
Update the settings of the Dynaconf object based on the arguments passed to the function.
@ -230,28 +246,22 @@ def update_settings_from_args(args: List[str]) -> None:
ValueError: If the argument is not in the correct format.
"""
other_args = []
if args:
for arg in args:
try:
arg = arg.strip()
if arg.startswith('--'):
arg = arg.strip('-').strip()
vals = arg.split('=')
if len(vals) != 2:
raise ValueError(f'Invalid argument format: {arg}')
logging.error(f'Invalid argument format: {arg}')
other_args.append(arg)
continue
key, value = vals
keys = key.split('.')
d = settings
for i, k in enumerate(keys[:-1]):
if k not in d:
raise ValueError(f'Invalid setting: {key}')
d = d[k]
if keys[-1] not in d:
raise ValueError(f'Invalid setting: {key}')
if isinstance(d[keys[-1]], bool):
d[keys[-1]] = value.lower() in ("yes", "true", "t", "1")
else:
d[keys[-1]] = type(d[keys[-1]])(value)
key = key.strip().upper()
value = value.strip()
get_settings().set(key, value)
logging.info(f'Updated setting {key} to: "{value}"')
except ValueError as e:
logging.error(str(e))
except Exception as e:
logging.error(f'Failed to parse argument {arg}: {e}')
else:
other_args.append(arg)
return other_args

View File

@ -3,15 +3,11 @@ import asyncio
import logging
import os
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
from pr_agent.agent.pr_agent import PRAgent, commands
from pr_agent.config_loader import get_settings
def run(args=None):
def run(inargs=None):
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\
Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>].
@ -34,79 +30,16 @@ To edit any configuration parameter from 'configuration.toml', just add -config_
For example: '- cli.py --pr-url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""")
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr',
'ask', 'ask_question',
'describe', 'describe_pr',
'improve', 'improve_code',
'reflect', 'review_after_reflect',
'update_changelog'],
default='review')
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
args = parser.parse_args(args)
args = parser.parse_args(inargs)
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
command = args.command.lower()
commands = {
'ask': _handle_ask_command,
'ask_question': _handle_ask_command,
'describe': _handle_describe_command,
'describe_pr': _handle_describe_command,
'improve': _handle_improve_command,
'improve_code': _handle_improve_command,
'review': _handle_review_command,
'review_pr': _handle_review_command,
'reflect': _handle_reflect_command,
'review_after_reflect': _handle_review_after_reflect_command,
'update_changelog': _handle_update_changelog,
}
if command in commands:
commands[command](args.pr_url, args.rest)
else:
print(f"Unknown command: {command}")
get_settings().set("CONFIG.CLI_MODE", True)
result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest)))
if not result:
parser.print_help()
def _handle_ask_command(pr_url: str, rest: list):
if len(rest) == 0:
print("Please specify a question")
return
print(f"Question: {' '.join(rest)} about PR {pr_url}")
reviewer = PRQuestions(pr_url, rest)
asyncio.run(reviewer.answer())
def _handle_describe_command(pr_url: str, rest: list):
print(f"PR description: {pr_url}")
reviewer = PRDescription(pr_url, args=rest)
asyncio.run(reviewer.describe())
def _handle_improve_command(pr_url: str, rest: list):
print(f"PR code suggestions: {pr_url}")
reviewer = PRCodeSuggestions(pr_url, args=rest)
asyncio.run(reviewer.suggest())
def _handle_review_command(pr_url: str, rest: list):
print(f"Reviewing PR: {pr_url}")
reviewer = PRReviewer(pr_url, cli_mode=True, args=rest)
asyncio.run(reviewer.review())
def _handle_reflect_command(pr_url: str, rest: list):
print(f"Asking the PR author questions: {pr_url}")
reviewer = PRInformationFromUser(pr_url)
asyncio.run(reviewer.generate_questions())
def _handle_review_after_reflect_command(pr_url: str, rest: list):
print(f"Processing author's answers and sending review: {pr_url}")
reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True, args=rest)
asyncio.run(reviewer.review())
def _handle_update_changelog(pr_url: str, rest: list):
print(f"Updating changlog for: {pr_url}")
reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest)
asyncio.run(reviewer.update_changelog())
if __name__ == '__main__':
run()

View File

@ -3,28 +3,36 @@ from pathlib import Path
from typing import Optional
from dynaconf import Dynaconf
from starlette_context import context
PR_AGENT_TOML_KEY = 'pr-agent'
current_dir = dirname(abspath(__file__))
settings = Dynaconf(
global_settings = Dynaconf(
envvar_prefix=False,
merge_enabled=True,
settings_files=[join(current_dir, f) for f in [
"settings/.secrets.toml",
"settings/configuration.toml",
"settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings_prod/.secrets.toml"
]]
"settings/.secrets.toml",
"settings/configuration.toml",
"settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings_prod/.secrets.toml"
]]
)
def get_settings():
try:
return context["settings"]
except Exception:
return global_settings
# Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Path:
"""
@ -39,6 +47,7 @@ def _find_repository_root() -> Path:
cwd = cwd.parent
return None
def _find_pyproject() -> Optional[Path]:
"""
Search for file pyproject.toml in the repository root.
@ -49,6 +58,7 @@ def _find_pyproject() -> Optional[Path]:
return pyproject if pyproject.is_file() else None
return None
pyproject_path = _find_pyproject()
if pyproject_path is not None:
settings.load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')
get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')

View File

@ -1,4 +1,4 @@
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from pr_agent.git_providers.github_provider import GithubProvider
from pr_agent.git_providers.gitlab_provider import GitLabProvider
@ -13,7 +13,7 @@ _GIT_PROVIDERS = {
def get_git_provider():
try:
provider_id = settings.config.git_provider
provider_id = get_settings().config.git_provider
except AttributeError as e:
raise ValueError("git_provider is a required attribute in the configuration file") from e
if provider_id not in _GIT_PROVIDERS:

View File

@ -5,15 +5,14 @@ from urllib.parse import urlparse
import requests
from atlassian.bitbucket import Cloud
from pr_agent.config_loader import settings
from ..config_loader import get_settings
from .git_provider import FilePatchInfo
class BitbucketProvider:
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
s = requests.Session()
s.headers['Authorization'] = f'Bearer {settings.get("BITBUCKET.BEARER_TOKEN", None)}'
s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
self.bitbucket_client = Cloud(session=s)
self.workspace_slug = None

View File

@ -7,12 +7,11 @@ from github import AppAuthentication, Auth, Github, GithubException
from retry import retry
from starlette_context import context
from pr_agent.config_loader import settings
from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
class GithubProvider(GitProvider):
@ -85,7 +84,7 @@ class GithubProvider(GitProvider):
return self.pr.get_files()
@retry(exceptions=RateLimitExceeded,
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]:
try:
files = self.get_files()
@ -118,7 +117,7 @@ class GithubProvider(GitProvider):
# self.pr.create_issue_comment(pr_comment)
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not settings.config.publish_output_progress:
if is_temporary and not get_settings().config.publish_output_progress:
logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
return
response = self.pr.create_issue_comment(pr_comment)
@ -149,7 +148,7 @@ class GithubProvider(GitProvider):
position = i
break
if position == -1:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
subject_type = "FILE"
else:
@ -174,13 +173,13 @@ class GithubProvider(GitProvider):
relevant_lines_end = suggestion['relevant_lines_end']
if not relevant_lines_start or relevant_lines_start == -1:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.exception(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
continue
if relevant_lines_end < relevant_lines_start:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.exception(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}")
@ -207,7 +206,7 @@ class GithubProvider(GitProvider):
self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list)
return True
except Exception as e:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.error(f"Failed to publish code suggestion, error: {e}")
return False
@ -241,7 +240,7 @@ class GithubProvider(GitProvider):
return self.github_user_id
def get_notifications(self, since: datetime):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user")
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications")
@ -282,12 +281,12 @@ class GithubProvider(GitProvider):
return repo_name, pr_number
def _get_github_client(self):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user")
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type == 'app':
try:
private_key = settings.github.private_key
app_id = settings.github.app_id
private_key = get_settings().github.private_key
app_id = get_settings().github.app_id
except AttributeError as e:
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
if not self.installation_id:
@ -298,7 +297,7 @@ class GithubProvider(GitProvider):
if deployment_type == 'user':
try:
token = settings.github.user_token
token = get_settings().github.user_token
except AttributeError as e:
raise ValueError(
"GitHub token is required when using user deployment. See: "
@ -327,7 +326,9 @@ class GithubProvider(GitProvider):
def publish_labels(self, pr_types):
try:
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", "Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9", "Other": "d1bcf9"}
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5",
"Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9",
"Other": "d1bcf9"}
post_parameters = []
for p in pr_types:
color = label_color_map.get(p, "d1bcf9") # default to "Other" color

View File

@ -6,9 +6,8 @@ from urllib.parse import urlparse
import gitlab
from gitlab import GitlabGetError
from pr_agent.config_loader import settings
from ..algo.language_handler import is_valid_file
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
logger = logging.getLogger()
@ -17,10 +16,10 @@ logger = logging.getLogger()
class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None)
gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GitLab URL is not set in the config file")
gitlab_access_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab(

View File

@ -5,7 +5,7 @@ from typing import List
from git import Repo
from pr_agent.config_loader import _find_repository_root, settings
from pr_agent.config_loader import _find_repository_root, get_settings
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
@ -38,12 +38,12 @@ class LocalGitProvider(GitProvider):
self._prepare_repo()
self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
self.description_path = settings.get('local.description_path') \
if settings.get('local.description_path') is not None else self.repo_path / 'description.md'
self.review_path = settings.get('local.review_path') \
if settings.get('local.review_path') is not None else self.repo_path / 'review.md'
self.description_path = get_settings().get('local.description_path') \
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
self.review_path = get_settings().get('local.review_path') \
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
# inline code comments are not supported for local git repositories
settings.pr_reviewer.inline_code_comments = False
get_settings().pr_reviewer.inline_code_comments = False
def _prepare_repo(self):
"""

View File

@ -3,7 +3,7 @@ import json
import os
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_reviewer import PRReviewer
@ -30,11 +30,11 @@ async def run_action():
return
# Set the environment variables in the settings
settings.set("OPENAI.KEY", OPENAI_KEY)
get_settings().set("OPENAI.KEY", OPENAI_KEY)
if OPENAI_ORG:
settings.set("OPENAI.ORG", OPENAI_ORG)
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
get_settings().set("OPENAI.ORG", OPENAI_ORG)
get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user")
# Load the event payload
try:
@ -50,7 +50,7 @@ async def run_action():
if action in ["opened", "reopened"]:
pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url:
await PRReviewer(pr_url).review()
await PRReviewer(pr_url).run()
# Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment":

View File

@ -1,6 +1,7 @@
from typing import Dict, Any
import copy
import logging
import sys
from typing import Any, Dict
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
@ -9,7 +10,7 @@ from starlette_context import context
from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.servers.utils import verify_signature
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@ -20,7 +21,8 @@ router = APIRouter()
async def handle_github_webhooks(request: Request, response: Response):
"""
Receives and processes incoming GitHub webhook requests.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further processing.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
processing.
"""
logging.debug("Received a GitHub webhook")
@ -29,6 +31,7 @@ async def handle_github_webhooks(request: Request, response: Response):
logging.debug(f'Request body:\n{body}')
installation_id = body.get("installation", {}).get("id")
context["installation_id"] = installation_id
context["settings"] = copy.deepcopy(global_settings)
return await handle_request(body)
@ -46,7 +49,7 @@ async def get_body(request):
raise HTTPException(status_code=400, detail="Error parsing request body") from e
body_bytes = await request.body()
signature_header = request.headers.get('x-hub-signature-256', None)
webhook_secret = getattr(settings.github, 'webhook_secret', None)
webhook_secret = getattr(get_settings().github, 'webhook_secret', None)
if webhook_secret:
verify_signature(body_bytes, webhook_secret, signature_header)
return body
@ -96,7 +99,7 @@ async def root():
def start():
# Override the deployment type to app
settings.set("GITHUB.DEPLOYMENT_TYPE", "app")
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router)

View File

@ -6,7 +6,7 @@ from datetime import datetime, timezone
import aiohttp
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.servers.help import bot_help_text
@ -38,8 +38,8 @@ async def polling_loop():
agent = PRAgent()
try:
deployment_type = settings.github.deployment_type
token = settings.github.user_token
deployment_type = get_settings().github.deployment_type
token = get_settings().github.user_token
except AttributeError:
deployment_type = 'none'
token = None

View File

@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse
from starlette.background import BackgroundTasks
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
app = FastAPI()
router = APIRouter()
@ -29,13 +29,13 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
def start():
gitlab_url = settings.get("GITLAB.URL", None)
gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GITLAB.URL is not set")
gitlab_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token:
raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set")
settings.config.git_provider = "gitlab"
get_settings().config.git_provider = "gitlab"
app = FastAPI()
app.include_router(router)

View File

@ -8,8 +8,8 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json, update_settings_from_args
from pr_agent.config_loader import settings
from pr_agent.algo.utils import try_fix_json
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import BitbucketProvider, get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -21,7 +21,6 @@ class PRCodeSuggestions:
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
update_settings_from_args(args)
self.ai_handler = AiHandler()
self.patches_diff = None
@ -33,24 +32,24 @@ class PRCodeSuggestions:
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": settings.pr_code_suggestions.num_code_suggestions,
"extra_instructions": settings.pr_code_suggestions.extra_instructions,
"num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_code_suggestions_prompt.system,
settings.pr_code_suggestions_prompt.user)
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
async def suggest(self):
async def run(self):
assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
logging.info('Generating code suggestions for PR...')
if settings.config.publish_output:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions()
if settings.config.publish_output:
if get_settings().config.publish_output:
logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...')
@ -71,9 +70,9 @@ class PRCodeSuggestions:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_code_suggestions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -86,7 +85,7 @@ class PRCodeSuggestions:
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review, code_suggestions=True)
return data
@ -95,7 +94,7 @@ class PRCodeSuggestions:
code_suggestions = []
for d in data['Code suggestions']:
try:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"suggestion: {d}")
relevant_file = d['relevant file'].strip()
relevant_lines_str = d['relevant lines'].strip()
@ -113,8 +112,8 @@ class PRCodeSuggestions:
code_suggestions.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_lines_start,
'relevant_lines_end': relevant_lines_end})
except:
if settings.config.verbosity_level >= 2:
except Exception:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse suggestion: {d}")
self.git_provider.publish_code_suggestions(code_suggestions)
@ -136,7 +135,7 @@ class PRCodeSuggestions:
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e:
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
return new_code_snippet

View File

@ -1,15 +1,14 @@
import copy
import json
import logging
from typing import Tuple, List
from typing import List, Tuple
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -17,13 +16,12 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
class PRDescription:
def __init__(self, pr_url: str, args: list = None):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model.
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
Args:
pr_url (str): The URL of the pull request.
args (list, optional): List of arguments passed to the PRDescription class. Defaults to None.
"""
update_settings_from_args(args)
# Initialize the git provider and main PR language
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
@ -40,27 +38,27 @@ class PRDescription:
"description": self.git_provider.get_pr_description(),
"language": self.main_pr_language,
"diff": "", # empty diff for initial calculation
"extra_instructions": settings.pr_description.extra_instructions,
"extra_instructions": get_settings().pr_description.extra_instructions,
}
# Initialize the token handler
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
settings.pr_description_prompt.system,
settings.pr_description_prompt.user,
get_settings().pr_description_prompt.system,
get_settings().pr_description_prompt.user,
)
# Initialize patches_diff and prediction attributes
self.patches_diff = None
self.prediction = None
async def describe(self):
async def run(self):
"""
Generates a PR description using an AI model and publishes it to the PR.
"""
logging.info('Generating a PR description...')
if settings.config.publish_output:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
@ -68,9 +66,9 @@ class PRDescription:
logging.info('Preparing answer...')
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
if settings.config.publish_output:
if get_settings().config.publish_output:
logging.info('Pushing answer...')
if settings.pr_description.publish_description_as_comment:
if get_settings().pr_description.publish_description_as_comment:
self.git_provider.publish_comment(markdown_text)
else:
self.git_provider.publish_description(pr_title, pr_body)
@ -116,10 +114,10 @@ class PRDescription:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
@ -170,7 +168,7 @@ class PRDescription:
else:
pr_body += f"{value}\n\n___\n"
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}")
return title, pr_body, pr_types, markdown_text

View File

@ -6,13 +6,11 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None):
self.git_provider = get_git_provider()(pr_url)
@ -29,19 +27,19 @@ class PRInformationFromUser:
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_information_from_user_prompt.system,
settings.pr_information_from_user_prompt.user)
get_settings().pr_information_from_user_prompt.system,
get_settings().pr_information_from_user_prompt.user)
self.patches_diff = None
self.prediction = None
async def generate_questions(self):
logging.info('Generating question to the user...')
if settings.config.publish_output:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing questions...')
pr_comment = self._prepare_pr_answer()
if settings.config.publish_output:
if get_settings().config.publish_output:
logging.info('Pushing questions...')
self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment()
@ -57,9 +55,9 @@ class PRInformationFromUser:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_information_from_user_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_information_from_user_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
system_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -68,7 +66,7 @@ class PRInformationFromUser:
def _prepare_pr_answer(self) -> str:
model_output = self.prediction.strip()
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"answer_str:\n{model_output}")
answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\
"\n>/answer\n>1) ...\n>2) ...\n>...\n"

View File

@ -6,7 +6,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -30,8 +30,8 @@ class PRQuestions:
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_questions_prompt.system,
settings.pr_questions_prompt.user)
get_settings().pr_questions_prompt.system,
get_settings().pr_questions_prompt.user)
self.patches_diff = None
self.prediction = None
@ -42,14 +42,14 @@ class PRQuestions:
question_str = ""
return question_str
async def answer(self):
async def run(self):
logging.info('Answering a PR question...')
if settings.config.publish_output:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing answer...')
pr_comment = self._prepare_pr_answer()
if settings.config.publish_output:
if get_settings().config.publish_output:
logging.info('Pushing answer...')
self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment()
@ -65,9 +65,9 @@ class PRQuestions:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_questions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_questions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -77,6 +77,6 @@ class PRQuestions:
def _prepare_pr_answer(self) -> str:
answer_str = f"Question: {self.question_str}\n\n"
answer_str += f"Answer:\n{self.prediction.strip()}\n\n"
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"answer_str:\n{answer_str}")
return answer_str

View File

@ -2,17 +2,17 @@ import copy
import json
import logging
from collections import OrderedDict
from typing import Tuple, List
from typing import List, Tuple
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, update_settings_from_args
from pr_agent.config_loader import settings
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language, IncrementalPR
from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
from pr_agent.servers.help import actions_help_text, bot_help_text
@ -20,17 +20,15 @@ class PRReviewer:
"""
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None):
def __init__(self, pr_url: str, is_answer: bool = False, args: list = None):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Args:
pr_url (str): The URL of the pull request to be reviewed.
cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False.
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
"""
update_settings_from_args(args)
self.parse_args(args) # -i command
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
@ -41,11 +39,10 @@ class PRReviewer:
self.is_answer = is_answer
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now")
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
answer_str, question_str = self._get_user_answers()
self.vars = {
@ -54,21 +51,21 @@ class PRReviewer:
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"require_score": settings.pr_reviewer.require_score_review,
"require_tests": settings.pr_reviewer.require_tests_review,
"require_security": settings.pr_reviewer.require_security_review,
"require_focused": settings.pr_reviewer.require_focused_review,
'num_code_suggestions': settings.pr_reviewer.num_code_suggestions,
"require_score": get_settings().pr_reviewer.require_score_review,
"require_tests": get_settings().pr_reviewer.require_tests_review,
"require_security": get_settings().pr_reviewer.require_security_review,
"require_focused": get_settings().pr_reviewer.require_focused_review,
'num_code_suggestions': get_settings().pr_reviewer.num_code_suggestions,
'question_str': question_str,
'answer_str': answer_str,
"extra_instructions": settings.pr_reviewer.extra_instructions,
"extra_instructions": get_settings().pr_reviewer.extra_instructions,
}
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
settings.pr_review_prompt.system,
settings.pr_review_prompt.user
get_settings().pr_review_prompt.system,
get_settings().pr_review_prompt.user
)
def parse_args(self, args: List[str]) -> None:
@ -88,13 +85,13 @@ class PRReviewer:
is_incremental = True
self.incremental = IncrementalPR(is_incremental)
async def review(self) -> None:
async def run(self) -> None:
"""
Review the pull request and generate feedback.
"""
logging.info('Reviewing PR...')
if settings.config.publish_output:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
@ -102,12 +99,12 @@ class PRReviewer:
logging.info('Preparing PR review...')
pr_comment = self._prepare_pr_review()
if settings.config.publish_output:
if get_settings().config.publish_output:
logging.info('Pushing PR review...')
self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment()
if settings.pr_reviewer.inline_code_comments:
if get_settings().pr_reviewer.inline_code_comments:
logging.info('Pushing inline code comments...')
self._publish_inline_code_comments()
@ -140,10 +137,10 @@ class PRReviewer:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables)
system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
@ -158,7 +155,8 @@ class PRReviewer:
def _prepare_pr_review(self) -> str:
"""
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback.
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes
the feedback.
"""
review = self.prediction.strip()
@ -174,7 +172,8 @@ class PRReviewer:
data['PR Analysis']['Security concerns'] = val
# Filter out code suggestions that can be submitted as inline comments
if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']:
if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
and 'Code suggestions' in data['PR Feedback']:
data['PR Feedback']['Code suggestions'] = [
d for d in data['PR Feedback']['Code suggestions']
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
@ -184,7 +183,8 @@ class PRReviewer:
# Add incremental review section
if self.incremental.is_incremental:
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}"
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
f"{self.git_provider.incremental.first_new_commit_sha}"
data = OrderedDict(data)
data.update({'Incremental PR Review': {
"⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}})
@ -194,7 +194,7 @@ class PRReviewer:
user = self.git_provider.get_user_id()
# Add help text if not in CLI mode
if not self.cli_mode:
if get_settings().get("CONFIG.CLI_MODE", False):
markdown_text += "\n### How to use\n"
if user and '[bot]' not in user:
markdown_text += bot_help_text(user)
@ -202,7 +202,7 @@ class PRReviewer:
markdown_text += actions_help_text
# Log markdown response if verbosity level is high
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}")
return markdown_text
@ -211,7 +211,7 @@ class PRReviewer:
"""
Publishes inline comments on a pull request with code suggestions generated by the AI model.
"""
if settings.pr_reviewer.num_code_suggestions == 0:
if get_settings().pr_reviewer.num_code_suggestions == 0:
return
review = self.prediction.strip()

View File

@ -9,9 +9,8 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.algo.utils import update_settings_from_args
from pr_agent.git_providers import get_git_provider, GithubProvider
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import GithubProvider, get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
CHANGELOG_LINES = 50
@ -24,8 +23,7 @@ class PRUpdateChangelog:
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
update_settings_from_args(args)
self.commit_changelog = settings.pr_update_changelog.push_changelog_changes
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str
self.ai_handler = AiHandler()
self.patches_diff = None
@ -39,23 +37,23 @@ class PRUpdateChangelog:
"diff": "", # empty diff for initial calculation
"changelog_file_str": self.changelog_file_str,
"today": date.today(),
"extra_instructions": settings.pr_update_changelog.extra_instructions,
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_update_changelog_prompt.system,
settings.pr_update_changelog_prompt.user)
get_settings().pr_update_changelog_prompt.system,
get_settings().pr_update_changelog_prompt.user)
async def update_changelog(self):
async def run(self):
assert type(self.git_provider) == GithubProvider, "Currently only Github is supported"
logging.info('Updating the changelog...')
if settings.config.publish_output:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR changelog updates...')
new_file_content, answer = self._prepare_changelog_update()
if settings.config.publish_output:
if get_settings().config.publish_output:
self.git_provider.remove_initial_comment()
logging.info('Publishing changelog updates...')
if self.commit_changelog:
@ -75,9 +73,9 @@ class PRUpdateChangelog:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_update_changelog_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_update_changelog_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -86,7 +84,7 @@ class PRUpdateChangelog:
return response
def _prepare_changelog_update(self) -> Tuple[str, str]:
answer = self.prediction.strip().strip("```").strip()
answer = self.prediction.strip().strip("```").strip() # noqa B005
if hasattr(self, "changelog_file"):
existing_content = self.changelog_file.decoded_content.decode()
else:
@ -100,7 +98,7 @@ class PRUpdateChangelog:
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
if settings.config.verbosity_level >= 2:
if get_settings().config.verbosity_level >= 2:
logging.info(f"answer:\n{answer}")
return new_file_content, answer
@ -120,7 +118,7 @@ class PRUpdateChangelog:
last_commit_id = list(self.git_provider.pr.get_commits())[-1]
try:
self.git_provider.pr.create_review(commit=last_commit_id, comments=[d])
except:
except Exception:
# we can't create a review for some reason, let's just publish a comment
self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}")
@ -147,7 +145,7 @@ Example:
changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines()
changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES]
self.changelog_file_str = "\n".join(changelog_file_lines)
except:
except Exception:
self.changelog_file_str = ""
if self.commit_changelog:
logging.info("No CHANGELOG.md file found in the repository. Creating one...")

View File

@ -2,7 +2,7 @@
import logging
from pr_agent.algo.git_patch_processing import handle_patch_deletions
from pr_agent.config_loader import settings
from pr_agent.config_loader import get_settings
"""
Code Analysis
@ -49,7 +49,7 @@ class TestHandlePatchDeletions:
original_file_content_str = 'foo\nbar\n'
new_file_content_str = ''
file_name = 'file.py'
settings.config.verbosity_level = 1
get_settings().config.verbosity_level = 1
with caplog.at_level(logging.INFO):
handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name)

View File

@ -1,50 +0,0 @@
# Generated by CodiumAI
from pr_agent.algo.utils import update_settings_from_args
import logging
from pr_agent.config_loader import settings
import pytest
class TestUpdateSettingsFromArgs:
# Tests that the function updates the setting when passed a single valid argument.
def test_single_valid_argument(self):
args = ['--pr_code_suggestions.extra_instructions="be funny"']
update_settings_from_args(args)
assert settings.pr_code_suggestions.extra_instructions == '"be funny"'
# Tests that the function updates the settings when passed multiple valid arguments.
def test_multiple_valid_arguments(self):
args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions=3']
update_settings_from_args(args)
assert settings.pr_code_suggestions.extra_instructions == '"be funny"'
assert settings.pr_code_suggestions.num_code_suggestions == 3
# Tests that the function updates the setting when passed a boolean value.
def test_boolean_values(self):
settings.pr_code_suggestions.enabled = False
args = ['--pr_code_suggestions.enabled=true']
update_settings_from_args(args)
assert 'pr_code_suggestions' in settings
assert 'enabled' in settings.pr_code_suggestions
assert settings.pr_code_suggestions.enabled == True
# Tests that the function updates the setting when passed an integer value.
def test_integer_values(self):
args = ['--pr_code_suggestions.num_code_suggestions=3']
update_settings_from_args(args)
assert settings.pr_code_suggestions.num_code_suggestions == 3
# Tests that the function does not update any settings when passed an empty argument list.
def test_empty_argument_list(self):
args = []
update_settings_from_args(args)
assert settings == settings
# Tests that the function logs an error when passed an invalid argument format.
def test_invalid_argument_format(self, caplog):
args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions']
with caplog.at_level(logging.ERROR):
update_settings_from_args(args)
assert 'Invalid argument format' in caplog.text