1. Move deployment_type to configuration.toml

2. Lint
3. Inject GitHub app installation ID into GitHub provider using the settings mechanism.
This commit is contained in:
Ori Kotek
2023-07-11 16:55:09 +03:00
parent 6eacf4791d
commit b2d952cafa
16 changed files with 53 additions and 39 deletions

View File

@ -1,5 +1,4 @@
import re import re
from typing import Optional
from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer

View File

@ -93,7 +93,7 @@ def sort_files_by_main_languages(languages: Dict, files: list):
for ext in main_extensions: for ext in main_extensions:
main_extensions_flat.extend(ext) main_extensions_flat.extend(ext)
for extensions, lang in zip(main_extensions, languages_sorted_list): for extensions, lang in zip(main_extensions, languages_sorted_list): # noqa: B905
tmp = [] tmp = []
for file in files_filtered: for file in files_filtered:
extension_str = f".{file.filename.split('.')[-1]}" extension_str = f".{file.filename.split('.')[-1]}"

View File

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABC
from dataclasses import dataclass from dataclasses import dataclass
@ -13,27 +12,35 @@ class FilePatchInfo:
class GitProvider(ABC): class GitProvider(ABC):
@abstractmethod
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
pass pass
@abstractmethod
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass pass
@abstractmethod
def remove_initial_comment(self): def remove_initial_comment(self):
pass pass
@abstractmethod
def get_languages(self): def get_languages(self):
pass pass
@abstractmethod
def get_pr_branch(self): def get_pr_branch(self):
pass pass
@abstractmethod
def get_user_id(self): def get_user_id(self):
pass pass
@abstractmethod
def get_pr_description(self): def get_pr_description(self):
pass pass
def get_main_pr_language(languages, files) -> str: def get_main_pr_language(languages, files) -> str:
""" """
Get the main language of the commit. Return an empty string if cannot determine. Get the main language of the commit. Return an empty string if cannot determine.

View File

@ -6,6 +6,7 @@ from urllib.parse import urlparse
from github import AppAuthentication, Github from github import AppAuthentication, Github
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo from .git_provider import FilePatchInfo

View File

@ -1,6 +1,8 @@
from urllib.parse import urlparse import logging
import gitlab
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse
import gitlab
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
@ -9,24 +11,28 @@ from .git_provider import FilePatchInfo, GitProvider
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None): def __init__(self, merge_request_url: Optional[str] = None):
gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GitLab URL is not set in the config file")
gitlab_access_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab( self.gl = gitlab.Gitlab(
settings.get("GITLAB.URL"), gitlab_url,
private_token=settings.get("GITLAB.PERSONAL_ACCESS_TOKEN") gitlab_access_token
) )
self.id_project = None self.id_project = None
self.id_mr = None self.id_mr = None
self.mr = None self.mr = None
self.temp_comments = [] self.temp_comments = []
self._set_merge_request(merge_request_url)
self.set_merge_request(merge_request_url)
@property @property
def pr(self): def pr(self):
'''The GitLab terminology is merge request (MR) instead of pull request (PR)''' '''The GitLab terminology is merge request (MR) instead of pull request (PR)'''
return self.mr return self.mr
def set_merge_request(self, merge_request_url: str): def _set_merge_request(self, merge_request_url: str):
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url) self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request() self.mr = self._get_merge_request()

View File

@ -35,7 +35,8 @@ async def handle_github_webhooks(request: Request, response: Response):
async def handle_request(body): async def handle_request(body):
action = body.get("action", None) action = body.get("action", None)
installation_id = body.get("installation", {}).get("id", None) installation_id = body.get("installation", {}).get("id", None)
agent = PRAgent(installation_id) settings.set("GITHUB.INSTALLATION_ID", installation_id)
agent = PRAgent()
if action == 'created': if action == 'created':
if "comment" not in body: if "comment" not in body:
return {} return {}
@ -66,8 +67,8 @@ async def root():
def start(): def start():
if settings.get("GITHUB.DEPLOYMENT_TYPE", "user") != "app": # Override the deployment type to app
raise Exception("Please set deployment type to app in .secrets.toml file") settings.set("GITHUB.DEPLOYMENT_TYPE", "app")
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)

View File

@ -76,7 +76,8 @@ async def polling_loop():
if comment['user']['login'] == user_id: if comment['user']['login'] == user_id:
continue continue
comment_body = comment['body'] if 'body' in comment else '' comment_body = comment['body'] if 'body' in comment else ''
commenter_github_user = comment['user']['login'] if 'user' in comment else '' commenter_github_user = comment['user']['login'] \
if 'user' in comment else ''
logging.info(f"Commenter: {commenter_github_user}\nComment: {comment_body}") logging.info(f"Commenter: {commenter_github_user}\nComment: {comment_body}")
user_tag = "@" + user_id user_tag = "@" + user_id
if user_tag not in comment_body: if user_tag not in comment_body:

View File

@ -1,12 +1,11 @@
import asyncio import asyncio
import time import time
from urllib.parse import urlparse
import gitlab import gitlab
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
gl = gitlab.Gitlab( gl = gitlab.Gitlab(
settings.get("GITLAB.URL"), settings.get("GITLAB.URL"),
private_token=settings.get("GITLAB.PERSONAL_ACCESS_TOKEN") private_token=settings.get("GITLAB.PERSONAL_ACCESS_TOKEN")

View File

@ -11,9 +11,6 @@ key = "<API_KEY>" # Acquire through https://platform.openai.com
org = "<ORGANIZATION>" # Optional, may be commented out. org = "<ORGANIZATION>" # Optional, may be commented out.
[github] [github]
# The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user"
# ---- Set the following only for deployment type == "user" # ---- Set the following only for deployment type == "user"
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope. user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope.
@ -30,5 +27,3 @@ webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.
# Gitlab personal access token # Gitlab personal access token
personal_access_token = "" personal_access_token = ""
# URL to the gitlab service
gitlab_url = "https://gitlab.com"

View File

@ -11,18 +11,21 @@ require_security_review=true
extended_code_suggestions=false extended_code_suggestions=false
num_code_suggestions=4 num_code_suggestions=4
[pr_questions] [pr_questions]
[github]
# The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user"
[gitlab] [gitlab]
# URL to the gitlab service # URL to the gitlab service
gitlab_url = "https://gitlab.com" gitlab_url = "https://gitlab.com"
# Polling (either proheheject id or namespace/project_name) syntax can be used # Polling (either project id or namespace/project_name) syntax can be used
projects_to_monitor = ['nuclai/algo', 'nuclai/pr-agent-test'] projects_to_monitor = ['org_name/repo_name']
# Polling trigger # Polling trigger
magic_word = "AutoReview" magic_word = "AutoReview"
# Polling interval # Polling interval
polling_interval_seconds = 300 polling_interval_seconds = 30

View File

@ -1,6 +1,5 @@
import copy import copy
import logging import logging
from typing import Optional
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined

View File

@ -1,7 +1,6 @@
import copy import copy
import json import json
import logging import logging
from typing import Optional
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined

View File

@ -7,3 +7,5 @@ Jinja2==3.1.2
tiktoken==0.4.0 tiktoken==0.4.0
uvicorn==0.22.0 uvicorn==0.22.0
python-gitlab==3.15.0 python-gitlab==3.15.0
pytest~=7.4.0
aiohttp~=3.8.4

View File

@ -1,6 +1,6 @@
# Generated by CodiumAI # Generated by CodiumAI
from pr_agent.algo.utils import convert_to_markdown from pr_agent.algo.utils import convert_to_markdown
import pytest
""" """
Code Analysis Code Analysis

View File

@ -1,15 +1,15 @@
# Generated by CodiumAI # Generated by CodiumAI
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
import pytest
""" """
Code Analysis Code Analysis
Objective: Objective:
The objective of the function is to sort a list of files by their main language, putting the files that are in the main language first and the rest of the files after. It takes in a dictionary of languages and their sizes, and a list of files. The objective of the function is to sort a list of files by their main language, putting the files that are in the main
language first and the rest of the files after. It takes in a dictionary of languages and their sizes, and a list of
files.
Inputs: Inputs:
- languages: a dictionary containing the languages and their sizes - languages: a dictionary containing the languages and their sizes
@ -33,6 +33,8 @@ Additional aspects:
- The function uses the filter_bad_extensions function to filter out files with bad extensions - The function uses the filter_bad_extensions function to filter out files with bad extensions
- The function uses a rest_files dictionary to store the files that do not belong to any of the main extensions - The function uses a rest_files dictionary to store the files that do not belong to any of the main extensions
""" """
class TestSortFilesByMainLanguages: class TestSortFilesByMainLanguages:
# Tests that files are sorted by main language, with files in main language first and the rest after # Tests that files are sorted by main language, with files in main language first and the rest after
def test_happy_path_sort_files_by_main_languages(self): def test_happy_path_sort_files_by_main_languages(self):

View File

@ -70,7 +70,7 @@ class TestParseCodeSuggestion:
'before': 'Before 1', 'before': 'Before 1',
'after': 'After 1' 'after': 'After 1'
} }
expected_output = " **suggestion:** Suggestion 1\n **description:** Description 1\n **before:** Before 1\n **after:** After 1\n\n" expected_output = " **suggestion:** Suggestion 1\n **description:** Description 1\n **before:** Before 1\n **after:** After 1\n\n" # noqa: E501
assert parse_code_suggestion(code_suggestions) == expected_output assert parse_code_suggestion(code_suggestions) == expected_output
# Tests that function returns correct output when input dictionary has 'code example' key # Tests that function returns correct output when input dictionary has 'code example' key
@ -84,5 +84,5 @@ class TestParseCodeSuggestion:
'after': 'After 2' 'after': 'After 2'
} }
} }
expected_output = " **suggestion:** Suggestion 2\n **description:** Description 2\n - **code example:**\n - **before:**\n ```\n Before 2\n ```\n - **after:**\n ```\n After 2\n ```\n\n" expected_output = " **suggestion:** Suggestion 2\n **description:** Description 2\n - **code example:**\n - **before:**\n ```\n Before 2\n ```\n - **after:**\n ```\n After 2\n ```\n\n" # noqa: E501
assert parse_code_suggestion(code_suggestions) == expected_output assert parse_code_suggestion(code_suggestions) == expected_output