diff --git a/pr_agent/algo/language_handler.py b/pr_agent/algo/language_handler.py index 66e85025..b4c02bee 100644 --- a/pr_agent/algo/language_handler.py +++ b/pr_agent/algo/language_handler.py @@ -3,8 +3,7 @@ from typing import Dict from pr_agent.config_loader import get_settings -language_extension_map_org = get_settings().language_extension_map_org -language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} + # Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501 bad_extensions = get_settings().bad_extensions.default @@ -29,6 +28,8 @@ def sort_files_by_main_languages(languages: Dict, files: list): # languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True) # get all extensions for the languages main_extensions = [] + language_extension_map_org = get_settings().language_extension_map_org + language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} for language in languages_sorted_list: if language.lower() in language_extension_map: main_extensions.append(language_extension_map[language.lower()]) diff --git a/pr_agent/git_providers/codecommit_provider.py b/pr_agent/git_providers/codecommit_provider.py index a4836849..399f0a94 100644 --- a/pr_agent/git_providers/codecommit_provider.py +++ b/pr_agent/git_providers/codecommit_provider.py @@ -6,9 +6,9 @@ from urllib.parse import urlparse from pr_agent.git_providers.codecommit_client import CodeCommitClient -from ..algo.language_handler import is_valid_file, language_extension_map from ..algo.utils import load_large_diff from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider +from ..config_loader import get_settings from ..log import get_logger @@ -269,6 +269,8 @@ class CodeCommitProvider(GitProvider): # where each dictionary item is a language name. # We build that language->extension dictionary here in main_extensions_flat. main_extensions_flat = {} + language_extension_map_org = get_settings().language_extension_map_org + language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} for language, extensions in language_extension_map.items(): for ext in extensions: main_extensions_flat[ext] = language diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index d0012b5e..a341f43a 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from pr_agent.algo.language_handler import language_extension_map # enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) from enum import Enum from typing import Optional +from pr_agent.config_loader import get_settings from pr_agent.log import get_logger @@ -176,6 +176,9 @@ def get_main_pr_language(languages, files) -> str: # get the most common extension most_common_extension = '.' + max(set(extension_list), key=extension_list.count) try: + language_extension_map_org = get_settings().language_extension_map_org + language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} + if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]: main_language_str = top_language else: