mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 04:40:38 +08:00
Support context aware settings (for each incoming request), support override of settings, refactor CLI to use pr_agent.py
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
import re
|
import shlex
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.algo.utils import update_settings_from_args
|
||||||
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||||
from pr_agent.tools.pr_description import PRDescription
|
from pr_agent.tools.pr_description import PRDescription
|
||||||
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
|
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
|
||||||
@ -8,29 +9,39 @@ from pr_agent.tools.pr_questions import PRQuestions
|
|||||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||||
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
|
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
|
||||||
|
|
||||||
|
command2class = {
|
||||||
|
"answer": PRReviewer,
|
||||||
|
"review": PRReviewer,
|
||||||
|
"review_pr": PRReviewer,
|
||||||
|
"reflect": PRInformationFromUser,
|
||||||
|
"reflect_and_review": PRInformationFromUser,
|
||||||
|
"describe": PRDescription,
|
||||||
|
"describe_pr": PRDescription,
|
||||||
|
"improve": PRCodeSuggestions,
|
||||||
|
"improve_code": PRCodeSuggestions,
|
||||||
|
"ask": PRQuestions,
|
||||||
|
"ask_question": PRQuestions,
|
||||||
|
"update_changelog": PRUpdateChangelog,
|
||||||
|
}
|
||||||
|
|
||||||
|
commands = list(command2class.keys())
|
||||||
|
|
||||||
class PRAgent:
|
class PRAgent:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def handle_request(self, pr_url, request) -> bool:
|
async def handle_request(self, pr_url, request) -> bool:
|
||||||
action, *args = request.strip().split()
|
lexer = shlex.shlex(request, posix=True)
|
||||||
if any(cmd == action for cmd in ["/answer"]):
|
lexer.whitespace_split = True
|
||||||
await PRReviewer(pr_url, is_answer=True, args=args).review()
|
action, *args = list(lexer)
|
||||||
elif any(cmd == action for cmd in ["/review", "/review_pr", "/reflect_and_review"]):
|
args = update_settings_from_args(args)
|
||||||
if settings.pr_reviewer.ask_and_reflect or "/reflect_and_review" in request:
|
action = action.lstrip("/").lower()
|
||||||
await PRInformationFromUser(pr_url, args=args).generate_questions()
|
if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect:
|
||||||
else:
|
action = "review"
|
||||||
await PRReviewer(pr_url, args=args).review()
|
if action == "answer":
|
||||||
elif any(cmd == action for cmd in ["/describe", "/describe_pr"]):
|
await PRReviewer(pr_url, is_answer=True, args=args).run()
|
||||||
await PRDescription(pr_url, args=args).describe()
|
elif action in command2class:
|
||||||
elif any(cmd == action for cmd in ["/improve", "/improve_code"]):
|
await command2class[action](pr_url, args=args).run()
|
||||||
await PRCodeSuggestions(pr_url, args=args).suggest()
|
|
||||||
elif any(cmd == action for cmd in ["/ask", "/ask_question"]):
|
|
||||||
await PRQuestions(pr_url, args=args).answer()
|
|
||||||
elif any(cmd == action for cmd in ["/update_changelog"]):
|
|
||||||
await PRUpdateChangelog(pr_url, args=args).update_changelog()
|
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai.error import APIError, Timeout, TryAgain, RateLimitError
|
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||||
from retry import retry
|
from retry import retry
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
OPENAI_RETRIES=5
|
OPENAI_RETRIES=5
|
||||||
|
|
||||||
@ -21,16 +21,16 @@ class AiHandler:
|
|||||||
Raises a ValueError if the OpenAI key is missing.
|
Raises a ValueError if the OpenAI key is missing.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
openai.api_key = settings.openai.key
|
openai.api_key = get_settings().openai.key
|
||||||
if settings.get("OPENAI.ORG", None):
|
if get_settings().get("OPENAI.ORG", None):
|
||||||
openai.organization = settings.openai.org
|
openai.organization = get_settings().openai.org
|
||||||
self.deployment_id = settings.get("OPENAI.DEPLOYMENT_ID", None)
|
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
if settings.get("OPENAI.API_TYPE", None):
|
if get_settings().get("OPENAI.API_TYPE", None):
|
||||||
openai.api_type = settings.openai.api_type
|
openai.api_type = get_settings().openai.api_type
|
||||||
if settings.get("OPENAI.API_VERSION", None):
|
if get_settings().get("OPENAI.API_VERSION", None):
|
||||||
openai.api_version = settings.openai.api_version
|
openai.api_version = get_settings().openai.api_version
|
||||||
if settings.get("OPENAI.API_BASE", None):
|
if get_settings().get("OPENAI.API_BASE", None):
|
||||||
openai.api_base = settings.openai.api_base
|
openai.api_base = get_settings().openai.api_base
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("OpenAI key is required") from e
|
raise ValueError("OpenAI key is required") from e
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
|
|
||||||
def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||||
@ -55,7 +55,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
|||||||
continue
|
continue
|
||||||
extended_patch_lines.append(line)
|
extended_patch_lines.append(line)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.error(f"Failed to extend patch: {e}")
|
logging.error(f"Failed to extend patch: {e}")
|
||||||
return patch_str
|
return patch_str
|
||||||
|
|
||||||
@ -126,14 +126,14 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
|||||||
"""
|
"""
|
||||||
if not new_file_content_str:
|
if not new_file_content_str:
|
||||||
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
||||||
if settings.config.verbosity_level > 0:
|
if get_settings().config.verbosity_level > 0:
|
||||||
logging.info(f"Processing file: {file_name}, minimizing deletion file")
|
logging.info(f"Processing file: {file_name}, minimizing deletion file")
|
||||||
patch = None # file was deleted
|
patch = None # file was deleted
|
||||||
else:
|
else:
|
||||||
patch_lines = patch.splitlines()
|
patch_lines = patch.splitlines()
|
||||||
patch_new = omit_deletion_hunks(patch_lines)
|
patch_new = omit_deletion_hunks(patch_lines)
|
||||||
if patch != patch_new:
|
if patch != patch_new:
|
||||||
if settings.config.verbosity_level > 0:
|
if get_settings().config.verbosity_level > 0:
|
||||||
logging.info(f"Processing file: {file_name}, hunks were deleted")
|
logging.info(f"Processing file: {file_name}, hunks were deleted")
|
||||||
patch = patch_new
|
patch = patch_new
|
||||||
return patch
|
return patch
|
||||||
@ -141,7 +141,8 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
|||||||
|
|
||||||
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file.
|
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
|
||||||
|
the file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
patch (str): The patch string to be converted.
|
patch (str): The patch string to be converted.
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
|
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
language_extension_map_org = settings.language_extension_map_org
|
language_extension_map_org = get_settings().language_extension_map_org
|
||||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||||
|
|
||||||
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
|
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
|
||||||
bad_extensions = settings.bad_extensions.default
|
bad_extensions = get_settings().bad_extensions.default
|
||||||
if settings.config.use_extra_bad_extensions:
|
if get_settings().config.use_extra_bad_extensions:
|
||||||
bad_extensions += settings.bad_extensions.extra
|
bad_extensions += get_settings().bad_extensions.extra
|
||||||
|
|
||||||
|
|
||||||
def filter_bad_extensions(files):
|
def filter_bad_extensions(files):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, Union, Callable, List
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
from github import RateLimitExceededException
|
from github import RateLimitExceededException
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
|
|||||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import load_large_diff
|
from pr_agent.algo.utils import load_large_diff
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers.git_provider import GitProvider
|
from pr_agent.git_providers.git_provider import GitProvider
|
||||||
|
|
||||||
DELETED_FILES_ = "Deleted files:\n"
|
DELETED_FILES_ = "Deleted files:\n"
|
||||||
@ -27,11 +27,15 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
|
|||||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
|
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
|
||||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
request.
|
||||||
|
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
|
||||||
|
pull request.
|
||||||
model (str): The name of the model used for tokenization.
|
model (str): The name of the model used for tokenization.
|
||||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
|
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
|
||||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
|
diff. Defaults to False.
|
||||||
|
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
|
||||||
|
extra lines of context. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
|
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||||
@ -76,10 +80,12 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
|||||||
add_line_numbers_to_hunks: bool) -> \
|
add_line_numbers_to_hunks: bool) -> \
|
||||||
Tuple[list, int]:
|
Tuple[list, int]:
|
||||||
"""
|
"""
|
||||||
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed.
|
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
|
||||||
|
minimization techniques if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
|
||||||
|
files.
|
||||||
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||||
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
|
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
|
||||||
|
|
||||||
@ -119,10 +125,13 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
|||||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
||||||
"""
|
"""
|
||||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
|
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
|
||||||
|
tokens used.
|
||||||
Args:
|
Args:
|
||||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
top_langs (list): A list of dictionaries representing the languages used in the pull request and their
|
||||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
corresponding files.
|
||||||
|
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
|
||||||
|
pull request.
|
||||||
model (str): The model used for tokenization.
|
model (str): The model used for tokenization.
|
||||||
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
||||||
Returns:
|
Returns:
|
||||||
@ -181,7 +190,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
# Current logic is to skip the patch if it's too large
|
# Current logic is to skip the patch if it's too large
|
||||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||||
# until we meet the requirements
|
# until we meet the requirements
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.warning(f"Patch too large, minimizing it, {file.filename}")
|
logging.warning(f"Patch too large, minimizing it, {file.filename}")
|
||||||
if not modified_files_list:
|
if not modified_files_list:
|
||||||
total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
|
total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
|
||||||
@ -196,15 +205,15 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
patch_final = patch
|
patch_final = patch
|
||||||
patches.append(patch_final)
|
patches.append(patch_final)
|
||||||
total_tokens += token_handler.count_tokens(patch_final)
|
total_tokens += token_handler.count_tokens(patch_final)
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
||||||
|
|
||||||
return patches, modified_files_list, deleted_files_list
|
return patches, modified_files_list, deleted_files_list
|
||||||
|
|
||||||
|
|
||||||
async def retry_with_fallback_models(f: Callable):
|
async def retry_with_fallback_models(f: Callable):
|
||||||
model = settings.config.model
|
model = get_settings().config.model
|
||||||
fallback_models = settings.config.fallback_models
|
fallback_models = get_settings().config.fallback_models
|
||||||
if not isinstance(fallback_models, list):
|
if not isinstance(fallback_models, list):
|
||||||
fallback_models = [fallback_models]
|
fallback_models = [fallback_models]
|
||||||
all_models = [model] + fallback_models
|
all_models = [model] + fallback_models
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
from pr_agent.algo import MAX_TOKENS
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.config_loader import settings
|
|
||||||
|
|
||||||
|
|
||||||
class TokenHandler:
|
class TokenHandler:
|
||||||
@ -10,9 +9,12 @@ class TokenHandler:
|
|||||||
A class for handling tokens in the context of a pull request.
|
A class for handling tokens in the context of a pull request.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them.
|
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the
|
||||||
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module.
|
number of tokens in them.
|
||||||
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method.
|
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the
|
||||||
|
pr_agent.algo module.
|
||||||
|
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens
|
||||||
|
method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pr, vars: dict, system, user):
|
def __init__(self, pr, vars: dict, system, user):
|
||||||
@ -25,7 +27,7 @@ class TokenHandler:
|
|||||||
- system: The system string.
|
- system: The system string.
|
||||||
- user: The user string.
|
- user: The user string.
|
||||||
"""
|
"""
|
||||||
self.encoder = encoding_for_model(settings.config.model)
|
self.encoder = encoding_for_model(get_settings().config.model)
|
||||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||||
|
|
||||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||||
|
@ -1,15 +1,24 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import difflib
|
import difflib
|
||||||
from datetime import datetime
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from starlette_context import context
|
||||||
|
|
||||||
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
|
|
||||||
|
|
||||||
|
def get_setting(key: str) -> Any:
|
||||||
|
try:
|
||||||
|
key = key.upper()
|
||||||
|
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
|
||||||
|
except Exception:
|
||||||
|
return global_settings.get(key, None)
|
||||||
|
|
||||||
def convert_to_markdown(output_data: dict) -> str:
|
def convert_to_markdown(output_data: dict) -> str:
|
||||||
"""
|
"""
|
||||||
@ -97,12 +106,16 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
|
|||||||
- data: A dictionary containing the parsed JSON data.
|
- data: A dictionary containing the parsed JSON data.
|
||||||
|
|
||||||
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
|
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
|
||||||
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message.
|
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the
|
||||||
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion.
|
message.
|
||||||
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines.
|
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON
|
||||||
|
message by parsing until the last valid code suggestion.
|
||||||
|
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or
|
||||||
|
newlines.
|
||||||
It tries to parse the JSON message with the closing bracket and checks if it is valid.
|
It tries to parse the JSON message with the closing bracket and checks if it is valid.
|
||||||
If the JSON message is valid, the parsed JSON data is returned.
|
If the JSON message is valid, the parsed JSON data is returned.
|
||||||
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached.
|
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON
|
||||||
|
message is obtained or the maximum number of iterations is reached.
|
||||||
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
|
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -184,7 +197,8 @@ def convert_str_to_datetime(date_str):
|
|||||||
|
|
||||||
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
|
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a patch for a modified file by comparing the original content of the file with the new content provided as input.
|
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
|
||||||
|
input.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file: The file object for which the patch needs to be generated.
|
file: The file object for which the patch needs to be generated.
|
||||||
@ -199,14 +213,16 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
|
|||||||
None.
|
None.
|
||||||
|
|
||||||
Additional Information:
|
Additional Information:
|
||||||
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output.
|
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it
|
||||||
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created.
|
as output.
|
||||||
|
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating
|
||||||
|
that the file was modified but no patch was found, and a patch is manually created.
|
||||||
"""
|
"""
|
||||||
if not patch: # to Do - also add condition for file extension
|
if not patch: # to Do - also add condition for file extension
|
||||||
try:
|
try:
|
||||||
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
|
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
|
||||||
new_file_content_str.splitlines(keepends=True))
|
new_file_content_str.splitlines(keepends=True))
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
|
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
|
||||||
patch = ''.join(diff)
|
patch = ''.join(diff)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -214,7 +230,7 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
|
|||||||
return patch
|
return patch
|
||||||
|
|
||||||
|
|
||||||
def update_settings_from_args(args: List[str]) -> None:
|
def update_settings_from_args(args: List[str]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Update the settings of the Dynaconf object based on the arguments passed to the function.
|
Update the settings of the Dynaconf object based on the arguments passed to the function.
|
||||||
|
|
||||||
@ -230,28 +246,22 @@ def update_settings_from_args(args: List[str]) -> None:
|
|||||||
ValueError: If the argument is not in the correct format.
|
ValueError: If the argument is not in the correct format.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
other_args = []
|
||||||
if args:
|
if args:
|
||||||
for arg in args:
|
for arg in args:
|
||||||
try:
|
arg = arg.strip()
|
||||||
|
if arg.startswith('--'):
|
||||||
arg = arg.strip('-').strip()
|
arg = arg.strip('-').strip()
|
||||||
vals = arg.split('=')
|
vals = arg.split('=')
|
||||||
if len(vals) != 2:
|
if len(vals) != 2:
|
||||||
raise ValueError(f'Invalid argument format: {arg}')
|
logging.error(f'Invalid argument format: {arg}')
|
||||||
|
other_args.append(arg)
|
||||||
|
continue
|
||||||
key, value = vals
|
key, value = vals
|
||||||
keys = key.split('.')
|
key = key.strip().upper()
|
||||||
d = settings
|
value = value.strip()
|
||||||
for i, k in enumerate(keys[:-1]):
|
get_settings().set(key, value)
|
||||||
if k not in d:
|
|
||||||
raise ValueError(f'Invalid setting: {key}')
|
|
||||||
d = d[k]
|
|
||||||
if keys[-1] not in d:
|
|
||||||
raise ValueError(f'Invalid setting: {key}')
|
|
||||||
if isinstance(d[keys[-1]], bool):
|
|
||||||
d[keys[-1]] = value.lower() in ("yes", "true", "t", "1")
|
|
||||||
else:
|
|
||||||
d[keys[-1]] = type(d[keys[-1]])(value)
|
|
||||||
logging.info(f'Updated setting {key} to: "{value}"')
|
logging.info(f'Updated setting {key} to: "{value}"')
|
||||||
except ValueError as e:
|
else:
|
||||||
logging.error(str(e))
|
other_args.append(arg)
|
||||||
except Exception as e:
|
return other_args
|
||||||
logging.error(f'Failed to parse argument {arg}: {e}')
|
|
||||||
|
@ -3,15 +3,11 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
from pr_agent.agent.pr_agent import PRAgent, commands
|
||||||
from pr_agent.tools.pr_description import PRDescription
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
|
|
||||||
from pr_agent.tools.pr_questions import PRQuestions
|
|
||||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
|
||||||
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
|
|
||||||
|
|
||||||
|
|
||||||
def run(args=None):
|
def run(inargs=None):
|
||||||
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
|
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
|
||||||
"""\
|
"""\
|
||||||
Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>].
|
Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>].
|
||||||
@ -34,79 +30,16 @@ To edit any configuration parameter from 'configuration.toml', just add -config_
|
|||||||
For example: '- cli.py --pr-url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
For example: '- cli.py --pr-url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
||||||
""")
|
""")
|
||||||
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
|
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
|
||||||
parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr',
|
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
|
||||||
'ask', 'ask_question',
|
|
||||||
'describe', 'describe_pr',
|
|
||||||
'improve', 'improve_code',
|
|
||||||
'reflect', 'review_after_reflect',
|
|
||||||
'update_changelog'],
|
|
||||||
default='review')
|
|
||||||
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
||||||
args = parser.parse_args(args)
|
args = parser.parse_args(inargs)
|
||||||
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
||||||
command = args.command.lower()
|
command = args.command.lower()
|
||||||
commands = {
|
get_settings().set("CONFIG.CLI_MODE", True)
|
||||||
'ask': _handle_ask_command,
|
result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest)))
|
||||||
'ask_question': _handle_ask_command,
|
if not result:
|
||||||
'describe': _handle_describe_command,
|
|
||||||
'describe_pr': _handle_describe_command,
|
|
||||||
'improve': _handle_improve_command,
|
|
||||||
'improve_code': _handle_improve_command,
|
|
||||||
'review': _handle_review_command,
|
|
||||||
'review_pr': _handle_review_command,
|
|
||||||
'reflect': _handle_reflect_command,
|
|
||||||
'review_after_reflect': _handle_review_after_reflect_command,
|
|
||||||
'update_changelog': _handle_update_changelog,
|
|
||||||
}
|
|
||||||
if command in commands:
|
|
||||||
commands[command](args.pr_url, args.rest)
|
|
||||||
else:
|
|
||||||
print(f"Unknown command: {command}")
|
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
def _handle_ask_command(pr_url: str, rest: list):
|
|
||||||
if len(rest) == 0:
|
|
||||||
print("Please specify a question")
|
|
||||||
return
|
|
||||||
print(f"Question: {' '.join(rest)} about PR {pr_url}")
|
|
||||||
reviewer = PRQuestions(pr_url, rest)
|
|
||||||
asyncio.run(reviewer.answer())
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_describe_command(pr_url: str, rest: list):
|
|
||||||
print(f"PR description: {pr_url}")
|
|
||||||
reviewer = PRDescription(pr_url, args=rest)
|
|
||||||
asyncio.run(reviewer.describe())
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_improve_command(pr_url: str, rest: list):
|
|
||||||
print(f"PR code suggestions: {pr_url}")
|
|
||||||
reviewer = PRCodeSuggestions(pr_url, args=rest)
|
|
||||||
asyncio.run(reviewer.suggest())
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_review_command(pr_url: str, rest: list):
|
|
||||||
print(f"Reviewing PR: {pr_url}")
|
|
||||||
reviewer = PRReviewer(pr_url, cli_mode=True, args=rest)
|
|
||||||
asyncio.run(reviewer.review())
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_reflect_command(pr_url: str, rest: list):
|
|
||||||
print(f"Asking the PR author questions: {pr_url}")
|
|
||||||
reviewer = PRInformationFromUser(pr_url)
|
|
||||||
asyncio.run(reviewer.generate_questions())
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_review_after_reflect_command(pr_url: str, rest: list):
|
|
||||||
print(f"Processing author's answers and sending review: {pr_url}")
|
|
||||||
reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True, args=rest)
|
|
||||||
asyncio.run(reviewer.review())
|
|
||||||
|
|
||||||
def _handle_update_changelog(pr_url: str, rest: list):
|
|
||||||
print(f"Updating changlog for: {pr_url}")
|
|
||||||
reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest)
|
|
||||||
asyncio.run(reviewer.update_changelog())
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run()
|
run()
|
||||||
|
@ -3,11 +3,12 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from dynaconf import Dynaconf
|
from dynaconf import Dynaconf
|
||||||
|
from starlette_context import context
|
||||||
|
|
||||||
PR_AGENT_TOML_KEY = 'pr-agent'
|
PR_AGENT_TOML_KEY = 'pr-agent'
|
||||||
|
|
||||||
current_dir = dirname(abspath(__file__))
|
current_dir = dirname(abspath(__file__))
|
||||||
settings = Dynaconf(
|
global_settings = Dynaconf(
|
||||||
envvar_prefix=False,
|
envvar_prefix=False,
|
||||||
merge_enabled=True,
|
merge_enabled=True,
|
||||||
settings_files=[join(current_dir, f) for f in [
|
settings_files=[join(current_dir, f) for f in [
|
||||||
@ -25,6 +26,13 @@ settings = Dynaconf(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings():
|
||||||
|
try:
|
||||||
|
return context["settings"]
|
||||||
|
except Exception:
|
||||||
|
return global_settings
|
||||||
|
|
||||||
|
|
||||||
# Add local configuration from pyproject.toml of the project being reviewed
|
# Add local configuration from pyproject.toml of the project being reviewed
|
||||||
def _find_repository_root() -> Path:
|
def _find_repository_root() -> Path:
|
||||||
"""
|
"""
|
||||||
@ -39,6 +47,7 @@ def _find_repository_root() -> Path:
|
|||||||
cwd = cwd.parent
|
cwd = cwd.parent
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _find_pyproject() -> Optional[Path]:
|
def _find_pyproject() -> Optional[Path]:
|
||||||
"""
|
"""
|
||||||
Search for file pyproject.toml in the repository root.
|
Search for file pyproject.toml in the repository root.
|
||||||
@ -49,6 +58,7 @@ def _find_pyproject() -> Optional[Path]:
|
|||||||
return pyproject if pyproject.is_file() else None
|
return pyproject if pyproject.is_file() else None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
pyproject_path = _find_pyproject()
|
pyproject_path = _find_pyproject()
|
||||||
if pyproject_path is not None:
|
if pyproject_path is not None:
|
||||||
settings.load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')
|
get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
||||||
from pr_agent.git_providers.github_provider import GithubProvider
|
from pr_agent.git_providers.github_provider import GithubProvider
|
||||||
from pr_agent.git_providers.gitlab_provider import GitLabProvider
|
from pr_agent.git_providers.gitlab_provider import GitLabProvider
|
||||||
@ -13,7 +13,7 @@ _GIT_PROVIDERS = {
|
|||||||
|
|
||||||
def get_git_provider():
|
def get_git_provider():
|
||||||
try:
|
try:
|
||||||
provider_id = settings.config.git_provider
|
provider_id = get_settings().config.git_provider
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("git_provider is a required attribute in the configuration file") from e
|
raise ValueError("git_provider is a required attribute in the configuration file") from e
|
||||||
if provider_id not in _GIT_PROVIDERS:
|
if provider_id not in _GIT_PROVIDERS:
|
||||||
|
@ -5,15 +5,14 @@ from urllib.parse import urlparse
|
|||||||
import requests
|
import requests
|
||||||
from atlassian.bitbucket import Cloud
|
from atlassian.bitbucket import Cloud
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from ..config_loader import get_settings
|
||||||
|
|
||||||
from .git_provider import FilePatchInfo
|
from .git_provider import FilePatchInfo
|
||||||
|
|
||||||
|
|
||||||
class BitbucketProvider:
|
class BitbucketProvider:
|
||||||
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||||
s = requests.Session()
|
s = requests.Session()
|
||||||
s.headers['Authorization'] = f'Bearer {settings.get("BITBUCKET.BEARER_TOKEN", None)}'
|
s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
|
||||||
self.bitbucket_client = Cloud(session=s)
|
self.bitbucket_client = Cloud(session=s)
|
||||||
|
|
||||||
self.workspace_slug = None
|
self.workspace_slug = None
|
||||||
|
@ -7,12 +7,11 @@ from github import AppAuthentication, Auth, Github, GithubException
|
|||||||
from retry import retry
|
from retry import retry
|
||||||
from starlette_context import context
|
from starlette_context import context
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
|
||||||
|
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
from ..algo.utils import load_large_diff
|
from ..algo.utils import load_large_diff
|
||||||
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
|
from ..config_loader import get_settings
|
||||||
from ..servers.utils import RateLimitExceeded
|
from ..servers.utils import RateLimitExceeded
|
||||||
|
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
|
||||||
|
|
||||||
|
|
||||||
class GithubProvider(GitProvider):
|
class GithubProvider(GitProvider):
|
||||||
@ -85,7 +84,7 @@ class GithubProvider(GitProvider):
|
|||||||
return self.pr.get_files()
|
return self.pr.get_files()
|
||||||
|
|
||||||
@retry(exceptions=RateLimitExceeded,
|
@retry(exceptions=RateLimitExceeded,
|
||||||
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
|
tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
try:
|
try:
|
||||||
files = self.get_files()
|
files = self.get_files()
|
||||||
@ -118,7 +117,7 @@ class GithubProvider(GitProvider):
|
|||||||
# self.pr.create_issue_comment(pr_comment)
|
# self.pr.create_issue_comment(pr_comment)
|
||||||
|
|
||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
if is_temporary and not settings.config.publish_output_progress:
|
if is_temporary and not get_settings().config.publish_output_progress:
|
||||||
logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
||||||
return
|
return
|
||||||
response = self.pr.create_issue_comment(pr_comment)
|
response = self.pr.create_issue_comment(pr_comment)
|
||||||
@ -149,7 +148,7 @@ class GithubProvider(GitProvider):
|
|||||||
position = i
|
position = i
|
||||||
break
|
break
|
||||||
if position == -1:
|
if position == -1:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||||
subject_type = "FILE"
|
subject_type = "FILE"
|
||||||
else:
|
else:
|
||||||
@ -174,13 +173,13 @@ class GithubProvider(GitProvider):
|
|||||||
relevant_lines_end = suggestion['relevant_lines_end']
|
relevant_lines_end = suggestion['relevant_lines_end']
|
||||||
|
|
||||||
if not relevant_lines_start or relevant_lines_start == -1:
|
if not relevant_lines_start or relevant_lines_start == -1:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
|
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if relevant_lines_end < relevant_lines_start:
|
if relevant_lines_end < relevant_lines_start:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.exception(f"Failed to publish code suggestion, "
|
logging.exception(f"Failed to publish code suggestion, "
|
||||||
f"relevant_lines_end is {relevant_lines_end} and "
|
f"relevant_lines_end is {relevant_lines_end} and "
|
||||||
f"relevant_lines_start is {relevant_lines_start}")
|
f"relevant_lines_start is {relevant_lines_start}")
|
||||||
@ -207,7 +206,7 @@ class GithubProvider(GitProvider):
|
|||||||
self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list)
|
self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.error(f"Failed to publish code suggestion, error: {e}")
|
logging.error(f"Failed to publish code suggestion, error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -241,7 +240,7 @@ class GithubProvider(GitProvider):
|
|||||||
return self.github_user_id
|
return self.github_user_id
|
||||||
|
|
||||||
def get_notifications(self, since: datetime):
|
def get_notifications(self, since: datetime):
|
||||||
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user")
|
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
|
||||||
|
|
||||||
if deployment_type != 'user':
|
if deployment_type != 'user':
|
||||||
raise ValueError("Deployment mode must be set to 'user' to get notifications")
|
raise ValueError("Deployment mode must be set to 'user' to get notifications")
|
||||||
@ -282,12 +281,12 @@ class GithubProvider(GitProvider):
|
|||||||
return repo_name, pr_number
|
return repo_name, pr_number
|
||||||
|
|
||||||
def _get_github_client(self):
|
def _get_github_client(self):
|
||||||
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user")
|
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
|
||||||
|
|
||||||
if deployment_type == 'app':
|
if deployment_type == 'app':
|
||||||
try:
|
try:
|
||||||
private_key = settings.github.private_key
|
private_key = get_settings().github.private_key
|
||||||
app_id = settings.github.app_id
|
app_id = get_settings().github.app_id
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
|
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
|
||||||
if not self.installation_id:
|
if not self.installation_id:
|
||||||
@ -298,7 +297,7 @@ class GithubProvider(GitProvider):
|
|||||||
|
|
||||||
if deployment_type == 'user':
|
if deployment_type == 'user':
|
||||||
try:
|
try:
|
||||||
token = settings.github.user_token
|
token = get_settings().github.user_token
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"GitHub token is required when using user deployment. See: "
|
"GitHub token is required when using user deployment. See: "
|
||||||
@ -327,7 +326,9 @@ class GithubProvider(GitProvider):
|
|||||||
|
|
||||||
def publish_labels(self, pr_types):
|
def publish_labels(self, pr_types):
|
||||||
try:
|
try:
|
||||||
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", "Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9", "Other": "d1bcf9"}
|
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5",
|
||||||
|
"Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9",
|
||||||
|
"Other": "d1bcf9"}
|
||||||
post_parameters = []
|
post_parameters = []
|
||||||
for p in pr_types:
|
for p in pr_types:
|
||||||
color = label_color_map.get(p, "d1bcf9") # default to "Other" color
|
color = label_color_map.get(p, "d1bcf9") # default to "Other" color
|
||||||
|
@ -6,9 +6,8 @@ from urllib.parse import urlparse
|
|||||||
import gitlab
|
import gitlab
|
||||||
from gitlab import GitlabGetError
|
from gitlab import GitlabGetError
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
|
||||||
|
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
|
from ..config_loader import get_settings
|
||||||
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@ -17,10 +16,10 @@ logger = logging.getLogger()
|
|||||||
class GitLabProvider(GitProvider):
|
class GitLabProvider(GitProvider):
|
||||||
|
|
||||||
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
|
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||||
gitlab_url = settings.get("GITLAB.URL", None)
|
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||||
if not gitlab_url:
|
if not gitlab_url:
|
||||||
raise ValueError("GitLab URL is not set in the config file")
|
raise ValueError("GitLab URL is not set in the config file")
|
||||||
gitlab_access_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||||
if not gitlab_access_token:
|
if not gitlab_access_token:
|
||||||
raise ValueError("GitLab personal access token is not set in the config file")
|
raise ValueError("GitLab personal access token is not set in the config file")
|
||||||
self.gl = gitlab.Gitlab(
|
self.gl = gitlab.Gitlab(
|
||||||
|
@ -5,7 +5,7 @@ from typing import List
|
|||||||
|
|
||||||
from git import Repo
|
from git import Repo
|
||||||
|
|
||||||
from pr_agent.config_loader import _find_repository_root, settings
|
from pr_agent.config_loader import _find_repository_root, get_settings
|
||||||
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||||
|
|
||||||
|
|
||||||
@ -38,12 +38,12 @@ class LocalGitProvider(GitProvider):
|
|||||||
self._prepare_repo()
|
self._prepare_repo()
|
||||||
self.diff_files = None
|
self.diff_files = None
|
||||||
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
||||||
self.description_path = settings.get('local.description_path') \
|
self.description_path = get_settings().get('local.description_path') \
|
||||||
if settings.get('local.description_path') is not None else self.repo_path / 'description.md'
|
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
|
||||||
self.review_path = settings.get('local.review_path') \
|
self.review_path = get_settings().get('local.review_path') \
|
||||||
if settings.get('local.review_path') is not None else self.repo_path / 'review.md'
|
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
|
||||||
# inline code comments are not supported for local git repositories
|
# inline code comments are not supported for local git repositories
|
||||||
settings.pr_reviewer.inline_code_comments = False
|
get_settings().pr_reviewer.inline_code_comments = False
|
||||||
|
|
||||||
def _prepare_repo(self):
|
def _prepare_repo(self):
|
||||||
"""
|
"""
|
||||||
|
@ -3,7 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from pr_agent.agent.pr_agent import PRAgent
|
from pr_agent.agent.pr_agent import PRAgent
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||||
|
|
||||||
|
|
||||||
@ -30,11 +30,11 @@ async def run_action():
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Set the environment variables in the settings
|
# Set the environment variables in the settings
|
||||||
settings.set("OPENAI.KEY", OPENAI_KEY)
|
get_settings().set("OPENAI.KEY", OPENAI_KEY)
|
||||||
if OPENAI_ORG:
|
if OPENAI_ORG:
|
||||||
settings.set("OPENAI.ORG", OPENAI_ORG)
|
get_settings().set("OPENAI.ORG", OPENAI_ORG)
|
||||||
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
|
get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
|
||||||
settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
|
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user")
|
||||||
|
|
||||||
# Load the event payload
|
# Load the event payload
|
||||||
try:
|
try:
|
||||||
@ -50,7 +50,7 @@ async def run_action():
|
|||||||
if action in ["opened", "reopened"]:
|
if action in ["opened", "reopened"]:
|
||||||
pr_url = event_payload.get("pull_request", {}).get("url")
|
pr_url = event_payload.get("pull_request", {}).get("url")
|
||||||
if pr_url:
|
if pr_url:
|
||||||
await PRReviewer(pr_url).review()
|
await PRReviewer(pr_url).run()
|
||||||
|
|
||||||
# Handle issue comment event
|
# Handle issue comment event
|
||||||
elif GITHUB_EVENT_NAME == "issue_comment":
|
elif GITHUB_EVENT_NAME == "issue_comment":
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Dict, Any
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
||||||
@ -9,7 +10,7 @@ from starlette_context import context
|
|||||||
from starlette_context.middleware import RawContextMiddleware
|
from starlette_context.middleware import RawContextMiddleware
|
||||||
|
|
||||||
from pr_agent.agent.pr_agent import PRAgent
|
from pr_agent.agent.pr_agent import PRAgent
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
from pr_agent.servers.utils import verify_signature
|
from pr_agent.servers.utils import verify_signature
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||||
@ -20,7 +21,8 @@ router = APIRouter()
|
|||||||
async def handle_github_webhooks(request: Request, response: Response):
|
async def handle_github_webhooks(request: Request, response: Response):
|
||||||
"""
|
"""
|
||||||
Receives and processes incoming GitHub webhook requests.
|
Receives and processes incoming GitHub webhook requests.
|
||||||
Verifies the request signature, parses the request body, and passes it to the handle_request function for further processing.
|
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
|
||||||
|
processing.
|
||||||
"""
|
"""
|
||||||
logging.debug("Received a GitHub webhook")
|
logging.debug("Received a GitHub webhook")
|
||||||
|
|
||||||
@ -29,6 +31,7 @@ async def handle_github_webhooks(request: Request, response: Response):
|
|||||||
logging.debug(f'Request body:\n{body}')
|
logging.debug(f'Request body:\n{body}')
|
||||||
installation_id = body.get("installation", {}).get("id")
|
installation_id = body.get("installation", {}).get("id")
|
||||||
context["installation_id"] = installation_id
|
context["installation_id"] = installation_id
|
||||||
|
context["settings"] = copy.deepcopy(global_settings)
|
||||||
|
|
||||||
return await handle_request(body)
|
return await handle_request(body)
|
||||||
|
|
||||||
@ -46,7 +49,7 @@ async def get_body(request):
|
|||||||
raise HTTPException(status_code=400, detail="Error parsing request body") from e
|
raise HTTPException(status_code=400, detail="Error parsing request body") from e
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
signature_header = request.headers.get('x-hub-signature-256', None)
|
signature_header = request.headers.get('x-hub-signature-256', None)
|
||||||
webhook_secret = getattr(settings.github, 'webhook_secret', None)
|
webhook_secret = getattr(get_settings().github, 'webhook_secret', None)
|
||||||
if webhook_secret:
|
if webhook_secret:
|
||||||
verify_signature(body_bytes, webhook_secret, signature_header)
|
verify_signature(body_bytes, webhook_secret, signature_header)
|
||||||
return body
|
return body
|
||||||
@ -96,7 +99,7 @@ async def root():
|
|||||||
|
|
||||||
def start():
|
def start():
|
||||||
# Override the deployment type to app
|
# Override the deployment type to app
|
||||||
settings.set("GITHUB.DEPLOYMENT_TYPE", "app")
|
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
|
||||||
middleware = [Middleware(RawContextMiddleware)]
|
middleware = [Middleware(RawContextMiddleware)]
|
||||||
app = FastAPI(middleware=middleware)
|
app = FastAPI(middleware=middleware)
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
@ -6,7 +6,7 @@ from datetime import datetime, timezone
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from pr_agent.agent.pr_agent import PRAgent
|
from pr_agent.agent.pr_agent import PRAgent
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.servers.help import bot_help_text
|
from pr_agent.servers.help import bot_help_text
|
||||||
|
|
||||||
@ -38,8 +38,8 @@ async def polling_loop():
|
|||||||
agent = PRAgent()
|
agent = PRAgent()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
deployment_type = settings.github.deployment_type
|
deployment_type = get_settings().github.deployment_type
|
||||||
token = settings.github.user_token
|
token = get_settings().github.user_token
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
deployment_type = 'none'
|
deployment_type = 'none'
|
||||||
token = None
|
token = None
|
||||||
|
@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse
|
|||||||
from starlette.background import BackgroundTasks
|
from starlette.background import BackgroundTasks
|
||||||
|
|
||||||
from pr_agent.agent.pr_agent import PRAgent
|
from pr_agent.agent.pr_agent import PRAgent
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -29,13 +29,13 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
gitlab_url = settings.get("GITLAB.URL", None)
|
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||||
if not gitlab_url:
|
if not gitlab_url:
|
||||||
raise ValueError("GITLAB.URL is not set")
|
raise ValueError("GITLAB.URL is not set")
|
||||||
gitlab_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||||
if not gitlab_token:
|
if not gitlab_token:
|
||||||
raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set")
|
raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set")
|
||||||
settings.config.git_provider = "gitlab"
|
get_settings().config.git_provider = "gitlab"
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
@ -8,8 +8,8 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import try_fix_json, update_settings_from_args
|
from pr_agent.algo.utils import try_fix_json
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import BitbucketProvider, get_git_provider
|
from pr_agent.git_providers import BitbucketProvider, get_git_provider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
|
|
||||||
@ -21,7 +21,6 @@ class PRCodeSuggestions:
|
|||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
update_settings_from_args(args)
|
|
||||||
|
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
@ -33,24 +32,24 @@ class PRCodeSuggestions:
|
|||||||
"description": self.git_provider.get_pr_description(),
|
"description": self.git_provider.get_pr_description(),
|
||||||
"language": self.main_language,
|
"language": self.main_language,
|
||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"num_code_suggestions": settings.pr_code_suggestions.num_code_suggestions,
|
"num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions,
|
||||||
"extra_instructions": settings.pr_code_suggestions.extra_instructions,
|
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
settings.pr_code_suggestions_prompt.system,
|
get_settings().pr_code_suggestions_prompt.system,
|
||||||
settings.pr_code_suggestions_prompt.user)
|
get_settings().pr_code_suggestions_prompt.user)
|
||||||
|
|
||||||
async def suggest(self):
|
async def run(self):
|
||||||
assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
|
assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
|
||||||
|
|
||||||
logging.info('Generating code suggestions for PR...')
|
logging.info('Generating code suggestions for PR...')
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
logging.info('Preparing PR review...')
|
logging.info('Preparing PR review...')
|
||||||
data = self._prepare_pr_code_suggestions()
|
data = self._prepare_pr_code_suggestions()
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
logging.info('Pushing PR review...')
|
logging.info('Pushing PR review...')
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
logging.info('Pushing inline code comments...')
|
logging.info('Pushing inline code comments...')
|
||||||
@ -71,9 +70,9 @@ class PRCodeSuggestions:
|
|||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(settings.pr_code_suggestions_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(settings.pr_code_suggestions_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
@ -86,7 +85,7 @@ class PRCodeSuggestions:
|
|||||||
try:
|
try:
|
||||||
data = json.loads(review)
|
data = json.loads(review)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Could not parse json response: {review}")
|
logging.info(f"Could not parse json response: {review}")
|
||||||
data = try_fix_json(review, code_suggestions=True)
|
data = try_fix_json(review, code_suggestions=True)
|
||||||
return data
|
return data
|
||||||
@ -95,7 +94,7 @@ class PRCodeSuggestions:
|
|||||||
code_suggestions = []
|
code_suggestions = []
|
||||||
for d in data['Code suggestions']:
|
for d in data['Code suggestions']:
|
||||||
try:
|
try:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"suggestion: {d}")
|
logging.info(f"suggestion: {d}")
|
||||||
relevant_file = d['relevant file'].strip()
|
relevant_file = d['relevant file'].strip()
|
||||||
relevant_lines_str = d['relevant lines'].strip()
|
relevant_lines_str = d['relevant lines'].strip()
|
||||||
@ -113,8 +112,8 @@ class PRCodeSuggestions:
|
|||||||
code_suggestions.append({'body': body, 'relevant_file': relevant_file,
|
code_suggestions.append({'body': body, 'relevant_file': relevant_file,
|
||||||
'relevant_lines_start': relevant_lines_start,
|
'relevant_lines_start': relevant_lines_start,
|
||||||
'relevant_lines_end': relevant_lines_end})
|
'relevant_lines_end': relevant_lines_end})
|
||||||
except:
|
except Exception:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Could not parse suggestion: {d}")
|
logging.info(f"Could not parse suggestion: {d}")
|
||||||
|
|
||||||
self.git_provider.publish_code_suggestions(code_suggestions)
|
self.git_provider.publish_code_suggestions(code_suggestions)
|
||||||
@ -136,7 +135,7 @@ class PRCodeSuggestions:
|
|||||||
if delta_spaces > 0:
|
if delta_spaces > 0:
|
||||||
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
|
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
|
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
|
||||||
|
|
||||||
return new_code_snippet
|
return new_code_snippet
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, List
|
from typing import List, Tuple
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import update_settings_from_args
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.config_loader import settings
|
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
|
|
||||||
@ -17,13 +16,12 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
|
|||||||
class PRDescription:
|
class PRDescription:
|
||||||
def __init__(self, pr_url: str, args: list = None):
|
def __init__(self, pr_url: str, args: list = None):
|
||||||
"""
|
"""
|
||||||
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model.
|
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
||||||
|
using an AI model.
|
||||||
Args:
|
Args:
|
||||||
pr_url (str): The URL of the pull request.
|
pr_url (str): The URL of the pull request.
|
||||||
args (list, optional): List of arguments passed to the PRDescription class. Defaults to None.
|
args (list, optional): List of arguments passed to the PRDescription class. Defaults to None.
|
||||||
"""
|
"""
|
||||||
update_settings_from_args(args)
|
|
||||||
|
|
||||||
# Initialize the git provider and main PR language
|
# Initialize the git provider and main PR language
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_pr_language = get_main_pr_language(
|
self.main_pr_language = get_main_pr_language(
|
||||||
@ -40,27 +38,27 @@ class PRDescription:
|
|||||||
"description": self.git_provider.get_pr_description(),
|
"description": self.git_provider.get_pr_description(),
|
||||||
"language": self.main_pr_language,
|
"language": self.main_pr_language,
|
||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"extra_instructions": settings.pr_description.extra_instructions,
|
"extra_instructions": get_settings().pr_description.extra_instructions,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize the token handler
|
# Initialize the token handler
|
||||||
self.token_handler = TokenHandler(
|
self.token_handler = TokenHandler(
|
||||||
self.git_provider.pr,
|
self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
settings.pr_description_prompt.system,
|
get_settings().pr_description_prompt.system,
|
||||||
settings.pr_description_prompt.user,
|
get_settings().pr_description_prompt.user,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize patches_diff and prediction attributes
|
# Initialize patches_diff and prediction attributes
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
|
||||||
async def describe(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
Generates a PR description using an AI model and publishes it to the PR.
|
Generates a PR description using an AI model and publishes it to the PR.
|
||||||
"""
|
"""
|
||||||
logging.info('Generating a PR description...')
|
logging.info('Generating a PR description...')
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
|
||||||
|
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
@ -68,9 +66,9 @@ class PRDescription:
|
|||||||
logging.info('Preparing answer...')
|
logging.info('Preparing answer...')
|
||||||
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
|
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
|
||||||
|
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
logging.info('Pushing answer...')
|
logging.info('Pushing answer...')
|
||||||
if settings.pr_description.publish_description_as_comment:
|
if get_settings().pr_description.publish_description_as_comment:
|
||||||
self.git_provider.publish_comment(markdown_text)
|
self.git_provider.publish_comment(markdown_text)
|
||||||
else:
|
else:
|
||||||
self.git_provider.publish_description(pr_title, pr_body)
|
self.git_provider.publish_description(pr_title, pr_body)
|
||||||
@ -116,10 +114,10 @@ class PRDescription:
|
|||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
|
||||||
|
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
|
|
||||||
@ -170,7 +168,7 @@ class PRDescription:
|
|||||||
else:
|
else:
|
||||||
pr_body += f"{value}\n\n___\n"
|
pr_body += f"{value}\n\n___\n"
|
||||||
|
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"title:\n{title}\n{pr_body}")
|
logging.info(f"title:\n{title}\n{pr_body}")
|
||||||
|
|
||||||
return title, pr_body, pr_types, markdown_text
|
return title, pr_body, pr_types, markdown_text
|
@ -6,13 +6,11 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PRInformationFromUser:
|
class PRInformationFromUser:
|
||||||
def __init__(self, pr_url: str, args: list = None):
|
def __init__(self, pr_url: str, args: list = None):
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
@ -29,19 +27,19 @@ class PRInformationFromUser:
|
|||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
settings.pr_information_from_user_prompt.system,
|
get_settings().pr_information_from_user_prompt.system,
|
||||||
settings.pr_information_from_user_prompt.user)
|
get_settings().pr_information_from_user_prompt.user)
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
|
||||||
async def generate_questions(self):
|
async def generate_questions(self):
|
||||||
logging.info('Generating question to the user...')
|
logging.info('Generating question to the user...')
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
logging.info('Preparing questions...')
|
logging.info('Preparing questions...')
|
||||||
pr_comment = self._prepare_pr_answer()
|
pr_comment = self._prepare_pr_answer()
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
logging.info('Pushing questions...')
|
logging.info('Pushing questions...')
|
||||||
self.git_provider.publish_comment(pr_comment)
|
self.git_provider.publish_comment(pr_comment)
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
@ -57,9 +55,9 @@ class PRInformationFromUser:
|
|||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(settings.pr_information_from_user_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(settings.pr_information_from_user_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.user).render(variables)
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
@ -68,7 +66,7 @@ class PRInformationFromUser:
|
|||||||
|
|
||||||
def _prepare_pr_answer(self) -> str:
|
def _prepare_pr_answer(self) -> str:
|
||||||
model_output = self.prediction.strip()
|
model_output = self.prediction.strip()
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"answer_str:\n{model_output}")
|
logging.info(f"answer_str:\n{model_output}")
|
||||||
answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\
|
answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\
|
||||||
"\n>/answer\n>1) ...\n>2) ...\n>...\n"
|
"\n>/answer\n>1) ...\n>2) ...\n>...\n"
|
||||||
|
@ -6,7 +6,7 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
|
|
||||||
@ -30,8 +30,8 @@ class PRQuestions:
|
|||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
settings.pr_questions_prompt.system,
|
get_settings().pr_questions_prompt.system,
|
||||||
settings.pr_questions_prompt.user)
|
get_settings().pr_questions_prompt.user)
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
|
||||||
@ -42,14 +42,14 @@ class PRQuestions:
|
|||||||
question_str = ""
|
question_str = ""
|
||||||
return question_str
|
return question_str
|
||||||
|
|
||||||
async def answer(self):
|
async def run(self):
|
||||||
logging.info('Answering a PR question...')
|
logging.info('Answering a PR question...')
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
logging.info('Preparing answer...')
|
logging.info('Preparing answer...')
|
||||||
pr_comment = self._prepare_pr_answer()
|
pr_comment = self._prepare_pr_answer()
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
logging.info('Pushing answer...')
|
logging.info('Pushing answer...')
|
||||||
self.git_provider.publish_comment(pr_comment)
|
self.git_provider.publish_comment(pr_comment)
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
@ -65,9 +65,9 @@ class PRQuestions:
|
|||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(settings.pr_questions_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(settings.pr_questions_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
@ -77,6 +77,6 @@ class PRQuestions:
|
|||||||
def _prepare_pr_answer(self) -> str:
|
def _prepare_pr_answer(self) -> str:
|
||||||
answer_str = f"Question: {self.question_str}\n\n"
|
answer_str = f"Question: {self.question_str}\n\n"
|
||||||
answer_str += f"Answer:\n{self.prediction.strip()}\n\n"
|
answer_str += f"Answer:\n{self.prediction.strip()}\n\n"
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"answer_str:\n{answer_str}")
|
logging.info(f"answer_str:\n{answer_str}")
|
||||||
return answer_str
|
return answer_str
|
||||||
|
@ -2,17 +2,17 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Tuple, List
|
from typing import List, Tuple
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, update_settings_from_args
|
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language, IncrementalPR
|
from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
|
||||||
from pr_agent.servers.help import actions_help_text, bot_help_text
|
from pr_agent.servers.help import actions_help_text, bot_help_text
|
||||||
|
|
||||||
|
|
||||||
@ -20,17 +20,15 @@ class PRReviewer:
|
|||||||
"""
|
"""
|
||||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||||
"""
|
"""
|
||||||
def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None):
|
def __init__(self, pr_url: str, is_answer: bool = False, args: list = None):
|
||||||
"""
|
"""
|
||||||
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pr_url (str): The URL of the pull request to be reviewed.
|
pr_url (str): The URL of the pull request to be reviewed.
|
||||||
cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False.
|
|
||||||
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
|
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
|
||||||
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
||||||
"""
|
"""
|
||||||
update_settings_from_args(args)
|
|
||||||
self.parse_args(args) # -i command
|
self.parse_args(args) # -i command
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
|
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
|
||||||
@ -41,11 +39,10 @@ class PRReviewer:
|
|||||||
self.is_answer = is_answer
|
self.is_answer = is_answer
|
||||||
|
|
||||||
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
||||||
raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now")
|
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
self.cli_mode = cli_mode
|
|
||||||
|
|
||||||
answer_str, question_str = self._get_user_answers()
|
answer_str, question_str = self._get_user_answers()
|
||||||
self.vars = {
|
self.vars = {
|
||||||
@ -54,21 +51,21 @@ class PRReviewer:
|
|||||||
"description": self.git_provider.get_pr_description(),
|
"description": self.git_provider.get_pr_description(),
|
||||||
"language": self.main_language,
|
"language": self.main_language,
|
||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"require_score": settings.pr_reviewer.require_score_review,
|
"require_score": get_settings().pr_reviewer.require_score_review,
|
||||||
"require_tests": settings.pr_reviewer.require_tests_review,
|
"require_tests": get_settings().pr_reviewer.require_tests_review,
|
||||||
"require_security": settings.pr_reviewer.require_security_review,
|
"require_security": get_settings().pr_reviewer.require_security_review,
|
||||||
"require_focused": settings.pr_reviewer.require_focused_review,
|
"require_focused": get_settings().pr_reviewer.require_focused_review,
|
||||||
'num_code_suggestions': settings.pr_reviewer.num_code_suggestions,
|
'num_code_suggestions': get_settings().pr_reviewer.num_code_suggestions,
|
||||||
'question_str': question_str,
|
'question_str': question_str,
|
||||||
'answer_str': answer_str,
|
'answer_str': answer_str,
|
||||||
"extra_instructions": settings.pr_reviewer.extra_instructions,
|
"extra_instructions": get_settings().pr_reviewer.extra_instructions,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.token_handler = TokenHandler(
|
self.token_handler = TokenHandler(
|
||||||
self.git_provider.pr,
|
self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
settings.pr_review_prompt.system,
|
get_settings().pr_review_prompt.system,
|
||||||
settings.pr_review_prompt.user
|
get_settings().pr_review_prompt.user
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_args(self, args: List[str]) -> None:
|
def parse_args(self, args: List[str]) -> None:
|
||||||
@ -88,13 +85,13 @@ class PRReviewer:
|
|||||||
is_incremental = True
|
is_incremental = True
|
||||||
self.incremental = IncrementalPR(is_incremental)
|
self.incremental = IncrementalPR(is_incremental)
|
||||||
|
|
||||||
async def review(self) -> None:
|
async def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
Review the pull request and generate feedback.
|
Review the pull request and generate feedback.
|
||||||
"""
|
"""
|
||||||
logging.info('Reviewing PR...')
|
logging.info('Reviewing PR...')
|
||||||
|
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
||||||
|
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
@ -102,12 +99,12 @@ class PRReviewer:
|
|||||||
logging.info('Preparing PR review...')
|
logging.info('Preparing PR review...')
|
||||||
pr_comment = self._prepare_pr_review()
|
pr_comment = self._prepare_pr_review()
|
||||||
|
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
logging.info('Pushing PR review...')
|
logging.info('Pushing PR review...')
|
||||||
self.git_provider.publish_comment(pr_comment)
|
self.git_provider.publish_comment(pr_comment)
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
|
|
||||||
if settings.pr_reviewer.inline_code_comments:
|
if get_settings().pr_reviewer.inline_code_comments:
|
||||||
logging.info('Pushing inline code comments...')
|
logging.info('Pushing inline code comments...')
|
||||||
self._publish_inline_code_comments()
|
self._publish_inline_code_comments()
|
||||||
|
|
||||||
@ -140,10 +137,10 @@ class PRReviewer:
|
|||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
|
||||||
|
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
|
|
||||||
@ -158,7 +155,8 @@ class PRReviewer:
|
|||||||
|
|
||||||
def _prepare_pr_review(self) -> str:
|
def _prepare_pr_review(self) -> str:
|
||||||
"""
|
"""
|
||||||
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback.
|
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes
|
||||||
|
the feedback.
|
||||||
"""
|
"""
|
||||||
review = self.prediction.strip()
|
review = self.prediction.strip()
|
||||||
|
|
||||||
@ -174,7 +172,8 @@ class PRReviewer:
|
|||||||
data['PR Analysis']['Security concerns'] = val
|
data['PR Analysis']['Security concerns'] = val
|
||||||
|
|
||||||
# Filter out code suggestions that can be submitted as inline comments
|
# Filter out code suggestions that can be submitted as inline comments
|
||||||
if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']:
|
if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
|
||||||
|
and 'Code suggestions' in data['PR Feedback']:
|
||||||
data['PR Feedback']['Code suggestions'] = [
|
data['PR Feedback']['Code suggestions'] = [
|
||||||
d for d in data['PR Feedback']['Code suggestions']
|
d for d in data['PR Feedback']['Code suggestions']
|
||||||
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
|
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
|
||||||
@ -184,7 +183,8 @@ class PRReviewer:
|
|||||||
|
|
||||||
# Add incremental review section
|
# Add incremental review section
|
||||||
if self.incremental.is_incremental:
|
if self.incremental.is_incremental:
|
||||||
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}"
|
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
|
||||||
|
f"{self.git_provider.incremental.first_new_commit_sha}"
|
||||||
data = OrderedDict(data)
|
data = OrderedDict(data)
|
||||||
data.update({'Incremental PR Review': {
|
data.update({'Incremental PR Review': {
|
||||||
"⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}})
|
"⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}})
|
||||||
@ -194,7 +194,7 @@ class PRReviewer:
|
|||||||
user = self.git_provider.get_user_id()
|
user = self.git_provider.get_user_id()
|
||||||
|
|
||||||
# Add help text if not in CLI mode
|
# Add help text if not in CLI mode
|
||||||
if not self.cli_mode:
|
if get_settings().get("CONFIG.CLI_MODE", False):
|
||||||
markdown_text += "\n### How to use\n"
|
markdown_text += "\n### How to use\n"
|
||||||
if user and '[bot]' not in user:
|
if user and '[bot]' not in user:
|
||||||
markdown_text += bot_help_text(user)
|
markdown_text += bot_help_text(user)
|
||||||
@ -202,7 +202,7 @@ class PRReviewer:
|
|||||||
markdown_text += actions_help_text
|
markdown_text += actions_help_text
|
||||||
|
|
||||||
# Log markdown response if verbosity level is high
|
# Log markdown response if verbosity level is high
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Markdown response:\n{markdown_text}")
|
logging.info(f"Markdown response:\n{markdown_text}")
|
||||||
|
|
||||||
return markdown_text
|
return markdown_text
|
||||||
@ -211,7 +211,7 @@ class PRReviewer:
|
|||||||
"""
|
"""
|
||||||
Publishes inline comments on a pull request with code suggestions generated by the AI model.
|
Publishes inline comments on a pull request with code suggestions generated by the AI model.
|
||||||
"""
|
"""
|
||||||
if settings.pr_reviewer.num_code_suggestions == 0:
|
if get_settings().pr_reviewer.num_code_suggestions == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
review = self.prediction.strip()
|
review = self.prediction.strip()
|
||||||
|
@ -9,9 +9,8 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.algo.utils import update_settings_from_args
|
from pr_agent.git_providers import GithubProvider, get_git_provider
|
||||||
from pr_agent.git_providers import get_git_provider, GithubProvider
|
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
|
|
||||||
CHANGELOG_LINES = 50
|
CHANGELOG_LINES = 50
|
||||||
@ -24,8 +23,7 @@ class PRUpdateChangelog:
|
|||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
update_settings_from_args(args)
|
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
|
||||||
self.commit_changelog = settings.pr_update_changelog.push_changelog_changes
|
|
||||||
self._get_changlog_file() # self.changelog_file_str
|
self._get_changlog_file() # self.changelog_file_str
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
@ -39,23 +37,23 @@ class PRUpdateChangelog:
|
|||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"changelog_file_str": self.changelog_file_str,
|
"changelog_file_str": self.changelog_file_str,
|
||||||
"today": date.today(),
|
"today": date.today(),
|
||||||
"extra_instructions": settings.pr_update_changelog.extra_instructions,
|
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
settings.pr_update_changelog_prompt.system,
|
get_settings().pr_update_changelog_prompt.system,
|
||||||
settings.pr_update_changelog_prompt.user)
|
get_settings().pr_update_changelog_prompt.user)
|
||||||
|
|
||||||
async def update_changelog(self):
|
async def run(self):
|
||||||
assert type(self.git_provider) == GithubProvider, "Currently only Github is supported"
|
assert type(self.git_provider) == GithubProvider, "Currently only Github is supported"
|
||||||
|
|
||||||
logging.info('Updating the changelog...')
|
logging.info('Updating the changelog...')
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
logging.info('Preparing PR changelog updates...')
|
logging.info('Preparing PR changelog updates...')
|
||||||
new_file_content, answer = self._prepare_changelog_update()
|
new_file_content, answer = self._prepare_changelog_update()
|
||||||
if settings.config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
logging.info('Publishing changelog updates...')
|
logging.info('Publishing changelog updates...')
|
||||||
if self.commit_changelog:
|
if self.commit_changelog:
|
||||||
@ -75,9 +73,9 @@ class PRUpdateChangelog:
|
|||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(settings.pr_update_changelog_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(settings.pr_update_changelog_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
@ -86,7 +84,7 @@ class PRUpdateChangelog:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def _prepare_changelog_update(self) -> Tuple[str, str]:
|
def _prepare_changelog_update(self) -> Tuple[str, str]:
|
||||||
answer = self.prediction.strip().strip("```").strip()
|
answer = self.prediction.strip().strip("```").strip() # noqa B005
|
||||||
if hasattr(self, "changelog_file"):
|
if hasattr(self, "changelog_file"):
|
||||||
existing_content = self.changelog_file.decoded_content.decode()
|
existing_content = self.changelog_file.decoded_content.decode()
|
||||||
else:
|
else:
|
||||||
@ -100,7 +98,7 @@ class PRUpdateChangelog:
|
|||||||
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
|
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
|
||||||
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
|
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
|
||||||
|
|
||||||
if settings.config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"answer:\n{answer}")
|
logging.info(f"answer:\n{answer}")
|
||||||
|
|
||||||
return new_file_content, answer
|
return new_file_content, answer
|
||||||
@ -120,7 +118,7 @@ class PRUpdateChangelog:
|
|||||||
last_commit_id = list(self.git_provider.pr.get_commits())[-1]
|
last_commit_id = list(self.git_provider.pr.get_commits())[-1]
|
||||||
try:
|
try:
|
||||||
self.git_provider.pr.create_review(commit=last_commit_id, comments=[d])
|
self.git_provider.pr.create_review(commit=last_commit_id, comments=[d])
|
||||||
except:
|
except Exception:
|
||||||
# we can't create a review for some reason, let's just publish a comment
|
# we can't create a review for some reason, let's just publish a comment
|
||||||
self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}")
|
self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}")
|
||||||
|
|
||||||
@ -147,7 +145,7 @@ Example:
|
|||||||
changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines()
|
changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines()
|
||||||
changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES]
|
changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES]
|
||||||
self.changelog_file_str = "\n".join(changelog_file_lines)
|
self.changelog_file_str = "\n".join(changelog_file_lines)
|
||||||
except:
|
except Exception:
|
||||||
self.changelog_file_str = ""
|
self.changelog_file_str = ""
|
||||||
if self.commit_changelog:
|
if self.commit_changelog:
|
||||||
logging.info("No CHANGELOG.md file found in the repository. Creating one...")
|
logging.info("No CHANGELOG.md file found in the repository. Creating one...")
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from pr_agent.algo.git_patch_processing import handle_patch_deletions
|
from pr_agent.algo.git_patch_processing import handle_patch_deletions
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Code Analysis
|
Code Analysis
|
||||||
@ -49,7 +49,7 @@ class TestHandlePatchDeletions:
|
|||||||
original_file_content_str = 'foo\nbar\n'
|
original_file_content_str = 'foo\nbar\n'
|
||||||
new_file_content_str = ''
|
new_file_content_str = ''
|
||||||
file_name = 'file.py'
|
file_name = 'file.py'
|
||||||
settings.config.verbosity_level = 1
|
get_settings().config.verbosity_level = 1
|
||||||
|
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name)
|
handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name)
|
||||||
|
@ -1,50 +0,0 @@
|
|||||||
|
|
||||||
# Generated by CodiumAI
|
|
||||||
from pr_agent.algo.utils import update_settings_from_args
|
|
||||||
import logging
|
|
||||||
from pr_agent.config_loader import settings
|
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
class TestUpdateSettingsFromArgs:
|
|
||||||
# Tests that the function updates the setting when passed a single valid argument.
|
|
||||||
def test_single_valid_argument(self):
|
|
||||||
args = ['--pr_code_suggestions.extra_instructions="be funny"']
|
|
||||||
update_settings_from_args(args)
|
|
||||||
assert settings.pr_code_suggestions.extra_instructions == '"be funny"'
|
|
||||||
|
|
||||||
# Tests that the function updates the settings when passed multiple valid arguments.
|
|
||||||
def test_multiple_valid_arguments(self):
|
|
||||||
args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions=3']
|
|
||||||
update_settings_from_args(args)
|
|
||||||
assert settings.pr_code_suggestions.extra_instructions == '"be funny"'
|
|
||||||
assert settings.pr_code_suggestions.num_code_suggestions == 3
|
|
||||||
|
|
||||||
# Tests that the function updates the setting when passed a boolean value.
|
|
||||||
def test_boolean_values(self):
|
|
||||||
settings.pr_code_suggestions.enabled = False
|
|
||||||
args = ['--pr_code_suggestions.enabled=true']
|
|
||||||
update_settings_from_args(args)
|
|
||||||
assert 'pr_code_suggestions' in settings
|
|
||||||
assert 'enabled' in settings.pr_code_suggestions
|
|
||||||
assert settings.pr_code_suggestions.enabled == True
|
|
||||||
|
|
||||||
# Tests that the function updates the setting when passed an integer value.
|
|
||||||
def test_integer_values(self):
|
|
||||||
args = ['--pr_code_suggestions.num_code_suggestions=3']
|
|
||||||
update_settings_from_args(args)
|
|
||||||
assert settings.pr_code_suggestions.num_code_suggestions == 3
|
|
||||||
|
|
||||||
# Tests that the function does not update any settings when passed an empty argument list.
|
|
||||||
def test_empty_argument_list(self):
|
|
||||||
args = []
|
|
||||||
update_settings_from_args(args)
|
|
||||||
assert settings == settings
|
|
||||||
|
|
||||||
# Tests that the function logs an error when passed an invalid argument format.
|
|
||||||
def test_invalid_argument_format(self, caplog):
|
|
||||||
args = ['--pr_code_suggestions.extra_instructions="be funny"', '--pr_code_suggestions.num_code_suggestions']
|
|
||||||
with caplog.at_level(logging.ERROR):
|
|
||||||
update_settings_from_args(args)
|
|
||||||
assert 'Invalid argument format' in caplog.text
|
|
Reference in New Issue
Block a user