mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-08 06:40:39 +08:00
Support context aware settings (for each incoming request), support override of settings, refactor CLI to use pr_agent.py
This commit is contained in:
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from openai.error import APIError, Timeout, TryAgain, RateLimitError
|
||||
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||
from retry import retry
|
||||
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
||||
OPENAI_RETRIES=5
|
||||
|
||||
@ -21,16 +21,16 @@ class AiHandler:
|
||||
Raises a ValueError if the OpenAI key is missing.
|
||||
"""
|
||||
try:
|
||||
openai.api_key = settings.openai.key
|
||||
if settings.get("OPENAI.ORG", None):
|
||||
openai.organization = settings.openai.org
|
||||
self.deployment_id = settings.get("OPENAI.DEPLOYMENT_ID", None)
|
||||
if settings.get("OPENAI.API_TYPE", None):
|
||||
openai.api_type = settings.openai.api_type
|
||||
if settings.get("OPENAI.API_VERSION", None):
|
||||
openai.api_version = settings.openai.api_version
|
||||
if settings.get("OPENAI.API_BASE", None):
|
||||
openai.api_base = settings.openai.api_base
|
||||
openai.api_key = get_settings().openai.key
|
||||
if get_settings().get("OPENAI.ORG", None):
|
||||
openai.organization = get_settings().openai.org
|
||||
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
if get_settings().get("OPENAI.API_TYPE", None):
|
||||
openai.api_type = get_settings().openai.api_type
|
||||
if get_settings().get("OPENAI.API_VERSION", None):
|
||||
openai.api_version = get_settings().openai.api_version
|
||||
if get_settings().get("OPENAI.API_BASE", None):
|
||||
openai.api_base = get_settings().openai.api_base
|
||||
except AttributeError as e:
|
||||
raise ValueError("OpenAI key is required") from e
|
||||
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import re
|
||||
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
@ -55,7 +55,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
continue
|
||||
extended_patch_lines.append(line)
|
||||
except Exception as e:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
logging.error(f"Failed to extend patch: {e}")
|
||||
return patch_str
|
||||
|
||||
@ -126,14 +126,14 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
"""
|
||||
if not new_file_content_str:
|
||||
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
||||
if settings.config.verbosity_level > 0:
|
||||
if get_settings().config.verbosity_level > 0:
|
||||
logging.info(f"Processing file: {file_name}, minimizing deletion file")
|
||||
patch = None # file was deleted
|
||||
else:
|
||||
patch_lines = patch.splitlines()
|
||||
patch_new = omit_deletion_hunks(patch_lines)
|
||||
if patch != patch_new:
|
||||
if settings.config.verbosity_level > 0:
|
||||
if get_settings().config.verbosity_level > 0:
|
||||
logging.info(f"Processing file: {file_name}, hunks were deleted")
|
||||
patch = patch_new
|
||||
return patch
|
||||
@ -141,7 +141,8 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
|
||||
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
||||
"""
|
||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file.
|
||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
|
||||
the file.
|
||||
|
||||
Args:
|
||||
patch (str): The patch string to be converted.
|
||||
|
@ -1,15 +1,15 @@
|
||||
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
|
||||
from typing import Dict
|
||||
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
||||
language_extension_map_org = settings.language_extension_map_org
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||
|
||||
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
|
||||
bad_extensions = settings.bad_extensions.default
|
||||
if settings.config.use_extra_bad_extensions:
|
||||
bad_extensions += settings.bad_extensions.extra
|
||||
bad_extensions = get_settings().bad_extensions.default
|
||||
if get_settings().config.use_extra_bad_extensions:
|
||||
bad_extensions += get_settings().bad_extensions.extra
|
||||
|
||||
|
||||
def filter_bad_extensions(files):
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Tuple, Union, Callable, List
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from github import RateLimitExceededException
|
||||
|
||||
@ -10,7 +10,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
|
||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_large_diff
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
|
||||
DELETED_FILES_ = "Deleted files:\n"
|
||||
@ -27,11 +27,15 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
|
||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
|
||||
request.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
|
||||
pull request.
|
||||
model (str): The name of the model used for tokenization.
|
||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
|
||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
|
||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
|
||||
diff. Defaults to False.
|
||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
|
||||
extra lines of context. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
@ -76,10 +80,12 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool) -> \
|
||||
Tuple[list, int]:
|
||||
"""
|
||||
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed.
|
||||
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
|
||||
minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
||||
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
|
||||
files.
|
||||
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
|
||||
|
||||
@ -119,10 +125,13 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
||||
"""
|
||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
|
||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
|
||||
tokens used.
|
||||
Args:
|
||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their
|
||||
corresponding files.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
|
||||
pull request.
|
||||
model (str): The model used for tokenization.
|
||||
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
||||
Returns:
|
||||
@ -181,7 +190,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
# Current logic is to skip the patch if it's too large
|
||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||
# until we meet the requirements
|
||||
if settings.config.verbosity_level >= 2:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
logging.warning(f"Patch too large, minimizing it, {file.filename}")
|
||||
if not modified_files_list:
|
||||
total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
|
||||
@ -196,15 +205,15 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
patch_final = patch
|
||||
patches.append(patch_final)
|
||||
total_tokens += token_handler.count_tokens(patch_final)
|
||||
if settings.config.verbosity_level >= 2:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
||||
|
||||
return patches, modified_files_list, deleted_files_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable):
|
||||
model = settings.config.model
|
||||
fallback_models = settings.config.fallback_models
|
||||
model = get_settings().config.model
|
||||
fallback_models = get_settings().config.fallback_models
|
||||
if not isinstance(fallback_models, list):
|
||||
fallback_models = [fallback_models]
|
||||
all_models = [model] + fallback_models
|
||||
|
@ -1,8 +1,7 @@
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
class TokenHandler:
|
||||
@ -10,9 +9,12 @@ class TokenHandler:
|
||||
A class for handling tokens in the context of a pull request.
|
||||
|
||||
Attributes:
|
||||
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them.
|
||||
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module.
|
||||
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method.
|
||||
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the
|
||||
number of tokens in them.
|
||||
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the
|
||||
pr_agent.algo module.
|
||||
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens
|
||||
method.
|
||||
"""
|
||||
|
||||
def __init__(self, pr, vars: dict, system, user):
|
||||
@ -25,7 +27,7 @@ class TokenHandler:
|
||||
- system: The system string.
|
||||
- user: The user string.
|
||||
"""
|
||||
self.encoder = encoding_for_model(settings.config.model)
|
||||
self.encoder = encoding_for_model(get_settings().config.model)
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
|
@ -1,15 +1,24 @@
|
||||
from __future__ import annotations
|
||||
from typing import List
|
||||
|
||||
import difflib
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from pr_agent.config_loader import settings
|
||||
from starlette_context import context
|
||||
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
|
||||
|
||||
def get_setting(key: str) -> Any:
|
||||
try:
|
||||
key = key.upper()
|
||||
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
|
||||
except Exception:
|
||||
return global_settings.get(key, None)
|
||||
|
||||
def convert_to_markdown(output_data: dict) -> str:
|
||||
"""
|
||||
@ -97,12 +106,16 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
|
||||
- data: A dictionary containing the parsed JSON data.
|
||||
|
||||
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
|
||||
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message.
|
||||
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion.
|
||||
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines.
|
||||
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the
|
||||
message.
|
||||
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON
|
||||
message by parsing until the last valid code suggestion.
|
||||
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or
|
||||
newlines.
|
||||
It tries to parse the JSON message with the closing bracket and checks if it is valid.
|
||||
If the JSON message is valid, the parsed JSON data is returned.
|
||||
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached.
|
||||
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON
|
||||
message is obtained or the maximum number of iterations is reached.
|
||||
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
|
||||
"""
|
||||
|
||||
@ -184,7 +197,8 @@ def convert_str_to_datetime(date_str):
|
||||
|
||||
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
|
||||
"""
|
||||
Generate a patch for a modified file by comparing the original content of the file with the new content provided as input.
|
||||
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
|
||||
input.
|
||||
|
||||
Args:
|
||||
file: The file object for which the patch needs to be generated.
|
||||
@ -199,14 +213,16 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
|
||||
None.
|
||||
|
||||
Additional Information:
|
||||
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output.
|
||||
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created.
|
||||
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it
|
||||
as output.
|
||||
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating
|
||||
that the file was modified but no patch was found, and a patch is manually created.
|
||||
"""
|
||||
if not patch: # to Do - also add condition for file extension
|
||||
try:
|
||||
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
|
||||
new_file_content_str.splitlines(keepends=True))
|
||||
if settings.config.verbosity_level >= 2:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
|
||||
patch = ''.join(diff)
|
||||
except Exception:
|
||||
@ -214,7 +230,7 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
|
||||
return patch
|
||||
|
||||
|
||||
def update_settings_from_args(args: List[str]) -> None:
|
||||
def update_settings_from_args(args: List[str]) -> List[str]:
|
||||
"""
|
||||
Update the settings of the Dynaconf object based on the arguments passed to the function.
|
||||
|
||||
@ -230,28 +246,22 @@ def update_settings_from_args(args: List[str]) -> None:
|
||||
ValueError: If the argument is not in the correct format.
|
||||
|
||||
"""
|
||||
other_args = []
|
||||
if args:
|
||||
for arg in args:
|
||||
try:
|
||||
arg = arg.strip()
|
||||
if arg.startswith('--'):
|
||||
arg = arg.strip('-').strip()
|
||||
vals = arg.split('=')
|
||||
if len(vals) != 2:
|
||||
raise ValueError(f'Invalid argument format: {arg}')
|
||||
logging.error(f'Invalid argument format: {arg}')
|
||||
other_args.append(arg)
|
||||
continue
|
||||
key, value = vals
|
||||
keys = key.split('.')
|
||||
d = settings
|
||||
for i, k in enumerate(keys[:-1]):
|
||||
if k not in d:
|
||||
raise ValueError(f'Invalid setting: {key}')
|
||||
d = d[k]
|
||||
if keys[-1] not in d:
|
||||
raise ValueError(f'Invalid setting: {key}')
|
||||
if isinstance(d[keys[-1]], bool):
|
||||
d[keys[-1]] = value.lower() in ("yes", "true", "t", "1")
|
||||
else:
|
||||
d[keys[-1]] = type(d[keys[-1]])(value)
|
||||
key = key.strip().upper()
|
||||
value = value.strip()
|
||||
get_settings().set(key, value)
|
||||
logging.info(f'Updated setting {key} to: "{value}"')
|
||||
except ValueError as e:
|
||||
logging.error(str(e))
|
||||
except Exception as e:
|
||||
logging.error(f'Failed to parse argument {arg}: {e}')
|
||||
else:
|
||||
other_args.append(arg)
|
||||
return other_args
|
||||
|
Reference in New Issue
Block a user