mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00

- Replace `== None` with `is None` in test_empty_input_text method - Follow Python best practice for None comparisons as recommended in code review - Address feedback from PR #1816 review comment Co-authored-by: mrT23
308 lines
12 KiB
Python
308 lines
12 KiB
Python
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
from pr_agent.algo.utils import clip_tokens
|
|
from pr_agent.algo.token_handler import TokenEncoder
|
|
|
|
|
|
class TestClipTokens:
|
|
"""Comprehensive test suite for the clip_tokens function."""
|
|
|
|
def test_empty_input_text(self):
|
|
"""Test that empty input returns empty string."""
|
|
assert clip_tokens("", 10) == ""
|
|
assert clip_tokens(None, 10) is None
|
|
|
|
def test_text_under_token_limit(self):
|
|
"""Test that text under the token limit is returned unchanged."""
|
|
text = "Short text"
|
|
max_tokens = 100
|
|
result = clip_tokens(text, max_tokens)
|
|
assert result == text
|
|
|
|
def test_text_exactly_at_token_limit(self):
|
|
"""Test text that is exactly at the token limit."""
|
|
text = "This is exactly at the limit"
|
|
# Mock the token encoder to return exact limit
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 10 # Exactly 10 tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, 10)
|
|
assert result == text
|
|
|
|
def test_text_over_token_limit_with_three_dots(self):
|
|
"""Test text over token limit with three dots addition."""
|
|
text = "This is a longer text that should be clipped when it exceeds the token limit"
|
|
max_tokens = 5
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
assert result.endswith("\n...(truncated)")
|
|
assert len(result) < len(text)
|
|
|
|
def test_text_over_token_limit_without_three_dots(self):
|
|
"""Test text over token limit without three dots addition."""
|
|
text = "This is a longer text that should be clipped"
|
|
max_tokens = 5
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens, add_three_dots=False)
|
|
assert not result.endswith("\n...(truncated)")
|
|
assert len(result) < len(text)
|
|
|
|
def test_negative_max_tokens(self):
|
|
"""Test that negative max_tokens returns empty string."""
|
|
text = "Some text"
|
|
result = clip_tokens(text, -1)
|
|
assert result == ""
|
|
|
|
result = clip_tokens(text, -100)
|
|
assert result == ""
|
|
|
|
def test_zero_max_tokens(self):
|
|
"""Test that zero max_tokens returns empty string."""
|
|
text = "Some text"
|
|
result = clip_tokens(text, 0)
|
|
assert result == ""
|
|
|
|
def test_delete_last_line_functionality(self):
|
|
"""Test the delete_last_line parameter functionality."""
|
|
text = "Line 1\nLine 2\nLine 3\nLine 4"
|
|
max_tokens = 5
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
# Without delete_last_line
|
|
result_normal = clip_tokens(text, max_tokens, delete_last_line=False)
|
|
|
|
# With delete_last_line
|
|
result_deleted = clip_tokens(text, max_tokens, delete_last_line=True)
|
|
|
|
# The result with delete_last_line should be shorter or equal
|
|
assert len(result_deleted) <= len(result_normal)
|
|
|
|
def test_pre_computed_num_input_tokens(self):
|
|
"""Test using pre-computed num_input_tokens parameter."""
|
|
text = "This is a test text"
|
|
max_tokens = 10
|
|
num_input_tokens = 15
|
|
|
|
# Should not call the encoder when num_input_tokens is provided
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_encoder.return_value = None # Should not be called
|
|
|
|
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
|
|
assert result.endswith("\n...(truncated)")
|
|
mock_encoder.assert_not_called()
|
|
|
|
def test_pre_computed_tokens_under_limit(self):
|
|
"""Test pre-computed tokens under the limit."""
|
|
text = "Short text"
|
|
max_tokens = 20
|
|
num_input_tokens = 5
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_encoder.return_value = None # Should not be called
|
|
|
|
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
|
|
assert result == text
|
|
mock_encoder.assert_not_called()
|
|
|
|
def test_special_characters_and_unicode(self):
|
|
"""Test text with special characters and Unicode content."""
|
|
text = "Special chars: @#$%^&*()_+ áéíóú 中문 🚀 emoji"
|
|
max_tokens = 5
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
assert isinstance(result, str)
|
|
assert len(result) < len(text)
|
|
|
|
def test_multiline_text_handling(self):
|
|
"""Test handling of multiline text."""
|
|
text = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
|
max_tokens = 5
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
assert isinstance(result, str)
|
|
|
|
def test_very_long_text(self):
|
|
"""Test with very long text."""
|
|
text = "A" * 10000 # Very long text
|
|
max_tokens = 10
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 5000 # Many tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
assert len(result) < len(text)
|
|
assert result.endswith("\n...(truncated)")
|
|
|
|
def test_encoder_exception_handling(self):
|
|
"""Test handling of encoder exceptions."""
|
|
text = "Test text"
|
|
max_tokens = 10
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_encoder.side_effect = Exception("Encoder error")
|
|
|
|
# Should return original text when encoder fails
|
|
result = clip_tokens(text, max_tokens)
|
|
assert result == text
|
|
|
|
def test_zero_division_scenario(self):
|
|
"""Test scenario that could lead to division by zero."""
|
|
text = "Test"
|
|
max_tokens = 10
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [] # Empty tokens (could cause division by zero)
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
# Should handle gracefully and return original text
|
|
assert result == text
|
|
|
|
def test_various_edge_cases(self):
|
|
"""Test various edge cases."""
|
|
# Single character
|
|
assert clip_tokens("A", 1000) == "A"
|
|
|
|
# Only whitespace
|
|
text = " \n \t "
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 10
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, 5)
|
|
assert isinstance(result, str)
|
|
|
|
# Text with only newlines
|
|
text = "\n\n\n\n"
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 10
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, 2, delete_last_line=True)
|
|
assert isinstance(result, str)
|
|
|
|
def test_parameter_combinations(self):
|
|
"""Test different parameter combinations."""
|
|
text = "Multi\nline\ntext\nfor\ntesting"
|
|
max_tokens = 5
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 20
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
# Test all combinations
|
|
combinations = [
|
|
(True, True), # add_three_dots=True, delete_last_line=True
|
|
(True, False), # add_three_dots=True, delete_last_line=False
|
|
(False, True), # add_three_dots=False, delete_last_line=True
|
|
(False, False), # add_three_dots=False, delete_last_line=False
|
|
]
|
|
|
|
for add_dots, delete_line in combinations:
|
|
result = clip_tokens(text, max_tokens,
|
|
add_three_dots=add_dots,
|
|
delete_last_line=delete_line)
|
|
assert isinstance(result, str)
|
|
if add_dots and len(result) > 0:
|
|
assert result.endswith("\n...(truncated)") or result == text
|
|
|
|
def test_num_output_chars_zero_scenario(self):
|
|
"""Test scenario where num_output_chars becomes zero or negative."""
|
|
text = "Short"
|
|
max_tokens = 1
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 1000 # Many tokens for short text
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
# When num_output_chars is 0 or negative, should return empty string
|
|
assert result == ""
|
|
|
|
def test_logging_on_exception(self):
|
|
"""Test that exceptions are properly logged."""
|
|
text = "Test text"
|
|
max_tokens = 10
|
|
|
|
# Patch the logger at the module level where it's imported
|
|
with patch('pr_agent.algo.utils.get_logger') as mock_logger:
|
|
mock_log_instance = MagicMock()
|
|
mock_logger.return_value = mock_log_instance
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_encoder.side_effect = Exception("Test exception")
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
|
|
# Should log the warning
|
|
mock_log_instance.warning.assert_called_once()
|
|
# Should return original text
|
|
assert result == text
|
|
|
|
def test_factor_safety_calculation(self):
|
|
"""Test that the 0.9 factor (10% reduction) works correctly."""
|
|
text = "Test text that should be reduced by 10 percent for safety"
|
|
max_tokens = 10
|
|
|
|
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
|
|
mock_encoder.return_value = mock_tokenizer
|
|
|
|
result = clip_tokens(text, max_tokens)
|
|
|
|
# The result should be shorter due to the 0.9 factor
|
|
# Characters per token = len(text) / 20
|
|
# Expected chars = int(0.9 * (len(text) / 20) * 10)
|
|
expected_chars = int(0.9 * (len(text) / 20) * 10)
|
|
|
|
# Result should be around expected_chars length (plus truncation text)
|
|
if result.endswith("\n...(truncated)"):
|
|
actual_content = result[:-len("\n...(truncated)")]
|
|
assert len(actual_content) <= expected_chars + 5 # Some tolerance
|
|
|
|
# Test the original basic functionality to ensure backward compatibility
|
|
def test_clip_original_functionality(self):
|
|
"""Test original functionality from the existing test."""
|
|
text = "line1\nline2\nline3\nline4\nline5\nline6"
|
|
max_tokens = 25
|
|
result = clip_tokens(text, max_tokens)
|
|
assert result == text
|
|
|
|
max_tokens = 10
|
|
result = clip_tokens(text, max_tokens)
|
|
expected_results = 'line1\nline2\nline3\n\n...(truncated)'
|
|
assert result == expected_results |