Merge remote-tracking branch 'origin/tr/agent_logic' into tr/agent_logic

This commit is contained in:
mrT23
2023-07-18 10:32:43 +03:00
7 changed files with 37 additions and 25 deletions

View File

@ -16,7 +16,7 @@ class PRAgent:
if any(cmd in request for cmd in ["/answer"]): if any(cmd in request for cmd in ["/answer"]):
await PRReviewer(pr_url, is_answer=True).review() await PRReviewer(pr_url, is_answer=True).review()
elif any(cmd in request for cmd in ["/review", "/review_pr", "/reflect_and_review"]): elif any(cmd in request for cmd in ["/review", "/review_pr", "/reflect_and_review"]):
if settings.pr_reviewer.ask_and_reflect or any(cmd in request for cmd in ["/reflect_and_review"]): if settings.pr_reviewer.ask_and_reflect or "/reflect_and_review" in request:
await PRInformationFromUser(pr_url).generate_questions() await PRInformationFromUser(pr_url).generate_questions()
else: else:
await PRReviewer(pr_url).review() await PRReviewer(pr_url).review()

View File

@ -25,6 +25,11 @@ class BitbucketProvider:
if pr_url: if pr_url:
self.set_pr(pr_url) self.set_pr(pr_url)
def is_supported(self, capability: str) -> bool:
if capability == 'get_issue_comments':
return False
return True
def set_pr(self, pr_url: str): def set_pr(self, pr_url: str):
self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url) self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr() self.pr = self._get_pr()
@ -74,6 +79,9 @@ class BitbucketProvider:
def get_user_id(self): def get_user_id(self):
return 0 return 0
def get_issue_comments(self):
raise NotImplementedError("Bitbucket provider does not support issue comments yet")
@staticmethod @staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]: def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)

View File

@ -21,6 +21,10 @@ class FilePatchInfo:
class GitProvider(ABC): class GitProvider(ABC):
@abstractmethod
def is_supported(self, capability: str) -> bool:
pass
@abstractmethod @abstractmethod
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
pass pass
@ -62,6 +66,10 @@ class GitProvider(ABC):
def get_pr_description(self): def get_pr_description(self):
pass pass
@abstractmethod
def get_issue_comments(self):
pass
def get_main_pr_language(languages, files) -> str: def get_main_pr_language(languages, files) -> str:
""" """

View File

@ -23,6 +23,9 @@ class GithubProvider(GitProvider):
self.set_pr(pr_url) self.set_pr(pr_url)
self.last_commit_id = list(self.pr.get_commits())[-1] self.last_commit_id = list(self.pr.get_commits())[-1]
def is_supported(self, capability: str) -> bool:
return True
def set_pr(self, pr_url: str): def set_pr(self, pr_url: str):
self.repo, self.pr_num = self._parse_pr_url(pr_url) self.repo, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr() self.pr = self._get_pr()
@ -161,6 +164,9 @@ class GithubProvider(GitProvider):
notifications = self.github_client.get_user().get_notifications(since=since) notifications = self.github_client.get_user().get_notifications(since=since)
return notifications return notifications
def get_issue_comments(self):
return self.pr.get_issue_comments()
@staticmethod @staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]: def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)

View File

@ -31,6 +31,11 @@ class GitLabProvider(GitProvider):
self.RE_HUNK_HEADER = re.compile( self.RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
def is_supported(self, capability: str) -> bool:
if capability == 'get_issue_comments':
return False
return True
@property @property
def pr(self): def pr(self):
'''The GitLab terminology is merge request (MR) instead of pull request (PR)''' '''The GitLab terminology is merge request (MR) instead of pull request (PR)'''
@ -203,6 +208,9 @@ class GitLabProvider(GitProvider):
def get_pr_description(self): def get_pr_description(self):
return self.mr.description return self.mr.description
def get_issue_comments(self):
raise NotImplementedError("GitLab provider does not support issue comments yet")
def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]: def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]:
parsed_url = urlparse(merge_request_url) parsed_url = urlparse(merge_request_url)

View File

@ -3,6 +3,7 @@ import json
import os import os
import re import re
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
@ -54,26 +55,7 @@ async def run_action():
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url", None) pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url", None)
if pr_url: if pr_url:
body = comment_body.strip().lower() body = comment_body.strip().lower()
if any(cmd in body for cmd in ["/answer"]): await PRAgent().handle_request(pr_url, body)
await PRReviewer(pr_url, is_answer=True).review()
elif any(cmd in body for cmd in ["/review", "/review_pr", "/reflect_and_review"]):
if settings.pr_reviewer.ask_and_reflect or \
any(cmd in body for cmd in ["/reflect_and_review"]):
await PRInformationFromUser(pr_url).generate_questions()
else:
await PRReviewer(pr_url).review()
elif any(cmd in body for cmd in ["/describe", "/describe_pr"]):
await PRDescription(pr_url).describe()
elif any(cmd in body for cmd in ["/improve", "/improve_code"]):
await PRCodeSuggestions(pr_url).suggest()
elif any(cmd in body for cmd in ["/ask", "/ask_question"]):
pattern = r'(/ask|/ask_question)\s*(.*)'
matches = re.findall(pattern, comment_body, re.IGNORECASE)
if matches:
question = matches[0][1]
await PRQuestions(pr_url, question).answer()
else:
print(f"Unknown command: {body}")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -9,7 +9,7 @@ from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider, GithubProvider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
from pr_agent.servers.help import bot_help_text, actions_help_text from pr_agent.servers.help import bot_help_text, actions_help_text
@ -22,8 +22,8 @@ class PRReviewer:
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.is_answer = is_answer self.is_answer = is_answer
if self.is_answer and type(self.git_provider) != GithubProvider: if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception("Answer mode is only supported for Github for now") raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now")
answer_str = question_str = self._get_user_answers() answer_str = question_str = self._get_user_answers()
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
@ -139,7 +139,7 @@ class PRReviewer:
def _get_user_answers(self): def _get_user_answers(self):
answer_str = question_str = "" answer_str = question_str = ""
if self.is_answer: if self.is_answer:
discussion_messages = self.git_provider.pr.get_issue_comments() discussion_messages = self.git_provider.get_issue_comments()
for message in discussion_messages.reversed: for message in discussion_messages.reversed:
if "Questions to better understand the PR:" in message.body: if "Questions to better understand the PR:" in message.body:
question_str = message.body question_str = message.body