import copy from functools import partial from jinja2 import Environment, StrictUndefined import math import os import re from tempfile import TemporaryDirectory from typing import Dict, List, Optional, Tuple from pr_agent.algo import MAX_TOKENS from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import clip_tokens, get_max_tokens, load_yaml, ModelType from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider_with_context from pr_agent.log import get_logger #Common code that can be called from similar tools: def modify_answer_section(ai_response: str) -> str | None: # Gets the model's answer and relevant sources section, repacing the heading of the answer section with: # :bulb: Auto-generated documentation-based answer: """ For example: The following input: ### Question: \nThe following general issue was asked by a user: Title: How does one request to re-review a PR? More Info: I cannot seem to find to do this. ### Answer:\nAccording to the documentation, one needs to invoke the command: /review #### Relevant Sources... Should become: ### :bulb: Auto-generated documentation-based answer:\n According to the documentation, one needs to invoke the command: /review #### Relevant Sources... """ model_answer_and_relevant_sections_in_response \ = _extract_model_answer_and_relevant_sources(ai_response) if model_answer_and_relevant_sections_in_response is not None: cleaned_question_with_answer = "### :bulb: Auto-generated documentation-based answer:\n" cleaned_question_with_answer += model_answer_and_relevant_sections_in_response return cleaned_question_with_answer get_logger().warning(f"Either no answer section found, or that section is malformed: {ai_response}") return None def _extract_model_answer_and_relevant_sources(ai_response: str) -> str | None: # It is assumed that the input contains several sections with leading "### ", # where the answer is the last one of them having the format: "### Answer:\n"), since the model returns the answer # AFTER the user question. By splitting using the string: "### Answer:\n" and grabbing the last part, # the model answer is guaranteed to be in that last part, provided it is followed by a "#### Relevant Sources:\n\n". # (for more details, see here: https://github.com/Codium-ai/pr-agent-pro/blob/main/pr_agent/tools/pr_help_message.py#L173) """ For example: ### Question: \nHow does one request to re-review a PR?\n\n ### Answer:\nAccording to the documentation, one needs to invoke the command: /review\n\n #### Relevant Sources:\n\n... The answer part is: "According to the documentation, one needs to invoke the command: /review\n\n" followed by "Relevant Sources:\n\n". """ if "### Answer:\n" in ai_response: model_answer_and_relevant_sources_sections_in_response = ai_response.split("### Answer:\n")[-1] # Split such part by "Relevant Sources" section to contain only the model answer: if "#### Relevant Sources:\n\n" in model_answer_and_relevant_sources_sections_in_response: model_answer_section_in_response \ = model_answer_and_relevant_sources_sections_in_response.split("#### Relevant Sources:\n\n")[0] get_logger().info(f"Found model answer: {model_answer_section_in_response}") return model_answer_and_relevant_sources_sections_in_response \ if len(model_answer_section_in_response) > 0 else None get_logger().warning(f"Either no answer section found, or that section is malformed: {ai_response}") return None def get_maximal_text_input_length_for_token_count_estimation(): model = get_settings().config.model if 'claude-3-7-sonnet' in model.lower(): return 900000 #Claude API for token estimation allows maximal text input of 900K chars return math.inf #Otherwise, no known limitation on input text just for token estimation # Load documentation files to memory, decorating them with a header to mark where each file begins, # as to help the LLM to give a better answer. def aggregate_documentation_files_for_prompt_contents(base_path: str, doc_files: List[str]) -> Optional[str]: docs_prompt = "" for file in doc_files: try: with open(file, 'r', encoding='utf-8') as f: content = f.read() # Skip files with no text content if not re.search(r'[a-zA-Z]', content): continue file_path = str(file).replace(str(base_path), '') docs_prompt += f"\n==file name==\n\n{file_path}\n\n==file content==\n\n{content.strip()}\n=========\n\n" except Exception as e: get_logger().warning(f"Error while reading the file {file}: {e}") continue if not docs_prompt: get_logger().error("Couldn't find any usable documentation files. Returning None.") return None return docs_prompt def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: List[Dict[str, str]], supported_suffixes: List[str], base_url_prefix: str, base_url_suffix: str="") -> str: answer_str = "" answer_str += f"### Question: \n{question_str}\n\n" answer_str += f"### Answer:\n{response_str.strip()}\n\n" answer_str += f"#### Relevant Sources:\n\n" for section in relevant_sections: file = section.get('file_name').strip() ext = [suffix for suffix in supported_suffixes if file.endswith(suffix)] if not ext: get_logger().warning(f"Unsupported file extension: {file}") continue if str(section['relevant_section_header_string']).strip(): markdown_header = format_markdown_header(section['relevant_section_header_string']) if base_url_prefix: answer_str += f"> - {base_url_prefix}{file}{base_url_suffix}#{markdown_header}\n" else: answer_str += f"> - {base_url_prefix}{file}{base_url_suffix}\n" return answer_str def format_markdown_header(header: str) -> str: try: # First, strip common characters from both ends cleaned = header.strip('# 💎\n') # Define all characters to be removed/replaced in a single pass replacements = { "'": '', "`": '', '(': '', ')': '', ',': '', '.': '', '?': '', '!': '', ' ': '-' } # Compile regex pattern for characters to remove pattern = re.compile('|'.join(map(re.escape, replacements.keys()))) # Perform replacements in a single pass and convert to lowercase return pattern.sub(lambda m: replacements[m.group()], cleaned).lower() except Exception: get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header}) return "" def clean_markdown_content(content: str) -> str: """ Remove hidden comments and unnecessary elements from markdown content to reduce size. Args: content: The original markdown content Returns: Cleaned markdown content """ # Remove HTML comments content = re.sub(r'', '', content, flags=re.DOTALL) # Remove frontmatter (YAML between --- or +++ delimiters) content = re.sub(r'^---\s*\n.*?\n---\s*\n', '', content, flags=re.DOTALL) content = re.sub(r'^\+\+\+\s*\n.*?\n\+\+\+\s*\n', '', content, flags=re.DOTALL) # Remove excessive blank lines (more than 2 consecutive) content = re.sub(r'\n{3,}', '\n\n', content) # Remove HTML tags that are often used for styling only content = re.sub(r'|||', '', content, flags=re.DOTALL) # Remove image alt text which can be verbose content = re.sub(r'!\[(.*?)\]', '![]', content) # Remove images completely content = re.sub(r'!\[.*?\]\(.*?\)', '', content) # Remove simple HTML tags but preserve content between them content = re.sub(r'<(?!table|tr|td|th|thead|tbody)([a-zA-Z][a-zA-Z0-9]*)[^>]*>(.*?)', r'\2', content, flags=re.DOTALL) return content.strip() class PredictionPreparator: def __init__(self, ai_handler, vars, system_prompt, user_prompt): self.ai_handler = ai_handler variables = copy.deepcopy(vars) environment = Environment(undefined=StrictUndefined) self.system_prompt = environment.from_string(system_prompt).render(variables) self.user_prompt = environment.from_string(user_prompt).render(variables) async def __call__(self, model: str) -> str: try: response, finish_reason = await self.ai_handler.chat_completion( model=model, temperature=get_settings().config.temperature, system=self.system_prompt, user=self.user_prompt) return response except Exception as e: get_logger().error(f"Error while preparing prediction: {e}") return "" class PRHelpDocs(object): def __init__(self, ctx_url, ai_handler:partial[BaseAiHandler,] = LiteLLMAIHandler, args: Tuple[str]=None, return_as_string: bool=False): self.ctx_url = ctx_url self.question = args[0] if args else None self.return_as_string = return_as_string self.repo_url_given_explicitly = True self.repo_url = get_settings()['PR_HELP_DOCS.REPO_URL'] self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README']) self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS'] self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH'] retrieved_settings = [self.include_root_readme_file, self.supported_doc_exts, self.docs_path] if any([setting is None for setting in retrieved_settings]): raise Exception(f"One of the settings is invalid: {retrieved_settings}") self.git_provider = get_git_provider_with_context(ctx_url) if not self.git_provider: raise Exception(f"No git provider found at {ctx_url}") if not self.repo_url: self.repo_url_given_explicitly = False get_logger().debug(f"No explicit repo url provided, deducing it from type: {self.git_provider.__class__.__name__} " f"context url: {self.ctx_url}") self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url) get_logger().debug(f"deduced repo url: {self.repo_url}") try: #Try to get the same branch in case triggered from a PR: self.repo_desired_branch = self.git_provider.get_pr_branch() except: #Otherwise (such as in issues) self.repo_desired_branch = get_settings()['PR_HELP_DOCS.REPO_DEFAULT_BRANCH'] finally: get_logger().debug(f"repo_desired_branch: {self.repo_desired_branch}") self.ai_handler = ai_handler() self.vars = { "docs_url": self.repo_url, "question": self.question, "snippets": "", } self.token_handler = TokenHandler(None, self.vars, get_settings().pr_help_docs_prompts.system, get_settings().pr_help_docs_prompts.user) async def run(self): if not self.question: get_logger().warning('No question provided. Will do nothing.') return None try: # Clone the repository and gather relevant documentation files. docs_prompt = None with TemporaryDirectory() as tmp_dir: get_logger().debug(f"About to clone repository: {self.repo_url} to temporary directory: {tmp_dir}...") returned_cloned_repo_root = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False) if not returned_cloned_repo_root: raise Exception(f"Failed to clone {self.repo_url} to {tmp_dir}") get_logger().debug(f"About to gather relevant documentation files...") doc_files = [] if self.include_root_readme_file: for root, _, files in os.walk(returned_cloned_repo_root.path): # Only look at files in the root directory, not subdirectories if root == returned_cloned_repo_root.path: for file in files: if file.lower().startswith("readme."): doc_files.append(os.path.join(root, file)) abs_docs_path = os.path.join(returned_cloned_repo_root.path, self.docs_path) if os.path.exists(abs_docs_path): doc_files.extend(self._find_all_document_files_matching_exts(abs_docs_path, ignore_readme=(self.docs_path=='.'))) if not doc_files: get_logger().warning(f"No documentation files found matching file extensions: " f"{self.supported_doc_exts} under repo: {self.repo_url} path: {self.docs_path}") return None get_logger().info(f'Answering a question inside context {self.ctx_url} for repo: {self.repo_url}' f' using the following documentation files: ', artifacts={'doc_files': doc_files}) docs_prompt = aggregate_documentation_files_for_prompt_contents(returned_cloned_repo_root.path, doc_files) if not docs_prompt: get_logger().warning(f"Error reading one of the documentation files. Returning with no result...") return None docs_prompt_to_send_to_model = docs_prompt # Estimate how many tokens will be needed. Trim in case of exceeding limit. # Firstly, check if text needs to be trimmed, as some models fail to return the estimated token count if the input text is too long. max_allowed_txt_input = get_maximal_text_input_length_for_token_count_estimation() if len(docs_prompt_to_send_to_model) >= max_allowed_txt_input: get_logger().warning(f"Text length: {len(docs_prompt_to_send_to_model)} exceeds the current returned limit of {max_allowed_txt_input} just for token count estimation. Trimming the text...") docs_prompt_to_send_to_model = docs_prompt_to_send_to_model[:max_allowed_txt_input] # Then, count the tokens in the prompt. If the count exceeds the limit, trim the text. token_count = self.token_handler.count_tokens(docs_prompt_to_send_to_model, force_accurate=True) get_logger().debug(f"Estimated token count of documentation to send to model: {token_count}") model = get_settings().config.model if model in MAX_TOKENS: max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt else: max_tokens_full = get_max_tokens(model) delta_output = 5000 #Elbow room to reduce chance of exceeding token limit or model paying less attention to prompt guidelines. if token_count > max_tokens_full - delta_output: docs_prompt_to_send_to_model = clean_markdown_content(docs_prompt_to_send_to_model) #Reduce unnecessary text/images/etc. get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Attempting to clip text to fit within the limit...") docs_prompt_to_send_to_model = clip_tokens(docs_prompt_to_send_to_model, max_tokens_full - delta_output, num_input_tokens=token_count) self.vars['snippets'] = docs_prompt_to_send_to_model.strip() # Run the AI model and extract sections from its response response = await retry_with_fallback_models(PredictionPreparator(self.ai_handler, self.vars, get_settings().pr_help_docs_prompts.system, get_settings().pr_help_docs_prompts.user), model_type=ModelType.REGULAR) response_yaml = load_yaml(response) if not response_yaml: get_logger().exception("Failed to parse the AI response.", artifacts={'response': response}) raise Exception(f"Failed to parse the AI response.") response_str = response_yaml.get('response') relevant_sections = response_yaml.get('relevant_sections') if not response_str or not relevant_sections: get_logger().exception("Failed to extract response/relevant sections.", artifacts={'response_str': response_str, 'relevant_sections': relevant_sections}) raise Exception(f"Failed to extract response/relevant sections.") # Format the response as markdown canonical_url_prefix, canonical_url_suffix = self.git_provider.get_canonical_url_parts(repo_git_url=self.repo_url if self.repo_url_given_explicitly else None, desired_branch=self.repo_desired_branch) answer_str = format_markdown_q_and_a_response(self.question, response_str, relevant_sections, self.supported_doc_exts, canonical_url_prefix, canonical_url_suffix) if answer_str: #Remove the question phrase and replace with light bulb and a heading mentioning this is an automated answer: answer_str = modify_answer_section(answer_str) # For PR help docs, we return the answer string instead of publishing it if answer_str and self.return_as_string: if int(response_yaml.get('question_is_relevant', '1')) == 0: get_logger().warning(f"Chat help docs answer would be ignored due to an invalid question.", artifacts={'answer_str': answer_str}) return "" get_logger().info(f"Chat help docs answer", artifacts={'answer_str': answer_str}) return answer_str # Publish the answer if not answer_str or int(response_yaml.get('question_is_relevant', '1')) == 0: get_logger().info(f"No answer found") return "" if get_settings().config.publish_output: self.git_provider.publish_comment(answer_str) else: get_logger().info("Answer:", artifacts={'answer_str': answer_str}) except: get_logger().exception('failed to provide answer to given user question as a result of a thrown exception (see above)') def _find_all_document_files_matching_exts(self, abs_docs_path: str, ignore_readme=False) -> List[str]: matching_files = [] # Ensure extensions don't have leading dots and are lowercase dotless_extensions = [ext.lower().lstrip('.') for ext in self.supported_doc_exts] # Walk through directory and subdirectories for root, _, files in os.walk(abs_docs_path): for file in files: if ignore_readme and root == abs_docs_path and file.lower() in [f"readme.{ext}" for ext in dotless_extensions]: continue # Check if file has one of the specified extensions if any(file.lower().endswith(f'.{ext}') for ext in dotless_extensions): matching_files.append(os.path.join(root, file)) return matching_files