Merge remote-tracking branch 'upstream/main' into abstract-BaseAiHandler

This commit is contained in:
Brian Pham
2023-12-14 07:45:43 +08:00
17 changed files with 80 additions and 27 deletions

View File

@ -354,5 +354,5 @@ class BitbucketProvider(GitProvider):
pass
# bitbucket does not support labels
def get_labels(self):
def get_pr_labels(self):
pass

View File

@ -344,7 +344,7 @@ class BitbucketServerProvider(GitProvider):
pass
# bitbucket does not support labels
def get_labels(self):
def get_pr_labels(self):
pass
def _get_pr_comments_url(self):

View File

@ -216,7 +216,7 @@ class CodeCommitProvider(GitProvider):
def publish_labels(self, labels):
return [""] # not implemented yet
def get_labels(self):
def get_pr_labels(self):
return [""] # not implemented yet
def remove_initial_comment(self):

View File

@ -207,7 +207,7 @@ class GerritProvider(GitProvider):
Comment = namedtuple('Comment', ['body'])
return Comments([Comment(c['message']) for c in reversed(comments)])
def get_labels(self):
def get_pr_labels(self):
raise NotImplementedError(
'Getting labels is not implemented for the gerrit provider')

View File

@ -135,7 +135,10 @@ class GitProvider(ABC):
pass
@abstractmethod
def get_labels(self):
def get_pr_labels(self):
pass
def get_repo_labels(self):
pass
@abstractmethod

View File

@ -461,13 +461,17 @@ class GithubProvider(GitProvider):
except Exception as e:
get_logger().exception(f"Failed to publish labels, error: {e}")
def get_labels(self):
def get_pr_labels(self):
try:
return [label.name for label in self.pr.labels]
except Exception as e:
get_logger().exception(f"Failed to get labels, error: {e}")
return []
def get_repo_labels(self):
labels = self.repo_obj.get_labels()
return [label for label in labels]
def get_commit_messages(self):
"""
Retrieves the commit messages of a pull request.

View File

@ -211,7 +211,11 @@ class GitLabProvider(GitProvider):
pos_obj['new_line'] = target_line_no - 1
pos_obj['old_line'] = source_line_no - 1
get_logger().debug(f"Creating comment in {self.id_mr} with body {body} and position {pos_obj}")
self.mr.discussions.create({'body': body, 'position': pos_obj})
try:
self.mr.discussions.create({'body': body, 'position': pos_obj})
except Exception as e:
get_logger().debug(
f"Failed to create comment in {self.id_mr} with position {pos_obj} (probably not a '+' line)")
def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: int) -> Optional[dict]:
changes = self.mr.changes() # Retrieve the changes for the merge request once
@ -404,7 +408,7 @@ class GitLabProvider(GitProvider):
def publish_inline_comments(self, comments: list[dict]):
pass
def get_labels(self):
def get_pr_labels(self):
return self.mr.labels
def get_commit_messages(self):

View File

@ -178,5 +178,5 @@ class LocalGitProvider(GitProvider):
def get_issue_comments(self):
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
def get_labels(self):
def get_pr_labels(self):
raise NotImplementedError('Getting labels is not implemented for the local git provider')

View File

@ -16,8 +16,13 @@ from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.secret_providers import get_secret_provider
from pr_agent.servers.github_action_runner import get_setting_or_env, is_true
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
@ -91,8 +96,20 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url
log_context["event"] = "pull_request"
with get_logger().contextualize(**log_context):
await agent.handle_request(pr_url, "review")
if pr_url:
with get_logger().contextualize(**log_context):
apply_repo_settings(pr_url)
auto_review = get_setting_or_env("BITBUCKET_APP.AUTO_REVIEW", None)
if auto_review is None or is_true(auto_review): # by default, auto review is enabled
await PRReviewer(pr_url).run()
auto_improve = get_setting_or_env("BITBUCKET_APP.AUTO_IMPROVE", None)
if is_true(auto_improve): # by default, auto improve is disabled
await PRCodeSuggestions(pr_url).run()
auto_describe = get_setting_or_env("BITBUCKET_APP.AUTO_DESCRIBE", None)
if is_true(auto_describe): # by default, auto describe is disabled
await PRDescription(pr_url).run()
# with get_logger().contextualize(**log_context):
# await agent.handle_request(pr_url, "review")
elif event == "pullrequest:comment_created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url
@ -139,7 +156,6 @@ async def handle_uninstalled_webhooks(request: Request, response: Response):
def start():
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
get_settings().set("CONFIG.GIT_PROVIDER", "bitbucket")
get_settings().set("PR_DESCRIPTION.PUBLISH_DESCRIPTION_AS_COMMENT", True)
middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router)

View File

@ -127,11 +127,15 @@ async def handle_request(body: Dict[str, Any], event: str):
await _perform_commands("pr_commands", agent, body, api_url, log_context)
# handle pull_request event with synchronize action - "push trigger" for new commits
elif event == 'pull_request' and action == 'synchronize' and get_settings().github_app.handle_push_trigger:
elif event == 'pull_request' and action == 'synchronize':
pull_request, api_url = _check_pull_request_event(action, body, log_context, bot_user)
if not (pull_request and api_url):
return {}
apply_repo_settings(api_url)
if not get_settings().github_app.handle_push_trigger:
return {}
# TODO: do we still want to get the list of commits to filter bot/merge commits?
before_sha = body.get("before")
after_sha = body.get("after")

View File

@ -143,6 +143,12 @@ magic_word = "AutoReview"
# Polling interval
polling_interval_seconds = 30
[bitbucket_app]
#auto_review = true # set as config var in .pr_agent.toml
#auto_describe = true # set as config var in .pr_agent.toml
#auto_improve = true # set as config var in .pr_agent.toml
[local]
# LocalGitProvider settings - uncomment to use paths other than default
# description_path= "path/to/description.md"
@ -170,3 +176,4 @@ max_issues_to_scan = 500
# fill and place in .secrets.toml
#api_key = ...
# environment = "gcp-starter"

View File

@ -102,11 +102,12 @@ class PRDescription:
if get_settings().config.publish_output:
get_logger().info(f"Pushing answer {self.pr_id}")
if get_settings().pr_description.publish_description_as_comment:
get_logger().info(f"Publishing answer as comment")
self.git_provider.publish_comment(full_markdown_description)
else:
self.git_provider.publish_description(pr_title, pr_body)
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
current_labels = self.git_provider.get_labels()
current_labels = self.git_provider.get_pr_labels()
user_labels = get_user_labels(current_labels)
self.git_provider.publish_labels(pr_labels + user_labels)
@ -158,7 +159,7 @@ class PRDescription:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables)
set_custom_labels(variables, self.git_provider)
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
@ -290,7 +291,7 @@ class PRDescription:
value = ', '.join(v for v in value)
pr_body += f"{value}\n"
if idx < len(self.data) - 1:
pr_body += "\n___\n"
pr_body += "\n\n___\n\n"
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"title:\n{title}\n{pr_body}")
@ -315,7 +316,6 @@ class PRDescription:
if not self.git_provider.is_supported("gfm_markdown"):
get_logger().info(f"Disabling semantic files types for {self.pr_id} since gfm_markdown is not supported")
return pr_body
try:
pr_body += "<table>"
header = f"Relevant files"

View File

@ -82,7 +82,7 @@ class PRGenerateLabels:
if get_settings().config.publish_output:
get_logger().info(f"Pushing labels {self.pr_id}")
current_labels = self.git_provider.get_labels()
current_labels = self.git_provider.get_pr_labels()
user_labels = get_user_labels(current_labels)
pr_labels = pr_labels + user_labels
@ -132,7 +132,7 @@ class PRGenerateLabels:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables)
set_custom_labels(variables, self.git_provider)
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)

View File

@ -394,7 +394,7 @@ class PRReviewer:
if security_concerns_bool:
review_labels.append('Possible security concern')
current_labels = self.git_provider.get_labels()
current_labels = self.git_provider.get_pr_labels()
current_labels_filtered = [label for label in current_labels if
not label.lower().startswith('review effort [1-5]:') and not label.lower().startswith(
'possible security concern')]