Compare commits

..

6 Commits

Author SHA1 Message Date
e0f295659d A less hacky way 2023-08-09 12:17:54 +03:00
e3274af831 A (still) hacky way to clip description and commit messages 2023-08-09 10:17:58 +03:00
ebbe655c40 Don't commment on Github, only eyes reaction 2023-08-07 18:09:39 +03:00
b1148e5f7a Don't commment on Github, only eyes reaction 2023-08-07 16:34:28 +03:00
2012e25596 Merge pull request #182 from Codium-ai/ok/add_eyes_reaction
Add Eyes Reaction to Comments and Configure AI Timeout
2023-08-07 16:28:38 +03:00
079d62af56 Merge pull request #181 from Codium-ai/ok/inference_timeout
Configurable AI Timeout
2023-08-07 16:23:06 +03:00
12 changed files with 71 additions and 15 deletions

View File

@ -37,7 +37,7 @@ class PRAgent:
def __init__(self): def __init__(self):
pass pass
async def handle_request(self, pr_url, request) -> bool: async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists # First, apply repo specific settings if exists
if get_settings().config.use_repo_settings_file: if get_settings().config.use_repo_settings_file:
repo_settings_file = None repo_settings_file = None
@ -67,8 +67,12 @@ class PRAgent:
if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect: if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect:
action = "review" action = "review"
if action == "answer": if action == "answer":
if notify:
notify()
await PRReviewer(pr_url, is_answer=True, args=args).run() await PRReviewer(pr_url, is_answer=True, args=args).run()
elif action in command2class: elif action in command2class:
if notify:
notify()
await command2class[action](pr_url, args=args).run() await command2class[action](pr_url, args=args).run()
else: else:
return False return False

View File

@ -11,7 +11,7 @@ from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
@ -284,3 +284,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
absolute_position = start2 + delta - 1 absolute_position = start2 + delta - 1
break break
return position, absolute_position return position, absolute_position
def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
# We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text

View File

@ -4,6 +4,10 @@ from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
def get_token_encoder():
return encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding(
"cl100k_base")
class TokenHandler: class TokenHandler:
""" """
A class for handling tokens in the context of a pull request. A class for handling tokens in the context of a pull request.
@ -27,7 +31,7 @@ class TokenHandler:
- system: The system string. - system: The system string.
- user: The user string. - user: The user string.
""" """
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base") self.encoder = get_token_encoder()
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):

View File

@ -5,6 +5,7 @@ from urllib.parse import urlparse
import requests import requests
from atlassian.bitbucket import Cloud from atlassian.bitbucket import Cloud
from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import FilePatchInfo from .git_provider import FilePatchInfo
@ -81,6 +82,9 @@ class BitbucketProvider:
return self.pr.source_branch return self.pr.source_branch
def get_pr_description(self): def get_pr_description(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.pr.description, max_tokens)
return self.pr.description return self.pr.description
def get_user_id(self): def get_user_id(self):

View File

@ -97,6 +97,10 @@ class GitProvider(ABC):
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
pass pass
@abstractmethod
def get_commit_messages(self):
pass
def get_main_pr_language(languages, files) -> str: def get_main_pr_language(languages, files) -> str:
""" """
Get the main language of the commit. Return an empty string if cannot determine. Get the main language of the commit. Return an empty string if cannot determine.

View File

@ -12,7 +12,7 @@ from starlette_context import context
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded from ..servers.utils import RateLimitExceeded
@ -234,6 +234,9 @@ class GithubProvider(GitProvider):
return self.pr.head.ref return self.pr.head.ref
def get_pr_description(self): def get_pr_description(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.pr.body, max_tokens)
return self.pr.body return self.pr.body
def get_user_id(self): def get_user_id(self):
@ -375,19 +378,22 @@ class GithubProvider(GitProvider):
logging.exception(f"Failed to get labels, error: {e}") logging.exception(f"Failed to get labels, error: {e}")
return [] return []
def get_commit_messages(self) -> str: def get_commit_messages(self):
""" """
Retrieves the commit messages of a pull request. Retrieves the commit messages of a pull request.
Returns: Returns:
str: A string containing the commit messages of the pull request. str: A string containing the commit messages of the pull request.
""" """
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
try: try:
commit_list = self.pr.get_commits() commit_list = self.pr.get_commits()
commit_messages = [commit.commit.message for commit in commit_list] commit_messages = [commit.commit.message for commit in commit_list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]) commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
except: except Exception:
commit_messages_str = "" commit_messages_str = ""
if max_tokens:
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
return commit_messages_str return commit_messages_str
def generate_link_to_relevant_line_number(self, suggestion) -> str: def generate_link_to_relevant_line_number(self, suggestion) -> str:

View File

@ -7,6 +7,7 @@ import gitlab
from gitlab import GitlabGetError from gitlab import GitlabGetError
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
@ -275,6 +276,9 @@ class GitLabProvider(GitProvider):
return self.mr.source_branch return self.mr.source_branch
def get_pr_description(self): def get_pr_description(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.mr.description, max_tokens)
return self.mr.description return self.mr.description
def get_issue_comments(self): def get_issue_comments(self):
@ -338,16 +342,19 @@ class GitLabProvider(GitProvider):
def get_labels(self): def get_labels(self):
return self.mr.labels return self.mr.labels
def get_commit_messages(self) -> str: def get_commit_messages(self):
""" """
Retrieves the commit messages of a pull request. Retrieves the commit messages of a pull request.
Returns: Returns:
str: A string containing the commit messages of the pull request. str: A string containing the commit messages of the pull request.
""" """
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
try: try:
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list] commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)]) commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
except: except Exception:
commit_messages_str = "" commit_messages_str = ""
if max_tokens:
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
return commit_messages_str return commit_messages_str

View File

@ -15,6 +15,8 @@ async def run_action():
OPENAI_KEY = os.environ.get('OPENAI_KEY') OPENAI_KEY = os.environ.get('OPENAI_KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG') OPENAI_ORG = os.environ.get('OPENAI_ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN') GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
# Check if required environment variables are set # Check if required environment variables are set
if not GITHUB_EVENT_NAME: if not GITHUB_EVENT_NAME:
@ -64,8 +66,7 @@ async def run_action():
body = comment_body.strip().lower() body = comment_body.strip().lower()
comment_id = event_payload.get("comment", {}).get("id") comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=pr_url) provider = get_git_provider()(pr_url=pr_url)
provider.add_eyes_reaction(comment_id) await PRAgent().handle_request(pr_url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
await PRAgent().handle_request(pr_url, body)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -83,8 +83,7 @@ async def handle_request(body: Dict[str, Any]):
api_url = pull_request.get("url") api_url = pull_request.get("url")
comment_id = body.get("comment", {}).get("id") comment_id = body.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=api_url) provider = get_git_provider()(pr_url=api_url)
provider.add_eyes_reaction(comment_id) await agent.handle_request(api_url, comment_body, notify=lambda: provider.add_eyes_reaction(comment_id))
await agent.handle_request(api_url, comment_body)
elif action == "opened" or 'reopened' in action: elif action == "opened" or 'reopened' in action:
@ -107,6 +106,7 @@ async def root():
def start(): def start():
# Override the deployment type to app # Override the deployment type to app
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app") get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
middleware = [Middleware(RawContextMiddleware)] middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware) app = FastAPI(middleware=middleware)
app.include_router(router) app.include_router(router)

View File

@ -36,6 +36,7 @@ async def polling_loop():
git_provider = get_git_provider()() git_provider = get_git_provider()()
user_id = git_provider.get_user_id() user_id = git_provider.get_user_id()
agent = PRAgent() agent = PRAgent()
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
try: try:
deployment_type = get_settings().github.deployment_type deployment_type = get_settings().github.deployment_type
@ -100,8 +101,8 @@ async def polling_loop():
rest_of_comment = comment_body.split(user_tag)[1].strip() rest_of_comment = comment_body.split(user_tag)[1].strip()
comment_id = comment['id'] comment_id = comment['id']
git_provider.set_pr(pr_url) git_provider.set_pr(pr_url)
git_provider.add_eyes_reaction(comment_id) success = await agent.handle_request(pr_url, rest_of_comment,
success = await agent.handle_request(pr_url, rest_of_comment) notify=lambda: git_provider.add_eyes_reaction(comment_id)) # noqa E501
if not success: if not success:
git_provider.set_pr(pr_url) git_provider.set_pr(pr_url)
git_provider.publish_comment("### How to use PR-Agent\n" + git_provider.publish_comment("### How to use PR-Agent\n" +

View File

@ -8,6 +8,8 @@ verbosity_level=0 # 0,1,2
use_extra_bad_extensions=false use_extra_bad_extensions=false
use_repo_settings_file=true use_repo_settings_file=true
ai_timeout=180 ai_timeout=180
max_description_tokens = 500
max_commits_tokens = 500
[pr_reviewer] # /review # [pr_reviewer] # /review #
require_focused_review=true require_focused_review=true

View File

@ -8,7 +8,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \ from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \
find_line_number_of_relevant_line_in_file find_line_number_of_relevant_line_in_file, clip_tokens
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings