mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
Refactor PR help message tool to use full documentation content for answering questions and update relevant section handling in prompts
This commit is contained in:
Binary file not shown.
@ -4,7 +4,7 @@ You will recieve a question, and the full documentation website content.
|
|||||||
Your goal is to provide the best answer to the question using the documentation provided.
|
Your goal is to provide the best answer to the question using the documentation provided.
|
||||||
|
|
||||||
Additional instructions:
|
Additional instructions:
|
||||||
- Try to be short and concise in your answers. Give examples if needed.
|
- Try to be short and concise in your answers. Try to give examples if needed.
|
||||||
- The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others.
|
- The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others.
|
||||||
|
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ class relevant_section(BaseModel):
|
|||||||
class DocHelper(BaseModel):
|
class DocHelper(BaseModel):
|
||||||
user_question: str = Field(description="The user's question")
|
user_question: str = Field(description="The user's question")
|
||||||
response: str = Field(description="The response to the user's question")
|
response: str = Field(description="The response to the user's question")
|
||||||
relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance")
|
relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance (most relevant first)")
|
||||||
=====
|
=====
|
||||||
|
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ User's Question:
|
|||||||
=====
|
=====
|
||||||
|
|
||||||
|
|
||||||
Relevant doc snippets retrieved:
|
Documentation website content:
|
||||||
=====
|
=====
|
||||||
{{ snippets|trim }}
|
{{ snippets|trim }}
|
||||||
=====
|
=====
|
||||||
|
@ -8,11 +8,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
|
from pr_agent.algo import MAX_TOKENS
|
||||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
from pr_agent.algo.pr_processing import retry_with_fallback_models
|
from pr_agent.algo.pr_processing import retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import ModelType, load_yaml
|
from pr_agent.algo.utils import ModelType, load_yaml, clip_tokens
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import get_git_provider, GithubProvider, BitbucketServerProvider, \
|
from pr_agent.git_providers import get_git_provider, GithubProvider, BitbucketServerProvider, \
|
||||||
get_git_provider_with_context
|
get_git_provider_with_context
|
||||||
@ -68,83 +69,6 @@ class PRHelpMessage:
|
|||||||
question_str = ""
|
question_str = ""
|
||||||
return question_str
|
return question_str
|
||||||
|
|
||||||
def get_sim_results_from_s3_db(self, embeddings):
|
|
||||||
get_logger().info("Loading the S3 index...")
|
|
||||||
sim_results = []
|
|
||||||
try:
|
|
||||||
from langchain_chroma import Chroma
|
|
||||||
from urllib import request
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
# Define the local file path within the temporary directory
|
|
||||||
local_file_path = os.path.join(temp_dir, 'chroma_db.zip')
|
|
||||||
|
|
||||||
bucket = 'pr-agent'
|
|
||||||
file_name = 'chroma_db.zip'
|
|
||||||
s3_url = f'https://{bucket}.s3.amazonaws.com/{file_name}'
|
|
||||||
request.urlretrieve(s3_url, local_file_path)
|
|
||||||
|
|
||||||
# # Download the file from S3 to the temporary directory
|
|
||||||
# s3 = boto3.client('s3')
|
|
||||||
# s3.download_file(bucket, file_name, local_file_path)
|
|
||||||
|
|
||||||
# Extract the contents of the zip file
|
|
||||||
with zipfile.ZipFile(local_file_path, 'r') as zip_ref:
|
|
||||||
zip_ref.extractall(temp_dir)
|
|
||||||
|
|
||||||
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
|
|
||||||
embedding_function=embeddings)
|
|
||||||
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
|
|
||||||
except Exception as e:
|
|
||||||
get_logger().error(f"Error while getting sim from S3: {e}",
|
|
||||||
artifact={"traceback": traceback.format_exc()})
|
|
||||||
return sim_results
|
|
||||||
|
|
||||||
def get_sim_results_from_local_db(self, embeddings):
|
|
||||||
get_logger().info("Loading the local index...")
|
|
||||||
sim_results = []
|
|
||||||
try:
|
|
||||||
from langchain_chroma import Chroma
|
|
||||||
get_logger().info("Loading the Chroma index...")
|
|
||||||
db_path = "./docs/chroma_db.zip"
|
|
||||||
if not os.path.exists(db_path):
|
|
||||||
db_path= "/app/docs/chroma_db.zip"
|
|
||||||
if not os.path.exists(db_path):
|
|
||||||
get_logger().error("Local db not found")
|
|
||||||
return sim_results
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
|
|
||||||
# Extract the ZIP file
|
|
||||||
with zipfile.ZipFile(db_path, 'r') as zip_ref:
|
|
||||||
zip_ref.extractall(temp_dir)
|
|
||||||
|
|
||||||
vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
|
|
||||||
embedding_function=embeddings)
|
|
||||||
|
|
||||||
# Do similarity search
|
|
||||||
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
|
|
||||||
except Exception as e:
|
|
||||||
get_logger().error(f"Error while getting sim from local db: {e}",
|
|
||||||
artifact={"traceback": traceback.format_exc()})
|
|
||||||
return sim_results
|
|
||||||
|
|
||||||
def get_sim_results_from_pinecone_db(self, embeddings):
|
|
||||||
get_logger().info("Loading the Pinecone index...")
|
|
||||||
sim_results = []
|
|
||||||
try:
|
|
||||||
from langchain_pinecone import PineconeVectorStore
|
|
||||||
INDEX_NAME = "pr-agent-docs"
|
|
||||||
vectorstore = PineconeVectorStore(
|
|
||||||
index_name=INDEX_NAME, embedding=embeddings,
|
|
||||||
pinecone_api_key=get_settings().pinecone.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
# Do similarity search
|
|
||||||
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
|
|
||||||
except Exception as e:
|
|
||||||
get_logger().error(f"Error while getting sim from Pinecone db: {e}",
|
|
||||||
artifact={"traceback": traceback.format_exc()})
|
|
||||||
return sim_results
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
if self.question_str:
|
if self.question_str:
|
||||||
@ -159,48 +83,36 @@ class PRHelpMessage:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# current path
|
# current path
|
||||||
docs_path= Path(__file__).parent.parent.parent/'docs'/'docs'
|
docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs'
|
||||||
# get all the 'md' files inside docs_path and its subdirectories
|
# get all the 'md' files inside docs_path and its subdirectories
|
||||||
md_files = list(docs_path.glob('**/*.md'))
|
md_files = list(docs_path.glob('**/*.md'))
|
||||||
folders_to_exclude =['/finetuning_benchmark/']
|
folders_to_exclude = ['/finetuning_benchmark/']
|
||||||
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md']
|
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md']
|
||||||
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
|
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
|
||||||
# # calculate the token count of all the md files
|
|
||||||
# token_count = 0
|
|
||||||
# for file in md_files:
|
|
||||||
# with open(file, 'r') as f:
|
|
||||||
# token_count += self.token_handler.count_tokens(f.read())
|
|
||||||
|
|
||||||
docs_prompt =""
|
# sort the 'md_files' so that 'priority_files' will be at the top
|
||||||
|
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
|
||||||
|
'tools/improve.md', '/faq']
|
||||||
|
md_files_priority = [file for file in md_files if
|
||||||
|
any(priority_string in str(file) for priority_string in priority_files_strings)]
|
||||||
|
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
|
||||||
|
md_files = md_files_priority + md_files_not_priority
|
||||||
|
|
||||||
|
docs_prompt = ""
|
||||||
for file in md_files:
|
for file in md_files:
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
file_path = str(file).replace(str(docs_path), '')
|
file_path = str(file).replace(str(docs_path), '')
|
||||||
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n"
|
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read().strip()}\n=========\n\n"
|
||||||
|
token_count = self.token_handler.count_tokens(docs_prompt)
|
||||||
|
get_logger().debug(f"Token count of full documentation website: {token_count}")
|
||||||
|
|
||||||
|
model = get_settings().config.model
|
||||||
|
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||||
|
delta_output = 2000
|
||||||
|
if token_count > max_tokens_full - delta_output:
|
||||||
|
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
|
||||||
|
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
|
||||||
self.vars['snippets'] = docs_prompt.strip()
|
self.vars['snippets'] = docs_prompt.strip()
|
||||||
# # Initialize embeddings
|
|
||||||
# from langchain_openai import OpenAIEmbeddings
|
|
||||||
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
|
|
||||||
# api_key=get_settings().openai.key)
|
|
||||||
#
|
|
||||||
# # Get similar snippets via similarity search
|
|
||||||
# if get_settings().pr_help.force_local_db:
|
|
||||||
# sim_results = self.get_sim_results_from_local_db(embeddings)
|
|
||||||
# elif get_settings().get('pinecone.api_key'):
|
|
||||||
# sim_results = self.get_sim_results_from_pinecone_db(embeddings)
|
|
||||||
# else:
|
|
||||||
# sim_results = self.get_sim_results_from_s3_db(embeddings)
|
|
||||||
# if not sim_results:
|
|
||||||
# get_logger().info("Failed to load the S3 index. Loading the local index...")
|
|
||||||
# sim_results = self.get_sim_results_from_local_db(embeddings)
|
|
||||||
# if not sim_results:
|
|
||||||
# get_logger().error("Failed to retrieve similar snippets. Exiting...")
|
|
||||||
# return
|
|
||||||
|
|
||||||
# # Prepare relevant snippets
|
|
||||||
# relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
|
|
||||||
# await self.prepare_relevant_snippets(sim_results)
|
|
||||||
# self.vars['snippets'] = relevant_snippets_str.strip()
|
|
||||||
|
|
||||||
# run the AI model
|
# run the AI model
|
||||||
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||||
|
@ -28,12 +28,6 @@ gunicorn==22.0.0
|
|||||||
pytest-cov==5.0.0
|
pytest-cov==5.0.0
|
||||||
pydantic==2.8.2
|
pydantic==2.8.2
|
||||||
html2text==2024.2.26
|
html2text==2024.2.26
|
||||||
# help bot
|
|
||||||
langchain==0.3.0
|
|
||||||
langchain-openai==0.2.0
|
|
||||||
langchain-pinecone==0.2.0
|
|
||||||
langchain-chroma==0.1.4
|
|
||||||
chromadb==0.5.7
|
|
||||||
# Uncomment the following lines to enable the 'similar issue' tool
|
# Uncomment the following lines to enable the 'similar issue' tool
|
||||||
# pinecone-client
|
# pinecone-client
|
||||||
# pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main
|
# pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main
|
||||||
|
Reference in New Issue
Block a user