From 2880e4886090c6d42428cd33efa108cf3247de3a Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 19 May 2024 12:29:06 +0300 Subject: [PATCH] Refactor model selection logic for PR tools and update turbo model to gpt-4o --- pr_agent/settings/configuration.toml | 2 +- pr_agent/tools/pr_code_suggestions.py | 10 +++++----- pr_agent/tools/pr_description.py | 2 +- pr_agent/tools/pr_questions.py | 3 ++- pr_agent/tools/pr_reviewer.py | 2 +- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index fa9b1afd..448dfb57 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,6 +1,6 @@ [config] model="gpt-4-turbo-2024-04-09" -model_turbo="gpt-4-turbo-2024-04-09" +model_turbo="gpt-4o" fallback_models=["gpt-4-0125-preview"] git_provider="github" publish_output=true diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 6dead2be..358f2992 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -82,9 +82,9 @@ class PRCodeSuggestions: self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True) if not self.is_extended: - data = await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO) + data = await retry_with_fallback_models(self._prepare_prediction) else: - data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO) + data = await retry_with_fallback_models(self._prepare_prediction_extended) if not data: data = {"code_suggestions": []} @@ -184,7 +184,8 @@ class PRCodeSuggestions: # self-reflect on suggestions if get_settings().pr_code_suggestions.self_reflect_on_suggestions: - response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], patches_diff) + model = get_settings().config.model_turbo # use turbo model for self-reflection, since it is an easier task + response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], patches_diff, model=model) if response_reflect: response_reflect_yaml = load_yaml(response_reflect) code_suggestions_feedback = response_reflect_yaml["code_suggestions"] @@ -546,7 +547,7 @@ class PRCodeSuggestions: get_logger().info(f"Failed to publish summarized code suggestions, error: {e}") return "" - async def self_reflect_on_suggestions(self, suggestion_list: List, patches_diff: str) -> str: + async def self_reflect_on_suggestions(self, suggestion_list: List, patches_diff: str, model: str) -> str: if not suggestion_list: return "" @@ -559,7 +560,6 @@ class PRCodeSuggestions: 'suggestion_str': suggestion_str, "diff": patches_diff, 'num_code_suggestions': len(suggestion_list)} - model = get_settings().config.model environment = Environment(undefined=StrictUndefined) system_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.system).render( variables) diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 0de93c9c..96f6512d 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -82,7 +82,7 @@ class PRDescription: if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing PR description...", is_temporary=True) - await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO) # turbo model because larger context + await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO) if self.prediction: self._prepare_data() diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 78db1452..3e6355f5 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -7,6 +7,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import ModelType from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -62,7 +63,7 @@ class PRQuestions: if img_path: get_logger().debug(f"Image path identified", artifact=img_path) - await retry_with_fallback_models(self._prepare_prediction) + await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO) pr_comment = self._prepare_pr_answer() get_logger().debug(f"PR output", artifact=pr_comment) diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 33adab39..3a127f4c 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -125,7 +125,7 @@ class PRReviewer: if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing review...", is_temporary=True) - await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO) + await retry_with_fallback_models(self._prepare_prediction) if not self.prediction: self.git_provider.remove_initial_comment() return None