mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 21:00:40 +08:00
Merge pull request #1816 from TaskerJang/feature/clip-tokens-tests-and-docs
Add Unit Tests and Improve Documentation for utils.py clip_tokens Function
This commit is contained in:
@ -945,12 +945,66 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token
|
|||||||
"""
|
"""
|
||||||
Clip the number of tokens in a string to a maximum number of tokens.
|
Clip the number of tokens in a string to a maximum number of tokens.
|
||||||
|
|
||||||
|
This function limits text to a specified token count by calculating the approximate
|
||||||
|
character-to-token ratio and truncating the text accordingly. A safety factor of 0.9
|
||||||
|
(10% reduction) is applied to ensure the result stays within the token limit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): The string to clip.
|
text (str): The string to clip. If empty or None, returns the input unchanged.
|
||||||
max_tokens (int): The maximum number of tokens allowed in the string.
|
max_tokens (int): The maximum number of tokens allowed in the string.
|
||||||
add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped
|
If negative, returns an empty string.
|
||||||
|
add_three_dots (bool, optional): Whether to add "\\n...(truncated)" at the end
|
||||||
|
of the clipped text to indicate truncation.
|
||||||
|
Defaults to True.
|
||||||
|
num_input_tokens (int, optional): Pre-computed number of tokens in the input text.
|
||||||
|
If provided, skips token encoding step for efficiency.
|
||||||
|
If None, tokens will be counted using TokenEncoder.
|
||||||
|
Defaults to None.
|
||||||
|
delete_last_line (bool, optional): Whether to remove the last line from the
|
||||||
|
clipped content before adding truncation indicator.
|
||||||
|
Useful for ensuring clean breaks at line boundaries.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The clipped string.
|
str: The clipped string. Returns original text if:
|
||||||
|
- Text is empty/None
|
||||||
|
- Token count is within limit
|
||||||
|
- An error occurs during processing
|
||||||
|
|
||||||
|
Returns empty string if max_tokens <= 0.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic usage:
|
||||||
|
>>> text = "This is a sample text that might be too long"
|
||||||
|
>>> result = clip_tokens(text, max_tokens=10)
|
||||||
|
>>> print(result)
|
||||||
|
This is a sample...
|
||||||
|
(truncated)
|
||||||
|
|
||||||
|
Without truncation indicator:
|
||||||
|
>>> result = clip_tokens(text, max_tokens=10, add_three_dots=False)
|
||||||
|
>>> print(result)
|
||||||
|
This is a sample
|
||||||
|
|
||||||
|
With pre-computed token count:
|
||||||
|
>>> result = clip_tokens(text, max_tokens=5, num_input_tokens=15)
|
||||||
|
>>> print(result)
|
||||||
|
This...
|
||||||
|
(truncated)
|
||||||
|
|
||||||
|
With line deletion:
|
||||||
|
>>> multiline_text = "Line 1\\nLine 2\\nLine 3"
|
||||||
|
>>> result = clip_tokens(multiline_text, max_tokens=3, delete_last_line=True)
|
||||||
|
>>> print(result)
|
||||||
|
Line 1
|
||||||
|
Line 2
|
||||||
|
...
|
||||||
|
(truncated)
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The function uses a safety factor of 0.9 (10% reduction) to ensure the
|
||||||
|
result stays within the token limit, as character-to-token ratios can vary.
|
||||||
|
If token encoding fails, the original text is returned with a warning logged.
|
||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return text
|
return text
|
||||||
|
@ -1,13 +1,302 @@
|
|||||||
|
|
||||||
# Generated by CodiumAI
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
from pr_agent.algo.utils import clip_tokens
|
from pr_agent.algo.utils import clip_tokens
|
||||||
|
from pr_agent.algo.token_handler import TokenEncoder
|
||||||
|
|
||||||
|
|
||||||
class TestClipTokens:
|
class TestClipTokens:
|
||||||
def test_clip(self):
|
"""Comprehensive test suite for the clip_tokens function."""
|
||||||
|
|
||||||
|
def test_empty_input_text(self):
|
||||||
|
"""Test that empty input returns empty string."""
|
||||||
|
assert clip_tokens("", 10) == ""
|
||||||
|
assert clip_tokens(None, 10) is None
|
||||||
|
|
||||||
|
def test_text_under_token_limit(self):
|
||||||
|
"""Test that text under the token limit is returned unchanged."""
|
||||||
|
text = "Short text"
|
||||||
|
max_tokens = 100
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
assert result == text
|
||||||
|
|
||||||
|
def test_text_exactly_at_token_limit(self):
|
||||||
|
"""Test text that is exactly at the token limit."""
|
||||||
|
text = "This is exactly at the limit"
|
||||||
|
# Mock the token encoder to return exact limit
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 10 # Exactly 10 tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, 10)
|
||||||
|
assert result == text
|
||||||
|
|
||||||
|
def test_text_over_token_limit_with_three_dots(self):
|
||||||
|
"""Test text over token limit with three dots addition."""
|
||||||
|
text = "This is a longer text that should be clipped when it exceeds the token limit"
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
assert result.endswith("\n...(truncated)")
|
||||||
|
assert len(result) < len(text)
|
||||||
|
|
||||||
|
def test_text_over_token_limit_without_three_dots(self):
|
||||||
|
"""Test text over token limit without three dots addition."""
|
||||||
|
text = "This is a longer text that should be clipped"
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens, add_three_dots=False)
|
||||||
|
assert not result.endswith("\n...(truncated)")
|
||||||
|
assert len(result) < len(text)
|
||||||
|
|
||||||
|
def test_negative_max_tokens(self):
|
||||||
|
"""Test that negative max_tokens returns empty string."""
|
||||||
|
text = "Some text"
|
||||||
|
result = clip_tokens(text, -1)
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
result = clip_tokens(text, -100)
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_zero_max_tokens(self):
|
||||||
|
"""Test that zero max_tokens returns empty string."""
|
||||||
|
text = "Some text"
|
||||||
|
result = clip_tokens(text, 0)
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_delete_last_line_functionality(self):
|
||||||
|
"""Test the delete_last_line parameter functionality."""
|
||||||
|
text = "Line 1\nLine 2\nLine 3\nLine 4"
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
# Without delete_last_line
|
||||||
|
result_normal = clip_tokens(text, max_tokens, delete_last_line=False)
|
||||||
|
|
||||||
|
# With delete_last_line
|
||||||
|
result_deleted = clip_tokens(text, max_tokens, delete_last_line=True)
|
||||||
|
|
||||||
|
# The result with delete_last_line should be shorter or equal
|
||||||
|
assert len(result_deleted) <= len(result_normal)
|
||||||
|
|
||||||
|
def test_pre_computed_num_input_tokens(self):
|
||||||
|
"""Test using pre-computed num_input_tokens parameter."""
|
||||||
|
text = "This is a test text"
|
||||||
|
max_tokens = 10
|
||||||
|
num_input_tokens = 15
|
||||||
|
|
||||||
|
# Should not call the encoder when num_input_tokens is provided
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_encoder.return_value = None # Should not be called
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
|
||||||
|
assert result.endswith("\n...(truncated)")
|
||||||
|
mock_encoder.assert_not_called()
|
||||||
|
|
||||||
|
def test_pre_computed_tokens_under_limit(self):
|
||||||
|
"""Test pre-computed tokens under the limit."""
|
||||||
|
text = "Short text"
|
||||||
|
max_tokens = 20
|
||||||
|
num_input_tokens = 5
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_encoder.return_value = None # Should not be called
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
|
||||||
|
assert result == text
|
||||||
|
mock_encoder.assert_not_called()
|
||||||
|
|
||||||
|
def test_special_characters_and_unicode(self):
|
||||||
|
"""Test text with special characters and Unicode content."""
|
||||||
|
text = "Special chars: @#$%^&*()_+ áéíóú 中문 🚀 emoji"
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) < len(text)
|
||||||
|
|
||||||
|
def test_multiline_text_handling(self):
|
||||||
|
"""Test handling of multiline text."""
|
||||||
|
text = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
def test_very_long_text(self):
|
||||||
|
"""Test with very long text."""
|
||||||
|
text = "A" * 10000 # Very long text
|
||||||
|
max_tokens = 10
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 5000 # Many tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
assert len(result) < len(text)
|
||||||
|
assert result.endswith("\n...(truncated)")
|
||||||
|
|
||||||
|
def test_encoder_exception_handling(self):
|
||||||
|
"""Test handling of encoder exceptions."""
|
||||||
|
text = "Test text"
|
||||||
|
max_tokens = 10
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_encoder.side_effect = Exception("Encoder error")
|
||||||
|
|
||||||
|
# Should return original text when encoder fails
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
assert result == text
|
||||||
|
|
||||||
|
def test_zero_division_scenario(self):
|
||||||
|
"""Test scenario that could lead to division by zero."""
|
||||||
|
text = "Test"
|
||||||
|
max_tokens = 10
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [] # Empty tokens (could cause division by zero)
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
# Should handle gracefully and return original text
|
||||||
|
assert result == text
|
||||||
|
|
||||||
|
def test_various_edge_cases(self):
|
||||||
|
"""Test various edge cases."""
|
||||||
|
# Single character
|
||||||
|
assert clip_tokens("A", 1000) == "A"
|
||||||
|
|
||||||
|
# Only whitespace
|
||||||
|
text = " \n \t "
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 10
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, 5)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
# Text with only newlines
|
||||||
|
text = "\n\n\n\n"
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 10
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, 2, delete_last_line=True)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
def test_parameter_combinations(self):
|
||||||
|
"""Test different parameter combinations."""
|
||||||
|
text = "Multi\nline\ntext\nfor\ntesting"
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 20
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
# Test all combinations
|
||||||
|
combinations = [
|
||||||
|
(True, True), # add_three_dots=True, delete_last_line=True
|
||||||
|
(True, False), # add_three_dots=True, delete_last_line=False
|
||||||
|
(False, True), # add_three_dots=False, delete_last_line=True
|
||||||
|
(False, False), # add_three_dots=False, delete_last_line=False
|
||||||
|
]
|
||||||
|
|
||||||
|
for add_dots, delete_line in combinations:
|
||||||
|
result = clip_tokens(text, max_tokens,
|
||||||
|
add_three_dots=add_dots,
|
||||||
|
delete_last_line=delete_line)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
if add_dots and len(result) > 0:
|
||||||
|
assert result.endswith("\n...(truncated)") or result == text
|
||||||
|
|
||||||
|
def test_num_output_chars_zero_scenario(self):
|
||||||
|
"""Test scenario where num_output_chars becomes zero or negative."""
|
||||||
|
text = "Short"
|
||||||
|
max_tokens = 1
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 1000 # Many tokens for short text
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
# When num_output_chars is 0 or negative, should return empty string
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_logging_on_exception(self):
|
||||||
|
"""Test that exceptions are properly logged."""
|
||||||
|
text = "Test text"
|
||||||
|
max_tokens = 10
|
||||||
|
|
||||||
|
# Patch the logger at the module level where it's imported
|
||||||
|
with patch('pr_agent.algo.utils.get_logger') as mock_logger:
|
||||||
|
mock_log_instance = MagicMock()
|
||||||
|
mock_logger.return_value = mock_log_instance
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_encoder.side_effect = Exception("Test exception")
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
|
||||||
|
# Should log the warning
|
||||||
|
mock_log_instance.warning.assert_called_once()
|
||||||
|
# Should return original text
|
||||||
|
assert result == text
|
||||||
|
|
||||||
|
def test_factor_safety_calculation(self):
|
||||||
|
"""Test that the 0.9 factor (10% reduction) works correctly."""
|
||||||
|
text = "Test text that should be reduced by 10 percent for safety"
|
||||||
|
max_tokens = 10
|
||||||
|
|
||||||
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
||||||
|
mock_encoder.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
result = clip_tokens(text, max_tokens)
|
||||||
|
|
||||||
|
# The result should be shorter due to the 0.9 factor
|
||||||
|
# Characters per token = len(text) / 20
|
||||||
|
# Expected chars = int(0.9 * (len(text) / 20) * 10)
|
||||||
|
expected_chars = int(0.9 * (len(text) / 20) * 10)
|
||||||
|
|
||||||
|
# Result should be around expected_chars length (plus truncation text)
|
||||||
|
if result.endswith("\n...(truncated)"):
|
||||||
|
actual_content = result[:-len("\n...(truncated)")]
|
||||||
|
assert len(actual_content) <= expected_chars + 5 # Some tolerance
|
||||||
|
|
||||||
|
# Test the original basic functionality to ensure backward compatibility
|
||||||
|
def test_clip_original_functionality(self):
|
||||||
|
"""Test original functionality from the existing test."""
|
||||||
text = "line1\nline2\nline3\nline4\nline5\nline6"
|
text = "line1\nline2\nline3\nline4\nline5\nline6"
|
||||||
max_tokens = 25
|
max_tokens = 25
|
||||||
result = clip_tokens(text, max_tokens)
|
result = clip_tokens(text, max_tokens)
|
||||||
@ -16,4 +305,4 @@ class TestClipTokens:
|
|||||||
max_tokens = 10
|
max_tokens = 10
|
||||||
result = clip_tokens(text, max_tokens)
|
result = clip_tokens(text, max_tokens)
|
||||||
expected_results = 'line1\nline2\nline3\n\n...(truncated)'
|
expected_results = 'line1\nline2\nline3\n\n...(truncated)'
|
||||||
assert result == expected_results
|
assert result == expected_results
|
Reference in New Issue
Block a user