This commit is contained in:
mrT23
2024-07-11 18:30:16 +03:00
parent 4b351cfe38
commit eccd00b86f
2 changed files with 122 additions and 6 deletions

View File

@ -6,7 +6,7 @@ from typing import Optional
from pr_agent.config_loader import get_settings
from pr_agent.algo.types import FilePatchInfo
from pr_agent.log import get_logger
MAX_FILES_ALLOWED_FULL = 50
class GitProvider(ABC):
@abstractmethod
@ -51,6 +51,12 @@ class GitProvider(ABC):
def edit_comment(self, comment, body: str):
pass
def edit_comment_from_comment_id(self, comment_id: int, body: str):
pass
def get_comment_body_from_comment_id(self, comment_id: int) -> str:
pass
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
pass
@ -74,6 +80,7 @@ class GitProvider(ABC):
# if the existing description wasn't generated by the pr-agent, just return it as-is
if not self._is_generated_by_pr_agent(description_lowercase):
get_logger().info(f"Existing description was not generated by the pr-agent")
self.user_description = description
return description
# if the existing description was generated by the pr-agent, but it doesn't contain a user description,
@ -120,12 +127,18 @@ class GitProvider(ABC):
def get_repo_settings(self):
pass
def get_workspace_name(self):
return ""
def get_pr_id(self):
return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
return ""
def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str:
return ""
#### comments operations ####
@abstractmethod
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
@ -166,6 +179,7 @@ class GitProvider(ABC):
pass
self.publish_comment(pr_comment)
@abstractmethod
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass

View File

@ -1,3 +1,4 @@
import itertools
import time
import hashlib
from datetime import datetime
@ -14,7 +15,7 @@ from ..algo.utils import PRReviewHeader, load_large_diff, clip_tokens, find_line
from ..config_loader import get_settings
from ..log import get_logger
from ..servers.utils import RateLimitExceeded
from .git_provider import GitProvider, IncrementalPR
from .git_provider import GitProvider, IncrementalPR, MAX_FILES_ALLOWED_FULL
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
@ -164,20 +165,36 @@ class GithubProvider(GitProvider):
diff_files = []
invalid_files_names = []
counter_valid = 0
for file in files:
if not is_valid_file(file.filename):
invalid_files_names.append(file.filename)
continue
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub
patch = file.patch
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
counter_valid += 1
avoid_load = False
if counter_valid >= MAX_FILES_ALLOWED_FULL and patch and not self.incremental.is_incremental:
avoid_load = True
if counter_valid == MAX_FILES_ALLOWED_FULL:
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
if avoid_load:
new_file_content_str = ""
else:
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub
if self.incremental.is_incremental and self.unreviewed_files_set:
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
self.unreviewed_files_set[file.filename] = patch
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
if avoid_load:
original_file_content_str = ""
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
if not patch:
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
@ -427,6 +444,16 @@ class GithubProvider(GitProvider):
def edit_comment(self, comment, body: str):
comment.edit(body=body)
def edit_comment_from_comment_id(self, comment_id: int, body: str):
try:
# self.pr.get_issue_comment(comment_id).edit(body)
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
input={"body": body}
)
except Exception as e:
get_logger().exception(f"Failed to edit comment, error: {e}")
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
try:
# self.pr.get_issue_comment(comment_id).edit(body)
@ -437,6 +464,50 @@ class GithubProvider(GitProvider):
except Exception as e:
get_logger().exception(f"Failed to reply comment, error: {e}")
def get_comment_body_from_comment_id(self, comment_id: int):
try:
# self.pr.get_issue_comment(comment_id).edit(body)
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"GET", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}"
)
return data_patch.get("body","")
except Exception as e:
get_logger().exception(f"Failed to edit comment, error: {e}")
return None
def publish_file_comments(self, file_comments: list) -> bool:
try:
headers, existing_comments = self.pr._requester.requestJsonAndCheck(
"GET", f"{self.pr.url}/comments"
)
for comment in file_comments:
comment['commit_id'] = self.last_commit_id.sha
found = False
for existing_comment in existing_comments:
comment['commit_id'] = self.last_commit_id.sha
our_app_name = get_settings().get("GITHUB.APP_NAME", "")
same_comment_creator = False
if self.deployment_type == 'app':
same_comment_creator = our_app_name.lower() in existing_comment['user']['login'].lower()
elif self.deployment_type == 'user':
same_comment_creator = self.github_user_id == existing_comment['user']['login']
if existing_comment['subject_type'] == 'file' and comment['path'] == existing_comment['path'] and same_comment_creator:
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"PATCH", f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", input={"body":comment['body']}
)
found = True
break
if not found:
headers, data_post = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.pr.url}/comments", input=comment
)
return True
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to publish diffview file summary, error: {e}")
return False
def remove_initial_comment(self):
try:
for comment in getattr(self.pr, 'comments_list', []):
@ -461,6 +532,11 @@ class GithubProvider(GitProvider):
def get_pr_branch(self):
return self.pr.head.ref
def get_pr_owner_id(self) -> str | None:
if not self.repo:
return None
return self.repo.split('/')[0]
def get_pr_description_full(self):
return self.pr.body
@ -495,6 +571,9 @@ class GithubProvider(GitProvider):
except Exception:
return ""
def get_workspace_name(self):
return self.repo.split('/')[0]
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
if disable_eyes:
return None
@ -673,7 +752,7 @@ class GithubProvider(GitProvider):
def get_repo_labels(self):
labels = self.repo_obj.get_labels()
return [label for label in labels]
return [label for label in itertools.islice(labels, 50)]
def get_commit_messages(self):
"""
@ -728,6 +807,29 @@ class GithubProvider(GitProvider):
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}"
return link
def get_lines_link_original_file(self, filepath: str, component_range: Range) -> str:
"""
Returns the link to the original file on GitHub that corresponds to the given filepath and component range.
Args:
filepath (str): The path of the file.
component_range (Range): The range of lines that represent the component.
Returns:
str: The link to the original file on GitHub.
Example:
>>> filepath = "path/to/file.py"
>>> component_range = Range(line_start=10, line_end=20)
>>> link = get_lines_link_original_file(filepath, component_range)
>>> print(link)
"https://github.com/{repo}/blob/{commit_sha}/{filepath}/#L11-L21"
"""
line_start = component_range.line_start + 1
line_end = component_range.line_end + 1
link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
f"#L{line_start}-L{line_end}")
return link
def get_pr_id(self):
try:
@ -747,4 +849,4 @@ class GithubProvider(GitProvider):
return False
def calc_pr_statistics(self, pull_request_data: dict):
return {}
return {}