diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 6063dece..4c1352f0 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -10,7 +10,7 @@ from github import RateLimitExceededException from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.file_filter import filter_ignored -from pr_agent.algo.token_handler import TokenHandler, get_token_encoder +from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import get_max_tokens from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE @@ -326,35 +326,6 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], return position, absolute_position -def clip_tokens(text: str, max_tokens: int) -> str: - """ - Clip the number of tokens in a string to a maximum number of tokens. - - Args: - text (str): The string to clip. - max_tokens (int): The maximum number of tokens allowed in the string. - - Returns: - str: The clipped string. - """ - if not text: - return text - - try: - encoder = get_token_encoder() - num_input_tokens = len(encoder.encode(text)) - if num_input_tokens <= max_tokens: - return text - num_chars = len(text) - chars_per_token = num_chars / num_input_tokens - num_output_chars = int(chars_per_token * max_tokens) - clipped_text = text[:num_output_chars] - return clipped_text - except Exception as e: - get_logger().warning(f"Failed to clip tokens: {e}") - return text - - def get_pr_multi_diffs(git_provider: GitProvider, token_handler: TokenHandler, model: str, diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index b9aaee94..73074098 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -11,6 +11,7 @@ import yaml from starlette_context import context from pr_agent.algo import MAX_TOKENS +from pr_agent.algo.token_handler import get_token_encoder from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger @@ -378,3 +379,34 @@ def get_max_tokens(model): max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) # get_logger().debug(f"limiting max tokens to {max_tokens_model}") return max_tokens_model + + +def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str: + """ + Clip the number of tokens in a string to a maximum number of tokens. + + Args: + text (str): The string to clip. + max_tokens (int): The maximum number of tokens allowed in the string. + add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped + Returns: + str: The clipped string. + """ + if not text: + return text + + try: + encoder = get_token_encoder() + num_input_tokens = len(encoder.encode(text)) + if num_input_tokens <= max_tokens: + return text + num_chars = len(text) + chars_per_token = num_chars / num_input_tokens + num_output_chars = int(chars_per_token * max_tokens) + clipped_text = text[:num_output_chars] + if add_three_dots: + clipped_text += "...(truncated)" + return clipped_text + except Exception as e: + get_logger().warning(f"Failed to clip tokens: {e}") + return text diff --git a/pr_agent/git_providers/azuredevops_provider.py b/pr_agent/git_providers/azuredevops_provider.py index 6a404532..ca11b9d8 100644 --- a/pr_agent/git_providers/azuredevops_provider.py +++ b/pr_agent/git_providers/azuredevops_provider.py @@ -14,9 +14,8 @@ try: except ImportError: AZURE_DEVOPS_AVAILABLE = False -from ..algo.pr_processing import clip_tokens from ..config_loader import get_settings -from ..algo.utils import load_large_diff +from ..algo.utils import load_large_diff, clip_tokens from ..algo.language_handler import is_valid_file from .git_provider import EDIT_TYPE, FilePatchInfo diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index d929ed37..d0012b5e 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -63,7 +63,7 @@ class GitProvider(ABC): def get_pr_description(self, *, full: bool = True) -> str: from pr_agent.config_loader import get_settings - from pr_agent.algo.pr_processing import clip_tokens + from pr_agent.algo.utils import clip_tokens max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) description = self.get_pr_description_full() if full else self.get_user_description() if max_tokens_description: diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 634b8694..46afbad6 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -8,8 +8,8 @@ from retry import retry from starlette_context import context from ..algo.language_handler import is_valid_file -from ..algo.pr_processing import clip_tokens, find_line_number_of_relevant_line_in_file -from ..algo.utils import load_large_diff +from ..algo.pr_processing import find_line_number_of_relevant_line_in_file +from ..algo.utils import load_large_diff, clip_tokens from ..config_loader import get_settings from ..log import get_logger from ..servers.utils import RateLimitExceeded diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 078ca9dd..2eb00ce1 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -7,8 +7,8 @@ import gitlab from gitlab import GitlabGetError from ..algo.language_handler import is_valid_file -from ..algo.pr_processing import clip_tokens, find_line_number_of_relevant_line_in_file -from ..algo.utils import load_large_diff +from ..algo.pr_processing import find_line_number_of_relevant_line_in_file +from ..algo.utils import load_large_diff, clip_tokens from ..config_loader import get_settings from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from ..log import get_logger diff --git a/tests/unittest/test_clip_tokens.py b/tests/unittest/test_clip_tokens.py new file mode 100644 index 00000000..cc52ab7e --- /dev/null +++ b/tests/unittest/test_clip_tokens.py @@ -0,0 +1,19 @@ + +# Generated by CodiumAI + +import pytest + +from pr_agent.algo.utils import clip_tokens + + +class TestClipTokens: + def test_clip(self): + text = "line1\nline2\nline3\nline4\nline5\nline6" + max_tokens = 25 + result = clip_tokens(text, max_tokens) + assert result == text + + max_tokens = 10 + result = clip_tokens(text, max_tokens) + expected_results = 'line1\nline2\nline3\nli...(truncated)' + assert result == expected_results