mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 21:00:40 +08:00
Merge pull request #172 from krrishdholakia/patch-1
adding support for Anthropic, Cohere, Replicate, Azure
This commit is contained in:
@ -7,4 +7,8 @@ MAX_TOKENS = {
|
||||
'gpt-4': 8000,
|
||||
'gpt-4-0613': 8000,
|
||||
'gpt-4-32k': 32000,
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'command-nightly': 4096,
|
||||
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
|
||||
}
|
||||
|
@ -3,9 +3,10 @@ import logging
|
||||
import openai
|
||||
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||
from retry import retry
|
||||
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
||||
import traceback
|
||||
OPENAI_RETRIES=5
|
||||
|
||||
class AiHandler:
|
||||
@ -22,15 +23,25 @@ class AiHandler:
|
||||
"""
|
||||
try:
|
||||
openai.api_key = get_settings().openai.key
|
||||
litellm.openai_key = get_settings().openai.key
|
||||
self.azure = False
|
||||
if get_settings().get("OPENAI.ORG", None):
|
||||
openai.organization = get_settings().openai.org
|
||||
litellm.organization = get_settings().openai.org
|
||||
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
if get_settings().get("OPENAI.API_TYPE", None):
|
||||
openai.api_type = get_settings().openai.api_type
|
||||
if get_settings().openai.api_type == "azure":
|
||||
self.azure = True
|
||||
litellm.azure_key = get_settings().openai.key
|
||||
if get_settings().get("OPENAI.API_VERSION", None):
|
||||
openai.api_version = get_settings().openai.api_version
|
||||
litellm.api_version = get_settings().openai.api_version
|
||||
if get_settings().get("OPENAI.API_BASE", None):
|
||||
openai.api_base = get_settings().openai.api_base
|
||||
litellm.api_base = get_settings().openai.api_base
|
||||
if get_settings().get("ANTHROPIC.KEY", None):
|
||||
litellm.anthropic_key = get_settings().anthropic.key
|
||||
if get_settings().get("COHERE.KEY", None):
|
||||
litellm.cohere_key = get_settings().cohere.key
|
||||
if get_settings().get("REPLICATE.KEY", None):
|
||||
litellm.replicate_key = get_settings().replicate.key
|
||||
except AttributeError as e:
|
||||
raise ValueError("OpenAI key is required") from e
|
||||
|
||||
@ -57,7 +68,7 @@ class AiHandler:
|
||||
TryAgain: If there is an attribute error during OpenAI inference.
|
||||
"""
|
||||
try:
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
response = await acompletion(
|
||||
model=model,
|
||||
deployment_id=self.deployment_id,
|
||||
messages=[
|
||||
@ -65,6 +76,7 @@ class AiHandler:
|
||||
{"role": "user", "content": user}
|
||||
],
|
||||
temperature=temperature,
|
||||
azure=self.azure
|
||||
)
|
||||
except (APIError, Timeout, TryAgain) as e:
|
||||
logging.error("Error during OpenAI inference: ", e)
|
||||
@ -75,8 +87,9 @@ class AiHandler:
|
||||
except (Exception) as e:
|
||||
logging.error("Unknown error during OpenAI inference: ", e)
|
||||
raise TryAgain from e
|
||||
if response is None or len(response.choices) == 0:
|
||||
if response is None or len(response["choices"]) == 0:
|
||||
raise TryAgain
|
||||
resp = response.choices[0]['message']['content']
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
resp = response["choices"][0]['message']['content']
|
||||
finish_reason = response["choices"][0]["finish_reason"]
|
||||
print(resp, finish_reason)
|
||||
return resp, finish_reason
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import difflib
|
||||
import logging
|
||||
from typing import Callable, Tuple, List, Any
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
from github import RateLimitExceededException
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
@ -11,7 +13,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
|
||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import GitProvider, FilePatchInfo
|
||||
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
|
||||
|
||||
DELETED_FILES_ = "Deleted files:\n"
|
||||
|
||||
@ -215,7 +217,7 @@ async def retry_with_fallback_models(f: Callable):
|
||||
try:
|
||||
return await f(model)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate prediction with {model}: {e}")
|
||||
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
|
||||
if i == len(all_models) - 1: # If it's the last iteration
|
||||
raise # Re-raise the last exception
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
from tiktoken import encoding_for_model
|
||||
from tiktoken import encoding_for_model, get_encoding
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
|
||||
@ -27,7 +27,7 @@ class TokenHandler:
|
||||
- system: The system string.
|
||||
- user: The user string.
|
||||
"""
|
||||
self.encoder = encoding_for_model(get_settings().config.model)
|
||||
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
@ -47,7 +47,6 @@ class TokenHandler:
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(system).render(vars)
|
||||
user_prompt = environment.from_string(user).render(vars)
|
||||
|
||||
system_prompt_tokens = len(encoder.encode(system_prompt))
|
||||
user_prompt_tokens = len(encoder.encode(user_prompt))
|
||||
return system_prompt_tokens + user_prompt_tokens
|
||||
|
@ -7,17 +7,26 @@
|
||||
# See README for details about GitHub App deployment.
|
||||
|
||||
[openai]
|
||||
key = "<API_KEY>" # Acquire through https://platform.openai.com
|
||||
org = "<ORGANIZATION>" # Optional, may be commented out.
|
||||
key = "" # Acquire through https://platform.openai.com
|
||||
#org = "<ORGANIZATION>" # Optional, may be commented out.
|
||||
# Uncomment the following for Azure OpenAI
|
||||
#api_type = "azure"
|
||||
#api_version = '2023-05-15' # Check Azure documentation for the current API version
|
||||
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
|
||||
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine
|
||||
#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
|
||||
#deployment_id = "" # The deployment name you chose when you deployed the engine
|
||||
|
||||
[anthropic]
|
||||
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
|
||||
|
||||
[cohere]
|
||||
key = "" # Optional, uncomment if you want to use Cohere. Acquire through https://dashboard.cohere.ai/
|
||||
|
||||
[replicate]
|
||||
key = "" # Optional, uncomment if you want to use Replicate. Acquire through https://replicate.com/
|
||||
[github]
|
||||
# ---- Set the following only for deployment type == "user"
|
||||
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope.
|
||||
user_token = "" # A GitHub personal access token with 'repo' scope.
|
||||
deployment_type = "user" #set to user by default
|
||||
|
||||
# ---- Set the following only for deployment type == "app", see README for details.
|
||||
private_key = """\
|
||||
|
@ -174,7 +174,7 @@ class PRReviewer:
|
||||
del pr_feedback['Security concerns']
|
||||
data.setdefault('PR Analysis', {})['Security concerns'] = security_concerns
|
||||
|
||||
#
|
||||
#
|
||||
if 'Code feedback' in pr_feedback:
|
||||
code_feedback = pr_feedback['Code feedback']
|
||||
|
||||
@ -218,6 +218,9 @@ class PRReviewer:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
logging.info(f"Markdown response:\n{markdown_text}")
|
||||
|
||||
if markdown_text == None or len(markdown_text) == 0:
|
||||
markdown_text = review
|
||||
|
||||
return markdown_text
|
||||
|
||||
def _publish_inline_code_comments(self) -> None:
|
||||
|
@ -41,7 +41,8 @@ dependencies = [
|
||||
"aiohttp~=3.8.4",
|
||||
"atlassian-python-api==3.39.0",
|
||||
"GitPython~=3.1.32",
|
||||
"starlette-context==0.3.6"
|
||||
"starlette-context==0.3.6",
|
||||
"litellm~=0.1.351"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
@ -10,4 +10,5 @@ python-gitlab==3.15.0
|
||||
pytest~=7.4.0
|
||||
aiohttp~=3.8.4
|
||||
atlassian-python-api==3.39.0
|
||||
GitPython~=3.1.32
|
||||
GitPython~=3.1.32
|
||||
litellm~=0.1.351
|
Reference in New Issue
Block a user