diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index 4d708fcb..2d4fa08b 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -1,6 +1,9 @@ +_LANGCHAIN_INSTALLED = False + try: from langchain_core.messages import HumanMessage, SystemMessage from langchain_openai import AzureChatOpenAI, ChatOpenAI + _LANGCHAIN_INSTALLED = True except: # we don't enforce langchain as a dependency, so if it's not installed, just move on pass @@ -8,6 +11,7 @@ import functools import openai from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt +from langchain_core.runnables import Runnable from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.config_loader import get_settings @@ -18,17 +22,14 @@ OPENAI_RETRIES = 5 class LangChainOpenAIHandler(BaseAiHandler): def __init__(self): - # Initialize OpenAIHandler specific attributes here + if not _LANGCHAIN_INSTALLED: + error_msg = "LangChain is not installed. Please install it with `pip install langchain`." + get_logger().error(error_msg) + raise ImportError(error_msg) + super().__init__() self.azure = get_settings().get("OPENAI.API_TYPE", "").lower() == "azure" - # Create a default unused chat object to trigger early validation - self._create_chat(self.deployment_id) - - def chat(self, messages: list, model: str, temperature: float): - chat = self._create_chat(self.deployment_id) - return chat.invoke(input=messages, model=model, temperature=temperature) - @property def deployment_id(self): """ @@ -36,16 +37,66 @@ class LangChainOpenAIHandler(BaseAiHandler): """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + async def _create_chat_async(self, deployment_id=None): + try: + if self.azure: + # Using Azure OpenAI service + return AzureChatOpenAI( + openai_api_key=get_settings().openai.key, + openai_api_version=get_settings().openai.api_version, + azure_deployment=deployment_id, + azure_endpoint=get_settings().openai.api_base, + ) + else: + # Using standard OpenAI or other LLM services + openai_api_base = get_settings().get("OPENAI.API_BASE", None) + if openai_api_base is None or len(openai_api_base) == 0: + return ChatOpenAI(openai_api_key=get_settings().openai.key) + else: + return ChatOpenAI( + openai_api_key=get_settings().openai.key, + openai_api_base=openai_api_base + ) + except AttributeError as e: + # Handle configuration errors + error_msg = f"OpenAI {e.name} is required" if getattr(e, "name") else str(e) + get_logger().error(error_msg) + raise ValueError(error_msg) from e + @retry( retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError), stop=stop_after_attempt(OPENAI_RETRIES), ) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): + if img_path: + get_logger().warning(f"Image path is not supported for LangChainOpenAIHandler. Ignoring image path: {img_path}") try: messages = [SystemMessage(content=system), HumanMessage(content=user)] + llm = await self._create_chat_async(deployment_id=self.deployment_id) + + if not isinstance(llm, Runnable): + error_message = ( + f"The Langchain LLM object ({type(llm)}) does not implement the Runnable interface. " + f"Please update your Langchain library to the latest version or " + f"check your LLM configuration to support async calls. " + f"PR-Agent is designed to utilize Langchain's async capabilities." + ) + get_logger().error(error_message) + raise NotImplementedError(error_message) + + # Handle parameters based on LLM type + if isinstance(llm, (ChatOpenAI, AzureChatOpenAI)): + # OpenAI models support all parameters + resp = await llm.ainvoke( + input=messages, + model=model, + temperature=temperature + ) + else: + # Other LLMs (like Gemini) only support input parameter + get_logger().info(f"Using simplified ainvoke for {type(llm)}") + resp = await llm.ainvoke(input=messages) - # get a chat completion from the formatted messages - resp = self.chat(messages, model=model, temperature=temperature) finish_reason = "completed" return resp.content, finish_reason @@ -58,27 +109,3 @@ class LangChainOpenAIHandler(BaseAiHandler): except Exception as e: get_logger().warning(f"Unknown error during LLM inference: {e}") raise openai.APIError from e - - def _create_chat(self, deployment_id=None): - try: - if self.azure: - # using a partial function so we can set the deployment_id later to support fallback_deployments - # but still need to access the other settings now so we can raise a proper exception if they're missing - return AzureChatOpenAI( - openai_api_key=get_settings().openai.key, - openai_api_version=get_settings().openai.api_version, - azure_deployment=deployment_id, - azure_endpoint=get_settings().openai.api_base, - ) - else: - # for llms that compatible with openai, should use custom api base - openai_api_base = get_settings().get("OPENAI.API_BASE", None) - if openai_api_base is None or len(openai_api_base) == 0: - return ChatOpenAI(openai_api_key=get_settings().openai.key) - else: - return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base) - except AttributeError as e: - if getattr(e, "name"): - raise ValueError(f"OpenAI {e.name} is required") from e - else: - raise e diff --git a/pr_agent/algo/ai_handlers/openai_ai_handler.py b/pr_agent/algo/ai_handlers/openai_ai_handler.py index 253282b0..f5fb99f6 100644 --- a/pr_agent/algo/ai_handlers/openai_ai_handler.py +++ b/pr_agent/algo/ai_handlers/openai_ai_handler.py @@ -42,8 +42,10 @@ class OpenAIHandler(BaseAiHandler): retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError), stop=stop_after_attempt(OPENAI_RETRIES), ) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): try: + if img_path: + get_logger().warning(f"Image path is not supported for OpenAIHandler. Ignoring image path: {img_path}") get_logger().info("System: ", system) get_logger().info("User: ", user) messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] diff --git a/tests/e2e_tests/langchain_ai_handler.py b/tests/e2e_tests/langchain_ai_handler.py new file mode 100644 index 00000000..d75c4292 --- /dev/null +++ b/tests/e2e_tests/langchain_ai_handler.py @@ -0,0 +1,90 @@ +import asyncio +import os +import time +from pr_agent.algo.ai_handlers.langchain_ai_handler import LangChainOpenAIHandler +from pr_agent.config_loader import get_settings + +def check_settings(): + print('Checking settings...') + settings = get_settings() + + # Check OpenAI settings + if not hasattr(settings, 'openai'): + print('OpenAI settings not found') + return False + + if not hasattr(settings.openai, 'key'): + print('OpenAI API key not found') + return False + + print('OpenAI API key found') + return True + +async def measure_performance(handler, num_requests=3): + print(f'\nRunning performance test with {num_requests} requests...') + start_time = time.time() + + # Create multiple requests + tasks = [ + handler.chat_completion( + model='gpt-3.5-turbo', + system='You are a helpful assistant', + user=f'Test message {i}', + temperature=0.2 + ) for i in range(num_requests) + ] + + # Execute requests concurrently + responses = await asyncio.gather(*tasks) + + end_time = time.time() + total_time = end_time - start_time + avg_time = total_time / num_requests + + print(f'Performance results:') + print(f'Total time: {total_time:.2f} seconds') + print(f'Average time per request: {avg_time:.2f} seconds') + print(f'Requests per second: {num_requests/total_time:.2f}') + + return responses + +async def test(): + print('Starting test...') + + # Check settings first + if not check_settings(): + print('Please set up your environment variables or configuration file') + print('Required: OPENAI_API_KEY') + return + + try: + handler = LangChainOpenAIHandler() + print('Handler created') + + # Basic functionality test + response = await handler.chat_completion( + model='gpt-3.5-turbo', + system='You are a helpful assistant', + user='Hello', + temperature=0.2, + img_path='test.jpg' + ) + print('Response:', response) + + # Performance test + await measure_performance(handler) + + except Exception as e: + print('Error:', str(e)) + print('Error type:', type(e)) + print('Error details:', e.__dict__ if hasattr(e, '__dict__') else 'No additional details') + +if __name__ == '__main__': + print('Environment variables:') + print('OPENAI_API_KEY:', 'Set' if os.getenv('OPENAI_API_KEY') else 'Not set') + print('OPENAI_API_TYPE:', os.getenv('OPENAI_API_TYPE', 'Not set')) + print('OPENAI_API_BASE:', os.getenv('OPENAI_API_BASE', 'Not set')) + + asyncio.run(test()) + + \ No newline at end of file