Merge pull request #1534 from qodo-ai/tr/help_r

feat: improve help tool with markdown header formatting and error han…
This commit is contained in:
Tal
2025-02-13 11:50:59 +02:00
committed by GitHub
2 changed files with 48 additions and 8 deletions

View File

@ -13,7 +13,7 @@ The output must be a YAML object equivalent to type $DocHelper, according to the
===== =====
class relevant_section(BaseModel): class relevant_section(BaseModel):
file_name: str = Field(description="The name of the relevant file") file_name: str = Field(description="The name of the relevant file")
relevant_section_header_string: str = Field(description="From the relevant file, exact text of the relevant section heading. If no markdown heading is relevant, return empty string") relevant_section_header_string: str = Field(description="The exact text of the relevant markdown section heading from the relevant file (starting with '#', '##', etc.). Return empty string if the entire file is the relevant section, or if the relevant section has no heading")
class DocHelper(BaseModel): class DocHelper(BaseModel):
user_question: str = Field(description="The user's question") user_question: str = Field(description="The user's question")

View File

@ -1,4 +1,5 @@
import copy import copy
import re
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -9,10 +10,9 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import retry_with_fallback_models from pr_agent.algo.pr_processing import retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import ModelType, clip_tokens, load_yaml from pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import (BitbucketServerProvider, GithubProvider, from pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context
get_git_provider_with_context)
from pr_agent.log import get_logger from pr_agent.log import get_logger
@ -30,10 +30,11 @@ def extract_header(snippet):
return res return res
class PRHelpMessage: class PRHelpMessage:
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False):
self.git_provider = get_git_provider_with_context(pr_url) self.git_provider = get_git_provider_with_context(pr_url)
self.ai_handler = ai_handler() self.ai_handler = ai_handler()
self.question_str = self.parse_args(args) self.question_str = self.parse_args(args)
self.return_as_string = return_as_string
self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5) self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5)
if self.question_str: if self.question_str:
self.vars = { self.vars = {
@ -65,6 +66,34 @@ class PRHelpMessage:
question_str = "" question_str = ""
return question_str return question_str
def format_markdown_header(self, header: str) -> str:
try:
# First, strip common characters from both ends
cleaned = header.strip('# 💎\n')
# Define all characters to be removed/replaced in a single pass
replacements = {
"'": '',
"`": '',
'(': '',
')': '',
',': '',
'.': '',
'?': '',
'!': '',
' ': '-'
}
# Compile regex pattern for characters to remove
pattern = re.compile('|'.join(map(re.escape, replacements.keys())))
# Perform replacements in a single pass and convert to lowercase
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
except Exception:
get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header})
return ""
async def run(self): async def run(self):
try: try:
if self.question_str: if self.question_str:
@ -106,7 +135,10 @@ class PRHelpMessage:
get_logger().debug(f"Token count of full documentation website: {token_count}") get_logger().debug(f"Token count of full documentation website: {token_count}")
model = get_settings().config.model model = get_settings().config.model
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt if model in MAX_TOKENS:
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
else:
max_tokens_full = get_max_tokens(model)
delta_output = 2000 delta_output = 2000
if token_count > max_tokens_full - delta_output: if token_count > max_tokens_full - delta_output:
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.") get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
@ -114,8 +146,16 @@ class PRHelpMessage:
self.vars['snippets'] = docs_prompt.strip() self.vars['snippets'] = docs_prompt.strip()
# run the AI model # run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK) response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
response_yaml = load_yaml(response) response_yaml = load_yaml(response)
if isinstance(response_yaml, str):
get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is")
if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n"
answer_str += response_yaml
self.git_provider.publish_comment(answer_str)
return ""
response_str = response_yaml.get('response') response_str = response_yaml.get('response')
relevant_sections = response_yaml.get('relevant_sections') relevant_sections = response_yaml.get('relevant_sections')
@ -138,7 +178,7 @@ class PRHelpMessage:
for section in relevant_sections: for section in relevant_sections:
file = section.get('file_name').strip().removesuffix('.md') file = section.get('file_name').strip().removesuffix('.md')
if str(section['relevant_section_header_string']).strip(): if str(section['relevant_section_header_string']).strip():
markdown_header = section['relevant_section_header_string'].strip().strip('#').strip().lower().replace(' ', '-').replace("'", '').replace('(', '').replace(')', '').replace(',', '').replace('.', '').replace('?', '').replace('!', '') markdown_header = self.format_markdown_header(section['relevant_section_header_string'])
answer_str += f"> - {base_path}{file}#{markdown_header}\n" answer_str += f"> - {base_path}{file}#{markdown_header}\n"
else: else:
answer_str += f"> - {base_path}{file}\n" answer_str += f"> - {base_path}{file}\n"