Compare commits

..

8 Commits

Author SHA1 Message Date
e0f295659d A less hacky way 2023-08-09 12:17:54 +03:00
e3274af831 A (still) hacky way to clip description and commit messages 2023-08-09 10:17:58 +03:00
ebbe655c40 Don't commment on Github, only eyes reaction 2023-08-07 18:09:39 +03:00
b1148e5f7a Don't commment on Github, only eyes reaction 2023-08-07 16:34:28 +03:00
2012e25596 Merge pull request #182 from Codium-ai/ok/add_eyes_reaction
Add Eyes Reaction to Comments and Configure AI Timeout
2023-08-07 16:28:38 +03:00
a75253097b Don't remove eyes 2023-08-07 16:28:20 +03:00
079d62af56 Merge pull request #181 from Codium-ai/ok/inference_timeout
Configurable AI Timeout
2023-08-07 16:23:06 +03:00
886139c6b5 Support adding / removing reaction from comments in GitHub different servers 2023-08-07 16:18:08 +03:00
12 changed files with 119 additions and 15 deletions

View File

@ -37,7 +37,7 @@ class PRAgent:
def __init__(self): def __init__(self):
pass pass
async def handle_request(self, pr_url, request) -> bool: async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists # First, apply repo specific settings if exists
if get_settings().config.use_repo_settings_file: if get_settings().config.use_repo_settings_file:
repo_settings_file = None repo_settings_file = None
@ -67,8 +67,12 @@ class PRAgent:
if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect: if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect:
action = "review" action = "review"
if action == "answer": if action == "answer":
if notify:
notify()
await PRReviewer(pr_url, is_answer=True, args=args).run() await PRReviewer(pr_url, is_answer=True, args=args).run()
elif action in command2class: elif action in command2class:
if notify:
notify()
await command2class[action](pr_url, args=args).run() await command2class[action](pr_url, args=args).run()
else: else:
return False return False

View File

@ -11,7 +11,7 @@ from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
@ -284,3 +284,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
absolute_position = start2 + delta - 1 absolute_position = start2 + delta - 1
break break
return position, absolute_position return position, absolute_position
def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
# We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text

View File

@ -4,6 +4,10 @@ from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
def get_token_encoder():
return encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding(
"cl100k_base")
class TokenHandler: class TokenHandler:
""" """
A class for handling tokens in the context of a pull request. A class for handling tokens in the context of a pull request.
@ -27,7 +31,7 @@ class TokenHandler:
- system: The system string. - system: The system string.
- user: The user string. - user: The user string.
""" """
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base") self.encoder = get_token_encoder()
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):

View File

@ -5,6 +5,7 @@ from urllib.parse import urlparse
import requests import requests
from atlassian.bitbucket import Cloud from atlassian.bitbucket import Cloud
from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import FilePatchInfo from .git_provider import FilePatchInfo
@ -81,6 +82,9 @@ class BitbucketProvider:
return self.pr.source_branch return self.pr.source_branch
def get_pr_description(self): def get_pr_description(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.pr.description, max_tokens)
return self.pr.description return self.pr.description
def get_user_id(self): def get_user_id(self):
@ -89,6 +93,12 @@ class BitbucketProvider:
def get_issue_comments(self): def get_issue_comments(self):
raise NotImplementedError("Bitbucket provider does not support issue comments yet") raise NotImplementedError("Bitbucket provider does not support issue comments yet")
def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return True
@staticmethod @staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]: def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) # enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from enum import Enum from enum import Enum
from typing import Optional
class EDIT_TYPE(Enum): class EDIT_TYPE(Enum):
@ -88,6 +89,17 @@ class GitProvider(ABC):
def get_issue_comments(self): def get_issue_comments(self):
pass pass
@abstractmethod
def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]:
pass
@abstractmethod
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
pass
@abstractmethod
def get_commit_messages(self):
pass
def get_main_pr_language(languages, files) -> str: def get_main_pr_language(languages, files) -> str:
""" """

View File

@ -2,17 +2,17 @@ import logging
import hashlib import hashlib
from datetime import datetime from datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from github import AppAuthentication, Auth, Github, GithubException from github import AppAuthentication, Auth, Github, GithubException, Reaction
from retry import retry from retry import retry
from starlette_context import context from starlette_context import context
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded from ..servers.utils import RateLimitExceeded
@ -234,6 +234,9 @@ class GithubProvider(GitProvider):
return self.pr.head.ref return self.pr.head.ref
def get_pr_description(self): def get_pr_description(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.pr.body, max_tokens)
return self.pr.body return self.pr.body
def get_user_id(self): def get_user_id(self):
@ -263,6 +266,23 @@ class GithubProvider(GitProvider):
except Exception: except Exception:
return "" return ""
def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]:
try:
reaction = self.pr.get_issue_comment(issue_comment_id).create_reaction("eyes")
return reaction.id
except Exception as e:
logging.exception(f"Failed to add eyes reaction, error: {e}")
return None
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
try:
self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id)
return True
except Exception as e:
logging.exception(f"Failed to remove eyes reaction, error: {e}")
return False
@staticmethod @staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]: def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)
@ -358,19 +378,22 @@ class GithubProvider(GitProvider):
logging.exception(f"Failed to get labels, error: {e}") logging.exception(f"Failed to get labels, error: {e}")
return [] return []
def get_commit_messages(self) -> str: def get_commit_messages(self):
""" """
Retrieves the commit messages of a pull request. Retrieves the commit messages of a pull request.
Returns: Returns:
str: A string containing the commit messages of the pull request. str: A string containing the commit messages of the pull request.
""" """
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
try: try:
commit_list = self.pr.get_commits() commit_list = self.pr.get_commits()
commit_messages = [commit.commit.message for commit in commit_list] commit_messages = [commit.commit.message for commit in commit_list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]) commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
except: except Exception:
commit_messages_str = "" commit_messages_str = ""
if max_tokens:
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
return commit_messages_str return commit_messages_str
def generate_link_to_relevant_line_number(self, suggestion) -> str: def generate_link_to_relevant_line_number(self, suggestion) -> str:

View File

@ -7,6 +7,7 @@ import gitlab
from gitlab import GitlabGetError from gitlab import GitlabGetError
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
@ -275,6 +276,9 @@ class GitLabProvider(GitProvider):
return self.mr.source_branch return self.mr.source_branch
def get_pr_description(self): def get_pr_description(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.mr.description, max_tokens)
return self.mr.description return self.mr.description
def get_issue_comments(self): def get_issue_comments(self):
@ -287,6 +291,12 @@ class GitLabProvider(GitProvider):
except Exception: except Exception:
return "" return ""
def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return True
def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[str, int]: def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[str, int]:
parsed_url = urlparse(merge_request_url) parsed_url = urlparse(merge_request_url)
@ -332,16 +342,19 @@ class GitLabProvider(GitProvider):
def get_labels(self): def get_labels(self):
return self.mr.labels return self.mr.labels
def get_commit_messages(self) -> str: def get_commit_messages(self):
""" """
Retrieves the commit messages of a pull request. Retrieves the commit messages of a pull request.
Returns: Returns:
str: A string containing the commit messages of the pull request. str: A string containing the commit messages of the pull request.
""" """
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
try: try:
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list] commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)]) commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
except: except Exception:
commit_messages_str = "" commit_messages_str = ""
if max_tokens:
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
return commit_messages_str return commit_messages_str

View File

@ -4,6 +4,7 @@ import os
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
@ -14,6 +15,8 @@ async def run_action():
OPENAI_KEY = os.environ.get('OPENAI_KEY') OPENAI_KEY = os.environ.get('OPENAI_KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG') OPENAI_ORG = os.environ.get('OPENAI_ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN') GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
# Check if required environment variables are set # Check if required environment variables are set
if not GITHUB_EVENT_NAME: if not GITHUB_EVENT_NAME:
@ -61,7 +64,9 @@ async def run_action():
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url") pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
if pr_url: if pr_url:
body = comment_body.strip().lower() body = comment_body.strip().lower()
await PRAgent().handle_request(pr_url, body) comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=pr_url)
await PRAgent().handle_request(pr_url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -11,6 +11,7 @@ from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.servers.utils import verify_signature from pr_agent.servers.utils import verify_signature
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@ -80,7 +81,10 @@ async def handle_request(body: Dict[str, Any]):
return {} return {}
pull_request = body["issue"]["pull_request"] pull_request = body["issue"]["pull_request"]
api_url = pull_request.get("url") api_url = pull_request.get("url")
await agent.handle_request(api_url, comment_body) comment_id = body.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=api_url)
await agent.handle_request(api_url, comment_body, notify=lambda: provider.add_eyes_reaction(comment_id))
elif action == "opened" or 'reopened' in action: elif action == "opened" or 'reopened' in action:
pull_request = body.get("pull_request") pull_request = body.get("pull_request")
@ -102,6 +106,7 @@ async def root():
def start(): def start():
# Override the deployment type to app # Override the deployment type to app
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app") get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
middleware = [Middleware(RawContextMiddleware)] middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware) app = FastAPI(middleware=middleware)
app.include_router(router) app.include_router(router)

View File

@ -36,6 +36,7 @@ async def polling_loop():
git_provider = get_git_provider()() git_provider = get_git_provider()()
user_id = git_provider.get_user_id() user_id = git_provider.get_user_id()
agent = PRAgent() agent = PRAgent()
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
try: try:
deployment_type = get_settings().github.deployment_type deployment_type = get_settings().github.deployment_type
@ -98,8 +99,10 @@ async def polling_loop():
if user_tag not in comment_body: if user_tag not in comment_body:
continue continue
rest_of_comment = comment_body.split(user_tag)[1].strip() rest_of_comment = comment_body.split(user_tag)[1].strip()
comment_id = comment['id']
success = await agent.handle_request(pr_url, rest_of_comment) git_provider.set_pr(pr_url)
success = await agent.handle_request(pr_url, rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)) # noqa E501
if not success: if not success:
git_provider.set_pr(pr_url) git_provider.set_pr(pr_url)
git_provider.publish_comment("### How to use PR-Agent\n" + git_provider.publish_comment("### How to use PR-Agent\n" +

View File

@ -8,6 +8,8 @@ verbosity_level=0 # 0,1,2
use_extra_bad_extensions=false use_extra_bad_extensions=false
use_repo_settings_file=true use_repo_settings_file=true
ai_timeout=180 ai_timeout=180
max_description_tokens = 500
max_commits_tokens = 500
[pr_reviewer] # /review # [pr_reviewer] # /review #
require_focused_review=true require_focused_review=true

View File

@ -8,7 +8,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \ from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \
find_line_number_of_relevant_line_in_file find_line_number_of_relevant_line_in_file, clip_tokens
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings