refactor: move CLI argument validation to dedicated class

This commit is contained in:
mrT23
2025-02-20 17:51:16 +02:00
parent a07f6855cb
commit 2887d0a7ed
2 changed files with 44 additions and 19 deletions

View File

@ -3,6 +3,7 @@ from functools import partial
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.cli_args import CliArgs
from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.utils import apply_repo_settings
@ -60,25 +61,15 @@ class PRAgent:
else:
action, *args = request
forbidden_cli_args = ['enable_auto_approval', 'approve_pr_on_self_review', 'base_url', 'url', 'app_name', 'secret_provider',
'git_provider', 'skip_keys', 'openai.key', 'ANALYTICS_FOLDER', 'uri', 'app_id', 'webhook_secret',
'bearer_token', 'PERSONAL_ACCESS_TOKEN', 'override_deployment_type', 'private_key',
'local_cache_path', 'enable_local_cache', 'jira_base_url', 'api_base', 'api_type', 'api_version',
'skip_keys']
if args:
for arg in args:
if arg.startswith('--'):
arg_word = arg.lower()
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key
for forbidden_arg in forbidden_cli_args:
forbidden_arg_word = forbidden_arg.lower()
if '.' not in forbidden_arg_word:
forbidden_arg_word = '.' + forbidden_arg_word
if forbidden_arg_word in arg_word:
get_logger().error(
f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file."
)
return False
# validate args
is_valid, arg = CliArgs.validate_user_args(args)
if not is_valid:
get_logger().error(
f"CLI argument for param '{arg}' is forbidden. Use instead a configuration file."
)
return False
# Update settings from args
args = update_settings_from_args(args)
action = action.lstrip("/").lower()

34
pr_agent/algo/cli_args.py Normal file
View File

@ -0,0 +1,34 @@
from base64 import b64decode
import hashlib
class CliArgs:
@staticmethod
def validate_user_args(args: list) -> (bool, str):
try:
if not args:
return True, ""
# decode forbidden args
_encoded_args = 'ZW5hYmxlX2F1dG9fYXBwcm92YWw=:YXBwcm92ZV9wcl9vbl9zZWxmX3Jldmlldw==:YmFzZV91cmw=:dXJs:YXBwX25hbWU=:c2VjcmV0X3Byb3ZpZGVy:Z2l0X3Byb3ZpZGVy:c2tpcF9rZXlz:b3BlbmFpLmtleQ==:QU5BTFlUSUNTX0ZPTERFUg==:dXJp:YXBwX2lk:d2ViaG9va19zZWNyZXQ=:YmVhcmVyX3Rva2Vu:UEVSU09OQUxfQUNDRVNTX1RPS0VO:b3ZlcnJpZGVfZGVwbG95bWVudF90eXBl:cHJpdmF0ZV9rZXk=:bG9jYWxfY2FjaGVfcGF0aA==:ZW5hYmxlX2xvY2FsX2NhY2hl:amlyYV9iYXNlX3VybA==:YXBpX2Jhc2U=:YXBpX3R5cGU=:YXBpX3ZlcnNpb24=:c2tpcF9rZXlz'
forbidden_cli_args = []
for e in _encoded_args.split(':'):
forbidden_cli_args.append(b64decode(e).decode())
# lowercase all forbidden args
for i, _ in enumerate(forbidden_cli_args):
forbidden_cli_args[i] = forbidden_cli_args[i].lower()
if '.' not in forbidden_cli_args[i]:
forbidden_cli_args[i] = '.' + forbidden_cli_args[i]
for arg in args:
if arg.startswith('--'):
arg_word = arg.lower()
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key
for forbidden_arg_word in forbidden_cli_args:
if forbidden_arg_word in arg_word:
return False, forbidden_arg_word
return True, ""
except Exception as e:
return False, str(e)