mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-06 13:50:44 +08:00
Support multiple model types for different reasoning tasks
This commit is contained in:
@ -19,7 +19,7 @@ from pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
|
||||
retry_with_fallback_models)
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags,
|
||||
show_relevant_configurations, get_max_tokens, clip_tokens)
|
||||
show_relevant_configurations, get_max_tokens, clip_tokens, get_model)
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import (AzureDevopsProvider, GithubProvider,
|
||||
GitLabProvider, get_git_provider,
|
||||
@ -121,7 +121,7 @@ class PRCodeSuggestions:
|
||||
# if not self.is_extended:
|
||||
# data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||
# else:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
|
||||
data = await retry_with_fallback_models(self.prepare_prediction_main, model_type=ModelType.REGULAR)
|
||||
if not data:
|
||||
data = {"code_suggestions": []}
|
||||
self.data = data
|
||||
@ -416,9 +416,14 @@ class PRCodeSuggestions:
|
||||
data = self._prepare_pr_code_suggestions(response)
|
||||
|
||||
# self-reflect on suggestions (mandatory, since line numbers are generated now here)
|
||||
model_reflection = get_settings().config.model
|
||||
model_reflect_with_reasoning = get_model('model_reasoning')
|
||||
if model_reflect_with_reasoning == get_settings().config.model and model != get_settings().config.model and model == \
|
||||
get_settings().config.fallback_models[0]:
|
||||
# we are using a fallback model (should not happen on regular conditions)
|
||||
get_logger().warning(f"Using the same model for self-reflection as the one used for suggestions")
|
||||
model_reflect_with_reasoning = model
|
||||
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"],
|
||||
patches_diff, model=model_reflection)
|
||||
patches_diff, model=model_reflect_with_reasoning)
|
||||
if response_reflect:
|
||||
await self.analyze_self_reflection_response(data, response_reflect)
|
||||
else:
|
||||
@ -675,7 +680,7 @@ class PRCodeSuggestions:
|
||||
get_logger().error(f"Error removing line numbers from patches_diff_list, error: {e}")
|
||||
return patches_diff_list
|
||||
|
||||
async def _prepare_prediction_extended(self, model: str) -> dict:
|
||||
async def prepare_prediction_main(self, model: str) -> dict:
|
||||
# get PR diff
|
||||
if get_settings().pr_code_suggestions.decouple_hunks:
|
||||
self.patches_diff_list = get_pr_multi_diffs(self.git_provider,
|
||||
|
Reference in New Issue
Block a user