From 14971c4f5fa23bae133e2b5e12ab33698d0aceeb Mon Sep 17 00:00:00 2001 From: sharoneyal Date: Thu, 3 Apr 2025 11:51:26 +0300 Subject: [PATCH] Add support for documentation content exceeding token limits (#1670) * - Add support for documentation content exceeding token limits via two phase operation: 1. Ask LLM to rank headings which are most likely to contain an answer to a user question 2. Provide the corresponding files for the LLM to search for an answer. - Refactor of help_docs to make the code more readable - For the purpose of getting canonical path: git providers to use default branch and not the PR's source branch. - Refactor of token counting and making it clear on when an estimate factor will be used. * Code review changes: 1. Correctly handle exception during retry_with_fallback_models (to allow fallback model to run in case of failure) 2. Better naming for default_branch in bitbucket cloud provider --- pr_agent/algo/token_handler.py | 32 +- pr_agent/config_loader.py | 1 + pr_agent/git_providers/bitbucket_provider.py | 12 +- .../bitbucket_server_provider.py | 8 +- pr_agent/git_providers/github_provider.py | 2 +- pr_agent/git_providers/gitlab_provider.py | 6 +- pr_agent/settings/configuration.toml | 2 +- .../pr_help_docs_headings_prompts.toml | 101 ++++ pr_agent/tools/pr_help_docs.py | 528 ++++++++++++------ 9 files changed, 505 insertions(+), 187 deletions(-) create mode 100644 pr_agent/settings/pr_help_docs_headings_prompts.toml diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 9bc801ed..f1393e38 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -1,7 +1,6 @@ from threading import Lock from jinja2 import Environment, StrictUndefined -from math import ceil from tiktoken import encoding_for_model, get_encoding from pr_agent.config_loader import get_settings @@ -105,6 +104,19 @@ class TokenHandler: get_logger().error( f"Error in Anthropic token counting: {e}") return MaxTokens + def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate): + from math import ceil + import re + + model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model) + if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'): + return default_encoder_estimate + #else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort. + + elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) + get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate") + return ceil(elbow_factor * default_encoder_estimate) + def count_tokens(self, patch: str, force_accurate=False) -> int: """ Counts the number of tokens in a given patch string. @@ -116,21 +128,15 @@ class TokenHandler: The number of tokens in the patch string. """ encoder_estimate = len(self.encoder.encode(patch, disallowed_special=())) + + #If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. if not force_accurate: return encoder_estimate - #else, need to provide an accurate estimation: + #else, force_accurate==True: User requested providing an accurate estimation: model = get_settings().config.model.lower() - if force_accurate and 'claude' in model and get_settings(use_context=False).get('anthropic.key'): + if 'claude' in model and get_settings(use_context=False).get('anthropic.key'): return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models - #else: Non Anthropic provided model - import re - model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model) - if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'): - return encoder_estimate - #else: Model is neither an OpenAI, nor an Anthropic model - therefore, cannot provide an accurate token count and instead, return a higher number as best effort. - - elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) - get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate") - return ceil(elbow_factor * encoder_estimate) + #else: Non Anthropic provided model: + return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate) diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 575c02a3..7a62adec 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -29,6 +29,7 @@ global_settings = Dynaconf( "settings/custom_labels.toml", "settings/pr_help_prompts.toml", "settings/pr_help_docs_prompts.toml", + "settings/pr_help_docs_headings_prompts.toml", "settings/.secrets.toml", "settings_prod/.secrets.toml", ]] diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index d3882fda..969ddf9e 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -92,7 +92,7 @@ class BitbucketProvider(GitProvider): return ("", "") workspace_name, project_name = repo_path.split('/') else: - desired_branch = self.get_pr_branch() + desired_branch = self.get_repo_default_branch() parsed_pr_url = urlparse(self.pr_url) scheme_and_netloc = parsed_pr_url.scheme + "://" + parsed_pr_url.netloc workspace_name, project_name = (self.workspace_slug, self.repo_slug) @@ -470,6 +470,16 @@ class BitbucketProvider(GitProvider): def get_pr_branch(self): return self.pr.source_branch + # This function attempts to get the default branch of the repository. As a fallback, uses the PR destination branch. + # Note: Must be running from a PR context. + def get_repo_default_branch(self): + try: + url_repo = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/" + response_repo = requests.request("GET", url_repo, headers=self.headers).json() + return response_repo['mainbranch']['name'] + except: + return self.pr.destination_branch + def get_pr_owner_id(self) -> str | None: return self.workspace_slug diff --git a/pr_agent/git_providers/bitbucket_server_provider.py b/pr_agent/git_providers/bitbucket_server_provider.py index ddbb60cc..e10d0319 100644 --- a/pr_agent/git_providers/bitbucket_server_provider.py +++ b/pr_agent/git_providers/bitbucket_server_provider.py @@ -64,9 +64,15 @@ class BitbucketServerProvider(GitProvider): workspace_name = None project_name = None if not repo_git_url: - desired_branch = self.get_pr_branch() workspace_name = self.workspace_slug project_name = self.repo_slug + default_branch_dict = self.bitbucket_client.get_default_branch(workspace_name, project_name) + if 'displayId' in default_branch_dict: + desired_branch = default_branch_dict['displayId'] + else: + get_logger().error(f"Cannot obtain default branch for workspace_name={workspace_name}, " + f"project_name={project_name}, default_branch_dict={default_branch_dict}") + return ("", "") elif '.git' in repo_git_url and 'scm/' in repo_git_url: repo_path = repo_git_url.split('.git')[0].split('scm/')[-1] if repo_path.count('/') == 1: # Has to have the form / diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 92e256fd..e782f9cf 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -133,7 +133,7 @@ class GithubProvider(GitProvider): if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, or not an issue, use self.repo object owner, repo = self.repo.split('/') scheme_and_netloc = self.base_url_html - desired_branch = self.get_pr_branch() + desired_branch = self.repo_obj.default_branch if not all([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)") return ("", "") diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 590aa32e..df18c957 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -87,7 +87,11 @@ class GitLabProvider(GitProvider): return ("", "") if not repo_git_url: #Use PR url as context repo_path = self._get_project_path_from_pr_or_issue_url(self.pr_url) - desired_branch = self.get_pr_branch() + try: + desired_branch = self.gl.projects.get(self.id_project).default_branch + except Exception as e: + get_logger().exception(f"Cannot get PR: {self.pr_url} default branch. Tried project ID: {self.id_project}") + return ("", "") else: #Use repo git url repo_path = repo_git_url.split('.git')[0].split('.com/')[-1] prefix = f"{self.gitlab_url}/{repo_path}/-/blob/{desired_branch}" diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index d5ddc95c..952dedd8 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -9,7 +9,6 @@ model="o3-mini" fallback_models=["gpt-4o-2024-11-20"] #model_weak="gpt-4o-mini-2024-07-18" # optional, a weaker model to use for some easier tasks -model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens. # CLI git_provider="github" publish_output=true @@ -30,6 +29,7 @@ max_description_tokens = 500 max_commits_tokens = 500 max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. custom_model_max_tokens=-1 # for models not in the default list +model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens - applicable only when requesting an accurate estimate. # patch extension logic patch_extension_skip_types =[".md",".txt"] allow_dynamic_context=true diff --git a/pr_agent/settings/pr_help_docs_headings_prompts.toml b/pr_agent/settings/pr_help_docs_headings_prompts.toml new file mode 100644 index 00000000..da9d6e53 --- /dev/null +++ b/pr_agent/settings/pr_help_docs_headings_prompts.toml @@ -0,0 +1,101 @@ + +[pr_help_docs_headings_prompts] +system="""You are Doc-helper, a language model that ranks documentation files based on their relevance to user questions. +You will receive a question, a repository url and file names along with optional groups of headings extracted from such files from that repository (either as markdown or as restructred text). +Your task is to rank file paths based on how likely they contain the answer to a user's question, using only the headings from each such file and the file name. + +====== +==file name== + +'src/file1.py' + +==index== + +0 based integer + +==file headings== +heading #1 +heading #2 +... + +==file name== + +'src/file2.py' + +==index== + +0 based integer + +==file headings== +heading #1 +heading #2 +... + +... +====== + +Additional instructions: +- Consider only the file names and section headings within each document +- Present the most relevant files first, based strictly on how well their headings and file names align with user question + +The output must be a YAML object equivalent to type $DocHeadingsHelper, according to the following Pydantic definitions: +===== +class file_idx_and_path(BaseModel): + idx: int = Field(description="The zero based index of file_name, as it appeared in the original list of headings. Cannot be negative.") + file_name: str = Field(description="The file_name exactly as it appeared in the question") + +class DocHeadingsHelper(BaseModel): + user_question: str = Field(description="The user's question") + relevant_files_ranking: List[file_idx_and_path] = Field(description="Files sorted in descending order by relevance to question") +===== + + +Example output: +```yaml +user_question: | + ... +relevant_files_ranking: +- idx: 101 + file_name: "src/file1.py" +- ... +""" + +user="""\ +Documentation url: '{{ docs_url|trim }}' +----- + + +User's Question: +===== +{{ question|trim }} +===== + + +Filenames with optional headings from documentation website content: +===== +{{ snippets|trim }} +===== + + +Reminder: The output must be a YAML object equivalent to type $DocHeadingsHelper, similar to the following example output: +===== + + +Example output: +```yaml +user_question: | + ... +relevant_files_ranking: +- idx: 101 + file_name: "src/file1.py" +- ... +===== + +Important Notes: +1. Output most relevant file names first, by descending order of relevancy. +2. Only include files with non-negative indices + + +Response (should be a valid YAML, and nothing else). +```yaml +""" diff --git a/pr_agent/tools/pr_help_docs.py b/pr_agent/tools/pr_help_docs.py index ddd42509..89849aa7 100644 --- a/pr_agent/tools/pr_help_docs.py +++ b/pr_agent/tools/pr_help_docs.py @@ -1,11 +1,11 @@ import copy from functools import partial + from jinja2 import Environment, StrictUndefined import math import os import re from tempfile import TemporaryDirectory -from typing import Dict, List, Optional, Tuple from pr_agent.algo import MAX_TOKENS from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler @@ -78,47 +78,118 @@ def get_maximal_text_input_length_for_token_count_estimation(): return 900000 #Claude API for token estimation allows maximal text input of 900K chars return math.inf #Otherwise, no known limitation on input text just for token estimation -# Load documentation files to memory, decorating them with a header to mark where each file begins, -# as to help the LLM to give a better answer. -def aggregate_documentation_files_for_prompt_contents(base_path: str, doc_files: List[str]) -> Optional[str]: - docs_prompt = "" - for file in doc_files: - try: - with open(file, 'r', encoding='utf-8') as f: - content = f.read() - # Skip files with no text content - if not re.search(r'[a-zA-Z]', content): - continue - file_path = str(file).replace(str(base_path), '') - docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{content.strip()}\n=========\n\n" - except Exception as e: - get_logger().warning(f"Error while reading the file {file}: {e}") - continue - if not docs_prompt: - get_logger().error("Couldn't find any usable documentation files. Returning None.") - return None - return docs_prompt +def return_document_headings(text: str, ext: str) -> str: + try: + lines = text.split('\n') + headings = set() -def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: List[Dict[str, str]], - supported_suffixes: List[str], base_url_prefix: str, base_url_suffix: str="") -> str: - base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix - answer_str = "" - answer_str += f"### Question: \n{question_str}\n\n" - answer_str += f"### Answer:\n{response_str.strip()}\n\n" - answer_str += f"#### Relevant Sources:\n\n" - for section in relevant_sections: - file = section.get('file_name').strip() - ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)] - if not ext: - get_logger().warning(f"Unsupported file extension: {file}") - continue - if str(section['relevant_section_header_string']).strip(): - markdown_header = format_markdown_header(section['relevant_section_header_string']) - if base_url_prefix: - answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}#{markdown_header}\n" + if not text or not re.search(r'[a-zA-Z]', text): + get_logger().error(f"Empty or non text content found in text: {text}.") + return "" + + if ext in ['.md', '.mdx']: + # Extract Markdown headings (lines starting with #) + headings = {line.strip() for line in lines if line.strip().startswith('#')} + elif ext == '.rst': + # Find indices of lines that have all same character: + #Allowed characters according to list from: https://docutils.sourceforge.io/docs/ref/rst/restructuredtext.html#sections + section_chars = set('!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~') + + # Find potential section marker lines (underlines/overlines): They have to be the same character + marker_lines = [] + for i, line in enumerate(lines): + line = line.rstrip() + if line and all(c == line[0] for c in line) and line[0] in section_chars: + marker_lines.append((i, len(line))) + + # Check for headings adjacent to marker lines (below + text must be in length equal or less) + for idx, length in marker_lines: + # Check if it's an underline (heading is above it) + if idx > 0 and lines[idx - 1].rstrip() and len(lines[idx - 1].rstrip()) <= length: + headings.add(lines[idx - 1].rstrip()) else: - answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n" - return answer_str + get_logger().error(f"Unsupported file extension: {ext}") + return "" + + return '\n'.join(headings) + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty result.") + return "" + +# Load documentation files to memory: full file path (as will be given as prompt) -> doc contents +def map_documentation_files_to_contents(base_path: str, doc_files: list[str], max_allowed_file_len=5000) -> dict[str, str]: + try: + returned_dict = {} + for file in doc_files: + try: + with open(file, 'r', encoding='utf-8') as f: + content = f.read() + # Skip files with no text content + if not re.search(r'[a-zA-Z]', content): + continue + if len(content) > max_allowed_file_len: + get_logger().warning(f"File {file} length: {len(content)} exceeds limit: {max_allowed_file_len}, so it will be trimmed.") + content = content[:max_allowed_file_len] + file_path = str(file).replace(str(base_path), '') + returned_dict[file_path] = content.strip() + except Exception as e: + get_logger().warning(f"Error while reading the file {file}: {e}") + continue + if not returned_dict: + get_logger().error("Couldn't find any usable documentation files. Returning empty dict.") + return returned_dict + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty dict.") + return {} + +# Goes over files' contents, generating payload for prompt while decorating them with a header to mark where each file begins, +# as to help the LLM to give a better answer. +def aggregate_documentation_files_for_prompt_contents(file_path_to_contents: dict[str, str], return_just_headings=False) -> str: + try: + docs_prompt = "" + for idx, file_path in enumerate(file_path_to_contents): + file_contents = file_path_to_contents[file_path].strip() + if not file_contents: + get_logger().error(f"Got empty file contents for: {file_path}. Skipping this file.") + continue + if return_just_headings: + file_headings = return_document_headings(file_contents, os.path.splitext(file_path)[-1]).strip() + if file_headings: + docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n==file headings==\n\n{file_headings}\n=========\n\n" + else: + get_logger().warning(f"No headers for: {file_path}. Will only use filename") + docs_prompt += f"\n==file name==\n\n{file_path}\n\n==index==\n\n{idx}\n\n" + else: + docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{file_contents}\n=========\n\n" + return docs_prompt + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty result.") + return "" + +def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: list[dict[str, str]], + supported_suffixes: list[str], base_url_prefix: str, base_url_suffix: str="") -> str: + try: + base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix + answer_str = "" + answer_str += f"### Question: \n{question_str}\n\n" + answer_str += f"### Answer:\n{response_str.strip()}\n\n" + answer_str += f"#### Relevant Sources:\n\n" + for section in relevant_sections: + file = section.get('file_name').lstrip('/').strip() #Remove any '/' in the beginning, since some models do it anyway + ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)] + if not ext: + get_logger().warning(f"Unsupported file extension: {file}") + continue + if str(section['relevant_section_header_string']).strip(): + markdown_header = format_markdown_header(section['relevant_section_header_string']) + if base_url_prefix: + answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}#{markdown_header}\n" + else: + answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n" + return answer_str + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty result.") + return "" def format_markdown_header(header: str) -> str: try: @@ -157,87 +228,103 @@ def clean_markdown_content(content: str) -> str: Returns: Cleaned markdown content """ - # Remove HTML comments - content = re.sub(r'', '', content, flags=re.DOTALL) + try: + # Remove HTML comments + content = re.sub(r'', '', content, flags=re.DOTALL) - # Remove frontmatter (YAML between --- or +++ delimiters) - content = re.sub(r'^---\s*\n.*?\n---\s*\n', '', content, flags=re.DOTALL) - content = re.sub(r'^\+\+\+\s*\n.*?\n\+\+\+\s*\n', '', content, flags=re.DOTALL) + # Remove frontmatter (YAML between --- or +++ delimiters) + content = re.sub(r'^---\s*\n.*?\n---\s*\n', '', content, flags=re.DOTALL) + content = re.sub(r'^\+\+\+\s*\n.*?\n\+\+\+\s*\n', '', content, flags=re.DOTALL) - # Remove excessive blank lines (more than 2 consecutive) - content = re.sub(r'\n{3,}', '\n\n', content) + # Remove excessive blank lines (more than 2 consecutive) + content = re.sub(r'\n{3,}', '\n\n', content) - # Remove HTML tags that are often used for styling only - content = re.sub(r'|||', '', content, flags=re.DOTALL) + # Remove HTML tags that are often used for styling only + content = re.sub(r'|||', '', content, flags=re.DOTALL) - # Remove image alt text which can be verbose - content = re.sub(r'!\[(.*?)\]', '![]', content) + # Remove image alt text which can be verbose + content = re.sub(r'!\[(.*?)\]', '![]', content) - # Remove images completely - content = re.sub(r'!\[.*?\]\(.*?\)', '', content) + # Remove images completely + content = re.sub(r'!\[.*?\]\(.*?\)', '', content) - # Remove simple HTML tags but preserve content between them - content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)', - r'\2', content, flags=re.DOTALL) - return content.strip() + # Remove simple HTML tags but preserve content between them + content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)', + r'\2', content, flags=re.DOTALL) + return content.strip() + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty result.") + return "" class PredictionPreparator: def __init__(self, ai_handler, vars, system_prompt, user_prompt): - self.ai_handler = ai_handler - variables = copy.deepcopy(vars) - environment = Environment(undefined=StrictUndefined) - self.system_prompt = environment.from_string(system_prompt).render(variables) - self.user_prompt = environment.from_string(user_prompt).render(variables) + try: + self.ai_handler = ai_handler + variables = copy.deepcopy(vars) + environment = Environment(undefined=StrictUndefined) + self.system_prompt = environment.from_string(system_prompt).render(variables) + self.user_prompt = environment.from_string(user_prompt).render(variables) + except Exception as e: + get_logger().exception(f"Caught exception during init. Setting ai_handler to None to prevent __call__.") + self.ai_handler = None + #Called by retry_with_fallback_models and therefore, on any failure must throw an exception: async def __call__(self, model: str) -> str: + if not self.ai_handler: + get_logger().error("ai handler not set. Cannot invoke model!") + raise ValueError("PredictionPreparator not initialized") try: response, finish_reason = await self.ai_handler.chat_completion( model=model, temperature=get_settings().config.temperature, system=self.system_prompt, user=self.user_prompt) return response except Exception as e: - get_logger().error(f"Error while preparing prediction: {e}") - return "" + get_logger().exception("Caught exception during prediction.", artifacts={'system': self.system_prompt, 'user': self.user_prompt}) + raise e class PRHelpDocs(object): - def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: Tuple[str]=None, return_as_string: bool=False): - self.ctx_url = ctx_url - self.question = args[0] if args else None - self.return_as_string = return_as_string - self.repo_url_given_explicitly = True - self.repo_url = get_settings().get('PR_HELP_DOCS.REPO_URL', '') - self.repo_desired_branch = get_settings().get('PR_HELP_DOCS.REPO_DEFAULT_BRANCH', 'main') #Ignored if self.repo_url is empty - self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README']) - self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS'] - self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH'] + def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: tuple[str]=None, return_as_string: bool=False): + try: + self.ctx_url = ctx_url + self.question = args[0] if args else None + self.return_as_string = return_as_string + self.repo_url_given_explicitly = True + self.repo_url = get_settings().get('PR_HELP_DOCS.REPO_URL', '') + self.repo_desired_branch = get_settings().get('PR_HELP_DOCS.REPO_DEFAULT_BRANCH', 'main') #Ignored if self.repo_url is empty + self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README']) + self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS'] + self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH'] - retrieved_settings = [self.include_root_readme_file, self.supported_doc_exts, self.docs_path] - if any([setting is None for setting in retrieved_settings]): - raise Exception(f"One of the settings is invalid: {retrieved_settings}") + retrieved_settings = [self.include_root_readme_file, self.supported_doc_exts, self.docs_path] + if any([setting is None for setting in retrieved_settings]): + raise Exception(f"One of the settings is invalid: {retrieved_settings}") - self.git_provider = get_git_provider_with_context(ctx_url) - if not self.git_provider: - raise Exception(f"No git provider found at {ctx_url}") - if not self.repo_url: - self.repo_url_given_explicitly = False - get_logger().debug(f"No explicit repo url provided, deducing it from type: {self.git_provider.__class__.__name__} " - f"context url: {self.ctx_url}") - self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url) + self.git_provider = get_git_provider_with_context(ctx_url) + if not self.git_provider: + raise Exception(f"No git provider found at {ctx_url}") if not self.repo_url: - raise Exception(f"Unable to deduce repo url from type: {self.git_provider.__class__.__name__} url: {self.ctx_url}") - get_logger().debug(f"deduced repo url: {self.repo_url}") - self.repo_desired_branch = None #Inferred from the repo provider. + self.repo_url_given_explicitly = False + get_logger().debug(f"No explicit repo url provided, deducing it from type: {self.git_provider.__class__.__name__} " + f"context url: {self.ctx_url}") + self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url) + if not self.repo_url: + raise Exception(f"Unable to deduce repo url from type: {self.git_provider.__class__.__name__} url: {self.ctx_url}") + get_logger().debug(f"deduced repo url: {self.repo_url}") + self.repo_desired_branch = None #Inferred from the repo provider. - self.ai_handler = ai_handler() - self.vars = { - "docs_url": self.repo_url, - "question": self.question, - "snippets": "", - } - self.token_handler = TokenHandler(None, - self.vars, - get_settings().pr_help_docs_prompts.system, - get_settings().pr_help_docs_prompts.user) + self.ai_handler = ai_handler() + self.vars = { + "docs_url": self.repo_url, + "question": self.question, + "snippets": "", + } + self.token_handler = TokenHandler(None, + self.vars, + get_settings().pr_help_docs_prompts.system, + get_settings().pr_help_docs_prompts.user) + except Exception as e: + get_logger().exception(f"Caught exception during init. Setting self.question to None to prevent run() to do anything.") + self.question = None async def run(self): if not self.question: @@ -246,7 +333,93 @@ class PRHelpDocs(object): try: # Clone the repository and gather relevant documentation files. - docs_prompt = None + docs_filepath_to_contents = self._gen_filenames_to_contents_map_from_repo() + + #Generate prompt for the AI model. This will be the full text of all the documentation files combined. + docs_prompt = aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents) + if not docs_filepath_to_contents or not docs_prompt: + get_logger().warning(f"Could not find any usable documentation. Returning with no result...") + return None + docs_prompt_to_send_to_model = docs_prompt + + # Estimate how many tokens will be needed. + # In case the expected number of tokens exceeds LLM limits, retry with just headings, asking the LLM to rank according to relevance to the question. + # Based on returned ranking, rerun but sort the documents accordingly, this time, trim in case of exceeding limit. + + #First, check if the text is not too long to even query the LLM provider: + max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation() + invoke_llm_just_with_headings = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input, + only_return_if_trim_needed=True) + if invoke_llm_just_with_headings: + #Entire docs is too long. Rank and return according to relevance. + docs_prompt_to_send_to_model = await self._rank_docs_and_return_them_as_prompt(docs_filepath_to_contents, + max_allowed_txt_input) + + if not docs_prompt_to_send_to_model: + get_logger().error("Failed to generate docs prompt for model. Returning with no result...") + return + # At this point, either all original documents be used (if their total length doesn't exceed limits), or only those selected. + self.vars['snippets'] = docs_prompt_to_send_to_model.strip() + # Run the AI model and extract sections from its response + response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars, + get_settings().pr_help_docs_prompts.system, + get_settings().pr_help_docs_prompts.user), + model_type=ModelType.REGULAR) + response_yaml = load_yaml(response) + if not response_yaml: + get_logger().error("Failed to parse the AI response.", artifacts={'response': response}) + return + response_str = response_yaml.get('response') + relevant_sections = response_yaml.get('relevant_sections') + if not response_str or not relevant_sections: + get_logger().error("Failed to extract response/relevant sections.", + artifacts={'raw_response': response, 'response_str': response_str, 'relevant_sections': relevant_sections}) + return + if int(response_yaml.get('question_is_relevant', '1')) == 0: + get_logger().warning(f"Question is not relevant. Returning without an answer...", + artifacts={'raw_response': response}) + return + + # Format the response as markdown + answer_str = self._format_model_answer(response_str, relevant_sections) + if self.return_as_string: #Skip publishing + return answer_str + #Otherwise, publish the answer if answer is non empty and publish is not turned off: + if answer_str and get_settings().config.publish_output: + self.git_provider.publish_comment(answer_str) + else: + get_logger().info("Answer:", artifacts={'answer_str': answer_str}) + return answer_str + except Exception as e: + get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)') + + def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False, max_allowed_files=5000) -> list[str]: + try: + matching_files = [] + + # Ensure extensions don't have leading dots and are lowercase + dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts] + + # Walk through directory and subdirectories + file_cntr = 0 + for root, _, files in os.walk(abs_docs_path): + for file in files: + if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]: + continue + # Check if file has one of the specified extensions + if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions): + file_cntr+=1 + matching_files.append(os.path.join(root, file)) + if file_cntr >= max_allowed_files: + get_logger().warning(f"Found at least {max_allowed_files} files in {abs_docs_path}, skipping the rest.") + return matching_files + return matching_files + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty list.") + return [] + + def _gen_filenames_to_contents_map_from_repo(self) -> dict[str, str]: + try: with TemporaryDirectory() as tmp_dir: get_logger().debug(f"About to clone repository: {self.repo_url} to temporary directory: {tmp_dir}...") returned_cloned_repo_root = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False) @@ -268,103 +441,120 @@ class PRHelpDocs(object): ignore_readme=(self.docs_path=='.'))) if not doc_files: get_logger().warning(f"No documentation files found matching file extensions: " - f"{self.supported_doc_exts} under repo: {self.repo_url} path: {self.docs_path}") - return None + f"{self.supported_doc_exts} under repo: {self.repo_url} " + f"path: {self.docs_path}. Returning empty dict.") + return {} - get_logger().info(f'Answering a question inside context {self.ctx_url} for repo: {self.repo_url}' - f' using the following documentation files: ', artifacts={'doc_files': doc_files}) + get_logger().info(f'For context {self.ctx_url} and repo: {self.repo_url}' + f' will be using the following documentation files: ', + artifacts={'doc_files': doc_files}) - docs_prompt = aggregate_documentation_files_for_prompt_contents(returned_cloned_repo_root.path, doc_files) - if not docs_prompt: - get_logger().warning(f"Error reading one of the documentation files. Returning with no result...") - return None - docs_prompt_to_send_to_model = docs_prompt + return map_documentation_files_to_contents(returned_cloned_repo_root.path, doc_files) + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty dict.") + return {} - # Estimate how many tokens will be needed. Trim in case of exceeding limit. - # Firstly, check if text needs to be trimmed, as some models fail to return the estimated token count if the input text is too long. - max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation() - if len(docs_prompt_to_send_to_model) >= max_allowed_txt_input: - get_logger().warning(f"Text length: {len(docs_prompt_to_send_to_model)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...") - docs_prompt_to_send_to_model = docs_prompt_to_send_to_model[:max_allowed_txt_input] + def _trim_docs_input(self, docs_input: str, max_allowed_txt_input: int, only_return_if_trim_needed=False) -> bool|str: + try: + if len(docs_input) >= max_allowed_txt_input: + get_logger().warning( + f"Text length: {len(docs_input)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...") + if only_return_if_trim_needed: + return True + docs_input = docs_input[:max_allowed_txt_input] # Then, count the tokens in the prompt. If the count exceeds the limit, trim the text. - token_count = self.token_handler.count_tokens(docs_prompt_to_send_to_model, force_accurate=True) + token_count = self.token_handler.count_tokens(docs_input, force_accurate=True) get_logger().debug(f"Estimated token count of documentation to send to model: {token_count}") model = get_settings().config.model if model in MAX_TOKENS: - max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt + max_tokens_full = MAX_TOKENS[ + model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt else: max_tokens_full = get_max_tokens(model) - delta_output = 5000 #Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines. + delta_output = 5000 # Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines. if token_count > max_tokens_full - delta_output: - docs_prompt_to_send_to_model = clean_markdown_content(docs_prompt_to_send_to_model) #Reduce unnecessary text/images/etc. - get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...") - docs_prompt_to_send_to_model = clip_tokens(docs_prompt_to_send_to_model, max_tokens_full - delta_output, + if only_return_if_trim_needed: + return True + docs_input = clean_markdown_content( + docs_input) # Reduce unnecessary text/images/etc. + get_logger().info( + f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...") + docs_input = clip_tokens(docs_input, max_tokens_full - delta_output, num_input_tokens=token_count) - self.vars['snippets'] = docs_prompt_to_send_to_model.strip() + if only_return_if_trim_needed: + return False + return docs_input + except Exception as e: + # Unexpected exception. Rethrowing it since: + # 1. This is an internal function. + # 2. An empty str/False result is a valid one - would require now checking also for None. + get_logger().exception(f"Unexpected exception thrown. Rethrowing it...") + raise e + async def _rank_docs_and_return_them_as_prompt(self, docs_filepath_to_contents: dict[str, str], max_allowed_txt_input: int) -> str: + try: + #Return just file name and their headings (if exist): + docs_prompt_to_send_to_model = ( + aggregate_documentation_files_for_prompt_contents(docs_filepath_to_contents, + return_just_headings=True)) + # Verify list of headings does not exceed limits - trim it if it does. + docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input, + only_return_if_trim_needed=False) + if not docs_prompt_to_send_to_model: + get_logger().error("_trim_docs_input returned an empty result.") + return "" + + self.vars['snippets'] = docs_prompt_to_send_to_model.strip() # Run the AI model and extract sections from its response response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars, - get_settings().pr_help_docs_prompts.system, - get_settings().pr_help_docs_prompts.user), + get_settings().pr_help_docs_headings_prompts.system, + get_settings().pr_help_docs_headings_prompts.user), model_type=ModelType.REGULAR) response_yaml = load_yaml(response) if not response_yaml: - get_logger().exception("Failed to parse the AI response.", artifacts={'response': response}) - raise Exception(f"Failed to parse the AI response.") - response_str = response_yaml.get('response') - relevant_sections = response_yaml.get('relevant_sections') - if not response_str or not relevant_sections: - get_logger().exception("Failed to extract response/relevant sections.", - artifacts={'response_str': response_str, 'relevant_sections': relevant_sections}) - raise Exception(f"Failed to extract response/relevant sections.") + get_logger().error("Failed to parse the AI response.", artifacts={'response': response}) + return "" + # else: Sanitize the output so that the file names match 1:1 dictionary keys. Do this via the file index and not its name, which may be altered by the model. + valid_indices = [int(entry['idx']) for entry in response_yaml.get('relevant_files_ranking') + if int(entry['idx']) >= 0 and int(entry['idx']) < len(docs_filepath_to_contents)] + valid_file_paths = [list(docs_filepath_to_contents.keys())[idx] for idx in valid_indices] + selected_docs_dict = {file_path: docs_filepath_to_contents[file_path] for file_path in valid_file_paths} + docs_prompt = aggregate_documentation_files_for_prompt_contents(selected_docs_dict) + docs_prompt_to_send_to_model = docs_prompt - # Format the response as markdown - canonical_url_prefix, canonical_url_suffix = self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None, - desired_branch=self.repo_desired_branch) - answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections, self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix) + # Check if the updated list of documents does not exceed limits and trim if it does: + docs_prompt_to_send_to_model = self._trim_docs_input(docs_prompt_to_send_to_model, max_allowed_txt_input, + only_return_if_trim_needed=False) + if not docs_prompt_to_send_to_model: + get_logger().error("_trim_docs_input returned an empty result.") + return "" + return docs_prompt_to_send_to_model + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty result.") + return "" + + def _format_model_answer(self, response_str: str, relevant_sections: list[dict[str, str]]) -> str: + try: + canonical_url_prefix, canonical_url_suffix = ( + self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None, + desired_branch=self.repo_desired_branch)) + answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections, + self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix) if answer_str: #Remove the question phrase and replace with light bulb and a heading mentioning this is an automated answer: answer_str = modify_answer_section(answer_str) - # For PR help docs, we return the answer string instead of publishing it + #In case the response should not be published and returned as string, stop here: if answer_str and self.return_as_string: - if int(response_yaml.get('question_is_relevant', '1')) == 0: - get_logger().warning(f"Chat help docs answer would be ignored due to an invalid question.", - artifacts={'answer_str': answer_str}) - return "" get_logger().info(f"Chat help docs answer", artifacts={'answer_str': answer_str}) return answer_str - - # Publish the answer - if not answer_str or int(response_yaml.get('question_is_relevant', '1')) == 0: + if not answer_str: get_logger().info(f"No answer found") return "" - if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_help_docs.enable_help_text: answer_str += "
\n\n
💡 Tool usage guide:
\n\n" answer_str += HelpMessage.get_help_docs_usage_guide() answer_str += "\n
\n" - - if get_settings().config.publish_output: - self.git_provider.publish_comment(answer_str) - else: - get_logger().info("Answer:", artifacts={'answer_str': answer_str}) - - except: - get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)') - - - def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False) -> List[str]: - matching_files = [] - - # Ensure extensions don't have leading dots and are lowercase - dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts] - - # Walk through directory and subdirectories - for root, _, files in os.walk(abs_docs_path): - for file in files: - if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]: - continue - # Check if file has one of the specified extensions - if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions): - matching_files.append(os.path.join(root, file)) - return matching_files + return answer_str + except Exception as e: + get_logger().exception(f"Unexpected exception thrown. Returning empty result.") + return ""