Refactor PR help message tool to use full documentation content for answering questions and update relevant section handling in prompts

This commit is contained in:
mrT23
2024-10-24 21:38:31 +03:00
parent c077c71fdb
commit 4f14742233
2 changed files with 63 additions and 42 deletions

View File

@ -4,6 +4,7 @@ import zipfile
import tempfile
import copy
from functools import partial
from pathlib import Path
from jinja2 import Environment, StrictUndefined
@ -157,38 +158,58 @@ class PRHelpMessage:
get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
return
# Initialize embeddings
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
api_key=get_settings().openai.key)
# current path
docs_path= Path(__file__).parent.parent.parent/'docs'/'docs'
# get all the 'md' files inside docs_path and its subdirectories
md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude =['/finetuning_benchmark/']
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md']
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
# # calculate the token count of all the md files
# token_count = 0
# for file in md_files:
# with open(file, 'r') as f:
# token_count += self.token_handler.count_tokens(f.read())
# Get similar snippets via similarity search
if get_settings().pr_help.force_local_db:
sim_results = self.get_sim_results_from_local_db(embeddings)
elif get_settings().get('pinecone.api_key'):
sim_results = self.get_sim_results_from_pinecone_db(embeddings)
else:
sim_results = self.get_sim_results_from_s3_db(embeddings)
if not sim_results:
get_logger().info("Failed to load the S3 index. Loading the local index...")
sim_results = self.get_sim_results_from_local_db(embeddings)
if not sim_results:
get_logger().error("Failed to retrieve similar snippets. Exiting...")
return
docs_prompt =""
for file in md_files:
with open(file, 'r') as f:
file_path = str(file).replace(str(docs_path), '')
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n"
# Prepare relevant snippets
relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
await self.prepare_relevant_snippets(sim_results)
self.vars['snippets'] = relevant_snippets_str.strip()
self.vars['snippets'] = docs_prompt.strip()
# # Initialize embeddings
# from langchain_openai import OpenAIEmbeddings
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
# api_key=get_settings().openai.key)
#
# # Get similar snippets via similarity search
# if get_settings().pr_help.force_local_db:
# sim_results = self.get_sim_results_from_local_db(embeddings)
# elif get_settings().get('pinecone.api_key'):
# sim_results = self.get_sim_results_from_pinecone_db(embeddings)
# else:
# sim_results = self.get_sim_results_from_s3_db(embeddings)
# if not sim_results:
# get_logger().info("Failed to load the S3 index. Loading the local index...")
# sim_results = self.get_sim_results_from_local_db(embeddings)
# if not sim_results:
# get_logger().error("Failed to retrieve similar snippets. Exiting...")
# return
# # Prepare relevant snippets
# relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
# await self.prepare_relevant_snippets(sim_results)
# self.vars['snippets'] = relevant_snippets_str.strip()
# run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
response_yaml = load_yaml(response)
response_str = response_yaml.get('response')
relevant_snippets_numbers = response_yaml.get('relevant_snippets')
relevant_sections = response_yaml.get('relevant_sections')
if not relevant_snippets_numbers:
get_logger().info(f"Could not find relevant snippets for the question: {self.question_str}")
if not relevant_sections:
get_logger().info(f"Could not find relevant answer for the question: {self.question_str}")
if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n"
@ -202,16 +223,12 @@ class PRHelpMessage:
answer_str += f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n{response_str.strip()}\n\n"
answer_str += f"#### Relevant Sources:\n\n"
paged_published = []
for page in relevant_snippets_numbers:
page = int(page - 1)
if page < len(relevant_pages_full) and page >= 0:
if relevant_pages_full[page] in paged_published:
continue
link = f"{relevant_pages_full[page]}{relevant_snippets_full_header[page]}"
# answer_str += f"> - [{relevant_pages_full[page]}]({link})\n"
answer_str += f"> - {link}\n"
paged_published.append(relevant_pages_full[page])
base_path = "https://qodo-merge-docs.qodo.ai/"
for section in relevant_sections:
file = section.get('file_name').strip().removesuffix('.md')
markdown_header = section['relevant_section_header_string'].strip().strip('#').strip().lower().replace(' ', '-')
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
# publish the answer
if get_settings().config.publish_output: