diff --git a/README.md b/README.md index d30c82e5..7b419225 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@
- - + +
@@ -14,9 +14,9 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull-requests faster and more efficiently. It automatically analyzes the pull-request and can provide several types of feedback: -**Auto-Description**: Automatically generating PR description - name, type, summary, and code walkthrough. +**Auto-Description**: Automatically generating PR description - title, type, summary, code walkthrough and PR labels. \ -**PR Review**: Feedback about the PR main theme, type, relevant tests, security issues, focused PR, and various suggestions for the PR content. +**PR Review**: Adjustable feedback about the PR main theme, type, relevant tests, security issues, focus, score, and various suggestions for the PR content. \ **Question Answering**: Answering free-text questions about the PR. \ @@ -128,7 +128,7 @@ There are several ways to use PR-Agent: - The "PR Q&A" tool answers free-text questions about the PR. - The "PR Description" tool automatically sets the PR Title and body. - The "PR Code Suggestion" tool provide inline code suggestions for the PR that can be applied and committed. -- The "PR Reflect and Review" tool first initiates a dialog with the user and asks them to reflect on the PR, and then provides a review. +- The "PR Reflect and Review" tool initiates a dialog with the user, asks them to reflect on the PR, and then provides a more focused review. ## How it works @@ -138,9 +138,9 @@ Check out the [PR Compression strategy](./PR_COMPRESSION.md) page for more detai ## Roadmap -- [ ] Support open-source models, as a replacement for openai models. (Note - a minimal requirement for each open-source model is to have 8k+ context, and good support for generating json as an output) +- [ ] Support open-source models, as a replacement for OpenAI models. (Note - a minimal requirement for each open-source model is to have 8k+ context, and good support for generating JSON as an output) - [x] Support other Git providers, such as Gitlab and Bitbucket. -- [ ] Develop additional logics for handling large PRs, and compressing git patches +- [ ] Develop additional logic for handling large PRs, and compressing git patches - [ ] Add additional context to the prompt. For example, repo (or relevant files) summarization, with tools such a [ctags](https://github.com/universal-ctags/ctags) - [ ] Adding more tools. Possible directions: - [x] PR description diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 0f3f13f5..a97b97ac 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -9,7 +9,17 @@ from pr_agent.config_loader import settings OPENAI_RETRIES=2 class AiHandler: + """ + This class handles interactions with the OpenAI API for chat completions. + It initializes the API key and other settings from a configuration file, + and provides a method for performing chat completions using the OpenAI ChatCompletion API. + """ + def __init__(self): + """ + Initializes the OpenAI API key and other settings from a configuration file. + Raises a ValueError if the OpenAI key is missing. + """ try: openai.api_key = settings.openai.key if settings.get("OPENAI.ORG", None): @@ -27,6 +37,25 @@ class AiHandler: @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError), tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) async def chat_completion(self, model: str, temperature: float, system: str, user: str): + """ + Performs a chat completion using the OpenAI ChatCompletion API. + Retries in case of API errors or timeouts. + + Args: + model (str): The model to use for chat completion. + temperature (float): The temperature parameter for chat completion. + system (str): The system message for chat completion. + user (str): The user message for chat completion. + + Returns: + tuple: A tuple containing the response and finish reason from the API. + + Raises: + TryAgain: If the API response is empty or there are no choices in the response. + APIError: If there is an error during OpenAI inference. + Timeout: If there is a timeout during OpenAI inference. + TryAgain: If there is an attribute error during OpenAI inference. + """ try: response = await openai.ChatCompletion.acreate( model=model, @@ -44,4 +73,4 @@ class AiHandler: raise TryAgain resp = response.choices[0]['message']['content'] finish_reason = response.choices[0].finish_reason - return resp, finish_reason + return resp, finish_reason \ No newline at end of file diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index d8aa1802..8128da48 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -8,7 +8,15 @@ from pr_agent.config_loader import settings def extend_patch(original_file_str, patch_str, num_lines) -> str: """ - Extends the patch to include 'num_lines' more surrounding lines + Extends the given patch to include a specified number of surrounding lines. + + Args: + original_file_str (str): The original file to which the patch will be applied. + patch_str (str): The patch to be applied to the original file. + num_lines (int): The number of surrounding lines to include in the extended patch. + + Returns: + str: The extended patch string. """ if not patch_str or num_lines == 0: return patch_str @@ -61,6 +69,14 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: def omit_deletion_hunks(patch_lines) -> str: + """ + Omit deletion hunks from the patch and return the modified patch. + Args: + - patch_lines: a list of strings representing the lines of the patch + Returns: + - A string representing the modified patch with deletion hunks omitted + """ + temp_hunk = [] added_patched = [] add_hunk = False @@ -93,7 +109,20 @@ def omit_deletion_hunks(patch_lines) -> str: def handle_patch_deletions(patch: str, original_file_content_str: str, new_file_content_str: str, file_name: str) -> str: """ - Handle entire file or deletion patches + Handle entire file or deletion patches. + + This function takes a patch, original file content, new file content, and file name as input. + It handles entire file or deletion patches and returns the modified patch with deletion hunks omitted. + + Args: + patch (str): The patch to be handled. + original_file_content_str (str): The original content of the file. + new_file_content_str (str): The new content of the file. + file_name (str): The name of the file. + + Returns: + str: The modified patch with deletion hunks omitted. + """ if not new_file_content_str: # logic for handling deleted files - don't show patch, just show that the file was deleted @@ -111,20 +140,26 @@ def handle_patch_deletions(patch: str, original_file_content_str: str, def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: - # toDO: (maybe remove '-' and '+' from the beginning of the line) """ - ## src/file.ts + Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file. + + Args: + patch (str): The patch string to be converted. + file: An object containing the filename of the file being patched. + + Returns: + str: A string with line numbers for each hunk, indicating the new and old content of the file. + + example output: +## src/file.ts --new hunk-- 881 line1 882 line2 883 line3 -884 line4 -885 line6 -886 line7 -887 + line8 -888 + line9 -889 line10 -890 line11 +887 + line4 +888 + line5 +889 line6 +890 line7 ... --old hunk-- line1 @@ -134,8 +169,8 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: line5 line6 ... - """ + patch_with_lines_str = f"## {file.filename}\n" import re patch_lines = patch.splitlines() diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 11f16449..8bfaac50 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -20,12 +20,21 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 PATCH_EXTRA_LINES = 3 -def get_pr_diff(git_provider: Union[GitProvider], token_handler: TokenHandler, +def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str: """ - Returns a string with the diff of the PR. - If needed, apply diff minimization techniques to reduce the number of tokens + Returns a string with the diff of the pull request, applying diff minimization techniques if needed. + + Args: + git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request. + token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. + add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False. + disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False. + + Returns: + str: A string with the diff of the pull request, applying diff minimization techniques if needed. """ + if disable_extra_lines: global PATCH_EXTRA_LINES PATCH_EXTRA_LINES = 0 @@ -61,7 +70,16 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, add_line_numbers_to_hunks: bool) -> \ Tuple[list, int]: """ - Generate a standard diff string, with patch extension + Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed. + + Args: + - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files. + - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request. + - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff. + + Returns: + - patches_extended: A list of extended patches for each file in the pull request. + - total_tokens: The total number of tokens used in the extended patches. """ total_tokens = token_handler.prompt_tokens # initial tokens patches_extended = [] @@ -94,12 +112,26 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]: - # Apply Diff Minimization techniques to reduce the number of tokens: - # 0. Start from the largest diff patch to smaller ones - # 1. Don't use extend context lines around diff - # 2. Minimize deleted files - # 3. Minimize deleted hunks - # 4. Minimize all remaining files when you reach token limit + """ + Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used. + Args: + top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files. + token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. + convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff. + Returns: + Tuple[list, list, list]: A tuple containing the following lists: + - patches: A list of compressed diff patches for each file in the pull request. + - modified_files_list: A list of file names that were skipped due to large patch size. + - deleted_files_list: A list of file names that were deleted in the pull request. + + Minimization techniques to reduce the number of tokens: + 0. Start from the largest diff patch to smaller ones + 1. Don't use extend context lines around diff + 2. Minimize deleted files + 3. Minimize deleted hunks + 4. Minimize all remaining files when you reach token limit + """ + patches = [] modified_files_list = [] diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 53fb3ac9..19d03df3 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -6,12 +6,43 @@ from pr_agent.config_loader import settings class TokenHandler: + """ + A class for handling tokens in the context of a pull request. + + Attributes: + - encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them. + - limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module. + - prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method. + """ + def __init__(self, pr, vars: dict, system, user): + """ + Initializes the TokenHandler object. + + Args: + - pr: The pull request object. + - vars: A dictionary of variables. + - system: The system string. + - user: The user string. + """ self.encoder = encoding_for_model(settings.config.model) self.limit = MAX_TOKENS[settings.config.model] self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): + """ + Calculates the number of tokens in the system and user strings. + + Args: + - pr: The pull request object. + - encoder: An object of the encoding_for_model class from the tiktoken module. + - vars: A dictionary of variables. + - system: The system string. + - user: The user string. + + Returns: + The sum of the number of tokens in the system and user strings. + """ environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(system).render(vars) user_prompt = environment.from_string(user).render(vars) @@ -21,4 +52,13 @@ class TokenHandler: return system_prompt_tokens + user_prompt_tokens def count_tokens(self, patch: str) -> int: + """ + Counts the number of tokens in a given patch string. + + Args: + - patch: The patch string. + + Returns: + The number of tokens in the patch string. + """ return len(self.encoder.encode(patch, disallowed_special=())) \ No newline at end of file diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 673a3570..23bf8dfd 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -11,6 +11,13 @@ from pr_agent.config_loader import settings def convert_to_markdown(output_data: dict) -> str: + """ + Convert a dictionary of data into markdown format. + Args: + output_data (dict): A dictionary containing data to be converted to markdown format. + Returns: + str: The markdown formatted text generated from the input dictionary. + """ markdown_text = "" emojis = { @@ -49,6 +56,15 @@ def convert_to_markdown(output_data: dict) -> str: def parse_code_suggestion(code_suggestions: dict) -> str: + """ + Convert a dictionary of data into markdown format. + + Args: + code_suggestions (dict): A dictionary containing data to be converted to markdown format. + + Returns: + str: A string containing the markdown formatted text generated from the input dictionary. + """ markdown_text = "" for sub_key, sub_value in code_suggestions.items(): if isinstance(sub_value, dict): # "code example" @@ -68,18 +84,41 @@ def parse_code_suggestion(code_suggestions: dict) -> str: def try_fix_json(review, max_iter=10, code_suggestions=False): + """ + Fix broken or incomplete JSON messages and return the parsed JSON data. + + Args: + - review: A string containing the JSON message to be fixed. + - max_iter: An integer representing the maximum number of iterations to try and fix the JSON message. + - code_suggestions: A boolean indicating whether to try and fix JSON messages with code suggestions. + + Returns: + - data: A dictionary containing the parsed JSON data. + + The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion. + If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message. + If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion. + The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines. + It tries to parse the JSON message with the closing bracket and checks if it is valid. + If the JSON message is valid, the parsed JSON data is returned. + If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached. + If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned. + """ + if review.endswith("}"): return fix_json_escape_char(review) - # Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion + data = {} if code_suggestions: closing_bracket = "]}" else: closing_bracket = "]}}" + if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0: last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 valid_json = False iter_count = 0 + while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter: try: data = json.loads(review[:last_code_suggestion_ind] + closing_bracket) @@ -87,16 +126,30 @@ def try_fix_json(review, max_iter=10, code_suggestions=False): review = review[:last_code_suggestion_ind].strip() + closing_bracket except json.decoder.JSONDecodeError: review = review[:last_code_suggestion_ind] - # Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 iter_count += 1 + if not valid_json: logging.error("Unable to decode JSON response from AI") data = {} + return data def fix_json_escape_char(json_message=None): + """ + Fix broken or incomplete JSON messages and return the parsed JSON data. + + Args: + json_message (str): A string containing the JSON message to be fixed. + + Returns: + dict: A dictionary containing the parsed JSON data. + + Raises: + None + + """ try: result = json.loads(json_message) except Exception as e: @@ -111,11 +164,43 @@ def fix_json_escape_char(json_message=None): def convert_str_to_datetime(date_str): + """ + Convert a string representation of a date and time into a datetime object. + + Args: + date_str (str): A string representation of a date and time in the format '%a, %d %b %Y %H:%M:%S %Z' + + Returns: + datetime: A datetime object representing the input date and time. + + Example: + >>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC') + datetime.datetime(2022, 1, 1, 12, 0, 0) + """ datetime_format = '%a, %d %b %Y %H:%M:%S %Z' return datetime.strptime(date_str, datetime_format) def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str: + """ + Generate a patch for a modified file by comparing the original content of the file with the new content provided as input. + + Args: + file: The file object for which the patch needs to be generated. + new_file_content_str: The new content of the file as a string. + original_file_content_str: The original content of the file as a string. + patch: An optional patch string that can be provided as input. + + Returns: + The generated or provided patch string. + + Raises: + None. + + Additional Information: + - If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output. + - If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created. + """ if not patch: # to Do - also add condition for file extension try: diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),