Add decoupled and non-decoupled modes for code suggestions

This commit is contained in:
mrT23
2025-03-11 16:46:53 +02:00
parent f5bd98a3b9
commit d16012a568
11 changed files with 269 additions and 100 deletions

View File

@ -10,14 +10,16 @@ from typing import Dict, List
from jinja2 import Environment, StrictUndefined
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.git_patch_processing import decouple_and_convert_to_hunks_with_lines_numbers
from pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
get_pr_diff, get_pr_multi_diffs,
retry_with_fallback_models)
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags,
show_relevant_configurations)
show_relevant_configurations, get_max_tokens, clip_tokens)
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import (AzureDevopsProvider, GithubProvider,
GitLabProvider, get_git_provider,
@ -45,14 +47,8 @@ class PRCodeSuggestions:
get_settings().config.max_model_tokens_original = get_settings().config.max_model_tokens
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
num_code_suggestions = int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk)
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
@ -85,12 +81,18 @@ class PRCodeSuggestions:
"date": datetime.now().strftime('%Y-%m-%d'),
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
}
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
if get_settings().pr_code_suggestions.get("decouple_hunks", True):
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
self.pr_code_suggestions_prompt_user = get_settings().pr_code_suggestions_prompt.user
else:
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt_not_decoupled.system
self.pr_code_suggestions_prompt_user = get_settings().pr_code_suggestions_prompt_not_decoupled.user
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
self.pr_code_suggestions_prompt_system,
get_settings().pr_code_suggestions_prompt.user)
self.pr_code_suggestions_prompt_user)
self.progress = f"## Generating PR code suggestions\n\n"
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
@ -115,11 +117,11 @@ class PRCodeSuggestions:
else:
self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True)
# call the model to get the suggestions, and self-reflect on them
if not self.is_extended:
data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
# # call the model to get the suggestions, and self-reflect on them
# if not self.is_extended:
# data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
# else:
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
if not data:
data = {"code_suggestions": []}
self.data = data
@ -623,16 +625,6 @@ class PRCodeSuggestions:
return new_code_snippet
def _get_is_extended(self, args: list[str]) -> bool:
"""Check if extended mode should be enabled by the `--extended` flag or automatically according to the configuration"""
if any(["extended" in arg for arg in args]):
get_logger().info("Extended mode is enabled by the `--extended` flag")
return True
if get_settings().pr_code_suggestions.auto_extended_mode:
# get_logger().info("Extended mode is enabled automatically based on the configuration toggle")
return True
return False
def validate_one_liner_suggestion_not_repeating_code(self, suggestion):
try:
existing_code = suggestion.get('existing_code', '').strip()
@ -683,11 +675,31 @@ class PRCodeSuggestions:
return patches_diff_list
async def _prepare_prediction_extended(self, model: str) -> dict:
self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
# get PR diff
if get_settings().pr_code_suggestions.decouple_hunks:
self.patches_diff_list = get_pr_multi_diffs(self.git_provider,
self.token_handler,
model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
add_line_numbers=True) # decouple hunk with line numbers
self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list) # decouple hunk
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list)
else:
# non-decoupled hunks
self.patches_diff_list_no_line_numbers = get_pr_multi_diffs(self.git_provider,
self.token_handler,
model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
add_line_numbers=False)
self.patches_diff_list = await self.convert_to_decoupled_with_line_numbers(
self.patches_diff_list_no_line_numbers, model)
if not self.patches_diff_list:
# fallback to decoupled hunks
self.patches_diff_list = get_pr_multi_diffs(self.git_provider,
self.token_handler,
model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
add_line_numbers=True) # decouple hunk with line numbers
if self.patches_diff_list:
get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}")
@ -728,6 +740,42 @@ class PRCodeSuggestions:
self.data = data = None
return data
async def convert_to_decoupled_with_line_numbers(self, patches_diff_list_no_line_numbers, model) -> List[str]:
with get_logger().contextualize(sub_feature='convert_to_decoupled_with_line_numbers'):
try:
patches_diff_list = []
for patch_prompt in patches_diff_list_no_line_numbers:
file_prefix = "## File: "
patches = patch_prompt.strip().split(f"\n{file_prefix}")
patches_new = copy.deepcopy(patches)
for i in range(len(patches_new)):
if i == 0:
prefix = patches_new[i].split("\n@@")[0].strip()
else:
prefix = file_prefix + patches_new[i].split("\n@@")[0][1:]
prefix = prefix.strip()
patches_new[i] = prefix + '\n\n' + decouple_and_convert_to_hunks_with_lines_numbers(patches_new[i],
file=None).strip()
patches_new[i] = patches_new[i].strip()
patch_final = "\n\n\n".join(patches_new)
if model in MAX_TOKENS:
max_tokens_full = MAX_TOKENS[
model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
else:
max_tokens_full = get_max_tokens(model)
delta_output = 2000
token_count = self.token_handler.count_tokens(patch_final)
if token_count > max_tokens_full - delta_output:
get_logger().warning(
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. clipping the tokens")
patch_final = clip_tokens(patch_final, max_tokens_full - delta_output)
patches_diff_list.append(patch_final)
return patches_diff_list
except Exception as e:
get_logger().exception(f"Error converting to decoupled with line numbers",
artifact={'patches_diff_list_no_line_numbers': patches_diff_list_no_line_numbers})
return []
def generate_summarized_suggestions(self, data: Dict) -> str:
try:
pr_body = "## PR Code Suggestions ✨\n\n"