Merge pull request #172 from krrishdholakia/patch-1

adding support for Anthropic, Cohere, Replicate, Azure
This commit is contained in:
Ori Kotek
2023-08-06 18:38:36 +03:00
committed by GitHub
8 changed files with 57 additions and 25 deletions

View File

@ -7,4 +7,8 @@ MAX_TOKENS = {
'gpt-4': 8000, 'gpt-4': 8000,
'gpt-4-0613': 8000, 'gpt-4-0613': 8000,
'gpt-4-32k': 32000, 'gpt-4-32k': 32000,
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
} }

View File

@ -3,9 +3,10 @@ import logging
import openai import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry from retry import retry
import litellm
from litellm import acompletion
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
import traceback
OPENAI_RETRIES=5 OPENAI_RETRIES=5
class AiHandler: class AiHandler:
@ -22,15 +23,25 @@ class AiHandler:
""" """
try: try:
openai.api_key = get_settings().openai.key openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
self.azure = False
if get_settings().get("OPENAI.ORG", None): if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org litellm.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None) self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if get_settings().get("OPENAI.API_TYPE", None): if get_settings().get("OPENAI.API_TYPE", None):
openai.api_type = get_settings().openai.api_type if get_settings().openai.api_type == "azure":
self.azure = True
litellm.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None): if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version litellm.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None): if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base litellm.api_base = get_settings().openai.api_base
if get_settings().get("ANTHROPIC.KEY", None):
litellm.anthropic_key = get_settings().anthropic.key
if get_settings().get("COHERE.KEY", None):
litellm.cohere_key = get_settings().cohere.key
if get_settings().get("REPLICATE.KEY", None):
litellm.replicate_key = get_settings().replicate.key
except AttributeError as e: except AttributeError as e:
raise ValueError("OpenAI key is required") from e raise ValueError("OpenAI key is required") from e
@ -57,7 +68,7 @@ class AiHandler:
TryAgain: If there is an attribute error during OpenAI inference. TryAgain: If there is an attribute error during OpenAI inference.
""" """
try: try:
response = await openai.ChatCompletion.acreate( response = await acompletion(
model=model, model=model,
deployment_id=self.deployment_id, deployment_id=self.deployment_id,
messages=[ messages=[
@ -65,6 +76,7 @@ class AiHandler:
{"role": "user", "content": user} {"role": "user", "content": user}
], ],
temperature=temperature, temperature=temperature,
azure=self.azure
) )
except (APIError, Timeout, TryAgain) as e: except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e) logging.error("Error during OpenAI inference: ", e)
@ -75,8 +87,9 @@ class AiHandler:
except (Exception) as e: except (Exception) as e:
logging.error("Unknown error during OpenAI inference: ", e) logging.error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e raise TryAgain from e
if response is None or len(response.choices) == 0: if response is None or len(response["choices"]) == 0:
raise TryAgain raise TryAgain
resp = response.choices[0]['message']['content'] resp = response["choices"][0]['message']['content']
finish_reason = response.choices[0].finish_reason finish_reason = response["choices"][0]["finish_reason"]
print(resp, finish_reason)
return resp, finish_reason return resp, finish_reason

View File

@ -1,9 +1,11 @@
from __future__ import annotations from __future__ import annotations
import re
import difflib import difflib
import logging import logging
from typing import Callable, Tuple, List, Any import re
import traceback
from typing import Any, Callable, List, Tuple
from github import RateLimitExceededException from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
@ -11,7 +13,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider, FilePatchInfo from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
DELETED_FILES_ = "Deleted files:\n" DELETED_FILES_ = "Deleted files:\n"
@ -215,7 +217,7 @@ async def retry_with_fallback_models(f: Callable):
try: try:
return await f(model) return await f(model)
except Exception as e: except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {e}") logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
if i == len(all_models) - 1: # If it's the last iteration if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception raise # Re-raise the last exception

View File

@ -1,5 +1,5 @@
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -27,7 +27,7 @@ class TokenHandler:
- system: The system string. - system: The system string.
- user: The user string. - user: The user string.
""" """
self.encoder = encoding_for_model(get_settings().config.model) self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
@ -47,7 +47,6 @@ class TokenHandler:
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars) system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars) user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt)) system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt)) user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens return system_prompt_tokens + user_prompt_tokens

View File

@ -7,17 +7,26 @@
# See README for details about GitHub App deployment. # See README for details about GitHub App deployment.
[openai] [openai]
key = "<API_KEY>" # Acquire through https://platform.openai.com key = "" # Acquire through https://platform.openai.com
org = "<ORGANIZATION>" # Optional, may be commented out. #org = "<ORGANIZATION>" # Optional, may be commented out.
# Uncomment the following for Azure OpenAI # Uncomment the following for Azure OpenAI
#api_type = "azure" #api_type = "azure"
#api_version = '2023-05-15' # Check Azure documentation for the current API version #api_version = '2023-05-15' # Check Azure documentation for the current API version
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com" #api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine #deployment_id = "" # The deployment name you chose when you deployed the engine
[anthropic]
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
[cohere]
key = "" # Optional, uncomment if you want to use Cohere. Acquire through https://dashboard.cohere.ai/
[replicate]
key = "" # Optional, uncomment if you want to use Replicate. Acquire through https://replicate.com/
[github] [github]
# ---- Set the following only for deployment type == "user" # ---- Set the following only for deployment type == "user"
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope. user_token = "" # A GitHub personal access token with 'repo' scope.
deployment_type = "user" #set to user by default
# ---- Set the following only for deployment type == "app", see README for details. # ---- Set the following only for deployment type == "app", see README for details.
private_key = """\ private_key = """\

View File

@ -174,7 +174,7 @@ class PRReviewer:
del pr_feedback['Security concerns'] del pr_feedback['Security concerns']
data.setdefault('PR Analysis', {})['Security concerns'] = security_concerns data.setdefault('PR Analysis', {})['Security concerns'] = security_concerns
# #
if 'Code feedback' in pr_feedback: if 'Code feedback' in pr_feedback:
code_feedback = pr_feedback['Code feedback'] code_feedback = pr_feedback['Code feedback']
@ -218,6 +218,9 @@ class PRReviewer:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}") logging.info(f"Markdown response:\n{markdown_text}")
if markdown_text == None or len(markdown_text) == 0:
markdown_text = review
return markdown_text return markdown_text
def _publish_inline_code_comments(self) -> None: def _publish_inline_code_comments(self) -> None:

View File

@ -41,7 +41,8 @@ dependencies = [
"aiohttp~=3.8.4", "aiohttp~=3.8.4",
"atlassian-python-api==3.39.0", "atlassian-python-api==3.39.0",
"GitPython~=3.1.32", "GitPython~=3.1.32",
"starlette-context==0.3.6" "starlette-context==0.3.6",
"litellm~=0.1.351"
] ]
[project.urls] [project.urls]

View File

@ -10,4 +10,5 @@ python-gitlab==3.15.0
pytest~=7.4.0 pytest~=7.4.0
aiohttp~=3.8.4 aiohttp~=3.8.4
atlassian-python-api==3.39.0 atlassian-python-api==3.39.0
GitPython~=3.1.32 GitPython~=3.1.32
litellm~=0.1.351