Extracted to helper functions

This commit is contained in:
zmeir
2023-08-13 11:03:10 +03:00
parent edcf89a456
commit 6ca0655517

View File

@ -208,24 +208,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
async def retry_with_fallback_models(f: Callable):
# getting all models
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]
all_models = [model] + fallback_models
# getting all deployments
deployment_id = get_settings().get("openai.deployment_id", None)
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
if not isinstance(fallback_deployments, list) and fallback_deployments:
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments
if len(fallback_deployments) < len(fallback_models):
raise ValueError(f"The number of fallback deployments ({len(all_deployments)}) "
f"is less than the number of fallback models ({len(all_models)})")
else:
all_deployments = [deployment_id] * len(all_models)
all_models = _get_all_models()
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
try:
@ -241,6 +225,30 @@ async def retry_with_fallback_models(f: Callable):
raise # Re-raise the last exception
def _get_all_models() -> List[str]:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]
all_models = [model] + fallback_models
return all_models
def _get_all_deployments(all_models: List[str]) -> List[str]:
deployment_id = get_settings().get("openai.deployment_id", None)
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
if not isinstance(fallback_deployments, list) and fallback_deployments:
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments
if len(all_deployments) < len(all_models):
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
f"is less than the number of models ({len(all_models)})")
else:
all_deployments = [deployment_id] * len(all_models)
return all_deployments
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str) -> Tuple[int, int]: