This commit is contained in:
salberts
2023-07-07 16:31:28 +03:00
parent 7d49e080fc
commit 8b3ff7a632
3 changed files with 39 additions and 37 deletions

View File

@ -34,42 +34,42 @@ class GitProvider(ABC):
def get_pr_description(self):
pass
def get_main_pr_language(self, languages, files) -> str:
"""
Get the main language of the commit. Return an empty string if cannot determine.
"""
main_language_str = ""
try:
top_language = max(languages, key=languages.get).lower()
def get_main_pr_language(languages, files) -> str:
"""
Get the main language of the commit. Return an empty string if cannot determine.
"""
main_language_str = ""
try:
top_language = max(languages, key=languages.get).lower()
# validate that the specific commit uses the main language
extension_list = []
for file in files:
extension_list.append(file.filename.rsplit('.')[-1])
# validate that the specific commit uses the main language
extension_list = []
for file in files:
extension_list.append(file.filename.rsplit('.')[-1])
# get the most common extension
most_common_extension = max(set(extension_list), key=extension_list.count)
# get the most common extension
most_common_extension = max(set(extension_list), key=extension_list.count)
# look for a match. TBD: add more languages, do this systematically
if most_common_extension == 'py' and top_language == 'python' or \
most_common_extension == 'js' and top_language == 'javascript' or \
most_common_extension == 'ts' and top_language == 'typescript' or \
most_common_extension == 'go' and top_language == 'go' or \
most_common_extension == 'java' and top_language == 'java' or \
most_common_extension == 'c' and top_language == 'c' or \
most_common_extension == 'cpp' and top_language == 'c++' or \
most_common_extension == 'cs' and top_language == 'c#' or \
most_common_extension == 'swift' and top_language == 'swift' or \
most_common_extension == 'php' and top_language == 'php' or \
most_common_extension == 'rb' and top_language == 'ruby' or \
most_common_extension == 'rs' and top_language == 'rust' or \
most_common_extension == 'scala' and top_language == 'scala' or \
most_common_extension == 'kt' and top_language == 'kotlin' or \
most_common_extension == 'pl' and top_language == 'perl' or \
most_common_extension == 'swift' and top_language == 'swift':
main_language_str = top_language
# look for a match. TBD: add more languages, do this systematically
if most_common_extension == 'py' and top_language == 'python' or \
most_common_extension == 'js' and top_language == 'javascript' or \
most_common_extension == 'ts' and top_language == 'typescript' or \
most_common_extension == 'go' and top_language == 'go' or \
most_common_extension == 'java' and top_language == 'java' or \
most_common_extension == 'c' and top_language == 'c' or \
most_common_extension == 'cpp' and top_language == 'c++' or \
most_common_extension == 'cs' and top_language == 'c#' or \
most_common_extension == 'swift' and top_language == 'swift' or \
most_common_extension == 'php' and top_language == 'php' or \
most_common_extension == 'rb' and top_language == 'ruby' or \
most_common_extension == 'rs' and top_language == 'rust' or \
most_common_extension == 'scala' and top_language == 'scala' or \
most_common_extension == 'kt' and top_language == 'kotlin' or \
most_common_extension == 'pl' and top_language == 'perl' or \
most_common_extension == 'swift' and top_language == 'swift':
main_language_str = top_language
except Exception:
pass
except Exception:
pass
return main_language_str
return main_language_str

View File

@ -9,12 +9,13 @@ from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
class PRQuestions:
def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None):
self.git_provider = get_git_provider()(pr_url, installation_id)
self.main_pr_language = self.git_provider.get_main_pr_language(
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.installation_id = installation_id

View File

@ -10,14 +10,15 @@ from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers import get_git_provider,
from pr_agent.git_provider import get_main_pr_language
class PRReviewer:
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
self.git_provider = get_git_provider()(pr_url)
self.main_language = self.git_provider.get_main_pr_language(
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()