diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 5958d15f..44815f56 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,17 +1,16 @@ import re -from typing import Optional from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer class PRAgent: - def __init__(self, installation_id: Optional[int] = None): - self.installation_id = installation_id + def __init__(self): + pass async def handle_request(self, pr_url, request): if 'please review' in request.lower() or 'review' == request.lower().strip() or len(request) == 0: - reviewer = PRReviewer(pr_url, self.installation_id) + reviewer = PRReviewer(pr_url) await reviewer.review() else: @@ -21,5 +20,5 @@ class PRAgent: question = re.split(r'(?i)answer', request)[1].strip() else: question = request - answerer = PRQuestions(pr_url, question, self.installation_id) + answerer = PRQuestions(pr_url, question) await answerer.answer() diff --git a/pr_agent/algo/language_handler.py b/pr_agent/algo/language_handler.py index efc038ca..db99d20a 100644 --- a/pr_agent/algo/language_handler.py +++ b/pr_agent/algo/language_handler.py @@ -93,7 +93,7 @@ def sort_files_by_main_languages(languages: Dict, files: list): for ext in main_extensions: main_extensions_flat.extend(ext) - for extensions, lang in zip(main_extensions, languages_sorted_list): + for extensions, lang in zip(main_extensions, languages_sorted_list): # noqa: B905 tmp = [] for file in files_filtered: extension_str = f".{file.filename.split('.')[-1]}" diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 24dfdea7..550f743e 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -5,6 +5,7 @@ from dynaconf import Dynaconf current_dir = dirname(abspath(__file__)) settings = Dynaconf( envvar_prefix=False, + merge_enabled=True, settings_files=[join(current_dir, f) for f in [ "settings/.secrets.toml", "settings/configuration.toml", diff --git a/pr_agent/git_providers/__init__.py b/pr_agent/git_providers/__init__.py index 54e52767..15f95ba7 100644 --- a/pr_agent/git_providers/__init__.py +++ b/pr_agent/git_providers/__init__.py @@ -1,15 +1,17 @@ from pr_agent.config_loader import settings from pr_agent.git_providers.github_provider import GithubProvider +from pr_agent.git_providers.gitlab_provider import GitLabProvider _GIT_PROVIDERS = { - 'github': GithubProvider + 'github': GithubProvider, + 'gitlab': GitLabProvider, } def get_git_provider(): try: provider_id = settings.config.git_provider except AttributeError as e: - raise ValueError("github_provider is a required attribute in the configuration file") from e + raise ValueError("git_provider is a required attribute in the configuration file") from e if provider_id not in _GIT_PROVIDERS: raise ValueError(f"Unknown git provider: {provider_id}") return _GIT_PROVIDERS[provider_id] diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py new file mode 100644 index 00000000..a30df90b --- /dev/null +++ b/pr_agent/git_providers/git_provider.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +@dataclass +class FilePatchInfo: + base_file: str + head_file: str + patch: str + filename: str + tokens: int = -1 + + +class GitProvider(ABC): + @abstractmethod + def get_diff_files(self) -> list[FilePatchInfo]: + pass + + @abstractmethod + def publish_comment(self, pr_comment: str, is_temporary: bool = False): + pass + + @abstractmethod + def remove_initial_comment(self): + pass + + @abstractmethod + def get_languages(self): + pass + + @abstractmethod + def get_pr_branch(self): + pass + + @abstractmethod + def get_user_id(self): + pass + + @abstractmethod + def get_pr_description(self): + pass + + +def get_main_pr_language(languages, files) -> str: + """ + Get the main language of the commit. Return an empty string if cannot determine. + """ + main_language_str = "" + try: + top_language = max(languages, key=languages.get).lower() + + # validate that the specific commit uses the main language + extension_list = [] + for file in files: + extension_list.append(file.filename.rsplit('.')[-1]) + + # get the most common extension + most_common_extension = max(set(extension_list), key=extension_list.count) + + # look for a match. TBD: add more languages, do this systematically + if most_common_extension == 'py' and top_language == 'python' or \ + most_common_extension == 'js' and top_language == 'javascript' or \ + most_common_extension == 'ts' and top_language == 'typescript' or \ + most_common_extension == 'go' and top_language == 'go' or \ + most_common_extension == 'java' and top_language == 'java' or \ + most_common_extension == 'c' and top_language == 'c' or \ + most_common_extension == 'cpp' and top_language == 'c++' or \ + most_common_extension == 'cs' and top_language == 'c#' or \ + most_common_extension == 'swift' and top_language == 'swift' or \ + most_common_extension == 'php' and top_language == 'php' or \ + most_common_extension == 'rb' and top_language == 'ruby' or \ + most_common_extension == 'rs' and top_language == 'rust' or \ + most_common_extension == 'scala' and top_language == 'scala' or \ + most_common_extension == 'kt' and top_language == 'kotlin' or \ + most_common_extension == 'pl' and top_language == 'perl' or \ + most_common_extension == 'swift' and top_language == 'swift': + main_language_str = top_language + + except Exception: + pass + + return main_language_str diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index a03d0bee..ecc624c7 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -1,25 +1,18 @@ import logging -from collections import namedtuple -from dataclasses import dataclass from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse -from github import AppAuthentication, File, Github +from github import AppAuthentication, Github from pr_agent.config_loader import settings -@dataclass -class FilePatchInfo: - base_file: str - head_file: str - patch: str - filename: str - tokens: int = -1 +from .git_provider import FilePatchInfo + class GithubProvider: - def __init__(self, pr_url: Optional[str] = None, installation_id: Optional[int] = None): - self.installation_id = installation_id + def __init__(self, pr_url: Optional[str] = None): + self.installation_id = settings.get("GITHUB.INSTALLATION_ID") self.github_client = self._get_github_client() self.repo = None self.pr_num = None @@ -32,6 +25,9 @@ class GithubProvider: self.repo, self.pr_num = self._parse_pr_url(pr_url) self.pr = self._get_pr() + def get_files(self): + return self.pr.get_files() + def get_diff_files(self) -> list[FilePatchInfo]: files = self.pr.get_files() diff_files = [] @@ -65,53 +61,15 @@ class GithubProvider: return self.pr.body def get_languages(self): - return self._get_repo().get_languages() - - def get_main_pr_language(self) -> str: - """ - Get the main language of the commit. Return an empty string if cannot determine. - """ - main_language_str = "" - try: - languages = self.get_languages() - top_language = max(languages, key=languages.get).lower() - - # validate that the specific commit uses the main language - extension_list = [] - files = self.pr.get_files() - for file in files: - extension_list.append(file.filename.rsplit('.')[-1]) - - # get the most common extension - most_common_extension = max(set(extension_list), key=extension_list.count) - - # look for a match. TBD: add more languages, do this systematically - if most_common_extension == 'py' and top_language == 'python' or \ - most_common_extension == 'js' and top_language == 'javascript' or \ - most_common_extension == 'ts' and top_language == 'typescript' or \ - most_common_extension == 'go' and top_language == 'go' or \ - most_common_extension == 'java' and top_language == 'java' or \ - most_common_extension == 'c' and top_language == 'c' or \ - most_common_extension == 'cpp' and top_language == 'c++' or \ - most_common_extension == 'cs' and top_language == 'c#' or \ - most_common_extension == 'swift' and top_language == 'swift' or \ - most_common_extension == 'php' and top_language == 'php' or \ - most_common_extension == 'rb' and top_language == 'ruby' or \ - most_common_extension == 'rs' and top_language == 'rust' or \ - most_common_extension == 'scala' and top_language == 'scala' or \ - most_common_extension == 'kt' and top_language == 'kotlin' or \ - most_common_extension == 'pl' and top_language == 'perl' or \ - most_common_extension == 'swift' and top_language == 'swift': - main_language_str = top_language - - except Exception: - pass - - return main_language_str + languages = self._get_repo().get_languages() + return languages def get_pr_branch(self): return self.pr.head.ref + def get_pr_description(self): + return self.pr.body + def get_user_id(self): if not self.github_user_id: try: diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py new file mode 100644 index 00000000..e9279a82 --- /dev/null +++ b/pr_agent/git_providers/gitlab_provider.py @@ -0,0 +1,92 @@ +import logging +from typing import Optional, Tuple +from urllib.parse import urlparse + +import gitlab + +from pr_agent.config_loader import settings + +from .git_provider import FilePatchInfo, GitProvider + + +class GitLabProvider(GitProvider): + def __init__(self, merge_request_url: Optional[str] = None): + gitlab_url = settings.get("GITLAB.URL", None) + if not gitlab_url: + raise ValueError("GitLab URL is not set in the config file") + gitlab_access_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None) + if not gitlab_access_token: + raise ValueError("GitLab personal access token is not set in the config file") + self.gl = gitlab.Gitlab( + gitlab_url, + gitlab_access_token + ) + self.id_project = None + self.id_mr = None + self.mr = None + self.temp_comments = [] + self._set_merge_request(merge_request_url) + + @property + def pr(self): + '''The GitLab terminology is merge request (MR) instead of pull request (PR)''' + return self.mr + + def _set_merge_request(self, merge_request_url: str): + self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url) + self.mr = self._get_merge_request() + + def get_diff_files(self) -> list[FilePatchInfo]: + diffs = self.mr.changes()['changes'] + diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs] + return diff_files + + def get_files(self): + return [change['new_path'] for change in self.mr.changes()['changes']] + + def publish_comment(self, mr_comment: str, is_temporary: bool = False): + comment = self.mr.notes.create({'body': mr_comment}) + if is_temporary: + self.temp_comments.append(comment) + + def remove_initial_comment(self): + try: + for comment in self.temp_comments: + comment.delete() + except Exception as e: + logging.exception(f"Failed to remove temp comments, error: {e}") + + def get_title(self): + return self.mr.title + + def get_description(self): + return self.mr.description + + def get_languages(self): + languages = self.gl.projects.get(self.id_project).languages() + return languages + + def get_pr_branch(self): + return self.mr.source_branch + + def get_pr_description(self): + return self.mr.description + + def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]: + parsed_url = urlparse(merge_request_url) + + path_parts = parsed_url.path.strip('/').split('/') + if path_parts[-2] != 'merge_requests': + raise ValueError("The provided URL does not appear to be a GitLab merge request URL") + + try: + mr_id = int(path_parts[-1]) + except ValueError as e: + raise ValueError("Unable to convert merge request ID to integer") from e + + # Gitlab supports access by both project numeric ID as well as 'namespace/project_name' + return "/".join(path_parts[:2]), mr_id + + def _get_merge_request(self): + mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr) + return mr diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index 6dc5782b..52425651 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -35,7 +35,8 @@ async def handle_github_webhooks(request: Request, response: Response): async def handle_request(body): action = body.get("action", None) installation_id = body.get("installation", {}).get("id", None) - agent = PRAgent(installation_id) + settings.set("GITHUB.INSTALLATION_ID", installation_id) + agent = PRAgent() if action == 'created': if "comment" not in body: return {} @@ -66,8 +67,8 @@ async def root(): def start(): - if settings.get("GITHUB.DEPLOYMENT_TYPE", "user") != "app": - raise Exception("Please set deployment type to app in .secrets.toml file") + # Override the deployment type to app + settings.set("GITHUB.DEPLOYMENT_TYPE", "app") app = FastAPI() app.include_router(router) diff --git a/pr_agent/servers/github_polling.py b/pr_agent/servers/github_polling.py index 06293fd6..e8cc4223 100644 --- a/pr_agent/servers/github_polling.py +++ b/pr_agent/servers/github_polling.py @@ -76,7 +76,8 @@ async def polling_loop(): if comment['user']['login'] == user_id: continue comment_body = comment['body'] if 'body' in comment else '' - commenter_github_user = comment['user']['login'] if 'user' in comment else '' + commenter_github_user = comment['user']['login'] \ + if 'user' in comment else '' logging.info(f"Commenter: {commenter_github_user}\nComment: {comment_body}") user_tag = "@" + user_id if user_tag not in comment_body: diff --git a/pr_agent/servers/gitlab_polling.py b/pr_agent/servers/gitlab_polling.py new file mode 100644 index 00000000..c240310e --- /dev/null +++ b/pr_agent/servers/gitlab_polling.py @@ -0,0 +1,64 @@ +import asyncio +import time + +import gitlab + +from pr_agent.agent.pr_agent import PRAgent +from pr_agent.config_loader import settings + +gl = gitlab.Gitlab( + settings.get("GITLAB.URL"), + private_token=settings.get("GITLAB.PERSONAL_ACCESS_TOKEN") +) + +# Set the list of projects to monitor +projects_to_monitor = settings.get("GITLAB.PROJECTS_TO_MONITOR") +magic_word = settings.get("GITLAB.MAGIC_WORD") + +# Hold the previous seen comments +previous_comments = set() + + +def check_comments(): + print('Polling') + new_comments = {} + for project in projects_to_monitor: + project = gl.projects.get(project) + merge_requests = project.mergerequests.list(state='opened') + for mr in merge_requests: + notes = mr.notes.list(get_all=True) + for note in notes: + if note.id not in previous_comments and note.body.startswith(magic_word): + new_comments[note.id] = dict( + body=note.body[len(magic_word):], + project=project.name, + mr=mr + ) + previous_comments.add(note.id) + print(f"New comment in project {project.name}, merge request {mr.title}: {note.body}") + + return new_comments + + +def handle_new_comments(new_comments): + print('Handling new comments') + agent = PRAgent() + for _, comment in new_comments.items(): + print(f"Handling comment: {comment['body']}") + asyncio.run(agent.handle_request(comment['mr'].web_url, comment['body'])) + + +def run(): + assert settings.get('CONFIG.GIT_PROVIDER') == 'gitlab', 'This script is only for GitLab' + # Initial run to populate previous_comments + check_comments() + + # Run the check every minute + while True: + time.sleep(settings.get("GITLAB.POLLING_INTERVAL_SECONDS")) + new_comments = check_comments() + if new_comments: + handle_new_comments(new_comments) + +if __name__ == '__main__': + run() diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index 420ad00e..eb2f9b76 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -11,9 +11,6 @@ key = "" # Acquire through https://platform.openai.com org = "" # Optional, may be commented out. [github] -# The type of deployment to create. Valid values are 'app' or 'user'. -deployment_type = "user" - # ---- Set the following only for deployment type == "user" user_token = "" # A GitHub personal access token with 'repo' scope. @@ -25,3 +22,8 @@ private_key = """\ """ app_id = 123456 # The GitHub App ID, replace with your own. webhook_secret = "" # Optional, may be commented out. + +[gitlab] +# Gitlab personal access token +personal_access_token = "" + diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 50b34cff..f0a646d1 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -11,5 +11,21 @@ require_security_review=true extended_code_suggestions=false num_code_suggestions=4 +[pr_questions] -[pr_questions] \ No newline at end of file +[github] +# The type of deployment to create. Valid values are 'app' or 'user'. +deployment_type = "user" + +[gitlab] +# URL to the gitlab service +gitlab_url = "https://gitlab.com" + +# Polling (either project id or namespace/project_name) syntax can be used +projects_to_monitor = ['org_name/repo_name'] + +# Polling trigger +magic_word = "AutoReview" + +# Polling interval +polling_interval_seconds = 30 diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 51cef14f..1c26cf99 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -1,6 +1,5 @@ import copy import logging -from typing import Optional from jinja2 import Environment, StrictUndefined @@ -9,21 +8,23 @@ from pr_agent.algo.pr_processing import get_pr_diff from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import settings from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers.git_provider import get_main_pr_language class PRQuestions: - def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None): - self.git_provider = get_git_provider()(pr_url, installation_id) - self.main_pr_language = self.git_provider.get_main_pr_language() - self.installation_id = installation_id + def __init__(self, pr_url: str, question_str: str): + self.git_provider = get_git_provider()(pr_url) + self.main_pr_language = get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) self.ai_handler = AiHandler() self.question_str = question_str self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), - "description": self.git_provider.pr.body, - "language": self.git_provider.get_main_pr_language(), - "diff": "", # empty diff for initial calculation + "description": self.git_provider.get_description(), + "language": self.main_pr_language, + "diff": "", # empty diff for initial calculation "questions": self.question_str, } self.token_handler = TokenHandler(self.git_provider.pr, diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 97d527c1..504548b1 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -1,7 +1,6 @@ import copy import json import logging -from typing import Optional from jinja2 import Environment, StrictUndefined @@ -11,14 +10,16 @@ from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown from pr_agent.config_loader import settings from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers.git_provider import get_main_pr_language class PRReviewer: - def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False): + def __init__(self, pr_url: str, cli_mode=False): - self.git_provider = get_git_provider()(pr_url, installation_id) - self.main_language = self.git_provider.get_main_pr_language() - self.installation_id = installation_id + self.git_provider = get_git_provider()(pr_url) + self.main_language = get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) self.ai_handler = AiHandler() self.patches_diff = None self.prediction = None @@ -26,7 +27,7 @@ class PRReviewer: self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), - "description": self.git_provider.pr.body, + "description": self.git_provider.get_pr_description(), "language": self.main_language, "diff": "", # empty diff for initial calculation "require_tests": settings.pr_reviewer.require_tests_review, diff --git a/requirements.txt b/requirements.txt index e7c6d4c1..64134909 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,6 @@ openai==0.27.8 Jinja2==3.1.2 tiktoken==0.4.0 uvicorn==0.22.0 -pytest==7.4.0 \ No newline at end of file +python-gitlab==3.15.0 +pytest~=7.4.0 +aiohttp~=3.8.4 diff --git a/tests/unit/test_convert_to_markdown.py b/tests/unit/test_convert_to_markdown.py index 08a49f76..bfd9c9b5 100644 --- a/tests/unit/test_convert_to_markdown.py +++ b/tests/unit/test_convert_to_markdown.py @@ -1,6 +1,6 @@ # Generated by CodiumAI from pr_agent.algo.utils import convert_to_markdown -import pytest + """ Code Analysis diff --git a/tests/unit/test_language_handler.py b/tests/unit/test_language_handler.py index 5807b09b..875ec1a7 100644 --- a/tests/unit/test_language_handler.py +++ b/tests/unit/test_language_handler.py @@ -1,15 +1,15 @@ # Generated by CodiumAI + from pr_agent.algo.language_handler import sort_files_by_main_languages - -import pytest - """ Code Analysis Objective: -The objective of the function is to sort a list of files by their main language, putting the files that are in the main language first and the rest of the files after. It takes in a dictionary of languages and their sizes, and a list of files. +The objective of the function is to sort a list of files by their main language, putting the files that are in the main +language first and the rest of the files after. It takes in a dictionary of languages and their sizes, and a list of +files. Inputs: - languages: a dictionary containing the languages and their sizes @@ -33,6 +33,8 @@ Additional aspects: - The function uses the filter_bad_extensions function to filter out files with bad extensions - The function uses a rest_files dictionary to store the files that do not belong to any of the main extensions """ + + class TestSortFilesByMainLanguages: # Tests that files are sorted by main language, with files in main language first and the rest after def test_happy_path_sort_files_by_main_languages(self): @@ -118,4 +120,4 @@ class TestSortFilesByMainLanguages: {'language': 'C++', 'files': [files[2], files[7]]}, {'language': 'Other', 'files': []} ] - assert sort_files_by_main_languages(languages, files) == expected_output \ No newline at end of file + assert sort_files_by_main_languages(languages, files) == expected_output diff --git a/tests/unit/test_parse_code_suggestion.py b/tests/unit/test_parse_code_suggestion.py index 082fed77..87e3cac8 100644 --- a/tests/unit/test_parse_code_suggestion.py +++ b/tests/unit/test_parse_code_suggestion.py @@ -70,7 +70,7 @@ class TestParseCodeSuggestion: 'before': 'Before 1', 'after': 'After 1' } - expected_output = " **suggestion:** Suggestion 1\n **description:** Description 1\n **before:** Before 1\n **after:** After 1\n\n" + expected_output = " **suggestion:** Suggestion 1\n **description:** Description 1\n **before:** Before 1\n **after:** After 1\n\n" # noqa: E501 assert parse_code_suggestion(code_suggestions) == expected_output # Tests that function returns correct output when input dictionary has 'code example' key @@ -84,5 +84,5 @@ class TestParseCodeSuggestion: 'after': 'After 2' } } - expected_output = " **suggestion:** Suggestion 2\n **description:** Description 2\n - **code example:**\n - **before:**\n ```\n Before 2\n ```\n - **after:**\n ```\n After 2\n ```\n\n" + expected_output = " **suggestion:** Suggestion 2\n **description:** Description 2\n - **code example:**\n - **before:**\n ```\n Before 2\n ```\n - **after:**\n ```\n After 2\n ```\n\n" # noqa: E501 assert parse_code_suggestion(code_suggestions) == expected_output