diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 23c9795b..a2b35e71 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -64,10 +64,13 @@ MAX_TOKENS = { 'vertex_ai/gemini-1.5-flash': 1048576, 'vertex_ai/gemini-2.0-flash': 1048576, 'vertex_ai/gemini-2.5-flash-preview-04-17': 1048576, + 'vertex_ai/gemini-2.5-flash-preview-05-20': 1048576, 'vertex_ai/gemma2': 8200, 'gemini/gemini-1.5-pro': 1048576, 'gemini/gemini-1.5-flash': 1048576, 'gemini/gemini-2.0-flash': 1048576, + 'gemini/gemini-2.5-flash-preview-04-17': 1048576, + 'gemini/gemini-2.5-flash-preview-05-20': 1048576, 'gemini/gemini-2.5-pro-preview-03-25': 1048576, 'gemini/gemini-2.5-pro-preview-05-06': 1048576, 'codechat-bison': 6144, diff --git a/tests/unittest/test_fix_json_escape_char.py b/tests/unittest/test_fix_json_escape_char.py new file mode 100644 index 00000000..afc870a2 --- /dev/null +++ b/tests/unittest/test_fix_json_escape_char.py @@ -0,0 +1,21 @@ +from pr_agent.algo.utils import fix_json_escape_char + + +class TestFixJsonEscapeChar: + def test_valid_json(self): + """Return unchanged when input JSON is already valid""" + text = '{"a": 1, "b": "ok"}' + expected_output = {"a": 1, "b": "ok"} + assert fix_json_escape_char(text) == expected_output + + def test_single_control_char(self): + """Remove a single ASCII control-character""" + text = '{"msg": "hel\x01lo"}' + expected_output = {"msg": "hel lo"} + assert fix_json_escape_char(text) == expected_output + + def test_multiple_control_chars(self): + """Remove multiple control-characters recursively""" + text = '{"x": "A\x02B\x03C"}' + expected_output = {"x": "A B C"} + assert fix_json_escape_char(text) == expected_output diff --git a/tests/unittest/test_get_max_tokens.py b/tests/unittest/test_get_max_tokens.py new file mode 100644 index 00000000..fe2f6685 --- /dev/null +++ b/tests/unittest/test_get_max_tokens.py @@ -0,0 +1,67 @@ +import pytest +from pr_agent.algo.utils import get_max_tokens, MAX_TOKENS +import pr_agent.algo.utils as utils + +class TestGetMaxTokens: + + # Test if the file is in MAX_TOKENS + def test_model_max_tokens(self, monkeypatch): + fake_settings = type('', (), { + 'config': type('', (), { + 'custom_model_max_tokens': 0, + 'max_model_tokens': 0 + })() + })() + + monkeypatch.setattr(utils, "get_settings", lambda: fake_settings) + + model = "gpt-3.5-turbo" + expected = MAX_TOKENS[model] + + assert get_max_tokens(model) == expected + + # Test situations where the model is not registered and exists as a custom model + def test_model_has_custom(self, monkeypatch): + fake_settings = type('', (), { + 'config': type('', (), { + 'custom_model_max_tokens': 5000, + 'max_model_tokens': 0 # 제한 없음 + })() + })() + + monkeypatch.setattr(utils, "get_settings", lambda: fake_settings) + + model = "custom-model" + expected = 5000 + + assert get_max_tokens(model) == expected + + def test_model_not_max_tokens_and_not_has_custom(self, monkeypatch): + fake_settings = type('', (), { + 'config': type('', (), { + 'custom_model_max_tokens': 0, + 'max_model_tokens': 0 + })() + })() + + monkeypatch.setattr(utils, "get_settings", lambda: fake_settings) + + model = "custom-model" + + with pytest.raises(Exception): + get_max_tokens(model) + + def test_model_max_tokens_with__limit(self, monkeypatch): + fake_settings = type('', (), { + 'config': type('', (), { + 'custom_model_max_tokens': 0, + 'max_model_tokens': 10000 + })() + })() + + monkeypatch.setattr(utils, "get_settings", lambda: fake_settings) + + model = "gpt-3.5-turbo" # this model setting is 160000 + expected = 10000 + + assert get_max_tokens(model) == expected diff --git a/tests/unittest/test_language_handler.py b/tests/unittest/test_language_handler.py index f76a15ea..8895a286 100644 --- a/tests/unittest/test_language_handler.py +++ b/tests/unittest/test_language_handler.py @@ -79,13 +79,14 @@ class TestSortFilesByMainLanguages: files = [ type('', (object,), {'filename': 'file1.py'})(), type('', (object,), {'filename': 'file2.java'})(), - type('', (object,), {'filename': 'file3.cpp'})() + type('', (object,), {'filename': 'file3.cpp'})(), + type('', (object,), {'filename': 'file3.test'})() ] expected_output = [ {'language': 'Python', 'files': [files[0]]}, {'language': 'Java', 'files': [files[1]]}, {'language': 'C++', 'files': [files[2]]}, - {'language': 'Other', 'files': []} + {'language': 'Other', 'files': [files[3]]} ] assert sort_files_by_main_languages(languages, files) == expected_output