bug fixes and updates

This commit is contained in:
Krrish Dholakia
2023-08-03 16:05:46 -07:00
parent 102edcdcf1
commit ed8554699b
7 changed files with 42 additions and 23 deletions

View File

@ -7,4 +7,8 @@ MAX_TOKENS = {
'gpt-4': 8000, 'gpt-4': 8000,
'gpt-4-0613': 8000, 'gpt-4-0613': 8000,
'gpt-4-32k': 32000, 'gpt-4-32k': 32000,
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
} }

View File

@ -3,6 +3,7 @@ import logging
import openai import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry from retry import retry
import litellm
from litellm import acompletion from litellm import acompletion
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -22,6 +23,7 @@ class AiHandler:
""" """
try: try:
openai.api_key = get_settings().openai.key openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None): if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org openai.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None) self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@ -31,6 +33,9 @@ class AiHandler:
openai.api_version = get_settings().openai.api_version openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None): if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base openai.api_base = get_settings().openai.api_base
litellm.api_base = get_settings().openai.api_base
if get_settings().get("LITE.KEY", None):
self.llm_api_key = get_settings().lite.key
except AttributeError as e: except AttributeError as e:
raise ValueError("OpenAI key is required") from e raise ValueError("OpenAI key is required") from e
@ -65,6 +70,7 @@ class AiHandler:
{"role": "user", "content": user} {"role": "user", "content": user}
], ],
temperature=temperature, temperature=temperature,
api_key=self.llm_api_key
) )
except (APIError, Timeout, TryAgain) as e: except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e) logging.error("Error during OpenAI inference: ", e)
@ -75,8 +81,9 @@ class AiHandler:
except (Exception) as e: except (Exception) as e:
logging.error("Unknown error during OpenAI inference: ", e) logging.error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e raise TryAgain from e
if response is None or len(response.choices) == 0: if response is None or len(response["choices"]) == 0:
raise TryAgain raise TryAgain
resp = response.choices[0]['message']['content'] resp = response["choices"][0]['message']['content']
finish_reason = response.choices[0].finish_reason finish_reason = response["choices"][0]["finish_reason"]
print(resp, finish_reason)
return resp, finish_reason return resp, finish_reason

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
import traceback
import logging import logging
from typing import Callable, Tuple from typing import Callable, Tuple
@ -221,6 +221,6 @@ async def retry_with_fallback_models(f: Callable):
try: try:
return await f(model) return await f(model)
except Exception as e: except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {e}") logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
if i == len(all_models) - 1: # If it's the last iteration if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception raise # Re-raise the last exception

View File

@ -1,5 +1,5 @@
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -27,7 +27,7 @@ class TokenHandler:
- system: The system string. - system: The system string.
- user: The user string. - user: The user string.
""" """
self.encoder = encoding_for_model(get_settings().config.model) self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
@ -47,7 +47,6 @@ class TokenHandler:
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars) system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars) user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt)) system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt)) user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens return system_prompt_tokens + user_prompt_tokens

View File

@ -7,17 +7,20 @@
# See README for details about GitHub App deployment. # See README for details about GitHub App deployment.
[openai] [openai]
key = "<API_KEY>" # Acquire through https://platform.openai.com key = "" # Acquire through https://platform.openai.com
org = "<ORGANIZATION>" # Optional, may be commented out. #org = "<ORGANIZATION>" # Optional, may be commented out.
# Uncomment the following for Azure OpenAI # Uncomment the following for Azure OpenAI
#api_type = "azure" #api_type = "azure"
#api_version = '2023-05-15' # Check Azure documentation for the current API version #api_version = '2023-05-15' # Check Azure documentation for the current API version
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com" #api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine #deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine
[lite]
key = "YOUR_LLM_API_KEY" # Optional, use this if you'd like to use Anthropic, Llama2 (Replicate), or Cohere models
[github] [github]
# ---- Set the following only for deployment type == "user" # ---- Set the following only for deployment type == "user"
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope. user_token = "" # A GitHub personal access token with 'repo' scope.
deployment_type = "user" #set to user by default
# ---- Set the following only for deployment type == "app", see README for details. # ---- Set the following only for deployment type == "app", see README for details.
private_key = """\ private_key = """\

View File

@ -160,12 +160,12 @@ class PRReviewer:
the feedback. the feedback.
""" """
review = self.prediction.strip() review = self.prediction.strip()
print(f"review: {review}")
try: try:
data = json.loads(review) data = json.loads(review)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
data = try_fix_json(review) data = try_fix_json(review)
print(f"data: {data}")
# Move 'Security concerns' key to 'PR Analysis' section for better display # Move 'Security concerns' key to 'PR Analysis' section for better display
if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']: if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']:
val = data['PR Feedback']['Security concerns'] val = data['PR Feedback']['Security concerns']
@ -173,6 +173,7 @@ class PRReviewer:
data['PR Analysis']['Security concerns'] = val data['PR Analysis']['Security concerns'] = val
# Filter out code suggestions that can be submitted as inline comments # Filter out code suggestions that can be submitted as inline comments
if 'PR Feedback' in data:
if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \ if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
and 'Code suggestions' in data['PR Feedback']: and 'Code suggestions' in data['PR Feedback']:
data['PR Feedback']['Code suggestions'] = [ data['PR Feedback']['Code suggestions'] = [
@ -206,6 +207,10 @@ class PRReviewer:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}") logging.info(f"Markdown response:\n{markdown_text}")
if markdown_text == None or len(markdown_text) == 0:
markdown_text = review
print(f"markdown text: {markdown_text}")
return markdown_text return markdown_text
def _publish_inline_code_comments(self) -> None: def _publish_inline_code_comments(self) -> None:

View File

@ -41,7 +41,8 @@ dependencies = [
"aiohttp~=3.8.4", "aiohttp~=3.8.4",
"atlassian-python-api==3.39.0", "atlassian-python-api==3.39.0",
"GitPython~=3.1.32", "GitPython~=3.1.32",
"starlette-context==0.3.6" "starlette-context==0.3.6",
"litellm==0.1.2291"
] ]
[project.urls] [project.urls]