Add truncation and summarization features to PR code suggestions

This commit is contained in:
mrT23
2024-03-04 08:16:05 +02:00
parent 248c6b13be
commit eed23a7aaa
2 changed files with 37 additions and 9 deletions

View File

@ -40,12 +40,14 @@ Specific instructions:
{%- if extra_instructions %} {%- if extra_instructions %}
Extra instructions from the user: Extra instructions from the user:
====== ======
{{ extra_instructions }} {{ extra_instructions }}
====== ======
{%- endif %} {%- endif %}
The output must be a YAML object equivalent to type $PRCodeSuggestions, according to the following Pydantic definitions: The output must be a YAML object equivalent to type $PRCodeSuggestions, according to the following Pydantic definitions:
===== =====
class CodeSuggestion(BaseModel): class CodeSuggestion(BaseModel):
@ -80,18 +82,18 @@ code_suggestions:
... ...
{%- if summarize_mode %} {%- if summarize_mode %}
existing_code: | existing_code: |
def func1(): ...
improved_code: | improved_code: |
... ...
one_sentence_summary: | one_sentence_summary: |
... ...
relevant_lines_start: 12 relevant_lines_start: 12
relevant_lines_end: 12 relevant_lines_end: 13
{%- else %} {%- else %}
existing_code: | existing_code: |
def func1(): ...
relevant_lines_start: 12 relevant_lines_start: 12
relevant_lines_end: 12 relevant_lines_end: 13
improved_code: | improved_code: |
... ...
{%- endif %} {%- endif %}

View File

@ -137,6 +137,7 @@ class PRCodeSuggestions:
model, model,
add_line_numbers_to_hunks=True, add_line_numbers_to_hunks=True,
disable_extra_lines=True) disable_extra_lines=True)
if self.patches_diff: if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff) get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, self.patches_diff) self.prediction = await self._get_prediction(model, self.patches_diff)
@ -150,12 +151,23 @@ class PRCodeSuggestions:
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt) system=system_prompt, user=user_prompt)
return response return response
@staticmethod
def _truncate_if_needed(suggestion):
max_code_suggestion_length = get_settings().get("PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0)
suggestion_truncation_message = get_settings().get("PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "")
if max_code_suggestion_length > 0:
if len(suggestion['improved_code']) > max_code_suggestion_length:
suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length]
suggestion['improved_code'] += f"\n{suggestion_truncation_message}"
get_logger().info(f"Truncated suggestion from {len(suggestion['improved_code'])} "
f"characters to {max_code_suggestion_length} characters")
return suggestion
def _prepare_pr_code_suggestions(self) -> Dict: def _prepare_pr_code_suggestions(self) -> Dict:
review = self.prediction.strip() review = self.prediction.strip()
data = load_yaml(review, data = load_yaml(review,
@ -165,8 +177,22 @@ class PRCodeSuggestions:
# remove invalid suggestions # remove invalid suggestions
suggestion_list = [] suggestion_list = []
one_sentence_summary_list = []
for i, suggestion in enumerate(data['code_suggestions']): for i, suggestion in enumerate(data['code_suggestions']):
if suggestion['existing_code'] != suggestion['improved_code']: if get_settings().pr_code_suggestions.summarize:
if not suggestion or 'one_sentence_summary' not in suggestion or 'label' not in suggestion or 'relevant_file' not in suggestion:
get_logger().debug(f"Skipping suggestion {i + 1}, because it is invalid: {suggestion}")
continue
if suggestion['one_sentence_summary'] in one_sentence_summary_list:
get_logger().debug(f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}")
continue
if ('existing_code' in suggestion) and ('improved_code' in suggestion) and (
suggestion['existing_code'] != suggestion['improved_code']):
suggestion = self._truncate_if_needed(suggestion)
if get_settings().pr_code_suggestions.summarize:
one_sentence_summary_list.append(suggestion['one_sentence_summary'])
suggestion_list.append(suggestion) suggestion_list.append(suggestion)
else: else:
get_logger().debug( get_logger().debug(
@ -250,7 +276,8 @@ class PRCodeSuggestions:
# parallelize calls to AI: # parallelize calls to AI:
if get_settings().pr_code_suggestions.parallel_calls: if get_settings().pr_code_suggestions.parallel_calls:
prediction_list = await asyncio.gather(*[self._get_prediction(model, patches_diff) for patches_diff in self.patches_diff_list]) prediction_list = await asyncio.gather(
*[self._get_prediction(model, patches_diff) for patches_diff in self.patches_diff_list])
self.prediction_list = prediction_list self.prediction_list = prediction_list
else: else:
prediction_list = [] prediction_list = []
@ -304,7 +331,6 @@ class PRCodeSuggestions:
system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render( system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(
variables) variables)
user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt, response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt,
user=user_prompt) user=user_prompt)