diff --git a/README.md b/README.md
index d30c82e5..7b419225 100644
--- a/README.md
+++ b/README.md
@@ -2,8 +2,8 @@
@@ -14,9 +14,9 @@
CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull-requests faster and more efficiently. It automatically analyzes the pull-request and can provide several types of feedback:
-**Auto-Description**: Automatically generating PR description - name, type, summary, and code walkthrough.
+**Auto-Description**: Automatically generating PR description - title, type, summary, code walkthrough and PR labels.
\
-**PR Review**: Feedback about the PR main theme, type, relevant tests, security issues, focused PR, and various suggestions for the PR content.
+**PR Review**: Adjustable feedback about the PR main theme, type, relevant tests, security issues, focus, score, and various suggestions for the PR content.
\
**Question Answering**: Answering free-text questions about the PR.
\
@@ -128,7 +128,7 @@ There are several ways to use PR-Agent:
- The "PR Q&A" tool answers free-text questions about the PR.
- The "PR Description" tool automatically sets the PR Title and body.
- The "PR Code Suggestion" tool provide inline code suggestions for the PR that can be applied and committed.
-- The "PR Reflect and Review" tool first initiates a dialog with the user and asks them to reflect on the PR, and then provides a review.
+- The "PR Reflect and Review" tool initiates a dialog with the user, asks them to reflect on the PR, and then provides a more focused review.
## How it works
@@ -138,9 +138,9 @@ Check out the [PR Compression strategy](./PR_COMPRESSION.md) page for more detai
## Roadmap
-- [ ] Support open-source models, as a replacement for openai models. (Note - a minimal requirement for each open-source model is to have 8k+ context, and good support for generating json as an output)
+- [ ] Support open-source models, as a replacement for OpenAI models. (Note - a minimal requirement for each open-source model is to have 8k+ context, and good support for generating JSON as an output)
- [x] Support other Git providers, such as Gitlab and Bitbucket.
-- [ ] Develop additional logics for handling large PRs, and compressing git patches
+- [ ] Develop additional logic for handling large PRs, and compressing git patches
- [ ] Add additional context to the prompt. For example, repo (or relevant files) summarization, with tools such a [ctags](https://github.com/universal-ctags/ctags)
- [ ] Adding more tools. Possible directions:
- [x] PR description
diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py
index 0f3f13f5..a97b97ac 100644
--- a/pr_agent/algo/ai_handler.py
+++ b/pr_agent/algo/ai_handler.py
@@ -9,7 +9,17 @@ from pr_agent.config_loader import settings
OPENAI_RETRIES=2
class AiHandler:
+ """
+ This class handles interactions with the OpenAI API for chat completions.
+ It initializes the API key and other settings from a configuration file,
+ and provides a method for performing chat completions using the OpenAI ChatCompletion API.
+ """
+
def __init__(self):
+ """
+ Initializes the OpenAI API key and other settings from a configuration file.
+ Raises a ValueError if the OpenAI key is missing.
+ """
try:
openai.api_key = settings.openai.key
if settings.get("OPENAI.ORG", None):
@@ -27,6 +37,25 @@ class AiHandler:
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
+ """
+ Performs a chat completion using the OpenAI ChatCompletion API.
+ Retries in case of API errors or timeouts.
+
+ Args:
+ model (str): The model to use for chat completion.
+ temperature (float): The temperature parameter for chat completion.
+ system (str): The system message for chat completion.
+ user (str): The user message for chat completion.
+
+ Returns:
+ tuple: A tuple containing the response and finish reason from the API.
+
+ Raises:
+ TryAgain: If the API response is empty or there are no choices in the response.
+ APIError: If there is an error during OpenAI inference.
+ Timeout: If there is a timeout during OpenAI inference.
+ TryAgain: If there is an attribute error during OpenAI inference.
+ """
try:
response = await openai.ChatCompletion.acreate(
model=model,
@@ -44,4 +73,4 @@ class AiHandler:
raise TryAgain
resp = response.choices[0]['message']['content']
finish_reason = response.choices[0].finish_reason
- return resp, finish_reason
+ return resp, finish_reason
\ No newline at end of file
diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py
index d8aa1802..8128da48 100644
--- a/pr_agent/algo/git_patch_processing.py
+++ b/pr_agent/algo/git_patch_processing.py
@@ -8,7 +8,15 @@ from pr_agent.config_loader import settings
def extend_patch(original_file_str, patch_str, num_lines) -> str:
"""
- Extends the patch to include 'num_lines' more surrounding lines
+ Extends the given patch to include a specified number of surrounding lines.
+
+ Args:
+ original_file_str (str): The original file to which the patch will be applied.
+ patch_str (str): The patch to be applied to the original file.
+ num_lines (int): The number of surrounding lines to include in the extended patch.
+
+ Returns:
+ str: The extended patch string.
"""
if not patch_str or num_lines == 0:
return patch_str
@@ -61,6 +69,14 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
def omit_deletion_hunks(patch_lines) -> str:
+ """
+ Omit deletion hunks from the patch and return the modified patch.
+ Args:
+ - patch_lines: a list of strings representing the lines of the patch
+ Returns:
+ - A string representing the modified patch with deletion hunks omitted
+ """
+
temp_hunk = []
added_patched = []
add_hunk = False
@@ -93,7 +109,20 @@ def omit_deletion_hunks(patch_lines) -> str:
def handle_patch_deletions(patch: str, original_file_content_str: str,
new_file_content_str: str, file_name: str) -> str:
"""
- Handle entire file or deletion patches
+ Handle entire file or deletion patches.
+
+ This function takes a patch, original file content, new file content, and file name as input.
+ It handles entire file or deletion patches and returns the modified patch with deletion hunks omitted.
+
+ Args:
+ patch (str): The patch to be handled.
+ original_file_content_str (str): The original content of the file.
+ new_file_content_str (str): The new content of the file.
+ file_name (str): The name of the file.
+
+ Returns:
+ str: The modified patch with deletion hunks omitted.
+
"""
if not new_file_content_str:
# logic for handling deleted files - don't show patch, just show that the file was deleted
@@ -111,20 +140,26 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
- # toDO: (maybe remove '-' and '+' from the beginning of the line)
"""
- ## src/file.ts
+ Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file.
+
+ Args:
+ patch (str): The patch string to be converted.
+ file: An object containing the filename of the file being patched.
+
+ Returns:
+ str: A string with line numbers for each hunk, indicating the new and old content of the file.
+
+ example output:
+## src/file.ts
--new hunk--
881 line1
882 line2
883 line3
-884 line4
-885 line6
-886 line7
-887 + line8
-888 + line9
-889 line10
-890 line11
+887 + line4
+888 + line5
+889 line6
+890 line7
...
--old hunk--
line1
@@ -134,8 +169,8 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
line5
line6
...
-
"""
+
patch_with_lines_str = f"## {file.filename}\n"
import re
patch_lines = patch.splitlines()
diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py
index 11f16449..8bfaac50 100644
--- a/pr_agent/algo/pr_processing.py
+++ b/pr_agent/algo/pr_processing.py
@@ -20,12 +20,21 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3
-def get_pr_diff(git_provider: Union[GitProvider], token_handler: TokenHandler,
+def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str:
"""
- Returns a string with the diff of the PR.
- If needed, apply diff minimization techniques to reduce the number of tokens
+ Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
+
+ Args:
+ git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
+ token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
+ add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
+ disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
+
+ Returns:
+ str: A string with the diff of the pull request, applying diff minimization techniques if needed.
"""
+
if disable_extra_lines:
global PATCH_EXTRA_LINES
PATCH_EXTRA_LINES = 0
@@ -61,7 +70,16 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]:
"""
- Generate a standard diff string, with patch extension
+ Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed.
+
+ Args:
+ - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files.
+ - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
+ - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
+
+ Returns:
+ - patches_extended: A list of extended patches for each file in the pull request.
+ - total_tokens: The total number of tokens used in the extended patches.
"""
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
@@ -94,12 +112,26 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
- # Apply Diff Minimization techniques to reduce the number of tokens:
- # 0. Start from the largest diff patch to smaller ones
- # 1. Don't use extend context lines around diff
- # 2. Minimize deleted files
- # 3. Minimize deleted hunks
- # 4. Minimize all remaining files when you reach token limit
+ """
+ Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
+ Args:
+ top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
+ token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
+ convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
+ Returns:
+ Tuple[list, list, list]: A tuple containing the following lists:
+ - patches: A list of compressed diff patches for each file in the pull request.
+ - modified_files_list: A list of file names that were skipped due to large patch size.
+ - deleted_files_list: A list of file names that were deleted in the pull request.
+
+ Minimization techniques to reduce the number of tokens:
+ 0. Start from the largest diff patch to smaller ones
+ 1. Don't use extend context lines around diff
+ 2. Minimize deleted files
+ 3. Minimize deleted hunks
+ 4. Minimize all remaining files when you reach token limit
+ """
+
patches = []
modified_files_list = []
diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py
index 53fb3ac9..19d03df3 100644
--- a/pr_agent/algo/token_handler.py
+++ b/pr_agent/algo/token_handler.py
@@ -6,12 +6,43 @@ from pr_agent.config_loader import settings
class TokenHandler:
+ """
+ A class for handling tokens in the context of a pull request.
+
+ Attributes:
+ - encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them.
+ - limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module.
+ - prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method.
+ """
+
def __init__(self, pr, vars: dict, system, user):
+ """
+ Initializes the TokenHandler object.
+
+ Args:
+ - pr: The pull request object.
+ - vars: A dictionary of variables.
+ - system: The system string.
+ - user: The user string.
+ """
self.encoder = encoding_for_model(settings.config.model)
self.limit = MAX_TOKENS[settings.config.model]
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
+ """
+ Calculates the number of tokens in the system and user strings.
+
+ Args:
+ - pr: The pull request object.
+ - encoder: An object of the encoding_for_model class from the tiktoken module.
+ - vars: A dictionary of variables.
+ - system: The system string.
+ - user: The user string.
+
+ Returns:
+ The sum of the number of tokens in the system and user strings.
+ """
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)
@@ -21,4 +52,13 @@ class TokenHandler:
return system_prompt_tokens + user_prompt_tokens
def count_tokens(self, patch: str) -> int:
+ """
+ Counts the number of tokens in a given patch string.
+
+ Args:
+ - patch: The patch string.
+
+ Returns:
+ The number of tokens in the patch string.
+ """
return len(self.encoder.encode(patch, disallowed_special=()))
\ No newline at end of file
diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py
index 673a3570..23bf8dfd 100644
--- a/pr_agent/algo/utils.py
+++ b/pr_agent/algo/utils.py
@@ -11,6 +11,13 @@ from pr_agent.config_loader import settings
def convert_to_markdown(output_data: dict) -> str:
+ """
+ Convert a dictionary of data into markdown format.
+ Args:
+ output_data (dict): A dictionary containing data to be converted to markdown format.
+ Returns:
+ str: The markdown formatted text generated from the input dictionary.
+ """
markdown_text = ""
emojis = {
@@ -49,6 +56,15 @@ def convert_to_markdown(output_data: dict) -> str:
def parse_code_suggestion(code_suggestions: dict) -> str:
+ """
+ Convert a dictionary of data into markdown format.
+
+ Args:
+ code_suggestions (dict): A dictionary containing data to be converted to markdown format.
+
+ Returns:
+ str: A string containing the markdown formatted text generated from the input dictionary.
+ """
markdown_text = ""
for sub_key, sub_value in code_suggestions.items():
if isinstance(sub_value, dict): # "code example"
@@ -68,18 +84,41 @@ def parse_code_suggestion(code_suggestions: dict) -> str:
def try_fix_json(review, max_iter=10, code_suggestions=False):
+ """
+ Fix broken or incomplete JSON messages and return the parsed JSON data.
+
+ Args:
+ - review: A string containing the JSON message to be fixed.
+ - max_iter: An integer representing the maximum number of iterations to try and fix the JSON message.
+ - code_suggestions: A boolean indicating whether to try and fix JSON messages with code suggestions.
+
+ Returns:
+ - data: A dictionary containing the parsed JSON data.
+
+ The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
+ If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message.
+ If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion.
+ The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines.
+ It tries to parse the JSON message with the closing bracket and checks if it is valid.
+ If the JSON message is valid, the parsed JSON data is returned.
+ If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached.
+ If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
+ """
+
if review.endswith("}"):
return fix_json_escape_char(review)
- # Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion
+
data = {}
if code_suggestions:
closing_bracket = "]}"
else:
closing_bracket = "]}}"
+
if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0:
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
valid_json = False
iter_count = 0
+
while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter:
try:
data = json.loads(review[:last_code_suggestion_ind] + closing_bracket)
@@ -87,16 +126,30 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
review = review[:last_code_suggestion_ind].strip() + closing_bracket
except json.decoder.JSONDecodeError:
review = review[:last_code_suggestion_ind]
- # Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
iter_count += 1
+
if not valid_json:
logging.error("Unable to decode JSON response from AI")
data = {}
+
return data
def fix_json_escape_char(json_message=None):
+ """
+ Fix broken or incomplete JSON messages and return the parsed JSON data.
+
+ Args:
+ json_message (str): A string containing the JSON message to be fixed.
+
+ Returns:
+ dict: A dictionary containing the parsed JSON data.
+
+ Raises:
+ None
+
+ """
try:
result = json.loads(json_message)
except Exception as e:
@@ -111,11 +164,43 @@ def fix_json_escape_char(json_message=None):
def convert_str_to_datetime(date_str):
+ """
+ Convert a string representation of a date and time into a datetime object.
+
+ Args:
+ date_str (str): A string representation of a date and time in the format '%a, %d %b %Y %H:%M:%S %Z'
+
+ Returns:
+ datetime: A datetime object representing the input date and time.
+
+ Example:
+ >>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC')
+ datetime.datetime(2022, 1, 1, 12, 0, 0)
+ """
datetime_format = '%a, %d %b %Y %H:%M:%S %Z'
return datetime.strptime(date_str, datetime_format)
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
+ """
+ Generate a patch for a modified file by comparing the original content of the file with the new content provided as input.
+
+ Args:
+ file: The file object for which the patch needs to be generated.
+ new_file_content_str: The new content of the file as a string.
+ original_file_content_str: The original content of the file as a string.
+ patch: An optional patch string that can be provided as input.
+
+ Returns:
+ The generated or provided patch string.
+
+ Raises:
+ None.
+
+ Additional Information:
+ - If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output.
+ - If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created.
+ """
if not patch: # to Do - also add condition for file extension
try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),