diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py index c8473fb3..b5166b8e 100644 --- a/pr_agent/algo/ai_handlers/base_ai_handler.py +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -15,7 +15,7 @@ class BaseAiHandler(ABC): pass @abstractmethod - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): """ This method should be implemented to return a chat completion from the AI model. Args: diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index d07542f6..536faf41 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -102,13 +102,23 @@ class LiteLLMAIHandler(BaseAiHandler): retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.Timeout)), # No retry on RateLimitError stop=stop_after_attempt(OPENAI_RETRIES) ) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): try: resp, finish_reason = None, None deployment_id = self.deployment_id if self.azure: model = 'azure/' + model messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + if img_path: + import requests + r = requests.get(img_path, allow_redirects=True) + if r.status_code == 404: + error_msg = "The image link is not alive. Please repost the image, get a new address, and send the question again." + get_logger().error(error_msg) + return f"{error_msg}", "error" + messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]}, + {"type": "image_url", "image_url": {"url": img_path}}] + kwargs = { "model": model, "deployment_id": deployment_id, diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index bc7042b1..6d942289 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -86,8 +86,13 @@ async def handle_comments_on_pr(body: Dict[str, Any], return {} comment_body = body.get("comment", {}).get("body") if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"): - get_logger().info("Ignoring comment not starting with /") - return {} + if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'): + comment_body_split = comment_body.split('/ask') + comment_body = '/ask' + comment_body_split[1] +'/n' +comment_body_split[0].strip() + get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}") + else: + get_logger().info("Ignoring comment not starting with /") + return {} disable_eyes = False if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]: api_url = body["issue"]["pull_request"]["url"] diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 39282706..d49b345c 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,7 +1,7 @@ [config] -model="gpt-4" # "gpt-4-0125-preview" -model_turbo="gpt-4-0125-preview" -fallback_models=["gpt-3.5-turbo-16k"] +model="gpt-4-turbo" # "gpt-4-0125-preview" +model_turbo="gpt-4-turbo" +fallback_models=["gpt-4-0125-preview"] git_provider="github" publish_output=true publish_output_progress=true diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 4e1d3c1e..1e2a360e 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -56,6 +56,12 @@ class PRQuestions: get_logger().debug("Relevant configs", artifacts=relevant_configs) if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing answer...", is_temporary=True) + + # identify image + img_path = self.idenfity_image_in_comment() + if img_path: + get_logger().debug(f"Image path identified", artifact=img_path) + await retry_with_fallback_models(self._prepare_prediction) pr_comment = self._prepare_pr_answer() @@ -71,6 +77,19 @@ class PRQuestions: self.git_provider.remove_initial_comment() return "" + def idenfity_image_in_comment(self): + img_path = '' + if '![image]' in self.question_str: + # assuming structure: + # /ask question ... > ![image](img_path) + img_path = self.question_str.split('![image]')[1].strip().strip('()') + self.vars['img_path'] = img_path + elif 'https://' in self.question_str and '.png' in self.question_str: # direct image link + # include https:// in the image path + img_path = 'https://' + self.question_str.split('https://')[1] + self.vars['img_path'] = img_path + return img_path + async def _prepare_prediction(self, model: str): self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) if self.patches_diff: @@ -86,8 +105,14 @@ class PRQuestions: environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables) - response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, - system=system_prompt, user=user_prompt) + if 'img_path' in variables: + img_path = self.vars['img_path'] + response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, + system=system_prompt, user=user_prompt, + img_path=img_path) + else: + response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, + system=system_prompt, user=user_prompt) return response def _prepare_pr_answer(self) -> str: