Refactor PR help message tool to use full documentation content for answering questions and update relevant section handling in prompts

This commit is contained in:
mrT23
2024-10-24 22:01:40 +03:00
parent 4f14742233
commit 9786499fa6
4 changed files with 26 additions and 120 deletions

View File

@ -8,11 +8,12 @@ from pathlib import Path
from jinja2 import Environment, StrictUndefined
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import ModelType, load_yaml
from pr_agent.algo.utils import ModelType, load_yaml, clip_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider, GithubProvider, BitbucketServerProvider, \
get_git_provider_with_context
@ -68,83 +69,6 @@ class PRHelpMessage:
question_str = ""
return question_str
def get_sim_results_from_s3_db(self, embeddings):
get_logger().info("Loading the S3 index...")
sim_results = []
try:
from langchain_chroma import Chroma
from urllib import request
with tempfile.TemporaryDirectory() as temp_dir:
# Define the local file path within the temporary directory
local_file_path = os.path.join(temp_dir, 'chroma_db.zip')
bucket = 'pr-agent'
file_name = 'chroma_db.zip'
s3_url = f'https://{bucket}.s3.amazonaws.com/{file_name}'
request.urlretrieve(s3_url, local_file_path)
# # Download the file from S3 to the temporary directory
# s3 = boto3.client('s3')
# s3.download_file(bucket, file_name, local_file_path)
# Extract the contents of the zip file
with zipfile.ZipFile(local_file_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
embedding_function=embeddings)
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from S3: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results
def get_sim_results_from_local_db(self, embeddings):
get_logger().info("Loading the local index...")
sim_results = []
try:
from langchain_chroma import Chroma
get_logger().info("Loading the Chroma index...")
db_path = "./docs/chroma_db.zip"
if not os.path.exists(db_path):
db_path= "/app/docs/chroma_db.zip"
if not os.path.exists(db_path):
get_logger().error("Local db not found")
return sim_results
with tempfile.TemporaryDirectory() as temp_dir:
# Extract the ZIP file
with zipfile.ZipFile(db_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
embedding_function=embeddings)
# Do similarity search
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from local db: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results
def get_sim_results_from_pinecone_db(self, embeddings):
get_logger().info("Loading the Pinecone index...")
sim_results = []
try:
from langchain_pinecone import PineconeVectorStore
INDEX_NAME = "pr-agent-docs"
vectorstore = PineconeVectorStore(
index_name=INDEX_NAME, embedding=embeddings,
pinecone_api_key=get_settings().pinecone.api_key
)
# Do similarity search
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from Pinecone db: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results
async def run(self):
try:
if self.question_str:
@ -159,48 +83,36 @@ class PRHelpMessage:
return
# current path
docs_path= Path(__file__).parent.parent.parent/'docs'/'docs'
docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs'
# get all the 'md' files inside docs_path and its subdirectories
md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude =['/finetuning_benchmark/']
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md']
folders_to_exclude = ['/finetuning_benchmark/']
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md']
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
# # calculate the token count of all the md files
# token_count = 0
# for file in md_files:
# with open(file, 'r') as f:
# token_count += self.token_handler.count_tokens(f.read())
docs_prompt =""
# sort the 'md_files' so that 'priority_files' will be at the top
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
'tools/improve.md', '/faq']
md_files_priority = [file for file in md_files if
any(priority_string in str(file) for priority_string in priority_files_strings)]
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
md_files = md_files_priority + md_files_not_priority
docs_prompt = ""
for file in md_files:
with open(file, 'r') as f:
file_path = str(file).replace(str(docs_path), '')
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n"
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read().strip()}\n=========\n\n"
token_count = self.token_handler.count_tokens(docs_prompt)
get_logger().debug(f"Token count of full documentation website: {token_count}")
model = get_settings().config.model
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
delta_output = 2000
if token_count > max_tokens_full - delta_output:
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
self.vars['snippets'] = docs_prompt.strip()
# # Initialize embeddings
# from langchain_openai import OpenAIEmbeddings
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
# api_key=get_settings().openai.key)
#
# # Get similar snippets via similarity search
# if get_settings().pr_help.force_local_db:
# sim_results = self.get_sim_results_from_local_db(embeddings)
# elif get_settings().get('pinecone.api_key'):
# sim_results = self.get_sim_results_from_pinecone_db(embeddings)
# else:
# sim_results = self.get_sim_results_from_s3_db(embeddings)
# if not sim_results:
# get_logger().info("Failed to load the S3 index. Loading the local index...")
# sim_results = self.get_sim_results_from_local_db(embeddings)
# if not sim_results:
# get_logger().error("Failed to retrieve similar snippets. Exiting...")
# return
# # Prepare relevant snippets
# relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
# await self.prepare_relevant_snippets(sim_results)
# self.vars['snippets'] = relevant_snippets_str.strip()
# run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)