Merge branch 'base-ai-handler' into abstract-BaseAiHandler

This commit is contained in:
Brian Pham
2023-12-14 07:44:13 +08:00
21 changed files with 256 additions and 57 deletions

View File

@ -1,8 +1,10 @@
import shlex import shlex
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.utils import update_settings_from_args from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.utils import apply_repo_settings from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.log import get_logger
from pr_agent.tools.pr_add_docs import PRAddDocs from pr_agent.tools.pr_add_docs import PRAddDocs
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_config import PRConfig from pr_agent.tools.pr_config import PRConfig
@ -13,6 +15,7 @@ from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_similar_issue import PRSimilarIssue from pr_agent.tools.pr_similar_issue import PRSimilarIssue
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
import inspect
command2class = { command2class = {
"auto_review": PRReviewer, "auto_review": PRReviewer,
@ -37,10 +40,19 @@ command2class = {
commands = list(command2class.keys()) commands = list(command2class.keys())
def has_ai_handler_param(cls: object):
constructor = getattr(cls, "__init__", None)
if constructor is not None:
parameters = inspect.signature(constructor).parameters
return "ai_handler" in parameters
return False
class PRAgent: class PRAgent:
def __init__(self): def __init__(self, ai_handler: BaseAiHandler = None):
self.ai_handler = ai_handler
pass pass
async def handle_request(self, pr_url, request, notify=None) -> bool: async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists # First, apply repo specific settings if exists
apply_repo_settings(pr_url) apply_repo_settings(pr_url)
@ -61,13 +73,18 @@ class PRAgent:
if action == "answer": if action == "answer":
if notify: if notify:
notify() notify()
await PRReviewer(pr_url, is_answer=True, args=args).run() await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
elif action == "auto_review": elif action == "auto_review":
await PRReviewer(pr_url, is_auto=True, args=args).run() await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
elif action in command2class: elif action in command2class:
if notify: if notify:
notify() notify()
get_logger().info(f"Class: {command2class[action]}")
if(not has_ai_handler_param(cls=command2class[action])):
await command2class[action](pr_url, args=args).run() await command2class[action](pr_url, args=args).run()
else:
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
else: else:
return False return False
return True return True

View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler to be used by the PR Agents.
"""
@abstractmethod
def __init__(self):
pass
@property
@abstractmethod
def deployment_id(self):
pass
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
"""
This method should be implemented to return a chat completion from the AI model.
Args:
model (str): the name of the model to use for the chat completion
system (str): the system message string to use for the chat completion
user (str): the user message string to use for the chat completion
temperature (float): the temperature to use for the chat completion
"""
pass

View File

@ -0,0 +1,46 @@
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
OPENAI_RETRIES = 5
class LangChainOpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
try:
super().__init__()
self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key)
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@property
def chat(self):
return self._chat
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
messages=[SystemMessage(content=system), HumanMessage(content=user)]
# get a chat completion from the formatted messages
resp = self.chat(messages, model=model, temperature=temperature)
finish_reason="completed"
return resp.content, finish_reason
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e

View File

@ -6,14 +6,14 @@ import openai
from litellm import acompletion from litellm import acompletion
from openai.error import APIError, RateLimitError, Timeout, TryAgain from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry from retry import retry
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.algo.base_ai_handler import BaseAiHandler
from pr_agent.log import get_logger from pr_agent.log import get_logger
OPENAI_RETRIES = 5 OPENAI_RETRIES = 5
class AiHandler(BaseAiHandler): class LiteLLMAIHandler(BaseAiHandler):
""" """
This class handles interactions with the OpenAI API for chat completions. This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file, It initializes the API key and other settings from a configuration file,

View File

@ -0,0 +1,67 @@
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
OPENAI_RETRIES = 5
class OpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
try:
super().__init__()
openai.api_key = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
openai.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
deployment_id = self.deployment_id
get_logger().info("System: ", system)
get_logger().info("User: ", user)
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
chat_completion = await openai.ChatCompletion.acreate(
model=model,
deployment_id=deployment_id,
messages=messages,
temperature=temperature,
)
resp = chat_completion["choices"][0]['message']['content']
finish_reason = chat_completion["choices"][0]["finish_reason"]
usage = chat_completion.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
except (APIError, Timeout, TryAgain) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
except (RateLimitError) as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e

View File

@ -59,14 +59,14 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool=True) -> str:
if key.lower() == 'code feedback': if key.lower() == 'code feedback':
if gfm_supported: if gfm_supported:
markdown_text += f"\n\n- " markdown_text += f"\n\n- "
markdown_text += f"<details><summary> { emoji } Code feedback:</summary>\n\n" markdown_text += f"<details><summary> { emoji } Code feedback:</summary>"
else: else:
markdown_text += f"\n\n- **{emoji} Code feedback:**\n\n" markdown_text += f"\n\n- **{emoji} Code feedback:**\n\n"
else: else:
markdown_text += f"- {emoji} **{key}:**\n\n" markdown_text += f"- {emoji} **{key}:**\n\n"
for item in value: for i, item in enumerate(value):
if isinstance(item, dict) and key.lower() == 'code feedback': if isinstance(item, dict) and key.lower() == 'code feedback':
markdown_text += parse_code_suggestion(item, gfm_supported) markdown_text += parse_code_suggestion(item, i, gfm_supported)
elif item: elif item:
markdown_text += f" - {item}\n" markdown_text += f" - {item}\n"
if key.lower() == 'code feedback': if key.lower() == 'code feedback':
@ -80,7 +80,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool=True) -> str:
return markdown_text return markdown_text
def parse_code_suggestion(code_suggestions: dict, gfm_supported: bool=True) -> str: def parse_code_suggestion(code_suggestions: dict, i: int = 0, gfm_supported: bool = True) -> str:
""" """
Convert a dictionary of data into markdown format. Convert a dictionary of data into markdown format.
@ -91,6 +91,34 @@ def parse_code_suggestion(code_suggestions: dict, gfm_supported: bool=True) -> s
str: A string containing the markdown formatted text generated from the input dictionary. str: A string containing the markdown formatted text generated from the input dictionary.
""" """
markdown_text = "" markdown_text = ""
if gfm_supported and 'relevant line' in code_suggestions:
if i == 0:
markdown_text += "<hr>"
markdown_text += '<table>'
for sub_key, sub_value in code_suggestions.items():
try:
if sub_key.lower() == 'relevant file':
relevant_file = sub_value.strip('`').strip('"').strip("'")
markdown_text += f"<tr><td>{sub_key}</td><td>{relevant_file}</td></tr>"
# continue
elif sub_key.lower() == 'suggestion':
markdown_text += f"<tr><td>{sub_key} &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td><td><strong>{sub_value}</strong></td></tr>"
elif sub_key.lower() == 'relevant line':
markdown_text += f"<tr><td>relevant line</td>"
sub_value_list = sub_value.split('](')
relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
if len(sub_value_list) > 1:
link = sub_value_list[1].rstrip(')').strip('`')
markdown_text += f"<td><a href={link}>{relevant_line}</a></td>"
else:
markdown_text += f"<td>{relevant_line}</td>"
markdown_text += "</tr>"
except Exception as e:
get_logger().exception(f"Failed to parse code suggestion: {e}")
pass
markdown_text += '</table>'
markdown_text += "<hr>"
else:
for sub_key, sub_value in code_suggestions.items(): for sub_key, sub_value in code_suggestions.items():
if isinstance(sub_value, dict): # "code example" if isinstance(sub_value, dict): # "code example"
markdown_text += f" - **{sub_key}:**\n" markdown_text += f" - **{sub_key}:**\n"
@ -336,7 +364,7 @@ def try_fix_yaml(response_text: str) -> dict:
pass pass
def set_custom_labels(variables): def set_custom_labels(variables, git_provider=None):
if not get_settings().config.enable_custom_labels: if not get_settings().config.enable_custom_labels:
return return
@ -348,11 +376,8 @@ def set_custom_labels(variables):
labels_list = f" - {labels_list}" if labels_list else "" labels_list = f" - {labels_list}" if labels_list else ""
variables["custom_labels"] = labels_list variables["custom_labels"] = labels_list
return return
#final_labels = ""
#for k, v in labels.items(): # Set custom labels
# final_labels += f" - {k} ({v['description']})\n"
#variables["custom_labels"] = final_labels
#variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"
variables["custom_labels_class"] = "class Label(str, Enum):" variables["custom_labels_class"] = "class Label(str, Enum):"
for k, v in labels.items(): for k, v in labels.items():
description = v['description'].strip('\n').replace('\n', '\\n') description = v['description'].strip('\n').replace('\n', '\\n')

View File

@ -5,7 +5,9 @@ import os
from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.agent.pr_agent import PRAgent, commands
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.log import setup_logger from pr_agent.log import setup_logger
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger() setup_logger()
@ -57,9 +59,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions
command = args.command.lower() command = args.command.lower()
get_settings().set("CONFIG.CLI_MODE", True) get_settings().set("CONFIG.CLI_MODE", True)
if args.issue_url: if args.issue_url:
result = asyncio.run(PRAgent().handle_request(args.issue_url, [command] + args.rest)) result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, [command] + args.rest))
else: else:
result = asyncio.run(PRAgent().handle_request(args.pr_url, [command] + args.rest)) result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, [command] + args.rest))
if not result: if not result:
parser.print_help() parser.print_help()

View File

@ -18,7 +18,9 @@ from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.secret_providers import get_secret_provider from pr_agent.secret_providers import get_secret_provider
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter() router = APIRouter()
secret_provider = get_secret_provider() secret_provider = get_secret_provider()
@ -84,7 +86,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
context['bitbucket_bearer_token'] = bearer_token context['bitbucket_bearer_token'] = bearer_token
context["settings"] = copy.deepcopy(global_settings) context["settings"] = copy.deepcopy(global_settings)
event = data["event"] event = data["event"]
agent = PRAgent() agent = PRAgent(ai_handler=litellm_ai_handler)
if event == "pullrequest:created": if event == "pullrequest:created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url log_context["api_url"] = pr_url

View File

@ -12,7 +12,9 @@ from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger, setup_logger from pr_agent.log import get_logger, setup_logger
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger() setup_logger()
router = APIRouter() router = APIRouter()
@ -43,7 +45,7 @@ async def handle_gerrit_request(action: Action, item: Item):
status_code=400, status_code=400,
detail="msg is required for ask command" detail="msg is required for ask command"
) )
await PRAgent().handle_request( await PRAgent(ai_handler=litellm_ai_handler).handle_request(
f"{item.project}:{item.refspec}", f"{item.project}:{item.refspec}",
f"/{item.msg.strip()}" f"/{item.msg.strip()}"
) )

View File

@ -11,7 +11,8 @@ from pr_agent.log import get_logger
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
def is_true(value: Union[str, bool]) -> bool: def is_true(value: Union[str, bool]) -> bool:
if isinstance(value, bool): if isinstance(value, bool):
@ -110,9 +111,9 @@ async def run_action():
comment_id = event_payload.get("comment", {}).get("id") comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=url) provider = get_git_provider()(pr_url=url)
if is_pr: if is_pr:
await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
else: else:
await PRAgent().handle_request(url, body) await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -17,7 +17,9 @@ from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.git_providers.git_provider import IncrementalPR from pr_agent.git_providers.git_provider import IncrementalPR
from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter() router = APIRouter()
@ -79,7 +81,7 @@ async def handle_request(body: Dict[str, Any], event: str):
action = body.get("action") action = body.get("action")
if not action: if not action:
return {} return {}
agent = PRAgent() agent = PRAgent(ai_handler=litellm_ai_handler)
bot_user = get_settings().github_app.bot_user bot_user = get_settings().github_app.bot_user
sender = body.get("sender", {}).get("login") sender = body.get("sender", {}).get("login")
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"} log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}

View File

@ -8,7 +8,9 @@ from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.help import bot_help_text from pr_agent.servers.help import bot_help_text
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
NOTIFICATION_URL = "https://api.github.com/notifications" NOTIFICATION_URL = "https://api.github.com/notifications"
@ -34,7 +36,7 @@ async def polling_loop():
last_modified = [None] last_modified = [None]
git_provider = get_git_provider()() git_provider = get_git_provider()()
user_id = git_provider.get_user_id() user_id = git_provider.get_user_id()
agent = PRAgent() agent = PRAgent(ai_handler=litellm_ai_handler)
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
try: try:

View File

@ -14,7 +14,9 @@ from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.secret_providers import get_secret_provider from pr_agent.secret_providers import get_secret_provider
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter() router = APIRouter()
@ -26,7 +28,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c
log_context["event"] = "pull_request" if body == "/review" else "comment" log_context["event"] = "pull_request" if body == "/review" else "comment"
log_context["api_url"] = url log_context["api_url"] = url
with get_logger().contextualize(**log_context): with get_logger().contextualize(**log_context):
background_tasks.add_task(PRAgent().handle_request, url, body) background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body)
@router.post("/webhook") @router.post("/webhook")

View File

@ -4,7 +4,7 @@ from typing import Dict
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -15,14 +15,14 @@ from pr_agent.log import get_logger
class PRAddDocs: class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None): def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = AiHandler() self.ai_handler = ai_handler
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode

View File

@ -3,7 +3,7 @@ import textwrap
from typing import Dict, List from typing import Dict, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -14,7 +14,7 @@ from pr_agent.log import get_logger
class PRCodeSuggestions: class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(

View File

@ -4,7 +4,7 @@ from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
class PRDescription: class PRDescription:
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model. using an AI model.

View File

@ -2,7 +2,7 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,7 +12,7 @@ from pr_agent.log import get_logger
class PRInformationFromUser: class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()

View File

@ -2,7 +2,7 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,7 +12,7 @@ from pr_agent.log import get_logger
class PRQuestions: class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None):
question_str = self.parse_args(args) question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(

View File

@ -7,7 +7,7 @@ import yaml
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from yaml import SafeLoader from yaml import SafeLoader
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels
@ -22,13 +22,15 @@ class PRReviewer:
""" """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Args: Args:
pr_url (str): The URL of the pull request to be reviewed. pr_url (str): The URL of the pull request to be reviewed.
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False.
ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None.
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
""" """
self.parse_args(args) # -i command self.parse_args(args) # -i command

View File

@ -5,7 +5,7 @@ from typing import Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -17,7 +17,7 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog: class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(

View File

@ -23,3 +23,4 @@ starlette-context==0.3.6
tiktoken==0.5.2 tiktoken==0.5.2
ujson==5.8.0 ujson==5.8.0
uvicorn==0.22.0 uvicorn==0.22.0
langchain==0.0.349