mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-06 05:40:38 +08:00
Merge branch 'base-ai-handler' into abstract-BaseAiHandler
This commit is contained in:
@ -1,8 +1,10 @@
|
|||||||
import shlex
|
import shlex
|
||||||
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
|
|
||||||
from pr_agent.algo.utils import update_settings_from_args
|
from pr_agent.algo.utils import update_settings_from_args
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers.utils import apply_repo_settings
|
from pr_agent.git_providers.utils import apply_repo_settings
|
||||||
|
from pr_agent.log import get_logger
|
||||||
from pr_agent.tools.pr_add_docs import PRAddDocs
|
from pr_agent.tools.pr_add_docs import PRAddDocs
|
||||||
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||||
from pr_agent.tools.pr_config import PRConfig
|
from pr_agent.tools.pr_config import PRConfig
|
||||||
@ -13,6 +15,7 @@ from pr_agent.tools.pr_questions import PRQuestions
|
|||||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||||
from pr_agent.tools.pr_similar_issue import PRSimilarIssue
|
from pr_agent.tools.pr_similar_issue import PRSimilarIssue
|
||||||
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
|
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
|
||||||
|
import inspect
|
||||||
|
|
||||||
command2class = {
|
command2class = {
|
||||||
"auto_review": PRReviewer,
|
"auto_review": PRReviewer,
|
||||||
@ -37,10 +40,19 @@ command2class = {
|
|||||||
|
|
||||||
commands = list(command2class.keys())
|
commands = list(command2class.keys())
|
||||||
|
|
||||||
|
def has_ai_handler_param(cls: object):
|
||||||
|
constructor = getattr(cls, "__init__", None)
|
||||||
|
if constructor is not None:
|
||||||
|
parameters = inspect.signature(constructor).parameters
|
||||||
|
return "ai_handler" in parameters
|
||||||
|
return False
|
||||||
|
|
||||||
class PRAgent:
|
class PRAgent:
|
||||||
def __init__(self):
|
def __init__(self, ai_handler: BaseAiHandler = None):
|
||||||
|
self.ai_handler = ai_handler
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||||
# First, apply repo specific settings if exists
|
# First, apply repo specific settings if exists
|
||||||
apply_repo_settings(pr_url)
|
apply_repo_settings(pr_url)
|
||||||
@ -61,13 +73,18 @@ class PRAgent:
|
|||||||
if action == "answer":
|
if action == "answer":
|
||||||
if notify:
|
if notify:
|
||||||
notify()
|
notify()
|
||||||
await PRReviewer(pr_url, is_answer=True, args=args).run()
|
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
|
||||||
elif action == "auto_review":
|
elif action == "auto_review":
|
||||||
await PRReviewer(pr_url, is_auto=True, args=args).run()
|
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
|
||||||
elif action in command2class:
|
elif action in command2class:
|
||||||
if notify:
|
if notify:
|
||||||
notify()
|
notify()
|
||||||
|
|
||||||
|
get_logger().info(f"Class: {command2class[action]}")
|
||||||
|
if(not has_ai_handler_param(cls=command2class[action])):
|
||||||
await command2class[action](pr_url, args=args).run()
|
await command2class[action](pr_url, args=args).run()
|
||||||
|
else:
|
||||||
|
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
28
pr_agent/algo/ai_handlers/base_ai_handler.py
Normal file
28
pr_agent/algo/ai_handlers/base_ai_handler.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
class BaseAiHandler(ABC):
|
||||||
|
"""
|
||||||
|
This class defines the interface for an AI handler to be used by the PR Agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def deployment_id(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||||
|
"""
|
||||||
|
This method should be implemented to return a chat completion from the AI model.
|
||||||
|
Args:
|
||||||
|
model (str): the name of the model to use for the chat completion
|
||||||
|
system (str): the system message string to use for the chat completion
|
||||||
|
user (str): the user message string to use for the chat completion
|
||||||
|
temperature (float): the temperature to use for the chat completion
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
46
pr_agent/algo/ai_handlers/langchain_ai_handler.py
Normal file
46
pr_agent/algo/ai_handlers/langchain_ai_handler.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.schema import SystemMessage, HumanMessage
|
||||||
|
|
||||||
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
|
from pr_agent.config_loader import get_settings
|
||||||
|
from pr_agent.log import get_logger
|
||||||
|
|
||||||
|
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||||
|
from retry import retry
|
||||||
|
|
||||||
|
OPENAI_RETRIES = 5
|
||||||
|
|
||||||
|
class LangChainOpenAIHandler(BaseAiHandler):
|
||||||
|
def __init__(self):
|
||||||
|
# Initialize OpenAIHandler specific attributes here
|
||||||
|
try:
|
||||||
|
super().__init__()
|
||||||
|
self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key)
|
||||||
|
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError("OpenAI key is required") from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat(self):
|
||||||
|
return self._chat
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deployment_id(self):
|
||||||
|
"""
|
||||||
|
Returns the deployment ID for the OpenAI API.
|
||||||
|
"""
|
||||||
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
||||||
|
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||||
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||||
|
try:
|
||||||
|
messages=[SystemMessage(content=system), HumanMessage(content=user)]
|
||||||
|
|
||||||
|
# get a chat completion from the formatted messages
|
||||||
|
resp = self.chat(messages, model=model, temperature=temperature)
|
||||||
|
finish_reason="completed"
|
||||||
|
return resp.content, finish_reason
|
||||||
|
|
||||||
|
except (Exception) as e:
|
||||||
|
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||||
|
raise e
|
@ -6,14 +6,14 @@ import openai
|
|||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||||
from retry import retry
|
from retry import retry
|
||||||
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.algo.base_ai_handler import BaseAiHandler
|
|
||||||
from pr_agent.log import get_logger
|
from pr_agent.log import get_logger
|
||||||
|
|
||||||
OPENAI_RETRIES = 5
|
OPENAI_RETRIES = 5
|
||||||
|
|
||||||
|
|
||||||
class AiHandler(BaseAiHandler):
|
class LiteLLMAIHandler(BaseAiHandler):
|
||||||
"""
|
"""
|
||||||
This class handles interactions with the OpenAI API for chat completions.
|
This class handles interactions with the OpenAI API for chat completions.
|
||||||
It initializes the API key and other settings from a configuration file,
|
It initializes the API key and other settings from a configuration file,
|
67
pr_agent/algo/ai_handlers/openai_ai_handler.py
Normal file
67
pr_agent/algo/ai_handlers/openai_ai_handler.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
|
import openai
|
||||||
|
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||||
|
from retry import retry
|
||||||
|
|
||||||
|
from pr_agent.config_loader import get_settings
|
||||||
|
from pr_agent.log import get_logger
|
||||||
|
|
||||||
|
OPENAI_RETRIES = 5
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIHandler(BaseAiHandler):
|
||||||
|
def __init__(self):
|
||||||
|
# Initialize OpenAIHandler specific attributes here
|
||||||
|
try:
|
||||||
|
super().__init__()
|
||||||
|
openai.api_key = get_settings().openai.key
|
||||||
|
if get_settings().get("OPENAI.ORG", None):
|
||||||
|
openai.organization = get_settings().openai.org
|
||||||
|
if get_settings().get("OPENAI.API_TYPE", None):
|
||||||
|
if get_settings().openai.api_type == "azure":
|
||||||
|
self.azure = True
|
||||||
|
openai.azure_key = get_settings().openai.key
|
||||||
|
if get_settings().get("OPENAI.API_VERSION", None):
|
||||||
|
openai.api_version = get_settings().openai.api_version
|
||||||
|
if get_settings().get("OPENAI.API_BASE", None):
|
||||||
|
openai.api_base = get_settings().openai.api_base
|
||||||
|
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError("OpenAI key is required") from e
|
||||||
|
@property
|
||||||
|
def deployment_id(self):
|
||||||
|
"""
|
||||||
|
Returns the deployment ID for the OpenAI API.
|
||||||
|
"""
|
||||||
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
|
||||||
|
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
||||||
|
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||||
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||||
|
try:
|
||||||
|
deployment_id = self.deployment_id
|
||||||
|
get_logger().info("System: ", system)
|
||||||
|
get_logger().info("User: ", user)
|
||||||
|
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||||
|
|
||||||
|
chat_completion = await openai.ChatCompletion.acreate(
|
||||||
|
model=model,
|
||||||
|
deployment_id=deployment_id,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
resp = chat_completion["choices"][0]['message']['content']
|
||||||
|
finish_reason = chat_completion["choices"][0]["finish_reason"]
|
||||||
|
usage = chat_completion.get("usage")
|
||||||
|
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
|
||||||
|
model=model, usage=usage)
|
||||||
|
return resp, finish_reason
|
||||||
|
except (APIError, Timeout, TryAgain) as e:
|
||||||
|
get_logger().error("Error during OpenAI inference: ", e)
|
||||||
|
raise
|
||||||
|
except (RateLimitError) as e:
|
||||||
|
get_logger().error("Rate limit error during OpenAI inference: ", e)
|
||||||
|
raise
|
||||||
|
except (Exception) as e:
|
||||||
|
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||||
|
raise TryAgain from e
|
@ -59,14 +59,14 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool=True) -> str:
|
|||||||
if key.lower() == 'code feedback':
|
if key.lower() == 'code feedback':
|
||||||
if gfm_supported:
|
if gfm_supported:
|
||||||
markdown_text += f"\n\n- "
|
markdown_text += f"\n\n- "
|
||||||
markdown_text += f"<details><summary> { emoji } Code feedback:</summary>\n\n"
|
markdown_text += f"<details><summary> { emoji } Code feedback:</summary>"
|
||||||
else:
|
else:
|
||||||
markdown_text += f"\n\n- **{emoji} Code feedback:**\n\n"
|
markdown_text += f"\n\n- **{emoji} Code feedback:**\n\n"
|
||||||
else:
|
else:
|
||||||
markdown_text += f"- {emoji} **{key}:**\n\n"
|
markdown_text += f"- {emoji} **{key}:**\n\n"
|
||||||
for item in value:
|
for i, item in enumerate(value):
|
||||||
if isinstance(item, dict) and key.lower() == 'code feedback':
|
if isinstance(item, dict) and key.lower() == 'code feedback':
|
||||||
markdown_text += parse_code_suggestion(item, gfm_supported)
|
markdown_text += parse_code_suggestion(item, i, gfm_supported)
|
||||||
elif item:
|
elif item:
|
||||||
markdown_text += f" - {item}\n"
|
markdown_text += f" - {item}\n"
|
||||||
if key.lower() == 'code feedback':
|
if key.lower() == 'code feedback':
|
||||||
@ -80,7 +80,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool=True) -> str:
|
|||||||
return markdown_text
|
return markdown_text
|
||||||
|
|
||||||
|
|
||||||
def parse_code_suggestion(code_suggestions: dict, gfm_supported: bool=True) -> str:
|
def parse_code_suggestion(code_suggestions: dict, i: int = 0, gfm_supported: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of data into markdown format.
|
Convert a dictionary of data into markdown format.
|
||||||
|
|
||||||
@ -91,6 +91,34 @@ def parse_code_suggestion(code_suggestions: dict, gfm_supported: bool=True) -> s
|
|||||||
str: A string containing the markdown formatted text generated from the input dictionary.
|
str: A string containing the markdown formatted text generated from the input dictionary.
|
||||||
"""
|
"""
|
||||||
markdown_text = ""
|
markdown_text = ""
|
||||||
|
if gfm_supported and 'relevant line' in code_suggestions:
|
||||||
|
if i == 0:
|
||||||
|
markdown_text += "<hr>"
|
||||||
|
markdown_text += '<table>'
|
||||||
|
for sub_key, sub_value in code_suggestions.items():
|
||||||
|
try:
|
||||||
|
if sub_key.lower() == 'relevant file':
|
||||||
|
relevant_file = sub_value.strip('`').strip('"').strip("'")
|
||||||
|
markdown_text += f"<tr><td>{sub_key}</td><td>{relevant_file}</td></tr>"
|
||||||
|
# continue
|
||||||
|
elif sub_key.lower() == 'suggestion':
|
||||||
|
markdown_text += f"<tr><td>{sub_key} </td><td><strong>{sub_value}</strong></td></tr>"
|
||||||
|
elif sub_key.lower() == 'relevant line':
|
||||||
|
markdown_text += f"<tr><td>relevant line</td>"
|
||||||
|
sub_value_list = sub_value.split('](')
|
||||||
|
relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
|
||||||
|
if len(sub_value_list) > 1:
|
||||||
|
link = sub_value_list[1].rstrip(')').strip('`')
|
||||||
|
markdown_text += f"<td><a href={link}>{relevant_line}</a></td>"
|
||||||
|
else:
|
||||||
|
markdown_text += f"<td>{relevant_line}</td>"
|
||||||
|
markdown_text += "</tr>"
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().exception(f"Failed to parse code suggestion: {e}")
|
||||||
|
pass
|
||||||
|
markdown_text += '</table>'
|
||||||
|
markdown_text += "<hr>"
|
||||||
|
else:
|
||||||
for sub_key, sub_value in code_suggestions.items():
|
for sub_key, sub_value in code_suggestions.items():
|
||||||
if isinstance(sub_value, dict): # "code example"
|
if isinstance(sub_value, dict): # "code example"
|
||||||
markdown_text += f" - **{sub_key}:**\n"
|
markdown_text += f" - **{sub_key}:**\n"
|
||||||
@ -336,7 +364,7 @@ def try_fix_yaml(response_text: str) -> dict:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def set_custom_labels(variables):
|
def set_custom_labels(variables, git_provider=None):
|
||||||
if not get_settings().config.enable_custom_labels:
|
if not get_settings().config.enable_custom_labels:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -348,11 +376,8 @@ def set_custom_labels(variables):
|
|||||||
labels_list = f" - {labels_list}" if labels_list else ""
|
labels_list = f" - {labels_list}" if labels_list else ""
|
||||||
variables["custom_labels"] = labels_list
|
variables["custom_labels"] = labels_list
|
||||||
return
|
return
|
||||||
#final_labels = ""
|
|
||||||
#for k, v in labels.items():
|
# Set custom labels
|
||||||
# final_labels += f" - {k} ({v['description']})\n"
|
|
||||||
#variables["custom_labels"] = final_labels
|
|
||||||
#variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"
|
|
||||||
variables["custom_labels_class"] = "class Label(str, Enum):"
|
variables["custom_labels_class"] = "class Label(str, Enum):"
|
||||||
for k, v in labels.items():
|
for k, v in labels.items():
|
||||||
description = v['description'].strip('\n').replace('\n', '\\n')
|
description = v['description'].strip('\n').replace('\n', '\\n')
|
||||||
|
@ -5,7 +5,9 @@ import os
|
|||||||
from pr_agent.agent.pr_agent import PRAgent, commands
|
from pr_agent.agent.pr_agent import PRAgent, commands
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.log import setup_logger
|
from pr_agent.log import setup_logger
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||||
|
|
||||||
|
litellm_ai_handler = LiteLLMAiHandler()
|
||||||
setup_logger()
|
setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@ -57,9 +59,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions
|
|||||||
command = args.command.lower()
|
command = args.command.lower()
|
||||||
get_settings().set("CONFIG.CLI_MODE", True)
|
get_settings().set("CONFIG.CLI_MODE", True)
|
||||||
if args.issue_url:
|
if args.issue_url:
|
||||||
result = asyncio.run(PRAgent().handle_request(args.issue_url, [command] + args.rest))
|
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, [command] + args.rest))
|
||||||
else:
|
else:
|
||||||
result = asyncio.run(PRAgent().handle_request(args.pr_url, [command] + args.rest))
|
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, [command] + args.rest))
|
||||||
if not result:
|
if not result:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
|
|
||||||
|
@ -18,7 +18,9 @@ from pr_agent.agent.pr_agent import PRAgent
|
|||||||
from pr_agent.config_loader import get_settings, global_settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||||
from pr_agent.secret_providers import get_secret_provider
|
from pr_agent.secret_providers import get_secret_provider
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||||
|
|
||||||
|
litellm_ai_handler = LiteLLMAiHandler()
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
secret_provider = get_secret_provider()
|
secret_provider = get_secret_provider()
|
||||||
@ -84,7 +86,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
|||||||
context['bitbucket_bearer_token'] = bearer_token
|
context['bitbucket_bearer_token'] = bearer_token
|
||||||
context["settings"] = copy.deepcopy(global_settings)
|
context["settings"] = copy.deepcopy(global_settings)
|
||||||
event = data["event"]
|
event = data["event"]
|
||||||
agent = PRAgent()
|
agent = PRAgent(ai_handler=litellm_ai_handler)
|
||||||
if event == "pullrequest:created":
|
if event == "pullrequest:created":
|
||||||
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
||||||
log_context["api_url"] = pr_url
|
log_context["api_url"] = pr_url
|
||||||
|
@ -12,7 +12,9 @@ from starlette_context.middleware import RawContextMiddleware
|
|||||||
from pr_agent.agent.pr_agent import PRAgent
|
from pr_agent.agent.pr_agent import PRAgent
|
||||||
from pr_agent.config_loader import get_settings, global_settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
from pr_agent.log import get_logger, setup_logger
|
from pr_agent.log import get_logger, setup_logger
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||||
|
|
||||||
|
litellm_ai_handler = LiteLLMAiHandler()
|
||||||
setup_logger()
|
setup_logger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -43,7 +45,7 @@ async def handle_gerrit_request(action: Action, item: Item):
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
detail="msg is required for ask command"
|
detail="msg is required for ask command"
|
||||||
)
|
)
|
||||||
await PRAgent().handle_request(
|
await PRAgent(ai_handler=litellm_ai_handler).handle_request(
|
||||||
f"{item.project}:{item.refspec}",
|
f"{item.project}:{item.refspec}",
|
||||||
f"/{item.msg.strip()}"
|
f"/{item.msg.strip()}"
|
||||||
)
|
)
|
||||||
|
@ -11,7 +11,8 @@ from pr_agent.log import get_logger
|
|||||||
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||||
from pr_agent.tools.pr_description import PRDescription
|
from pr_agent.tools.pr_description import PRDescription
|
||||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||||
|
litellm_ai_handler = LiteLLMAiHandler()
|
||||||
|
|
||||||
def is_true(value: Union[str, bool]) -> bool:
|
def is_true(value: Union[str, bool]) -> bool:
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
@ -110,9 +111,9 @@ async def run_action():
|
|||||||
comment_id = event_payload.get("comment", {}).get("id")
|
comment_id = event_payload.get("comment", {}).get("id")
|
||||||
provider = get_git_provider()(pr_url=url)
|
provider = get_git_provider()(pr_url=url)
|
||||||
if is_pr:
|
if is_pr:
|
||||||
await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
|
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
|
||||||
else:
|
else:
|
||||||
await PRAgent().handle_request(url, body)
|
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -17,7 +17,9 @@ from pr_agent.git_providers.utils import apply_repo_settings
|
|||||||
from pr_agent.git_providers.git_provider import IncrementalPR
|
from pr_agent.git_providers.git_provider import IncrementalPR
|
||||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||||
from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout
|
from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||||
|
|
||||||
|
litellm_ai_handler = LiteLLMAiHandler()
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -79,7 +81,7 @@ async def handle_request(body: Dict[str, Any], event: str):
|
|||||||
action = body.get("action")
|
action = body.get("action")
|
||||||
if not action:
|
if not action:
|
||||||
return {}
|
return {}
|
||||||
agent = PRAgent()
|
agent = PRAgent(ai_handler=litellm_ai_handler)
|
||||||
bot_user = get_settings().github_app.bot_user
|
bot_user = get_settings().github_app.bot_user
|
||||||
sender = body.get("sender", {}).get("login")
|
sender = body.get("sender", {}).get("login")
|
||||||
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}
|
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}
|
||||||
|
@ -8,7 +8,9 @@ from pr_agent.config_loader import get_settings
|
|||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||||
from pr_agent.servers.help import bot_help_text
|
from pr_agent.servers.help import bot_help_text
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||||
|
|
||||||
|
litellm_ai_handler = LiteLLMAiHandler()
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
NOTIFICATION_URL = "https://api.github.com/notifications"
|
NOTIFICATION_URL = "https://api.github.com/notifications"
|
||||||
|
|
||||||
@ -34,7 +36,7 @@ async def polling_loop():
|
|||||||
last_modified = [None]
|
last_modified = [None]
|
||||||
git_provider = get_git_provider()()
|
git_provider = get_git_provider()()
|
||||||
user_id = git_provider.get_user_id()
|
user_id = git_provider.get_user_id()
|
||||||
agent = PRAgent()
|
agent = PRAgent(ai_handler=litellm_ai_handler)
|
||||||
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -14,7 +14,9 @@ from pr_agent.agent.pr_agent import PRAgent
|
|||||||
from pr_agent.config_loader import get_settings, global_settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||||
from pr_agent.secret_providers import get_secret_provider
|
from pr_agent.secret_providers import get_secret_provider
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||||
|
|
||||||
|
litellm_ai_handler = LiteLLMAiHandler()
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -26,7 +28,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c
|
|||||||
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
||||||
log_context["api_url"] = url
|
log_context["api_url"] = url
|
||||||
with get_logger().contextualize(**log_context):
|
with get_logger().contextualize(**log_context):
|
||||||
background_tasks.add_task(PRAgent().handle_request, url, body)
|
background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/webhook")
|
@router.post("/webhook")
|
||||||
|
@ -4,7 +4,7 @@ from typing import Dict
|
|||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import load_yaml
|
from pr_agent.algo.utils import load_yaml
|
||||||
@ -15,14 +15,14 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRAddDocs:
|
class PRAddDocs:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None):
|
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = ai_handler
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
self.cli_mode = cli_mode
|
self.cli_mode = cli_mode
|
||||||
|
@ -3,7 +3,7 @@ import textwrap
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import load_yaml
|
from pr_agent.algo.utils import load_yaml
|
||||||
@ -14,7 +14,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRCodeSuggestions:
|
class PRCodeSuggestions:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler()):
|
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
|
@ -4,7 +4,7 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
|
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
|
||||||
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRDescription:
|
class PRDescription:
|
||||||
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()):
|
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
|
||||||
"""
|
"""
|
||||||
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
||||||
using an AI model.
|
using an AI model.
|
||||||
|
@ -2,7 +2,7 @@ import copy
|
|||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
@ -12,7 +12,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRInformationFromUser:
|
class PRInformationFromUser:
|
||||||
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()):
|
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_pr_language = get_main_pr_language(
|
self.main_pr_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
|
@ -2,7 +2,7 @@ import copy
|
|||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
@ -12,7 +12,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRQuestions:
|
class PRQuestions:
|
||||||
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = AiHandler()):
|
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None):
|
||||||
question_str = self.parse_args(args)
|
question_str = self.parse_args(args)
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_pr_language = get_main_pr_language(
|
self.main_pr_language = get_main_pr_language(
|
||||||
|
@ -7,7 +7,7 @@ import yaml
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
from yaml import SafeLoader
|
from yaml import SafeLoader
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels
|
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels
|
||||||
@ -22,13 +22,15 @@ class PRReviewer:
|
|||||||
"""
|
"""
|
||||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||||
"""
|
"""
|
||||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = AiHandler()):
|
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None):
|
||||||
"""
|
"""
|
||||||
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pr_url (str): The URL of the pull request to be reviewed.
|
pr_url (str): The URL of the pull request to be reviewed.
|
||||||
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
|
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
|
||||||
|
is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False.
|
||||||
|
ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None.
|
||||||
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.parse_args(args) # -i command
|
self.parse_args(args) # -i command
|
||||||
|
@ -5,7 +5,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
@ -17,7 +17,7 @@ CHANGELOG_LINES = 50
|
|||||||
|
|
||||||
|
|
||||||
class PRUpdateChangelog:
|
class PRUpdateChangelog:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = AiHandler()):
|
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
|
@ -23,3 +23,4 @@ starlette-context==0.3.6
|
|||||||
tiktoken==0.5.2
|
tiktoken==0.5.2
|
||||||
ujson==5.8.0
|
ujson==5.8.0
|
||||||
uvicorn==0.22.0
|
uvicorn==0.22.0
|
||||||
|
langchain==0.0.349
|
||||||
|
Reference in New Issue
Block a user