refactor + add description options

This commit is contained in:
Hussam.lawen
2023-10-24 22:28:57 +03:00
parent 07617eab5a
commit 1a89c7eadf
5 changed files with 38 additions and 24 deletions

View File

@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml
from pr_agent.algo.utils import load_yaml, set_custom_labels
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -44,7 +44,6 @@ class PRDescription:
"extra_instructions": get_settings().pr_description.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
"custom_labels": ""
}
self.user_description = self.git_provider.get_user_description()
@ -142,7 +141,7 @@ class PRDescription:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
await self.set_custom_labels(variables)
await set_custom_labels(variables)
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
@ -159,15 +158,6 @@ class PRDescription:
return response
async def set_custom_labels(self, variables):
labels = get_settings().pr_description.custom_labels
if not labels:
# set default labels
labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other']
labels_list = "\n - ".join(labels) if labels else ""
labels_list = f" - {labels_list}" if labels_list else ""
variables["custom_labels"] = labels_list
def _prepare_data(self):
# Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip())