From 7e47baa9db8a0a92292516e76e1d726b25c5ab43 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Sun, 10 Dec 2023 00:25:25 +0800 Subject: [PATCH 01/17] Refactor AI handler classes --- pr_agent/algo/ai_handler.py | 4 +- pr_agent/algo/base_ai_handler.py | 48 ++++++++++++++++++++++ pr_agent/tools/pr_add_docs.py | 6 +-- pr_agent/tools/pr_code_suggestions.py | 6 +-- pr_agent/tools/pr_description.py | 6 +-- pr_agent/tools/pr_information_from_user.py | 6 +-- pr_agent/tools/pr_questions.py | 6 +-- pr_agent/tools/pr_reviewer.py | 6 +-- pr_agent/tools/pr_update_changelog.py | 6 +-- 9 files changed, 71 insertions(+), 23 deletions(-) create mode 100644 pr_agent/algo/base_ai_handler.py diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index c3989563..3e4dc18a 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -7,11 +7,11 @@ from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry from pr_agent.config_loader import get_settings from pr_agent.log import get_logger - +from pr_agent.algo.base_ai_handler import BaseAiHandler OPENAI_RETRIES = 5 -class AiHandler: +class AiHandler(BaseAiHandler): """ This class handles interactions with the OpenAI API for chat completions. It initializes the API key and other settings from a configuration file, diff --git a/pr_agent/algo/base_ai_handler.py b/pr_agent/algo/base_ai_handler.py new file mode 100644 index 00000000..9cb40b5c --- /dev/null +++ b/pr_agent/algo/base_ai_handler.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod + +class BaseAiHandler(ABC): + """ + This class defines the interface for an AI handler. + """ + + @abstractmethod + def __init__(self): + pass + + @property + @abstractmethod + def deployment_id(self): + pass + + @abstractmethod + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + pass + + +class AiHandler(BaseAiHandler): + """ + This class handles interactions with the OpenAI API for chat completions. + It initializes the API key and other settings from a configuration file, + and provides a method for performing chat completions using the OpenAI ChatCompletion API. + """ + + # ... rest of your code ... + + +class CustomAiHandler(BaseAiHandler): + """ + This class is your custom AI handler that uses a different LLM library. + """ + + def __init__(self): + # Initialize your custom AI handler + pass + + @property + def deployment_id(self): + # Return the deployment ID for your custom AI handler + pass + + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + # Implement the chat completion method for your custom AI handler + pass \ No newline at end of file diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index eec75b9c..04cc53f4 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -4,7 +4,7 @@ from typing import Dict from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,14 +15,14 @@ from pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 9e8d7f15..00e58a21 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -4,7 +4,7 @@ from typing import Dict, List from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler() ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( @@ -32,7 +32,7 @@ class PRCodeSuggestions: else: num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index c1bd03fd..9b02cd35 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,7 +4,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRDescription: - def __init__(self, pr_url: str, args: list = None): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. @@ -31,7 +31,7 @@ class PRDescription: self.pr_id = self.git_provider.get_pr_id() # Initialize the AI handler - self.ai_handler = AiHandler() + self.ai_handler = ai_handler # Initialize the variables dictionary self.vars = { diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index 059966e1..27c77180 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,12 +12,12 @@ from pr_agent.log import get_logger class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 7740fd4a..4aec3edf 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,13 +12,13 @@ from pr_agent.log import get_logger class PRQuestions: - def __init__(self, pr_url: str, args=None): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = AiHandler()): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.question_str = question_str self.vars = { "title": self.git_provider.pr.title, diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index ed99ddf6..c78a11e8 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -6,7 +6,7 @@ import yaml from jinja2 import Environment, StrictUndefined from yaml import SafeLoader -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml @@ -21,7 +21,7 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. @@ -42,7 +42,7 @@ class PRReviewer: if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index a5f24e0d..f8a84960 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -5,7 +5,7 @@ from typing import Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -17,7 +17,7 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = AiHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( @@ -25,7 +25,7 @@ class PRUpdateChangelog: ) self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes self._get_changlog_file() # self.changelog_file_str - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None self.cli_mode = cli_mode From 523a896465b477cd52db284071a5b46585c362bb Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Mon, 11 Dec 2023 16:56:49 +0800 Subject: [PATCH 02/17] Rename AiHandler to LiteLLMAiHandler --- pr_agent/algo/ai_handler.py | 2 +- pr_agent/tools/pr_add_docs.py | 4 ++-- pr_agent/tools/pr_code_suggestions.py | 4 ++-- pr_agent/tools/pr_description.py | 4 ++-- pr_agent/tools/pr_information_from_user.py | 4 ++-- pr_agent/tools/pr_questions.py | 4 ++-- pr_agent/tools/pr_reviewer.py | 4 ++-- pr_agent/tools/pr_update_changelog.py | 4 ++-- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 3e4dc18a..7d88a00d 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -11,7 +11,7 @@ from pr_agent.algo.base_ai_handler import BaseAiHandler OPENAI_RETRIES = 5 -class AiHandler(BaseAiHandler): +class LiteLLMAiHandler(BaseAiHandler): """ This class handles interactions with the OpenAI API for chat completions. It initializes the API key and other settings from a configuration file, diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index 04cc53f4..3f5f01b1 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -4,7 +4,7 @@ from typing import Dict from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 00e58a21..02a894fc 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -4,7 +4,7 @@ from typing import Dict, List from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler() ): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler() ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 9b02cd35..49564812 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,7 +4,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRDescription: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index 27c77180..d2beade0 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,7 +12,7 @@ from pr_agent.log import get_logger class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 4aec3edf..e21ab9d0 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,7 +12,7 @@ from pr_agent.log import get_logger class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index c78a11e8..0a8018a7 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -6,7 +6,7 @@ import yaml from jinja2 import Environment, StrictUndefined from yaml import SafeLoader -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml @@ -21,7 +21,7 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index f8a84960..b85bb060 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -5,7 +5,7 @@ from typing import Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler +from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -17,7 +17,7 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = AiHandler()): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( From b8021d7ca3d50936907a77e49ea30b22dd4a01d4 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Mon, 11 Dec 2023 16:57:23 +0800 Subject: [PATCH 03/17] rename file --- pr_agent/algo/{ai_handler.py => litellm_ai_handler.py} | 0 pr_agent/tools/pr_add_docs.py | 2 +- pr_agent/tools/pr_code_suggestions.py | 2 +- pr_agent/tools/pr_description.py | 2 +- pr_agent/tools/pr_information_from_user.py | 2 +- pr_agent/tools/pr_questions.py | 2 +- pr_agent/tools/pr_reviewer.py | 2 +- pr_agent/tools/pr_update_changelog.py | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename pr_agent/algo/{ai_handler.py => litellm_ai_handler.py} (100%) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/litellm_ai_handler.py similarity index 100% rename from pr_agent/algo/ai_handler.py rename to pr_agent/algo/litellm_ai_handler.py diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index 3f5f01b1..3d50afe2 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -4,7 +4,7 @@ from typing import Dict from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 02a894fc..e55720ae 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -4,7 +4,7 @@ from typing import Dict, List from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 49564812..a1b73df2 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,7 +4,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index d2beade0..beb1b0ab 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index e21ab9d0..b1930924 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 0a8018a7..4de24b06 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -6,7 +6,7 @@ import yaml from jinja2 import Environment, StrictUndefined from yaml import SafeLoader -from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index b85bb060..7625218b 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -5,7 +5,7 @@ from typing import Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings From a1cbd80b2a3a3cf5faa860499d235e3ee85b3cb3 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Mon, 11 Dec 2023 17:49:09 +0800 Subject: [PATCH 04/17] update base ai handler --- pr_agent/algo/base_ai_handler.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/pr_agent/algo/base_ai_handler.py b/pr_agent/algo/base_ai_handler.py index 9cb40b5c..21050db9 100644 --- a/pr_agent/algo/base_ai_handler.py +++ b/pr_agent/algo/base_ai_handler.py @@ -18,31 +18,3 @@ class BaseAiHandler(ABC): async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): pass - -class AiHandler(BaseAiHandler): - """ - This class handles interactions with the OpenAI API for chat completions. - It initializes the API key and other settings from a configuration file, - and provides a method for performing chat completions using the OpenAI ChatCompletion API. - """ - - # ... rest of your code ... - - -class CustomAiHandler(BaseAiHandler): - """ - This class is your custom AI handler that uses a different LLM library. - """ - - def __init__(self): - # Initialize your custom AI handler - pass - - @property - def deployment_id(self): - # Return the deployment ID for your custom AI handler - pass - - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): - # Implement the chat completion method for your custom AI handler - pass \ No newline at end of file From ebf7027aabb63d7e26b792cb9a01e0ed2a0e9ab0 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Mon, 11 Dec 2023 17:49:20 +0800 Subject: [PATCH 05/17] add openai handler --- pr_agent/algo/openai_ai_handler.py | 50 ++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 pr_agent/algo/openai_ai_handler.py diff --git a/pr_agent/algo/openai_ai_handler.py b/pr_agent/algo/openai_ai_handler.py new file mode 100644 index 00000000..4c9264d7 --- /dev/null +++ b/pr_agent/algo/openai_ai_handler.py @@ -0,0 +1,50 @@ +from pr_agent.algo.base_ai_handler import BaseAiHandler +import openai +from openai.error import APIError, RateLimitError, Timeout, TryAgain +from retry import retry + +from pr_agent.config_loader import get_settings + +OPENAI_RETRIES = 5 + + +class OpenAIHandler(BaseAiHandler): + def __init__(self): + # Initialize OpenAIHandler specific attributes here + try: + super().__init__() + openai.api_key = get_settings().openai.key + if get_settings().get("OPENAI.ORG", None): + openai.organization = get_settings().openai.org + if get_settings().get("OPENAI.API_TYPE", None): + if get_settings().openai.api_type == "azure": + self.azure = True + openai.azure_key = get_settings().openai.key + if get_settings().get("OPENAI.API_VERSION", None): + openai.api_version = get_settings().openai.api_version + if get_settings().get("OPENAI.API_BASE", None): + openai.api_base = get_settings().openai.api_base + + except AttributeError as e: + raise ValueError("OpenAI key is required") from e + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + chat_completion = await openai.ChatCompletion.acreate( + model=model, + messages=[{ + "role": "system", + "content": system + }, { + "role": "user", + "content": user + }], + ) + return chat_completion.choices[0].message.content From 5239e1c3e9189ad9543092d1f72649d98d5f882d Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 21:51:05 +0800 Subject: [PATCH 06/17] Load default AI Handler from util function --- pr_agent/algo/openai_ai_handler.py | 39 ++++++++++++++++------ pr_agent/algo/utils.py | 6 ++++ pr_agent/tools/pr_add_docs.py | 6 ++-- pr_agent/tools/pr_code_suggestions.py | 6 ++-- pr_agent/tools/pr_description.py | 6 ++-- pr_agent/tools/pr_information_from_user.py | 5 +-- pr_agent/tools/pr_questions.py | 5 +-- pr_agent/tools/pr_reviewer.py | 6 ++-- pr_agent/tools/pr_update_changelog.py | 5 +-- 9 files changed, 55 insertions(+), 29 deletions(-) diff --git a/pr_agent/algo/openai_ai_handler.py b/pr_agent/algo/openai_ai_handler.py index 4c9264d7..c521442d 100644 --- a/pr_agent/algo/openai_ai_handler.py +++ b/pr_agent/algo/openai_ai_handler.py @@ -4,6 +4,7 @@ from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger OPENAI_RETRIES = 5 @@ -37,14 +38,30 @@ class OpenAIHandler(BaseAiHandler): @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): - chat_completion = await openai.ChatCompletion.acreate( - model=model, - messages=[{ - "role": "system", - "content": system - }, { - "role": "user", - "content": user - }], - ) - return chat_completion.choices[0].message.content + try: + deployment_id = self.deployment_id + get_logger().info("System: ", system) + get_logger().info("User: ", user) + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + chat_completion = await openai.ChatCompletion.acreate( + model=model, + deployment_id=deployment_id, + messages=messages, + temperature=temperature, + ) + resp = chat_completion["choices"][0]['message']['content'] + finish_reason = chat_completion["choices"][0]["finish_reason"] + usage = chat_completion.get("usage") + get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, + model=model, usage=usage) + return resp, finish_reason + except (APIError, Timeout, TryAgain) as e: + get_logger().error("Error during OpenAI inference: ", e) + raise + except (RateLimitError) as e: + get_logger().error("Rate limit error during OpenAI inference: ", e) + raise + except (Exception) as e: + get_logger().error("Unknown error during OpenAI inference: ", e) + raise TryAgain from e \ No newline at end of file diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 4e88b33e..14cde04a 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -8,6 +8,9 @@ from datetime import datetime from typing import Any, List import yaml +from pr_agent.algo.litellm_ai_handler import LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.openai_ai_handler import OpenAIHandler from starlette_context import context from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger @@ -304,3 +307,6 @@ def try_fix_yaml(review_text: str) -> dict: except: pass return data + +def get_ai_handler() -> BaseAiHandler: + return OpenAIHandler() \ No newline at end of file diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index 3d50afe2..f76baa02 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -4,10 +4,10 @@ from typing import Dict from jinja2 import Environment, StrictUndefined -from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml +from pr_agent.algo.utils import load_yaml, get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index e55720ae..96bd79fb 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -4,10 +4,10 @@ from typing import Dict, List from jinja2 import Environment, StrictUndefined -from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml +from pr_agent.algo.utils import load_yaml, get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler() ): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler() ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index a1b73df2..58de3b4e 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,10 +4,10 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml +from pr_agent.algo.utils import load_yaml, get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRDescription: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index beb1b0ab..78903490 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -2,9 +2,10 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -12,7 +13,7 @@ from pr_agent.log import get_logger class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index b1930924..43c276cb 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -2,9 +2,10 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -12,7 +13,7 @@ from pr_agent.log import get_logger class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = get_ai_handler()): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 4de24b06..d68b893e 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -6,10 +6,10 @@ import yaml from jinja2 import Environment, StrictUndefined from yaml import SafeLoader -from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml +from pr_agent.algo.utils import convert_to_markdown, get_ai_handler, load_yaml, try_fix_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language @@ -21,7 +21,7 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index 7625218b..febe6fec 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -5,9 +5,10 @@ from typing import Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.litellm_ai_handler import BaseAiHandler, LiteLLMAiHandler +from pr_agent.algo.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -17,7 +18,7 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAiHandler()): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = get_ai_handler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( From 7eb2e769cf82c259c33c30e0a158c9fade2302af Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:03:38 +0800 Subject: [PATCH 07/17] Move ai handlers to specific folder --- pr_agent/algo/{ => ai_handlers}/base_ai_handler.py | 0 pr_agent/algo/{ => ai_handlers}/litellm_ai_handler.py | 2 +- pr_agent/algo/{ => ai_handlers}/openai_ai_handler.py | 2 +- pr_agent/algo/utils.py | 6 +++--- pr_agent/tools/pr_add_docs.py | 2 +- pr_agent/tools/pr_code_suggestions.py | 2 +- pr_agent/tools/pr_description.py | 2 +- pr_agent/tools/pr_information_from_user.py | 2 +- pr_agent/tools/pr_questions.py | 2 +- pr_agent/tools/pr_reviewer.py | 2 +- pr_agent/tools/pr_update_changelog.py | 2 +- 11 files changed, 12 insertions(+), 12 deletions(-) rename pr_agent/algo/{ => ai_handlers}/base_ai_handler.py (100%) rename pr_agent/algo/{ => ai_handlers}/litellm_ai_handler.py (98%) rename pr_agent/algo/{ => ai_handlers}/openai_ai_handler.py (96%) diff --git a/pr_agent/algo/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py similarity index 100% rename from pr_agent/algo/base_ai_handler.py rename to pr_agent/algo/ai_handlers/base_ai_handler.py diff --git a/pr_agent/algo/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py similarity index 98% rename from pr_agent/algo/litellm_ai_handler.py rename to pr_agent/algo/ai_handlers/litellm_ai_handler.py index 7d88a00d..2e37a592 100644 --- a/pr_agent/algo/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -7,7 +7,7 @@ from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry from pr_agent.config_loader import get_settings from pr_agent.log import get_logger -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler OPENAI_RETRIES = 5 diff --git a/pr_agent/algo/openai_ai_handler.py b/pr_agent/algo/ai_handlers/openai_ai_handler.py similarity index 96% rename from pr_agent/algo/openai_ai_handler.py rename to pr_agent/algo/ai_handlers/openai_ai_handler.py index c521442d..3856f6f7 100644 --- a/pr_agent/algo/openai_ai_handler.py +++ b/pr_agent/algo/ai_handlers/openai_ai_handler.py @@ -1,4 +1,4 @@ -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler import openai from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 14cde04a..824e4b70 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -8,9 +8,9 @@ from datetime import datetime from typing import Any, List import yaml -from pr_agent.algo.litellm_ai_handler import LiteLLMAiHandler -from pr_agent.algo.base_ai_handler import BaseAiHandler -from pr_agent.algo.openai_ai_handler import OpenAIHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.openai_ai_handler import OpenAIHandler from starlette_context import context from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index f76baa02..916f479f 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -4,7 +4,7 @@ from typing import Dict from jinja2 import Environment, StrictUndefined -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, get_ai_handler diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 96bd79fb..61a382a5 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -4,7 +4,7 @@ from typing import Dict, List from jinja2 import Environment, StrictUndefined -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, get_ai_handler diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 58de3b4e..c3db0cef 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,7 +4,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, get_ai_handler diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index 78903490..c4240723 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import get_ai_handler diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 43c276cb..ecaf4d8d 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -2,7 +2,7 @@ import copy from jinja2 import Environment, StrictUndefined -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import get_ai_handler diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index d68b893e..138ad5ad 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -6,7 +6,7 @@ import yaml from jinja2 import Environment, StrictUndefined from yaml import SafeLoader -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, get_ai_handler, load_yaml, try_fix_yaml diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index febe6fec..33ba941d 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -5,7 +5,7 @@ from typing import Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import get_ai_handler From 6c7beccb4f11dcd495522c3d6a9cd65ef5c87c4b Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:03:49 +0800 Subject: [PATCH 08/17] add LangChain AI Handler --- .../algo/ai_handlers/langchain_ai_handler.py | 45 +++++++++++++++++++ pr_agent/algo/utils.py | 3 +- 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 pr_agent/algo/ai_handlers/langchain_ai_handler.py diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py new file mode 100644 index 00000000..406c6c40 --- /dev/null +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -0,0 +1,45 @@ +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema import SystemMessage, HumanMessage + + +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + +OPENAI_RETRIES = 5 +chat = ChatOpenAI(openai_api_key = get_settings().openai.key, model="gpt-4") + +class LangChainAIHandler(BaseAiHandler): + def __init__(self): + # Initialize OpenAIHandler specific attributes here + try: + super().__init__() + + except AttributeError as e: + raise ValueError("OpenAI key is required") from e + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + try: + + messages=[SystemMessage(content=system), HumanMessage(content=user)] + + # get a chat completion from the formatted messages + resp = chat(messages) + get_logger().info("AI response: ", resp.content) + finish_reason="completed" + return resp.content, finish_reason + + except (Exception) as e: + get_logger().error("Unknown error during OpenAI inference: ", e) + raise e \ No newline at end of file diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 824e4b70..d0b86b63 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -11,6 +11,7 @@ import yaml from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.openai_ai_handler import OpenAIHandler +from pr_agent.algo.ai_handlers.langchain_ai_handler import LangChainAIHandler from starlette_context import context from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger @@ -309,4 +310,4 @@ def try_fix_yaml(review_text: str) -> dict: return data def get_ai_handler() -> BaseAiHandler: - return OpenAIHandler() \ No newline at end of file + return LangChainAIHandler() \ No newline at end of file From 506eafc0c5e0438b784cceb73d70777eef59d8fe Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:04:01 +0800 Subject: [PATCH 09/17] add langchain in requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 8589b30b..43738491 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ msrest==0.7.1 pinecone-client pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main loguru==0.7.2 +langchain==0.0.349 From 0c66554d505a1a607e0cc14bc80e4c750bcad13c Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:07:46 +0800 Subject: [PATCH 10/17] langchain: move model and temperature to chat_completion --- pr_agent/algo/ai_handlers/langchain_ai_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index 406c6c40..87cf1072 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -12,7 +12,7 @@ from pr_agent.config_loader import get_settings from pr_agent.log import get_logger OPENAI_RETRIES = 5 -chat = ChatOpenAI(openai_api_key = get_settings().openai.key, model="gpt-4") +chat = ChatOpenAI(openai_api_key = get_settings().openai.key) class LangChainAIHandler(BaseAiHandler): def __init__(self): @@ -35,7 +35,7 @@ class LangChainAIHandler(BaseAiHandler): messages=[SystemMessage(content=system), HumanMessage(content=user)] # get a chat completion from the formatted messages - resp = chat(messages) + resp = chat(messages, model=model, temperature=temperature) get_logger().info("AI response: ", resp.content) finish_reason="completed" return resp.content, finish_reason From a627dcd64fe1de6e740e385323bf9b451d6d2772 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:28:58 +0800 Subject: [PATCH 11/17] Update langchain --- .../algo/ai_handlers/langchain_ai_handler.py | 26 ++++++++----------- pr_agent/algo/utils.py | 4 +-- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index 87cf1072..bc26e624 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -1,41 +1,37 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from langchain.chat_models import ChatOpenAI -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, -) from langchain.schema import SystemMessage, HumanMessage - - from pr_agent.config_loader import get_settings from pr_agent.log import get_logger -OPENAI_RETRIES = 5 -chat = ChatOpenAI(openai_api_key = get_settings().openai.key) - -class LangChainAIHandler(BaseAiHandler): +class LangChainOpenAIHandler(BaseAiHandler): def __init__(self): # Initialize OpenAIHandler specific attributes here try: super().__init__() - + self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key) + except AttributeError as e: raise ValueError("OpenAI key is required") from e + + @property + def chat(self): + return self._chat + @property def deployment_id(self): """ Returns the deployment ID for the OpenAI API. """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): try: - + get_logger().info("model: ", model) messages=[SystemMessage(content=system), HumanMessage(content=user)] # get a chat completion from the formatted messages - resp = chat(messages, model=model, temperature=temperature) + resp = self.chat(messages, model=model, temperature=temperature) get_logger().info("AI response: ", resp.content) finish_reason="completed" return resp.content, finish_reason diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index d0b86b63..8c66d96e 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -11,7 +11,7 @@ import yaml from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.openai_ai_handler import OpenAIHandler -from pr_agent.algo.ai_handlers.langchain_ai_handler import LangChainAIHandler +from pr_agent.algo.ai_handlers.langchain_ai_handler import LangChainOpenAIHandler from starlette_context import context from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger @@ -310,4 +310,4 @@ def try_fix_yaml(review_text: str) -> dict: return data def get_ai_handler() -> BaseAiHandler: - return LangChainAIHandler() \ No newline at end of file + return LangChainOpenAIHandler() \ No newline at end of file From b7225cc674d28e6c337768630223090ad78cc8ee Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:52:50 +0800 Subject: [PATCH 12/17] update langchain --- pr_agent/algo/ai_handlers/langchain_ai_handler.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index bc26e624..5c793f2b 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -1,9 +1,15 @@ -from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from langchain.chat_models import ChatOpenAI from langchain.schema import SystemMessage, HumanMessage + +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +from openai.error import APIError, RateLimitError, Timeout, TryAgain +from retry import retry + +OPENAI_RETRIES = 5 + class LangChainOpenAIHandler(BaseAiHandler): def __init__(self): # Initialize OpenAIHandler specific attributes here @@ -24,15 +30,14 @@ class LangChainOpenAIHandler(BaseAiHandler): Returns the deployment ID for the OpenAI API. """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): try: - get_logger().info("model: ", model) messages=[SystemMessage(content=system), HumanMessage(content=user)] # get a chat completion from the formatted messages resp = self.chat(messages, model=model, temperature=temperature) - get_logger().info("AI response: ", resp.content) finish_reason="completed" return resp.content, finish_reason From ca1ccd7b91a2a3cc8c0596b9207a9779da301134 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:56:20 +0800 Subject: [PATCH 13/17] update base --- pr_agent/algo/ai_handlers/base_ai_handler.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py index 21050db9..7c6c3ddf 100644 --- a/pr_agent/algo/ai_handlers/base_ai_handler.py +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod class BaseAiHandler(ABC): """ - This class defines the interface for an AI handler. + This class defines the interface for an AI handler to be used by the PR Agents. """ @abstractmethod @@ -15,6 +15,14 @@ class BaseAiHandler(ABC): pass @abstractmethod + """ + This method should be implemented to return a chat completion from the AI model. + params: + model: the name of the model to use for the chat completion + system: the system message string to use for the chat completion + user: the user message string to use for the chat completion + temperature: the temperature to use for the chat completion + """ async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): pass From 8fb4a42ef1e45aa5ac67a5eca42eeabe5c3c999a Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Wed, 13 Dec 2023 08:16:02 +0800 Subject: [PATCH 14/17] Update AI handler instantiation in server files --- pr_agent/agent/pr_agent.py | 22 ++++++++++++++++---- pr_agent/algo/ai_handlers/base_ai_handler.py | 18 ++++++++-------- pr_agent/cli.py | 6 ++++-- pr_agent/servers/bitbucket_app.py | 4 +++- pr_agent/servers/gerrit_server.py | 4 +++- pr_agent/servers/github_action_runner.py | 7 ++++--- pr_agent/servers/github_app.py | 4 +++- pr_agent/servers/github_polling.py | 4 +++- pr_agent/servers/gitlab_webhook.py | 4 +++- pr_agent/tools/pr_add_docs.py | 4 ++-- pr_agent/tools/pr_code_suggestions.py | 4 ++-- pr_agent/tools/pr_description.py | 4 ++-- pr_agent/tools/pr_information_from_user.py | 3 +-- pr_agent/tools/pr_questions.py | 3 +-- pr_agent/tools/pr_reviewer.py | 6 ++++-- pr_agent/tools/pr_update_changelog.py | 3 +-- 16 files changed, 63 insertions(+), 37 deletions(-) diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index cd2bf2cc..a94984ac 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,4 +1,5 @@ import shlex +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings @@ -12,6 +13,7 @@ from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_similar_issue import PRSimilarIssue from pr_agent.tools.pr_update_changelog import PRUpdateChangelog +import inspect command2class = { "auto_review": PRReviewer, @@ -36,8 +38,16 @@ command2class = { commands = list(command2class.keys()) class PRAgent: - def __init__(self): + def __init__(self, ai_handler: BaseAiHandler = None): + self.ai_handler = ai_handler pass + + def has_ai_handler_param(cls): + constructor = getattr(cls, "__init__", None) + if constructor is not None: + parameters = inspect.signature(constructor).parameters + return "ai_handler" in parameters + return False async def handle_request(self, pr_url, request, notify=None) -> bool: # First, apply repo specific settings if exists @@ -56,13 +66,17 @@ class PRAgent: if action == "answer": if notify: notify() - await PRReviewer(pr_url, is_answer=True, args=args).run() + await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run() elif action == "auto_review": - await PRReviewer(pr_url, is_auto=True, args=args).run() + await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run() elif action in command2class: if notify: notify() - await command2class[action](pr_url, args=args).run() + + if(not self.has_ai_handler_param(command2class[action])): + await command2class[action](pr_url, args=args).run() + else + await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() else: return False return True diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py index 7c6c3ddf..c8473fb3 100644 --- a/pr_agent/algo/ai_handlers/base_ai_handler.py +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -14,15 +14,15 @@ class BaseAiHandler(ABC): def deployment_id(self): pass - @abstractmethod - """ - This method should be implemented to return a chat completion from the AI model. - params: - model: the name of the model to use for the chat completion - system: the system message string to use for the chat completion - user: the user message string to use for the chat completion - temperature: the temperature to use for the chat completion - """ + @abstractmethod async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + """ + This method should be implemented to return a chat completion from the AI model. + Args: + model (str): the name of the model to use for the chat completion + system (str): the system message string to use for the chat completion + user (str): the user message string to use for the chat completion + temperature (float): the temperature to use for the chat completion + """ pass diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 6728db9f..422185ab 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -5,7 +5,9 @@ import os from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.config_loader import get_settings from pr_agent.log import setup_logger +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger() def run(inargs=None): @@ -51,9 +53,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions command = args.command.lower() get_settings().set("CONFIG.CLI_MODE", True) if args.issue_url: - result = asyncio.run(PRAgent().handle_request(args.issue_url, command + " " + " ".join(args.rest))) + result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, command + " " + " ".join(args.rest))) else: - result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest))) + result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, command + " " + " ".join(args.rest))) if not result: parser.print_help() diff --git a/pr_agent/servers/bitbucket_app.py b/pr_agent/servers/bitbucket_app.py index e147fbdd..739ba162 100644 --- a/pr_agent/servers/bitbucket_app.py +++ b/pr_agent/servers/bitbucket_app.py @@ -18,7 +18,9 @@ from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.secret_providers import get_secret_provider +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() secret_provider = get_secret_provider() @@ -84,7 +86,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req context['bitbucket_bearer_token'] = bearer_token context["settings"] = copy.deepcopy(global_settings) event = data["event"] - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) if event == "pullrequest:created": pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] log_context["api_url"] = pr_url diff --git a/pr_agent/servers/gerrit_server.py b/pr_agent/servers/gerrit_server.py index 1783f6b9..b8b90670 100644 --- a/pr_agent/servers/gerrit_server.py +++ b/pr_agent/servers/gerrit_server.py @@ -12,7 +12,9 @@ from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger, setup_logger +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger() router = APIRouter() @@ -43,7 +45,7 @@ async def handle_gerrit_request(action: Action, item: Item): status_code=400, detail="msg is required for ask command" ) - await PRAgent().handle_request( + await PRAgent(ai_handler=litellm_ai_handler).handle_request( f"{item.project}:{item.refspec}", f"/{item.msg.strip()}" ) diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py index 714e7297..030e3f2c 100644 --- a/pr_agent/servers/github_action_runner.py +++ b/pr_agent/servers/github_action_runner.py @@ -8,7 +8,8 @@ from pr_agent.git_providers import get_git_provider from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_reviewer import PRReviewer - +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() async def run_action(): # Get environment variables @@ -83,9 +84,9 @@ async def run_action(): comment_id = event_payload.get("comment", {}).get("id") provider = get_git_provider()(pr_url=url) if is_pr: - await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) + await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) else: - await PRAgent().handle_request(url, body) + await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body) if __name__ == '__main__': diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index 37f96e2d..bb595a89 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -16,7 +16,9 @@ from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.utils import apply_repo_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.utils import verify_signature +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -75,7 +77,7 @@ async def handle_request(body: Dict[str, Any], event: str): action = body.get("action") if not action: return {} - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) bot_user = get_settings().github_app.bot_user sender = body.get("sender", {}).get("login") log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"} diff --git a/pr_agent/servers/github_polling.py b/pr_agent/servers/github_polling.py index 1363b941..b473b8fa 100644 --- a/pr_agent/servers/github_polling.py +++ b/pr_agent/servers/github_polling.py @@ -8,7 +8,9 @@ from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.help import bot_help_text +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) NOTIFICATION_URL = "https://api.github.com/notifications" @@ -34,7 +36,7 @@ async def polling_loop(): last_modified = [None] git_provider = get_git_provider()() user_id = git_provider.get_user_id() - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) try: diff --git a/pr_agent/servers/gitlab_webhook.py b/pr_agent/servers/gitlab_webhook.py index 63bf99ce..b6921684 100644 --- a/pr_agent/servers/gitlab_webhook.py +++ b/pr_agent/servers/gitlab_webhook.py @@ -14,7 +14,9 @@ from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.secret_providers import get_secret_provider +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -26,7 +28,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c log_context["event"] = "pull_request" if body == "/review" else "comment" log_context["api_url"] = url with get_logger().contextualize(**log_context): - background_tasks.add_task(PRAgent().handle_request, url, body) + background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body) @router.post("/webhook") diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index 916f479f..70dd66c2 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, get_ai_handler +from pr_agent.algo.utils import load_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 61a382a5..a85bab5f 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, get_ai_handler +from pr_agent.algo.utils import load_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler() ): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index c3db0cef..d377d75c 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, get_ai_handler +from pr_agent.algo.utils import load_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRDescription: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index c4240723..e52765f7 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -5,7 +5,6 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -13,7 +12,7 @@ from pr_agent.log import get_logger class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index ecaf4d8d..79edfd6a 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -5,7 +5,6 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -13,7 +12,7 @@ from pr_agent.log import get_logger class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 138ad5ad..becd2191 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -9,7 +9,7 @@ from yaml import SafeLoader from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import convert_to_markdown, get_ai_handler, load_yaml, try_fix_yaml +from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language @@ -21,13 +21,15 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Args: pr_url (str): The URL of the pull request to be reviewed. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. + is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False. + ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. """ self.parse_args(args) # -i command diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index 33ba941d..07130749 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -8,7 +8,6 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -18,7 +17,7 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( From be8d6af87fed0da767752087863dd1d3447ffdc8 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Wed, 13 Dec 2023 08:16:31 +0800 Subject: [PATCH 15/17] Add code documentation generation for PR diffs --- langchain.ipynb | 166 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 langchain.ipynb diff --git a/langchain.ipynb b/langchain.ipynb new file mode 100644 index 00000000..10e163cc --- /dev/null +++ b/langchain.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.prompts.chat import (\n", + " ChatPromptTemplate,\n", + " HumanMessagePromptTemplate,\n", + " SystemMessagePromptTemplate,\n", + ")\n", + "\n", + "chat = ChatOpenAI(temperature=0, openai_api_key=\"\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"You are a language model called PR-Code-Documentation Agent, that specializes in generating documentation for code.\n", + "Your task is to generate meaningfull {{ docs_for_language }} to a PR (the '+' lines).\n", + "\n", + "Example for a PR Diff input:\n", + "'\n", + "## src/file1.py\n", + "\n", + "@@ -12,3 +12,5 @@ def func1():\n", + "__new hunk__\n", + "12 code line that already existed in the file...\n", + "13 code line that already existed in the file....\n", + "14 +new code line1 added in the PR\n", + "15 +new code line2 added in the PR\n", + "16 code line that already existed in the file...\n", + "__old hunk__\n", + " code line that already existed in the file...\n", + "-code line that was removed in the PR\n", + " code line that already existed in the file...\n", + "\n", + "\n", + "@@ ... @@ def func2():\n", + "__new hunk__\n", + "...\n", + "__old hunk__\n", + "...\n", + "\n", + "\n", + "## src/file2.py\n", + "...\n", + "'\n", + "\n", + "Specific instructions:\n", + "- Try to identify edited/added code components (classes/functions/methods...) that are undocumented. and generate {{ docs_for_language }} for each one.\n", + "- If there are documented (any type of {{ language }} documentation) code components in the PR, Don't generate {{ docs_for_language }} for them.\n", + "- Ignore code components that don't appear fully in the '__new hunk__' section. For example. you must see the component header and body,\n", + "- Make sure the {{ docs_for_language }} starts and ends with standart {{ language }} {{ docs_for_language }} signs.\n", + "- The {{ docs_for_language }} should be in standard format.\n", + "- Provide the exact line number (inclusive) where the {{ docs_for_language }} should be added.\n", + "\n", + "\n", + "You must use the following YAML schema to format your answer:\n", + "```yaml\n", + "Code Documentation:\n", + " type: array\n", + " uniqueItems: true\n", + " items:\n", + " relevant file:\n", + " type: string\n", + " description: the relevant file full path\n", + " relevant line:\n", + " type: integer\n", + " description: |-\n", + " The relevant line number from a '__new hunk__' section where the {{ docs_for_language }} should be added.\n", + " doc placement:\n", + " type: string\n", + " enum:\n", + " - before\n", + " - after\n", + " description: |-\n", + " The {{ docs_for_language }} placement relative to the relevant line (code component).\n", + " documentation:\n", + " type: string\n", + " description: |-\n", + " The {{ docs_for_language }} content. It should be complete, correctly formatted and indented, and without line numbers.\n", + "```\n", + "\n", + "Example output:\n", + "```yaml\n", + "Code Documentation:\n", + "- relevant file: |-\n", + " src/file1.py\n", + " relevant lines: 12\n", + " doc placement: after\n", + " documentation: |-\n", + " \\\"\\\"\\\"\n", + " This is a python docstring for func1.\n", + " \\\"\\\"\\\"\n", + "- ...\n", + "...\n", + "```\n", + "\n", + "\n", + "Each YAML output MUST be after a newline, indented, with block scalar indicator ('|-').\n", + "Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.\"\"\"\n", + "\n", + "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n", + "human_template = \"{text}\"\n", + "human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='```yaml\\nCode Documentation:\\n- relevant file: |-\\n src/file1.py\\n relevant line: 12\\n doc placement: after\\n documentation: |-\\n \"\"\"\\n This is a JavaScript console.log statement that prints \\'hello world\\'.\\n \"\"\"\\n```')" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_prompt = ChatPromptTemplate.from_messages(\n", + " [system_message_prompt, human_message_prompt]\n", + ")\n", + "\n", + "# get a chat completion from the formatted messages\n", + "chat(\n", + " chat_prompt.format_prompt(\n", + " docs_for_language=\"JSDoc\", language=\"JavaScript\", text=\"console.log('hello world')\"\n", + " ).to_messages()\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.0.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From ebb2ed891b1a7ce297eef4f9572f03f758fee01d Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Wed, 13 Dec 2023 08:16:45 +0800 Subject: [PATCH 16/17] Add logging to pr_agent.py --- pr_agent/agent/pr_agent.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index a94984ac..5c6e4ec1 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -4,6 +4,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings from pr_agent.git_providers.utils import apply_repo_settings +from pr_agent.log import get_logger from pr_agent.tools.pr_add_docs import PRAddDocs from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_config import PRConfig @@ -42,7 +43,7 @@ class PRAgent: self.ai_handler = ai_handler pass - def has_ai_handler_param(cls): + def has_ai_handler_param(cls: object): constructor = getattr(cls, "__init__", None) if constructor is not None: parameters = inspect.signature(constructor).parameters @@ -73,9 +74,10 @@ class PRAgent: if notify: notify() - if(not self.has_ai_handler_param(command2class[action])): + get_logger().info(f"Class: {command2class[action]}") + if(not self.has_ai_handler_param(cls=command2class[action])): await command2class[action](pr_url, args=args).run() - else + else: await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() else: return False From 69a7c77a0d8762c6eb14285138040c42fca493d6 Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Thu, 14 Dec 2023 07:15:56 +0800 Subject: [PATCH 17/17] Refactor PRAgent class and has_ai_handler_param method This commit refactors the PRAgent class and the has_ai_handler_param method. The has_ai_handler_param method is moved outside the class and made a standalone function. This change improves code organization and readability. The has_ai_handler_param function now takes a class object as a parameter and checks if the class constructor has an "ai_handler" parameter. This refactoring is done to streamline the code and improve its maintainability. No issue references. --- pr_agent/agent/pr_agent.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 5c6e4ec1..ff2237e0 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -38,17 +38,18 @@ command2class = { commands = list(command2class.keys()) +def has_ai_handler_param(cls: object): + constructor = getattr(cls, "__init__", None) + if constructor is not None: + parameters = inspect.signature(constructor).parameters + return "ai_handler" in parameters + return False + class PRAgent: def __init__(self, ai_handler: BaseAiHandler = None): self.ai_handler = ai_handler pass - def has_ai_handler_param(cls: object): - constructor = getattr(cls, "__init__", None) - if constructor is not None: - parameters = inspect.signature(constructor).parameters - return "ai_handler" in parameters - return False async def handle_request(self, pr_url, request, notify=None) -> bool: # First, apply repo specific settings if exists @@ -75,7 +76,7 @@ class PRAgent: notify() get_logger().info(f"Class: {command2class[action]}") - if(not self.has_ai_handler_param(cls=command2class[action])): + if(not has_ai_handler_param(cls=command2class[action])): await command2class[action](pr_url, args=args).run() else: await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()