mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-06 22:00:40 +08:00
refactor
This commit is contained in:
@ -25,14 +25,51 @@ class GitProvider(ABC):
|
|||||||
def get_languages(self):
|
def get_languages(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_main_pr_language(self) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_pr_description():
|
def get_pr_description(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_main_pr_language(self, languages, files) -> str:
|
||||||
|
"""
|
||||||
|
Get the main language of the commit. Return an empty string if cannot determine.
|
||||||
|
"""
|
||||||
|
main_language_str = ""
|
||||||
|
try:
|
||||||
|
top_language = max(languages, key=languages.get).lower()
|
||||||
|
|
||||||
|
# validate that the specific commit uses the main language
|
||||||
|
extension_list = []
|
||||||
|
for file in files:
|
||||||
|
extension_list.append(file.filename.rsplit('.')[-1])
|
||||||
|
|
||||||
|
# get the most common extension
|
||||||
|
most_common_extension = max(set(extension_list), key=extension_list.count)
|
||||||
|
|
||||||
|
# look for a match. TBD: add more languages, do this systematically
|
||||||
|
if most_common_extension == 'py' and top_language == 'python' or \
|
||||||
|
most_common_extension == 'js' and top_language == 'javascript' or \
|
||||||
|
most_common_extension == 'ts' and top_language == 'typescript' or \
|
||||||
|
most_common_extension == 'go' and top_language == 'go' or \
|
||||||
|
most_common_extension == 'java' and top_language == 'java' or \
|
||||||
|
most_common_extension == 'c' and top_language == 'c' or \
|
||||||
|
most_common_extension == 'cpp' and top_language == 'c++' or \
|
||||||
|
most_common_extension == 'cs' and top_language == 'c#' or \
|
||||||
|
most_common_extension == 'swift' and top_language == 'swift' or \
|
||||||
|
most_common_extension == 'php' and top_language == 'php' or \
|
||||||
|
most_common_extension == 'rb' and top_language == 'ruby' or \
|
||||||
|
most_common_extension == 'rs' and top_language == 'rust' or \
|
||||||
|
most_common_extension == 'scala' and top_language == 'scala' or \
|
||||||
|
most_common_extension == 'kt' and top_language == 'kotlin' or \
|
||||||
|
most_common_extension == 'pl' and top_language == 'perl' or \
|
||||||
|
most_common_extension == 'swift' and top_language == 'swift':
|
||||||
|
main_language_str = top_language
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return main_language_str
|
@ -24,6 +24,9 @@ class GithubProvider:
|
|||||||
self.repo, self.pr_num = self._parse_pr_url(pr_url)
|
self.repo, self.pr_num = self._parse_pr_url(pr_url)
|
||||||
self.pr = self._get_pr()
|
self.pr = self._get_pr()
|
||||||
|
|
||||||
|
def get_files(self):
|
||||||
|
return self.pr.get_files()
|
||||||
|
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
files = self.pr.get_files()
|
files = self.pr.get_files()
|
||||||
diff_files = []
|
diff_files = []
|
||||||
@ -60,48 +63,6 @@ class GithubProvider:
|
|||||||
languages = self._get_repo().get_languages()
|
languages = self._get_repo().get_languages()
|
||||||
return languages
|
return languages
|
||||||
|
|
||||||
def get_main_pr_language(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the main language of the commit. Return an empty string if cannot determine.
|
|
||||||
"""
|
|
||||||
main_language_str = ""
|
|
||||||
try:
|
|
||||||
languages = self.get_languages()
|
|
||||||
top_language = max(languages, key=languages.get).lower()
|
|
||||||
|
|
||||||
# validate that the specific commit uses the main language
|
|
||||||
extension_list = []
|
|
||||||
files = self.pr.get_files()
|
|
||||||
for file in files:
|
|
||||||
extension_list.append(file.filename.rsplit('.')[-1])
|
|
||||||
|
|
||||||
# get the most common extension
|
|
||||||
most_common_extension = max(set(extension_list), key=extension_list.count)
|
|
||||||
|
|
||||||
# look for a match. TBD: add more languages, do this systematically
|
|
||||||
if most_common_extension == 'py' and top_language == 'python' or \
|
|
||||||
most_common_extension == 'js' and top_language == 'javascript' or \
|
|
||||||
most_common_extension == 'ts' and top_language == 'typescript' or \
|
|
||||||
most_common_extension == 'go' and top_language == 'go' or \
|
|
||||||
most_common_extension == 'java' and top_language == 'java' or \
|
|
||||||
most_common_extension == 'c' and top_language == 'c' or \
|
|
||||||
most_common_extension == 'cpp' and top_language == 'c++' or \
|
|
||||||
most_common_extension == 'cs' and top_language == 'c#' or \
|
|
||||||
most_common_extension == 'swift' and top_language == 'swift' or \
|
|
||||||
most_common_extension == 'php' and top_language == 'php' or \
|
|
||||||
most_common_extension == 'rb' and top_language == 'ruby' or \
|
|
||||||
most_common_extension == 'rs' and top_language == 'rust' or \
|
|
||||||
most_common_extension == 'scala' and top_language == 'scala' or \
|
|
||||||
most_common_extension == 'kt' and top_language == 'kotlin' or \
|
|
||||||
most_common_extension == 'pl' and top_language == 'perl' or \
|
|
||||||
most_common_extension == 'swift' and top_language == 'swift':
|
|
||||||
main_language_str = top_language
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return main_language_str
|
|
||||||
|
|
||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.pr.head.ref
|
return self.pr.head.ref
|
||||||
|
|
||||||
|
@ -1,30 +1,54 @@
|
|||||||
|
from urllib.parse import urlparse
|
||||||
import gitlab
|
import gitlab
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from pr_agent.config_loader import settings
|
||||||
|
|
||||||
class GitLabProvider:
|
from .git_provider import FilePatchInfo, GitProvider
|
||||||
def __init__(self, merge_request_url: Optional[str] = None, personal_access_token: Optional[str] = None):
|
|
||||||
self.gl = gitlab.Gitlab('https://your.gitlab.com', private_token=personal_access_token)
|
|
||||||
self.project = None
|
class GitLabProvider(GitProvider):
|
||||||
self.mr_iid = None
|
def __init__(self, merge_request_url: Optional[str] = None):
|
||||||
|
self.gl = gitlab.Gitlab(
|
||||||
|
settings.get("GITLAB.URL"),
|
||||||
|
private_token=settings.get("GITLAB.PERSONAL_ACCESS_TOKEN")
|
||||||
|
)
|
||||||
|
|
||||||
|
self.id_project = None
|
||||||
|
self.id_mr = None
|
||||||
self.mr = None
|
self.mr = None
|
||||||
if merge_request_url:
|
self.temp_comments = []
|
||||||
self.set_merge_request(merge_request_url)
|
|
||||||
|
self.set_merge_request(merge_request_url)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pr(self):
|
||||||
|
return self.mr
|
||||||
|
|
||||||
def set_merge_request(self, merge_request_url: str):
|
def set_merge_request(self, merge_request_url: str):
|
||||||
self.project, self.mr_iid = self._parse_merge_request_url(merge_request_url)
|
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
|
||||||
self.mr = self._get_merge_request()
|
self.mr = self._get_merge_request()
|
||||||
|
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
diffs = self.mr.diffs.list()
|
diffs = self.mr.changes()['changes']
|
||||||
diff_files = []
|
diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs]
|
||||||
for diff in diffs:
|
|
||||||
# GitLab doesn't provide base and head files. Only diffs are available.
|
|
||||||
diff_files.append(FilePatchInfo("", "", diff['diff'], diff['new_path']))
|
|
||||||
return diff_files
|
return diff_files
|
||||||
|
|
||||||
def publish_comment(self, mr_comment: str):
|
def get_files(self):
|
||||||
self.mr.notes.create({'body': mr_comment})
|
return [change['new_path'] for change in self.mr.changes()['changes']]
|
||||||
|
|
||||||
|
def publish_comment(self, mr_comment: str, is_temporary: bool = False):
|
||||||
|
comment = self.mr.notes.create({'body': mr_comment})
|
||||||
|
print(comment)
|
||||||
|
if is_temporary:
|
||||||
|
self.temp_comments.append(comment)
|
||||||
|
|
||||||
|
def remove_initial_comment(self):
|
||||||
|
try:
|
||||||
|
for comment in self.temp_comments:
|
||||||
|
comment.delete()
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"Failed to remove temp comments, error: {e}")
|
||||||
|
|
||||||
def get_title(self):
|
def get_title(self):
|
||||||
return self.mr.title
|
return self.mr.title
|
||||||
@ -33,25 +57,44 @@ class GitLabProvider:
|
|||||||
return self.mr.description
|
return self.mr.description
|
||||||
|
|
||||||
def get_languages(self):
|
def get_languages(self):
|
||||||
# GitLab does not have a direct equivalent to get_languages().
|
languages = self.gl.projects.get(self.id_project).languages()
|
||||||
# An alternative could be to manually parse all the repository files and determine the language from the file extensions.
|
return languages
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_main_pr_language(self) -> str:
|
|
||||||
# Similar issue as get_languages().
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.mr.source_branch
|
return self.mr.source_branch
|
||||||
|
|
||||||
def get_notifications(self):
|
def get_pr_description(self):
|
||||||
# GitLab doesn't provide a notifications API similar to GitHub's.
|
return self.mr.description
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]:
|
||||||
def _parse_merge_request_url(merge_request_url: str) -> Tuple[str, int]:
|
parsed_url = urlparse(merge_request_url)
|
||||||
# This function will depend on your GitLab setup and URL structure
|
|
||||||
raise NotImplementedError
|
path_parts = parsed_url.path.strip('/').split('/')
|
||||||
|
print(path_parts)
|
||||||
|
if path_parts[-2] != 'merge_requests':
|
||||||
|
raise ValueError("The provided URL does not appear to be a GitLab merge request URL")
|
||||||
|
|
||||||
|
try:
|
||||||
|
mr_id = int(path_parts[-1])
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError("Unable to convert merge request ID to integer") from e
|
||||||
|
|
||||||
|
# Either user or group
|
||||||
|
namespace, project_name = path_parts[:2]
|
||||||
|
|
||||||
|
# Workaround for gitlab API limitation
|
||||||
|
project_ids = [
|
||||||
|
project.get_id() for project in self.gl.projects.list(search=project_name, simple=True)
|
||||||
|
if project.path_with_namespace == f"{namespace}/{project_name}"
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(project_ids) == 0:
|
||||||
|
raise ValueError(f"Unable to find project with name {project_name}")
|
||||||
|
elif len(project_ids) > 1:
|
||||||
|
raise ValueError(f"Multiple projects found with name {namespace}/{project_name}")
|
||||||
|
|
||||||
|
return project_ids[0], mr_id
|
||||||
|
|
||||||
def _get_merge_request(self):
|
def _get_merge_request(self):
|
||||||
return self.gl.projects.get(self.project).mergerequests.get(self.mr_iid)
|
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
|
||||||
|
return mr
|
||||||
|
@ -14,7 +14,9 @@ from pr_agent.git_providers import get_git_provider
|
|||||||
class PRQuestions:
|
class PRQuestions:
|
||||||
def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None):
|
def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None):
|
||||||
self.git_provider = get_git_provider()(pr_url, installation_id)
|
self.git_provider = get_git_provider()(pr_url, installation_id)
|
||||||
self.main_pr_language = self.git_provider.get_main_pr_language()
|
self.main_pr_language = self.git_provider.get_main_pr_language(
|
||||||
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
|
)
|
||||||
self.installation_id = installation_id
|
self.installation_id = installation_id
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
self.question_str = question_str
|
self.question_str = question_str
|
||||||
@ -22,7 +24,7 @@ class PRQuestions:
|
|||||||
"title": self.git_provider.pr.title,
|
"title": self.git_provider.pr.title,
|
||||||
"branch": self.git_provider.get_pr_branch(),
|
"branch": self.git_provider.get_pr_branch(),
|
||||||
"description": self.git_provider.pr.body,
|
"description": self.git_provider.pr.body,
|
||||||
"language": self.git_provider.get_main_pr_language(),
|
"language": self.main_pr_language,
|
||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"questions": self.question_str,
|
"questions": self.question_str,
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,10 @@ class PRReviewer:
|
|||||||
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
|
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = self.git_provider.get_main_pr_language()
|
print(dir(self.git_provider.pr))
|
||||||
|
self.main_language = self.git_provider.get_main_pr_language(
|
||||||
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
|
)
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
Reference in New Issue
Block a user