diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 780c7953..3e3103ad 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -945,12 +945,66 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token """ Clip the number of tokens in a string to a maximum number of tokens. + This function limits text to a specified token count by calculating the approximate + character-to-token ratio and truncating the text accordingly. A safety factor of 0.9 + (10% reduction) is applied to ensure the result stays within the token limit. + Args: - text (str): The string to clip. + text (str): The string to clip. If empty or None, returns the input unchanged. max_tokens (int): The maximum number of tokens allowed in the string. - add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped + If negative, returns an empty string. + add_three_dots (bool, optional): Whether to add "\\n...(truncated)" at the end + of the clipped text to indicate truncation. + Defaults to True. + num_input_tokens (int, optional): Pre-computed number of tokens in the input text. + If provided, skips token encoding step for efficiency. + If None, tokens will be counted using TokenEncoder. + Defaults to None. + delete_last_line (bool, optional): Whether to remove the last line from the + clipped content before adding truncation indicator. + Useful for ensuring clean breaks at line boundaries. + Defaults to False. + Returns: - str: The clipped string. + str: The clipped string. Returns original text if: + - Text is empty/None + - Token count is within limit + - An error occurs during processing + + Returns empty string if max_tokens <= 0. + + Examples: + Basic usage: + >>> text = "This is a sample text that might be too long" + >>> result = clip_tokens(text, max_tokens=10) + >>> print(result) + This is a sample... + (truncated) + + Without truncation indicator: + >>> result = clip_tokens(text, max_tokens=10, add_three_dots=False) + >>> print(result) + This is a sample + + With pre-computed token count: + >>> result = clip_tokens(text, max_tokens=5, num_input_tokens=15) + >>> print(result) + This... + (truncated) + + With line deletion: + >>> multiline_text = "Line 1\\nLine 2\\nLine 3" + >>> result = clip_tokens(multiline_text, max_tokens=3, delete_last_line=True) + >>> print(result) + Line 1 + Line 2 + ... + (truncated) + + Notes: + The function uses a safety factor of 0.9 (10% reduction) to ensure the + result stays within the token limit, as character-to-token ratios can vary. + If token encoding fails, the original text is returned with a warning logged. """ if not text: return text diff --git a/tests/unittest/test_clip_tokens.py b/tests/unittest/test_clip_tokens.py index 79de6294..a42ef929 100644 --- a/tests/unittest/test_clip_tokens.py +++ b/tests/unittest/test_clip_tokens.py @@ -1,13 +1,302 @@ - -# Generated by CodiumAI - import pytest - +from unittest.mock import patch, MagicMock from pr_agent.algo.utils import clip_tokens +from pr_agent.algo.token_handler import TokenEncoder class TestClipTokens: - def test_clip(self): + """Comprehensive test suite for the clip_tokens function.""" + + def test_empty_input_text(self): + """Test that empty input returns empty string.""" + assert clip_tokens("", 10) == "" + assert clip_tokens(None, 10) is None + + def test_text_under_token_limit(self): + """Test that text under the token limit is returned unchanged.""" + text = "Short text" + max_tokens = 100 + result = clip_tokens(text, max_tokens) + assert result == text + + def test_text_exactly_at_token_limit(self): + """Test text that is exactly at the token limit.""" + text = "This is exactly at the limit" + # Mock the token encoder to return exact limit + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 10 # Exactly 10 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, 10) + assert result == text + + def test_text_over_token_limit_with_three_dots(self): + """Test text over token limit with three dots addition.""" + text = "This is a longer text that should be clipped when it exceeds the token limit" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert result.endswith("\n...(truncated)") + assert len(result) < len(text) + + def test_text_over_token_limit_without_three_dots(self): + """Test text over token limit without three dots addition.""" + text = "This is a longer text that should be clipped" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens, add_three_dots=False) + assert not result.endswith("\n...(truncated)") + assert len(result) < len(text) + + def test_negative_max_tokens(self): + """Test that negative max_tokens returns empty string.""" + text = "Some text" + result = clip_tokens(text, -1) + assert result == "" + + result = clip_tokens(text, -100) + assert result == "" + + def test_zero_max_tokens(self): + """Test that zero max_tokens returns empty string.""" + text = "Some text" + result = clip_tokens(text, 0) + assert result == "" + + def test_delete_last_line_functionality(self): + """Test the delete_last_line parameter functionality.""" + text = "Line 1\nLine 2\nLine 3\nLine 4" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + # Without delete_last_line + result_normal = clip_tokens(text, max_tokens, delete_last_line=False) + + # With delete_last_line + result_deleted = clip_tokens(text, max_tokens, delete_last_line=True) + + # The result with delete_last_line should be shorter or equal + assert len(result_deleted) <= len(result_normal) + + def test_pre_computed_num_input_tokens(self): + """Test using pre-computed num_input_tokens parameter.""" + text = "This is a test text" + max_tokens = 10 + num_input_tokens = 15 + + # Should not call the encoder when num_input_tokens is provided + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.return_value = None # Should not be called + + result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens) + assert result.endswith("\n...(truncated)") + mock_encoder.assert_not_called() + + def test_pre_computed_tokens_under_limit(self): + """Test pre-computed tokens under the limit.""" + text = "Short text" + max_tokens = 20 + num_input_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.return_value = None # Should not be called + + result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens) + assert result == text + mock_encoder.assert_not_called() + + def test_special_characters_and_unicode(self): + """Test text with special characters and Unicode content.""" + text = "Special chars: @#$%^&*()_+ áéíóú 中문 🚀 emoji" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert isinstance(result, str) + assert len(result) < len(text) + + def test_multiline_text_handling(self): + """Test handling of multiline text.""" + text = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert isinstance(result, str) + + def test_very_long_text(self): + """Test with very long text.""" + text = "A" * 10000 # Very long text + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 5000 # Many tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert len(result) < len(text) + assert result.endswith("\n...(truncated)") + + def test_encoder_exception_handling(self): + """Test handling of encoder exceptions.""" + text = "Test text" + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.side_effect = Exception("Encoder error") + + # Should return original text when encoder fails + result = clip_tokens(text, max_tokens) + assert result == text + + def test_zero_division_scenario(self): + """Test scenario that could lead to division by zero.""" + text = "Test" + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [] # Empty tokens (could cause division by zero) + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + # Should handle gracefully and return original text + assert result == text + + def test_various_edge_cases(self): + """Test various edge cases.""" + # Single character + assert clip_tokens("A", 1000) == "A" + + # Only whitespace + text = " \n \t " + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 10 + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, 5) + assert isinstance(result, str) + + # Text with only newlines + text = "\n\n\n\n" + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 10 + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, 2, delete_last_line=True) + assert isinstance(result, str) + + def test_parameter_combinations(self): + """Test different parameter combinations.""" + text = "Multi\nline\ntext\nfor\ntesting" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 + mock_encoder.return_value = mock_tokenizer + + # Test all combinations + combinations = [ + (True, True), # add_three_dots=True, delete_last_line=True + (True, False), # add_three_dots=True, delete_last_line=False + (False, True), # add_three_dots=False, delete_last_line=True + (False, False), # add_three_dots=False, delete_last_line=False + ] + + for add_dots, delete_line in combinations: + result = clip_tokens(text, max_tokens, + add_three_dots=add_dots, + delete_last_line=delete_line) + assert isinstance(result, str) + if add_dots and len(result) > 0: + assert result.endswith("\n...(truncated)") or result == text + + def test_num_output_chars_zero_scenario(self): + """Test scenario where num_output_chars becomes zero or negative.""" + text = "Short" + max_tokens = 1 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 1000 # Many tokens for short text + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + # When num_output_chars is 0 or negative, should return empty string + assert result == "" + + def test_logging_on_exception(self): + """Test that exceptions are properly logged.""" + text = "Test text" + max_tokens = 10 + + # Patch the logger at the module level where it's imported + with patch('pr_agent.algo.utils.get_logger') as mock_logger: + mock_log_instance = MagicMock() + mock_logger.return_value = mock_log_instance + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.side_effect = Exception("Test exception") + + result = clip_tokens(text, max_tokens) + + # Should log the warning + mock_log_instance.warning.assert_called_once() + # Should return original text + assert result == text + + def test_factor_safety_calculation(self): + """Test that the 0.9 factor (10% reduction) works correctly.""" + text = "Test text that should be reduced by 10 percent for safety" + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + + # The result should be shorter due to the 0.9 factor + # Characters per token = len(text) / 20 + # Expected chars = int(0.9 * (len(text) / 20) * 10) + expected_chars = int(0.9 * (len(text) / 20) * 10) + + # Result should be around expected_chars length (plus truncation text) + if result.endswith("\n...(truncated)"): + actual_content = result[:-len("\n...(truncated)")] + assert len(actual_content) <= expected_chars + 5 # Some tolerance + + # Test the original basic functionality to ensure backward compatibility + def test_clip_original_functionality(self): + """Test original functionality from the existing test.""" text = "line1\nline2\nline3\nline4\nline5\nline6" max_tokens = 25 result = clip_tokens(text, max_tokens) @@ -16,4 +305,4 @@ class TestClipTokens: max_tokens = 10 result = clip_tokens(text, max_tokens) expected_results = 'line1\nline2\nline3\n\n...(truncated)' - assert result == expected_results + assert result == expected_results \ No newline at end of file