mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Add support for documentation content exceeding token limits (#1670)
* - Add support for documentation content exceeding token limits via two phase operation: 1. Ask LLM to rank headings which are most likely to contain an answer to a user question 2. Provide the corresponding files for the LLM to search for an answer. - Refactor of help_docs to make the code more readable - For the purpose of getting canonical path: git providers to use default branch and not the PR's source branch. - Refactor of token counting and making it clear on when an estimate factor will be used. * Code review changes: 1. Correctly handle exception during retry_with_fallback_models (to allow fallback model to run in case of failure) 2. Better naming for default_branch in bitbucket cloud provider
This commit is contained in:
@ -1,7 +1,6 @@
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
from math import ceil
|
|
||||||
from tiktoken import encoding_for_model, get_encoding
|
from tiktoken import encoding_for_model, get_encoding
|
||||||
|
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
@ -105,6 +104,19 @@ class TokenHandler:
|
|||||||
get_logger().error( f"Error in Anthropic token counting: {e}")
|
get_logger().error( f"Error in Anthropic token counting: {e}")
|
||||||
return MaxTokens
|
return MaxTokens
|
||||||
|
|
||||||
|
def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate):
|
||||||
|
from math import ceil
|
||||||
|
import re
|
||||||
|
|
||||||
|
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model)
|
||||||
|
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
|
||||||
|
return default_encoder_estimate
|
||||||
|
#else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
|
||||||
|
|
||||||
|
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
|
||||||
|
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate")
|
||||||
|
return ceil(elbow_factor * default_encoder_estimate)
|
||||||
|
|
||||||
def count_tokens(self, patch: str, force_accurate=False) -> int:
|
def count_tokens(self, patch: str, force_accurate=False) -> int:
|
||||||
"""
|
"""
|
||||||
Counts the number of tokens in a given patch string.
|
Counts the number of tokens in a given patch string.
|
||||||
@ -116,21 +128,15 @@ class TokenHandler:
|
|||||||
The number of tokens in the patch string.
|
The number of tokens in the patch string.
|
||||||
"""
|
"""
|
||||||
encoder_estimate = len(self.encoder.encode(patch, disallowed_special=()))
|
encoder_estimate = len(self.encoder.encode(patch, disallowed_special=()))
|
||||||
|
|
||||||
|
#If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it.
|
||||||
if not force_accurate:
|
if not force_accurate:
|
||||||
return encoder_estimate
|
return encoder_estimate
|
||||||
#else, need to provide an accurate estimation:
|
|
||||||
|
|
||||||
|
#else, force_accurate==True: User requested providing an accurate estimation:
|
||||||
model = get_settings().config.model.lower()
|
model = get_settings().config.model.lower()
|
||||||
if force_accurate and 'claude' in model and get_settings(use_context=False).get('anthropic.key'):
|
if 'claude' in model and get_settings(use_context=False).get('anthropic.key'):
|
||||||
return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models
|
return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models
|
||||||
#else: Non Anthropic provided model
|
|
||||||
|
|
||||||
import re
|
#else: Non Anthropic provided model:
|
||||||
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model)
|
return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate)
|
||||||
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
|
|
||||||
return encoder_estimate
|
|
||||||
#else: Model is neither an OpenAI, nor an Anthropic model - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
|
|
||||||
|
|
||||||
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
|
|
||||||
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate")
|
|
||||||
return ceil(elbow_factor * encoder_estimate)
|
|
||||||
|
@ -29,6 +29,7 @@ global_settings = Dynaconf(
|
|||||||
"settings/custom_labels.toml",
|
"settings/custom_labels.toml",
|
||||||
"settings/pr_help_prompts.toml",
|
"settings/pr_help_prompts.toml",
|
||||||
"settings/pr_help_docs_prompts.toml",
|
"settings/pr_help_docs_prompts.toml",
|
||||||
|
"settings/pr_help_docs_headings_prompts.toml",
|
||||||
"settings/.secrets.toml",
|
"settings/.secrets.toml",
|
||||||
"settings_prod/.secrets.toml",
|
"settings_prod/.secrets.toml",
|
||||||
]]
|
]]
|
||||||
|
@ -92,7 +92,7 @@ class BitbucketProvider(GitProvider):
|
|||||||
return ("", "")
|
return ("", "")
|
||||||
workspace_name, project_name = repo_path.split('/')
|
workspace_name, project_name = repo_path.split('/')
|
||||||
else:
|
else:
|
||||||
desired_branch = self.get_pr_branch()
|
desired_branch = self.get_repo_default_branch()
|
||||||
parsed_pr_url = urlparse(self.pr_url)
|
parsed_pr_url = urlparse(self.pr_url)
|
||||||
scheme_and_netloc = parsed_pr_url.scheme + "://" + parsed_pr_url.netloc
|
scheme_and_netloc = parsed_pr_url.scheme + "://" + parsed_pr_url.netloc
|
||||||
workspace_name, project_name = (self.workspace_slug, self.repo_slug)
|
workspace_name, project_name = (self.workspace_slug, self.repo_slug)
|
||||||
@ -470,6 +470,16 @@ class BitbucketProvider(GitProvider):
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.pr.source_branch
|
return self.pr.source_branch
|
||||||
|
|
||||||
|
# This function attempts to get the default branch of the repository. As a fallback, uses the PR destination branch.
|
||||||
|
# Note: Must be running from a PR context.
|
||||||
|
def get_repo_default_branch(self):
|
||||||
|
try:
|
||||||
|
url_repo = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/"
|
||||||
|
response_repo = requests.request("GET", url_repo, headers=self.headers).json()
|
||||||
|
return response_repo['mainbranch']['name']
|
||||||
|
except:
|
||||||
|
return self.pr.destination_branch
|
||||||
|
|
||||||
def get_pr_owner_id(self) -> str | None:
|
def get_pr_owner_id(self) -> str | None:
|
||||||
return self.workspace_slug
|
return self.workspace_slug
|
||||||
|
|
||||||
|
@ -64,9 +64,15 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
workspace_name = None
|
workspace_name = None
|
||||||
project_name = None
|
project_name = None
|
||||||
if not repo_git_url:
|
if not repo_git_url:
|
||||||
desired_branch = self.get_pr_branch()
|
|
||||||
workspace_name = self.workspace_slug
|
workspace_name = self.workspace_slug
|
||||||
project_name = self.repo_slug
|
project_name = self.repo_slug
|
||||||
|
default_branch_dict = self.bitbucket_client.get_default_branch(workspace_name, project_name)
|
||||||
|
if 'displayId' in default_branch_dict:
|
||||||
|
desired_branch = default_branch_dict['displayId']
|
||||||
|
else:
|
||||||
|
get_logger().error(f"Cannot obtain default branch for workspace_name={workspace_name}, "
|
||||||
|
f"project_name={project_name}, default_branch_dict={default_branch_dict}")
|
||||||
|
return ("", "")
|
||||||
elif '.git' in repo_git_url and 'scm/' in repo_git_url:
|
elif '.git' in repo_git_url and 'scm/' in repo_git_url:
|
||||||
repo_path = repo_git_url.split('.git')[0].split('scm/')[-1]
|
repo_path = repo_git_url.split('.git')[0].split('scm/')[-1]
|
||||||
if repo_path.count('/') == 1: # Has to have the form <workspace>/<repo>
|
if repo_path.count('/') == 1: # Has to have the form <workspace>/<repo>
|
||||||
|
@ -133,7 +133,7 @@ class GithubProvider(GitProvider):
|
|||||||
if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, or not an issue, use self.repo object
|
if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, or not an issue, use self.repo object
|
||||||
owner, repo = self.repo.split('/')
|
owner, repo = self.repo.split('/')
|
||||||
scheme_and_netloc = self.base_url_html
|
scheme_and_netloc = self.base_url_html
|
||||||
desired_branch = self.get_pr_branch()
|
desired_branch = self.repo_obj.default_branch
|
||||||
if not all([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context
|
if not all([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context
|
||||||
get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)")
|
get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)")
|
||||||
return ("", "")
|
return ("", "")
|
||||||
|
@ -87,7 +87,11 @@ class GitLabProvider(GitProvider):
|
|||||||
return ("", "")
|
return ("", "")
|
||||||
if not repo_git_url: #Use PR url as context
|
if not repo_git_url: #Use PR url as context
|
||||||
repo_path = self._get_project_path_from_pr_or_issue_url(self.pr_url)
|
repo_path = self._get_project_path_from_pr_or_issue_url(self.pr_url)
|
||||||
desired_branch = self.get_pr_branch()
|
try:
|
||||||
|
desired_branch = self.gl.projects.get(self.id_project).default_branch
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Cannot get PR: {self.pr_url} default branch. Tried project ID: {self.id_project}")
|
||||||
|
return ("", "")
|
||||||
else: #Use repo git url
|
else: #Use repo git url
|
||||||
repo_path = repo_git_url.split('.git')[0].split('.com/')[-1]
|
repo_path = repo_git_url.split('.git')[0].split('.com/')[-1]
|
||||||
prefix = f"{self.gitlab_url}/{repo_path}/-/blob/{desired_branch}"
|
prefix = f"{self.gitlab_url}/{repo_path}/-/blob/{desired_branch}"
|
||||||
|
@ -9,7 +9,6 @@
|
|||||||
model="o3-mini"
|
model="o3-mini"
|
||||||
fallback_models=["gpt-4o-2024-11-20"]
|
fallback_models=["gpt-4o-2024-11-20"]
|
||||||
#model_weak="gpt-4o-mini-2024-07-18" # optional, a weaker model to use for some easier tasks
|
#model_weak="gpt-4o-mini-2024-07-18" # optional, a weaker model to use for some easier tasks
|
||||||
model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens.
|
|
||||||
# CLI
|
# CLI
|
||||||
git_provider="github"
|
git_provider="github"
|
||||||
publish_output=true
|
publish_output=true
|
||||||
@ -30,6 +29,7 @@ max_description_tokens = 500
|
|||||||
max_commits_tokens = 500
|
max_commits_tokens = 500
|
||||||
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
||||||
custom_model_max_tokens=-1 # for models not in the default list
|
custom_model_max_tokens=-1 # for models not in the default list
|
||||||
|
model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens - applicable only when requesting an accurate estimate.
|
||||||
# patch extension logic
|
# patch extension logic
|
||||||
patch_extension_skip_types =[".md",".txt"]
|
patch_extension_skip_types =[".md",".txt"]
|
||||||
allow_dynamic_context=true
|
allow_dynamic_context=true
|
||||||
|
101
pr_agent/settings/pr_help_docs_headings_prompts.toml
Normal file
101
pr_agent/settings/pr_help_docs_headings_prompts.toml
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
|
||||||
|
[pr_help_docs_headings_prompts]
|
||||||
|
system="""You are Doc-helper, a language model that ranks documentation files based on their relevance to user questions.
|
||||||
|
You will receive a question, a repository url and file names along with optional groups of headings extracted from such files from that repository (either as markdown or as restructred text).
|
||||||
|
Your task is to rank file paths based on how likely they contain the answer to a user's question, using only the headings from each such file and the file name.
|
||||||
|
|
||||||
|
======
|
||||||
|
==file name==
|
||||||
|
|
||||||
|
'src/file1.py'
|
||||||
|
|
||||||
|
==index==
|
||||||
|
|
||||||
|
0 based integer
|
||||||
|
|
||||||
|
==file headings==
|
||||||
|
heading #1
|
||||||
|
heading #2
|
||||||
|
...
|
||||||
|
|
||||||
|
==file name==
|
||||||
|
|
||||||
|
'src/file2.py'
|
||||||
|
|
||||||
|
==index==
|
||||||
|
|
||||||
|
0 based integer
|
||||||
|
|
||||||
|
==file headings==
|
||||||
|
heading #1
|
||||||
|
heading #2
|
||||||
|
...
|
||||||
|
|
||||||
|
...
|
||||||
|
======
|
||||||
|
|
||||||
|
Additional instructions:
|
||||||
|
- Consider only the file names and section headings within each document
|
||||||
|
- Present the most relevant files first, based strictly on how well their headings and file names align with user question
|
||||||
|
|
||||||
|
The output must be a YAML object equivalent to type $DocHeadingsHelper, according to the following Pydantic definitions:
|
||||||
|
=====
|
||||||
|
class file_idx_and_path(BaseModel):
|
||||||
|
idx: int = Field(description="The zero based index of file_name, as it appeared in the original list of headings. Cannot be negative.")
|
||||||
|
file_name: str = Field(description="The file_name exactly as it appeared in the question")
|
||||||
|
|
||||||
|
class DocHeadingsHelper(BaseModel):
|
||||||
|
user_question: str = Field(description="The user's question")
|
||||||
|
relevant_files_ranking: List[file_idx_and_path] = Field(description="Files sorted in descending order by relevance to question")
|
||||||
|
=====
|
||||||
|
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
```yaml
|
||||||
|
user_question: |
|
||||||
|
...
|
||||||
|
relevant_files_ranking:
|
||||||
|
- idx: 101
|
||||||
|
file_name: "src/file1.py"
|
||||||
|
- ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
user="""\
|
||||||
|
Documentation url: '{{ docs_url|trim }}'
|
||||||
|
-----
|
||||||
|
|
||||||
|
|
||||||
|
User's Question:
|
||||||
|
=====
|
||||||
|
{{ question|trim }}
|
||||||
|
=====
|
||||||
|
|
||||||
|
|
||||||
|
Filenames with optional headings from documentation website content:
|
||||||
|
=====
|
||||||
|
{{ snippets|trim }}
|
||||||
|
=====
|
||||||
|
|
||||||
|
|
||||||
|
Reminder: The output must be a YAML object equivalent to type $DocHeadingsHelper, similar to the following example output:
|
||||||
|
=====
|
||||||
|
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
```yaml
|
||||||
|
user_question: |
|
||||||
|
...
|
||||||
|
relevant_files_ranking:
|
||||||
|
- idx: 101
|
||||||
|
file_name: "src/file1.py"
|
||||||
|
- ...
|
||||||
|
=====
|
||||||
|
|
||||||
|
Important Notes:
|
||||||
|
1. Output most relevant file names first, by descending order of relevancy.
|
||||||
|
2. Only include files with non-negative indices
|
||||||
|
|
||||||
|
|
||||||
|
Response (should be a valid YAML, and nothing else).
|
||||||
|
```yaml
|
||||||
|
"""
|
@ -1,11 +1,11 @@
|
|||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from pr_agent.algo import MAX_TOKENS
|
from pr_agent.algo import MAX_TOKENS
|
||||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
@ -78,47 +78,118 @@ def get_maximal_text_input_length_for_token_count_estimation():
|
|||||||
return 900000 #Claude API for token estimation allows maximal text input of 900K chars
|
return 900000 #Claude API for token estimation allows maximal text input of 900K chars
|
||||||
return math.inf #Otherwise, no known limitation on input text just for token estimation
|
return math.inf #Otherwise, no known limitation on input text just for token estimation
|
||||||
|
|
||||||
# Load documentation files to memory, decorating them with a header to mark where each file begins,
|
def return_document_headings(text: str, ext: str) -> str:
|
||||||
# as to help the LLM to give a better answer.
|
try:
|
||||||
def aggregate_documentation_files_for_prompt_contents(base_path: str, doc_files: List[str]) -> Optional[str]:
|
lines = text.split('\n')
|
||||||
docs_prompt = ""
|
headings = set()
|
||||||
for file in doc_files:
|
|
||||||
try:
|
|
||||||
with open(file, 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
# Skip files with no text content
|
|
||||||
if not re.search(r'[a-zA-Z]', content):
|
|
||||||
continue
|
|
||||||
file_path = str(file).replace(str(base_path), '')
|
|
||||||
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{content.strip()}\n=========\n\n"
|
|
||||||
except Exception as e:
|
|
||||||
get_logger().warning(f"Error while reading the file {file}: {e}")
|
|
||||||
continue
|
|
||||||
if not docs_prompt:
|
|
||||||
get_logger().error("Couldn't find any usable documentation files. Returning None.")
|
|
||||||
return None
|
|
||||||
return docs_prompt
|
|
||||||
|
|
||||||
def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: List[Dict[str, str]],
|
if not text or not re.search(r'[a-zA-Z]', text):
|
||||||
supported_suffixes: List[str], base_url_prefix: str, base_url_suffix: str="") -> str:
|
get_logger().error(f"Empty or non text content found in text: {text}.")
|
||||||
base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix
|
return ""
|
||||||
answer_str = ""
|
|
||||||
answer_str += f"### Question: \n{question_str}\n\n"
|
if ext in ['.md', '.mdx']:
|
||||||
answer_str += f"### Answer:\n{response_str.strip()}\n\n"
|
# Extract Markdown headings (lines starting with #)
|
||||||
answer_str += f"#### Relevant Sources:\n\n"
|
headings = {line.strip() for line in lines if line.strip().startswith('#')}
|
||||||
for section in relevant_sections:
|
elif ext == '.rst':
|
||||||
file = section.get('file_name').strip()
|
# Find indices of lines that have all same character:
|
||||||
ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)]
|
#Allowed characters according to list from: https://docutils.sourceforge.io/docs/ref/rst/restructuredtext.html#sections
|
||||||
if not ext:
|
section_chars = set('!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~')
|
||||||
get_logger().warning(f"Unsupported file extension: {file}")
|
|
||||||
continue
|
# Find potential section marker lines (underlines/overlines): They have to be the same character
|
||||||
if str(section['relevant_section_header_string']).strip():
|
marker_lines = []
|
||||||
markdown_header = format_markdown_header(section['relevant_section_header_string'])
|
for i, line in enumerate(lines):
|
||||||
if base_url_prefix:
|
line = line.rstrip()
|
||||||
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}#{markdown_header}\n"
|
if line and all(c == line[0] for c in line) and line[0] in section_chars:
|
||||||
|
marker_lines.append((i, len(line)))
|
||||||
|
|
||||||
|
# Check for headings adjacent to marker lines (below + text must be in length equal or less)
|
||||||
|
for idx, length in marker_lines:
|
||||||
|
# Check if it's an underline (heading is above it)
|
||||||
|
if idx > 0 and lines[idx - 1].rstrip() and len(lines[idx - 1].rstrip()) <= length:
|
||||||
|
headings.add(lines[idx - 1].rstrip())
|
||||||
else:
|
else:
|
||||||
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n"
|
get_logger().error(f"Unsupported file extension: {ext}")
|
||||||
return answer_str
|
return ""
|
||||||
|
|
||||||
|
return '\n'.join(headings)
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Load documentation files to memory: full file path (as will be given as prompt) -> doc contents
|
||||||
|
def map_documentation_files_to_contents(base_path: str, doc_files: list[str], max_allowed_file_len=5000) -> dict[str, str]:
|
||||||
|
try:
|
||||||
|
returned_dict = {}
|
||||||
|
for file in doc_files:
|
||||||
|
try:
|
||||||
|
with open(file, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
# Skip files with no text content
|
||||||
|
if not re.search(r'[a-zA-Z]', content):
|
||||||
|
continue
|
||||||
|
if len(content) > max_allowed_file_len:
|
||||||
|
get_logger().warning(f"File {file} length: {len(content)} exceeds limit: {max_allowed_file_len}, so it will be trimmed.")
|
||||||
|
content = content[:max_allowed_file_len]
|
||||||
|
file_path = str(file).replace(str(base_path), '')
|
||||||
|
returned_dict[file_path] = content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().warning(f"Error while reading the file {file}: {e}")
|
||||||
|
continue
|
||||||
|
if not returned_dict:
|
||||||
|
get_logger().error("Couldn't find any usable documentation files. Returning empty dict.")
|
||||||
|
return returned_dict
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Returning empty dict.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Goes over files' contents, generating payload for prompt while decorating them with a header to mark where each file begins,
|
||||||
|
# as to help the LLM to give a better answer.
|
||||||
|
def aggregate_documentation_files_for_prompt_contents(file_path_to_contents: dict[str, str], return_just_headings=False) -> str:
|
||||||
|
try:
|
||||||
|
docs_prompt = ""
|
||||||
|
for idx, file_path in enumerate(file_path_to_contents):
|
||||||
|
file_contents = file_path_to_contents[file_path].strip()
|
||||||
|
if not file_contents:
|
||||||
|
get_logger().error(f"Got empty file contents for: {file_path}. Skipping this file.")
|
||||||
|
continue
|
||||||
|
if return_just_headings:
|
||||||
|
file_headings = return_document_headings(file_contents, os.path.splitext(file_path)[-1]).strip()
|
||||||
|
if file_headings:
|
||||||
|
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n==file headings==\n\n{file_headings}\n=========\n\n"
|
||||||
|
else:
|
||||||
|
get_logger().warning(f"No headers for: {file_path}. Will only use filename")
|
||||||
|
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n"
|
||||||
|
else:
|
||||||
|
docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{file_contents}\n=========\n\n"
|
||||||
|
return docs_prompt
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: list[dict[str, str]],
|
||||||
|
supported_suffixes: list[str], base_url_prefix: str, base_url_suffix: str="") -> str:
|
||||||
|
try:
|
||||||
|
base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix
|
||||||
|
answer_str = ""
|
||||||
|
answer_str += f"### Question: \n{question_str}\n\n"
|
||||||
|
answer_str += f"### Answer:\n{response_str.strip()}\n\n"
|
||||||
|
answer_str += f"#### Relevant Sources:\n\n"
|
||||||
|
for section in relevant_sections:
|
||||||
|
file = section.get('file_name').lstrip('/').strip() #Remove any '/' in the beginning, since some models do it anyway
|
||||||
|
ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)]
|
||||||
|
if not ext:
|
||||||
|
get_logger().warning(f"Unsupported file extension: {file}")
|
||||||
|
continue
|
||||||
|
if str(section['relevant_section_header_string']).strip():
|
||||||
|
markdown_header = format_markdown_header(section['relevant_section_header_string'])
|
||||||
|
if base_url_prefix:
|
||||||
|
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}#{markdown_header}\n"
|
||||||
|
else:
|
||||||
|
answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n"
|
||||||
|
return answer_str
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||||
|
return ""
|
||||||
|
|
||||||
def format_markdown_header(header: str) -> str:
|
def format_markdown_header(header: str) -> str:
|
||||||
try:
|
try:
|
||||||
@ -157,87 +228,103 @@ def clean_markdown_content(content: str) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
Cleaned markdown content
|
Cleaned markdown content
|
||||||
"""
|
"""
|
||||||
# Remove HTML comments
|
try:
|
||||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
# Remove HTML comments
|
||||||
|
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||||
|
|
||||||
# Remove frontmatter (YAML between --- or +++ delimiters)
|
# Remove frontmatter (YAML between --- or +++ delimiters)
|
||||||
content = re.sub(r'^---\s*\n.*?\n---\s*\n', '', content, flags=re.DOTALL)
|
content = re.sub(r'^---\s*\n.*?\n---\s*\n', '', content, flags=re.DOTALL)
|
||||||
content = re.sub(r'^\+\+\+\s*\n.*?\n\+\+\+\s*\n', '', content, flags=re.DOTALL)
|
content = re.sub(r'^\+\+\+\s*\n.*?\n\+\+\+\s*\n', '', content, flags=re.DOTALL)
|
||||||
|
|
||||||
# Remove excessive blank lines (more than 2 consecutive)
|
# Remove excessive blank lines (more than 2 consecutive)
|
||||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||||
|
|
||||||
# Remove HTML tags that are often used for styling only
|
# Remove HTML tags that are often used for styling only
|
||||||
content = re.sub(r'<div.*?>|</div>|<span.*?>|</span>', '', content, flags=re.DOTALL)
|
content = re.sub(r'<div.*?>|</div>|<span.*?>|</span>', '', content, flags=re.DOTALL)
|
||||||
|
|
||||||
# Remove image alt text which can be verbose
|
# Remove image alt text which can be verbose
|
||||||
content = re.sub(r'!\[(.*?)\]', '![]', content)
|
content = re.sub(r'!\[(.*?)\]', '![]', content)
|
||||||
|
|
||||||
# Remove images completely
|
# Remove images completely
|
||||||
content = re.sub(r'!\[.*?\]\(.*?\)', '', content)
|
content = re.sub(r'!\[.*?\]\(.*?\)', '', content)
|
||||||
|
|
||||||
# Remove simple HTML tags but preserve content between them
|
# Remove simple HTML tags but preserve content between them
|
||||||
content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)</\1>',
|
content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)</\1>',
|
||||||
r'\2', content, flags=re.DOTALL)
|
r'\2', content, flags=re.DOTALL)
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||||
|
return ""
|
||||||
|
|
||||||
class PredictionPreparator:
|
class PredictionPreparator:
|
||||||
def __init__(self, ai_handler, vars, system_prompt, user_prompt):
|
def __init__(self, ai_handler, vars, system_prompt, user_prompt):
|
||||||
self.ai_handler = ai_handler
|
try:
|
||||||
variables = copy.deepcopy(vars)
|
self.ai_handler = ai_handler
|
||||||
environment = Environment(undefined=StrictUndefined)
|
variables = copy.deepcopy(vars)
|
||||||
self.system_prompt = environment.from_string(system_prompt).render(variables)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
self.user_prompt = environment.from_string(user_prompt).render(variables)
|
self.system_prompt = environment.from_string(system_prompt).render(variables)
|
||||||
|
self.user_prompt = environment.from_string(user_prompt).render(variables)
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Caught exception during init. Setting ai_handler to None to prevent __call__.")
|
||||||
|
self.ai_handler = None
|
||||||
|
|
||||||
|
#Called by retry_with_fallback_models and therefore, on any failure must throw an exception:
|
||||||
async def __call__(self, model: str) -> str:
|
async def __call__(self, model: str) -> str:
|
||||||
|
if not self.ai_handler:
|
||||||
|
get_logger().error("ai handler not set. Cannot invoke model!")
|
||||||
|
raise ValueError("PredictionPreparator not initialized")
|
||||||
try:
|
try:
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model, temperature=get_settings().config.temperature, system=self.system_prompt, user=self.user_prompt)
|
model=model, temperature=get_settings().config.temperature, system=self.system_prompt, user=self.user_prompt)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error while preparing prediction: {e}")
|
get_logger().exception("Caught exception during prediction.", artifacts={'system': self.system_prompt, 'user': self.user_prompt})
|
||||||
return ""
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class PRHelpDocs(object):
|
class PRHelpDocs(object):
|
||||||
def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: Tuple[str]=None, return_as_string: bool=False):
|
def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: tuple[str]=None, return_as_string: bool=False):
|
||||||
self.ctx_url = ctx_url
|
try:
|
||||||
self.question = args[0] if args else None
|
self.ctx_url = ctx_url
|
||||||
self.return_as_string = return_as_string
|
self.question = args[0] if args else None
|
||||||
self.repo_url_given_explicitly = True
|
self.return_as_string = return_as_string
|
||||||
self.repo_url = get_settings().get('PR_HELP_DOCS.REPO_URL', '')
|
self.repo_url_given_explicitly = True
|
||||||
self.repo_desired_branch = get_settings().get('PR_HELP_DOCS.REPO_DEFAULT_BRANCH', 'main') #Ignored if self.repo_url is empty
|
self.repo_url = get_settings().get('PR_HELP_DOCS.REPO_URL', '')
|
||||||
self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README'])
|
self.repo_desired_branch = get_settings().get('PR_HELP_DOCS.REPO_DEFAULT_BRANCH', 'main') #Ignored if self.repo_url is empty
|
||||||
self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS']
|
self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README'])
|
||||||
self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH']
|
self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS']
|
||||||
|
self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH']
|
||||||
|
|
||||||
retrieved_settings = [self.include_root_readme_file, self.supported_doc_exts, self.docs_path]
|
retrieved_settings = [self.include_root_readme_file, self.supported_doc_exts, self.docs_path]
|
||||||
if any([setting is None for setting in retrieved_settings]):
|
if any([setting is None for setting in retrieved_settings]):
|
||||||
raise Exception(f"One of the settings is invalid: {retrieved_settings}")
|
raise Exception(f"One of the settings is invalid: {retrieved_settings}")
|
||||||
|
|
||||||
self.git_provider = get_git_provider_with_context(ctx_url)
|
self.git_provider = get_git_provider_with_context(ctx_url)
|
||||||
if not self.git_provider:
|
if not self.git_provider:
|
||||||
raise Exception(f"No git provider found at {ctx_url}")
|
raise Exception(f"No git provider found at {ctx_url}")
|
||||||
if not self.repo_url:
|
|
||||||
self.repo_url_given_explicitly = False
|
|
||||||
get_logger().debug(f"No explicit repo url provided, deducing it from type: {self.git_provider.__class__.__name__} "
|
|
||||||
f"context url: {self.ctx_url}")
|
|
||||||
self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url)
|
|
||||||
if not self.repo_url:
|
if not self.repo_url:
|
||||||
raise Exception(f"Unable to deduce repo url from type: {self.git_provider.__class__.__name__} url: {self.ctx_url}")
|
self.repo_url_given_explicitly = False
|
||||||
get_logger().debug(f"deduced repo url: {self.repo_url}")
|
get_logger().debug(f"No explicit repo url provided, deducing it from type: {self.git_provider.__class__.__name__} "
|
||||||
self.repo_desired_branch = None #Inferred from the repo provider.
|
f"context url: {self.ctx_url}")
|
||||||
|
self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url)
|
||||||
|
if not self.repo_url:
|
||||||
|
raise Exception(f"Unable to deduce repo url from type: {self.git_provider.__class__.__name__} url: {self.ctx_url}")
|
||||||
|
get_logger().debug(f"deduced repo url: {self.repo_url}")
|
||||||
|
self.repo_desired_branch = None #Inferred from the repo provider.
|
||||||
|
|
||||||
self.ai_handler = ai_handler()
|
self.ai_handler = ai_handler()
|
||||||
self.vars = {
|
self.vars = {
|
||||||
"docs_url": self.repo_url,
|
"docs_url": self.repo_url,
|
||||||
"question": self.question,
|
"question": self.question,
|
||||||
"snippets": "",
|
"snippets": "",
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(None,
|
self.token_handler = TokenHandler(None,
|
||||||
self.vars,
|
self.vars,
|
||||||
get_settings().pr_help_docs_prompts.system,
|
get_settings().pr_help_docs_prompts.system,
|
||||||
get_settings().pr_help_docs_prompts.user)
|
get_settings().pr_help_docs_prompts.user)
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Caught exception during init. Setting self.question to None to prevent run() to do anything.")
|
||||||
|
self.question = None
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
if not self.question:
|
if not self.question:
|
||||||
@ -246,7 +333,93 @@ class PRHelpDocs(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Clone the repository and gather relevant documentation files.
|
# Clone the repository and gather relevant documentation files.
|
||||||
docs_prompt = None
|
docs_filepath_to_contents = self._gen_filenames_to_contents_map_from_repo()
|
||||||
|
|
||||||
|
#Generate prompt for the AI model. This will be the full text of all the documentation files combined.
|
||||||
|
docs_prompt = aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents)
|
||||||
|
if not docs_filepath_to_contents or not docs_prompt:
|
||||||
|
get_logger().warning(f"Could not find any usable documentation. Returning with no result...")
|
||||||
|
return None
|
||||||
|
docs_prompt_to_send_to_model = docs_prompt
|
||||||
|
|
||||||
|
# Estimate how many tokens will be needed.
|
||||||
|
# In case the expected number of tokens exceeds LLM limits, retry with just headings, asking the LLM to rank according to relevance to the question.
|
||||||
|
# Based on returned ranking, rerun but sort the documents accordingly, this time, trim in case of exceeding limit.
|
||||||
|
|
||||||
|
#First, check if the text is not too long to even query the LLM provider:
|
||||||
|
max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation()
|
||||||
|
invoke_llm_just_with_headings = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
|
||||||
|
only_return_if_trim_needed=True)
|
||||||
|
if invoke_llm_just_with_headings:
|
||||||
|
#Entire docs is too long. Rank and return according to relevance.
|
||||||
|
docs_prompt_to_send_to_model = await self._rank_docs_and_return_them_as_prompt(docs_filepath_to_contents,
|
||||||
|
max_allowed_txt_input)
|
||||||
|
|
||||||
|
if not docs_prompt_to_send_to_model:
|
||||||
|
get_logger().error("Failed to generate docs prompt for model. Returning with no result...")
|
||||||
|
return
|
||||||
|
# At this point, either all original documents be used (if their total length doesn't exceed limits), or only those selected.
|
||||||
|
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
|
||||||
|
# Run the AI model and extract sections from its response
|
||||||
|
response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars,
|
||||||
|
get_settings().pr_help_docs_prompts.system,
|
||||||
|
get_settings().pr_help_docs_prompts.user),
|
||||||
|
model_type=ModelType.REGULAR)
|
||||||
|
response_yaml = load_yaml(response)
|
||||||
|
if not response_yaml:
|
||||||
|
get_logger().error("Failed to parse the AI response.", artifacts={'response': response})
|
||||||
|
return
|
||||||
|
response_str = response_yaml.get('response')
|
||||||
|
relevant_sections = response_yaml.get('relevant_sections')
|
||||||
|
if not response_str or not relevant_sections:
|
||||||
|
get_logger().error("Failed to extract response/relevant sections.",
|
||||||
|
artifacts={'raw_response': response, 'response_str': response_str, 'relevant_sections': relevant_sections})
|
||||||
|
return
|
||||||
|
if int(response_yaml.get('question_is_relevant', '1')) == 0:
|
||||||
|
get_logger().warning(f"Question is not relevant. Returning without an answer...",
|
||||||
|
artifacts={'raw_response': response})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Format the response as markdown
|
||||||
|
answer_str = self._format_model_answer(response_str, relevant_sections)
|
||||||
|
if self.return_as_string: #Skip publishing
|
||||||
|
return answer_str
|
||||||
|
#Otherwise, publish the answer if answer is non empty and publish is not turned off:
|
||||||
|
if answer_str and get_settings().config.publish_output:
|
||||||
|
self.git_provider.publish_comment(answer_str)
|
||||||
|
else:
|
||||||
|
get_logger().info("Answer:", artifacts={'answer_str': answer_str})
|
||||||
|
return answer_str
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)')
|
||||||
|
|
||||||
|
def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False, max_allowed_files=5000) -> list[str]:
|
||||||
|
try:
|
||||||
|
matching_files = []
|
||||||
|
|
||||||
|
# Ensure extensions don't have leading dots and are lowercase
|
||||||
|
dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts]
|
||||||
|
|
||||||
|
# Walk through directory and subdirectories
|
||||||
|
file_cntr = 0
|
||||||
|
for root, _, files in os.walk(abs_docs_path):
|
||||||
|
for file in files:
|
||||||
|
if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]:
|
||||||
|
continue
|
||||||
|
# Check if file has one of the specified extensions
|
||||||
|
if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions):
|
||||||
|
file_cntr+=1
|
||||||
|
matching_files.append(os.path.join(root, file))
|
||||||
|
if file_cntr >= max_allowed_files:
|
||||||
|
get_logger().warning(f"Found at least {max_allowed_files} files in {abs_docs_path}, skipping the rest.")
|
||||||
|
return matching_files
|
||||||
|
return matching_files
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Returning empty list.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _gen_filenames_to_contents_map_from_repo(self) -> dict[str, str]:
|
||||||
|
try:
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
get_logger().debug(f"About to clone repository: {self.repo_url} to temporary directory: {tmp_dir}...")
|
get_logger().debug(f"About to clone repository: {self.repo_url} to temporary directory: {tmp_dir}...")
|
||||||
returned_cloned_repo_root = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False)
|
returned_cloned_repo_root = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False)
|
||||||
@ -268,103 +441,120 @@ class PRHelpDocs(object):
|
|||||||
ignore_readme=(self.docs_path=='.')))
|
ignore_readme=(self.docs_path=='.')))
|
||||||
if not doc_files:
|
if not doc_files:
|
||||||
get_logger().warning(f"No documentation files found matching file extensions: "
|
get_logger().warning(f"No documentation files found matching file extensions: "
|
||||||
f"{self.supported_doc_exts} under repo: {self.repo_url} path: {self.docs_path}")
|
f"{self.supported_doc_exts} under repo: {self.repo_url} "
|
||||||
return None
|
f"path: {self.docs_path}. Returning empty dict.")
|
||||||
|
return {}
|
||||||
|
|
||||||
get_logger().info(f'Answering a question inside context {self.ctx_url} for repo: {self.repo_url}'
|
get_logger().info(f'For context {self.ctx_url} and repo: {self.repo_url}'
|
||||||
f' using the following documentation files: ', artifacts={'doc_files': doc_files})
|
f' will be using the following documentation files: ',
|
||||||
|
artifacts={'doc_files': doc_files})
|
||||||
|
|
||||||
docs_prompt = aggregate_documentation_files_for_prompt_contents(returned_cloned_repo_root.path, doc_files)
|
return map_documentation_files_to_contents(returned_cloned_repo_root.path, doc_files)
|
||||||
if not docs_prompt:
|
except Exception as e:
|
||||||
get_logger().warning(f"Error reading one of the documentation files. Returning with no result...")
|
get_logger().exception(f"Unexpected exception thrown. Returning empty dict.")
|
||||||
return None
|
return {}
|
||||||
docs_prompt_to_send_to_model = docs_prompt
|
|
||||||
|
|
||||||
# Estimate how many tokens will be needed. Trim in case of exceeding limit.
|
def _trim_docs_input(self, docs_input: str, max_allowed_txt_input: int, only_return_if_trim_needed=False) -> bool|str:
|
||||||
# Firstly, check if text needs to be trimmed, as some models fail to return the estimated token count if the input text is too long.
|
try:
|
||||||
max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation()
|
if len(docs_input) >= max_allowed_txt_input:
|
||||||
if len(docs_prompt_to_send_to_model) >= max_allowed_txt_input:
|
get_logger().warning(
|
||||||
get_logger().warning(f"Text length: {len(docs_prompt_to_send_to_model)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...")
|
f"Text length: {len(docs_input)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...")
|
||||||
docs_prompt_to_send_to_model = docs_prompt_to_send_to_model[:max_allowed_txt_input]
|
if only_return_if_trim_needed:
|
||||||
|
return True
|
||||||
|
docs_input = docs_input[:max_allowed_txt_input]
|
||||||
# Then, count the tokens in the prompt. If the count exceeds the limit, trim the text.
|
# Then, count the tokens in the prompt. If the count exceeds the limit, trim the text.
|
||||||
token_count = self.token_handler.count_tokens(docs_prompt_to_send_to_model, force_accurate=True)
|
token_count = self.token_handler.count_tokens(docs_input, force_accurate=True)
|
||||||
get_logger().debug(f"Estimated token count of documentation to send to model: {token_count}")
|
get_logger().debug(f"Estimated token count of documentation to send to model: {token_count}")
|
||||||
model = get_settings().config.model
|
model = get_settings().config.model
|
||||||
if model in MAX_TOKENS:
|
if model in MAX_TOKENS:
|
||||||
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
max_tokens_full = MAX_TOKENS[
|
||||||
|
model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||||
else:
|
else:
|
||||||
max_tokens_full = get_max_tokens(model)
|
max_tokens_full = get_max_tokens(model)
|
||||||
delta_output = 5000 #Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines.
|
delta_output = 5000 # Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines.
|
||||||
if token_count > max_tokens_full - delta_output:
|
if token_count > max_tokens_full - delta_output:
|
||||||
docs_prompt_to_send_to_model = clean_markdown_content(docs_prompt_to_send_to_model) #Reduce unnecessary text/images/etc.
|
if only_return_if_trim_needed:
|
||||||
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...")
|
return True
|
||||||
docs_prompt_to_send_to_model = clip_tokens(docs_prompt_to_send_to_model, max_tokens_full - delta_output,
|
docs_input = clean_markdown_content(
|
||||||
|
docs_input) # Reduce unnecessary text/images/etc.
|
||||||
|
get_logger().info(
|
||||||
|
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...")
|
||||||
|
docs_input = clip_tokens(docs_input, max_tokens_full - delta_output,
|
||||||
num_input_tokens=token_count)
|
num_input_tokens=token_count)
|
||||||
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
|
if only_return_if_trim_needed:
|
||||||
|
return False
|
||||||
|
return docs_input
|
||||||
|
except Exception as e:
|
||||||
|
# Unexpected exception. Rethrowing it since:
|
||||||
|
# 1. This is an internal function.
|
||||||
|
# 2. An empty str/False result is a valid one - would require now checking also for None.
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Rethrowing it...")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _rank_docs_and_return_them_as_prompt(self, docs_filepath_to_contents: dict[str, str], max_allowed_txt_input: int) -> str:
|
||||||
|
try:
|
||||||
|
#Return just file name and their headings (if exist):
|
||||||
|
docs_prompt_to_send_to_model = (
|
||||||
|
aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents,
|
||||||
|
return_just_headings=True))
|
||||||
|
# Verify list of headings does not exceed limits - trim it if it does.
|
||||||
|
docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
|
||||||
|
only_return_if_trim_needed=False)
|
||||||
|
if not docs_prompt_to_send_to_model:
|
||||||
|
get_logger().error("_trim_docs_input returned an empty result.")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
self.vars['snippets'] = docs_prompt_to_send_to_model.strip()
|
||||||
# Run the AI model and extract sections from its response
|
# Run the AI model and extract sections from its response
|
||||||
response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars,
|
response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars,
|
||||||
get_settings().pr_help_docs_prompts.system,
|
get_settings().pr_help_docs_headings_prompts.system,
|
||||||
get_settings().pr_help_docs_prompts.user),
|
get_settings().pr_help_docs_headings_prompts.user),
|
||||||
model_type=ModelType.REGULAR)
|
model_type=ModelType.REGULAR)
|
||||||
response_yaml = load_yaml(response)
|
response_yaml = load_yaml(response)
|
||||||
if not response_yaml:
|
if not response_yaml:
|
||||||
get_logger().exception("Failed to parse the AI response.", artifacts={'response': response})
|
get_logger().error("Failed to parse the AI response.", artifacts={'response': response})
|
||||||
raise Exception(f"Failed to parse the AI response.")
|
return ""
|
||||||
response_str = response_yaml.get('response')
|
# else: Sanitize the output so that the file names match 1:1 dictionary keys. Do this via the file index and not its name, which may be altered by the model.
|
||||||
relevant_sections = response_yaml.get('relevant_sections')
|
valid_indices = [int(entry['idx']) for entry in response_yaml.get('relevant_files_ranking')
|
||||||
if not response_str or not relevant_sections:
|
if int(entry['idx']) >= 0 and int(entry['idx']) < len(docs_filepath_to_contents)]
|
||||||
get_logger().exception("Failed to extract response/relevant sections.",
|
valid_file_paths = [list(docs_filepath_to_contents.keys())[idx] for idx in valid_indices]
|
||||||
artifacts={'response_str': response_str, 'relevant_sections': relevant_sections})
|
selected_docs_dict = {file_path: docs_filepath_to_contents[file_path] for file_path in valid_file_paths}
|
||||||
raise Exception(f"Failed to extract response/relevant sections.")
|
docs_prompt = aggregate_documentation_files_for_prompt_contents(selected_docs_dict)
|
||||||
|
docs_prompt_to_send_to_model = docs_prompt
|
||||||
|
|
||||||
# Format the response as markdown
|
# Check if the updated list of documents does not exceed limits and trim if it does:
|
||||||
canonical_url_prefix, canonical_url_suffix = self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None,
|
docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input,
|
||||||
desired_branch=self.repo_desired_branch)
|
only_return_if_trim_needed=False)
|
||||||
answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections, self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix)
|
if not docs_prompt_to_send_to_model:
|
||||||
|
get_logger().error("_trim_docs_input returned an empty result.")
|
||||||
|
return ""
|
||||||
|
return docs_prompt_to_send_to_model
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _format_model_answer(self, response_str: str, relevant_sections: list[dict[str, str]]) -> str:
|
||||||
|
try:
|
||||||
|
canonical_url_prefix, canonical_url_suffix = (
|
||||||
|
self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None,
|
||||||
|
desired_branch=self.repo_desired_branch))
|
||||||
|
answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections,
|
||||||
|
self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix)
|
||||||
if answer_str:
|
if answer_str:
|
||||||
#Remove the question phrase and replace with light bulb and a heading mentioning this is an automated answer:
|
#Remove the question phrase and replace with light bulb and a heading mentioning this is an automated answer:
|
||||||
answer_str = modify_answer_section(answer_str)
|
answer_str = modify_answer_section(answer_str)
|
||||||
# For PR help docs, we return the answer string instead of publishing it
|
#In case the response should not be published and returned as string, stop here:
|
||||||
if answer_str and self.return_as_string:
|
if answer_str and self.return_as_string:
|
||||||
if int(response_yaml.get('question_is_relevant', '1')) == 0:
|
|
||||||
get_logger().warning(f"Chat help docs answer would be ignored due to an invalid question.",
|
|
||||||
artifacts={'answer_str': answer_str})
|
|
||||||
return ""
|
|
||||||
get_logger().info(f"Chat help docs answer", artifacts={'answer_str': answer_str})
|
get_logger().info(f"Chat help docs answer", artifacts={'answer_str': answer_str})
|
||||||
return answer_str
|
return answer_str
|
||||||
|
if not answer_str:
|
||||||
# Publish the answer
|
|
||||||
if not answer_str or int(response_yaml.get('question_is_relevant', '1')) == 0:
|
|
||||||
get_logger().info(f"No answer found")
|
get_logger().info(f"No answer found")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_help_docs.enable_help_text:
|
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_help_docs.enable_help_text:
|
||||||
answer_str += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
answer_str += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
||||||
answer_str += HelpMessage.get_help_docs_usage_guide()
|
answer_str += HelpMessage.get_help_docs_usage_guide()
|
||||||
answer_str += "\n</details>\n"
|
answer_str += "\n</details>\n"
|
||||||
|
return answer_str
|
||||||
if get_settings().config.publish_output:
|
except Exception as e:
|
||||||
self.git_provider.publish_comment(answer_str)
|
get_logger().exception(f"Unexpected exception thrown. Returning empty result.")
|
||||||
else:
|
return ""
|
||||||
get_logger().info("Answer:", artifacts={'answer_str': answer_str})
|
|
||||||
|
|
||||||
except:
|
|
||||||
get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)')
|
|
||||||
|
|
||||||
|
|
||||||
def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False) -> List[str]:
|
|
||||||
matching_files = []
|
|
||||||
|
|
||||||
# Ensure extensions don't have leading dots and are lowercase
|
|
||||||
dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts]
|
|
||||||
|
|
||||||
# Walk through directory and subdirectories
|
|
||||||
for root, _, files in os.walk(abs_docs_path):
|
|
||||||
for file in files:
|
|
||||||
if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]:
|
|
||||||
continue
|
|
||||||
# Check if file has one of the specified extensions
|
|
||||||
if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions):
|
|
||||||
matching_files.append(os.path.join(root, file))
|
|
||||||
return matching_files
|
|
||||||
|
Reference in New Issue
Block a user