From 4f14742233cb3baeafae889efa61929863ff8e03 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Thu, 24 Oct 2024 21:38:31 +0300 Subject: [PATCH] Refactor PR help message tool to use full documentation content for answering questions and update relevant section handling in prompts --- pr_agent/settings/pr_help_prompts.toml | 20 +++--- pr_agent/tools/pr_help_message.py | 85 +++++++++++++++----------- 2 files changed, 63 insertions(+), 42 deletions(-) diff --git a/pr_agent/settings/pr_help_prompts.toml b/pr_agent/settings/pr_help_prompts.toml index b537b7b7..4ff6cf6d 100644 --- a/pr_agent/settings/pr_help_prompts.toml +++ b/pr_agent/settings/pr_help_prompts.toml @@ -1,20 +1,23 @@ [pr_help_prompts] system="""You are Doc-helper, a language models designed to answer questions about a documentation website for an open-soure project called "PR-Agent" (recently renamed to "Qodo Merge"). -You will recieve a question, and a list of snippets that were collected for a documentation site using RAG as the retrieval method. -Your goal is to provide the best answer to the question using the snippets provided. +You will recieve a question, and the full documentation website content. +Your goal is to provide the best answer to the question using the documentation provided. Additional instructions: - Try to be short and concise in your answers. Give examples if needed. -- It is possible some of the snippets may not be relevant to the question. In that case, you should ignore them and focus on the ones that are relevant. - The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others. The output must be a YAML object equivalent to type $DocHelper, according to the following Pydantic definitions: ===== +class relevant_section(BaseModel): + file_name: str = Field(description="The name of the relevant file") + relevant_section_header_string: str = Field(description="Exact text of the relevant section heading") + class DocHelper(BaseModel): user_question: str = Field(description="The user's question") response: str = Field(description="The response to the user's question") - relevant_snippets: List[int] = Field(description="One-based index of the relevant snippets in the list of snippets provided. Order the by relevance, with the most relevant first. If a snippet was not relevant, do not include it in the list.") + relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance") ===== @@ -24,10 +27,11 @@ user_question: | ... response: | ... -relevant_snippets: -- 2 -- 1 -- 4 +relevant_sections: +- file_name: "src/file1.py" + relevant_section_header_string: | + ... +- ... """ user="""\ diff --git a/pr_agent/tools/pr_help_message.py b/pr_agent/tools/pr_help_message.py index 5c909ea6..b2933632 100644 --- a/pr_agent/tools/pr_help_message.py +++ b/pr_agent/tools/pr_help_message.py @@ -4,6 +4,7 @@ import zipfile import tempfile import copy from functools import partial +from pathlib import Path from jinja2 import Environment, StrictUndefined @@ -157,38 +158,58 @@ class PRHelpMessage: get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings") return - # Initialize embeddings - from langchain_openai import OpenAIEmbeddings - embeddings = OpenAIEmbeddings(model="text-embedding-3-small", - api_key=get_settings().openai.key) + # current path + docs_path= Path(__file__).parent.parent.parent/'docs'/'docs' + # get all the 'md' files inside docs_path and its subdirectories + md_files = list(docs_path.glob('**/*.md')) + folders_to_exclude =['/finetuning_benchmark/'] + files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md'] + md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)] + # # calculate the token count of all the md files + # token_count = 0 + # for file in md_files: + # with open(file, 'r') as f: + # token_count += self.token_handler.count_tokens(f.read()) - # Get similar snippets via similarity search - if get_settings().pr_help.force_local_db: - sim_results = self.get_sim_results_from_local_db(embeddings) - elif get_settings().get('pinecone.api_key'): - sim_results = self.get_sim_results_from_pinecone_db(embeddings) - else: - sim_results = self.get_sim_results_from_s3_db(embeddings) - if not sim_results: - get_logger().info("Failed to load the S3 index. Loading the local index...") - sim_results = self.get_sim_results_from_local_db(embeddings) - if not sim_results: - get_logger().error("Failed to retrieve similar snippets. Exiting...") - return + docs_prompt ="" + for file in md_files: + with open(file, 'r') as f: + file_path = str(file).replace(str(docs_path), '') + docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n" - # Prepare relevant snippets - relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\ - await self.prepare_relevant_snippets(sim_results) - self.vars['snippets'] = relevant_snippets_str.strip() + self.vars['snippets'] = docs_prompt.strip() + # # Initialize embeddings + # from langchain_openai import OpenAIEmbeddings + # embeddings = OpenAIEmbeddings(model="text-embedding-3-small", + # api_key=get_settings().openai.key) + # + # # Get similar snippets via similarity search + # if get_settings().pr_help.force_local_db: + # sim_results = self.get_sim_results_from_local_db(embeddings) + # elif get_settings().get('pinecone.api_key'): + # sim_results = self.get_sim_results_from_pinecone_db(embeddings) + # else: + # sim_results = self.get_sim_results_from_s3_db(embeddings) + # if not sim_results: + # get_logger().info("Failed to load the S3 index. Loading the local index...") + # sim_results = self.get_sim_results_from_local_db(embeddings) + # if not sim_results: + # get_logger().error("Failed to retrieve similar snippets. Exiting...") + # return + + # # Prepare relevant snippets + # relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\ + # await self.prepare_relevant_snippets(sim_results) + # self.vars['snippets'] = relevant_snippets_str.strip() # run the AI model response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) response_yaml = load_yaml(response) response_str = response_yaml.get('response') - relevant_snippets_numbers = response_yaml.get('relevant_snippets') + relevant_sections = response_yaml.get('relevant_sections') - if not relevant_snippets_numbers: - get_logger().info(f"Could not find relevant snippets for the question: {self.question_str}") + if not relevant_sections: + get_logger().info(f"Could not find relevant answer for the question: {self.question_str}") if get_settings().config.publish_output: answer_str = f"### Question: \n{self.question_str}\n\n" answer_str += f"### Answer:\n\n" @@ -202,16 +223,12 @@ class PRHelpMessage: answer_str += f"### Question: \n{self.question_str}\n\n" answer_str += f"### Answer:\n{response_str.strip()}\n\n" answer_str += f"#### Relevant Sources:\n\n" - paged_published = [] - for page in relevant_snippets_numbers: - page = int(page - 1) - if page < len(relevant_pages_full) and page >= 0: - if relevant_pages_full[page] in paged_published: - continue - link = f"{relevant_pages_full[page]}{relevant_snippets_full_header[page]}" - # answer_str += f"> - [{relevant_pages_full[page]}]({link})\n" - answer_str += f"> - {link}\n" - paged_published.append(relevant_pages_full[page]) + base_path = "https://qodo-merge-docs.qodo.ai/" + for section in relevant_sections: + file = section.get('file_name').strip().removesuffix('.md') + markdown_header = section['relevant_section_header_string'].strip().strip('#').strip().lower().replace(' ', '-') + answer_str += f"> - {base_path}{file}#{markdown_header}\n" + # publish the answer if get_settings().config.publish_output: