From 2c37b02aa05a28b0ee17073e4ad3f92f8592c97e Mon Sep 17 00:00:00 2001 From: mrT23 Date: Thu, 13 Feb 2025 11:44:50 +0200 Subject: [PATCH] feat: improve help tool with markdown header formatting and error handling --- pr_agent/settings/pr_help_prompts.toml | 2 +- pr_agent/tools/pr_help_message.py | 58 ++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/pr_agent/settings/pr_help_prompts.toml b/pr_agent/settings/pr_help_prompts.toml index 1911db9c..8e4c7a28 100644 --- a/pr_agent/settings/pr_help_prompts.toml +++ b/pr_agent/settings/pr_help_prompts.toml @@ -13,7 +13,7 @@ The output must be a YAML object equivalent to type $DocHelper, according to the ===== class relevant_section(BaseModel): file_name: str = Field(description="The name of the relevant file") - relevant_section_header_string: str = Field(description="From the relevant file, exact text of the relevant section heading. If no markdown heading is relevant, return empty string") + relevant_section_header_string: str = Field(description="The exact text of the relevant markdown section heading from the relevant file (starting with '#', '##', etc.). Return empty string if the entire file is the relevant section, or if the relevant section has no heading") class DocHelper(BaseModel): user_question: str = Field(description="The user's question") diff --git a/pr_agent/tools/pr_help_message.py b/pr_agent/tools/pr_help_message.py index 82bdf43c..8a07874c 100644 --- a/pr_agent/tools/pr_help_message.py +++ b/pr_agent/tools/pr_help_message.py @@ -1,4 +1,5 @@ import copy +import re from functools import partial from pathlib import Path @@ -9,10 +10,9 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import ModelType, clip_tokens, load_yaml +from pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens from pr_agent.config_loader import get_settings -from pr_agent.git_providers import (BitbucketServerProvider, GithubProvider, - get_git_provider_with_context) +from pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context from pr_agent.log import get_logger @@ -30,10 +30,11 @@ def extract_header(snippet): return res class PRHelpMessage: - def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False): self.git_provider = get_git_provider_with_context(pr_url) self.ai_handler = ai_handler() self.question_str = self.parse_args(args) + self.return_as_string = return_as_string self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5) if self.question_str: self.vars = { @@ -65,6 +66,34 @@ class PRHelpMessage: question_str = "" return question_str + def format_markdown_header(self, header: str) -> str: + try: + # First, strip common characters from both ends + cleaned = header.strip('# 💎\n') + + # Define all characters to be removed/replaced in a single pass + replacements = { + "'": '', + "`": '', + '(': '', + ')': '', + ',': '', + '.': '', + '?': '', + '!': '', + ' ': '-' + } + + # Compile regex pattern for characters to remove + pattern = re.compile('|'.join(map(re.escape, replacements.keys()))) + + # Perform replacements in a single pass and convert to lowercase + return pattern.sub(lambda m: replacements[m.group()], cleaned).lower() + except Exception: + get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header}) + return "" + + async def run(self): try: if self.question_str: @@ -106,7 +135,10 @@ class PRHelpMessage: get_logger().debug(f"Token count of full documentation website: {token_count}") model = get_settings().config.model - max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt + if model in MAX_TOKENS: + max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt + else: + max_tokens_full = get_max_tokens(model) delta_output = 2000 if token_count > max_tokens_full - delta_output: get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.") @@ -114,8 +146,16 @@ class PRHelpMessage: self.vars['snippets'] = docs_prompt.strip() # run the AI model - response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK) + response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) response_yaml = load_yaml(response) + if isinstance(response_yaml, str): + get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is") + if get_settings().config.publish_output: + answer_str = f"### Question: \n{self.question_str}\n\n" + answer_str += f"### Answer:\n\n" + answer_str += response_yaml + self.git_provider.publish_comment(answer_str) + return "" response_str = response_yaml.get('response') relevant_sections = response_yaml.get('relevant_sections') @@ -138,7 +178,7 @@ class PRHelpMessage: for section in relevant_sections: file = section.get('file_name').strip().removesuffix('.md') if str(section['relevant_section_header_string']).strip(): - markdown_header = section['relevant_section_header_string'].strip().strip('#').strip().lower().replace(' ', '-').replace("'", '').replace('(', '').replace(')', '').replace(',', '').replace('.', '').replace('?', '').replace('!', '') + markdown_header = self.format_markdown_header(section['relevant_section_header_string']) answer_str += f"> - {base_path}{file}#{markdown_header}\n" else: answer_str += f"> - {base_path}{file}\n" @@ -236,7 +276,7 @@ class PRHelpMessage: for i in range(len(tool_names)): pr_comment += f"\n\n\n{tool_names[i]}\n{descriptions[i]}\n\n\n{checkbox_list[i]}\n" pr_comment += "\n\n" - pr_comment += f"""\n\n(1) Note that each tool can be [triggered automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) when a new PR is opened, or called manually by [commenting on a PR](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#online-usage).""" + pr_comment += f"""\n\n(1) Note that each tool be [triggered automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) when a new PR is opened, or called manually by [commenting on a PR](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#online-usage).""" pr_comment += f"""\n\n(2) Tools marked with [*] require additional parameters to be passed. For example, to invoke the `/ask` tool, you need to comment on a PR: `/ask ""`. See the relevant documentation for each tool for more details.""" elif isinstance(self.git_provider, BitbucketServerProvider): # only support basic commands in BBDC @@ -246,7 +286,7 @@ class PRHelpMessage: for i in range(len(tool_names)): pr_comment += f"\n\n\n{tool_names[i]}{commands[i]}{descriptions[i]}" pr_comment += "\n\n" - pr_comment += f"""\n\nNote that each tool can be [invoked automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/) when a new PR is opened, or called manually by [commenting on a PR](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#online-usage).""" + pr_comment += f"""\n\nNote that each tool be [invoked automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/) when a new PR is opened, or called manually by [commenting on a PR](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#online-usage).""" if get_settings().config.publish_output: self.git_provider.publish_comment(pr_comment)