diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 30f7f7cd..b5d8df23 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -34,42 +34,42 @@ class GitProvider(ABC): def get_pr_description(self): pass - def get_main_pr_language(self, languages, files) -> str: - """ - Get the main language of the commit. Return an empty string if cannot determine. - """ - main_language_str = "" - try: - top_language = max(languages, key=languages.get).lower() +def get_main_pr_language(languages, files) -> str: + """ + Get the main language of the commit. Return an empty string if cannot determine. + """ + main_language_str = "" + try: + top_language = max(languages, key=languages.get).lower() - # validate that the specific commit uses the main language - extension_list = [] - for file in files: - extension_list.append(file.filename.rsplit('.')[-1]) + # validate that the specific commit uses the main language + extension_list = [] + for file in files: + extension_list.append(file.filename.rsplit('.')[-1]) - # get the most common extension - most_common_extension = max(set(extension_list), key=extension_list.count) + # get the most common extension + most_common_extension = max(set(extension_list), key=extension_list.count) - # look for a match. TBD: add more languages, do this systematically - if most_common_extension == 'py' and top_language == 'python' or \ - most_common_extension == 'js' and top_language == 'javascript' or \ - most_common_extension == 'ts' and top_language == 'typescript' or \ - most_common_extension == 'go' and top_language == 'go' or \ - most_common_extension == 'java' and top_language == 'java' or \ - most_common_extension == 'c' and top_language == 'c' or \ - most_common_extension == 'cpp' and top_language == 'c++' or \ - most_common_extension == 'cs' and top_language == 'c#' or \ - most_common_extension == 'swift' and top_language == 'swift' or \ - most_common_extension == 'php' and top_language == 'php' or \ - most_common_extension == 'rb' and top_language == 'ruby' or \ - most_common_extension == 'rs' and top_language == 'rust' or \ - most_common_extension == 'scala' and top_language == 'scala' or \ - most_common_extension == 'kt' and top_language == 'kotlin' or \ - most_common_extension == 'pl' and top_language == 'perl' or \ - most_common_extension == 'swift' and top_language == 'swift': - main_language_str = top_language + # look for a match. TBD: add more languages, do this systematically + if most_common_extension == 'py' and top_language == 'python' or \ + most_common_extension == 'js' and top_language == 'javascript' or \ + most_common_extension == 'ts' and top_language == 'typescript' or \ + most_common_extension == 'go' and top_language == 'go' or \ + most_common_extension == 'java' and top_language == 'java' or \ + most_common_extension == 'c' and top_language == 'c' or \ + most_common_extension == 'cpp' and top_language == 'c++' or \ + most_common_extension == 'cs' and top_language == 'c#' or \ + most_common_extension == 'swift' and top_language == 'swift' or \ + most_common_extension == 'php' and top_language == 'php' or \ + most_common_extension == 'rb' and top_language == 'ruby' or \ + most_common_extension == 'rs' and top_language == 'rust' or \ + most_common_extension == 'scala' and top_language == 'scala' or \ + most_common_extension == 'kt' and top_language == 'kotlin' or \ + most_common_extension == 'pl' and top_language == 'perl' or \ + most_common_extension == 'swift' and top_language == 'swift': + main_language_str = top_language - except Exception: - pass + except Exception: + pass - return main_language_str \ No newline at end of file + return main_language_str \ No newline at end of file diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 5b8c503f..c36ce74e 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -9,12 +9,13 @@ from pr_agent.algo.pr_processing import get_pr_diff from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import settings from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers.git_provider import get_main_pr_language class PRQuestions: def __init__(self, pr_url: str, question_str: str, installation_id: Optional[int] = None): self.git_provider = get_git_provider()(pr_url, installation_id) - self.main_pr_language = self.git_provider.get_main_pr_language( + self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) self.installation_id = installation_id diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 2c0b606f..ddc69e4d 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -10,14 +10,15 @@ from pr_agent.algo.pr_processing import get_pr_diff from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown from pr_agent.config_loader import settings -from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers import get_git_provider, +from pr_agent.git_provider import get_main_pr_language class PRReviewer: def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False): self.git_provider = get_git_provider()(pr_url) - self.main_language = self.git_provider.get_main_pr_language( + self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) self.ai_handler = AiHandler()