Support repo-specific configuration file

This commit is contained in:
Ori Kotek
2023-08-01 17:44:08 +03:00
parent 696e2bd6ff
commit e12874b696
3 changed files with 22 additions and 28 deletions

View File

@ -1,7 +1,10 @@
import os
import shlex import shlex
import tempfile
from pr_agent.algo.utils import update_settings_from_args from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_information_from_user import PRInformationFromUser
@ -31,11 +34,28 @@ class PRAgent:
pass pass
async def handle_request(self, pr_url, request) -> bool: async def handle_request(self, pr_url, request) -> bool:
# First, apply repo specific settings if exists
if get_settings().config.use_repo_settings_file:
repo_settings_file = None
try:
git_provider = get_git_provider()(pr_url)
repo_settings = git_provider.get_repo_settings()
if repo_settings:
repo_settings_file = None
fd, repo_settings_file = tempfile.mkstemp(suffix='.toml')
os.write(fd, repo_settings)
get_settings().load_file(repo_settings_file)
finally:
if repo_settings_file:
os.remove(repo_settings_file)
# Then, apply user specific settings if exists
request = request.replace("'", "\\'") request = request.replace("'", "\\'")
lexer = shlex.shlex(request, posix=True) lexer = shlex.shlex(request, posix=True)
lexer.whitespace_split = True lexer.whitespace_split = True
action, *args = list(lexer) action, *args = list(lexer)
args = update_settings_from_args(args) args = update_settings_from_args(args)
action = action.lstrip("/").lower() action = action.lstrip("/").lower()
if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect: if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect:
action = "review" action = "review"

View File

@ -1,6 +1,4 @@
import logging import logging
import os
import tempfile
from datetime import datetime from datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
@ -9,11 +7,11 @@ from github import AppAuthentication, Auth, Github, GithubException
from retry import retry from retry import retry
from starlette_context import context from starlette_context import context
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from ..config_loader import get_settings from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded from ..servers.utils import RateLimitExceeded
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
class GithubProvider(GitProvider): class GithubProvider(GitProvider):
@ -33,17 +31,6 @@ class GithubProvider(GitProvider):
if pr_url: if pr_url:
self.set_pr(pr_url) self.set_pr(pr_url)
self.last_commit_id = list(self.pr.get_commits())[-1] self.last_commit_id = list(self.pr.get_commits())[-1]
if get_settings().config.use_repo_settings_file:
repo_settings = self.get_repo_settings()
if repo_settings:
repo_settings_file = None
try:
fd, repo_settings_file = tempfile.mkstemp(suffix='.toml')
os.write(fd, repo_settings)
get_settings().load_file(repo_settings_file)
finally:
if repo_settings_file:
os.remove(repo_settings_file)
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
return True return True

View File

@ -1,7 +1,5 @@
import logging import logging
import os
import re import re
import tempfile
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
@ -37,17 +35,6 @@ class GitLabProvider(GitProvider):
self.RE_HUNK_HEADER = re.compile( self.RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
self.incremental = incremental self.incremental = incremental
if get_settings().config.use_repo_settings_file:
repo_settings = self.get_repo_settings()
if repo_settings:
repo_settings_file = None
try:
fd, repo_settings_file = tempfile.mkstemp(suffix='.toml')
os.write(fd, repo_settings.encode())
get_settings().load_file(repo_settings_file)
finally:
if repo_settings_file:
os.remove(repo_settings_file)
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments']: if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments']:
@ -268,7 +255,7 @@ class GitLabProvider(GitProvider):
def get_repo_settings(self): def get_repo_settings(self):
try: try:
contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.source_branch).decode() contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.source_branch)
return contents return contents
except Exception: except Exception:
return "" return ""