This commit is contained in:
salberts
2023-07-07 16:31:28 +03:00
parent 7d49e080fc
commit 8b3ff7a632
3 changed files with 39 additions and 37 deletions

View File

@ -34,42 +34,42 @@ class GitProvider(ABC):
def get_pr_description(self): def get_pr_description(self):
pass pass
def get_main_pr_language(self, languages, files) -> str: def get_main_pr_language(languages, files) -> str:
""" """
Get the main language of the commit. Return an empty string if cannot determine. Get the main language of the commit. Return an empty string if cannot determine.
""" """
main_language_str = "" main_language_str = ""
try: try:
top_language = max(languages, key=languages.get).lower() top_language = max(languages, key=languages.get).lower()
# validate that the specific commit uses the main language # validate that the specific commit uses the main language
extension_list = [] extension_list = []
for file in files: for file in files:
extension_list.append(file.filename.rsplit('.')[-1]) extension_list.append(file.filename.rsplit('.')[-1])
# get the most common extension # get the most common extension
most_common_extension = max(set(extension_list), key=extension_list.count) most_common_extension = max(set(extension_list), key=extension_list.count)
# look for a match. TBD: add more languages, do this systematically # look for a match. TBD: add more languages, do this systematically
if most_common_extension == 'py' and top_language == 'python' or \ if most_common_extension == 'py' and top_language == 'python' or \
most_common_extension == 'js' and top_language == 'javascript' or \ most_common_extension == 'js' and top_language == 'javascript' or \
most_common_extension == 'ts' and top_language == 'typescript' or \ most_common_extension == 'ts' and top_language == 'typescript' or \
most_common_extension == 'go' and top_language == 'go' or \ most_common_extension == 'go' and top_language == 'go' or \
most_common_extension == 'java' and top_language == 'java' or \ most_common_extension == 'java' and top_language == 'java' or \
most_common_extension == 'c' and top_language == 'c' or \ most_common_extension == 'c' and top_language == 'c' or \
most_common_extension == 'cpp' and top_language == 'c++' or \ most_common_extension == 'cpp' and top_language == 'c++' or \
most_common_extension == 'cs' and top_language == 'c#' or \ most_common_extension == 'cs' and top_language == 'c#' or \
most_common_extension == 'swift' and top_language == 'swift' or \ most_common_extension == 'swift' and top_language == 'swift' or \
most_common_extension == 'php' and top_language == 'php' or \ most_common_extension == 'php' and top_language == 'php' or \
most_common_extension == 'rb' and top_language == 'ruby' or \ most_common_extension == 'rb' and top_language == 'ruby' or \
most_common_extension == 'rs' and top_language == 'rust' or \ most_common_extension == 'rs' and top_language == 'rust' or \
most_common_extension == 'scala' and top_language == 'scala' or \ most_common_extension == 'scala' and top_language == 'scala' or \
most_common_extension == 'kt' and top_language == 'kotlin' or \ most_common_extension == 'kt' and top_language == 'kotlin' or \
most_common_extension == 'pl' and top_language == 'perl' or \ most_common_extension == 'pl' and top_language == 'perl' or \
most_common_extension == 'swift' and top_language == 'swift': most_common_extension == 'swift' and top_language == 'swift':
main_language_str = top_language main_language_str = top_language
except Exception: except Exception:
pass pass
return main_language_str return main_language_str

View File

@ -9,12 +9,13 @@ from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
class PRQuestions: class PRQuestions:
def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None): def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None):
self.git_provider = get_git_provider()(pr_url, installation_id) self.git_provider = get_git_provider()(pr_url, installation_id)
self.main_pr_language = self.git_provider.get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.installation_id = installation_id self.installation_id = installation_id

View File

@ -10,14 +10,15 @@ from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown from pr_agent.algo.utils import convert_to_markdown
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider,
from pr_agent.git_provider import get_main_pr_language
class PRReviewer: class PRReviewer:
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False): def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = self.git_provider.get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = AiHandler() self.ai_handler = AiHandler()