Add support for documentation content exceeding token limits (#1670)

* - Add support for documentation content exceeding token limits via two phase operation:
1. Ask LLM to rank headings which are most likely to contain an answer to a user question
2. Provide the corresponding files for the LLM to search for an answer.

- Refactor of help_docs to make the code more readable
- For the purpose of getting canonical path: git providers to use default branch and not the PR's source branch.
- Refactor of token counting and making it clear on when an estimate factor will be used.

* Code review changes:
1. Correctly handle exception during retry_with_fallback_models (to allow fallback model to run in case of failure)
2. Better naming for default_branch in bitbucket cloud provider
This commit is contained in:
sharoneyal
2025-04-03 11:51:26 +03:00
committed by GitHub
parent ceaca3e621
commit 14971c4f5f
9 changed files with 505 additions and 187 deletions

View File

@ -1,7 +1,6 @@
from threading import Lock from threading import Lock
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from math import ceil
from tiktoken import encoding_for_model, get_encoding from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -105,6 +104,19 @@ class TokenHandler:
get_logger().error( f"Error in Anthropic token counting: {e}") get_logger().error( f"Error in Anthropic token counting: {e}")
return MaxTokens return MaxTokens
def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate):
from math import ceil
import re
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model)
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
return default_encoder_estimate
#else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate")
return ceil(elbow_factor * default_encoder_estimate)
def count_tokens(self, patch: str, force_accurate=False) -> int: def count_tokens(self, patch: str, force_accurate=False) -> int:
""" """
Counts the number of tokens in a given patch string. Counts the number of tokens in a given patch string.
@ -116,21 +128,15 @@ class TokenHandler:
The number of tokens in the patch string. The number of tokens in the patch string.
""" """
encoder_estimate = len(self.encoder.encode(patch, disallowed_special=())) encoder_estimate = len(self.encoder.encode(patch, disallowed_special=()))
#If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it.
if not force_accurate: if not force_accurate:
return encoder_estimate return encoder_estimate
#else, need to provide an accurate estimation:
#else, force_accurate==True: User requested providing an accurate estimation:
model = get_settings().config.model.lower() model = get_settings().config.model.lower()
if force_accurate and 'claude' in model and get_settings(use_context=False).get('anthropic.key'): if 'claude' in model and get_settings(use_context=False).get('anthropic.key'):
return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models
#else: Non Anthropic provided model
import re #else: Non Anthropic provided model:
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model) return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate)
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
return encoder_estimate
#else: Model is neither an OpenAI, nor an Anthropic model - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate")
return ceil(elbow_factor * encoder_estimate)

View File

@ -29,6 +29,7 @@ global_settings = Dynaconf(
"settings/custom_labels.toml", "settings/custom_labels.toml",
"settings/pr_help_prompts.toml", "settings/pr_help_prompts.toml",
"settings/pr_help_docs_prompts.toml", "settings/pr_help_docs_prompts.toml",
"settings/pr_help_docs_headings_prompts.toml",
"settings/.secrets.toml", "settings/.secrets.toml",
"settings_prod/.secrets.toml", "settings_prod/.secrets.toml",
]] ]]

View File

@ -92,7 +92,7 @@ class BitbucketProvider(GitProvider):
return ("", "") return ("", "")
workspace_name, project_name = repo_path.split('/') workspace_name, project_name = repo_path.split('/')
else: else:
desired_branch = self.get_pr_branch() desired_branch = self.get_repo_default_branch()
parsed_pr_url = urlparse(self.pr_url) parsed_pr_url = urlparse(self.pr_url)
scheme_and_netloc = parsed_pr_url.scheme + "://" + parsed_pr_url.netloc scheme_and_netloc = parsed_pr_url.scheme + "://" + parsed_pr_url.netloc
workspace_name, project_name = (self.workspace_slug, self.repo_slug) workspace_name, project_name = (self.workspace_slug, self.repo_slug)
@ -470,6 +470,16 @@ class BitbucketProvider(GitProvider):
def get_pr_branch(self): def get_pr_branch(self):
return self.pr.source_branch return self.pr.source_branch
# This function attempts to get the default branch of the repository. As a fallback, uses the PR destination branch.
# Note: Must be running from a PR context.
def get_repo_default_branch(self):
try:
url_repo = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/"
response_repo = requests.request("GET", url_repo, headers=self.headers).json()
return response_repo['mainbranch']['name']
except:
return self.pr.destination_branch
def get_pr_owner_id(self) -> str | None: def get_pr_owner_id(self) -> str | None:
return self.workspace_slug return self.workspace_slug

View File

@ -64,9 +64,15 @@ class BitbucketServerProvider(GitProvider):
workspace_name = None workspace_name = None
project_name = None project_name = None
if not repo_git_url: if not repo_git_url:
desired_branch = self.get_pr_branch()
workspace_name = self.workspace_slug workspace_name = self.workspace_slug
project_name = self.repo_slug project_name = self.repo_slug
default_branch_dict = self.bitbucket_client.get_default_branch(workspace_name, project_name)
if 'displayId' in default_branch_dict:
desired_branch = default_branch_dict['displayId']
else:
get_logger().error(f"Cannot obtain default branch for workspace_name={workspace_name}, "
f"project_name={project_name}, default_branch_dict={default_branch_dict}")
return ("", "")
elif '.git' in repo_git_url and 'scm/' in repo_git_url: elif '.git' in repo_git_url and 'scm/' in repo_git_url:
repo_path = repo_git_url.split('.git')[0].split('scm/')[-1] repo_path = repo_git_url.split('.git')[0].split('scm/')[-1]
if repo_path.count('/') == 1: # Has to have the form <workspace>/<repo> if repo_path.count('/') == 1: # Has to have the form <workspace>/<repo>

View File

@ -133,7 +133,7 @@ class GithubProvider(GitProvider):
if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, or not an issue, use self.repo object if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, or not an issue, use self.repo object
owner, repo = self.repo.split('/') owner, repo = self.repo.split('/')
scheme_and_netloc = self.base_url_html scheme_and_netloc = self.base_url_html
desired_branch = self.get_pr_branch() desired_branch = self.repo_obj.default_branch
if not all([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context if not all([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context
get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)") get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)")
return ("", "") return ("", "")

View File

@ -87,7 +87,11 @@ class GitLabProvider(GitProvider):
return ("", "") return ("", "")
if not repo_git_url: #Use PR url as context if not repo_git_url: #Use PR url as context
repo_path = self._get_project_path_from_pr_or_issue_url(self.pr_url) repo_path = self._get_project_path_from_pr_or_issue_url(self.pr_url)
desired_branch = self.get_pr_branch() try:
desired_branch = self.gl.projects.get(self.id_project).default_branch
except Exception as e:
get_logger().exception(f"Cannot get PR: {self.pr_url} default branch. Tried project ID: {self.id_project}")
return ("", "")
else: #Use repo git url else: #Use repo git url
repo_path = repo_git_url.split('.git')[0].split('.com/')[-1] repo_path = repo_git_url.split('.git')[0].split('.com/')[-1]
prefix = f"{self.gitlab_url}/{repo_path}/-/blob/{desired_branch}" prefix = f"{self.gitlab_url}/{repo_path}/-/blob/{desired_branch}"

View File

@ -9,7 +9,6 @@
model="o3-mini" model="o3-mini"
fallback_models=["gpt-4o-2024-11-20"] fallback_models=["gpt-4o-2024-11-20"]
#model_weak="gpt-4o-mini-2024-07-18" # optional, a weaker model to use for some easier tasks #model_weak="gpt-4o-mini-2024-07-18" # optional, a weaker model to use for some easier tasks
model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens.
# CLI # CLI
git_provider="github" git_provider="github"
publish_output=true publish_output=true
@ -30,6 +29,7 @@ max_description_tokens = 500
max_commits_tokens = 500 max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
custom_model_max_tokens=-1 # for models not in the default list custom_model_max_tokens=-1 # for models not in the default list
model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens - applicable only when requesting an accurate estimate.
# patch extension logic # patch extension logic
patch_extension_skip_types =[".md",".txt"] patch_extension_skip_types =[".md",".txt"]
allow_dynamic_context=true allow_dynamic_context=true

View File

@ -0,0 +1,101 @@
[pr_help_docs_headings_prompts]
system="""You are Doc-helper, a language model that ranks documentation files based on their relevance to user questions.
You will receive a question, a repository url and file names along with optional groups of headings extracted from such files from that repository (either as markdown or as restructred text).
Your task is to rank file paths based on how likely they contain the answer to a user's question, using only the headings from each such file and the file name.
======
==file name==
'src/file1.py'
==index==
0 based integer
==file headings==
heading #1
heading #2
...
==file name==
'src/file2.py'
==index==
0 based integer
==file headings==
heading #1
heading #2
...
...
======
Additional instructions:
- Consider only the file names and section headings within each document
- Present the most relevant files first, based strictly on how well their headings and file names align with user question
The output must be a YAML object equivalent to type $DocHeadingsHelper, according to the following Pydantic definitions:
=====
class file_idx_and_path(BaseModel):
idx: int = Field(description="The zero based index of file_name, as it appeared in the original list of headings. Cannot be negative.")
file_name: str = Field(description="The file_name exactly as it appeared in the question")
class DocHeadingsHelper(BaseModel):
user_question: str = Field(description="The user's question")
relevant_files_ranking: List[file_idx_and_path] = Field(description="Files sorted in descending order by relevance to question")
=====
Example output:
```yaml
user_question: |
...
relevant_files_ranking:
- idx: 101
file_name: "src/file1.py"
- ...
"""
user="""\
Documentation url: '{{ docs_url|trim }}'
-----
User's Question:
=====
{{ question|trim }}
=====
Filenames with optional headings from documentation website content:
=====
{{ snippets|trim }}
=====
Reminder: The output must be a YAML object equivalent to type $DocHeadingsHelper, similar to the following example output:
=====
Example output:
```yaml
user_question: |
...
relevant_files_ranking:
- idx: 101
file_name: "src/file1.py"
- ...
=====
Important Notes:
1. Output most relevant file names first, by descending order of relevancy.
2. Only include files with non-negative indices
Response (should be a valid YAML, and nothing else).
```yaml
"""

View File

@ -1,11 +1,11 @@
import copy import copy
from functools import partial from functools import partial
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
import math import math
import os import os
import re import re
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Dict, List, Optional, Tuple
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
@ -78,47 +78,118 @@ def get_maximal_text_input_length_for_token_count_estimation():
return 900000 #Claude API for token estimation allows maximal text input of 900K chars return 900000 #Claude API for token estimation allows maximal text input of 900K chars
return math.inf #Otherwise, no known limitation on input text just for token estimation return math.inf #Otherwise, no known limitation on input text just for token estimation
# Load documentation files to memory, decorating them with a header to mark where each file begins, def return_document_headings(text: str, ext: str) -> str:
# as to help the LLM to give a better answer. try:
def aggregate_documentation_files_for_prompt_contents(base_path: str, doc_files: List[str]) -> Optional[str]: lines = text.split('\n')
docs_prompt = "" headings = set()
for file in doc_files:
try:
with open(file, 'r', encoding='utf-8') as f:
content = f.read()
# Skip files with no text content
if not re.search(r'[a-zA-Z]', content):
continue
file_path = str(file).replace(str(base_path), '')
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{content.strip()}\n=========\n\n"
except Exception as e:
get_logger().warning(f"Error while reading the file {file}: {e}")
continue
if not docs_prompt:
get_logger().error("Couldn't find any usable documentation files. Returning None.")
return None
return docs_prompt
def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: List[Dict[str, str]], if not text or not re.search(r'[a-zA-Z]', text):
supported_suffixes: List[str], base_url_prefix: str, base_url_suffix: str="") -> str: get_logger().error(f"Empty or non text content found in text: {text}.")
base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix return ""
answer_str = ""
answer_str += f"### Question: \n{question_str}\n\n" if ext in ['.md', '.mdx']:
answer_str += f"### Answer:\n{response_str.strip()}\n\n" # Extract Markdown headings (lines starting with #)
answer_str += f"#### Relevant Sources:\n\n" headings = {line.strip() for line in lines if line.strip().startswith('#')}
for section in relevant_sections: elif ext == '.rst':
file = section.get('file_name').strip() # Find indices of lines that have all same character:
ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)] #Allowed characters according to list from: https://docutils.sourceforge.io/docs/ref/rst/restructuredtext.html#sections
if not ext: section_chars = set('!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~')
get_logger().warning(f"Unsupported file extension: {file}")
continue # Find potential section marker lines (underlines/overlines): They have to be the same character
if str(section['relevant_section_header_string']).strip(): marker_lines = []
markdown_header = format_markdown_header(section['relevant_section_header_string']) for i, line in enumerate(lines):
if base_url_prefix: line = line.rstrip()
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}#{markdown_header}\n" if line and all(c == line[0] for c in line) and line[0] in section_chars:
marker_lines.append((i, len(line)))
# Check for headings adjacent to marker lines (below + text must be in length equal or less)
for idx, length in marker_lines:
# Check if it's an underline (heading is above it)
if idx > 0 and lines[idx - 1].rstrip() and len(lines[idx - 1].rstrip()) <= length:
headings.add(lines[idx - 1].rstrip())
else: else:
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n" get_logger().error(f"Unsupported file extension: {ext}")
return answer_str return ""
return '\n'.join(headings)
except Exception as e:
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
return ""
# Load documentation files to memory: full file path (as will be given as prompt) -> doc contents
def map_documentation_files_to_contents(base_path: str, doc_files: list[str], max_allowed_file_len=5000) -> dict[str, str]:
try:
returned_dict = {}
for file in doc_files:
try:
with open(file, 'r', encoding='utf-8') as f:
content = f.read()
# Skip files with no text content
if not re.search(r'[a-zA-Z]', content):
continue
if len(content) > max_allowed_file_len:
get_logger().warning(f"File {file} length: {len(content)} exceeds limit: {max_allowed_file_len}, so it will be trimmed.")
content = content[:max_allowed_file_len]
file_path = str(file).replace(str(base_path), '')
returned_dict[file_path] = content.strip()
except Exception as e:
get_logger().warning(f"Error while reading the file {file}: {e}")
continue
if not returned_dict:
get_logger().error("Couldn't find any usable documentation files. Returning empty dict.")
return returned_dict
except Exception as e:
get_logger().exception(f"Unexpected exception thrown. Returning empty dict.")
return {}
# Goes over files' contents, generating payload for prompt while decorating them with a header to mark where each file begins,
# as to help the LLM to give a better answer.
def aggregate_documentation_files_for_prompt_contents(file_path_to_contents: dict[str, str], return_just_headings=False) -> str:
try:
docs_prompt = ""
for idx, file_path in enumerate(file_path_to_contents):
file_contents = file_path_to_contents[file_path].strip()
if not file_contents:
get_logger().error(f"Got empty file contents for: {file_path}. Skipping this file.")
continue
if return_just_headings:
file_headings = return_document_headings(file_contents, os.path.splitext(file_path)[-1]).strip()
if file_headings:
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n==file headings==\n\n{file_headings}\n=========\n\n"
else:
get_logger().warning(f"No headers for: {file_path}. Will only use filename")
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n"
else:
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{file_contents}\n=========\n\n"
return docs_prompt
except Exception as e:
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
return ""
def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: list[dict[str, str]],
supported_suffixes: list[str], base_url_prefix: str, base_url_suffix: str="") -> str:
try:
base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix
answer_str = ""
answer_str += f"### Question: \n{question_str}\n\n"
answer_str += f"### Answer:\n{response_str.strip()}\n\n"
answer_str += f"#### Relevant Sources:\n\n"
for section in relevant_sections:
file = section.get('file_name').lstrip('/').strip() #Remove any '/' in the beginning, since some models do it anyway
ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)]
if not ext:
get_logger().warning(f"Unsupported file extension: {file}")
continue
if str(section['relevant_section_header_string']).strip():
markdown_header = format_markdown_header(section['relevant_section_header_string'])
if base_url_prefix:
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}#{markdown_header}\n"
else:
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n"
return answer_str
except Exception as e:
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
return ""
def format_markdown_header(header: str) -> str: def format_markdown_header(header: str) -> str:
try: try:
@ -157,87 +228,103 @@ def clean_markdown_content(content: str) -> str:
Returns: Returns:
Cleaned markdown content Cleaned markdown content
""" """
# Remove HTML comments try:
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL) # Remove HTML comments
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
# Remove frontmatter (YAML between --- or +++ delimiters) # Remove frontmatter (YAML between --- or +++ delimiters)
content = re.sub(r'^---\s*\n.*?\n---\s*\n', '', content, flags=re.DOTALL) content = re.sub(r'^---\s*\n.*?\n---\s*\n', '', content, flags=re.DOTALL)
content = re.sub(r'^\+\+\+\s*\n.*?\n\+\+\+\s*\n', '', content, flags=re.DOTALL) content = re.sub(r'^\+\+\+\s*\n.*?\n\+\+\+\s*\n', '', content, flags=re.DOTALL)
# Remove excessive blank lines (more than 2 consecutive) # Remove excessive blank lines (more than 2 consecutive)
content = re.sub(r'\n{3,}', '\n\n', content) content = re.sub(r'\n{3,}', '\n\n', content)
# Remove HTML tags that are often used for styling only # Remove HTML tags that are often used for styling only
content = re.sub(r'<div.*?>|</div>|<span.*?>|</span>', '', content, flags=re.DOTALL) content = re.sub(r'<div.*?>|</div>|<span.*?>|</span>', '', content, flags=re.DOTALL)
# Remove image alt text which can be verbose # Remove image alt text which can be verbose
content = re.sub(r'!\[(.*?)\]', '![]', content) content = re.sub(r'!\[(.*?)\]', '![]', content)
# Remove images completely # Remove images completely
content = re.sub(r'!\[.*?\]\(.*?\)', '', content) content = re.sub(r'!\[.*?\]\(.*?\)', '', content)
# Remove simple HTML tags but preserve content between them # Remove simple HTML tags but preserve content between them
content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)</\1>', content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)</\1>',
r'\2', content, flags=re.DOTALL) r'\2', content, flags=re.DOTALL)
return content.strip() return content.strip()
except Exception as e:
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
return ""
class PredictionPreparator: class PredictionPreparator:
def __init__(self, ai_handler, vars, system_prompt, user_prompt): def __init__(self, ai_handler, vars, system_prompt, user_prompt):
self.ai_handler = ai_handler try:
variables = copy.deepcopy(vars) self.ai_handler = ai_handler
environment = Environment(undefined=StrictUndefined) variables = copy.deepcopy(vars)
self.system_prompt = environment.from_string(system_prompt).render(variables) environment = Environment(undefined=StrictUndefined)
self.user_prompt = environment.from_string(user_prompt).render(variables) self.system_prompt = environment.from_string(system_prompt).render(variables)
self.user_prompt = environment.from_string(user_prompt).render(variables)
except Exception as e:
get_logger().exception(f"Caught exception during init. Setting ai_handler to None to prevent __call__.")
self.ai_handler = None
#Called by retry_with_fallback_models and therefore, on any failure must throw an exception:
async def __call__(self, model: str) -> str: async def __call__(self, model: str) -> str:
if not self.ai_handler:
get_logger().error("ai handler not set. Cannot invoke model!")
raise ValueError("PredictionPreparator not initialized")
try: try:
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=self.system_prompt, user=self.user_prompt) model=model, temperature=get_settings().config.temperature, system=self.system_prompt, user=self.user_prompt)
return response return response
except Exception as e: except Exception as e:
get_logger().error(f"Error while preparing prediction: {e}") get_logger().exception("Caught exception during prediction.", artifacts={'system': self.system_prompt, 'user': self.user_prompt})
return "" raise e
class PRHelpDocs(object): class PRHelpDocs(object):
def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: Tuple[str]=None, return_as_string: bool=False): def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: tuple[str]=None, return_as_string: bool=False):
self.ctx_url = ctx_url try:
self.question = args[0] if args else None self.ctx_url = ctx_url
self.return_as_string = return_as_string self.question = args[0] if args else None
self.repo_url_given_explicitly = True self.return_as_string = return_as_string
self.repo_url = get_settings().get('PR_HELP_DOCS.REPO_URL', '') self.repo_url_given_explicitly = True
self.repo_desired_branch = get_settings().get('PR_HELP_DOCS.REPO_DEFAULT_BRANCH', 'main') #Ignored if self.repo_url is empty self.repo_url = get_settings().get('PR_HELP_DOCS.REPO_URL', '')
self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README']) self.repo_desired_branch = get_settings().get('PR_HELP_DOCS.REPO_DEFAULT_BRANCH', 'main') #Ignored if self.repo_url is empty
self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS'] self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README'])
self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH'] self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS']
self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH']
retrieved_settings = [self.include_root_readme_file, self.supported_doc_exts, self.docs_path] retrieved_settings = [self.include_root_readme_file, self.supported_doc_exts, self.docs_path]
if any([setting is None for setting in retrieved_settings]): if any([setting is None for setting in retrieved_settings]):
raise Exception(f"One of the settings is invalid: {retrieved_settings}") raise Exception(f"One of the settings is invalid: {retrieved_settings}")
self.git_provider = get_git_provider_with_context(ctx_url) self.git_provider = get_git_provider_with_context(ctx_url)
if not self.git_provider: if not self.git_provider:
raise Exception(f"No git provider found at {ctx_url}") raise Exception(f"No git provider found at {ctx_url}")
if not self.repo_url:
self.repo_url_given_explicitly = False
get_logger().debug(f"No explicit repo url provided, deducing it from type: {self.git_provider.__class__.__name__} "
f"context url: {self.ctx_url}")
self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url)
if not self.repo_url: if not self.repo_url:
raise Exception(f"Unable to deduce repo url from type: {self.git_provider.__class__.__name__} url: {self.ctx_url}") self.repo_url_given_explicitly = False
get_logger().debug(f"deduced repo url: {self.repo_url}") get_logger().debug(f"No explicit repo url provided, deducing it from type: {self.git_provider.__class__.__name__} "
self.repo_desired_branch = None #Inferred from the repo provider. f"context url: {self.ctx_url}")
self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url)
if not self.repo_url:
raise Exception(f"Unable to deduce repo url from type: {self.git_provider.__class__.__name__} url: {self.ctx_url}")
get_logger().debug(f"deduced repo url: {self.repo_url}")
self.repo_desired_branch = None #Inferred from the repo provider.
self.ai_handler = ai_handler() self.ai_handler = ai_handler()
self.vars = { self.vars = {
"docs_url": self.repo_url, "docs_url": self.repo_url,
"question": self.question, "question": self.question,
"snippets": "", "snippets": "",
} }
self.token_handler = TokenHandler(None, self.token_handler = TokenHandler(None,
self.vars, self.vars,
get_settings().pr_help_docs_prompts.system, get_settings().pr_help_docs_prompts.system,
get_settings().pr_help_docs_prompts.user) get_settings().pr_help_docs_prompts.user)
except Exception as e:
get_logger().exception(f"Caught exception during init. Setting self.question to None to prevent run() to do anything.")
self.question = None
async def run(self): async def run(self):
if not self.question: if not self.question:
@ -246,7 +333,93 @@ class PRHelpDocs(object):
try: try:
# Clone the repository and gather relevant documentation files. # Clone the repository and gather relevant documentation files.
docs_prompt = None docs_filepath_to_contents = self._gen_filenames_to_contents_map_from_repo()
#Generate prompt for the AI model. This will be the full text of all the documentation files combined.
docs_prompt = aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents)
if not docs_filepath_to_contents or not docs_prompt:
get_logger().warning(f"Could not find any usable documentation. Returning with no result...")
return None
docs_prompt_to_send_to_model = docs_prompt
# Estimate how many tokens will be needed.
# In case the expected number of tokens exceeds LLM limits, retry with just headings, asking the LLM to rank according to relevance to the question.
# Based on returned ranking, rerun but sort the documents accordingly, this time, trim in case of exceeding limit.
#First, check if the text is not too long to even query the LLM provider:
max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation()
invoke_llm_just_with_headings = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
only_return_if_trim_needed=True)
if invoke_llm_just_with_headings:
#Entire docs is too long. Rank and return according to relevance.
docs_prompt_to_send_to_model = await self._rank_docs_and_return_them_as_prompt(docs_filepath_to_contents,
max_allowed_txt_input)
if not docs_prompt_to_send_to_model:
get_logger().error("Failed to generate docs prompt for model. Returning with no result...")
return
# At this point, either all original documents be used (if their total length doesn't exceed limits), or only those selected.
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
# Run the AI model and extract sections from its response
response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars,
get_settings().pr_help_docs_prompts.system,
get_settings().pr_help_docs_prompts.user),
model_type=ModelType.REGULAR)
response_yaml = load_yaml(response)
if not response_yaml:
get_logger().error("Failed to parse the AI response.", artifacts={'response': response})
return
response_str = response_yaml.get('response')
relevant_sections = response_yaml.get('relevant_sections')
if not response_str or not relevant_sections:
get_logger().error("Failed to extract response/relevant sections.",
artifacts={'raw_response': response, 'response_str': response_str, 'relevant_sections': relevant_sections})
return
if int(response_yaml.get('question_is_relevant', '1')) == 0:
get_logger().warning(f"Question is not relevant. Returning without an answer...",
artifacts={'raw_response': response})
return
# Format the response as markdown
answer_str = self._format_model_answer(response_str, relevant_sections)
if self.return_as_string: #Skip publishing
return answer_str
#Otherwise, publish the answer if answer is non empty and publish is not turned off:
if answer_str and get_settings().config.publish_output:
self.git_provider.publish_comment(answer_str)
else:
get_logger().info("Answer:", artifacts={'answer_str': answer_str})
return answer_str
except Exception as e:
get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)')
def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False, max_allowed_files=5000) -> list[str]:
try:
matching_files = []
# Ensure extensions don't have leading dots and are lowercase
dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts]
# Walk through directory and subdirectories
file_cntr = 0
for root, _, files in os.walk(abs_docs_path):
for file in files:
if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]:
continue
# Check if file has one of the specified extensions
if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions):
file_cntr+=1
matching_files.append(os.path.join(root, file))
if file_cntr >= max_allowed_files:
get_logger().warning(f"Found at least {max_allowed_files} files in {abs_docs_path}, skipping the rest.")
return matching_files
return matching_files
except Exception as e:
get_logger().exception(f"Unexpected exception thrown. Returning empty list.")
return []
def _gen_filenames_to_contents_map_from_repo(self) -> dict[str, str]:
try:
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
get_logger().debug(f"About to clone repository: {self.repo_url} to temporary directory: {tmp_dir}...") get_logger().debug(f"About to clone repository: {self.repo_url} to temporary directory: {tmp_dir}...")
returned_cloned_repo_root = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False) returned_cloned_repo_root = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False)
@ -268,103 +441,120 @@ class PRHelpDocs(object):
ignore_readme=(self.docs_path=='.'))) ignore_readme=(self.docs_path=='.')))
if not doc_files: if not doc_files:
get_logger().warning(f"No documentation files found matching file extensions: " get_logger().warning(f"No documentation files found matching file extensions: "
f"{self.supported_doc_exts} under repo: {self.repo_url} path: {self.docs_path}") f"{self.supported_doc_exts} under repo: {self.repo_url} "
return None f"path: {self.docs_path}. Returning empty dict.")
return {}
get_logger().info(f'Answering a question inside context {self.ctx_url} for repo: {self.repo_url}' get_logger().info(f'For context {self.ctx_url} and repo: {self.repo_url}'
f' using the following documentation files: ', artifacts={'doc_files': doc_files}) f' will be using the following documentation files: ',
artifacts={'doc_files': doc_files})
docs_prompt = aggregate_documentation_files_for_prompt_contents(returned_cloned_repo_root.path, doc_files) return map_documentation_files_to_contents(returned_cloned_repo_root.path, doc_files)
if not docs_prompt: except Exception as e:
get_logger().warning(f"Error reading one of the documentation files. Returning with no result...") get_logger().exception(f"Unexpected exception thrown. Returning empty dict.")
return None return {}
docs_prompt_to_send_to_model = docs_prompt
# Estimate how many tokens will be needed. Trim in case of exceeding limit. def _trim_docs_input(self, docs_input: str, max_allowed_txt_input: int, only_return_if_trim_needed=False) -> bool|str:
# Firstly, check if text needs to be trimmed, as some models fail to return the estimated token count if the input text is too long. try:
max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation() if len(docs_input) >= max_allowed_txt_input:
if len(docs_prompt_to_send_to_model) >= max_allowed_txt_input: get_logger().warning(
get_logger().warning(f"Text length: {len(docs_prompt_to_send_to_model)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...") f"Text length: {len(docs_input)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...")
docs_prompt_to_send_to_model = docs_prompt_to_send_to_model[:max_allowed_txt_input] if only_return_if_trim_needed:
return True
docs_input = docs_input[:max_allowed_txt_input]
# Then, count the tokens in the prompt. If the count exceeds the limit, trim the text. # Then, count the tokens in the prompt. If the count exceeds the limit, trim the text.
token_count = self.token_handler.count_tokens(docs_prompt_to_send_to_model, force_accurate=True) token_count = self.token_handler.count_tokens(docs_input, force_accurate=True)
get_logger().debug(f"Estimated token count of documentation to send to model: {token_count}") get_logger().debug(f"Estimated token count of documentation to send to model: {token_count}")
model = get_settings().config.model model = get_settings().config.model
if model in MAX_TOKENS: if model in MAX_TOKENS:
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt max_tokens_full = MAX_TOKENS[
model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
else: else:
max_tokens_full = get_max_tokens(model) max_tokens_full = get_max_tokens(model)
delta_output = 5000 #Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines. delta_output = 5000 # Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines.
if token_count > max_tokens_full - delta_output: if token_count > max_tokens_full - delta_output:
docs_prompt_to_send_to_model = clean_markdown_content(docs_prompt_to_send_to_model) #Reduce unnecessary text/images/etc. if only_return_if_trim_needed:
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...") return True
docs_prompt_to_send_to_model = clip_tokens(docs_prompt_to_send_to_model, max_tokens_full - delta_output, docs_input = clean_markdown_content(
docs_input) # Reduce unnecessary text/images/etc.
get_logger().info(
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...")
docs_input = clip_tokens(docs_input, max_tokens_full - delta_output,
num_input_tokens=token_count) num_input_tokens=token_count)
self.vars['snippets'] = docs_prompt_to_send_to_model.strip() if only_return_if_trim_needed:
return False
return docs_input
except Exception as e:
# Unexpected exception. Rethrowing it since:
# 1. This is an internal function.
# 2. An empty str/False result is a valid one - would require now checking also for None.
get_logger().exception(f"Unexpected exception thrown. Rethrowing it...")
raise e
async def _rank_docs_and_return_them_as_prompt(self, docs_filepath_to_contents: dict[str, str], max_allowed_txt_input: int) -> str:
try:
#Return just file name and their headings (if exist):
docs_prompt_to_send_to_model = (
aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents,
return_just_headings=True))
# Verify list of headings does not exceed limits - trim it if it does.
docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
only_return_if_trim_needed=False)
if not docs_prompt_to_send_to_model:
get_logger().error("_trim_docs_input returned an empty result.")
return ""
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
# Run the AI model and extract sections from its response # Run the AI model and extract sections from its response
response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars, response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars,
get_settings().pr_help_docs_prompts.system, get_settings().pr_help_docs_headings_prompts.system,
get_settings().pr_help_docs_prompts.user), get_settings().pr_help_docs_headings_prompts.user),
model_type=ModelType.REGULAR) model_type=ModelType.REGULAR)
response_yaml = load_yaml(response) response_yaml = load_yaml(response)
if not response_yaml: if not response_yaml:
get_logger().exception("Failed to parse the AI response.", artifacts={'response': response}) get_logger().error("Failed to parse the AI response.", artifacts={'response': response})
raise Exception(f"Failed to parse the AI response.") return ""
response_str = response_yaml.get('response') # else: Sanitize the output so that the file names match 1:1 dictionary keys. Do this via the file index and not its name, which may be altered by the model.
relevant_sections = response_yaml.get('relevant_sections') valid_indices = [int(entry['idx']) for entry in response_yaml.get('relevant_files_ranking')
if not response_str or not relevant_sections: if int(entry['idx']) >= 0 and int(entry['idx']) < len(docs_filepath_to_contents)]
get_logger().exception("Failed to extract response/relevant sections.", valid_file_paths = [list(docs_filepath_to_contents.keys())[idx] for idx in valid_indices]
artifacts={'response_str': response_str, 'relevant_sections': relevant_sections}) selected_docs_dict = {file_path: docs_filepath_to_contents[file_path] for file_path in valid_file_paths}
raise Exception(f"Failed to extract response/relevant sections.") docs_prompt = aggregate_documentation_files_for_prompt_contents(selected_docs_dict)
docs_prompt_to_send_to_model = docs_prompt
# Format the response as markdown # Check if the updated list of documents does not exceed limits and trim if it does:
canonical_url_prefix, canonical_url_suffix = self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None, docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
desired_branch=self.repo_desired_branch) only_return_if_trim_needed=False)
answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections, self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix) if not docs_prompt_to_send_to_model:
get_logger().error("_trim_docs_input returned an empty result.")
return ""
return docs_prompt_to_send_to_model
except Exception as e:
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
return ""
def _format_model_answer(self, response_str: str, relevant_sections: list[dict[str, str]]) -> str:
try:
canonical_url_prefix, canonical_url_suffix = (
self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None,
desired_branch=self.repo_desired_branch))
answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections,
self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix)
if answer_str: if answer_str:
#Remove the question phrase and replace with light bulb and a heading mentioning this is an automated answer: #Remove the question phrase and replace with light bulb and a heading mentioning this is an automated answer:
answer_str = modify_answer_section(answer_str) answer_str = modify_answer_section(answer_str)
# For PR help docs, we return the answer string instead of publishing it #In case the response should not be published and returned as string, stop here:
if answer_str and self.return_as_string: if answer_str and self.return_as_string:
if int(response_yaml.get('question_is_relevant', '1')) == 0:
get_logger().warning(f"Chat help docs answer would be ignored due to an invalid question.",
artifacts={'answer_str': answer_str})
return ""
get_logger().info(f"Chat help docs answer", artifacts={'answer_str': answer_str}) get_logger().info(f"Chat help docs answer", artifacts={'answer_str': answer_str})
return answer_str return answer_str
if not answer_str:
# Publish the answer
if not answer_str or int(response_yaml.get('question_is_relevant', '1')) == 0:
get_logger().info(f"No answer found") get_logger().info(f"No answer found")
return "" return ""
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_help_docs.enable_help_text: if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_help_docs.enable_help_text:
answer_str += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n" answer_str += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
answer_str += HelpMessage.get_help_docs_usage_guide() answer_str += HelpMessage.get_help_docs_usage_guide()
answer_str += "\n</details>\n" answer_str += "\n</details>\n"
return answer_str
if get_settings().config.publish_output: except Exception as e:
self.git_provider.publish_comment(answer_str) get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
else: return ""
get_logger().info("Answer:", artifacts={'answer_str': answer_str})
except:
get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)')
def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False) -> List[str]:
matching_files = []
# Ensure extensions don't have leading dots and are lowercase
dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts]
# Walk through directory and subdirectories
for root, _, files in os.walk(abs_docs_path):
for file in files:
if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]:
continue
# Check if file has one of the specified extensions
if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions):
matching_files.append(os.path.join(root, file))
return matching_files