diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index d3733b6d..30f7f7cd 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -25,14 +25,51 @@ class GitProvider(ABC): def get_languages(self): pass - def get_main_pr_language(self) -> str: - pass - def get_pr_branch(self): pass def get_user_id(self): pass - def get_pr_description(): - pass \ No newline at end of file + def get_pr_description(self): + pass + + def get_main_pr_language(self, languages, files) -> str: + """ + Get the main language of the commit. Return an empty string if cannot determine. + """ + main_language_str = "" + try: + top_language = max(languages, key=languages.get).lower() + + # validate that the specific commit uses the main language + extension_list = [] + for file in files: + extension_list.append(file.filename.rsplit('.')[-1]) + + # get the most common extension + most_common_extension = max(set(extension_list), key=extension_list.count) + + # look for a match. TBD: add more languages, do this systematically + if most_common_extension == 'py' and top_language == 'python' or \ + most_common_extension == 'js' and top_language == 'javascript' or \ + most_common_extension == 'ts' and top_language == 'typescript' or \ + most_common_extension == 'go' and top_language == 'go' or \ + most_common_extension == 'java' and top_language == 'java' or \ + most_common_extension == 'c' and top_language == 'c' or \ + most_common_extension == 'cpp' and top_language == 'c++' or \ + most_common_extension == 'cs' and top_language == 'c#' or \ + most_common_extension == 'swift' and top_language == 'swift' or \ + most_common_extension == 'php' and top_language == 'php' or \ + most_common_extension == 'rb' and top_language == 'ruby' or \ + most_common_extension == 'rs' and top_language == 'rust' or \ + most_common_extension == 'scala' and top_language == 'scala' or \ + most_common_extension == 'kt' and top_language == 'kotlin' or \ + most_common_extension == 'pl' and top_language == 'perl' or \ + most_common_extension == 'swift' and top_language == 'swift': + main_language_str = top_language + + except Exception: + pass + + return main_language_str \ No newline at end of file diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index bd5576c4..0b65c539 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -24,6 +24,9 @@ class GithubProvider: self.repo, self.pr_num = self._parse_pr_url(pr_url) self.pr = self._get_pr() + def get_files(self): + return self.pr.get_files() + def get_diff_files(self) -> list[FilePatchInfo]: files = self.pr.get_files() diff_files = [] @@ -60,48 +63,6 @@ class GithubProvider: languages = self._get_repo().get_languages() return languages - def get_main_pr_language(self) -> str: - """ - Get the main language of the commit. Return an empty string if cannot determine. - """ - main_language_str = "" - try: - languages = self.get_languages() - top_language = max(languages, key=languages.get).lower() - - # validate that the specific commit uses the main language - extension_list = [] - files = self.pr.get_files() - for file in files: - extension_list.append(file.filename.rsplit('.')[-1]) - - # get the most common extension - most_common_extension = max(set(extension_list), key=extension_list.count) - - # look for a match. TBD: add more languages, do this systematically - if most_common_extension == 'py' and top_language == 'python' or \ - most_common_extension == 'js' and top_language == 'javascript' or \ - most_common_extension == 'ts' and top_language == 'typescript' or \ - most_common_extension == 'go' and top_language == 'go' or \ - most_common_extension == 'java' and top_language == 'java' or \ - most_common_extension == 'c' and top_language == 'c' or \ - most_common_extension == 'cpp' and top_language == 'c++' or \ - most_common_extension == 'cs' and top_language == 'c#' or \ - most_common_extension == 'swift' and top_language == 'swift' or \ - most_common_extension == 'php' and top_language == 'php' or \ - most_common_extension == 'rb' and top_language == 'ruby' or \ - most_common_extension == 'rs' and top_language == 'rust' or \ - most_common_extension == 'scala' and top_language == 'scala' or \ - most_common_extension == 'kt' and top_language == 'kotlin' or \ - most_common_extension == 'pl' and top_language == 'perl' or \ - most_common_extension == 'swift' and top_language == 'swift': - main_language_str = top_language - - except Exception: - pass - - return main_language_str - def get_pr_branch(self): return self.pr.head.ref diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 54c27ad3..3c1c1e6e 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -1,30 +1,54 @@ +from urllib.parse import urlparse import gitlab from typing import Optional, Tuple +from pr_agent.config_loader import settings -class GitLabProvider: - def __init__(self, merge_request_url: Optional[str] = None, personal_access_token: Optional[str] = None): - self.gl = gitlab.Gitlab('https://your.gitlab.com', private_token=personal_access_token) - self.project = None - self.mr_iid = None +from .git_provider import FilePatchInfo, GitProvider + + +class GitLabProvider(GitProvider): + def __init__(self, merge_request_url: Optional[str] = None): + self.gl = gitlab.Gitlab( + settings.get("GITLAB.URL"), + private_token=settings.get("GITLAB.PERSONAL_ACCESS_TOKEN") + ) + + self.id_project = None + self.id_mr = None self.mr = None - if merge_request_url: - self.set_merge_request(merge_request_url) + self.temp_comments = [] + + self.set_merge_request(merge_request_url) + + @property + def pr(self): + return self.mr def set_merge_request(self, merge_request_url: str): - self.project, self.mr_iid = self._parse_merge_request_url(merge_request_url) + self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url) self.mr = self._get_merge_request() def get_diff_files(self) -> list[FilePatchInfo]: - diffs = self.mr.diffs.list() - diff_files = [] - for diff in diffs: - # GitLab doesn't provide base and head files. Only diffs are available. - diff_files.append(FilePatchInfo("", "", diff['diff'], diff['new_path'])) + diffs = self.mr.changes()['changes'] + diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs] return diff_files - def publish_comment(self, mr_comment: str): - self.mr.notes.create({'body': mr_comment}) + def get_files(self): + return [change['new_path'] for change in self.mr.changes()['changes']] + + def publish_comment(self, mr_comment: str, is_temporary: bool = False): + comment = self.mr.notes.create({'body': mr_comment}) + print(comment) + if is_temporary: + self.temp_comments.append(comment) + + def remove_initial_comment(self): + try: + for comment in self.temp_comments: + comment.delete() + except Exception as e: + logging.exception(f"Failed to remove temp comments, error: {e}") def get_title(self): return self.mr.title @@ -33,25 +57,44 @@ class GitLabProvider: return self.mr.description def get_languages(self): - # GitLab does not have a direct equivalent to get_languages(). - # An alternative could be to manually parse all the repository files and determine the language from the file extensions. - raise NotImplementedError - - def get_main_pr_language(self) -> str: - # Similar issue as get_languages(). - raise NotImplementedError + languages = self.gl.projects.get(self.id_project).languages() + return languages def get_pr_branch(self): return self.mr.source_branch - def get_notifications(self): - # GitLab doesn't provide a notifications API similar to GitHub's. - raise NotImplementedError + def get_pr_description(self): + return self.mr.description - @staticmethod - def _parse_merge_request_url(merge_request_url: str) -> Tuple[str, int]: - # This function will depend on your GitLab setup and URL structure - raise NotImplementedError + def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]: + parsed_url = urlparse(merge_request_url) + + path_parts = parsed_url.path.strip('/').split('/') + print(path_parts) + if path_parts[-2] != 'merge_requests': + raise ValueError("The provided URL does not appear to be a GitLab merge request URL") + + try: + mr_id = int(path_parts[-1]) + except ValueError as e: + raise ValueError("Unable to convert merge request ID to integer") from e + + # Either user or group + namespace, project_name = path_parts[:2] + + # Workaround for gitlab API limitation + project_ids = [ + project.get_id() for project in self.gl.projects.list(search=project_name, simple=True) + if project.path_with_namespace == f"{namespace}/{project_name}" + ] + + if len(project_ids) == 0: + raise ValueError(f"Unable to find project with name {project_name}") + elif len(project_ids) > 1: + raise ValueError(f"Multiple projects found with name {namespace}/{project_name}") + + return project_ids[0], mr_id def _get_merge_request(self): - return self.gl.projects.get(self.project).mergerequests.get(self.mr_iid) + mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr) + return mr diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 51cef14f..4cadf2d9 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -14,7 +14,9 @@ from pr_agent.git_providers import get_git_provider class PRQuestions: def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None): self.git_provider = get_git_provider()(pr_url, installation_id) - self.main_pr_language = self.git_provider.get_main_pr_language() + self.main_pr_language = self.git_provider.get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) self.installation_id = installation_id self.ai_handler = AiHandler() self.question_str = question_str @@ -22,7 +24,7 @@ class PRQuestions: "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), "description": self.git_provider.pr.body, - "language": self.git_provider.get_main_pr_language(), + "language": self.main_pr_language, "diff": "", # empty diff for initial calculation "questions": self.question_str, } diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 214056a1..418b2b9d 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -17,7 +17,10 @@ class PRReviewer: def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False): self.git_provider = get_git_provider()(pr_url) - self.main_language = self.git_provider.get_main_pr_language() + print(dir(self.git_provider.pr)) + self.main_language = self.git_provider.get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) self.ai_handler = AiHandler() self.patches_diff = None self.prediction = None