diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 1c34e603..4319af30 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -1,19 +1,17 @@ from __future__ import annotations -import difflib import logging -import re import traceback -from typing import Any, Callable, List, Tuple +from typing import Callable, List, Tuple from github import RateLimitExceededException from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages -from pr_agent.algo.token_handler import TokenHandler, get_token_encoder +from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings -from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider +from pr_agent.git_providers.git_provider import GitProvider DELETED_FILES_ = "Deleted files:\n" @@ -247,99 +245,6 @@ def _get_all_deployments(all_models: List[str]) -> List[str]: return all_deployments -def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], - relevant_file: str, - relevant_line_in_file: str) -> Tuple[int, int]: - """ - Find the line number and absolute position of a relevant line in a file. - - Args: - diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files. - relevant_file (str): The name of the file where the relevant line is located. - relevant_line_in_file (str): The content of the relevant line. - - Returns: - Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file. - """ - position = -1 - absolute_position = -1 - re_hunk_header = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") - - for file in diff_files: - if file.filename.strip() == relevant_file: - patch = file.patch - patch_lines = patch.splitlines() - - # try to find the line in the patch using difflib, with some margin of error - matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file, - patch_lines, n=3, cutoff=0.93) - if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'): - relevant_line_in_file = matches_difflib[0] - - delta = 0 - start1, size1, start2, size2 = 0, 0, 0, 0 - for i, line in enumerate(patch_lines): - if line.startswith('@@'): - delta = 0 - match = re_hunk_header.match(line) - start1, size1, start2, size2 = map(int, match.groups()[:4]) - elif not line.startswith('-'): - delta += 1 - - if relevant_line_in_file in line and line[0] != '-': - position = i - absolute_position = start2 + delta - 1 - break - - if position == -1 and relevant_line_in_file[0] == '+': - no_plus_line = relevant_line_in_file[1:].lstrip() - for i, line in enumerate(patch_lines): - if line.startswith('@@'): - delta = 0 - match = re_hunk_header.match(line) - start1, size1, start2, size2 = map(int, match.groups()[:4]) - elif not line.startswith('-'): - delta += 1 - - if no_plus_line in line and line[0] != '-': - # The model might add a '+' to the beginning of the relevant_line_in_file even if originally - # it's a context line - position = i - absolute_position = start2 + delta - 1 - break - return position, absolute_position - - -def clip_tokens(text: str, max_tokens: int) -> str: - """ - Clip the number of tokens in a string to a maximum number of tokens. - - Args: - text (str): The string to clip. - max_tokens (int): The maximum number of tokens allowed in the string. - - Returns: - str: The clipped string. - """ - if not text: - return text - - try: - encoder = get_token_encoder() - num_input_tokens = len(encoder.encode(text)) - if num_input_tokens <= max_tokens: - return text - num_chars = len(text) - chars_per_token = num_chars / num_input_tokens - num_output_chars = int(chars_per_token * max_tokens) - clipped_text = text[:num_output_chars] - return clipped_text - except Exception as e: - logging.warning(f"Failed to clip tokens: {e}") - return text - - def get_pr_multi_diffs(git_provider: GitProvider, token_handler: TokenHandler, model: str, diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index f018a92b..e94d1ed2 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -21,7 +21,7 @@ class TokenHandler: method. """ - def __init__(self, pr, vars: dict, system, user): + def __init__(self, vars: dict, system, user): """ Initializes the TokenHandler object. @@ -32,9 +32,9 @@ class TokenHandler: - user: The user string. """ self.encoder = get_token_encoder() - self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) + self.prompt_tokens = self._get_system_user_tokens(self.encoder, vars, system, user) - def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): + def _get_system_user_tokens(self, encoder, vars: dict, system, user): """ Calculates the number of tokens in the system and user strings. diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 4d09b6e7..2d7a6d39 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -5,14 +5,24 @@ import json import logging import re import textwrap +from dataclasses import dataclass from datetime import datetime -from typing import Any, List +from enum import Enum +from typing import Any, List, Tuple, Optional import yaml from starlette_context import context + +from pr_agent.algo.token_handler import get_token_encoder from pr_agent.config_loader import get_settings, global_settings +class EDIT_TYPE(Enum): + ADDED = 1 + DELETED = 2 + MODIFIED = 3 + RENAMED = 4 + def get_setting(key: str) -> Any: try: key = key.upper() @@ -294,3 +304,108 @@ def try_fix_yaml(review_text: str) -> dict: except: pass return data + +def clip_tokens(text: str, max_tokens: int) -> str: + """ + Clip the number of tokens in a string to a maximum number of tokens. + + Args: + text (str): The string to clip. + max_tokens (int): The maximum number of tokens allowed in the string. + + Returns: + str: The clipped string. + """ + if not text: + return text + + try: + encoder = get_token_encoder() + num_input_tokens = len(encoder.encode(text)) + if num_input_tokens <= max_tokens: + return text + num_chars = len(text) + chars_per_token = num_chars / num_input_tokens + num_output_chars = int(chars_per_token * max_tokens) + clipped_text = text[:num_output_chars] + return clipped_text + except Exception as e: + logging.warning(f"Failed to clip tokens: {e}") + return text + + +@dataclass +class FilePatchInfo: + base_file: str + head_file: str + patch: str + filename: str + tokens: int = -1 + edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED + old_filename: str = None + language: Optional[str] = None + + +def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], + relevant_file: str, + relevant_line_in_file: str) -> Tuple[int, int]: + """ + Find the line number and absolute position of a relevant line in a file. + + Args: + diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files. + relevant_file (str): The name of the file where the relevant line is located. + relevant_line_in_file (str): The content of the relevant line. + + Returns: + Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file. + """ + position = -1 + absolute_position = -1 + re_hunk_header = re.compile( + r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + + for file in diff_files: + if file.filename.strip() == relevant_file: + patch = file.patch + patch_lines = patch.splitlines() + + # try to find the line in the patch using difflib, with some margin of error + matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file, + patch_lines, n=3, cutoff=0.93) + if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'): + relevant_line_in_file = matches_difflib[0] + + delta = 0 + start1, size1, start2, size2 = 0, 0, 0, 0 + for i, line in enumerate(patch_lines): + if line.startswith('@@'): + delta = 0 + match = re_hunk_header.match(line) + start1, size1, start2, size2 = map(int, match.groups()[:4]) + elif not line.startswith('-'): + delta += 1 + + if relevant_line_in_file in line and line[0] != '-': + position = i + absolute_position = start2 + delta - 1 + break + + if position == -1 and relevant_line_in_file[0] == '+': + no_plus_line = relevant_line_in_file[1:].lstrip() + for i, line in enumerate(patch_lines): + if line.startswith('@@'): + delta = 0 + match = re_hunk_header.match(line) + start1, size1, start2, size2 = map(int, match.groups()[:4]) + elif not line.startswith('-'): + delta += 1 + + if no_plus_line in line and line[0] != '-': + # The model might add a '+' to the beginning of the relevant_line_in_file even if originally + # it's a context line + position = i + absolute_position = start2 + delta - 1 + break + return position, absolute_position + diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index 0cd860fa..0239d10a 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -8,7 +8,8 @@ from atlassian.bitbucket import Cloud from starlette_context import context from ..config_loader import get_settings -from .git_provider import FilePatchInfo, GitProvider +from .git_provider import GitProvider +from ..algo.utils import FilePatchInfo class BitbucketProvider(GitProvider): diff --git a/pr_agent/git_providers/codecommit_client.py b/pr_agent/git_providers/codecommit_client.py index c1cfa763..6c8f3320 100644 --- a/pr_agent/git_providers/codecommit_client.py +++ b/pr_agent/git_providers/codecommit_client.py @@ -1,6 +1,9 @@ -import boto3 -import botocore - +try: # Allow this module to be imported without requiring boto3 + import boto3 + import botocore +except ModuleNotFoundError: + boto3 = None + botocore = None class CodeCommitDifferencesResponse: """ diff --git a/pr_agent/git_providers/codecommit_provider.py b/pr_agent/git_providers/codecommit_provider.py index a747e7f2..a7510a69 100644 --- a/pr_agent/git_providers/codecommit_provider.py +++ b/pr_agent/git_providers/codecommit_provider.py @@ -4,13 +4,12 @@ from collections import Counter from typing import List, Optional, Tuple from urllib.parse import urlparse -from ..algo.language_handler import is_valid_file, language_extension_map -from ..algo.pr_processing import clip_tokens -from ..algo.utils import load_large_diff -from ..config_loader import get_settings -from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider, IncrementalPR from pr_agent.git_providers.codecommit_client import CodeCommitClient +from ..algo.language_handler import is_valid_file, language_extension_map +from ..algo.utils import EDIT_TYPE, FilePatchInfo, load_large_diff +from .git_provider import GitProvider + class PullRequestCCMimic: """ diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index f0a5419e..72e1cf07 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -1,27 +1,9 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass # enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) -from enum import Enum from typing import Optional - -class EDIT_TYPE(Enum): - ADDED = 1 - DELETED = 2 - MODIFIED = 3 - RENAMED = 4 - - -@dataclass -class FilePatchInfo: - base_file: str - head_file: str - patch: str - filename: str - tokens: int = -1 - edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED - old_filename: str = None +from pr_agent.algo.utils import FilePatchInfo class GitProvider(ABC): @@ -87,7 +69,7 @@ class GitProvider(ABC): def get_pr_description(self) -> str: from pr_agent.config_loader import get_settings - from pr_agent.algo.pr_processing import clip_tokens + from pr_agent.algo.utils import clip_tokens max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) description = self.get_pr_description_full() if max_tokens: diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 057bc15a..055ba9c5 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -9,10 +9,9 @@ from github import AppAuthentication, Auth, Github, GithubException, Reaction from retry import retry from starlette_context import context -from .git_provider import FilePatchInfo, GitProvider, IncrementalPR +from .git_provider import GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file -from ..algo.utils import load_large_diff -from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens +from ..algo.utils import load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file, FilePatchInfo from ..config_loader import get_settings from ..servers.utils import RateLimitExceeded diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 2deae177..977e1b4c 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -7,10 +7,9 @@ import gitlab from gitlab import GitlabGetError from ..algo.language_handler import is_valid_file -from ..algo.pr_processing import clip_tokens -from ..algo.utils import load_large_diff +from ..algo.utils import load_large_diff, clip_tokens, EDIT_TYPE, FilePatchInfo from ..config_loader import get_settings -from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider +from .git_provider import GitProvider logger = logging.getLogger() diff --git a/pr_agent/git_providers/local_git_provider.py b/pr_agent/git_providers/local_git_provider.py index e6ee1456..9cca9c86 100644 --- a/pr_agent/git_providers/local_git_provider.py +++ b/pr_agent/git_providers/local_git_provider.py @@ -6,7 +6,8 @@ from typing import List from git import Repo from pr_agent.config_loader import _find_repository_root, get_settings -from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider +from pr_agent.git_providers.git_provider import GitProvider +from pr_agent.algo.utils import EDIT_TYPE, FilePatchInfo class PullRequestMimic: diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index cc787f5e..2d9bdd33 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -45,9 +45,7 @@ class PRCodeSuggestions: "extra_instructions": get_settings().pr_code_suggestions.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_code_suggestions_prompt.system, + self.token_handler = TokenHandler(self.vars, get_settings().pr_code_suggestions_prompt.system, get_settings().pr_code_suggestions_prompt.user) async def run(self): diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index acd272bc..a93fb04f 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -46,12 +46,8 @@ class PRDescription: self.user_description = self.git_provider.get_user_description() # Initialize the token handler - self.token_handler = TokenHandler( - self.git_provider.pr, - self.vars, - get_settings().pr_description_prompt.system, - get_settings().pr_description_prompt.user, - ) + self.token_handler = TokenHandler(self.vars, get_settings().pr_description_prompt.system, + get_settings().pr_description_prompt.user) # Initialize patches_diff and prediction attributes self.patches_diff = None diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index c049250f..27d6f00b 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -26,9 +26,7 @@ class PRInformationFromUser: "diff": "", # empty diff for initial calculation "commit_messages_str": self.git_provider.get_commit_messages(), } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_information_from_user_prompt.system, + self.token_handler = TokenHandler(self.vars, get_settings().pr_information_from_user_prompt.system, get_settings().pr_information_from_user_prompt.user) self.patches_diff = None self.prediction = None diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 959bebe7..16ca32c7 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -29,9 +29,7 @@ class PRQuestions: "questions": self.question_str, "commit_messages_str": self.git_provider.get_commit_messages(), } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_questions_prompt.system, + self.token_handler = TokenHandler(self.vars, get_settings().pr_questions_prompt.system, get_settings().pr_questions_prompt.user) self.patches_diff = None self.prediction = None diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index a89c27a3..230bb2b2 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -9,8 +9,7 @@ from jinja2 import Environment, StrictUndefined from yaml import SafeLoader from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \ - find_line_number_of_relevant_line_in_file, clip_tokens +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, try_fix_json, try_fix_yaml, load_yaml from pr_agent.config_loader import get_settings @@ -66,12 +65,8 @@ class PRReviewer: "commit_messages_str": self.git_provider.get_commit_messages(), } - self.token_handler = TokenHandler( - self.git_provider.pr, - self.vars, - get_settings().pr_review_prompt.system, - get_settings().pr_review_prompt.user - ) + self.token_handler = TokenHandler(self.vars, get_settings().pr_review_prompt.system, + get_settings().pr_review_prompt.user) def parse_args(self, args: List[str]) -> None: """ @@ -217,8 +212,8 @@ class PRReviewer: markdown_text = convert_to_markdown(data) user = self.git_provider.get_user_id() - # Add help text if not in CLI mode - if not get_settings().get("CONFIG.CLI_MODE", False): + # Add help text if not in CLI§ mode + if not get_settings().get("CONFIG.CLI§_MODE", False): markdown_text += "\n### How to use\n" if user and '[bot]' not in user: markdown_text += bot_help_text(user) diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index 1ec62709..812efe4d 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -40,9 +40,7 @@ class PRUpdateChangelog: "extra_instructions": get_settings().pr_update_changelog.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_update_changelog_prompt.system, + self.token_handler = TokenHandler(self.vars, get_settings().pr_update_changelog_prompt.system, get_settings().pr_update_changelog_prompt.user) async def run(self): diff --git a/tests/unittest/test_codecommit_provider.py b/tests/unittest/test_codecommit_provider.py index e35f7250..9bcece8b 100644 --- a/tests/unittest/test_codecommit_provider.py +++ b/tests/unittest/test_codecommit_provider.py @@ -1,7 +1,7 @@ import pytest from pr_agent.git_providers.codecommit_provider import CodeCommitFile from pr_agent.git_providers.codecommit_provider import CodeCommitProvider -from pr_agent.git_providers.git_provider import EDIT_TYPE +from pr_agent.algo.utils import EDIT_TYPE class TestCodeCommitFile: diff --git a/tests/unittest/test_find_line_number_of_relevant_line_in_file.py b/tests/unittest/test_find_line_number_of_relevant_line_in_file.py index 7488c6df..0c63f9f0 100644 --- a/tests/unittest/test_find_line_number_of_relevant_line_in_file.py +++ b/tests/unittest/test_find_line_number_of_relevant_line_in_file.py @@ -1,8 +1,6 @@ # Generated by CodiumAI -from pr_agent.git_providers.git_provider import FilePatchInfo -from pr_agent.algo.pr_processing import find_line_number_of_relevant_line_in_file - +from pr_agent.algo.utils import FilePatchInfo, find_line_number_of_relevant_line_in_file import pytest