Refactor PR help message tool to use full documentation content for answering questions and update relevant section handling in prompts

This commit is contained in:
mrT23
2024-10-24 21:38:31 +03:00
parent c077c71fdb
commit 4f14742233
2 changed files with 63 additions and 42 deletions

View File

@ -1,20 +1,23 @@
[pr_help_prompts] [pr_help_prompts]
system="""You are Doc-helper, a language models designed to answer questions about a documentation website for an open-soure project called "PR-Agent" (recently renamed to "Qodo Merge"). system="""You are Doc-helper, a language models designed to answer questions about a documentation website for an open-soure project called "PR-Agent" (recently renamed to "Qodo Merge").
You will recieve a question, and a list of snippets that were collected for a documentation site using RAG as the retrieval method. You will recieve a question, and the full documentation website content.
Your goal is to provide the best answer to the question using the snippets provided. Your goal is to provide the best answer to the question using the documentation provided.
Additional instructions: Additional instructions:
- Try to be short and concise in your answers. Give examples if needed. - Try to be short and concise in your answers. Give examples if needed.
- It is possible some of the snippets may not be relevant to the question. In that case, you should ignore them and focus on the ones that are relevant.
- The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others. - The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others.
The output must be a YAML object equivalent to type $DocHelper, according to the following Pydantic definitions: The output must be a YAML object equivalent to type $DocHelper, according to the following Pydantic definitions:
===== =====
class relevant_section(BaseModel):
file_name: str = Field(description="The name of the relevant file")
relevant_section_header_string: str = Field(description="Exact text of the relevant section heading")
class DocHelper(BaseModel): class DocHelper(BaseModel):
user_question: str = Field(description="The user's question") user_question: str = Field(description="The user's question")
response: str = Field(description="The response to the user's question") response: str = Field(description="The response to the user's question")
relevant_snippets: List[int] = Field(description="One-based index of the relevant snippets in the list of snippets provided. Order the by relevance, with the most relevant first. If a snippet was not relevant, do not include it in the list.") relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance")
===== =====
@ -24,10 +27,11 @@ user_question: |
... ...
response: | response: |
... ...
relevant_snippets: relevant_sections:
- 2 - file_name: "src/file1.py"
- 1 relevant_section_header_string: |
- 4 ...
- ...
""" """
user="""\ user="""\

View File

@ -4,6 +4,7 @@ import zipfile
import tempfile import tempfile
import copy import copy
from functools import partial from functools import partial
from pathlib import Path
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -157,38 +158,58 @@ class PRHelpMessage:
get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings") get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
return return
# Initialize embeddings # current path
from langchain_openai import OpenAIEmbeddings docs_path= Path(__file__).parent.parent.parent/'docs'/'docs'
embeddings = OpenAIEmbeddings(model="text-embedding-3-small", # get all the 'md' files inside docs_path and its subdirectories
api_key=get_settings().openai.key) md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude =['/finetuning_benchmark/']
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md']
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
# # calculate the token count of all the md files
# token_count = 0
# for file in md_files:
# with open(file, 'r') as f:
# token_count += self.token_handler.count_tokens(f.read())
# Get similar snippets via similarity search docs_prompt =""
if get_settings().pr_help.force_local_db: for file in md_files:
sim_results = self.get_sim_results_from_local_db(embeddings) with open(file, 'r') as f:
elif get_settings().get('pinecone.api_key'): file_path = str(file).replace(str(docs_path), '')
sim_results = self.get_sim_results_from_pinecone_db(embeddings) docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n"
else:
sim_results = self.get_sim_results_from_s3_db(embeddings)
if not sim_results:
get_logger().info("Failed to load the S3 index. Loading the local index...")
sim_results = self.get_sim_results_from_local_db(embeddings)
if not sim_results:
get_logger().error("Failed to retrieve similar snippets. Exiting...")
return
# Prepare relevant snippets self.vars['snippets'] = docs_prompt.strip()
relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\ # # Initialize embeddings
await self.prepare_relevant_snippets(sim_results) # from langchain_openai import OpenAIEmbeddings
self.vars['snippets'] = relevant_snippets_str.strip() # embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
# api_key=get_settings().openai.key)
#
# # Get similar snippets via similarity search
# if get_settings().pr_help.force_local_db:
# sim_results = self.get_sim_results_from_local_db(embeddings)
# elif get_settings().get('pinecone.api_key'):
# sim_results = self.get_sim_results_from_pinecone_db(embeddings)
# else:
# sim_results = self.get_sim_results_from_s3_db(embeddings)
# if not sim_results:
# get_logger().info("Failed to load the S3 index. Loading the local index...")
# sim_results = self.get_sim_results_from_local_db(embeddings)
# if not sim_results:
# get_logger().error("Failed to retrieve similar snippets. Exiting...")
# return
# # Prepare relevant snippets
# relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
# await self.prepare_relevant_snippets(sim_results)
# self.vars['snippets'] = relevant_snippets_str.strip()
# run the AI model # run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
response_yaml = load_yaml(response) response_yaml = load_yaml(response)
response_str = response_yaml.get('response') response_str = response_yaml.get('response')
relevant_snippets_numbers = response_yaml.get('relevant_snippets') relevant_sections = response_yaml.get('relevant_sections')
if not relevant_snippets_numbers: if not relevant_sections:
get_logger().info(f"Could not find relevant snippets for the question: {self.question_str}") get_logger().info(f"Could not find relevant answer for the question: {self.question_str}")
if get_settings().config.publish_output: if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n" answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n" answer_str += f"### Answer:\n\n"
@ -202,16 +223,12 @@ class PRHelpMessage:
answer_str += f"### Question: \n{self.question_str}\n\n" answer_str += f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n{response_str.strip()}\n\n" answer_str += f"### Answer:\n{response_str.strip()}\n\n"
answer_str += f"#### Relevant Sources:\n\n" answer_str += f"#### Relevant Sources:\n\n"
paged_published = [] base_path = "https://qodo-merge-docs.qodo.ai/"
for page in relevant_snippets_numbers: for section in relevant_sections:
page = int(page - 1) file = section.get('file_name').strip().removesuffix('.md')
if page < len(relevant_pages_full) and page >= 0: markdown_header = section['relevant_section_header_string'].strip().strip('#').strip().lower().replace(' ', '-')
if relevant_pages_full[page] in paged_published: answer_str += f"> - {base_path}{file}#{markdown_header}\n"
continue
link = f"{relevant_pages_full[page]}{relevant_snippets_full_header[page]}"
# answer_str += f"> - [{relevant_pages_full[page]}]({link})\n"
answer_str += f"> - {link}\n"
paged_published.append(relevant_pages_full[page])
# publish the answer # publish the answer
if get_settings().config.publish_output: if get_settings().config.publish_output: