mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 20:30:41 +08:00
Merge pull request #187 from Codium-ai/ok/limit_description
Limiting Description and Commit Messages Length
This commit is contained in:
@ -11,7 +11,7 @@ from github import RateLimitExceededException
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
|
||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
|
||||
|
||||
@ -284,3 +284,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
absolute_position = start2 + delta - 1
|
||||
break
|
||||
return position, absolute_position
|
||||
|
||||
|
||||
def clip_tokens(text: str, max_tokens: int) -> str:
|
||||
"""
|
||||
Clip the number of tokens in a string to a maximum number of tokens.
|
||||
|
||||
Args:
|
||||
text (str): The string to clip.
|
||||
max_tokens (int): The maximum number of tokens allowed in the string.
|
||||
|
||||
Returns:
|
||||
str: The clipped string.
|
||||
"""
|
||||
# We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word
|
||||
encoder = get_token_encoder()
|
||||
num_input_tokens = len(encoder.encode(text))
|
||||
if num_input_tokens <= max_tokens:
|
||||
return text
|
||||
num_chars = len(text)
|
||||
chars_per_token = num_chars / num_input_tokens
|
||||
num_output_chars = int(chars_per_token * max_tokens)
|
||||
clipped_text = text[:num_output_chars]
|
||||
return clipped_text
|
||||
|
@ -4,6 +4,10 @@ from tiktoken import encoding_for_model, get_encoding
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def get_token_encoder():
|
||||
return encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding(
|
||||
"cl100k_base")
|
||||
|
||||
class TokenHandler:
|
||||
"""
|
||||
A class for handling tokens in the context of a pull request.
|
||||
@ -27,7 +31,7 @@ class TokenHandler:
|
||||
- system: The system string.
|
||||
- user: The user string.
|
||||
"""
|
||||
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
|
||||
self.encoder = get_token_encoder()
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
|
@ -5,6 +5,7 @@ from urllib.parse import urlparse
|
||||
import requests
|
||||
from atlassian.bitbucket import Cloud
|
||||
|
||||
from ..algo.pr_processing import clip_tokens
|
||||
from ..config_loader import get_settings
|
||||
from .git_provider import FilePatchInfo
|
||||
|
||||
@ -81,6 +82,9 @@ class BitbucketProvider:
|
||||
return self.pr.source_branch
|
||||
|
||||
def get_pr_description(self):
|
||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||
if max_tokens:
|
||||
return clip_tokens(self.pr.description, max_tokens)
|
||||
return self.pr.description
|
||||
|
||||
def get_user_id(self):
|
||||
|
@ -97,6 +97,10 @@ class GitProvider(ABC):
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_commit_messages(self):
|
||||
pass
|
||||
|
||||
def get_main_pr_language(languages, files) -> str:
|
||||
"""
|
||||
Get the main language of the commit. Return an empty string if cannot determine.
|
||||
|
@ -12,7 +12,7 @@ from starlette_context import context
|
||||
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import load_large_diff
|
||||
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
|
||||
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens
|
||||
from ..config_loader import get_settings
|
||||
from ..servers.utils import RateLimitExceeded
|
||||
|
||||
@ -234,6 +234,9 @@ class GithubProvider(GitProvider):
|
||||
return self.pr.head.ref
|
||||
|
||||
def get_pr_description(self):
|
||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||
if max_tokens:
|
||||
return clip_tokens(self.pr.body, max_tokens)
|
||||
return self.pr.body
|
||||
|
||||
def get_user_id(self):
|
||||
@ -375,19 +378,22 @@ class GithubProvider(GitProvider):
|
||||
logging.exception(f"Failed to get labels, error: {e}")
|
||||
return []
|
||||
|
||||
def get_commit_messages(self) -> str:
|
||||
def get_commit_messages(self):
|
||||
"""
|
||||
Retrieves the commit messages of a pull request.
|
||||
|
||||
Returns:
|
||||
str: A string containing the commit messages of the pull request.
|
||||
"""
|
||||
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
|
||||
try:
|
||||
commit_list = self.pr.get_commits()
|
||||
commit_messages = [commit.commit.message for commit in commit_list]
|
||||
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
|
||||
except:
|
||||
except Exception:
|
||||
commit_messages_str = ""
|
||||
if max_tokens:
|
||||
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
|
||||
return commit_messages_str
|
||||
|
||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||
|
@ -7,6 +7,7 @@ import gitlab
|
||||
from gitlab import GitlabGetError
|
||||
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.pr_processing import clip_tokens
|
||||
from ..algo.utils import load_large_diff
|
||||
from ..config_loader import get_settings
|
||||
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||
@ -275,6 +276,9 @@ class GitLabProvider(GitProvider):
|
||||
return self.mr.source_branch
|
||||
|
||||
def get_pr_description(self):
|
||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||
if max_tokens:
|
||||
return clip_tokens(self.mr.description, max_tokens)
|
||||
return self.mr.description
|
||||
|
||||
def get_issue_comments(self):
|
||||
@ -338,16 +342,19 @@ class GitLabProvider(GitProvider):
|
||||
def get_labels(self):
|
||||
return self.mr.labels
|
||||
|
||||
def get_commit_messages(self) -> str:
|
||||
def get_commit_messages(self):
|
||||
"""
|
||||
Retrieves the commit messages of a pull request.
|
||||
|
||||
Returns:
|
||||
str: A string containing the commit messages of the pull request.
|
||||
"""
|
||||
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
|
||||
try:
|
||||
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
|
||||
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
|
||||
except:
|
||||
except Exception:
|
||||
commit_messages_str = ""
|
||||
if max_tokens:
|
||||
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
|
||||
return commit_messages_str
|
@ -8,6 +8,8 @@ verbosity_level=0 # 0,1,2
|
||||
use_extra_bad_extensions=false
|
||||
use_repo_settings_file=true
|
||||
ai_timeout=180
|
||||
max_description_tokens = 500
|
||||
max_commits_tokens = 500
|
||||
|
||||
[pr_reviewer] # /review #
|
||||
require_focused_review=true
|
||||
|
@ -10,7 +10,7 @@ from yaml import SafeLoader
|
||||
|
||||
from pr_agent.algo.ai_handler import AiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \
|
||||
find_line_number_of_relevant_line_in_file
|
||||
find_line_number_of_relevant_line_in_file, clip_tokens
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, try_fix_yaml, load_yaml
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
Reference in New Issue
Block a user