Merge pull request #1387 from KennyDizi/main

Introduce to weak model
This commit is contained in:
Tal
2024-12-11 17:36:18 +02:00
committed by GitHub
12 changed files with 26 additions and 27 deletions

View File

@ -333,7 +333,7 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.WEAK):
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
@ -354,8 +354,8 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.TURBO:
model = get_settings().config.model_turbo
if get_settings().config.get('model_weak') and model_type == ModelType.WEAK:
model = get_settings().config.model_weak
else:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models

View File

@ -35,8 +35,7 @@ class Range(BaseModel):
class ModelType(str, Enum):
REGULAR = "regular"
TURBO = "turbo"
WEAK = "weak"
class PRReviewHeader(str, Enum):
REGULAR = "## PR Reviewer Guide"

View File

@ -99,5 +99,5 @@ def set_claude_model():
"""
model_claude = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
get_settings().set('config.model', model_claude)
get_settings().set('config.model_turbo', model_claude)
get_settings().set('config.model_weak', model_claude)
get_settings().set('config.fallback_models', [model_claude])

View File

@ -1,7 +1,7 @@
[config]
# models
model="gpt-4-turbo-2024-04-09"
model_turbo="gpt-4o-2024-11-20"
model_weak="gpt-4o-mini-2024-07-18"
model="gpt-4o-2024-11-20"
fallback_models=["gpt-4o-2024-08-06"]
# CLI
git_provider="github"

View File

@ -114,9 +114,9 @@ class PRCodeSuggestions:
# call the model to get the suggestions, and self-reflect on them
if not self.is_extended:
data = await retry_with_fallback_models(self._prepare_prediction)
data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
if not data:
data = {"code_suggestions": []}
self.data = data

View File

@ -99,7 +99,7 @@ class PRDescription:
# ticket extraction if exists
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
await retry_with_fallback_models(self._prepare_prediction, ModelType.WEAK)
if self.prediction:
self._prepare_data()

View File

@ -114,7 +114,7 @@ class PRHelpMessage:
self.vars['snippets'] = docs_prompt.strip()
# run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
response_yaml = load_yaml(response)
response_str = response_yaml.get('response')
relevant_sections = response_yaml.get('relevant_sections')

View File

@ -79,7 +79,7 @@ class PR_LineQuestions:
line_end=line_end,
side=side)
if self.patch_with_lines:
response = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.TURBO)
response = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK)
get_logger().info('Preparing answer...')
if comment_id:

View File

@ -63,7 +63,7 @@ class PRQuestions:
if img_path:
get_logger().debug(f"Image path identified", artifact=img_path)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
pr_comment = self._prepare_pr_answer()
get_logger().debug(f"PR output", artifact=pr_comment)

View File

@ -148,7 +148,7 @@ class PRReviewer:
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
if not self.prediction:
self.git_provider.remove_initial_comment()
return None

View File

@ -73,7 +73,7 @@ class PRUpdateChangelog:
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
new_file_content, answer = self._prepare_changelog_update()