From c92648cbd5b3603cb473c05e81dc55da2a9eb66d Mon Sep 17 00:00:00 2001 From: mrT23 Date: Thu, 3 Aug 2023 21:38:18 +0300 Subject: [PATCH] caching --- pr_agent/git_providers/github_provider.py | 55 ++++++++++++++--------- pr_agent/git_providers/gitlab_provider.py | 41 ++++++++++++----- pr_agent/tools/pr_code_suggestions.py | 1 + 3 files changed, 65 insertions(+), 32 deletions(-) diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 4869ca69..1fccda9f 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -27,6 +27,7 @@ class GithubProvider(GitProvider): self.pr = None self.github_user_id = None self.diff_files = None + self.git_files = None self.incremental = incremental if pr_url: self.set_pr(pr_url) @@ -81,40 +82,54 @@ class GithubProvider(GitProvider): def get_files(self): if self.incremental.is_incremental and self.file_set: return self.file_set.values() - return self.pr.get_files() + if not self.git_files: + # bring files from GitHub only once + self.git_files = self.pr.get_files() + return self.git_files @retry(exceptions=RateLimitExceeded, tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) def get_diff_files(self) -> list[FilePatchInfo]: + """ + Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub, + along with their content and patch information. + + Returns: + diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted, + or renamed files in the merge request. + """ try: + if self.diff_files: + return self.diff_files + files = self.get_files() diff_files = [] - for file in files: - if is_valid_file(file.filename): - new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) - patch = file.patch - if self.incremental.is_incremental and self.file_set: - original_file_content_str = self._get_pr_file_content(file, - self.incremental.last_seen_commit_sha) - patch = load_large_diff(file, - new_file_content_str, - original_file_content_str, - None) - self.file_set[file.filename] = patch - else: - original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) - diff_files.append( - FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) + for file in files: + if not is_valid_file(file.filename): + continue + + new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub + patch = file.patch + + if self.incremental.is_incremental and self.file_set: + original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) + patch = load_large_diff(file, new_file_content_str, original_file_content_str, "") + self.file_set[file.filename] = patch + else: + original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) + + diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) + self.diff_files = diff_files return diff_files + except GithubException.RateLimitExceededException as e: logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}") raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e def publish_description(self, pr_title: str, pr_body: str): self.pr.edit(title=pr_title, body=pr_body) - # self.pr.create_issue_comment(pr_comment) def publish_comment(self, pr_comment: str, is_temporary: bool = False): if is_temporary and not get_settings().config.publish_output_progress: @@ -132,9 +147,9 @@ class GithubProvider(GitProvider): self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)]) def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): - self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() + diff_files = self.get_diff_files() position = -1 - for file in self.diff_files: + for file in diff_files: if file.filename.strip() == relevant_file: patch = file.patch patch_lines = patch.splitlines() diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 170b356e..3fcf5dbb 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -30,6 +30,7 @@ class GitLabProvider(GitProvider): self.id_mr = None self.mr = None self.diff_files = None + self.git_files = None self.temp_comments = [] self._set_merge_request(merge_request_url) self.RE_HUNK_HEADER = re.compile( @@ -65,19 +66,25 @@ class GitLabProvider(GitProvider): return '' def get_diff_files(self) -> list[FilePatchInfo]: + """ + Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitLab, + along with their content and patch information. + + Returns: + diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted, + or renamed files in the merge request. + """ + + if self.diff_files: + return self.diff_files + diffs = self.mr.changes()['changes'] diff_files = [] for diff in diffs: if is_valid_file(diff['new_path']): original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch) new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch) - edit_type = EDIT_TYPE.MODIFIED - if diff['new_file']: - edit_type = EDIT_TYPE.ADDED - elif diff['deleted_file']: - edit_type = EDIT_TYPE.DELETED - elif diff['renamed_file']: - edit_type = EDIT_TYPE.RENAMED + try: if isinstance(original_file_content_str, bytes): original_file_content_str = bytes.decode(original_file_content_str, 'utf-8') @@ -86,6 +93,15 @@ class GitLabProvider(GitProvider): except UnicodeDecodeError: logging.warning( f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}") + + edit_type = EDIT_TYPE.MODIFIED + if diff['new_file']: + edit_type = EDIT_TYPE.ADDED + elif diff['deleted_file']: + edit_type = EDIT_TYPE.DELETED + elif diff['renamed_file']: + edit_type = EDIT_TYPE.RENAMED + diff_files.append( FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'], edit_type=edit_type, @@ -94,7 +110,9 @@ class GitLabProvider(GitProvider): return diff_files def get_files(self): - return [change['new_path'] for change in self.mr.changes()['changes']] + if not self.git_files: + self.git_files = [change['new_path'] for change in self.mr.changes()['changes']] + return self.git_files def publish_description(self, pr_title: str, pr_body: str): try: @@ -110,7 +128,6 @@ class GitLabProvider(GitProvider): self.temp_comments.append(comment) def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): - self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file, relevant_line_in_file) self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, @@ -151,9 +168,9 @@ class GitLabProvider(GitProvider): relevant_lines_start = suggestion['relevant_lines_start'] relevant_lines_end = suggestion['relevant_lines_end'] - self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() + diff_files = self.get_diff_files() target_file = None - for file in self.diff_files: + for file in diff_files: if file.filename == relevant_file: if file.filename == relevant_file: target_file = file @@ -180,7 +197,7 @@ class GitLabProvider(GitProvider): target_file = None edit_type = self.get_edit_type(relevant_line_in_file) - for file in self.diff_files: + for file in self.get_diff_files(): if file.filename == relevant_file: edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file, relevant_line_in_file) diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 1816088f..32550f69 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -64,6 +64,7 @@ class PRCodeSuggestions: model, add_line_numbers_to_hunks=True, disable_extra_lines=True) + logging.info('Getting AI prediction...') self.prediction = await self._get_prediction(model)