mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
docstring
This commit is contained in:
@ -9,7 +9,17 @@ from pr_agent.config_loader import settings
|
||||
OPENAI_RETRIES=2
|
||||
|
||||
class AiHandler:
|
||||
"""
|
||||
This class handles interactions with the OpenAI API for chat completions.
|
||||
It initializes the API key and other settings from a configuration file,
|
||||
and provides a method for performing chat completions using the OpenAI ChatCompletion API.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the OpenAI API key and other settings from a configuration file.
|
||||
Raises a ValueError if the OpenAI key is missing.
|
||||
"""
|
||||
try:
|
||||
openai.api_key = settings.openai.key
|
||||
if settings.get("OPENAI.ORG", None):
|
||||
@ -27,6 +37,25 @@ class AiHandler:
|
||||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError),
|
||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
|
||||
"""
|
||||
Performs a chat completion using the OpenAI ChatCompletion API.
|
||||
Retries in case of API errors or timeouts.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for chat completion.
|
||||
temperature (float): The temperature parameter for chat completion.
|
||||
system (str): The system message for chat completion.
|
||||
user (str): The user message for chat completion.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the response and finish reason from the API.
|
||||
|
||||
Raises:
|
||||
TryAgain: If the API response is empty or there are no choices in the response.
|
||||
APIError: If there is an error during OpenAI inference.
|
||||
Timeout: If there is a timeout during OpenAI inference.
|
||||
TryAgain: If there is an attribute error during OpenAI inference.
|
||||
"""
|
||||
try:
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
model=model,
|
||||
@ -44,4 +73,4 @@ class AiHandler:
|
||||
raise TryAgain
|
||||
resp = response.choices[0]['message']['content']
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
return resp, finish_reason
|
||||
return resp, finish_reason
|
@ -8,7 +8,15 @@ from pr_agent.config_loader import settings
|
||||
|
||||
def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
"""
|
||||
Extends the patch to include 'num_lines' more surrounding lines
|
||||
Extends the given patch to include a specified number of surrounding lines.
|
||||
|
||||
Args:
|
||||
original_file_str (str): The original file to which the patch will be applied.
|
||||
patch_str (str): The patch to be applied to the original file.
|
||||
num_lines (int): The number of surrounding lines to include in the extended patch.
|
||||
|
||||
Returns:
|
||||
str: The extended patch string.
|
||||
"""
|
||||
if not patch_str or num_lines == 0:
|
||||
return patch_str
|
||||
@ -61,6 +69,14 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
|
||||
|
||||
def omit_deletion_hunks(patch_lines) -> str:
|
||||
"""
|
||||
Omit deletion hunks from the patch and return the modified patch.
|
||||
Args:
|
||||
- patch_lines: a list of strings representing the lines of the patch
|
||||
Returns:
|
||||
- A string representing the modified patch with deletion hunks omitted
|
||||
"""
|
||||
|
||||
temp_hunk = []
|
||||
added_patched = []
|
||||
add_hunk = False
|
||||
@ -93,7 +109,20 @@ def omit_deletion_hunks(patch_lines) -> str:
|
||||
def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
new_file_content_str: str, file_name: str) -> str:
|
||||
"""
|
||||
Handle entire file or deletion patches
|
||||
Handle entire file or deletion patches.
|
||||
|
||||
This function takes a patch, original file content, new file content, and file name as input.
|
||||
It handles entire file or deletion patches and returns the modified patch with deletion hunks omitted.
|
||||
|
||||
Args:
|
||||
patch (str): The patch to be handled.
|
||||
original_file_content_str (str): The original content of the file.
|
||||
new_file_content_str (str): The new content of the file.
|
||||
file_name (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
str: The modified patch with deletion hunks omitted.
|
||||
|
||||
"""
|
||||
if not new_file_content_str:
|
||||
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
||||
@ -111,20 +140,26 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
|
||||
|
||||
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
||||
# toDO: (maybe remove '-' and '+' from the beginning of the line)
|
||||
"""
|
||||
## src/file.ts
|
||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file.
|
||||
|
||||
Args:
|
||||
patch (str): The patch string to be converted.
|
||||
file: An object containing the filename of the file being patched.
|
||||
|
||||
Returns:
|
||||
str: A string with line numbers for each hunk, indicating the new and old content of the file.
|
||||
|
||||
example output:
|
||||
## src/file.ts
|
||||
--new hunk--
|
||||
881 line1
|
||||
882 line2
|
||||
883 line3
|
||||
884 line4
|
||||
885 line6
|
||||
886 line7
|
||||
887 + line8
|
||||
888 + line9
|
||||
889 line10
|
||||
890 line11
|
||||
887 + line4
|
||||
888 + line5
|
||||
889 line6
|
||||
890 line7
|
||||
...
|
||||
--old hunk--
|
||||
line1
|
||||
@ -134,8 +169,8 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
||||
line5
|
||||
line6
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
patch_with_lines_str = f"## {file.filename}\n"
|
||||
import re
|
||||
patch_lines = patch.splitlines()
|
||||
|
@ -20,12 +20,21 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
||||
PATCH_EXTRA_LINES = 3
|
||||
|
||||
|
||||
def get_pr_diff(git_provider: Union[GitProvider], token_handler: TokenHandler,
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str:
|
||||
"""
|
||||
Returns a string with the diff of the PR.
|
||||
If needed, apply diff minimization techniques to reduce the number of tokens
|
||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
|
||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
"""
|
||||
|
||||
if disable_extra_lines:
|
||||
global PATCH_EXTRA_LINES
|
||||
PATCH_EXTRA_LINES = 0
|
||||
@ -61,7 +70,16 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool) -> \
|
||||
Tuple[list, int]:
|
||||
"""
|
||||
Generate a standard diff string, with patch extension
|
||||
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
||||
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
|
||||
|
||||
Returns:
|
||||
- patches_extended: A list of extended patches for each file in the pull request.
|
||||
- total_tokens: The total number of tokens used in the extended patches.
|
||||
"""
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
patches_extended = []
|
||||
@ -94,12 +112,26 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
||||
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
||||
# Apply Diff Minimization techniques to reduce the number of tokens:
|
||||
# 0. Start from the largest diff patch to smaller ones
|
||||
# 1. Don't use extend context lines around diff
|
||||
# 2. Minimize deleted files
|
||||
# 3. Minimize deleted hunks
|
||||
# 4. Minimize all remaining files when you reach token limit
|
||||
"""
|
||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
|
||||
Args:
|
||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
||||
Returns:
|
||||
Tuple[list, list, list]: A tuple containing the following lists:
|
||||
- patches: A list of compressed diff patches for each file in the pull request.
|
||||
- modified_files_list: A list of file names that were skipped due to large patch size.
|
||||
- deleted_files_list: A list of file names that were deleted in the pull request.
|
||||
|
||||
Minimization techniques to reduce the number of tokens:
|
||||
0. Start from the largest diff patch to smaller ones
|
||||
1. Don't use extend context lines around diff
|
||||
2. Minimize deleted files
|
||||
3. Minimize deleted hunks
|
||||
4. Minimize all remaining files when you reach token limit
|
||||
"""
|
||||
|
||||
|
||||
patches = []
|
||||
modified_files_list = []
|
||||
|
@ -6,12 +6,43 @@ from pr_agent.config_loader import settings
|
||||
|
||||
|
||||
class TokenHandler:
|
||||
"""
|
||||
A class for handling tokens in the context of a pull request.
|
||||
|
||||
Attributes:
|
||||
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them.
|
||||
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module.
|
||||
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method.
|
||||
"""
|
||||
|
||||
def __init__(self, pr, vars: dict, system, user):
|
||||
"""
|
||||
Initializes the TokenHandler object.
|
||||
|
||||
Args:
|
||||
- pr: The pull request object.
|
||||
- vars: A dictionary of variables.
|
||||
- system: The system string.
|
||||
- user: The user string.
|
||||
"""
|
||||
self.encoder = encoding_for_model(settings.config.model)
|
||||
self.limit = MAX_TOKENS[settings.config.model]
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
"""
|
||||
Calculates the number of tokens in the system and user strings.
|
||||
|
||||
Args:
|
||||
- pr: The pull request object.
|
||||
- encoder: An object of the encoding_for_model class from the tiktoken module.
|
||||
- vars: A dictionary of variables.
|
||||
- system: The system string.
|
||||
- user: The user string.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens in the system and user strings.
|
||||
"""
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(system).render(vars)
|
||||
user_prompt = environment.from_string(user).render(vars)
|
||||
@ -21,4 +52,13 @@ class TokenHandler:
|
||||
return system_prompt_tokens + user_prompt_tokens
|
||||
|
||||
def count_tokens(self, patch: str) -> int:
|
||||
"""
|
||||
Counts the number of tokens in a given patch string.
|
||||
|
||||
Args:
|
||||
- patch: The patch string.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the patch string.
|
||||
"""
|
||||
return len(self.encoder.encode(patch, disallowed_special=()))
|
@ -11,6 +11,13 @@ from pr_agent.config_loader import settings
|
||||
|
||||
|
||||
def convert_to_markdown(output_data: dict) -> str:
|
||||
"""
|
||||
Convert a dictionary of data into markdown format.
|
||||
Args:
|
||||
output_data (dict): A dictionary containing data to be converted to markdown format.
|
||||
Returns:
|
||||
str: The markdown formatted text generated from the input dictionary.
|
||||
"""
|
||||
markdown_text = ""
|
||||
|
||||
emojis = {
|
||||
@ -49,6 +56,15 @@ def convert_to_markdown(output_data: dict) -> str:
|
||||
|
||||
|
||||
def parse_code_suggestion(code_suggestions: dict) -> str:
|
||||
"""
|
||||
Convert a dictionary of data into markdown format.
|
||||
|
||||
Args:
|
||||
code_suggestions (dict): A dictionary containing data to be converted to markdown format.
|
||||
|
||||
Returns:
|
||||
str: A string containing the markdown formatted text generated from the input dictionary.
|
||||
"""
|
||||
markdown_text = ""
|
||||
for sub_key, sub_value in code_suggestions.items():
|
||||
if isinstance(sub_value, dict): # "code example"
|
||||
@ -68,18 +84,41 @@ def parse_code_suggestion(code_suggestions: dict) -> str:
|
||||
|
||||
|
||||
def try_fix_json(review, max_iter=10, code_suggestions=False):
|
||||
"""
|
||||
Fix broken or incomplete JSON messages and return the parsed JSON data.
|
||||
|
||||
Args:
|
||||
- review: A string containing the JSON message to be fixed.
|
||||
- max_iter: An integer representing the maximum number of iterations to try and fix the JSON message.
|
||||
- code_suggestions: A boolean indicating whether to try and fix JSON messages with code suggestions.
|
||||
|
||||
Returns:
|
||||
- data: A dictionary containing the parsed JSON data.
|
||||
|
||||
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
|
||||
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message.
|
||||
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion.
|
||||
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines.
|
||||
It tries to parse the JSON message with the closing bracket and checks if it is valid.
|
||||
If the JSON message is valid, the parsed JSON data is returned.
|
||||
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached.
|
||||
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
|
||||
"""
|
||||
|
||||
if review.endswith("}"):
|
||||
return fix_json_escape_char(review)
|
||||
# Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion
|
||||
|
||||
data = {}
|
||||
if code_suggestions:
|
||||
closing_bracket = "]}"
|
||||
else:
|
||||
closing_bracket = "]}}"
|
||||
|
||||
if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0:
|
||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
||||
valid_json = False
|
||||
iter_count = 0
|
||||
|
||||
while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter:
|
||||
try:
|
||||
data = json.loads(review[:last_code_suggestion_ind] + closing_bracket)
|
||||
@ -87,16 +126,30 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
|
||||
review = review[:last_code_suggestion_ind].strip() + closing_bracket
|
||||
except json.decoder.JSONDecodeError:
|
||||
review = review[:last_code_suggestion_ind]
|
||||
# Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines
|
||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
||||
iter_count += 1
|
||||
|
||||
if not valid_json:
|
||||
logging.error("Unable to decode JSON response from AI")
|
||||
data = {}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def fix_json_escape_char(json_message=None):
|
||||
"""
|
||||
Fix broken or incomplete JSON messages and return the parsed JSON data.
|
||||
|
||||
Args:
|
||||
json_message (str): A string containing the JSON message to be fixed.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the parsed JSON data.
|
||||
|
||||
Raises:
|
||||
None
|
||||
|
||||
"""
|
||||
try:
|
||||
result = json.loads(json_message)
|
||||
except Exception as e:
|
||||
@ -111,11 +164,43 @@ def fix_json_escape_char(json_message=None):
|
||||
|
||||
|
||||
def convert_str_to_datetime(date_str):
|
||||
"""
|
||||
Convert a string representation of a date and time into a datetime object.
|
||||
|
||||
Args:
|
||||
date_str (str): A string representation of a date and time in the format '%a, %d %b %Y %H:%M:%S %Z'
|
||||
|
||||
Returns:
|
||||
datetime: A datetime object representing the input date and time.
|
||||
|
||||
Example:
|
||||
>>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC')
|
||||
datetime.datetime(2022, 1, 1, 12, 0, 0)
|
||||
"""
|
||||
datetime_format = '%a, %d %b %Y %H:%M:%S %Z'
|
||||
return datetime.strptime(date_str, datetime_format)
|
||||
|
||||
|
||||
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
|
||||
"""
|
||||
Generate a patch for a modified file by comparing the original content of the file with the new content provided as input.
|
||||
|
||||
Args:
|
||||
file: The file object for which the patch needs to be generated.
|
||||
new_file_content_str: The new content of the file as a string.
|
||||
original_file_content_str: The original content of the file as a string.
|
||||
patch: An optional patch string that can be provided as input.
|
||||
|
||||
Returns:
|
||||
The generated or provided patch string.
|
||||
|
||||
Raises:
|
||||
None.
|
||||
|
||||
Additional Information:
|
||||
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output.
|
||||
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created.
|
||||
"""
|
||||
if not patch: # to Do - also add condition for file extension
|
||||
try:
|
||||
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
|
||||
|
Reference in New Issue
Block a user