Files
pr-agent/tests/unittest/test_clip_tokens.py
TaskerJang f5bb508736 fix: use identity check for None comparison in clip_tokens tests
- Replace `== None` with `is None` in test_empty_input_text method
- Follow Python best practice for None comparisons as recommended in code review
- Address feedback from PR #1816 review comment

Co-authored-by: mrT23
2025-05-25 07:55:18 +09:00

308 lines
12 KiB
Python

import pytest
from unittest.mock import patch, MagicMock
from pr_agent.algo.utils import clip_tokens
from pr_agent.algo.token_handler import TokenEncoder
class TestClipTokens:
"""Comprehensive test suite for the clip_tokens function."""
def test_empty_input_text(self):
"""Test that empty input returns empty string."""
assert clip_tokens("", 10) == ""
assert clip_tokens(None, 10) is None
def test_text_under_token_limit(self):
"""Test that text under the token limit is returned unchanged."""
text = "Short text"
max_tokens = 100
result = clip_tokens(text, max_tokens)
assert result == text
def test_text_exactly_at_token_limit(self):
"""Test text that is exactly at the token limit."""
text = "This is exactly at the limit"
# Mock the token encoder to return exact limit
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 10 # Exactly 10 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, 10)
assert result == text
def test_text_over_token_limit_with_three_dots(self):
"""Test text over token limit with three dots addition."""
text = "This is a longer text that should be clipped when it exceeds the token limit"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert result.endswith("\n...(truncated)")
assert len(result) < len(text)
def test_text_over_token_limit_without_three_dots(self):
"""Test text over token limit without three dots addition."""
text = "This is a longer text that should be clipped"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens, add_three_dots=False)
assert not result.endswith("\n...(truncated)")
assert len(result) < len(text)
def test_negative_max_tokens(self):
"""Test that negative max_tokens returns empty string."""
text = "Some text"
result = clip_tokens(text, -1)
assert result == ""
result = clip_tokens(text, -100)
assert result == ""
def test_zero_max_tokens(self):
"""Test that zero max_tokens returns empty string."""
text = "Some text"
result = clip_tokens(text, 0)
assert result == ""
def test_delete_last_line_functionality(self):
"""Test the delete_last_line parameter functionality."""
text = "Line 1\nLine 2\nLine 3\nLine 4"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
# Without delete_last_line
result_normal = clip_tokens(text, max_tokens, delete_last_line=False)
# With delete_last_line
result_deleted = clip_tokens(text, max_tokens, delete_last_line=True)
# The result with delete_last_line should be shorter or equal
assert len(result_deleted) <= len(result_normal)
def test_pre_computed_num_input_tokens(self):
"""Test using pre-computed num_input_tokens parameter."""
text = "This is a test text"
max_tokens = 10
num_input_tokens = 15
# Should not call the encoder when num_input_tokens is provided
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.return_value = None # Should not be called
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
assert result.endswith("\n...(truncated)")
mock_encoder.assert_not_called()
def test_pre_computed_tokens_under_limit(self):
"""Test pre-computed tokens under the limit."""
text = "Short text"
max_tokens = 20
num_input_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.return_value = None # Should not be called
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
assert result == text
mock_encoder.assert_not_called()
def test_special_characters_and_unicode(self):
"""Test text with special characters and Unicode content."""
text = "Special chars: @#$%^&*()_+ áéíóú 中문 🚀 emoji"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert isinstance(result, str)
assert len(result) < len(text)
def test_multiline_text_handling(self):
"""Test handling of multiline text."""
text = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert isinstance(result, str)
def test_very_long_text(self):
"""Test with very long text."""
text = "A" * 10000 # Very long text
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 5000 # Many tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert len(result) < len(text)
assert result.endswith("\n...(truncated)")
def test_encoder_exception_handling(self):
"""Test handling of encoder exceptions."""
text = "Test text"
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.side_effect = Exception("Encoder error")
# Should return original text when encoder fails
result = clip_tokens(text, max_tokens)
assert result == text
def test_zero_division_scenario(self):
"""Test scenario that could lead to division by zero."""
text = "Test"
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [] # Empty tokens (could cause division by zero)
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
# Should handle gracefully and return original text
assert result == text
def test_various_edge_cases(self):
"""Test various edge cases."""
# Single character
assert clip_tokens("A", 1000) == "A"
# Only whitespace
text = " \n \t "
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 10
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, 5)
assert isinstance(result, str)
# Text with only newlines
text = "\n\n\n\n"
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 10
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, 2, delete_last_line=True)
assert isinstance(result, str)
def test_parameter_combinations(self):
"""Test different parameter combinations."""
text = "Multi\nline\ntext\nfor\ntesting"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20
mock_encoder.return_value = mock_tokenizer
# Test all combinations
combinations = [
(True, True), # add_three_dots=True, delete_last_line=True
(True, False), # add_three_dots=True, delete_last_line=False
(False, True), # add_three_dots=False, delete_last_line=True
(False, False), # add_three_dots=False, delete_last_line=False
]
for add_dots, delete_line in combinations:
result = clip_tokens(text, max_tokens,
add_three_dots=add_dots,
delete_last_line=delete_line)
assert isinstance(result, str)
if add_dots and len(result) > 0:
assert result.endswith("\n...(truncated)") or result == text
def test_num_output_chars_zero_scenario(self):
"""Test scenario where num_output_chars becomes zero or negative."""
text = "Short"
max_tokens = 1
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 1000 # Many tokens for short text
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
# When num_output_chars is 0 or negative, should return empty string
assert result == ""
def test_logging_on_exception(self):
"""Test that exceptions are properly logged."""
text = "Test text"
max_tokens = 10
# Patch the logger at the module level where it's imported
with patch('pr_agent.algo.utils.get_logger') as mock_logger:
mock_log_instance = MagicMock()
mock_logger.return_value = mock_log_instance
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.side_effect = Exception("Test exception")
result = clip_tokens(text, max_tokens)
# Should log the warning
mock_log_instance.warning.assert_called_once()
# Should return original text
assert result == text
def test_factor_safety_calculation(self):
"""Test that the 0.9 factor (10% reduction) works correctly."""
text = "Test text that should be reduced by 10 percent for safety"
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
# The result should be shorter due to the 0.9 factor
# Characters per token = len(text) / 20
# Expected chars = int(0.9 * (len(text) / 20) * 10)
expected_chars = int(0.9 * (len(text) / 20) * 10)
# Result should be around expected_chars length (plus truncation text)
if result.endswith("\n...(truncated)"):
actual_content = result[:-len("\n...(truncated)")]
assert len(actual_content) <= expected_chars + 5 # Some tolerance
# Test the original basic functionality to ensure backward compatibility
def test_clip_original_functionality(self):
"""Test original functionality from the existing test."""
text = "line1\nline2\nline3\nline4\nline5\nline6"
max_tokens = 25
result = clip_tokens(text, max_tokens)
assert result == text
max_tokens = 10
result = clip_tokens(text, max_tokens)
expected_results = 'line1\nline2\nline3\n\n...(truncated)'
assert result == expected_results