From e0f295659dbc33c681dfd95859f149b3c5aae854 Mon Sep 17 00:00:00 2001 From: Ori Kotek Date: Wed, 9 Aug 2023 12:17:54 +0300 Subject: [PATCH] A less hacky way --- pr_agent/algo/pr_processing.py | 19 ++++++++++--------- pr_agent/algo/token_handler.py | 6 +++++- pr_agent/git_providers/bitbucket_provider.py | 4 ++++ pr_agent/git_providers/git_provider.py | 4 ++++ pr_agent/git_providers/github_provider.py | 12 +++++++++--- pr_agent/git_providers/gitlab_provider.py | 11 +++++++++-- pr_agent/settings/configuration.toml | 2 ++ pr_agent/tools/pr_reviewer.py | 2 -- 8 files changed, 43 insertions(+), 17 deletions(-) diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index be3a461b..b195f9f4 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -11,7 +11,7 @@ from github import RateLimitExceededException from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages -from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.token_handler import TokenHandler, get_token_encoder from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider @@ -298,11 +298,12 @@ def clip_tokens(text: str, max_tokens: int) -> str: str: The clipped string. """ # We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word - words = re.finditer(r'\S+', text) - max_words = max_tokens // 2.5 - end_pos = None - for i, token in enumerate(words): - if i == max_words: - end_pos = token.start() - break - return text if end_pos is None else text[:end_pos] \ No newline at end of file + encoder = get_token_encoder() + num_input_tokens = len(encoder.encode(text)) + if num_input_tokens <= max_tokens: + return text + num_chars = len(text) + chars_per_token = num_chars / num_input_tokens + num_output_chars = int(chars_per_token * max_tokens) + clipped_text = text[:num_output_chars] + return clipped_text diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 3686f521..f018a92b 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -4,6 +4,10 @@ from tiktoken import encoding_for_model, get_encoding from pr_agent.config_loader import get_settings +def get_token_encoder(): + return encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding( + "cl100k_base") + class TokenHandler: """ A class for handling tokens in the context of a pull request. @@ -27,7 +31,7 @@ class TokenHandler: - system: The system string. - user: The user string. """ - self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base") + self.encoder = get_token_encoder() self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index 122b0db3..07b92295 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse import requests from atlassian.bitbucket import Cloud +from ..algo.pr_processing import clip_tokens from ..config_loader import get_settings from .git_provider import FilePatchInfo @@ -81,6 +82,9 @@ class BitbucketProvider: return self.pr.source_branch def get_pr_description(self): + max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) + if max_tokens: + return clip_tokens(self.pr.description, max_tokens) return self.pr.description def get_user_id(self): diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 8e161252..2a891938 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -97,6 +97,10 @@ class GitProvider(ABC): def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: pass + @abstractmethod + def get_commit_messages(self): + pass + def get_main_pr_language(languages, files) -> str: """ Get the main language of the commit. Return an empty string if cannot determine. diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index bc5cc6a7..dbad5388 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -12,7 +12,7 @@ from starlette_context import context from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file from ..algo.utils import load_large_diff -from ..algo.pr_processing import find_line_number_of_relevant_line_in_file +from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens from ..config_loader import get_settings from ..servers.utils import RateLimitExceeded @@ -234,6 +234,9 @@ class GithubProvider(GitProvider): return self.pr.head.ref def get_pr_description(self): + max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) + if max_tokens: + return clip_tokens(self.pr.body, max_tokens) return self.pr.body def get_user_id(self): @@ -375,19 +378,22 @@ class GithubProvider(GitProvider): logging.exception(f"Failed to get labels, error: {e}") return [] - def get_commit_messages(self) -> str: + def get_commit_messages(self): """ Retrieves the commit messages of a pull request. Returns: str: A string containing the commit messages of the pull request. """ + max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None) try: commit_list = self.pr.get_commits() commit_messages = [commit.commit.message for commit in commit_list] commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]) - except: + except Exception: commit_messages_str = "" + if max_tokens: + commit_messages_str = clip_tokens(commit_messages_str, max_tokens) return commit_messages_str def generate_link_to_relevant_line_number(self, suggestion) -> str: diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index a4d2d127..73a3a2f9 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -7,6 +7,7 @@ import gitlab from gitlab import GitlabGetError from ..algo.language_handler import is_valid_file +from ..algo.pr_processing import clip_tokens from ..algo.utils import load_large_diff from ..config_loader import get_settings from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider @@ -275,6 +276,9 @@ class GitLabProvider(GitProvider): return self.mr.source_branch def get_pr_description(self): + max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) + if max_tokens: + return clip_tokens(self.mr.description, max_tokens) return self.mr.description def get_issue_comments(self): @@ -338,16 +342,19 @@ class GitLabProvider(GitProvider): def get_labels(self): return self.mr.labels - def get_commit_messages(self) -> str: + def get_commit_messages(self): """ Retrieves the commit messages of a pull request. Returns: str: A string containing the commit messages of the pull request. """ + max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None) try: commit_messages_list = [commit['message'] for commit in self.mr.commits()._list] commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)]) - except: + except Exception: commit_messages_str = "" + if max_tokens: + commit_messages_str = clip_tokens(commit_messages_str, max_tokens) return commit_messages_str \ No newline at end of file diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 8334049d..0c502df9 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -8,6 +8,8 @@ verbosity_level=0 # 0,1,2 use_extra_bad_extensions=false use_repo_settings_file=true ai_timeout=180 +max_description_tokens = 500 +max_commits_tokens = 500 [pr_reviewer] # /review # require_focused_review=true diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 982f18cc..f679851b 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -62,8 +62,6 @@ class PRReviewer: "extra_instructions": get_settings().pr_reviewer.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), } - self.vars["description"] = clip_tokens(self.vars["description"], 500) - self.vars["commit_messages_str"] = clip_tokens(self.vars["commit_messages_str"], 500) self.token_handler = TokenHandler( self.git_provider.pr,