From 4d84f76948367870ff71555e138d6f4e4de343e9 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Mon, 24 Jul 2023 11:31:35 +0300 Subject: [PATCH] _get_prediction --- pr_agent/tools/pr_description.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 07b5ebb1..a57df47d 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -57,17 +57,34 @@ class PRDescription: logging.info('Getting AI prediction...') self.prediction = await self._get_prediction(model) - async def _get_prediction(self, model: str): + async def _get_prediction(self, model: str) -> str: + """ + Generate an AI prediction for the PR description based on the provided model. + + Args: + model (str): The name of the model to be used for generating the prediction. + + Returns: + str: The generated AI prediction. + """ variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff + environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables) user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables) + if settings.config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") - response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, - system=system_prompt, user=user_prompt) + + response, finish_reason = await self.ai_handler.chat_completion( + model=model, + temperature=0.2, + system=system_prompt, + user=user_prompt + ) + return response def _prepare_pr_answer(self) -> Tuple[str, str, List[str], str]: