diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 72f6184e..aab94894 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -76,7 +76,35 @@ class TokenHandler: get_logger().error(f"Error in _get_system_user_tokens: {e}") return 0 - def count_tokens(self, patch: str) -> int: + def calc_claude_tokens(self, patch): + try: + import anthropic + from pr_agent.algo import MAX_TOKENS + client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key')) + MaxTokens = MAX_TOKENS[get_settings().config.model] + + # Check if the content size is too large (9MB limit) + if len(patch.encode('utf-8')) > 9_000_000: + get_logger().warning( + "Content too large for Anthropic token counting API, falling back to local tokenizer" + ) + return MaxTokens + + response = client.messages.count_tokens( + model="claude-3-7-sonnet-20250219", + system="system", + messages=[{ + "role": "user", + "content": patch + }], + ) + return response.input_tokens + + except Exception as e: + get_logger().error( f"Error in Anthropic token counting: {e}") + return MaxTokens + + def count_tokens(self, patch: str, force_accurate=False) -> int: """ Counts the number of tokens in a given patch string. @@ -86,4 +114,6 @@ class TokenHandler: Returns: The number of tokens in the patch string. """ - return len(self.encoder.encode(patch, disallowed_special=())) + if force_accurate and 'claude' in get_settings().config.model.lower() and get_settings(use_context=False).get('anthropic.key'): + return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models + return len(self.encoder.encode(patch, disallowed_special=())) \ No newline at end of file diff --git a/pr_agent/git_providers/azuredevops_provider.py b/pr_agent/git_providers/azuredevops_provider.py index e305d7c2..67ab0b6e 100644 --- a/pr_agent/git_providers/azuredevops_provider.py +++ b/pr_agent/git_providers/azuredevops_provider.py @@ -623,3 +623,6 @@ class AzureDevopsProvider(GitProvider): def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: return self.pr_url+f"?_a=files&path={relevant_file}" + + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + raise Exception("Not implemented!") diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index 9967c310..1a2421e2 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -75,6 +75,9 @@ class BitbucketProvider(GitProvider): get_logger().exception(f"url is not a valid merge requests url: {self.pr_url}") return "" + # Given a git repo url, return prefix and suffix of the provider in order to view a given file belonging to that repo. + # Example: git clone git clone https://bitbucket.org/codiumai/pr-agent.git and branch: main -> prefix: "https://bitbucket.org/codiumai/pr-agent/src/main", suffix: "" + # In case git url is not provided, provider will use PR context (which includes branch) to determine the prefix and suffix. def get_canonical_url_parts(self, repo_git_url:str=None, desired_branch:str=None) -> Tuple[str, str]: scheme_and_netloc = None if repo_git_url: @@ -86,6 +89,7 @@ class BitbucketProvider(GitProvider): return ("", "") workspace_name, project_name = repo_path.split('/') else: + desired_branch = self.get_pr_branch() parsed_pr_url = urlparse(self.pr_url) scheme_and_netloc = parsed_pr_url.scheme + "://" + parsed_pr_url.netloc workspace_name, project_name = (self.workspace_slug, self.repo_slug) @@ -586,3 +590,21 @@ class BitbucketProvider(GitProvider): # bitbucket does not support labels def get_pr_labels(self, update=False): pass + #Clone related + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + if "bitbucket.org" not in repo_url_to_clone: + get_logger().error("Repo URL is not a valid bitbucket URL.") + return None + bearer_token = self.bearer_token + if not bearer_token: + get_logger().error("No bearer token provided. Returning None") + return None + + #For example: For repo: https://bitbucket.org/codiumai/pr-agent-tests.git + #clone url will be: https://x-token-auth:@bitbucket.org/codiumai/pr-agent-tests.git + (scheme, base_url) = repo_url_to_clone.split("bitbucket.org") + if not all([scheme, base_url]): + get_logger().error(f"repo_url_to_clone: {repo_url_to_clone} is not a valid bitbucket URL.") + return None + clone_url = f"{scheme}x-token-auth:{bearer_token}@bitbucket.org{base_url}" + return clone_url diff --git a/pr_agent/git_providers/bitbucket_server_provider.py b/pr_agent/git_providers/bitbucket_server_provider.py index fef1bc0e..849b86d2 100644 --- a/pr_agent/git_providers/bitbucket_server_provider.py +++ b/pr_agent/git_providers/bitbucket_server_provider.py @@ -7,6 +7,8 @@ from urllib.parse import quote_plus, urlparse from atlassian.bitbucket import Bitbucket from requests.exceptions import HTTPError +import shlex +import subprocess from ..algo.git_patch_processing import decode_if_bytes from ..algo.language_handler import is_valid_file @@ -47,6 +49,35 @@ class BitbucketServerProvider(GitProvider): if pr_url: self.set_pr(pr_url) + def get_git_repo_url(self, pr_url: str=None) -> str: #bitbucket server does not support issue url, so ignore param + try: + parsed_url = urlparse(self.pr_url) + return f"{parsed_url.scheme}://{parsed_url.netloc}/scm/{self.workspace_slug.lower()}/{self.repo_slug.lower()}.git" + except Exception as e: + get_logger().exception(f"url is not a valid merge requests url: {self.pr_url}") + return "" + + # Given a git repo url, return prefix and suffix of the provider in order to view a given file belonging to that repo. + # Example: https://bitbucket.dev.my_inc.com/scm/my_work/my_repo.git and branch: my_branch -> prefix: "https://bitbucket.dev.my_inc.com/projects/MY_WORK/repos/my_repo/browse/src", suffix: "?at=refs%2Fheads%2Fmy_branch" + # In case git url is not provided, provider will use PR context (which includes branch) to determine the prefix and suffix. + def get_canonical_url_parts(self, repo_git_url:str=None, desired_branch:str=None) -> Tuple[str, str]: + workspace_name = None + project_name = None + if not repo_git_url: + desired_branch = self.get_pr_branch() + workspace_name = self.workspace_slug + project_name = self.repo_slug + else: + repo_path = repo_git_url.split('.git')[0].split('scm/')[-1] + if repo_path.count('/') == 1: # Has to have the form / + workspace_name, project_name = repo_path.split('/') + if not workspace_name or not project_name: + get_logger().error(f"workspace_name or project_name not found in context, either git url: {repo_git_url} or uninitialized workspace/project.") + return ("", "") + prefix = f"{self.bitbucket_server_url}/projects/{workspace_name}/repos/{project_name}/browse" + suffix = f"?at=refs%2Fheads%2F{desired_branch}" + return (prefix, suffix) + def get_repo_settings(self): try: content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch()) @@ -138,31 +169,6 @@ class BitbucketServerProvider(GitProvider): return False return True - def get_git_repo_url(self, pr_url: str=None) -> str: #bitbucket server does not support issue url, so ignore param - try: - parsed_url = urlparse(self.pr_url) - return f"{parsed_url.scheme}://{parsed_url.netloc}/scm/{self.workspace_slug.lower()}/{self.repo_slug.lower()}.git" - except Exception as e: - get_logger().exception(f"url is not a valid merge requests url: {self.pr_url}") - return "" - - def get_canonical_url_parts(self, repo_git_url:str=None, desired_branch:str=None) -> Tuple[str, str]: - workspace_name = None - project_name = None - if not repo_git_url: - workspace_name = self.workspace_slug - project_name = self.repo_slug - else: - repo_path = repo_git_url.split('.git')[0].split('scm/')[-1] - if repo_path.count('/') == 1: # Has to have the form / - workspace_name, project_name = repo_path.split('/') - if not workspace_name or not project_name: - get_logger().error(f"workspace_name or project_name not found in context, either git url: {repo_git_url} or uninitialized workspace/project.") - return ("", "") - prefix = f"{self.bitbucket_server_url}/projects/{workspace_name}/repos/{project_name}/browse" - suffix = f"?at=refs%2Fheads%2F{desired_branch}" - return (prefix, suffix) - def set_pr(self, pr_url: str): self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url) self.pr = self._get_pr() @@ -506,3 +512,28 @@ class BitbucketServerProvider(GitProvider): def _get_merge_base(self): return f"rest/api/latest/projects/{self.workspace_slug}/repos/{self.repo_slug}/pull-requests/{self.pr_num}/merge-base" + # Clone related + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + if 'bitbucket.' not in repo_url_to_clone: + get_logger().error("Repo URL is not a valid bitbucket URL.") + return None + bearer_token = self.bearer_token + if not bearer_token: + get_logger().error("No bearer token provided. Returning None") + return None + # Return unmodified URL as the token is passed via HTTP headers in _clone_inner, as seen below. + return repo_url_to_clone + + #Overriding the shell command, since for some reason usage of x-token-auth doesn't work, as mentioned here: + # https://stackoverflow.com/questions/56760396/cloning-bitbucket-server-repo-with-access-tokens + def _clone_inner(self, repo_url: str, dest_folder: str, operation_timeout_in_seconds: int=None): + bearer_token = self.bearer_token + if not bearer_token: + #Shouldn't happen since this is checked in _prepare_clone, therefore - throwing an exception. + raise RuntimeError(f"Bearer token is required!") + + cli_args = shlex.split(f"git clone -c http.extraHeader='Authorization: Bearer {bearer_token}' " + f"--filter=blob:none --depth 1 {repo_url} {dest_folder}") + + subprocess.run(cli_args, check=True, # check=True will raise an exception if the command fails + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=operation_timeout_in_seconds) diff --git a/pr_agent/git_providers/codecommit_provider.py b/pr_agent/git_providers/codecommit_provider.py index c4f1ed7b..3dbe15bc 100644 --- a/pr_agent/git_providers/codecommit_provider.py +++ b/pr_agent/git_providers/codecommit_provider.py @@ -495,3 +495,6 @@ class CodeCommitProvider(GitProvider): lang: round(count / total_files * 100) for lang, count in lang_count.items() } return lang_percentage + + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + raise Exception("Not implemented!") diff --git a/pr_agent/git_providers/gerrit_provider.py b/pr_agent/git_providers/gerrit_provider.py index ba587656..cce4fad4 100644 --- a/pr_agent/git_providers/gerrit_provider.py +++ b/pr_agent/git_providers/gerrit_provider.py @@ -397,3 +397,6 @@ class GerritProvider(GitProvider): def get_pr_branch(self): return self.repo.head + + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + raise Exception("Not implemented!") diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index e5a2927f..4aecbf0d 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod # enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) +import os +import shutil +import subprocess from typing import Optional, Tuple from pr_agent.algo.types import FilePatchInfo @@ -20,13 +23,69 @@ class GitProvider(ABC): return "" # Given a git repo url, return prefix and suffix of the provider in order to view a given file belonging to that repo. Needs to be implemented by the provider. - # For example: For a git: https://git_provider.com/MY_PROJECT/MY_REPO.git then it should return ('https://git_provider.com/projects/MY_PROJECT/repos/MY_REPO', '?=') - # so that to properly view the file: docs/readme.md -> /docs/readme.md -> https://git_provider.com/projects/MY_PROJECT/repos/MY_REPO/docs/readme.md?=) + # For example: For a git: https://git_provider.com/MY_PROJECT/MY_REPO.git and desired branch: then it should return ('https://git_provider.com/projects/MY_PROJECT/repos/MY_REPO/.../', '?=') + # so that to properly view the file: docs/readme.md -> /docs/readme.md -> https://git_provider.com/projects/MY_PROJECT/repos/MY_REPO//docs/readme.md?=) def get_canonical_url_parts(self, repo_git_url:str, desired_branch:str) -> Tuple[str, str]: get_logger().warning("Not implemented! Returning empty prefix and suffix") return ("", "") + #Clone related API + #An object which ensures deletion of a cloned repo, once it becomes out of scope. + # Example usage: + # with TemporaryDirectory() as tmp_dir: + # returned_obj: GitProvider.ScopedClonedRepo = self.git_provider.clone(self.repo_url, tmp_dir, remove_dest_folder=False) + # print(returned_obj.path) #Use returned_obj.path. + # #From this point, returned_obj.path may be deleted at any point and therefore must not be used. + class ScopedClonedRepo(object): + def __init__(self, dest_folder): + self.path = dest_folder + + def __del__(self): + if self.path and os.path.exists(self.path): + shutil.rmtree(self.path, ignore_errors=True) + + @abstractmethod + #Method to allow implementors to manipulate the repo url to clone (such as embedding tokens in the url string). + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + pass + + # Does a shallow clone, using a forked process to support a timeout guard. + # In case operation has failed, it is expected to throw an exception as this method does not return a value. + def _clone_inner(self, repo_url: str, dest_folder: str, operation_timeout_in_seconds: int=None) -> None: + #The following ought to be equivalent to: + # #Repo.clone_from(repo_url, dest_folder) + # , but with throwing an exception upon timeout. + # Note: This can only be used in context that supports using pipes. + subprocess.run([ + "git", "clone", + "--filter=blob:none", + "--depth", "1", + repo_url, dest_folder + ], check=True, # check=True will raise an exception if the command fails + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=operation_timeout_in_seconds) + + CLONE_TIMEOUT_SEC = 20 + # Clone a given url to a destination folder. If successful, returns an object that wraps the destination folder, + # deleting it once it is garbage collected. See: GitProvider.ScopedClonedRepo for more details. + def clone(self, repo_url_to_clone: str, dest_folder: str, remove_dest_folder: bool = True, + operation_timeout_in_seconds: int=CLONE_TIMEOUT_SEC) -> ScopedClonedRepo|None: + returned_obj = None + clone_url = self._prepare_clone_url_with_token(repo_url_to_clone) + if not clone_url: + get_logger().error("Clone failed: Unable to obtain url to clone.") + return returned_obj + try: + if remove_dest_folder and os.path.exists(dest_folder) and os.path.isdir(dest_folder): + shutil.rmtree(dest_folder) + self._clone_inner(clone_url, dest_folder, operation_timeout_in_seconds) + returned_obj = GitProvider.ScopedClonedRepo(dest_folder) + except Exception as e: + get_logger().exception(f"Clone failed: Could not clone url.", + artifact={"error": str(e), "url": clone_url, "dest_folder": dest_folder}) + finally: + return returned_obj + @abstractmethod def get_files(self) -> list: pass diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index e020cceb..d83223bc 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -85,6 +85,9 @@ class GithubProvider(GitProvider): repo_path = self._get_owner_and_repo_path(issues_or_pr_url) return f"{issues_or_pr_url.split(repo_path)[0]}{repo_path}.git" + # Given a git repo url, return prefix and suffix of the provider in order to view a given file belonging to that repo. + # Example: https://github.com/qodo-ai/pr-agent.git and branch: v0.8 -> prefix: "https://github.com/qodo-ai/pr-agent/blob/v0.8", suffix: "" + # In case git url is not provided, provider will use PR context (which includes branch) to determine the prefix and suffix. def get_canonical_url_parts(self, repo_git_url:str, desired_branch:str) -> Tuple[str, str]: owner = None repo = None @@ -102,6 +105,7 @@ class GithubProvider(GitProvider): if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, use self.repo object: owner, repo = self.repo.split('/') scheme_and_netloc = self.base_url_html + desired_branch = self.get_pr_branch() if not any([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)") return ("", "") @@ -750,9 +754,8 @@ class GithubProvider(GitProvider): return repo_name, issue_number def _get_github_client(self): - deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") - - if deployment_type == 'app': + self.deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") + if self.deployment_type == 'app': try: private_key = get_settings().github.private_key app_id = get_settings().github.app_id @@ -762,16 +765,19 @@ class GithubProvider(GitProvider): raise ValueError("GitHub app installation ID is required when using GitHub app deployment") auth = AppAuthentication(app_id=app_id, private_key=private_key, installation_id=self.installation_id) - return Github(app_auth=auth, base_url=self.base_url) - - if deployment_type == 'user': + self.auth = auth + elif self.deployment_type == 'user': try: token = get_settings().github.user_token except AttributeError as e: raise ValueError( "GitHub token is required when using user deployment. See: " "https://github.com/Codium-ai/pr-agent#method-2-run-from-source") from e - return Github(auth=Auth.Token(token), base_url=self.base_url) + self.auth = Auth.Token(token) + if self.auth: + return Github(auth=self.auth, base_url=self.base_url) + else: + raise ValueError("Could not authenticate to GitHub") def _get_repo(self): if hasattr(self, 'repo_obj') and \ @@ -1111,3 +1117,37 @@ class GithubProvider(GitProvider): get_logger().error(f"Failed to process patch for committable comment, error: {e}") return code_suggestions_copy + #Clone related + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + scheme = "https://" + + #For example, to clone: + #https://github.com/Codium-ai/pr-agent-pro.git + #Need to embed inside the github token: + #https://@github.com/Codium-ai/pr-agent-pro.git + + github_token = self.auth.token + github_base_url = self.base_url_html + if not all([github_token, github_base_url]): + get_logger().error("Either missing auth token or missing base url") + return None + if scheme not in github_base_url: + get_logger().error(f"Base url: {github_base_url} is missing prefix: {scheme}") + return None + github_com = github_base_url.split(scheme)[1] # e.g. 'github.com' or github..com + if not github_com: + get_logger().error(f"Base url: {github_base_url} has an empty base url") + return None + if github_com not in repo_url_to_clone: + get_logger().error(f"url to clone: {repo_url_to_clone} does not contain {github_com}") + return None + repo_full_name = repo_url_to_clone.split(github_com)[-1] + if not repo_full_name: + get_logger().error(f"url to clone: {repo_url_to_clone} is malformed") + return None + + clone_url = scheme + if self.deployment_type == 'app': + clone_url += "git:" + clone_url += f"{github_token}@{github_com}{repo_full_name}" + return clone_url diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 1040f3b5..c1c589c6 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -76,6 +76,9 @@ class GitLabProvider(GitProvider): return "" return f"{provider_url.split(repo_path)[0]}{repo_path}.git" + # Given a git repo url, return prefix and suffix of the provider in order to view a given file belonging to that repo. + # Example: https://gitlab.com/codiumai/pr-agent.git and branch: t1 -> prefix: "https://gitlab.com/codiumai/pr-agent/-/blob/t1", suffix: "?ref_type=heads" + # In case git url is not provided, provider will use PR context (which includes branch) to determine the prefix and suffix. def get_canonical_url_parts(self, repo_git_url:str=None, desired_branch:str=None) -> Tuple[str, str]: repo_path = "" if not repo_git_url and not self.pr_url: @@ -83,6 +86,7 @@ class GitLabProvider(GitProvider): return ("", "") if not repo_git_url: #Use PR url as context repo_path = self._get_project_path_from_pr_or_issue_url(self.pr_url) + desired_branch = self.get_pr_branch() else: #Use repo git url repo_path = repo_git_url.split('.git')[0].split('.com/')[-1] prefix = f"{self.gitlab_url}/{repo_path}/-/blob/{desired_branch}" @@ -629,3 +633,24 @@ class GitLabProvider(GitProvider): get_logger().info(f"Failed adding line link, error: {e}") return "" + #Clone related + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + if "gitlab." not in repo_url_to_clone: + get_logger().error(f"Repo URL: {repo_url_to_clone} is not a valid gitlab URL.") + return None + (scheme, base_url) = repo_url_to_clone.split("gitlab.") + access_token = self.gl.oauth_token + if not all([scheme, access_token, base_url]): + get_logger().error(f"Either no access token found, or repo URL: {repo_url_to_clone} " + f"is missing prefix: {scheme} and/or base URL: {base_url}.") + return None + + #Note that the ""official"" method found here: + # https://docs.gitlab.com/user/profile/personal_access_tokens/#clone-repository-using-personal-access-token + # requires a username, which may not be applicable. + # The following solution is taken from: https://stackoverflow.com/questions/25409700/using-gitlab-token-to-clone-without-authentication/35003812#35003812 + # For example: For repo url: https://gitlab.codium-inc.com/qodo/autoscraper.git + # Then to clone one will issue: 'git clone https://oauth2:@gitlab.codium-inc.com/qodo/autoscraper.git' + + clone_url = f"{scheme}oauth2:{access_token}@gitlab.{base_url}" + return clone_url diff --git a/pr_agent/git_providers/local_git_provider.py b/pr_agent/git_providers/local_git_provider.py index 42028976..beaab24d 100644 --- a/pr_agent/git_providers/local_git_provider.py +++ b/pr_agent/git_providers/local_git_provider.py @@ -190,3 +190,6 @@ class LocalGitProvider(GitProvider): def get_pr_labels(self, update=False): raise NotImplementedError('Getting labels is not implemented for the local git provider') + + def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: + raise Exception("Not implemented!") diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index e9ee6e6d..899e9c1f 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -214,6 +214,7 @@ num_retrieved_snippets=5 [pr_help_docs] repo_url = "" #If not overwritten, will use the repo from where the context came from (issue or PR) +repo_default_branch = "main" docs_path = "docs" exclude_root_readme = false supported_doc_exts = [".md", ".mdx", ".rst"] diff --git a/pr_agent/tools/pr_help_docs.py b/pr_agent/tools/pr_help_docs.py index 3958ac52..d8dbcc66 100644 --- a/pr_agent/tools/pr_help_docs.py +++ b/pr_agent/tools/pr_help_docs.py @@ -36,7 +36,7 @@ def modify_answer_section(ai_response: str) -> str | None: #### Relevant Sources... """ model_answer_and_relevant_sections_in_response \ - = _extract_model_answer_and_relevant_sources(ai_response) + = extract_model_answer_and_relevant_sources(ai_response) if model_answer_and_relevant_sections_in_response is not None: cleaned_question_with_answer = "### :bulb: Auto-generated documentation-based answer:\n" cleaned_question_with_answer += model_answer_and_relevant_sections_in_response @@ -44,7 +44,7 @@ def modify_answer_section(ai_response: str) -> str | None: get_logger().warning(f"Either no answer section found, or that section is malformed: {ai_response}") return None -def _extract_model_answer_and_relevant_sources(ai_response: str) -> str | None: +def extract_model_answer_and_relevant_sources(ai_response: str) -> str | None: # It is assumed that the input contains several sections with leading "### ", # where the answer is the last one of them having the format: "### Answer:\n"), since the model returns the answer # AFTER the user question. By splitting using the string: "### Answer:\n" and grabbing the last part, @@ -71,7 +71,6 @@ def _extract_model_answer_and_relevant_sources(ai_response: str) -> str | None: get_logger().warning(f"Either no answer section found, or that section is malformed: {ai_response}") return None - def get_maximal_text_input_length_for_token_count_estimation(): model = get_settings().config.model if 'claude-3-7-sonnet' in model.lower(): @@ -204,7 +203,8 @@ class PRHelpDocs(object): self.question = args[0] if args else None self.return_as_string = return_as_string self.repo_url_given_explicitly = True - self.repo_url = get_settings()['PR_HELP_DOCS.REPO_URL'] + self.repo_url = get_settings().get('PR_HELP_DOCS.REPO_URL', '') + self.repo_desired_branch = get_settings().get('PR_HELP_DOCS.REPO_DEFAULT_BRANCH', 'main') #Ignored if self.repo_url is empty self.include_root_readme_file = not(get_settings()['PR_HELP_DOCS.EXCLUDE_ROOT_README']) self.supported_doc_exts = get_settings()['PR_HELP_DOCS.SUPPORTED_DOC_EXTS'] self.docs_path = get_settings()['PR_HELP_DOCS.DOCS_PATH'] @@ -222,12 +222,7 @@ class PRHelpDocs(object): f"context url: {self.ctx_url}") self.repo_url = self.git_provider.get_git_repo_url(self.ctx_url) get_logger().debug(f"deduced repo url: {self.repo_url}") - try: #Try to get the same branch in case triggered from a PR: - self.repo_desired_branch = self.git_provider.get_pr_branch() - except: #Otherwise (such as in issues) - self.repo_desired_branch = get_settings()['PR_HELP_DOCS.REPO_DEFAULT_BRANCH'] - finally: - get_logger().debug(f"repo_desired_branch: {self.repo_desired_branch}") + self.repo_desired_branch = None #Inferred from the repo provider. self.ai_handler = ai_handler() self.vars = { diff --git a/requirements.txt b/requirements.txt index ad42140f..2625ad66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiohttp==3.9.5 -anthropic[vertex]==0.47.1 +anthropic>=0.48 +#anthropic[vertex]==0.47.1 atlassian-python-api==3.41.4 azure-devops==7.1.0b3 azure-identity==1.15.0