diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index d0037c95..9347d188 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,7 +1,10 @@ +import os import shlex +import tempfile from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings +from pr_agent.git_providers import get_git_provider from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_information_from_user import PRInformationFromUser @@ -31,11 +34,28 @@ class PRAgent: pass async def handle_request(self, pr_url, request) -> bool: + # First, apply repo specific settings if exists + if get_settings().config.use_repo_settings_file: + repo_settings_file = None + try: + git_provider = get_git_provider()(pr_url) + repo_settings = git_provider.get_repo_settings() + if repo_settings: + repo_settings_file = None + fd, repo_settings_file = tempfile.mkstemp(suffix='.toml') + os.write(fd, repo_settings) + get_settings().load_file(repo_settings_file) + finally: + if repo_settings_file: + os.remove(repo_settings_file) + + # Then, apply user specific settings if exists request = request.replace("'", "\\'") lexer = shlex.shlex(request, posix=True) lexer.whitespace_split = True action, *args = list(lexer) args = update_settings_from_args(args) + action = action.lstrip("/").lower() if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect: action = "review" diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index b4b31a5c..4869ca69 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -1,6 +1,4 @@ import logging -import os -import tempfile from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse @@ -9,11 +7,11 @@ from github import AppAuthentication, Auth, Github, GithubException from retry import retry from starlette_context import context +from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file from ..algo.utils import load_large_diff from ..config_loader import get_settings from ..servers.utils import RateLimitExceeded -from .git_provider import FilePatchInfo, GitProvider, IncrementalPR class GithubProvider(GitProvider): @@ -33,17 +31,6 @@ class GithubProvider(GitProvider): if pr_url: self.set_pr(pr_url) self.last_commit_id = list(self.pr.get_commits())[-1] - if get_settings().config.use_repo_settings_file: - repo_settings = self.get_repo_settings() - if repo_settings: - repo_settings_file = None - try: - fd, repo_settings_file = tempfile.mkstemp(suffix='.toml') - os.write(fd, repo_settings) - get_settings().load_file(repo_settings_file) - finally: - if repo_settings_file: - os.remove(repo_settings_file) def is_supported(self, capability: str) -> bool: return True diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index d6a2d591..170b356e 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -1,7 +1,5 @@ import logging -import os import re -import tempfile from typing import Optional, Tuple from urllib.parse import urlparse @@ -37,17 +35,6 @@ class GitLabProvider(GitProvider): self.RE_HUNK_HEADER = re.compile( r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") self.incremental = incremental - if get_settings().config.use_repo_settings_file: - repo_settings = self.get_repo_settings() - if repo_settings: - repo_settings_file = None - try: - fd, repo_settings_file = tempfile.mkstemp(suffix='.toml') - os.write(fd, repo_settings.encode()) - get_settings().load_file(repo_settings_file) - finally: - if repo_settings_file: - os.remove(repo_settings_file) def is_supported(self, capability: str) -> bool: if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments']: @@ -268,7 +255,7 @@ class GitLabProvider(GitProvider): def get_repo_settings(self): try: - contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.source_branch).decode() + contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.source_branch) return contents except Exception: return ""