mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-01 19:30:40 +08:00
Refactor PR help message tool to use full documentation content for answering questions and update relevant section handling in prompts
This commit is contained in:
Binary file not shown.
@ -4,7 +4,7 @@ You will recieve a question, and the full documentation website content.
|
||||
Your goal is to provide the best answer to the question using the documentation provided.
|
||||
|
||||
Additional instructions:
|
||||
- Try to be short and concise in your answers. Give examples if needed.
|
||||
- Try to be short and concise in your answers. Try to give examples if needed.
|
||||
- The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others.
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class relevant_section(BaseModel):
|
||||
class DocHelper(BaseModel):
|
||||
user_question: str = Field(description="The user's question")
|
||||
response: str = Field(description="The response to the user's question")
|
||||
relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance")
|
||||
relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance (most relevant first)")
|
||||
=====
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ User's Question:
|
||||
=====
|
||||
|
||||
|
||||
Relevant doc snippets retrieved:
|
||||
Documentation website content:
|
||||
=====
|
||||
{{ snippets|trim }}
|
||||
=====
|
||||
|
@ -8,11 +8,12 @@ from pathlib import Path
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from pr_agent.algo.pr_processing import retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import ModelType, load_yaml
|
||||
from pr_agent.algo.utils import ModelType, load_yaml, clip_tokens
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider, GithubProvider, BitbucketServerProvider, \
|
||||
get_git_provider_with_context
|
||||
@ -68,83 +69,6 @@ class PRHelpMessage:
|
||||
question_str = ""
|
||||
return question_str
|
||||
|
||||
def get_sim_results_from_s3_db(self, embeddings):
|
||||
get_logger().info("Loading the S3 index...")
|
||||
sim_results = []
|
||||
try:
|
||||
from langchain_chroma import Chroma
|
||||
from urllib import request
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Define the local file path within the temporary directory
|
||||
local_file_path = os.path.join(temp_dir, 'chroma_db.zip')
|
||||
|
||||
bucket = 'pr-agent'
|
||||
file_name = 'chroma_db.zip'
|
||||
s3_url = f'https://{bucket}.s3.amazonaws.com/{file_name}'
|
||||
request.urlretrieve(s3_url, local_file_path)
|
||||
|
||||
# # Download the file from S3 to the temporary directory
|
||||
# s3 = boto3.client('s3')
|
||||
# s3.download_file(bucket, file_name, local_file_path)
|
||||
|
||||
# Extract the contents of the zip file
|
||||
with zipfile.ZipFile(local_file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
|
||||
embedding_function=embeddings)
|
||||
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error while getting sim from S3: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
return sim_results
|
||||
|
||||
def get_sim_results_from_local_db(self, embeddings):
|
||||
get_logger().info("Loading the local index...")
|
||||
sim_results = []
|
||||
try:
|
||||
from langchain_chroma import Chroma
|
||||
get_logger().info("Loading the Chroma index...")
|
||||
db_path = "./docs/chroma_db.zip"
|
||||
if not os.path.exists(db_path):
|
||||
db_path= "/app/docs/chroma_db.zip"
|
||||
if not os.path.exists(db_path):
|
||||
get_logger().error("Local db not found")
|
||||
return sim_results
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
||||
# Extract the ZIP file
|
||||
with zipfile.ZipFile(db_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
|
||||
embedding_function=embeddings)
|
||||
|
||||
# Do similarity search
|
||||
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error while getting sim from local db: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
return sim_results
|
||||
|
||||
def get_sim_results_from_pinecone_db(self, embeddings):
|
||||
get_logger().info("Loading the Pinecone index...")
|
||||
sim_results = []
|
||||
try:
|
||||
from langchain_pinecone import PineconeVectorStore
|
||||
INDEX_NAME = "pr-agent-docs"
|
||||
vectorstore = PineconeVectorStore(
|
||||
index_name=INDEX_NAME, embedding=embeddings,
|
||||
pinecone_api_key=get_settings().pinecone.api_key
|
||||
)
|
||||
|
||||
# Do similarity search
|
||||
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error while getting sim from Pinecone db: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
return sim_results
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
if self.question_str:
|
||||
@ -163,44 +87,32 @@ class PRHelpMessage:
|
||||
# get all the 'md' files inside docs_path and its subdirectories
|
||||
md_files = list(docs_path.glob('**/*.md'))
|
||||
folders_to_exclude = ['/finetuning_benchmark/']
|
||||
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md']
|
||||
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md']
|
||||
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
|
||||
# # calculate the token count of all the md files
|
||||
# token_count = 0
|
||||
# for file in md_files:
|
||||
# with open(file, 'r') as f:
|
||||
# token_count += self.token_handler.count_tokens(f.read())
|
||||
|
||||
# sort the 'md_files' so that 'priority_files' will be at the top
|
||||
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
|
||||
'tools/improve.md', '/faq']
|
||||
md_files_priority = [file for file in md_files if
|
||||
any(priority_string in str(file) for priority_string in priority_files_strings)]
|
||||
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
|
||||
md_files = md_files_priority + md_files_not_priority
|
||||
|
||||
docs_prompt = ""
|
||||
for file in md_files:
|
||||
with open(file, 'r') as f:
|
||||
file_path = str(file).replace(str(docs_path), '')
|
||||
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n"
|
||||
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read().strip()}\n=========\n\n"
|
||||
token_count = self.token_handler.count_tokens(docs_prompt)
|
||||
get_logger().debug(f"Token count of full documentation website: {token_count}")
|
||||
|
||||
model = get_settings().config.model
|
||||
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||
delta_output = 2000
|
||||
if token_count > max_tokens_full - delta_output:
|
||||
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
|
||||
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
|
||||
self.vars['snippets'] = docs_prompt.strip()
|
||||
# # Initialize embeddings
|
||||
# from langchain_openai import OpenAIEmbeddings
|
||||
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
|
||||
# api_key=get_settings().openai.key)
|
||||
#
|
||||
# # Get similar snippets via similarity search
|
||||
# if get_settings().pr_help.force_local_db:
|
||||
# sim_results = self.get_sim_results_from_local_db(embeddings)
|
||||
# elif get_settings().get('pinecone.api_key'):
|
||||
# sim_results = self.get_sim_results_from_pinecone_db(embeddings)
|
||||
# else:
|
||||
# sim_results = self.get_sim_results_from_s3_db(embeddings)
|
||||
# if not sim_results:
|
||||
# get_logger().info("Failed to load the S3 index. Loading the local index...")
|
||||
# sim_results = self.get_sim_results_from_local_db(embeddings)
|
||||
# if not sim_results:
|
||||
# get_logger().error("Failed to retrieve similar snippets. Exiting...")
|
||||
# return
|
||||
|
||||
# # Prepare relevant snippets
|
||||
# relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
|
||||
# await self.prepare_relevant_snippets(sim_results)
|
||||
# self.vars['snippets'] = relevant_snippets_str.strip()
|
||||
|
||||
# run the AI model
|
||||
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||
|
@ -28,12 +28,6 @@ gunicorn==22.0.0
|
||||
pytest-cov==5.0.0
|
||||
pydantic==2.8.2
|
||||
html2text==2024.2.26
|
||||
# help bot
|
||||
langchain==0.3.0
|
||||
langchain-openai==0.2.0
|
||||
langchain-pinecone==0.2.0
|
||||
langchain-chroma==0.1.4
|
||||
chromadb==0.5.7
|
||||
# Uncomment the following lines to enable the 'similar issue' tool
|
||||
# pinecone-client
|
||||
# pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main
|
||||
|
Reference in New Issue
Block a user