diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index f29a24e9..1a84f736 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -9,7 +9,6 @@ from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_large_diff from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import GitProvider @@ -46,7 +45,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s PATCH_EXTRA_LINES = 0 try: - diff_files = list(git_provider.get_diff_files()) + diff_files = git_provider.get_diff_files() except RateLimitExceededException as e: logging.error(f"Rate limit exceeded for git provider API. original message {e}") raise @@ -98,12 +97,7 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, for lang in pr_languages: for file in lang['files']: original_file_content_str = file.base_file - new_file_content_str = file.head_file patch = file.patch - - # handle the case of large patch, that initially was not loaded - patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch) - if not patch: continue @@ -161,7 +155,6 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo original_file_content_str = file.base_file new_file_content_str = file.head_file patch = file.patch - patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch) if not patch: continue diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 6d0a9206..2f446613 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -195,38 +195,30 @@ def convert_str_to_datetime(date_str): return datetime.strptime(date_str, datetime_format) -def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str: +def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str) -> str: """ Generate a patch for a modified file by comparing the original content of the file with the new content provided as input. Args: - file: The file object for which the patch needs to be generated. new_file_content_str: The new content of the file as a string. original_file_content_str: The original content of the file as a string. - patch: An optional patch string that can be provided as input. Returns: The generated or provided patch string. Raises: None. - - Additional Information: - - If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it - as output. - - If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating - that the file was modified but no patch was found, and a patch is manually created. """ - if not patch: # to Do - also add condition for file extension - try: - diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), - new_file_content_str.splitlines(keepends=True)) - if get_settings().config.verbosity_level >= 2: - logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.") - patch = ''.join(diff) - except Exception: - pass + patch = "" + try: + diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), + new_file_content_str.splitlines(keepends=True)) + if get_settings().config.verbosity_level >= 2: + logging.warning(f"File was modified, but no patch was found. Manually creating patch: {filename}.") + patch = ''.join(diff) + except Exception: + pass return patch diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 1fccda9f..c7201037 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -114,10 +114,12 @@ class GithubProvider(GitProvider): if self.incremental.is_incremental and self.file_set: original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) - patch = load_large_diff(file, new_file_content_str, original_file_content_str, "") + patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) self.file_set[file.filename] = patch else: original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) + if not patch: + patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 3fcf5dbb..36d3bbcc 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -7,6 +7,7 @@ import gitlab from gitlab import GitlabGetError from ..algo.language_handler import is_valid_file +from ..algo.utils import load_large_diff from ..config_loader import get_settings from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider @@ -102,8 +103,15 @@ class GitLabProvider(GitProvider): elif diff['renamed_file']: edit_type = EDIT_TYPE.RENAMED + filename = diff['new_path'] + patch = diff['diff'] + if not patch: + patch = load_large_diff(filename, new_file_content_str, original_file_content_str) + diff_files.append( - FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'], + FilePatchInfo(original_file_content_str, new_file_content_str, + patch=patch, + filename=filename, edit_type=edit_type, old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'])) self.diff_files = diff_files diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 32550f69..a235852e 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -58,7 +58,6 @@ class PRCodeSuggestions: async def _prepare_prediction(self, model: str): logging.info('Getting PR diff...') - # we are using extended hunk with line numbers for code suggestions self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model,