diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 5608c50a..dfffcb6b 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,8 +1,10 @@ import shlex +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings from pr_agent.git_providers.utils import apply_repo_settings +from pr_agent.log import get_logger from pr_agent.tools.pr_add_docs import PRAddDocs from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_config import PRConfig @@ -13,6 +15,7 @@ from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_similar_issue import PRSimilarIssue from pr_agent.tools.pr_update_changelog import PRUpdateChangelog +import inspect command2class = { "auto_review": PRReviewer, @@ -37,9 +40,18 @@ command2class = { commands = list(command2class.keys()) +def has_ai_handler_param(cls: object): + constructor = getattr(cls, "__init__", None) + if constructor is not None: + parameters = inspect.signature(constructor).parameters + return "ai_handler" in parameters + return False + class PRAgent: - def __init__(self): + def __init__(self, ai_handler: BaseAiHandler = None): + self.ai_handler = ai_handler pass + async def handle_request(self, pr_url, request, notify=None) -> bool: # First, apply repo specific settings if exists @@ -61,13 +73,18 @@ class PRAgent: if action == "answer": if notify: notify() - await PRReviewer(pr_url, is_answer=True, args=args).run() + await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run() elif action == "auto_review": - await PRReviewer(pr_url, is_auto=True, args=args).run() + await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run() elif action in command2class: if notify: notify() - await command2class[action](pr_url, args=args).run() + + get_logger().info(f"Class: {command2class[action]}") + if(not has_ai_handler_param(cls=command2class[action])): + await command2class[action](pr_url, args=args).run() + else: + await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() else: return False return True diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py new file mode 100644 index 00000000..c8473fb3 --- /dev/null +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod + +class BaseAiHandler(ABC): + """ + This class defines the interface for an AI handler to be used by the PR Agents. + """ + + @abstractmethod + def __init__(self): + pass + + @property + @abstractmethod + def deployment_id(self): + pass + + @abstractmethod + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + """ + This method should be implemented to return a chat completion from the AI model. + Args: + model (str): the name of the model to use for the chat completion + system (str): the system message string to use for the chat completion + user (str): the user message string to use for the chat completion + temperature (float): the temperature to use for the chat completion + """ + pass + diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py new file mode 100644 index 00000000..5c793f2b --- /dev/null +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -0,0 +1,46 @@ +from langchain.chat_models import ChatOpenAI +from langchain.schema import SystemMessage, HumanMessage + +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + +from openai.error import APIError, RateLimitError, Timeout, TryAgain +from retry import retry + +OPENAI_RETRIES = 5 + +class LangChainOpenAIHandler(BaseAiHandler): + def __init__(self): + # Initialize OpenAIHandler specific attributes here + try: + super().__init__() + self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key) + + except AttributeError as e: + raise ValueError("OpenAI key is required") from e + + @property + def chat(self): + return self._chat + + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + try: + messages=[SystemMessage(content=system), HumanMessage(content=user)] + + # get a chat completion from the formatted messages + resp = self.chat(messages, model=model, temperature=temperature) + finish_reason="completed" + return resp.content, finish_reason + + except (Exception) as e: + get_logger().error("Unknown error during OpenAI inference: ", e) + raise e \ No newline at end of file diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py similarity index 97% rename from pr_agent/algo/ai_handler.py rename to pr_agent/algo/ai_handlers/litellm_ai_handler.py index fc2899b9..7061ca79 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -6,14 +6,14 @@ import openai from litellm import acompletion from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.config_loader import get_settings -from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.log import get_logger OPENAI_RETRIES = 5 -class AiHandler(BaseAiHandler): +class LiteLLMAIHandler(BaseAiHandler): """ This class handles interactions with the OpenAI API for chat completions. It initializes the API key and other settings from a configuration file, @@ -135,4 +135,4 @@ class AiHandler(BaseAiHandler): usage = response.get("usage") get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, model=model, usage=usage) - return resp, finish_reason + return resp, finish_reason \ No newline at end of file diff --git a/pr_agent/algo/ai_handlers/openai_ai_handler.py b/pr_agent/algo/ai_handlers/openai_ai_handler.py new file mode 100644 index 00000000..3856f6f7 --- /dev/null +++ b/pr_agent/algo/ai_handlers/openai_ai_handler.py @@ -0,0 +1,67 @@ +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +import openai +from openai.error import APIError, RateLimitError, Timeout, TryAgain +from retry import retry + +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + +OPENAI_RETRIES = 5 + + +class OpenAIHandler(BaseAiHandler): + def __init__(self): + # Initialize OpenAIHandler specific attributes here + try: + super().__init__() + openai.api_key = get_settings().openai.key + if get_settings().get("OPENAI.ORG", None): + openai.organization = get_settings().openai.org + if get_settings().get("OPENAI.API_TYPE", None): + if get_settings().openai.api_type == "azure": + self.azure = True + openai.azure_key = get_settings().openai.key + if get_settings().get("OPENAI.API_VERSION", None): + openai.api_version = get_settings().openai.api_version + if get_settings().get("OPENAI.API_BASE", None): + openai.api_base = get_settings().openai.api_base + + except AttributeError as e: + raise ValueError("OpenAI key is required") from e + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + try: + deployment_id = self.deployment_id + get_logger().info("System: ", system) + get_logger().info("User: ", user) + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + chat_completion = await openai.ChatCompletion.acreate( + model=model, + deployment_id=deployment_id, + messages=messages, + temperature=temperature, + ) + resp = chat_completion["choices"][0]['message']['content'] + finish_reason = chat_completion["choices"][0]["finish_reason"] + usage = chat_completion.get("usage") + get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, + model=model, usage=usage) + return resp, finish_reason + except (APIError, Timeout, TryAgain) as e: + get_logger().error("Error during OpenAI inference: ", e) + raise + except (RateLimitError) as e: + get_logger().error("Rate limit error during OpenAI inference: ", e) + raise + except (Exception) as e: + get_logger().error("Unknown error during OpenAI inference: ", e) + raise TryAgain from e \ No newline at end of file diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index d5e1a3c6..9e100042 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -59,14 +59,14 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool=True) -> str: if key.lower() == 'code feedback': if gfm_supported: markdown_text += f"\n\n- " - markdown_text += f"
{ emoji } Code feedback:\n\n" + markdown_text += f"
{ emoji } Code feedback:" else: markdown_text += f"\n\n- **{emoji} Code feedback:**\n\n" else: markdown_text += f"- {emoji} **{key}:**\n\n" - for item in value: + for i, item in enumerate(value): if isinstance(item, dict) and key.lower() == 'code feedback': - markdown_text += parse_code_suggestion(item, gfm_supported) + markdown_text += parse_code_suggestion(item, i, gfm_supported) elif item: markdown_text += f" - {item}\n" if key.lower() == 'code feedback': @@ -80,7 +80,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool=True) -> str: return markdown_text -def parse_code_suggestion(code_suggestions: dict, gfm_supported: bool=True) -> str: +def parse_code_suggestion(code_suggestions: dict, i: int = 0, gfm_supported: bool = True) -> str: """ Convert a dictionary of data into markdown format. @@ -91,24 +91,52 @@ def parse_code_suggestion(code_suggestions: dict, gfm_supported: bool=True) -> s str: A string containing the markdown formatted text generated from the input dictionary. """ markdown_text = "" - for sub_key, sub_value in code_suggestions.items(): - if isinstance(sub_value, dict): # "code example" - markdown_text += f" - **{sub_key}:**\n" - for code_key, code_value in sub_value.items(): # 'before' and 'after' code - code_str = f"```\n{code_value}\n```" - code_str_indented = textwrap.indent(code_str, ' ') - markdown_text += f" - **{code_key}:**\n{code_str_indented}\n" - else: - if "relevant file" in sub_key.lower(): - markdown_text += f"\n - **{sub_key}:** {sub_value} \n" + if gfm_supported and 'relevant line' in code_suggestions: + if i == 0: + markdown_text += "
" + markdown_text += '' + for sub_key, sub_value in code_suggestions.items(): + try: + if sub_key.lower() == 'relevant file': + relevant_file = sub_value.strip('`').strip('"').strip("'") + markdown_text += f"" + # continue + elif sub_key.lower() == 'suggestion': + markdown_text += f"" + elif sub_key.lower() == 'relevant line': + markdown_text += f"" + sub_value_list = sub_value.split('](') + relevant_line = sub_value_list[0].lstrip('`').lstrip('[') + if len(sub_value_list) > 1: + link = sub_value_list[1].rstrip(')').strip('`') + markdown_text += f"" + else: + markdown_text += f"" + markdown_text += "" + except Exception as e: + get_logger().exception(f"Failed to parse code suggestion: {e}") + pass + markdown_text += '
{sub_key}{relevant_file}
{sub_key}      {sub_value}
relevant line{relevant_line}{relevant_line}
' + markdown_text += "
" + else: + for sub_key, sub_value in code_suggestions.items(): + if isinstance(sub_value, dict): # "code example" + markdown_text += f" - **{sub_key}:**\n" + for code_key, code_value in sub_value.items(): # 'before' and 'after' code + code_str = f"```\n{code_value}\n```" + code_str_indented = textwrap.indent(code_str, ' ') + markdown_text += f" - **{code_key}:**\n{code_str_indented}\n" else: - markdown_text += f" **{sub_key}:** {sub_value} \n" - if not gfm_supported: - if "relevant line" not in sub_key.lower(): # nicer presentation + if "relevant file" in sub_key.lower(): + markdown_text += f"\n - **{sub_key}:** {sub_value} \n" + else: + markdown_text += f" **{sub_key}:** {sub_value} \n" + if not gfm_supported: + if "relevant line" not in sub_key.lower(): # nicer presentation # markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker - markdown_text += "\n" + markdown_text += "\n" return markdown_text @@ -336,7 +364,7 @@ def try_fix_yaml(response_text: str) -> dict: pass -def set_custom_labels(variables): +def set_custom_labels(variables, git_provider=None): if not get_settings().config.enable_custom_labels: return @@ -348,11 +376,8 @@ def set_custom_labels(variables): labels_list = f" - {labels_list}" if labels_list else "" variables["custom_labels"] = labels_list return - #final_labels = "" - #for k, v in labels.items(): - # final_labels += f" - {k} ({v['description']})\n" - #variables["custom_labels"] = final_labels - #variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}" + + # Set custom labels variables["custom_labels_class"] = "class Label(str, Enum):" for k, v in labels.items(): description = v['description'].strip('\n').replace('\n', '\\n') @@ -422,4 +447,4 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str: return clipped_text except Exception as e: get_logger().warning(f"Failed to clip tokens: {e}") - return text + return text \ No newline at end of file diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 5a6a6640..caa7feed 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -5,7 +5,9 @@ import os from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.config_loader import get_settings from pr_agent.log import setup_logger +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger() @@ -57,9 +59,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions command = args.command.lower() get_settings().set("CONFIG.CLI_MODE", True) if args.issue_url: - result = asyncio.run(PRAgent().handle_request(args.issue_url, [command] + args.rest)) + result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, [command] + args.rest)) else: - result = asyncio.run(PRAgent().handle_request(args.pr_url, [command] + args.rest)) + result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, [command] + args.rest)) if not result: parser.print_help() diff --git a/pr_agent/servers/bitbucket_app.py b/pr_agent/servers/bitbucket_app.py index e147fbdd..739ba162 100644 --- a/pr_agent/servers/bitbucket_app.py +++ b/pr_agent/servers/bitbucket_app.py @@ -18,7 +18,9 @@ from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.secret_providers import get_secret_provider +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() secret_provider = get_secret_provider() @@ -84,7 +86,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req context['bitbucket_bearer_token'] = bearer_token context["settings"] = copy.deepcopy(global_settings) event = data["event"] - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) if event == "pullrequest:created": pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] log_context["api_url"] = pr_url diff --git a/pr_agent/servers/gerrit_server.py b/pr_agent/servers/gerrit_server.py index 1783f6b9..b8b90670 100644 --- a/pr_agent/servers/gerrit_server.py +++ b/pr_agent/servers/gerrit_server.py @@ -12,7 +12,9 @@ from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger, setup_logger +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger() router = APIRouter() @@ -43,7 +45,7 @@ async def handle_gerrit_request(action: Action, item: Item): status_code=400, detail="msg is required for ask command" ) - await PRAgent().handle_request( + await PRAgent(ai_handler=litellm_ai_handler).handle_request( f"{item.project}:{item.refspec}", f"/{item.msg.strip()}" ) diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py index 45f9c712..e420a61e 100644 --- a/pr_agent/servers/github_action_runner.py +++ b/pr_agent/servers/github_action_runner.py @@ -11,7 +11,8 @@ from pr_agent.log import get_logger from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_reviewer import PRReviewer - +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() def is_true(value: Union[str, bool]) -> bool: if isinstance(value, bool): @@ -110,9 +111,9 @@ async def run_action(): comment_id = event_payload.get("comment", {}).get("id") provider = get_git_provider()(pr_url=url) if is_pr: - await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) + await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) else: - await PRAgent().handle_request(url, body) + await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body) if __name__ == '__main__': diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index 36cc3e88..00b32b94 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -17,7 +17,9 @@ from pr_agent.git_providers.utils import apply_repo_settings from pr_agent.git_providers.git_provider import IncrementalPR from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -79,7 +81,7 @@ async def handle_request(body: Dict[str, Any], event: str): action = body.get("action") if not action: return {} - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) bot_user = get_settings().github_app.bot_user sender = body.get("sender", {}).get("login") log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"} diff --git a/pr_agent/servers/github_polling.py b/pr_agent/servers/github_polling.py index 1363b941..b473b8fa 100644 --- a/pr_agent/servers/github_polling.py +++ b/pr_agent/servers/github_polling.py @@ -8,7 +8,9 @@ from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.help import bot_help_text +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) NOTIFICATION_URL = "https://api.github.com/notifications" @@ -34,7 +36,7 @@ async def polling_loop(): last_modified = [None] git_provider = get_git_provider()() user_id = git_provider.get_user_id() - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) try: diff --git a/pr_agent/servers/gitlab_webhook.py b/pr_agent/servers/gitlab_webhook.py index a5d5a115..91956262 100644 --- a/pr_agent/servers/gitlab_webhook.py +++ b/pr_agent/servers/gitlab_webhook.py @@ -14,7 +14,9 @@ from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.secret_providers import get_secret_provider +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -26,7 +28,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c log_context["event"] = "pull_request" if body == "/review" else "comment" log_context["api_url"] = url with get_logger().contextualize(**log_context): - background_tasks.add_task(PRAgent().handle_request, url, body) + background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body) @router.post("/webhook") diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index eec75b9c..70dd66c2 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -4,7 +4,7 @@ from typing import Dict from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,14 +15,14 @@ from pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index c18ec06e..8bded7de 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -3,7 +3,7 @@ import textwrap from typing import Dict, List from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -14,7 +14,7 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index a4060ba1..73807cf5 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,7 +4,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRDescription: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index 27c77180..e52765f7 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,7 +12,7 @@ from pr_agent.log import get_logger class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 4aec3edf..79edfd6a 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,7 +12,7 @@ from pr_agent.log import get_logger class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index fd7461eb..d86dc052 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -7,7 +7,7 @@ import yaml from jinja2 import Environment, StrictUndefined from yaml import SafeLoader -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels @@ -22,13 +22,15 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Args: pr_url (str): The URL of the pull request to be reviewed. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. + is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False. + ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. """ self.parse_args(args) # -i command diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index f8a84960..07130749 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -5,7 +5,7 @@ from typing import Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -17,7 +17,7 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/requirements.txt b/requirements.txt index 2f38da7b..1f97ed7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ starlette-context==0.3.6 tiktoken==0.5.2 ujson==5.8.0 uvicorn==0.22.0 +langchain==0.0.349