Merge branch 'main' of github.com:qodo-ai/pr-agent into feature/gitea-implement

This commit is contained in:
Pinyoo Thotaboot
2025-05-26 10:59:19 +07:00
15 changed files with 693 additions and 72 deletions

View File

@ -56,6 +56,21 @@ Everything below this marker is treated as previously auto-generated content and
![Describe comment](https://codium.ai/images/pr_agent/pr_description_user_description.png){width=512}
### Sequence Diagram Support
When the `enable_pr_diagram` option is enabled in your configuration, the `/describe` tool will include a `Mermaid` sequence diagram in the PR description.
This diagram represents interactions between components/functions based on the diff content.
### How to enable
In your configuration:
```
toml
[pr_description]
enable_pr_diagram = true
```
## Configuration options
!!! example "Possible configurations"
@ -109,6 +124,10 @@ Everything below this marker is treated as previously auto-generated content and
<td><b>enable_help_text</b></td>
<td>If set to true, the tool will display a help text in the comment. Default is false.</td>
</tr>
<tr>
<td><b>add_diagram</b></td>
<td>If set to true, the tool will generate a <code>Mermaid</code> sequence diagram (in code block format) describing component interactions based on the code changes. Default is false.</td>
</tr>
</table>
## Inline file summary 💎

View File

@ -435,7 +435,7 @@ To enable auto-approval based on specific criteria, first, you need to enable th
enable_auto_approval = true
```
There are two criteria that can be set for auto-approval:
There are several criteria that can be set for auto-approval:
- **Review effort score**
@ -457,7 +457,19 @@ enable_auto_approval = true
auto_approve_for_no_suggestions = true
```
When no [code suggestion](https://www.qodo.ai/images/pr_agent/code_suggestions_as_comment_closed.png) were found for the PR, the PR will be auto-approved.
When no [code suggestions](https://www.qodo.ai/images/pr_agent/code_suggestions_as_comment_closed.png) were found for the PR, the PR will be auto-approved.
___
- **Ticket Compliance**
```toml
[config]
enable_auto_approval = true
ensure_ticket_compliance = true # Default is false
```
If `ensure_ticket_compliance` is set to `true`, auto-approval will be disabled if a ticket is linked to the PR and the ticket is not compliant (e.g., the `review` tool did not mark the PR as fully compliant with the ticket). This ensures that PRs are only auto-approved if their associated tickets are properly resolved.
### How many code suggestions are generated?

View File

@ -53,9 +53,11 @@ MAX_TOKENS = {
'vertex_ai/claude-3-5-haiku@20241022': 100000,
'vertex_ai/claude-3-sonnet@20240229': 100000,
'vertex_ai/claude-3-opus@20240229': 100000,
'vertex_ai/claude-opus-4@20250514': 200000,
'vertex_ai/claude-3-5-sonnet@20240620': 100000,
'vertex_ai/claude-3-5-sonnet-v2@20241022': 100000,
'vertex_ai/claude-3-7-sonnet@20250219': 200000,
'vertex_ai/claude-sonnet-4@20250514': 200000,
'vertex_ai/gemini-1.5-pro': 1048576,
'vertex_ai/gemini-2.5-pro-preview-03-25': 1048576,
'vertex_ai/gemini-2.5-pro-preview-05-06': 1048576,
@ -74,22 +76,28 @@ MAX_TOKENS = {
'anthropic.claude-v1': 100000,
'anthropic.claude-v2': 100000,
'anthropic/claude-3-opus-20240229': 100000,
'anthropic/claude-opus-4-20250514': 200000,
'anthropic/claude-3-5-sonnet-20240620': 100000,
'anthropic/claude-3-5-sonnet-20241022': 100000,
'anthropic/claude-3-7-sonnet-20250219': 200000,
'anthropic/claude-sonnet-4-20250514': 200000,
'claude-3-7-sonnet-20250219': 200000,
'anthropic/claude-3-5-haiku-20241022': 100000,
'bedrock/anthropic.claude-instant-v1': 100000,
'bedrock/anthropic.claude-v2': 100000,
'bedrock/anthropic.claude-v2:1': 100000,
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000,
'bedrock/anthropic.claude-opus-4-20250514-v1:0': 200000,
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
'bedrock/anthropic.claude-3-5-haiku-20241022-v1:0': 100000,
'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000,
'bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0': 100000,
'bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0': 200000,
'bedrock/anthropic.claude-sonnet-4-20250514-v1:0': 200000,
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": 200000,
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0": 100000,
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0": 200000,
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": 200000,
'claude-3-5-sonnet': 100000,
'groq/meta-llama/llama-4-scout-17b-16e-instruct': 131072,
'groq/meta-llama/llama-4-maverick-17b-128e-instruct': 131072,
@ -102,9 +110,13 @@ MAX_TOKENS = {
'xai/grok-2': 131072,
'xai/grok-2-1212': 131072,
'xai/grok-2-latest': 131072,
'xai/grok-3': 131072,
'xai/grok-3-beta': 131072,
'xai/grok-3-fast': 131072,
'xai/grok-3-fast-beta': 131072,
'xai/grok-3-mini': 131072,
'xai/grok-3-mini-beta': 131072,
'xai/grok-3-mini-fast': 131072,
'xai/grok-3-mini-fast-beta': 131072,
'ollama/llama3': 4096,
'watsonx/meta-llama/llama-3-8b-instruct': 4096,

View File

@ -6,8 +6,8 @@ except: # we don't enforce langchain as a dependency, so if it's not installed,
import functools
from openai import APIError, RateLimitError, Timeout
from retry import retry
import openai
from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
@ -36,8 +36,10 @@ class LangChainOpenAIHandler(BaseAiHandler):
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
@retry(
retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError),
stop=stop_after_attempt(OPENAI_RETRIES),
)
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
messages = [SystemMessage(content=system), HumanMessage(content=user)]
@ -47,9 +49,15 @@ class LangChainOpenAIHandler(BaseAiHandler):
finish_reason = "completed"
return resp.content, finish_reason
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e
except openai.RateLimitError as e:
get_logger().error(f"Rate limit error during LLM inference: {e}")
raise
except openai.APIError as e:
get_logger().warning(f"Error during LLM inference: {e}")
raise
except Exception as e:
get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e
def _create_chat(self, deployment_id=None):
try:

View File

@ -3,7 +3,7 @@ import litellm
import openai
import requests
from litellm import acompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt
from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
@ -274,8 +274,8 @@ class LiteLLMAIHandler(BaseAiHandler):
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError
stop=stop_after_attempt(OPENAI_RETRIES)
retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError),
stop=stop_after_attempt(OPENAI_RETRIES),
)
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
try:
@ -371,13 +371,13 @@ class LiteLLMAIHandler(BaseAiHandler):
get_logger().info(f"\nUser prompt:\n{user}")
response = await acompletion(**kwargs)
except (openai.RateLimitError) as e:
except openai.RateLimitError as e:
get_logger().error(f"Rate limit error during LLM inference: {e}")
raise
except (openai.APIError, openai.APITimeoutError) as e:
except openai.APIError as e:
get_logger().warning(f"Error during LLM inference: {e}")
raise
except (Exception) as e:
except Exception as e:
get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e
if response is None or len(response["choices"]) == 0:

View File

@ -1,8 +1,8 @@
from os import environ
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
import openai
from openai import APIError, AsyncOpenAI, RateLimitError, Timeout
from retry import retry
from openai import AsyncOpenAI
from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
@ -38,8 +38,10 @@ class OpenAIHandler(BaseAiHandler):
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
@retry(
retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError),
stop=stop_after_attempt(OPENAI_RETRIES),
)
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
get_logger().info("System: ", system)
@ -57,12 +59,12 @@ class OpenAIHandler(BaseAiHandler):
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
except (APIError, Timeout) as e:
get_logger().error("Error during OpenAI inference: ", e)
except openai.RateLimitError as e:
get_logger().error(f"Rate limit error during LLM inference: {e}")
raise
except (RateLimitError) as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
except openai.APIError as e:
get_logger().warning(f"Error during LLM inference: {e}")
raise
except Exception as e:
get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e

View File

@ -1,4 +1,6 @@
from threading import Lock
from math import ceil
import re
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model, get_encoding
@ -7,6 +9,16 @@ from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
class ModelTypeValidator:
@staticmethod
def is_openai_model(model_name: str) -> bool:
return 'gpt' in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name)
@staticmethod
def is_anthropic_model(model_name: str) -> bool:
return 'claude' in model_name
class TokenEncoder:
_encoder_instance = None
_model = None
@ -40,6 +52,10 @@ class TokenHandler:
method.
"""
# Constants
CLAUDE_MODEL = "claude-3-7-sonnet-20250219"
CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API
def __init__(self, pr=None, vars: dict = {}, system="", user=""):
"""
Initializes the TokenHandler object.
@ -51,6 +67,7 @@ class TokenHandler:
- user: The user string.
"""
self.encoder = TokenEncoder.get_token_encoder()
if pr is not None:
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
@ -79,22 +96,22 @@ class TokenHandler:
get_logger().error(f"Error in _get_system_user_tokens: {e}")
return 0
def calc_claude_tokens(self, patch):
def _calc_claude_tokens(self, patch: str) -> int:
try:
import anthropic
from pr_agent.algo import MAX_TOKENS
client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key'))
MaxTokens = MAX_TOKENS[get_settings().config.model]
# Check if the content size is too large (9MB limit)
if len(patch.encode('utf-8')) > 9_000_000:
client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key'))
max_tokens = MAX_TOKENS[get_settings().config.model]
if len(patch.encode('utf-8')) > self.CLAUDE_MAX_CONTENT_SIZE:
get_logger().warning(
"Content too large for Anthropic token counting API, falling back to local tokenizer"
)
return MaxTokens
return max_tokens
response = client.messages.count_tokens(
model="claude-3-7-sonnet-20250219",
model=self.CLAUDE_MODEL,
system="system",
messages=[{
"role": "user",
@ -104,42 +121,51 @@ class TokenHandler:
return response.input_tokens
except Exception as e:
get_logger().error( f"Error in Anthropic token counting: {e}")
return MaxTokens
get_logger().error(f"Error in Anthropic token counting: {e}")
return max_tokens
def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate):
from math import ceil
import re
def _apply_estimation_factor(self, model_name: str, default_estimate: int) -> int:
factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}")
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model)
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
return default_encoder_estimate
#else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
return ceil(factor * default_estimate)
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate")
return ceil(elbow_factor * default_encoder_estimate)
def _get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int:
"""
Get token count based on model type.
def count_tokens(self, patch: str, force_accurate=False) -> int:
Args:
patch: The text to count tokens for.
default_estimate: The default token count estimate.
Returns:
int: The calculated token count.
"""
model_name = get_settings().config.model.lower()
if ModelTypeValidator.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'):
return default_estimate
if ModelTypeValidator.is_anthropic_model(model_name) and get_settings(use_context=False).get('anthropic.key'):
return self._calc_claude_tokens(patch)
return self._apply_estimation_factor(model_name, default_estimate)
def count_tokens(self, patch: str, force_accurate: bool = False) -> int:
"""
Counts the number of tokens in a given patch string.
Args:
- patch: The patch string.
- force_accurate: If True, uses a more precise calculation method.
Returns:
The number of tokens in the patch string.
"""
encoder_estimate = len(self.encoder.encode(patch, disallowed_special=()))
#If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it.
# If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it.
if not force_accurate:
return encoder_estimate
#else, force_accurate==True: User requested providing an accurate estimation:
model = get_settings().config.model.lower()
if 'claude' in model and get_settings(use_context=False).get('anthropic.key'):
return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models
#else: Non Anthropic provided model:
return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate)
return self._get_token_count_by_model_type(patch, encoder_estimate)

View File

@ -945,12 +945,66 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token
"""
Clip the number of tokens in a string to a maximum number of tokens.
This function limits text to a specified token count by calculating the approximate
character-to-token ratio and truncating the text accordingly. A safety factor of 0.9
(10% reduction) is applied to ensure the result stays within the token limit.
Args:
text (str): The string to clip.
text (str): The string to clip. If empty or None, returns the input unchanged.
max_tokens (int): The maximum number of tokens allowed in the string.
add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped
If negative, returns an empty string.
add_three_dots (bool, optional): Whether to add "\\n...(truncated)" at the end
of the clipped text to indicate truncation.
Defaults to True.
num_input_tokens (int, optional): Pre-computed number of tokens in the input text.
If provided, skips token encoding step for efficiency.
If None, tokens will be counted using TokenEncoder.
Defaults to None.
delete_last_line (bool, optional): Whether to remove the last line from the
clipped content before adding truncation indicator.
Useful for ensuring clean breaks at line boundaries.
Defaults to False.
Returns:
str: The clipped string.
str: The clipped string. Returns original text if:
- Text is empty/None
- Token count is within limit
- An error occurs during processing
Returns empty string if max_tokens <= 0.
Examples:
Basic usage:
>>> text = "This is a sample text that might be too long"
>>> result = clip_tokens(text, max_tokens=10)
>>> print(result)
This is a sample...
(truncated)
Without truncation indicator:
>>> result = clip_tokens(text, max_tokens=10, add_three_dots=False)
>>> print(result)
This is a sample
With pre-computed token count:
>>> result = clip_tokens(text, max_tokens=5, num_input_tokens=15)
>>> print(result)
This...
(truncated)
With line deletion:
>>> multiline_text = "Line 1\\nLine 2\\nLine 3"
>>> result = clip_tokens(multiline_text, max_tokens=3, delete_last_line=True)
>>> print(result)
Line 1
Line 2
...
(truncated)
Notes:
The function uses a safety factor of 0.9 (10% reduction) to ensure the
result stays within the token limit, as character-to-token ratios can vary.
If token encoding fails, the original text is returned with a warning logged.
"""
if not text:
return text

View File

@ -117,3 +117,8 @@ api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.co
[openrouter]
key = ""
api_base = ""
[aws]
AWS_ACCESS_KEY_ID = ""
AWS_SECRET_ACCESS_KEY = ""
AWS_REGION_NAME = ""

View File

@ -64,6 +64,7 @@ reasoning_effort = "medium" # "low", "medium", "high"
enable_auto_approval=false # Set to true to enable auto-approval of PRs under certain conditions
auto_approve_for_low_review_effort=-1 # -1 to disable, [1-5] to set the threshold for auto-approval
auto_approve_for_no_suggestions=false # If true, the PR will be auto-approved if there are no suggestions
ensure_ticket_compliance=false # Set to true to disable auto-approval of PRs if the ticket is not compliant
# extended thinking for Claude reasoning models
enable_claude_extended_thinking = false # Set to true to enable extended thinking feature
extended_thinking_budget_tokens = 2048
@ -103,6 +104,7 @@ enable_pr_type=true
final_update_message = true
enable_help_text=false
enable_help_comment=true
enable_pr_diagram=false # adds a section with a diagram of the PR changes
# describe as comment
publish_description_as_comment=false
publish_description_as_comment_persistent=true

View File

@ -46,6 +46,9 @@ class PRDescription(BaseModel):
type: List[PRType] = Field(description="one or more types that describe the PR content. Return the label member value (e.g. 'Bug fix', not 'bug_fix')")
description: str = Field(description="summarize the PR changes in up to four bullet points, each up to 8 words. For large PRs, add sub-bullets if needed. Order bullets by importance, with each bullet highlighting a key change group.")
title: str = Field(description="a concise and descriptive title that captures the PR's main theme")
{%- if enable_pr_diagram %}
changes_diagram: str = Field(description="a horizontal diagram that represents the main PR changes, in the format of a valid mermaid LR flowchart. The diagram should be concise and easy to read. Leave empty if no diagram is relevant. To create robust Mermaid diagrams, follow this two-step process: (1) Declare the nodes: nodeID["node description"]. (2) Then define the links: nodeID1 -- "link text" --> nodeID2 ")
{%- endif %}
{%- if enable_semantic_files_types %}
pr_files: List[FileDescription] = Field(max_items=20, description="a list of all the files that were changed in the PR, and summary of their changes. Each file must be analyzed regardless of change size.")
{%- endif %}
@ -62,6 +65,13 @@ description: |
...
title: |
...
{%- if enable_pr_diagram %}
changes_diagram: |
```mermaid
flowchart LR
...
```
{%- endif %}
{%- if enable_semantic_files_types %}
pr_files:
- filename: |
@ -143,6 +153,13 @@ description: |
...
title: |
...
{%- if enable_pr_diagram %}
changes_diagram: |
```mermaid
flowchart LR
...
```
{%- endif %}
{%- if enable_semantic_files_types %}
pr_files:
- filename: |

View File

@ -72,7 +72,8 @@ class PRDescription:
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
"related_tickets": "",
"include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
"duplicate_prompt_examples": get_settings().config.get("duplicate_prompt_examples", False),
"enable_pr_diagram": get_settings().pr_description.get("enable_pr_diagram", False),
}
self.user_description = self.git_provider.get_user_description()
@ -456,6 +457,12 @@ class PRDescription:
self.data['labels'] = self.data.pop('labels')
if 'description' in self.data:
self.data['description'] = self.data.pop('description')
if 'changes_diagram' in self.data:
changes_diagram = self.data.pop('changes_diagram').strip()
if changes_diagram.startswith('```'):
if not changes_diagram.endswith('```'): # fallback for missing closing
changes_diagram += '\n```'
self.data['changes_diagram'] = '\n'+ changes_diagram
if 'pr_files' in self.data:
self.data['pr_files'] = self.data.pop('pr_files')

View File

@ -1,5 +1,5 @@
aiohttp==3.9.5
anthropic>=0.48
anthropic>=0.52.0
#anthropic[vertex]==0.47.1
atlassian-python-api==3.41.4
azure-devops==7.1.0b3
@ -13,7 +13,7 @@ google-cloud-aiplatform==1.38.0
google-generativeai==0.8.3
google-cloud-storage==2.10.0
Jinja2==3.1.2
litellm==1.69.3
litellm==1.70.4
loguru==0.7.2
msrest==0.7.1
openai>=1.55.3

View File

@ -1,13 +1,302 @@
# Generated by CodiumAI
import pytest
from unittest.mock import patch, MagicMock
from pr_agent.algo.utils import clip_tokens
from pr_agent.algo.token_handler import TokenEncoder
class TestClipTokens:
def test_clip(self):
"""Comprehensive test suite for the clip_tokens function."""
def test_empty_input_text(self):
"""Test that empty input returns empty string."""
assert clip_tokens("", 10) == ""
assert clip_tokens(None, 10) is None
def test_text_under_token_limit(self):
"""Test that text under the token limit is returned unchanged."""
text = "Short text"
max_tokens = 100
result = clip_tokens(text, max_tokens)
assert result == text
def test_text_exactly_at_token_limit(self):
"""Test text that is exactly at the token limit."""
text = "This is exactly at the limit"
# Mock the token encoder to return exact limit
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 10 # Exactly 10 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, 10)
assert result == text
def test_text_over_token_limit_with_three_dots(self):
"""Test text over token limit with three dots addition."""
text = "This is a longer text that should be clipped when it exceeds the token limit"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert result.endswith("\n...(truncated)")
assert len(result) < len(text)
def test_text_over_token_limit_without_three_dots(self):
"""Test text over token limit without three dots addition."""
text = "This is a longer text that should be clipped"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens, add_three_dots=False)
assert not result.endswith("\n...(truncated)")
assert len(result) < len(text)
def test_negative_max_tokens(self):
"""Test that negative max_tokens returns empty string."""
text = "Some text"
result = clip_tokens(text, -1)
assert result == ""
result = clip_tokens(text, -100)
assert result == ""
def test_zero_max_tokens(self):
"""Test that zero max_tokens returns empty string."""
text = "Some text"
result = clip_tokens(text, 0)
assert result == ""
def test_delete_last_line_functionality(self):
"""Test the delete_last_line parameter functionality."""
text = "Line 1\nLine 2\nLine 3\nLine 4"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
# Without delete_last_line
result_normal = clip_tokens(text, max_tokens, delete_last_line=False)
# With delete_last_line
result_deleted = clip_tokens(text, max_tokens, delete_last_line=True)
# The result with delete_last_line should be shorter or equal
assert len(result_deleted) <= len(result_normal)
def test_pre_computed_num_input_tokens(self):
"""Test using pre-computed num_input_tokens parameter."""
text = "This is a test text"
max_tokens = 10
num_input_tokens = 15
# Should not call the encoder when num_input_tokens is provided
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.return_value = None # Should not be called
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
assert result.endswith("\n...(truncated)")
mock_encoder.assert_not_called()
def test_pre_computed_tokens_under_limit(self):
"""Test pre-computed tokens under the limit."""
text = "Short text"
max_tokens = 20
num_input_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.return_value = None # Should not be called
result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens)
assert result == text
mock_encoder.assert_not_called()
def test_special_characters_and_unicode(self):
"""Test text with special characters and Unicode content."""
text = "Special chars: @#$%^&*()_+ áéíóú 中문 🚀 emoji"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert isinstance(result, str)
assert len(result) < len(text)
def test_multiline_text_handling(self):
"""Test handling of multiline text."""
text = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert isinstance(result, str)
def test_very_long_text(self):
"""Test with very long text."""
text = "A" * 10000 # Very long text
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 5000 # Many tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
assert len(result) < len(text)
assert result.endswith("\n...(truncated)")
def test_encoder_exception_handling(self):
"""Test handling of encoder exceptions."""
text = "Test text"
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.side_effect = Exception("Encoder error")
# Should return original text when encoder fails
result = clip_tokens(text, max_tokens)
assert result == text
def test_zero_division_scenario(self):
"""Test scenario that could lead to division by zero."""
text = "Test"
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [] # Empty tokens (could cause division by zero)
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
# Should handle gracefully and return original text
assert result == text
def test_various_edge_cases(self):
"""Test various edge cases."""
# Single character
assert clip_tokens("A", 1000) == "A"
# Only whitespace
text = " \n \t "
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 10
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, 5)
assert isinstance(result, str)
# Text with only newlines
text = "\n\n\n\n"
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 10
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, 2, delete_last_line=True)
assert isinstance(result, str)
def test_parameter_combinations(self):
"""Test different parameter combinations."""
text = "Multi\nline\ntext\nfor\ntesting"
max_tokens = 5
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20
mock_encoder.return_value = mock_tokenizer
# Test all combinations
combinations = [
(True, True), # add_three_dots=True, delete_last_line=True
(True, False), # add_three_dots=True, delete_last_line=False
(False, True), # add_three_dots=False, delete_last_line=True
(False, False), # add_three_dots=False, delete_last_line=False
]
for add_dots, delete_line in combinations:
result = clip_tokens(text, max_tokens,
add_three_dots=add_dots,
delete_last_line=delete_line)
assert isinstance(result, str)
if add_dots and len(result) > 0:
assert result.endswith("\n...(truncated)") or result == text
def test_num_output_chars_zero_scenario(self):
"""Test scenario where num_output_chars becomes zero or negative."""
text = "Short"
max_tokens = 1
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 1000 # Many tokens for short text
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
# When num_output_chars is 0 or negative, should return empty string
assert result == ""
def test_logging_on_exception(self):
"""Test that exceptions are properly logged."""
text = "Test text"
max_tokens = 10
# Patch the logger at the module level where it's imported
with patch('pr_agent.algo.utils.get_logger') as mock_logger:
mock_log_instance = MagicMock()
mock_logger.return_value = mock_log_instance
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_encoder.side_effect = Exception("Test exception")
result = clip_tokens(text, max_tokens)
# Should log the warning
mock_log_instance.warning.assert_called_once()
# Should return original text
assert result == text
def test_factor_safety_calculation(self):
"""Test that the 0.9 factor (10% reduction) works correctly."""
text = "Test text that should be reduced by 10 percent for safety"
max_tokens = 10
with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder:
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens
mock_encoder.return_value = mock_tokenizer
result = clip_tokens(text, max_tokens)
# The result should be shorter due to the 0.9 factor
# Characters per token = len(text) / 20
# Expected chars = int(0.9 * (len(text) / 20) * 10)
expected_chars = int(0.9 * (len(text) / 20) * 10)
# Result should be around expected_chars length (plus truncation text)
if result.endswith("\n...(truncated)"):
actual_content = result[:-len("\n...(truncated)")]
assert len(actual_content) <= expected_chars + 5 # Some tolerance
# Test the original basic functionality to ensure backward compatibility
def test_clip_original_functionality(self):
"""Test original functionality from the existing test."""
text = "line1\nline2\nline3\nline4\nline5\nline6"
max_tokens = 25
result = clip_tokens(text, max_tokens)

View File

@ -53,12 +53,12 @@ code_suggestions:
- relevant_file: |
src/index2.ts
label: |
enhancment
enhancement
```
We can further improve the code by using the `const` keyword instead of `var` in the `src/index.ts` file.
'''
expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancment'}]}
expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement'}]}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output
@ -76,10 +76,178 @@ code_suggestions:
- relevant_file: |
src/index2.ts
label: |
enhancment
enhancement
```
We can further improve the code by using the `const` keyword instead of `var` in the `src/index.ts` file.
'''
expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancment'}]}
expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement'}]}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output
def test_with_brackets_yaml_content(self):
review_text = '''\
{
code_suggestions:
- relevant_file: |
src/index.ts
label: |
best practice
- relevant_file: |
src/index2.ts
label: |
enhancement
}
'''
expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement'}]}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output
def test_tab_indent_yaml(self):
review_text = '''\
code_suggestions:
- relevant_file: |
src/index.ts
label: |
\tbest practice
- relevant_file: |
src/index2.ts
label: |
enhancement
'''
expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement\n'}]}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output
def test_leading_plus_mark_code(self):
review_text = '''\
code_suggestions:
- relevant_file: |
src/index.ts
label: |
best practice
existing_code: |
+ var router = createBrowserRouter([
improved_code: |
+ const router = createBrowserRouter([
'''
expected_output = {'code_suggestions': [{
'relevant_file': 'src/index.ts\n',
'label': 'best practice\n',
'existing_code': 'var router = createBrowserRouter([\n',
'improved_code': 'const router = createBrowserRouter([\n'
}]}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='improved_code') == expected_output
def test_inconsistent_indentation_in_block_scalar_yaml(self):
"""
This test case represents a situation where the AI outputs the opening '{' with 5 spaces
(resulting in an inferred indent level of 5), while the closing '}' is output with only 4 spaces.
This inconsistency makes it impossible for the YAML parser to automatically determine the correct
indent level, causing a parsing failure.
The root cause may be the LLM miscounting spaces or misunderstanding the active block scalar context
while generating YAML output.
"""
review_text = '''\
code_suggestions:
- relevant_file: |
tsconfig.json
existing_code: |
{
"key1": "value1",
"key2": {
"subkey": "value"
}
}
'''
expected_json = '''\
{
"key1": "value1",
"key2": {
"subkey": "value"
}
}
'''
expected_output = {
'code_suggestions': [{
'relevant_file': 'tsconfig.json\n',
'existing_code': expected_json
}]
}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='existing_code') == expected_output
def test_inconsistent_and_insufficient_indentation_in_block_scalar_yaml(self):
"""
This test case reproduces a YAML parsing failure where the block scalar content
generated by the AI includes inconsistent and insufficient indentation levels.
The root cause may be the LLM miscounting spaces or misunderstanding the active block scalar context
while generating YAML output.
"""
review_text = '''\
code_suggestions:
- relevant_file: |
tsconfig.json
existing_code: |
{
"key1": "value1",
"key2": {
"subkey": "value"
}
}
'''
expected_json = '''\
{
"key1": "value1",
"key2": {
"subkey": "value"
}
}
'''
expected_output = {
'code_suggestions': [{
'relevant_file': 'tsconfig.json\n',
'existing_code': expected_json
}]
}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='existing_code') == expected_output
def test_wrong_indentation_code_block_scalar(self):
review_text = '''\
code_suggestions:
- relevant_file: |
a.c
existing_code: |
int sum(int a, int b) {
return a + b;
}
int sub(int a, int b) {
return a - b;
}
'''
expected_code_block = '''\
int sum(int a, int b) {
return a + b;
}
int sub(int a, int b) {
return a - b;
}
'''
expected_output = {
"code_suggestions": [
{
"relevant_file": "a.c\n",
"existing_code": expected_code_block
}
]
}
assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='existing_code') == expected_output