mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 05:10:38 +08:00
Add decoupled and non-decoupled modes for code suggestions
This commit is contained in:
@ -10,14 +10,16 @@ from typing import Dict, List
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from pr_agent.algo.git_patch_processing import decouple_and_convert_to_hunks_with_lines_numbers
|
||||
from pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
|
||||
get_pr_diff, get_pr_multi_diffs,
|
||||
retry_with_fallback_models)
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags,
|
||||
show_relevant_configurations)
|
||||
show_relevant_configurations, get_max_tokens, clip_tokens)
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import (AzureDevopsProvider, GithubProvider,
|
||||
GitLabProvider, get_git_provider,
|
||||
@ -45,14 +47,8 @@ class PRCodeSuggestions:
|
||||
get_settings().config.max_model_tokens_original = get_settings().config.max_model_tokens
|
||||
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
|
||||
|
||||
# extended mode
|
||||
try:
|
||||
self.is_extended = self._get_is_extended(args or [])
|
||||
except:
|
||||
self.is_extended = False
|
||||
num_code_suggestions = int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk)
|
||||
|
||||
|
||||
self.ai_handler = ai_handler()
|
||||
self.ai_handler.main_pr_language = self.main_language
|
||||
self.patches_diff = None
|
||||
@ -85,12 +81,18 @@ class PRCodeSuggestions:
|
||||
"date": datetime.now().strftime('%Y-%m-%d'),
|
||||
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
|
||||
}
|
||||
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
|
||||
|
||||
if get_settings().pr_code_suggestions.get("decouple_hunks", True):
|
||||
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
|
||||
self.pr_code_suggestions_prompt_user = get_settings().pr_code_suggestions_prompt.user
|
||||
else:
|
||||
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt_not_decoupled.system
|
||||
self.pr_code_suggestions_prompt_user = get_settings().pr_code_suggestions_prompt_not_decoupled.user
|
||||
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
self.pr_code_suggestions_prompt_system,
|
||||
get_settings().pr_code_suggestions_prompt.user)
|
||||
self.pr_code_suggestions_prompt_user)
|
||||
|
||||
self.progress = f"## Generating PR code suggestions\n\n"
|
||||
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
|
||||
@ -115,11 +117,11 @@ class PRCodeSuggestions:
|
||||
else:
|
||||
self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True)
|
||||
|
||||
# call the model to get the suggestions, and self-reflect on them
|
||||
if not self.is_extended:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||
else:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
|
||||
# # call the model to get the suggestions, and self-reflect on them
|
||||
# if not self.is_extended:
|
||||
# data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||
# else:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
|
||||
if not data:
|
||||
data = {"code_suggestions": []}
|
||||
self.data = data
|
||||
@ -623,16 +625,6 @@ class PRCodeSuggestions:
|
||||
|
||||
return new_code_snippet
|
||||
|
||||
def _get_is_extended(self, args: list[str]) -> bool:
|
||||
"""Check if extended mode should be enabled by the `--extended` flag or automatically according to the configuration"""
|
||||
if any(["extended" in arg for arg in args]):
|
||||
get_logger().info("Extended mode is enabled by the `--extended` flag")
|
||||
return True
|
||||
if get_settings().pr_code_suggestions.auto_extended_mode:
|
||||
# get_logger().info("Extended mode is enabled automatically based on the configuration toggle")
|
||||
return True
|
||||
return False
|
||||
|
||||
def validate_one_liner_suggestion_not_repeating_code(self, suggestion):
|
||||
try:
|
||||
existing_code = suggestion.get('existing_code', '').strip()
|
||||
@ -683,11 +675,31 @@ class PRCodeSuggestions:
|
||||
return patches_diff_list
|
||||
|
||||
async def _prepare_prediction_extended(self, model: str) -> dict:
|
||||
self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
|
||||
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
|
||||
# get PR diff
|
||||
if get_settings().pr_code_suggestions.decouple_hunks:
|
||||
self.patches_diff_list = get_pr_multi_diffs(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
|
||||
add_line_numbers=True) # decouple hunk with line numbers
|
||||
self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list) # decouple hunk
|
||||
|
||||
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
|
||||
self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list)
|
||||
else:
|
||||
# non-decoupled hunks
|
||||
self.patches_diff_list_no_line_numbers = get_pr_multi_diffs(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
|
||||
add_line_numbers=False)
|
||||
self.patches_diff_list = await self.convert_to_decoupled_with_line_numbers(
|
||||
self.patches_diff_list_no_line_numbers, model)
|
||||
if not self.patches_diff_list:
|
||||
# fallback to decoupled hunks
|
||||
self.patches_diff_list = get_pr_multi_diffs(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
|
||||
add_line_numbers=True) # decouple hunk with line numbers
|
||||
|
||||
if self.patches_diff_list:
|
||||
get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}")
|
||||
@ -728,6 +740,42 @@ class PRCodeSuggestions:
|
||||
self.data = data = None
|
||||
return data
|
||||
|
||||
async def convert_to_decoupled_with_line_numbers(self, patches_diff_list_no_line_numbers, model) -> List[str]:
|
||||
with get_logger().contextualize(sub_feature='convert_to_decoupled_with_line_numbers'):
|
||||
try:
|
||||
patches_diff_list = []
|
||||
for patch_prompt in patches_diff_list_no_line_numbers:
|
||||
file_prefix = "## File: "
|
||||
patches = patch_prompt.strip().split(f"\n{file_prefix}")
|
||||
patches_new = copy.deepcopy(patches)
|
||||
for i in range(len(patches_new)):
|
||||
if i == 0:
|
||||
prefix = patches_new[i].split("\n@@")[0].strip()
|
||||
else:
|
||||
prefix = file_prefix + patches_new[i].split("\n@@")[0][1:]
|
||||
prefix = prefix.strip()
|
||||
patches_new[i] = prefix + '\n\n' + decouple_and_convert_to_hunks_with_lines_numbers(patches_new[i],
|
||||
file=None).strip()
|
||||
patches_new[i] = patches_new[i].strip()
|
||||
patch_final = "\n\n\n".join(patches_new)
|
||||
if model in MAX_TOKENS:
|
||||
max_tokens_full = MAX_TOKENS[
|
||||
model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||
else:
|
||||
max_tokens_full = get_max_tokens(model)
|
||||
delta_output = 2000
|
||||
token_count = self.token_handler.count_tokens(patch_final)
|
||||
if token_count > max_tokens_full - delta_output:
|
||||
get_logger().warning(
|
||||
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. clipping the tokens")
|
||||
patch_final = clip_tokens(patch_final, max_tokens_full - delta_output)
|
||||
patches_diff_list.append(patch_final)
|
||||
return patches_diff_list
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Error converting to decoupled with line numbers",
|
||||
artifact={'patches_diff_list_no_line_numbers': patches_diff_list_no_line_numbers})
|
||||
return []
|
||||
|
||||
def generate_summarized_suggestions(self, data: Dict) -> str:
|
||||
try:
|
||||
pr_body = "## PR Code Suggestions ✨\n\n"
|
||||
|
Reference in New Issue
Block a user