This commit is contained in:
mrT23
2024-06-27 08:32:14 +03:00
parent a57896aa94
commit 556dc68add

View File

@ -467,46 +467,3 @@ def get_pr_multi_diffs(git_provider: GitProvider,
final_diff_list.append(final_diff)
return final_diff_list
def prune_context(token_handler, curr_component_str, component_context_str, minium_output_tokens, max_tokens=None) -> Tuple[str, str]:
try:
# Get the max tokens possible
if not max_tokens:
get_logger().error(f"Max tokens not provided, using default value")
max_tokens = get_max_tokens(get_settings().config.model_turbo)
# Check if the component + context are too long
component_tokens = token_handler.count_tokens(curr_component_str)
context_tokens = token_handler.count_tokens(component_context_str)
total_tokens = component_tokens + context_tokens + token_handler.prompt_tokens
get_logger().info(
f"Total tokens: {total_tokens}, context_tokens: {context_tokens}, component_tokens: {component_tokens}, prompt_tokens: {token_handler.prompt_tokens}, max_tokens: {max_tokens}")
# clip the context to fit the max tokens
if total_tokens > max_tokens - minium_output_tokens:
# clip the context to fit the max tokens
max_context_tokens = max_tokens - (minium_output_tokens) - component_tokens - token_handler.prompt_tokens
component_context_str = clip_tokens(component_context_str,
max_context_tokens, num_input_tokens=context_tokens)
context_tokens_old = context_tokens
context_tokens = token_handler.count_tokens(component_context_str)
total_tokens = component_tokens + context_tokens + token_handler.prompt_tokens
get_logger().info(f"Clipped context from {context_tokens_old} to {context_tokens} tokens, total tokens: {total_tokens}")
# clip the class itself to fit the max tokens, if needed
delta = 50 # extra tokens to prevent clipping the component if not necessary
if total_tokens > (max_tokens - minium_output_tokens-delta):
max_context_tokens = max_tokens - minium_output_tokens - context_tokens - token_handler.prompt_tokens # notice 'context_tokens'
curr_component_str= clip_tokens(curr_component_str,
max_context_tokens, num_input_tokens=component_tokens)
component_tokens_new = token_handler.count_tokens(curr_component_str)
total_tokens = component_tokens_new + context_tokens + token_handler.prompt_tokens
get_logger().info(f"Clipped component to fit the max tokens, from {component_tokens} to {component_tokens_new} tokens, total tokens: {total_tokens}")
except Exception as e:
component_context_str = ''
curr_component_str = ''
return curr_component_str, component_context_str