Refactor model selection logic for PR tools and update turbo model to gpt-4o

This commit is contained in:
mrT23
2024-05-19 12:29:06 +03:00
parent 9b56c83c1d
commit 2880e48860
5 changed files with 10 additions and 9 deletions

View File

@ -1,6 +1,6 @@
[config]
model="gpt-4-turbo-2024-04-09"
model_turbo="gpt-4-turbo-2024-04-09"
model_turbo="gpt-4o"
fallback_models=["gpt-4-0125-preview"]
git_provider="github"
publish_output=true

View File

@ -82,9 +82,9 @@ class PRCodeSuggestions:
self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True)
if not self.is_extended:
data = await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
data = await retry_with_fallback_models(self._prepare_prediction)
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
data = await retry_with_fallback_models(self._prepare_prediction_extended)
if not data:
data = {"code_suggestions": []}
@ -184,7 +184,8 @@ class PRCodeSuggestions:
# self-reflect on suggestions
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], patches_diff)
model = get_settings().config.model_turbo # use turbo model for self-reflection, since it is an easier task
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], patches_diff, model=model)
if response_reflect:
response_reflect_yaml = load_yaml(response_reflect)
code_suggestions_feedback = response_reflect_yaml["code_suggestions"]
@ -546,7 +547,7 @@ class PRCodeSuggestions:
get_logger().info(f"Failed to publish summarized code suggestions, error: {e}")
return ""
async def self_reflect_on_suggestions(self, suggestion_list: List, patches_diff: str) -> str:
async def self_reflect_on_suggestions(self, suggestion_list: List, patches_diff: str, model: str) -> str:
if not suggestion_list:
return ""
@ -559,7 +560,6 @@ class PRCodeSuggestions:
'suggestion_str': suggestion_str,
"diff": patches_diff,
'num_code_suggestions': len(suggestion_list)}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.system).render(
variables)

View File

@ -82,7 +82,7 @@ class PRDescription:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing PR description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO) # turbo model because larger context
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
if self.prediction:
self._prepare_data()

View File

@ -7,6 +7,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -62,7 +63,7 @@ class PRQuestions:
if img_path:
get_logger().debug(f"Image path identified", artifact=img_path)
await retry_with_fallback_models(self._prepare_prediction)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO)
pr_comment = self._prepare_pr_answer()
get_logger().debug(f"PR output", artifact=pr_comment)

View File

@ -125,7 +125,7 @@ class PRReviewer:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO)
await retry_with_fallback_models(self._prepare_prediction)
if not self.prediction:
self.git_provider.remove_initial_comment()
return None