diff --git a/docs/docs/tools/describe.md b/docs/docs/tools/describe.md index 0f2c45a7..a967a6c5 100644 --- a/docs/docs/tools/describe.md +++ b/docs/docs/tools/describe.md @@ -56,6 +56,21 @@ Everything below this marker is treated as previously auto-generated content and ![Describe comment](https://codium.ai/images/pr_agent/pr_description_user_description.png){width=512} +### Sequence Diagram Support +When the `enable_pr_diagram` option is enabled in your configuration, the `/describe` tool will include a `Mermaid` sequence diagram in the PR description. + +This diagram represents interactions between components/functions based on the diff content. + +### How to enable + +In your configuration: + +``` +toml +[pr_description] +enable_pr_diagram = true +``` + ## Configuration options !!! example "Possible configurations" @@ -109,6 +124,10 @@ Everything below this marker is treated as previously auto-generated content and enable_help_text If set to true, the tool will display a help text in the comment. Default is false. + + add_diagram + If set to true, the tool will generate a Mermaid sequence diagram (in code block format) describing component interactions based on the code changes. Default is false. + ## Inline file summary 💎 diff --git a/docs/docs/tools/improve.md b/docs/docs/tools/improve.md index 2777a6d5..54ece175 100644 --- a/docs/docs/tools/improve.md +++ b/docs/docs/tools/improve.md @@ -435,7 +435,7 @@ To enable auto-approval based on specific criteria, first, you need to enable th enable_auto_approval = true ``` -There are two criteria that can be set for auto-approval: +There are several criteria that can be set for auto-approval: - **Review effort score** @@ -457,7 +457,19 @@ enable_auto_approval = true auto_approve_for_no_suggestions = true ``` -When no [code suggestion](https://www.qodo.ai/images/pr_agent/code_suggestions_as_comment_closed.png) were found for the PR, the PR will be auto-approved. +When no [code suggestions](https://www.qodo.ai/images/pr_agent/code_suggestions_as_comment_closed.png) were found for the PR, the PR will be auto-approved. + +___ + +- **Ticket Compliance** + +```toml +[config] +enable_auto_approval = true +ensure_ticket_compliance = true # Default is false +``` + +If `ensure_ticket_compliance` is set to `true`, auto-approval will be disabled if a ticket is linked to the PR and the ticket is not compliant (e.g., the `review` tool did not mark the PR as fully compliant with the ticket). This ensures that PRs are only auto-approved if their associated tickets are properly resolved. ### How many code suggestions are generated? diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index e38bd713..23c9795b 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -53,9 +53,11 @@ MAX_TOKENS = { 'vertex_ai/claude-3-5-haiku@20241022': 100000, 'vertex_ai/claude-3-sonnet@20240229': 100000, 'vertex_ai/claude-3-opus@20240229': 100000, + 'vertex_ai/claude-opus-4@20250514': 200000, 'vertex_ai/claude-3-5-sonnet@20240620': 100000, 'vertex_ai/claude-3-5-sonnet-v2@20241022': 100000, 'vertex_ai/claude-3-7-sonnet@20250219': 200000, + 'vertex_ai/claude-sonnet-4@20250514': 200000, 'vertex_ai/gemini-1.5-pro': 1048576, 'vertex_ai/gemini-2.5-pro-preview-03-25': 1048576, 'vertex_ai/gemini-2.5-pro-preview-05-06': 1048576, @@ -74,22 +76,28 @@ MAX_TOKENS = { 'anthropic.claude-v1': 100000, 'anthropic.claude-v2': 100000, 'anthropic/claude-3-opus-20240229': 100000, + 'anthropic/claude-opus-4-20250514': 200000, 'anthropic/claude-3-5-sonnet-20240620': 100000, 'anthropic/claude-3-5-sonnet-20241022': 100000, 'anthropic/claude-3-7-sonnet-20250219': 200000, + 'anthropic/claude-sonnet-4-20250514': 200000, 'claude-3-7-sonnet-20250219': 200000, 'anthropic/claude-3-5-haiku-20241022': 100000, 'bedrock/anthropic.claude-instant-v1': 100000, 'bedrock/anthropic.claude-v2': 100000, 'bedrock/anthropic.claude-v2:1': 100000, 'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000, + 'bedrock/anthropic.claude-opus-4-20250514-v1:0': 200000, 'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000, 'bedrock/anthropic.claude-3-5-haiku-20241022-v1:0': 100000, 'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000, 'bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0': 100000, 'bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0': 200000, + 'bedrock/anthropic.claude-sonnet-4-20250514-v1:0': 200000, + "bedrock/us.anthropic.claude-opus-4-20250514-v1:0": 200000, "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0": 100000, "bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0": 200000, + "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": 200000, 'claude-3-5-sonnet': 100000, 'groq/meta-llama/llama-4-scout-17b-16e-instruct': 131072, 'groq/meta-llama/llama-4-maverick-17b-128e-instruct': 131072, @@ -102,9 +110,13 @@ MAX_TOKENS = { 'xai/grok-2': 131072, 'xai/grok-2-1212': 131072, 'xai/grok-2-latest': 131072, + 'xai/grok-3': 131072, 'xai/grok-3-beta': 131072, + 'xai/grok-3-fast': 131072, 'xai/grok-3-fast-beta': 131072, + 'xai/grok-3-mini': 131072, 'xai/grok-3-mini-beta': 131072, + 'xai/grok-3-mini-fast': 131072, 'xai/grok-3-mini-fast-beta': 131072, 'ollama/llama3': 4096, 'watsonx/meta-llama/llama-3-8b-instruct': 4096, diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index d4ea0aa5..4d708fcb 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -6,8 +6,8 @@ except: # we don't enforce langchain as a dependency, so if it's not installed, import functools -from openai import APIError, RateLimitError, Timeout -from retry import retry +import openai +from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.config_loader import get_settings @@ -36,8 +36,10 @@ class LangChainOpenAIHandler(BaseAiHandler): """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - @retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError), - tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) + @retry( + retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError), + stop=stop_after_attempt(OPENAI_RETRIES), + ) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): try: messages = [SystemMessage(content=system), HumanMessage(content=user)] @@ -47,9 +49,15 @@ class LangChainOpenAIHandler(BaseAiHandler): finish_reason = "completed" return resp.content, finish_reason - except (Exception) as e: - get_logger().error("Unknown error during OpenAI inference: ", e) - raise e + except openai.RateLimitError as e: + get_logger().error(f"Rate limit error during LLM inference: {e}") + raise + except openai.APIError as e: + get_logger().warning(f"Error during LLM inference: {e}") + raise + except Exception as e: + get_logger().warning(f"Unknown error during LLM inference: {e}") + raise openai.APIError from e def _create_chat(self, deployment_id=None): try: diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 8d727b8b..f20b03f8 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -3,7 +3,7 @@ import litellm import openai import requests from litellm import acompletion -from tenacity import retry, retry_if_exception_type, stop_after_attempt +from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler @@ -274,8 +274,8 @@ class LiteLLMAIHandler(BaseAiHandler): return get_settings().get("OPENAI.DEPLOYMENT_ID", None) @retry( - retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError - stop=stop_after_attempt(OPENAI_RETRIES) + retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError), + stop=stop_after_attempt(OPENAI_RETRIES), ) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): try: @@ -371,13 +371,13 @@ class LiteLLMAIHandler(BaseAiHandler): get_logger().info(f"\nUser prompt:\n{user}") response = await acompletion(**kwargs) - except (openai.RateLimitError) as e: + except openai.RateLimitError as e: get_logger().error(f"Rate limit error during LLM inference: {e}") raise - except (openai.APIError, openai.APITimeoutError) as e: + except openai.APIError as e: get_logger().warning(f"Error during LLM inference: {e}") raise - except (Exception) as e: + except Exception as e: get_logger().warning(f"Unknown error during LLM inference: {e}") raise openai.APIError from e if response is None or len(response["choices"]) == 0: diff --git a/pr_agent/algo/ai_handlers/openai_ai_handler.py b/pr_agent/algo/ai_handlers/openai_ai_handler.py index f74444a1..253282b0 100644 --- a/pr_agent/algo/ai_handlers/openai_ai_handler.py +++ b/pr_agent/algo/ai_handlers/openai_ai_handler.py @@ -1,8 +1,8 @@ from os import environ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler import openai -from openai import APIError, AsyncOpenAI, RateLimitError, Timeout -from retry import retry +from openai import AsyncOpenAI +from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.config_loader import get_settings @@ -38,8 +38,10 @@ class OpenAIHandler(BaseAiHandler): """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - @retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError), - tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) + @retry( + retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError), + stop=stop_after_attempt(OPENAI_RETRIES), + ) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): try: get_logger().info("System: ", system) @@ -57,12 +59,12 @@ class OpenAIHandler(BaseAiHandler): get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, model=model, usage=usage) return resp, finish_reason - except (APIError, Timeout) as e: - get_logger().error("Error during OpenAI inference: ", e) + except openai.RateLimitError as e: + get_logger().error(f"Rate limit error during LLM inference: {e}") raise - except (RateLimitError) as e: - get_logger().error("Rate limit error during OpenAI inference: ", e) - raise - except (Exception) as e: - get_logger().error("Unknown error during OpenAI inference: ", e) + except openai.APIError as e: + get_logger().warning(f"Error during LLM inference: {e}") raise + except Exception as e: + get_logger().warning(f"Unknown error during LLM inference: {e}") + raise openai.APIError from e diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 60cf2c84..cb313f02 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -1,4 +1,6 @@ from threading import Lock +from math import ceil +import re from jinja2 import Environment, StrictUndefined from tiktoken import encoding_for_model, get_encoding @@ -7,6 +9,16 @@ from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +class ModelTypeValidator: + @staticmethod + def is_openai_model(model_name: str) -> bool: + return 'gpt' in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name) + + @staticmethod + def is_anthropic_model(model_name: str) -> bool: + return 'claude' in model_name + + class TokenEncoder: _encoder_instance = None _model = None @@ -40,6 +52,10 @@ class TokenHandler: method. """ + # Constants + CLAUDE_MODEL = "claude-3-7-sonnet-20250219" + CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API + def __init__(self, pr=None, vars: dict = {}, system="", user=""): """ Initializes the TokenHandler object. @@ -51,6 +67,7 @@ class TokenHandler: - user: The user string. """ self.encoder = TokenEncoder.get_token_encoder() + if pr is not None: self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) @@ -79,22 +96,22 @@ class TokenHandler: get_logger().error(f"Error in _get_system_user_tokens: {e}") return 0 - def calc_claude_tokens(self, patch): + def _calc_claude_tokens(self, patch: str) -> int: try: import anthropic from pr_agent.algo import MAX_TOKENS + client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key')) - MaxTokens = MAX_TOKENS[get_settings().config.model] + max_tokens = MAX_TOKENS[get_settings().config.model] - # Check if the content size is too large (9MB limit) - if len(patch.encode('utf-8')) > 9_000_000: + if len(patch.encode('utf-8')) > self.CLAUDE_MAX_CONTENT_SIZE: get_logger().warning( "Content too large for Anthropic token counting API, falling back to local tokenizer" ) - return MaxTokens + return max_tokens response = client.messages.count_tokens( - model="claude-3-7-sonnet-20250219", + model=self.CLAUDE_MODEL, system="system", messages=[{ "role": "user", @@ -104,42 +121,51 @@ class TokenHandler: return response.input_tokens except Exception as e: - get_logger().error( f"Error in Anthropic token counting: {e}") - return MaxTokens + get_logger().error(f"Error in Anthropic token counting: {e}") + return max_tokens - def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate): - from math import ceil - import re + def _apply_estimation_factor(self, model_name: str, default_estimate: int) -> int: + factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) + get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}") + + return ceil(factor * default_estimate) - model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model) - if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'): - return default_encoder_estimate - #else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort. + def _get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int: + """ + Get token count based on model type. - elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) - get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate") - return ceil(elbow_factor * default_encoder_estimate) + Args: + patch: The text to count tokens for. + default_estimate: The default token count estimate. - def count_tokens(self, patch: str, force_accurate=False) -> int: + Returns: + int: The calculated token count. + """ + model_name = get_settings().config.model.lower() + + if ModelTypeValidator.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'): + return default_estimate + + if ModelTypeValidator.is_anthropic_model(model_name) and get_settings(use_context=False).get('anthropic.key'): + return self._calc_claude_tokens(patch) + + return self._apply_estimation_factor(model_name, default_estimate) + + def count_tokens(self, patch: str, force_accurate: bool = False) -> int: """ Counts the number of tokens in a given patch string. Args: - patch: The patch string. + - force_accurate: If True, uses a more precise calculation method. Returns: The number of tokens in the patch string. """ encoder_estimate = len(self.encoder.encode(patch, disallowed_special=())) - #If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. + # If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. if not force_accurate: return encoder_estimate - #else, force_accurate==True: User requested providing an accurate estimation: - model = get_settings().config.model.lower() - if 'claude' in model and get_settings(use_context=False).get('anthropic.key'): - return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models - - #else: Non Anthropic provided model: - return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate) + return self._get_token_count_by_model_type(patch, encoder_estimate) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 780c7953..3e3103ad 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -945,12 +945,66 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token """ Clip the number of tokens in a string to a maximum number of tokens. + This function limits text to a specified token count by calculating the approximate + character-to-token ratio and truncating the text accordingly. A safety factor of 0.9 + (10% reduction) is applied to ensure the result stays within the token limit. + Args: - text (str): The string to clip. + text (str): The string to clip. If empty or None, returns the input unchanged. max_tokens (int): The maximum number of tokens allowed in the string. - add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped + If negative, returns an empty string. + add_three_dots (bool, optional): Whether to add "\\n...(truncated)" at the end + of the clipped text to indicate truncation. + Defaults to True. + num_input_tokens (int, optional): Pre-computed number of tokens in the input text. + If provided, skips token encoding step for efficiency. + If None, tokens will be counted using TokenEncoder. + Defaults to None. + delete_last_line (bool, optional): Whether to remove the last line from the + clipped content before adding truncation indicator. + Useful for ensuring clean breaks at line boundaries. + Defaults to False. + Returns: - str: The clipped string. + str: The clipped string. Returns original text if: + - Text is empty/None + - Token count is within limit + - An error occurs during processing + + Returns empty string if max_tokens <= 0. + + Examples: + Basic usage: + >>> text = "This is a sample text that might be too long" + >>> result = clip_tokens(text, max_tokens=10) + >>> print(result) + This is a sample... + (truncated) + + Without truncation indicator: + >>> result = clip_tokens(text, max_tokens=10, add_three_dots=False) + >>> print(result) + This is a sample + + With pre-computed token count: + >>> result = clip_tokens(text, max_tokens=5, num_input_tokens=15) + >>> print(result) + This... + (truncated) + + With line deletion: + >>> multiline_text = "Line 1\\nLine 2\\nLine 3" + >>> result = clip_tokens(multiline_text, max_tokens=3, delete_last_line=True) + >>> print(result) + Line 1 + Line 2 + ... + (truncated) + + Notes: + The function uses a safety factor of 0.9 (10% reduction) to ensure the + result stays within the token limit, as character-to-token ratios can vary. + If token encoding fails, the original text is returned with a warning logged. """ if not text: return text diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index 9590a84c..460711cb 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -116,4 +116,9 @@ api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.co [openrouter] key = "" -api_base = "" \ No newline at end of file +api_base = "" + +[aws] +AWS_ACCESS_KEY_ID = "" +AWS_SECRET_ACCESS_KEY = "" +AWS_REGION_NAME = "" \ No newline at end of file diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index f03a5e66..9c80e9fb 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -64,6 +64,7 @@ reasoning_effort = "medium" # "low", "medium", "high" enable_auto_approval=false # Set to true to enable auto-approval of PRs under certain conditions auto_approve_for_low_review_effort=-1 # -1 to disable, [1-5] to set the threshold for auto-approval auto_approve_for_no_suggestions=false # If true, the PR will be auto-approved if there are no suggestions +ensure_ticket_compliance=false # Set to true to disable auto-approval of PRs if the ticket is not compliant # extended thinking for Claude reasoning models enable_claude_extended_thinking = false # Set to true to enable extended thinking feature extended_thinking_budget_tokens = 2048 @@ -103,6 +104,7 @@ enable_pr_type=true final_update_message = true enable_help_text=false enable_help_comment=true +enable_pr_diagram=false # adds a section with a diagram of the PR changes # describe as comment publish_description_as_comment=false publish_description_as_comment_persistent=true diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index 73ec8459..4c14abee 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -46,6 +46,9 @@ class PRDescription(BaseModel): type: List[PRType] = Field(description="one or more types that describe the PR content. Return the label member value (e.g. 'Bug fix', not 'bug_fix')") description: str = Field(description="summarize the PR changes in up to four bullet points, each up to 8 words. For large PRs, add sub-bullets if needed. Order bullets by importance, with each bullet highlighting a key change group.") title: str = Field(description="a concise and descriptive title that captures the PR's main theme") +{%- if enable_pr_diagram %} + changes_diagram: str = Field(description="a horizontal diagram that represents the main PR changes, in the format of a valid mermaid LR flowchart. The diagram should be concise and easy to read. Leave empty if no diagram is relevant. To create robust Mermaid diagrams, follow this two-step process: (1) Declare the nodes: nodeID["node description"]. (2) Then define the links: nodeID1 -- "link text" --> nodeID2 ") +{%- endif %} {%- if enable_semantic_files_types %} pr_files: List[FileDescription] = Field(max_items=20, description="a list of all the files that were changed in the PR, and summary of their changes. Each file must be analyzed regardless of change size.") {%- endif %} @@ -62,6 +65,13 @@ description: | ... title: | ... +{%- if enable_pr_diagram %} + changes_diagram: | + ```mermaid + flowchart LR + ... + ``` +{%- endif %} {%- if enable_semantic_files_types %} pr_files: - filename: | @@ -143,6 +153,13 @@ description: | ... title: | ... +{%- if enable_pr_diagram %} + changes_diagram: | + ```mermaid + flowchart LR + ... + ``` +{%- endif %} {%- if enable_semantic_files_types %} pr_files: - filename: | @@ -164,4 +181,4 @@ pr_files: Response (should be a valid YAML, and nothing else): ```yaml -""" +""" \ No newline at end of file diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index df82db67..663c5a2d 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -72,7 +72,8 @@ class PRDescription: "enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types, "related_tickets": "", "include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD, - 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False), + "duplicate_prompt_examples": get_settings().config.get("duplicate_prompt_examples", False), + "enable_pr_diagram": get_settings().pr_description.get("enable_pr_diagram", False), } self.user_description = self.git_provider.get_user_description() @@ -456,6 +457,12 @@ class PRDescription: self.data['labels'] = self.data.pop('labels') if 'description' in self.data: self.data['description'] = self.data.pop('description') + if 'changes_diagram' in self.data: + changes_diagram = self.data.pop('changes_diagram').strip() + if changes_diagram.startswith('```'): + if not changes_diagram.endswith('```'): # fallback for missing closing + changes_diagram += '\n```' + self.data['changes_diagram'] = '\n'+ changes_diagram if 'pr_files' in self.data: self.data['pr_files'] = self.data.pop('pr_files') @@ -820,4 +827,4 @@ def replace_code_tags(text): parts = text.split('`') for i in range(1, len(parts), 2): parts[i] = '' + parts[i] + '' - return ''.join(parts) + return ''.join(parts) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5290b749..18f6e383 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aiohttp==3.9.5 -anthropic>=0.48 +anthropic>=0.52.0 #anthropic[vertex]==0.47.1 atlassian-python-api==3.41.4 azure-devops==7.1.0b3 @@ -13,7 +13,7 @@ google-cloud-aiplatform==1.38.0 google-generativeai==0.8.3 google-cloud-storage==2.10.0 Jinja2==3.1.2 -litellm==1.69.3 +litellm==1.70.4 loguru==0.7.2 msrest==0.7.1 openai>=1.55.3 diff --git a/tests/unittest/test_clip_tokens.py b/tests/unittest/test_clip_tokens.py index 79de6294..a42ef929 100644 --- a/tests/unittest/test_clip_tokens.py +++ b/tests/unittest/test_clip_tokens.py @@ -1,13 +1,302 @@ - -# Generated by CodiumAI - import pytest - +from unittest.mock import patch, MagicMock from pr_agent.algo.utils import clip_tokens +from pr_agent.algo.token_handler import TokenEncoder class TestClipTokens: - def test_clip(self): + """Comprehensive test suite for the clip_tokens function.""" + + def test_empty_input_text(self): + """Test that empty input returns empty string.""" + assert clip_tokens("", 10) == "" + assert clip_tokens(None, 10) is None + + def test_text_under_token_limit(self): + """Test that text under the token limit is returned unchanged.""" + text = "Short text" + max_tokens = 100 + result = clip_tokens(text, max_tokens) + assert result == text + + def test_text_exactly_at_token_limit(self): + """Test text that is exactly at the token limit.""" + text = "This is exactly at the limit" + # Mock the token encoder to return exact limit + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 10 # Exactly 10 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, 10) + assert result == text + + def test_text_over_token_limit_with_three_dots(self): + """Test text over token limit with three dots addition.""" + text = "This is a longer text that should be clipped when it exceeds the token limit" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert result.endswith("\n...(truncated)") + assert len(result) < len(text) + + def test_text_over_token_limit_without_three_dots(self): + """Test text over token limit without three dots addition.""" + text = "This is a longer text that should be clipped" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens, add_three_dots=False) + assert not result.endswith("\n...(truncated)") + assert len(result) < len(text) + + def test_negative_max_tokens(self): + """Test that negative max_tokens returns empty string.""" + text = "Some text" + result = clip_tokens(text, -1) + assert result == "" + + result = clip_tokens(text, -100) + assert result == "" + + def test_zero_max_tokens(self): + """Test that zero max_tokens returns empty string.""" + text = "Some text" + result = clip_tokens(text, 0) + assert result == "" + + def test_delete_last_line_functionality(self): + """Test the delete_last_line parameter functionality.""" + text = "Line 1\nLine 2\nLine 3\nLine 4" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + # Without delete_last_line + result_normal = clip_tokens(text, max_tokens, delete_last_line=False) + + # With delete_last_line + result_deleted = clip_tokens(text, max_tokens, delete_last_line=True) + + # The result with delete_last_line should be shorter or equal + assert len(result_deleted) <= len(result_normal) + + def test_pre_computed_num_input_tokens(self): + """Test using pre-computed num_input_tokens parameter.""" + text = "This is a test text" + max_tokens = 10 + num_input_tokens = 15 + + # Should not call the encoder when num_input_tokens is provided + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.return_value = None # Should not be called + + result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens) + assert result.endswith("\n...(truncated)") + mock_encoder.assert_not_called() + + def test_pre_computed_tokens_under_limit(self): + """Test pre-computed tokens under the limit.""" + text = "Short text" + max_tokens = 20 + num_input_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.return_value = None # Should not be called + + result = clip_tokens(text, max_tokens, num_input_tokens=num_input_tokens) + assert result == text + mock_encoder.assert_not_called() + + def test_special_characters_and_unicode(self): + """Test text with special characters and Unicode content.""" + text = "Special chars: @#$%^&*()_+ áéíóú 中문 🚀 emoji" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert isinstance(result, str) + assert len(result) < len(text) + + def test_multiline_text_handling(self): + """Test handling of multiline text.""" + text = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert isinstance(result, str) + + def test_very_long_text(self): + """Test with very long text.""" + text = "A" * 10000 # Very long text + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 5000 # Many tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + assert len(result) < len(text) + assert result.endswith("\n...(truncated)") + + def test_encoder_exception_handling(self): + """Test handling of encoder exceptions.""" + text = "Test text" + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.side_effect = Exception("Encoder error") + + # Should return original text when encoder fails + result = clip_tokens(text, max_tokens) + assert result == text + + def test_zero_division_scenario(self): + """Test scenario that could lead to division by zero.""" + text = "Test" + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [] # Empty tokens (could cause division by zero) + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + # Should handle gracefully and return original text + assert result == text + + def test_various_edge_cases(self): + """Test various edge cases.""" + # Single character + assert clip_tokens("A", 1000) == "A" + + # Only whitespace + text = " \n \t " + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 10 + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, 5) + assert isinstance(result, str) + + # Text with only newlines + text = "\n\n\n\n" + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 10 + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, 2, delete_last_line=True) + assert isinstance(result, str) + + def test_parameter_combinations(self): + """Test different parameter combinations.""" + text = "Multi\nline\ntext\nfor\ntesting" + max_tokens = 5 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 + mock_encoder.return_value = mock_tokenizer + + # Test all combinations + combinations = [ + (True, True), # add_three_dots=True, delete_last_line=True + (True, False), # add_three_dots=True, delete_last_line=False + (False, True), # add_three_dots=False, delete_last_line=True + (False, False), # add_three_dots=False, delete_last_line=False + ] + + for add_dots, delete_line in combinations: + result = clip_tokens(text, max_tokens, + add_three_dots=add_dots, + delete_last_line=delete_line) + assert isinstance(result, str) + if add_dots and len(result) > 0: + assert result.endswith("\n...(truncated)") or result == text + + def test_num_output_chars_zero_scenario(self): + """Test scenario where num_output_chars becomes zero or negative.""" + text = "Short" + max_tokens = 1 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 1000 # Many tokens for short text + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + # When num_output_chars is 0 or negative, should return empty string + assert result == "" + + def test_logging_on_exception(self): + """Test that exceptions are properly logged.""" + text = "Test text" + max_tokens = 10 + + # Patch the logger at the module level where it's imported + with patch('pr_agent.algo.utils.get_logger') as mock_logger: + mock_log_instance = MagicMock() + mock_logger.return_value = mock_log_instance + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_encoder.side_effect = Exception("Test exception") + + result = clip_tokens(text, max_tokens) + + # Should log the warning + mock_log_instance.warning.assert_called_once() + # Should return original text + assert result == text + + def test_factor_safety_calculation(self): + """Test that the 0.9 factor (10% reduction) works correctly.""" + text = "Test text that should be reduced by 10 percent for safety" + max_tokens = 10 + + with patch.object(TokenEncoder, 'get_token_encoder') as mock_encoder: + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1] * 20 # 20 tokens + mock_encoder.return_value = mock_tokenizer + + result = clip_tokens(text, max_tokens) + + # The result should be shorter due to the 0.9 factor + # Characters per token = len(text) / 20 + # Expected chars = int(0.9 * (len(text) / 20) * 10) + expected_chars = int(0.9 * (len(text) / 20) * 10) + + # Result should be around expected_chars length (plus truncation text) + if result.endswith("\n...(truncated)"): + actual_content = result[:-len("\n...(truncated)")] + assert len(actual_content) <= expected_chars + 5 # Some tolerance + + # Test the original basic functionality to ensure backward compatibility + def test_clip_original_functionality(self): + """Test original functionality from the existing test.""" text = "line1\nline2\nline3\nline4\nline5\nline6" max_tokens = 25 result = clip_tokens(text, max_tokens) @@ -16,4 +305,4 @@ class TestClipTokens: max_tokens = 10 result = clip_tokens(text, max_tokens) expected_results = 'line1\nline2\nline3\n\n...(truncated)' - assert result == expected_results + assert result == expected_results \ No newline at end of file diff --git a/tests/unittest/test_try_fix_yaml.py b/tests/unittest/test_try_fix_yaml.py index 826d7312..98773c81 100644 --- a/tests/unittest/test_try_fix_yaml.py +++ b/tests/unittest/test_try_fix_yaml.py @@ -53,12 +53,12 @@ code_suggestions: - relevant_file: | src/index2.ts label: | - enhancment + enhancement ``` We can further improve the code by using the `const` keyword instead of `var` in the `src/index.ts` file. ''' - expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancment'}]} + expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement'}]} assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output @@ -76,10 +76,178 @@ code_suggestions: - relevant_file: | src/index2.ts label: | - enhancment + enhancement ``` We can further improve the code by using the `const` keyword instead of `var` in the `src/index.ts` file. ''' - expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancment'}]} + expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement'}]} assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output + + + def test_with_brackets_yaml_content(self): + review_text = '''\ +{ +code_suggestions: +- relevant_file: | + src/index.ts + label: | + best practice + +- relevant_file: | + src/index2.ts + label: | + enhancement +} +''' + expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement'}]} + assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output + + def test_tab_indent_yaml(self): + review_text = '''\ +code_suggestions: +- relevant_file: | + src/index.ts + label: | +\tbest practice + +- relevant_file: | + src/index2.ts + label: | + enhancement +''' + expected_output = {'code_suggestions': [{'relevant_file': 'src/index.ts\n', 'label': 'best practice\n'}, {'relevant_file': 'src/index2.ts\n', 'label': 'enhancement\n'}]} + assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='label') == expected_output + + + def test_leading_plus_mark_code(self): + review_text = '''\ +code_suggestions: +- relevant_file: | + src/index.ts + label: | + best practice + existing_code: | ++ var router = createBrowserRouter([ + improved_code: | ++ const router = createBrowserRouter([ +''' + expected_output = {'code_suggestions': [{ + 'relevant_file': 'src/index.ts\n', + 'label': 'best practice\n', + 'existing_code': 'var router = createBrowserRouter([\n', + 'improved_code': 'const router = createBrowserRouter([\n' + }]} + assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='improved_code') == expected_output + + + def test_inconsistent_indentation_in_block_scalar_yaml(self): + """ + This test case represents a situation where the AI outputs the opening '{' with 5 spaces + (resulting in an inferred indent level of 5), while the closing '}' is output with only 4 spaces. + This inconsistency makes it impossible for the YAML parser to automatically determine the correct + indent level, causing a parsing failure. + + The root cause may be the LLM miscounting spaces or misunderstanding the active block scalar context + while generating YAML output. + """ + + review_text = '''\ +code_suggestions: +- relevant_file: | + tsconfig.json + existing_code: | + { + "key1": "value1", + "key2": { + "subkey": "value" + } + } +''' + expected_json = '''\ + { + "key1": "value1", + "key2": { + "subkey": "value" + } +} +''' + expected_output = { + 'code_suggestions': [{ + 'relevant_file': 'tsconfig.json\n', + 'existing_code': expected_json + }] + } + assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='existing_code') == expected_output + + + def test_inconsistent_and_insufficient_indentation_in_block_scalar_yaml(self): + """ + This test case reproduces a YAML parsing failure where the block scalar content + generated by the AI includes inconsistent and insufficient indentation levels. + + The root cause may be the LLM miscounting spaces or misunderstanding the active block scalar context + while generating YAML output. + """ + + review_text = '''\ +code_suggestions: +- relevant_file: | + tsconfig.json + existing_code: | + { + "key1": "value1", + "key2": { + "subkey": "value" + } + } +''' + expected_json = '''\ +{ + "key1": "value1", + "key2": { + "subkey": "value" + } +} +''' + expected_output = { + 'code_suggestions': [{ + 'relevant_file': 'tsconfig.json\n', + 'existing_code': expected_json + }] + } + assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='existing_code') == expected_output + + + def test_wrong_indentation_code_block_scalar(self): + review_text = '''\ +code_suggestions: +- relevant_file: | + a.c + existing_code: | + int sum(int a, int b) { + return a + b; + } + + int sub(int a, int b) { + return a - b; + } +''' + expected_code_block = '''\ +int sum(int a, int b) { + return a + b; +} + +int sub(int a, int b) { + return a - b; +} +''' + expected_output = { + "code_suggestions": [ + { + "relevant_file": "a.c\n", + "existing_code": expected_code_block + } + ] + } + assert try_fix_yaml(review_text, first_key='code_suggestions', last_key='existing_code') == expected_output