update openai api
This commit is contained in:
zhouleilei
2024-11-02 09:47:14 +08:00
parent 15e8c988a4
commit dacb45dd8a

View File

@ -1,6 +1,7 @@
from os import environ
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
import openai import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain from openai import APIError, AsyncOpenAI, RateLimitError, Timeout
from retry import retry from retry import retry
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -14,7 +15,7 @@ class OpenAIHandler(BaseAiHandler):
# Initialize OpenAIHandler specific attributes here # Initialize OpenAIHandler specific attributes here
try: try:
super().__init__() super().__init__()
openai.api_key = get_settings().openai.key environ["OPENAI_API_KEY"] = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None): if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org openai.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None): if get_settings().get("OPENAI.API_TYPE", None):
@ -24,7 +25,7 @@ class OpenAIHandler(BaseAiHandler):
if get_settings().get("OPENAI.API_VERSION", None): if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None): if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base environ["OPENAI_BASE_URL"] = get_settings().openai.api_base
except AttributeError as e: except AttributeError as e:
raise ValueError("OpenAI key is required") from e raise ValueError("OpenAI key is required") from e
@ -36,7 +37,7 @@ class OpenAIHandler(BaseAiHandler):
""" """
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), @retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try: try:
@ -44,20 +45,19 @@ class OpenAIHandler(BaseAiHandler):
get_logger().info("System: ", system) get_logger().info("System: ", system)
get_logger().info("User: ", user) get_logger().info("User: ", user)
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
client = AsyncOpenAI()
chat_completion = await openai.ChatCompletion.acreate( chat_completion = await client.chat.completions.create(
model=model, model=model,
deployment_id=deployment_id,
messages=messages, messages=messages,
temperature=temperature, temperature=temperature,
) )
resp = chat_completion["choices"][0]['message']['content'] resp = chat_completion.choices[0].message.content
finish_reason = chat_completion["choices"][0]["finish_reason"] finish_reason = chat_completion.choices[0].finish_reason
usage = chat_completion.get("usage") usage = chat_completion.usage
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage) model=model, usage=usage)
return resp, finish_reason return resp, finish_reason
except (APIError, Timeout, TryAgain) as e: except (APIError, Timeout) as e:
get_logger().error("Error during OpenAI inference: ", e) get_logger().error("Error during OpenAI inference: ", e)
raise raise
except (RateLimitError) as e: except (RateLimitError) as e:
@ -65,4 +65,4 @@ class OpenAIHandler(BaseAiHandler):
raise raise
except (Exception) as e: except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e) get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e raise