mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Add support for documentation content exceeding token limits (#1670)
* - Add support for documentation content exceeding token limits via two phase operation: 1. Ask LLM to rank headings which are most likely to contain an answer to a user question 2. Provide the corresponding files for the LLM to search for an answer. - Refactor of help_docs to make the code more readable - For the purpose of getting canonical path: git providers to use default branch and not the PR's source branch. - Refactor of token counting and making it clear on when an estimate factor will be used. * Code review changes: 1. Correctly handle exception during retry_with_fallback_models (to allow fallback model to run in case of failure) 2. Better naming for default_branch in bitbucket cloud provider
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
from threading import Lock
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
from math import ceil
|
||||
from tiktoken import encoding_for_model, get_encoding
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
@ -105,6 +104,19 @@ class TokenHandler:
|
||||
get_logger().error( f"Error in Anthropic token counting: {e}")
|
||||
return MaxTokens
|
||||
|
||||
def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate):
|
||||
from math import ceil
|
||||
import re
|
||||
|
||||
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model)
|
||||
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
|
||||
return default_encoder_estimate
|
||||
#else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
|
||||
|
||||
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
|
||||
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate")
|
||||
return ceil(elbow_factor * default_encoder_estimate)
|
||||
|
||||
def count_tokens(self, patch: str, force_accurate=False) -> int:
|
||||
"""
|
||||
Counts the number of tokens in a given patch string.
|
||||
@ -116,21 +128,15 @@ class TokenHandler:
|
||||
The number of tokens in the patch string.
|
||||
"""
|
||||
encoder_estimate = len(self.encoder.encode(patch, disallowed_special=()))
|
||||
|
||||
#If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it.
|
||||
if not force_accurate:
|
||||
return encoder_estimate
|
||||
#else, need to provide an accurate estimation:
|
||||
|
||||
#else, force_accurate==True: User requested providing an accurate estimation:
|
||||
model = get_settings().config.model.lower()
|
||||
if force_accurate and 'claude' in model and get_settings(use_context=False).get('anthropic.key'):
|
||||
if 'claude' in model and get_settings(use_context=False).get('anthropic.key'):
|
||||
return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models
|
||||
#else: Non Anthropic provided model
|
||||
|
||||
import re
|
||||
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model)
|
||||
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
|
||||
return encoder_estimate
|
||||
#else: Model is neither an OpenAI, nor an Anthropic model - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
|
||||
|
||||
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
|
||||
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate")
|
||||
return ceil(elbow_factor * encoder_estimate)
|
||||
#else: Non Anthropic provided model:
|
||||
return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate)
|
||||
|
@ -29,6 +29,7 @@ global_settings = Dynaconf(
|
||||
"settings/custom_labels.toml",
|
||||
"settings/pr_help_prompts.toml",
|
||||
"settings/pr_help_docs_prompts.toml",
|
||||
"settings/pr_help_docs_headings_prompts.toml",
|
||||
"settings/.secrets.toml",
|
||||
"settings_prod/.secrets.toml",
|
||||
]]
|
||||
|
@ -92,7 +92,7 @@ class BitbucketProvider(GitProvider):
|
||||
return ("", "")
|
||||
workspace_name, project_name = repo_path.split('/')
|
||||
else:
|
||||
desired_branch = self.get_pr_branch()
|
||||
desired_branch = self.get_repo_default_branch()
|
||||
parsed_pr_url = urlparse(self.pr_url)
|
||||
scheme_and_netloc = parsed_pr_url.scheme + "://" + parsed_pr_url.netloc
|
||||
workspace_name, project_name = (self.workspace_slug, self.repo_slug)
|
||||
@ -470,6 +470,16 @@ class BitbucketProvider(GitProvider):
|
||||
def get_pr_branch(self):
|
||||
return self.pr.source_branch
|
||||
|
||||
# This function attempts to get the default branch of the repository. As a fallback, uses the PR destination branch.
|
||||
# Note: Must be running from a PR context.
|
||||
def get_repo_default_branch(self):
|
||||
try:
|
||||
url_repo = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/"
|
||||
response_repo = requests.request("GET", url_repo, headers=self.headers).json()
|
||||
return response_repo['mainbranch']['name']
|
||||
except:
|
||||
return self.pr.destination_branch
|
||||
|
||||
def get_pr_owner_id(self) -> str | None:
|
||||
return self.workspace_slug
|
||||
|
||||
|
@ -64,9 +64,15 @@ class BitbucketServerProvider(GitProvider):
|
||||
workspace_name = None
|
||||
project_name = None
|
||||
if not repo_git_url:
|
||||
desired_branch = self.get_pr_branch()
|
||||
workspace_name = self.workspace_slug
|
||||
project_name = self.repo_slug
|
||||
default_branch_dict = self.bitbucket_client.get_default_branch(workspace_name, project_name)
|
||||
if 'displayId' in default_branch_dict:
|
||||
desired_branch = default_branch_dict['displayId']
|
||||
else:
|
||||
get_logger().error(f"Cannot obtain default branch for workspace_name={workspace_name}, "
|
||||
f"project_name={project_name}, default_branch_dict={default_branch_dict}")
|
||||
return ("", "")
|
||||
elif '.git' in repo_git_url and 'scm/' in repo_git_url:
|
||||
repo_path = repo_git_url.split('.git')[0].split('scm/')[-1]
|
||||
if repo_path.count('/') == 1: # Has to have the form <workspace>/<repo>
|
||||
|
@ -133,7 +133,7 @@ class GithubProvider(GitProvider):
|
||||
if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, or not an issue, use self.repo object
|
||||
owner, repo = self.repo.split('/')
|
||||
scheme_and_netloc = self.base_url_html
|
||||
desired_branch = self.get_pr_branch()
|
||||
desired_branch = self.repo_obj.default_branch
|
||||
if not all([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context
|
||||
get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)")
|
||||
return ("", "")
|
||||
|
@ -87,7 +87,11 @@ class GitLabProvider(GitProvider):
|
||||
return ("", "")
|
||||
if not repo_git_url: #Use PR url as context
|
||||
repo_path = self._get_project_path_from_pr_or_issue_url(self.pr_url)
|
||||
desired_branch = self.get_pr_branch()
|
||||
try:
|
||||
desired_branch = self.gl.projects.get(self.id_project).default_branch
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Cannot get PR: {self.pr_url} default branch. Tried project ID: {self.id_project}")
|
||||
return ("", "")
|
||||
else: #Use repo git url
|
||||
repo_path = repo_git_url.split('.git')[0].split('.com/')[-1]
|
||||
prefix = f"{self.gitlab_url}/{repo_path}/-/blob/{desired_branch}"
|
||||
|
@ -9,7 +9,6 @@
|
||||
model="o3-mini"
|
||||
fallback_models=["gpt-4o-2024-11-20"]
|
||||
#model_weak="gpt-4o-mini-2024-07-18" # optional, a weaker model to use for some easier tasks
|
||||
model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens.
|
||||
# CLI
|
||||
git_provider="github"
|
||||
publish_output=true
|
||||
@ -30,6 +29,7 @@ max_description_tokens = 500
|
||||
max_commits_tokens = 500
|
||||
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
||||
custom_model_max_tokens=-1 # for models not in the default list
|
||||
model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens - applicable only when requesting an accurate estimate.
|
||||
# patch extension logic
|
||||
patch_extension_skip_types =[".md",".txt"]
|
||||
allow_dynamic_context=true
|
||||
|
101
pr_agent/settings/pr_help_docs_headings_prompts.toml
Normal file
101
pr_agent/settings/pr_help_docs_headings_prompts.toml
Normal file
@ -0,0 +1,101 @@
|
||||
|
||||
[pr_help_docs_headings_prompts]
|
||||
system="""You are Doc-helper, a language model that ranks documentation files based on their relevance to user questions.
|
||||
You will receive a question, a repository url and file names along with optional groups of headings extracted from such files from that repository (either as markdown or as restructred text).
|
||||
Your task is to rank file paths based on how likely they contain the answer to a user's question, using only the headings from each such file and the file name.
|
||||
|
||||
======
|
||||
==file name==
|
||||
|
||||
'src/file1.py'
|
||||
|
||||
==index==
|
||||
|
||||
0 based integer
|
||||
|
||||
==file headings==
|
||||
heading #1
|
||||
heading #2
|
||||
...
|
||||
|
||||
==file name==
|
||||
|
||||
'src/file2.py'
|
||||
|
||||
==index==
|
||||
|
||||
0 based integer
|
||||
|
||||
==file headings==
|
||||
heading #1
|
||||
heading #2
|
||||
...
|
||||
|
||||
...
|
||||
======
|
||||
|
||||
Additional instructions:
|
||||
- Consider only the file names and section headings within each document
|
||||
- Present the most relevant files first, based strictly on how well their headings and file names align with user question
|
||||
|
||||
The output must be a YAML object equivalent to type $DocHeadingsHelper, according to the following Pydantic definitions:
|
||||
=====
|
||||
class file_idx_and_path(BaseModel):
|
||||
idx: int = Field(description="The zero based index of file_name, as it appeared in the original list of headings. Cannot be negative.")
|
||||
file_name: str = Field(description="The file_name exactly as it appeared in the question")
|
||||
|
||||
class DocHeadingsHelper(BaseModel):
|
||||
user_question: str = Field(description="The user's question")
|
||||
relevant_files_ranking: List[file_idx_and_path] = Field(description="Files sorted in descending order by relevance to question")
|
||||
=====
|
||||
|
||||
|
||||
Example output:
|
||||
```yaml
|
||||
user_question: |
|
||||
...
|
||||
relevant_files_ranking:
|
||||
- idx: 101
|
||||
file_name: "src/file1.py"
|
||||
- ...
|
||||
"""
|
||||
|
||||
user="""\
|
||||
Documentation url: '{{ docs_url|trim }}'
|
||||
-----
|
||||
|
||||
|
||||
User's Question:
|
||||
=====
|
||||
{{ question|trim }}
|
||||
=====
|
||||
|
||||
|
||||
Filenames with optional headings from documentation website content:
|
||||
=====
|
||||
{{ snippets|trim }}
|
||||
=====
|
||||
|
||||
|
||||
Reminder: The output must be a YAML object equivalent to type $DocHeadingsHelper, similar to the following example output:
|
||||
=====
|
||||
|
||||
|
||||
Example output:
|
||||
```yaml
|
||||
user_question: |
|
||||
...
|
||||
relevant_files_ranking:
|
||||
- idx: 101
|
||||
file_name: "src/file1.py"
|
||||
- ...
|
||||
=====
|
||||
|
||||
Important Notes:
|
||||
1. Output most relevant file names first, by descending order of relevancy.
|
||||
2. Only include files with non-negative indices
|
||||
|
||||
|
||||
Response (should be a valid YAML, and nothing else).
|
||||
```yaml
|
||||
"""
|
@ -1,11 +1,11 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
@ -78,10 +78,48 @@ def get_maximal_text_input_length_for_token_count_estimation():
|
||||
return 900000 #Claude API for token estimation allows maximal text input of 900K chars
|
||||
return math.inf #Otherwise, no known limitation on input text just for token estimation
|
||||
|
||||
# Load documentation files to memory, decorating them with a header to mark where each file begins,
|
||||
# as to help the LLM to give a better answer.
|
||||
def aggregate_documentation_files_for_prompt_contents(base_path: str, doc_files: List[str]) -> Optional[str]:
|
||||
docs_prompt = ""
|
||||
def return_document_headings(text: str, ext: str) -> str:
|
||||
try:
|
||||
lines = text.split('\n')
|
||||
headings = set()
|
||||
|
||||
if not text or not re.search(r'[a-zA-Z]', text):
|
||||
get_logger().error(f"Empty or non text content found in text: {text}.")
|
||||
return ""
|
||||
|
||||
if ext in ['.md', '.mdx']:
|
||||
# Extract Markdown headings (lines starting with #)
|
||||
headings = {line.strip() for line in lines if line.strip().startswith('#')}
|
||||
elif ext == '.rst':
|
||||
# Find indices of lines that have all same character:
|
||||
#Allowed characters according to list from: https://docutils.sourceforge.io/docs/ref/rst/restructuredtext.html#sections
|
||||
section_chars = set('!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~')
|
||||
|
||||
# Find potential section marker lines (underlines/overlines): They have to be the same character
|
||||
marker_lines = []
|
||||
for i, line in enumerate(lines):
|
||||
line = line.rstrip()
|
||||
if line and all(c == line[0] for c in line) and line[0] in section_chars:
|
||||
marker_lines.append((i, len(line)))
|
||||
|
||||
# Check for headings adjacent to marker lines (below + text must be in length equal or less)
|
||||
for idx, length in marker_lines:
|
||||
# Check if it's an underline (heading is above it)
|
||||
if idx > 0 and lines[idx - 1].rstrip() and len(lines[idx - 1].rstrip()) <= length:
|
||||
headings.add(lines[idx - 1].rstrip())
|
||||
else:
|
||||
get_logger().error(f"Unsupported file extension: {ext}")
|
||||
return ""
|
||||
|
||||
return '\n'.join(headings)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||
return ""
|
||||
|
||||
# Load documentation files to memory: full file path (as will be given as prompt) -> doc contents
|
||||
def map_documentation_files_to_contents(base_path: str, doc_files: list[str], max_allowed_file_len=5000) -> dict[str, str]:
|
||||
try:
|
||||
returned_dict = {}
|
||||
for file in doc_files:
|
||||
try:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
@ -89,25 +127,55 @@ def aggregate_documentation_files_for_prompt_contents(base_path: str, doc_files:
|
||||
# Skip files with no text content
|
||||
if not re.search(r'[a-zA-Z]', content):
|
||||
continue
|
||||
if len(content) > max_allowed_file_len:
|
||||
get_logger().warning(f"File {file} length: {len(content)} exceeds limit: {max_allowed_file_len}, so it will be trimmed.")
|
||||
content = content[:max_allowed_file_len]
|
||||
file_path = str(file).replace(str(base_path), '')
|
||||
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{content.strip()}\n=========\n\n"
|
||||
returned_dict[file_path] = content.strip()
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Error while reading the file {file}: {e}")
|
||||
continue
|
||||
if not docs_prompt:
|
||||
get_logger().error("Couldn't find any usable documentation files. Returning None.")
|
||||
return None
|
||||
return docs_prompt
|
||||
if not returned_dict:
|
||||
get_logger().error("Couldn't find any usable documentation files. Returning empty dict.")
|
||||
return returned_dict
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty dict.")
|
||||
return {}
|
||||
|
||||
def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: List[Dict[str, str]],
|
||||
supported_suffixes: List[str], base_url_prefix: str, base_url_suffix: str="") -> str:
|
||||
# Goes over files' contents, generating payload for prompt while decorating them with a header to mark where each file begins,
|
||||
# as to help the LLM to give a better answer.
|
||||
def aggregate_documentation_files_for_prompt_contents(file_path_to_contents: dict[str, str], return_just_headings=False) -> str:
|
||||
try:
|
||||
docs_prompt = ""
|
||||
for idx, file_path in enumerate(file_path_to_contents):
|
||||
file_contents = file_path_to_contents[file_path].strip()
|
||||
if not file_contents:
|
||||
get_logger().error(f"Got empty file contents for: {file_path}. Skipping this file.")
|
||||
continue
|
||||
if return_just_headings:
|
||||
file_headings = return_document_headings(file_contents, os.path.splitext(file_path)[-1]).strip()
|
||||
if file_headings:
|
||||
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n==file headings==\n\n{file_headings}\n=========\n\n"
|
||||
else:
|
||||
get_logger().warning(f"No headers for: {file_path}. Will only use filename")
|
||||
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n"
|
||||
else:
|
||||
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{file_contents}\n=========\n\n"
|
||||
return docs_prompt
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||
return ""
|
||||
|
||||
def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: list[dict[str, str]],
|
||||
supported_suffixes: list[str], base_url_prefix: str, base_url_suffix: str="") -> str:
|
||||
try:
|
||||
base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix
|
||||
answer_str = ""
|
||||
answer_str += f"### Question: \n{question_str}\n\n"
|
||||
answer_str += f"### Answer:\n{response_str.strip()}\n\n"
|
||||
answer_str += f"#### Relevant Sources:\n\n"
|
||||
for section in relevant_sections:
|
||||
file = section.get('file_name').strip()
|
||||
file = section.get('file_name').lstrip('/').strip() #Remove any '/' in the beginning, since some models do it anyway
|
||||
ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)]
|
||||
if not ext:
|
||||
get_logger().warning(f"Unsupported file extension: {file}")
|
||||
@ -119,6 +187,9 @@ def format_markdown_q_and_a_response(question_str: str, response_str: str, relev
|
||||
else:
|
||||
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n"
|
||||
return answer_str
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||
return ""
|
||||
|
||||
def format_markdown_header(header: str) -> str:
|
||||
try:
|
||||
@ -157,6 +228,7 @@ def clean_markdown_content(content: str) -> str:
|
||||
Returns:
|
||||
Cleaned markdown content
|
||||
"""
|
||||
try:
|
||||
# Remove HTML comments
|
||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||
|
||||
@ -180,27 +252,39 @@ def clean_markdown_content(content: str) -> str:
|
||||
content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)</\1>',
|
||||
r'\2', content, flags=re.DOTALL)
|
||||
return content.strip()
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||
return ""
|
||||
|
||||
class PredictionPreparator:
|
||||
def __init__(self, ai_handler, vars, system_prompt, user_prompt):
|
||||
try:
|
||||
self.ai_handler = ai_handler
|
||||
variables = copy.deepcopy(vars)
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
self.system_prompt = environment.from_string(system_prompt).render(variables)
|
||||
self.user_prompt = environment.from_string(user_prompt).render(variables)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Caught exception during init. Setting ai_handler to None to prevent __call__.")
|
||||
self.ai_handler = None
|
||||
|
||||
#Called by retry_with_fallback_models and therefore, on any failure must throw an exception:
|
||||
async def __call__(self, model: str) -> str:
|
||||
if not self.ai_handler:
|
||||
get_logger().error("ai handler not set. Cannot invoke model!")
|
||||
raise ValueError("PredictionPreparator not initialized")
|
||||
try:
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, temperature=get_settings().config.temperature, system=self.system_prompt, user=self.user_prompt)
|
||||
return response
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error while preparing prediction: {e}")
|
||||
return ""
|
||||
get_logger().exception("Caught exception during prediction.", artifacts={'system': self.system_prompt, 'user': self.user_prompt})
|
||||
raise e
|
||||
|
||||
|
||||
class PRHelpDocs(object):
|
||||
def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: Tuple[str]=None, return_as_string: bool=False):
|
||||
def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: tuple[str]=None, return_as_string: bool=False):
|
||||
try:
|
||||
self.ctx_url = ctx_url
|
||||
self.question = args[0] if args else None
|
||||
self.return_as_string = return_as_string
|
||||
@ -238,6 +322,9 @@ class PRHelpDocs(object):
|
||||
self.vars,
|
||||
get_settings().pr_help_docs_prompts.system,
|
||||
get_settings().pr_help_docs_prompts.user)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Caught exception during init. Setting self.question to None to prevent run() to do anything.")
|
||||
self.question = None
|
||||
|
||||
async def run(self):
|
||||
if not self.question:
|
||||
@ -246,7 +333,93 @@ class PRHelpDocs(object):
|
||||
|
||||
try:
|
||||
# Clone the repository and gather relevant documentation files.
|
||||
docs_prompt = None
|
||||
docs_filepath_to_contents = self._gen_filenames_to_contents_map_from_repo()
|
||||
|
||||
#Generate prompt for the AI model. This will be the full text of all the documentation files combined.
|
||||
docs_prompt = aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents)
|
||||
if not docs_filepath_to_contents or not docs_prompt:
|
||||
get_logger().warning(f"Could not find any usable documentation. Returning with no result...")
|
||||
return None
|
||||
docs_prompt_to_send_to_model = docs_prompt
|
||||
|
||||
# Estimate how many tokens will be needed.
|
||||
# In case the expected number of tokens exceeds LLM limits, retry with just headings, asking the LLM to rank according to relevance to the question.
|
||||
# Based on returned ranking, rerun but sort the documents accordingly, this time, trim in case of exceeding limit.
|
||||
|
||||
#First, check if the text is not too long to even query the LLM provider:
|
||||
max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation()
|
||||
invoke_llm_just_with_headings = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
|
||||
only_return_if_trim_needed=True)
|
||||
if invoke_llm_just_with_headings:
|
||||
#Entire docs is too long. Rank and return according to relevance.
|
||||
docs_prompt_to_send_to_model = await self._rank_docs_and_return_them_as_prompt(docs_filepath_to_contents,
|
||||
max_allowed_txt_input)
|
||||
|
||||
if not docs_prompt_to_send_to_model:
|
||||
get_logger().error("Failed to generate docs prompt for model. Returning with no result...")
|
||||
return
|
||||
# At this point, either all original documents be used (if their total length doesn't exceed limits), or only those selected.
|
||||
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
|
||||
# Run the AI model and extract sections from its response
|
||||
response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars,
|
||||
get_settings().pr_help_docs_prompts.system,
|
||||
get_settings().pr_help_docs_prompts.user),
|
||||
model_type=ModelType.REGULAR)
|
||||
response_yaml = load_yaml(response)
|
||||
if not response_yaml:
|
||||
get_logger().error("Failed to parse the AI response.", artifacts={'response': response})
|
||||
return
|
||||
response_str = response_yaml.get('response')
|
||||
relevant_sections = response_yaml.get('relevant_sections')
|
||||
if not response_str or not relevant_sections:
|
||||
get_logger().error("Failed to extract response/relevant sections.",
|
||||
artifacts={'raw_response': response, 'response_str': response_str, 'relevant_sections': relevant_sections})
|
||||
return
|
||||
if int(response_yaml.get('question_is_relevant', '1')) == 0:
|
||||
get_logger().warning(f"Question is not relevant. Returning without an answer...",
|
||||
artifacts={'raw_response': response})
|
||||
return
|
||||
|
||||
# Format the response as markdown
|
||||
answer_str = self._format_model_answer(response_str, relevant_sections)
|
||||
if self.return_as_string: #Skip publishing
|
||||
return answer_str
|
||||
#Otherwise, publish the answer if answer is non empty and publish is not turned off:
|
||||
if answer_str and get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment(answer_str)
|
||||
else:
|
||||
get_logger().info("Answer:", artifacts={'answer_str': answer_str})
|
||||
return answer_str
|
||||
except Exception as e:
|
||||
get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)')
|
||||
|
||||
def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False, max_allowed_files=5000) -> list[str]:
|
||||
try:
|
||||
matching_files = []
|
||||
|
||||
# Ensure extensions don't have leading dots and are lowercase
|
||||
dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts]
|
||||
|
||||
# Walk through directory and subdirectories
|
||||
file_cntr = 0
|
||||
for root, _, files in os.walk(abs_docs_path):
|
||||
for file in files:
|
||||
if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]:
|
||||
continue
|
||||
# Check if file has one of the specified extensions
|
||||
if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions):
|
||||
file_cntr+=1
|
||||
matching_files.append(os.path.join(root, file))
|
||||
if file_cntr >= max_allowed_files:
|
||||
get_logger().warning(f"Found at least {max_allowed_files} files in {abs_docs_path}, skipping the rest.")
|
||||
return matching_files
|
||||
return matching_files
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty list.")
|
||||
return []
|
||||
|
||||
def _gen_filenames_to_contents_map_from_repo(self) -> dict[str, str]:
|
||||
try:
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
get_logger().debug(f"About to clone repository: {self.repo_url} to temporary directory: {tmp_dir}...")
|
||||
returned_cloned_repo_root = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False)
|
||||
@ -268,103 +441,120 @@ class PRHelpDocs(object):
|
||||
ignore_readme=(self.docs_path=='.')))
|
||||
if not doc_files:
|
||||
get_logger().warning(f"No documentation files found matching file extensions: "
|
||||
f"{self.supported_doc_exts} under repo: {self.repo_url} path: {self.docs_path}")
|
||||
return None
|
||||
f"{self.supported_doc_exts} under repo: {self.repo_url} "
|
||||
f"path: {self.docs_path}. Returning empty dict.")
|
||||
return {}
|
||||
|
||||
get_logger().info(f'Answering a question inside context {self.ctx_url} for repo: {self.repo_url}'
|
||||
f' using the following documentation files: ', artifacts={'doc_files': doc_files})
|
||||
get_logger().info(f'For context {self.ctx_url} and repo: {self.repo_url}'
|
||||
f' will be using the following documentation files: ',
|
||||
artifacts={'doc_files': doc_files})
|
||||
|
||||
docs_prompt = aggregate_documentation_files_for_prompt_contents(returned_cloned_repo_root.path, doc_files)
|
||||
if not docs_prompt:
|
||||
get_logger().warning(f"Error reading one of the documentation files. Returning with no result...")
|
||||
return None
|
||||
docs_prompt_to_send_to_model = docs_prompt
|
||||
return map_documentation_files_to_contents(returned_cloned_repo_root.path, doc_files)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty dict.")
|
||||
return {}
|
||||
|
||||
# Estimate how many tokens will be needed. Trim in case of exceeding limit.
|
||||
# Firstly, check if text needs to be trimmed, as some models fail to return the estimated token count if the input text is too long.
|
||||
max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation()
|
||||
if len(docs_prompt_to_send_to_model) >= max_allowed_txt_input:
|
||||
get_logger().warning(f"Text length: {len(docs_prompt_to_send_to_model)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...")
|
||||
docs_prompt_to_send_to_model = docs_prompt_to_send_to_model[:max_allowed_txt_input]
|
||||
def _trim_docs_input(self, docs_input: str, max_allowed_txt_input: int, only_return_if_trim_needed=False) -> bool|str:
|
||||
try:
|
||||
if len(docs_input) >= max_allowed_txt_input:
|
||||
get_logger().warning(
|
||||
f"Text length: {len(docs_input)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...")
|
||||
if only_return_if_trim_needed:
|
||||
return True
|
||||
docs_input = docs_input[:max_allowed_txt_input]
|
||||
# Then, count the tokens in the prompt. If the count exceeds the limit, trim the text.
|
||||
token_count = self.token_handler.count_tokens(docs_prompt_to_send_to_model, force_accurate=True)
|
||||
token_count = self.token_handler.count_tokens(docs_input, force_accurate=True)
|
||||
get_logger().debug(f"Estimated token count of documentation to send to model: {token_count}")
|
||||
model = get_settings().config.model
|
||||
if model in MAX_TOKENS:
|
||||
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||
max_tokens_full = MAX_TOKENS[
|
||||
model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||
else:
|
||||
max_tokens_full = get_max_tokens(model)
|
||||
delta_output = 5000 #Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines.
|
||||
delta_output = 5000 # Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines.
|
||||
if token_count > max_tokens_full - delta_output:
|
||||
docs_prompt_to_send_to_model = clean_markdown_content(docs_prompt_to_send_to_model) #Reduce unnecessary text/images/etc.
|
||||
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...")
|
||||
docs_prompt_to_send_to_model = clip_tokens(docs_prompt_to_send_to_model, max_tokens_full - delta_output,
|
||||
if only_return_if_trim_needed:
|
||||
return True
|
||||
docs_input = clean_markdown_content(
|
||||
docs_input) # Reduce unnecessary text/images/etc.
|
||||
get_logger().info(
|
||||
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...")
|
||||
docs_input = clip_tokens(docs_input, max_tokens_full - delta_output,
|
||||
num_input_tokens=token_count)
|
||||
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
|
||||
if only_return_if_trim_needed:
|
||||
return False
|
||||
return docs_input
|
||||
except Exception as e:
|
||||
# Unexpected exception. Rethrowing it since:
|
||||
# 1. This is an internal function.
|
||||
# 2. An empty str/False result is a valid one - would require now checking also for None.
|
||||
get_logger().exception(f"Unexpected exception thrown. Rethrowing it...")
|
||||
raise e
|
||||
|
||||
async def _rank_docs_and_return_them_as_prompt(self, docs_filepath_to_contents: dict[str, str], max_allowed_txt_input: int) -> str:
|
||||
try:
|
||||
#Return just file name and their headings (if exist):
|
||||
docs_prompt_to_send_to_model = (
|
||||
aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents,
|
||||
return_just_headings=True))
|
||||
# Verify list of headings does not exceed limits - trim it if it does.
|
||||
docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
|
||||
only_return_if_trim_needed=False)
|
||||
if not docs_prompt_to_send_to_model:
|
||||
get_logger().error("_trim_docs_input returned an empty result.")
|
||||
return ""
|
||||
|
||||
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
|
||||
# Run the AI model and extract sections from its response
|
||||
response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars,
|
||||
get_settings().pr_help_docs_prompts.system,
|
||||
get_settings().pr_help_docs_prompts.user),
|
||||
get_settings().pr_help_docs_headings_prompts.system,
|
||||
get_settings().pr_help_docs_headings_prompts.user),
|
||||
model_type=ModelType.REGULAR)
|
||||
response_yaml = load_yaml(response)
|
||||
if not response_yaml:
|
||||
get_logger().exception("Failed to parse the AI response.", artifacts={'response': response})
|
||||
raise Exception(f"Failed to parse the AI response.")
|
||||
response_str = response_yaml.get('response')
|
||||
relevant_sections = response_yaml.get('relevant_sections')
|
||||
if not response_str or not relevant_sections:
|
||||
get_logger().exception("Failed to extract response/relevant sections.",
|
||||
artifacts={'response_str': response_str, 'relevant_sections': relevant_sections})
|
||||
raise Exception(f"Failed to extract response/relevant sections.")
|
||||
get_logger().error("Failed to parse the AI response.", artifacts={'response': response})
|
||||
return ""
|
||||
# else: Sanitize the output so that the file names match 1:1 dictionary keys. Do this via the file index and not its name, which may be altered by the model.
|
||||
valid_indices = [int(entry['idx']) for entry in response_yaml.get('relevant_files_ranking')
|
||||
if int(entry['idx']) >= 0 and int(entry['idx']) < len(docs_filepath_to_contents)]
|
||||
valid_file_paths = [list(docs_filepath_to_contents.keys())[idx] for idx in valid_indices]
|
||||
selected_docs_dict = {file_path: docs_filepath_to_contents[file_path] for file_path in valid_file_paths}
|
||||
docs_prompt = aggregate_documentation_files_for_prompt_contents(selected_docs_dict)
|
||||
docs_prompt_to_send_to_model = docs_prompt
|
||||
|
||||
# Format the response as markdown
|
||||
canonical_url_prefix, canonical_url_suffix = self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None,
|
||||
desired_branch=self.repo_desired_branch)
|
||||
answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections, self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix)
|
||||
# Check if the updated list of documents does not exceed limits and trim if it does:
|
||||
docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
|
||||
only_return_if_trim_needed=False)
|
||||
if not docs_prompt_to_send_to_model:
|
||||
get_logger().error("_trim_docs_input returned an empty result.")
|
||||
return ""
|
||||
return docs_prompt_to_send_to_model
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||
return ""
|
||||
|
||||
def _format_model_answer(self, response_str: str, relevant_sections: list[dict[str, str]]) -> str:
|
||||
try:
|
||||
canonical_url_prefix, canonical_url_suffix = (
|
||||
self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None,
|
||||
desired_branch=self.repo_desired_branch))
|
||||
answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections,
|
||||
self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix)
|
||||
if answer_str:
|
||||
#Remove the question phrase and replace with light bulb and a heading mentioning this is an automated answer:
|
||||
answer_str = modify_answer_section(answer_str)
|
||||
# For PR help docs, we return the answer string instead of publishing it
|
||||
#In case the response should not be published and returned as string, stop here:
|
||||
if answer_str and self.return_as_string:
|
||||
if int(response_yaml.get('question_is_relevant', '1')) == 0:
|
||||
get_logger().warning(f"Chat help docs answer would be ignored due to an invalid question.",
|
||||
artifacts={'answer_str': answer_str})
|
||||
return ""
|
||||
get_logger().info(f"Chat help docs answer", artifacts={'answer_str': answer_str})
|
||||
return answer_str
|
||||
|
||||
# Publish the answer
|
||||
if not answer_str or int(response_yaml.get('question_is_relevant', '1')) == 0:
|
||||
if not answer_str:
|
||||
get_logger().info(f"No answer found")
|
||||
return ""
|
||||
|
||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_help_docs.enable_help_text:
|
||||
answer_str += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
||||
answer_str += HelpMessage.get_help_docs_usage_guide()
|
||||
answer_str += "\n</details>\n"
|
||||
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment(answer_str)
|
||||
else:
|
||||
get_logger().info("Answer:", artifacts={'answer_str': answer_str})
|
||||
|
||||
except:
|
||||
get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)')
|
||||
|
||||
|
||||
def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False) -> List[str]:
|
||||
matching_files = []
|
||||
|
||||
# Ensure extensions don't have leading dots and are lowercase
|
||||
dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts]
|
||||
|
||||
# Walk through directory and subdirectories
|
||||
for root, _, files in os.walk(abs_docs_path):
|
||||
for file in files:
|
||||
if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]:
|
||||
continue
|
||||
# Check if file has one of the specified extensions
|
||||
if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions):
|
||||
matching_files.append(os.path.join(root, file))
|
||||
return matching_files
|
||||
return answer_str
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||
return ""
|
||||
|
Reference in New Issue
Block a user