mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 04:40:38 +08:00
Merge pull request #1534 from qodo-ai/tr/help_r
feat: improve help tool with markdown header formatting and error han…
This commit is contained in:
@ -13,7 +13,7 @@ The output must be a YAML object equivalent to type $DocHelper, according to the
|
|||||||
=====
|
=====
|
||||||
class relevant_section(BaseModel):
|
class relevant_section(BaseModel):
|
||||||
file_name: str = Field(description="The name of the relevant file")
|
file_name: str = Field(description="The name of the relevant file")
|
||||||
relevant_section_header_string: str = Field(description="From the relevant file, exact text of the relevant section heading. If no markdown heading is relevant, return empty string")
|
relevant_section_header_string: str = Field(description="The exact text of the relevant markdown section heading from the relevant file (starting with '#', '##', etc.). Return empty string if the entire file is the relevant section, or if the relevant section has no heading")
|
||||||
|
|
||||||
class DocHelper(BaseModel):
|
class DocHelper(BaseModel):
|
||||||
user_question: str = Field(description="The user's question")
|
user_question: str = Field(description="The user's question")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -9,10 +10,9 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
|||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
from pr_agent.algo.pr_processing import retry_with_fallback_models
|
from pr_agent.algo.pr_processing import retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import ModelType, clip_tokens, load_yaml
|
from pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import (BitbucketServerProvider, GithubProvider,
|
from pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context
|
||||||
get_git_provider_with_context)
|
|
||||||
from pr_agent.log import get_logger
|
from pr_agent.log import get_logger
|
||||||
|
|
||||||
|
|
||||||
@ -30,10 +30,11 @@ def extract_header(snippet):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
class PRHelpMessage:
|
class PRHelpMessage:
|
||||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False):
|
||||||
self.git_provider = get_git_provider_with_context(pr_url)
|
self.git_provider = get_git_provider_with_context(pr_url)
|
||||||
self.ai_handler = ai_handler()
|
self.ai_handler = ai_handler()
|
||||||
self.question_str = self.parse_args(args)
|
self.question_str = self.parse_args(args)
|
||||||
|
self.return_as_string = return_as_string
|
||||||
self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5)
|
self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5)
|
||||||
if self.question_str:
|
if self.question_str:
|
||||||
self.vars = {
|
self.vars = {
|
||||||
@ -65,6 +66,34 @@ class PRHelpMessage:
|
|||||||
question_str = ""
|
question_str = ""
|
||||||
return question_str
|
return question_str
|
||||||
|
|
||||||
|
def format_markdown_header(self, header: str) -> str:
|
||||||
|
try:
|
||||||
|
# First, strip common characters from both ends
|
||||||
|
cleaned = header.strip('# 💎\n')
|
||||||
|
|
||||||
|
# Define all characters to be removed/replaced in a single pass
|
||||||
|
replacements = {
|
||||||
|
"'": '',
|
||||||
|
"`": '',
|
||||||
|
'(': '',
|
||||||
|
')': '',
|
||||||
|
',': '',
|
||||||
|
'.': '',
|
||||||
|
'?': '',
|
||||||
|
'!': '',
|
||||||
|
' ': '-'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compile regex pattern for characters to remove
|
||||||
|
pattern = re.compile('|'.join(map(re.escape, replacements.keys())))
|
||||||
|
|
||||||
|
# Perform replacements in a single pass and convert to lowercase
|
||||||
|
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
|
||||||
|
except Exception:
|
||||||
|
get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header})
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
if self.question_str:
|
if self.question_str:
|
||||||
@ -106,7 +135,10 @@ class PRHelpMessage:
|
|||||||
get_logger().debug(f"Token count of full documentation website: {token_count}")
|
get_logger().debug(f"Token count of full documentation website: {token_count}")
|
||||||
|
|
||||||
model = get_settings().config.model
|
model = get_settings().config.model
|
||||||
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
if model in MAX_TOKENS:
|
||||||
|
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||||
|
else:
|
||||||
|
max_tokens_full = get_max_tokens(model)
|
||||||
delta_output = 2000
|
delta_output = 2000
|
||||||
if token_count > max_tokens_full - delta_output:
|
if token_count > max_tokens_full - delta_output:
|
||||||
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
|
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
|
||||||
@ -114,8 +146,16 @@ class PRHelpMessage:
|
|||||||
self.vars['snippets'] = docs_prompt.strip()
|
self.vars['snippets'] = docs_prompt.strip()
|
||||||
|
|
||||||
# run the AI model
|
# run the AI model
|
||||||
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
|
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||||
response_yaml = load_yaml(response)
|
response_yaml = load_yaml(response)
|
||||||
|
if isinstance(response_yaml, str):
|
||||||
|
get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is")
|
||||||
|
if get_settings().config.publish_output:
|
||||||
|
answer_str = f"### Question: \n{self.question_str}\n\n"
|
||||||
|
answer_str += f"### Answer:\n\n"
|
||||||
|
answer_str += response_yaml
|
||||||
|
self.git_provider.publish_comment(answer_str)
|
||||||
|
return ""
|
||||||
response_str = response_yaml.get('response')
|
response_str = response_yaml.get('response')
|
||||||
relevant_sections = response_yaml.get('relevant_sections')
|
relevant_sections = response_yaml.get('relevant_sections')
|
||||||
|
|
||||||
@ -138,7 +178,7 @@ class PRHelpMessage:
|
|||||||
for section in relevant_sections:
|
for section in relevant_sections:
|
||||||
file = section.get('file_name').strip().removesuffix('.md')
|
file = section.get('file_name').strip().removesuffix('.md')
|
||||||
if str(section['relevant_section_header_string']).strip():
|
if str(section['relevant_section_header_string']).strip():
|
||||||
markdown_header = section['relevant_section_header_string'].strip().strip('#').strip().lower().replace(' ', '-').replace("'", '').replace('(', '').replace(')', '').replace(',', '').replace('.', '').replace('?', '').replace('!', '')
|
markdown_header = self.format_markdown_header(section['relevant_section_header_string'])
|
||||||
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
|
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
|
||||||
else:
|
else:
|
||||||
answer_str += f"> - {base_path}{file}\n"
|
answer_str += f"> - {base_path}{file}\n"
|
||||||
|
Reference in New Issue
Block a user