mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 13:20:39 +08:00
bug fixes and updates
This commit is contained in:
@ -7,4 +7,8 @@ MAX_TOKENS = {
|
|||||||
'gpt-4': 8000,
|
'gpt-4': 8000,
|
||||||
'gpt-4-0613': 8000,
|
'gpt-4-0613': 8000,
|
||||||
'gpt-4-32k': 32000,
|
'gpt-4-32k': 32000,
|
||||||
|
'claude-instant-1': 100000,
|
||||||
|
'claude-2': 100000,
|
||||||
|
'command-nightly': 4096,
|
||||||
|
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ import logging
|
|||||||
import openai
|
import openai
|
||||||
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||||
from retry import retry
|
from retry import retry
|
||||||
|
import litellm
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ class AiHandler:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
openai.api_key = get_settings().openai.key
|
openai.api_key = get_settings().openai.key
|
||||||
|
litellm.openai_key = get_settings().openai.key
|
||||||
if get_settings().get("OPENAI.ORG", None):
|
if get_settings().get("OPENAI.ORG", None):
|
||||||
openai.organization = get_settings().openai.org
|
openai.organization = get_settings().openai.org
|
||||||
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
@ -31,6 +33,9 @@ class AiHandler:
|
|||||||
openai.api_version = get_settings().openai.api_version
|
openai.api_version = get_settings().openai.api_version
|
||||||
if get_settings().get("OPENAI.API_BASE", None):
|
if get_settings().get("OPENAI.API_BASE", None):
|
||||||
openai.api_base = get_settings().openai.api_base
|
openai.api_base = get_settings().openai.api_base
|
||||||
|
litellm.api_base = get_settings().openai.api_base
|
||||||
|
if get_settings().get("LITE.KEY", None):
|
||||||
|
self.llm_api_key = get_settings().lite.key
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("OpenAI key is required") from e
|
raise ValueError("OpenAI key is required") from e
|
||||||
|
|
||||||
@ -65,6 +70,7 @@ class AiHandler:
|
|||||||
{"role": "user", "content": user}
|
{"role": "user", "content": user}
|
||||||
],
|
],
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
api_key=self.llm_api_key
|
||||||
)
|
)
|
||||||
except (APIError, Timeout, TryAgain) as e:
|
except (APIError, Timeout, TryAgain) as e:
|
||||||
logging.error("Error during OpenAI inference: ", e)
|
logging.error("Error during OpenAI inference: ", e)
|
||||||
@ -75,8 +81,9 @@ class AiHandler:
|
|||||||
except (Exception) as e:
|
except (Exception) as e:
|
||||||
logging.error("Unknown error during OpenAI inference: ", e)
|
logging.error("Unknown error during OpenAI inference: ", e)
|
||||||
raise TryAgain from e
|
raise TryAgain from e
|
||||||
if response is None or len(response.choices) == 0:
|
if response is None or len(response["choices"]) == 0:
|
||||||
raise TryAgain
|
raise TryAgain
|
||||||
resp = response.choices[0]['message']['content']
|
resp = response["choices"][0]['message']['content']
|
||||||
finish_reason = response.choices[0].finish_reason
|
finish_reason = response["choices"][0]["finish_reason"]
|
||||||
|
print(resp, finish_reason)
|
||||||
return resp, finish_reason
|
return resp, finish_reason
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
@ -221,6 +221,6 @@ async def retry_with_fallback_models(f: Callable):
|
|||||||
try:
|
try:
|
||||||
return await f(model)
|
return await f(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to generate prediction with {model}: {e}")
|
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
|
||||||
if i == len(all_models) - 1: # If it's the last iteration
|
if i == len(all_models) - 1: # If it's the last iteration
|
||||||
raise # Re-raise the last exception
|
raise # Re-raise the last exception
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model, get_encoding
|
||||||
|
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ class TokenHandler:
|
|||||||
- system: The system string.
|
- system: The system string.
|
||||||
- user: The user string.
|
- user: The user string.
|
||||||
"""
|
"""
|
||||||
self.encoder = encoding_for_model(get_settings().config.model)
|
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
|
||||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||||
|
|
||||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||||
@ -47,7 +47,6 @@ class TokenHandler:
|
|||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(system).render(vars)
|
system_prompt = environment.from_string(system).render(vars)
|
||||||
user_prompt = environment.from_string(user).render(vars)
|
user_prompt = environment.from_string(user).render(vars)
|
||||||
|
|
||||||
system_prompt_tokens = len(encoder.encode(system_prompt))
|
system_prompt_tokens = len(encoder.encode(system_prompt))
|
||||||
user_prompt_tokens = len(encoder.encode(user_prompt))
|
user_prompt_tokens = len(encoder.encode(user_prompt))
|
||||||
return system_prompt_tokens + user_prompt_tokens
|
return system_prompt_tokens + user_prompt_tokens
|
||||||
|
@ -7,17 +7,20 @@
|
|||||||
# See README for details about GitHub App deployment.
|
# See README for details about GitHub App deployment.
|
||||||
|
|
||||||
[openai]
|
[openai]
|
||||||
key = "<API_KEY>" # Acquire through https://platform.openai.com
|
key = "" # Acquire through https://platform.openai.com
|
||||||
org = "<ORGANIZATION>" # Optional, may be commented out.
|
#org = "<ORGANIZATION>" # Optional, may be commented out.
|
||||||
# Uncomment the following for Azure OpenAI
|
# Uncomment the following for Azure OpenAI
|
||||||
#api_type = "azure"
|
#api_type = "azure"
|
||||||
#api_version = '2023-05-15' # Check Azure documentation for the current API version
|
#api_version = '2023-05-15' # Check Azure documentation for the current API version
|
||||||
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
|
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
|
||||||
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine
|
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine
|
||||||
|
|
||||||
|
[lite]
|
||||||
|
key = "YOUR_LLM_API_KEY" # Optional, use this if you'd like to use Anthropic, Llama2 (Replicate), or Cohere models
|
||||||
[github]
|
[github]
|
||||||
# ---- Set the following only for deployment type == "user"
|
# ---- Set the following only for deployment type == "user"
|
||||||
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope.
|
user_token = "" # A GitHub personal access token with 'repo' scope.
|
||||||
|
deployment_type = "user" #set to user by default
|
||||||
|
|
||||||
# ---- Set the following only for deployment type == "app", see README for details.
|
# ---- Set the following only for deployment type == "app", see README for details.
|
||||||
private_key = """\
|
private_key = """\
|
||||||
|
@ -160,12 +160,12 @@ class PRReviewer:
|
|||||||
the feedback.
|
the feedback.
|
||||||
"""
|
"""
|
||||||
review = self.prediction.strip()
|
review = self.prediction.strip()
|
||||||
|
print(f"review: {review}")
|
||||||
try:
|
try:
|
||||||
data = json.loads(review)
|
data = json.loads(review)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
data = try_fix_json(review)
|
data = try_fix_json(review)
|
||||||
|
print(f"data: {data}")
|
||||||
# Move 'Security concerns' key to 'PR Analysis' section for better display
|
# Move 'Security concerns' key to 'PR Analysis' section for better display
|
||||||
if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']:
|
if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']:
|
||||||
val = data['PR Feedback']['Security concerns']
|
val = data['PR Feedback']['Security concerns']
|
||||||
@ -173,6 +173,7 @@ class PRReviewer:
|
|||||||
data['PR Analysis']['Security concerns'] = val
|
data['PR Analysis']['Security concerns'] = val
|
||||||
|
|
||||||
# Filter out code suggestions that can be submitted as inline comments
|
# Filter out code suggestions that can be submitted as inline comments
|
||||||
|
if 'PR Feedback' in data:
|
||||||
if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
|
if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
|
||||||
and 'Code suggestions' in data['PR Feedback']:
|
and 'Code suggestions' in data['PR Feedback']:
|
||||||
data['PR Feedback']['Code suggestions'] = [
|
data['PR Feedback']['Code suggestions'] = [
|
||||||
@ -206,6 +207,10 @@ class PRReviewer:
|
|||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Markdown response:\n{markdown_text}")
|
logging.info(f"Markdown response:\n{markdown_text}")
|
||||||
|
|
||||||
|
if markdown_text == None or len(markdown_text) == 0:
|
||||||
|
markdown_text = review
|
||||||
|
|
||||||
|
print(f"markdown text: {markdown_text}")
|
||||||
return markdown_text
|
return markdown_text
|
||||||
|
|
||||||
def _publish_inline_code_comments(self) -> None:
|
def _publish_inline_code_comments(self) -> None:
|
||||||
|
@ -41,7 +41,8 @@ dependencies = [
|
|||||||
"aiohttp~=3.8.4",
|
"aiohttp~=3.8.4",
|
||||||
"atlassian-python-api==3.39.0",
|
"atlassian-python-api==3.39.0",
|
||||||
"GitPython~=3.1.32",
|
"GitPython~=3.1.32",
|
||||||
"starlette-context==0.3.6"
|
"starlette-context==0.3.6",
|
||||||
|
"litellm==0.1.2291"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
Reference in New Issue
Block a user