This commit is contained in:
mrT23
2024-04-14 12:00:19 +03:00
parent a4680ded93
commit 8f0f08006f
5 changed files with 49 additions and 9 deletions

View File

@ -15,7 +15,7 @@ class BaseAiHandler(ABC):
pass
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
"""
This method should be implemented to return a chat completion from the AI model.
Args:

View File

@ -102,13 +102,23 @@ class LiteLLMAIHandler(BaseAiHandler):
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.Timeout)), # No retry on RateLimitError
stop=stop_after_attempt(OPENAI_RETRIES)
)
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
try:
resp, finish_reason = None, None
deployment_id = self.deployment_id
if self.azure:
model = 'azure/' + model
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
if img_path:
import requests
r = requests.get(img_path, allow_redirects=True)
if r.status_code == 404:
error_msg = "The image link is not alive. Please repost the image, get a new address, and send the question again."
get_logger().error(error_msg)
return f"{error_msg}", "error"
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
{"type": "image_url", "image_url": {"url": img_path}}]
kwargs = {
"model": model,
"deployment_id": deployment_id,

View File

@ -86,8 +86,13 @@ async def handle_comments_on_pr(body: Dict[str, Any],
return {}
comment_body = body.get("comment", {}).get("body")
if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"):
get_logger().info("Ignoring comment not starting with /")
return {}
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
comment_body_split = comment_body.split('/ask')
comment_body = '/ask' + comment_body_split[1] +'/n' +comment_body_split[0].strip()
get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}")
else:
get_logger().info("Ignoring comment not starting with /")
return {}
disable_eyes = False
if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]:
api_url = body["issue"]["pull_request"]["url"]

View File

@ -1,7 +1,7 @@
[config]
model="gpt-4" # "gpt-4-0125-preview"
model_turbo="gpt-4-0125-preview"
fallback_models=["gpt-3.5-turbo-16k"]
model="gpt-4-turbo" # "gpt-4-0125-preview"
model_turbo="gpt-4-turbo"
fallback_models=["gpt-4-0125-preview"]
git_provider="github"
publish_output=true
publish_output_progress=true

View File

@ -56,6 +56,12 @@ class PRQuestions:
get_logger().debug("Relevant configs", artifacts=relevant_configs)
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
# identify image
img_path = self.idenfity_image_in_comment()
if img_path:
get_logger().debug(f"Image path identified", artifact=img_path)
await retry_with_fallback_models(self._prepare_prediction)
pr_comment = self._prepare_pr_answer()
@ -71,6 +77,19 @@ class PRQuestions:
self.git_provider.remove_initial_comment()
return ""
def idenfity_image_in_comment(self):
img_path = ''
if '![image]' in self.question_str:
# assuming structure:
# /ask question ... > ![image](img_path)
img_path = self.question_str.split('![image]')[1].strip().strip('()')
self.vars['img_path'] = img_path
elif 'https://' in self.question_str and '.png' in self.question_str: # direct image link
# include https:// in the image path
img_path = 'https://' + self.question_str.split('https://')[1]
self.vars['img_path'] = img_path
return img_path
async def _prepare_prediction(self, model: str):
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
if self.patches_diff:
@ -86,8 +105,14 @@ class PRQuestions:
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
if 'img_path' in variables:
img_path = self.vars['img_path']
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt,
img_path=img_path)
else:
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
def _prepare_pr_answer(self) -> str: