Refactor PR help message tool to use full documentation content for answering questions and update relevant section handling in prompts

This commit is contained in:
mrT23
2024-10-24 22:01:40 +03:00
parent 4f14742233
commit 9786499fa6
4 changed files with 26 additions and 120 deletions

Binary file not shown.

View File

@ -4,7 +4,7 @@ You will recieve a question, and the full documentation website content.
Your goal is to provide the best answer to the question using the documentation provided. Your goal is to provide the best answer to the question using the documentation provided.
Additional instructions: Additional instructions:
- Try to be short and concise in your answers. Give examples if needed. - Try to be short and concise in your answers. Try to give examples if needed.
- The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others. - The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others.
@ -17,7 +17,7 @@ class relevant_section(BaseModel):
class DocHelper(BaseModel): class DocHelper(BaseModel):
user_question: str = Field(description="The user's question") user_question: str = Field(description="The user's question")
response: str = Field(description="The response to the user's question") response: str = Field(description="The response to the user's question")
relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance") relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance (most relevant first)")
===== =====
@ -41,7 +41,7 @@ User's Question:
===== =====
Relevant doc snippets retrieved: Documentation website content:
===== =====
{{ snippets|trim }} {{ snippets|trim }}
===== =====

View File

@ -8,11 +8,12 @@ from pathlib import Path
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import retry_with_fallback_models from pr_agent.algo.pr_processing import retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import ModelType, load_yaml from pr_agent.algo.utils import ModelType, load_yaml, clip_tokens
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider, GithubProvider, BitbucketServerProvider, \ from pr_agent.git_providers import get_git_provider, GithubProvider, BitbucketServerProvider, \
get_git_provider_with_context get_git_provider_with_context
@ -68,83 +69,6 @@ class PRHelpMessage:
question_str = "" question_str = ""
return question_str return question_str
def get_sim_results_from_s3_db(self, embeddings):
get_logger().info("Loading the S3 index...")
sim_results = []
try:
from langchain_chroma import Chroma
from urllib import request
with tempfile.TemporaryDirectory() as temp_dir:
# Define the local file path within the temporary directory
local_file_path = os.path.join(temp_dir, 'chroma_db.zip')
bucket = 'pr-agent'
file_name = 'chroma_db.zip'
s3_url = f'https://{bucket}.s3.amazonaws.com/{file_name}'
request.urlretrieve(s3_url, local_file_path)
# # Download the file from S3 to the temporary directory
# s3 = boto3.client('s3')
# s3.download_file(bucket, file_name, local_file_path)
# Extract the contents of the zip file
with zipfile.ZipFile(local_file_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
embedding_function=embeddings)
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from S3: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results
def get_sim_results_from_local_db(self, embeddings):
get_logger().info("Loading the local index...")
sim_results = []
try:
from langchain_chroma import Chroma
get_logger().info("Loading the Chroma index...")
db_path = "./docs/chroma_db.zip"
if not os.path.exists(db_path):
db_path= "/app/docs/chroma_db.zip"
if not os.path.exists(db_path):
get_logger().error("Local db not found")
return sim_results
with tempfile.TemporaryDirectory() as temp_dir:
# Extract the ZIP file
with zipfile.ZipFile(db_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
embedding_function=embeddings)
# Do similarity search
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from local db: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results
def get_sim_results_from_pinecone_db(self, embeddings):
get_logger().info("Loading the Pinecone index...")
sim_results = []
try:
from langchain_pinecone import PineconeVectorStore
INDEX_NAME = "pr-agent-docs"
vectorstore = PineconeVectorStore(
index_name=INDEX_NAME, embedding=embeddings,
pinecone_api_key=get_settings().pinecone.api_key
)
# Do similarity search
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from Pinecone db: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results
async def run(self): async def run(self):
try: try:
if self.question_str: if self.question_str:
@ -163,44 +87,32 @@ class PRHelpMessage:
# get all the 'md' files inside docs_path and its subdirectories # get all the 'md' files inside docs_path and its subdirectories
md_files = list(docs_path.glob('**/*.md')) md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude = ['/finetuning_benchmark/'] folders_to_exclude = ['/finetuning_benchmark/']
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md'] files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md']
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)] md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
# # calculate the token count of all the md files
# token_count = 0 # sort the 'md_files' so that 'priority_files' will be at the top
# for file in md_files: priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
# with open(file, 'r') as f: 'tools/improve.md', '/faq']
# token_count += self.token_handler.count_tokens(f.read()) md_files_priority = [file for file in md_files if
any(priority_string in str(file) for priority_string in priority_files_strings)]
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
md_files = md_files_priority + md_files_not_priority
docs_prompt = "" docs_prompt = ""
for file in md_files: for file in md_files:
with open(file, 'r') as f: with open(file, 'r') as f:
file_path = str(file).replace(str(docs_path), '') file_path = str(file).replace(str(docs_path), '')
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n" docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read().strip()}\n=========\n\n"
token_count = self.token_handler.count_tokens(docs_prompt)
get_logger().debug(f"Token count of full documentation website: {token_count}")
model = get_settings().config.model
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
delta_output = 2000
if token_count > max_tokens_full - delta_output:
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
self.vars['snippets'] = docs_prompt.strip() self.vars['snippets'] = docs_prompt.strip()
# # Initialize embeddings
# from langchain_openai import OpenAIEmbeddings
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
# api_key=get_settings().openai.key)
#
# # Get similar snippets via similarity search
# if get_settings().pr_help.force_local_db:
# sim_results = self.get_sim_results_from_local_db(embeddings)
# elif get_settings().get('pinecone.api_key'):
# sim_results = self.get_sim_results_from_pinecone_db(embeddings)
# else:
# sim_results = self.get_sim_results_from_s3_db(embeddings)
# if not sim_results:
# get_logger().info("Failed to load the S3 index. Loading the local index...")
# sim_results = self.get_sim_results_from_local_db(embeddings)
# if not sim_results:
# get_logger().error("Failed to retrieve similar snippets. Exiting...")
# return
# # Prepare relevant snippets
# relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
# await self.prepare_relevant_snippets(sim_results)
# self.vars['snippets'] = relevant_snippets_str.strip()
# run the AI model # run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)

View File

@ -28,12 +28,6 @@ gunicorn==22.0.0
pytest-cov==5.0.0 pytest-cov==5.0.0
pydantic==2.8.2 pydantic==2.8.2
html2text==2024.2.26 html2text==2024.2.26
# help bot
langchain==0.3.0
langchain-openai==0.2.0
langchain-pinecone==0.2.0
langchain-chroma==0.1.4
chromadb==0.5.7
# Uncomment the following lines to enable the 'similar issue' tool # Uncomment the following lines to enable the 'similar issue' tool
# pinecone-client # pinecone-client
# pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main # pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main