Update Python code formatting, configuration loading, and local model additions

1. Code Formatting:
   - Standardized Python code formatting across multiple files to align with PEP 8 guidelines. This includes adjustments to whitespace, line breaks, and inline comments.

2. Configuration Loader Enhancements:
   - Enhanced the `get_settings` function in `config_loader.py` to provide more robust handling of settings retrieval. Added detailed documentation to improve code maintainability and clarity.

3. Model Addition in __init__.py:
   - Added a new model "ollama/llama3" with a token limit to the MAX_TOKENS dictionary in `__init__.py` to support new AI capabilities and configurations.
This commit is contained in:
Kamakura
2024-06-03 23:58:31 +08:00
parent ab31d2f1f8
commit b4f0ad948f
11 changed files with 48 additions and 29 deletions

View File

@ -46,9 +46,10 @@ command2class = {
commands = list(command2class.keys())
class PRAgent:
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler # will be initialized in run_action
self.ai_handler = ai_handler # will be initialized in run_action
self.forbidden_cli_args = ['enable_auto_approval']
async def handle_request(self, pr_url, request, notify=None) -> bool:
@ -68,7 +69,9 @@ class PRAgent:
for forbidden_arg in self.forbidden_cli_args:
for arg in args:
if forbidden_arg in arg:
get_logger().error(f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file.")
get_logger().error(
f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file."
)
return False
args = update_settings_from_args(args)
@ -94,4 +97,3 @@ class PRAgent:
else:
return False
return True

View File

@ -9,7 +9,7 @@ MAX_TOKENS = {
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o-2024-05-13': 128000, # 128K, but may be limited by config.max_model_tokens
@ -36,4 +36,5 @@ MAX_TOKENS = {
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
'groq/llama3-8b-8192': 8192,
'groq/llama3-70b-8192': 8192,
"ollama/llama3": 4096,
}

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler to be used by the PR Agents.
@ -14,7 +15,7 @@ class BaseAiHandler(ABC):
def deployment_id(self):
pass
@abstractmethod
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
"""
This method should be implemented to return a chat completion from the AI model.
@ -25,4 +26,3 @@ class BaseAiHandler(ABC):
temperature (float): the temperature to use for the chat completion
"""
pass

View File

@ -1,7 +1,7 @@
try:
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
pass
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
@ -14,6 +14,7 @@ import functools
OPENAI_RETRIES = 5
class LangChainOpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
@ -36,7 +37,7 @@ class LangChainOpenAIHandler(BaseAiHandler):
raise ValueError(f"OpenAI {e.name} is required") from e
else:
raise e
@property
def chat(self):
if self.azure:
@ -51,17 +52,18 @@ class LangChainOpenAIHandler(BaseAiHandler):
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
messages=[SystemMessage(content=system), HumanMessage(content=user)]
messages = [SystemMessage(content=system), HumanMessage(content=user)]
# get a chat completion from the formatted messages
resp = self.chat(messages, model=model, temperature=temperature)
finish_reason="completed"
finish_reason = "completed"
return resp.content, finish_reason
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e
raise e

View File

@ -61,7 +61,7 @@ class LiteLLMAIHandler(BaseAiHandler):
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None) :
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None):
@ -129,7 +129,7 @@ class LiteLLMAIHandler(BaseAiHandler):
"messages": messages,
"temperature": temperature,
"force_timeout": get_settings().config.ai_timeout,
"api_base" : self.api_base,
"api_base": self.api_base,
}
if self.aws_bedrock_client:
kwargs["aws_bedrock_client"] = self.aws_bedrock_client

View File

@ -28,13 +28,14 @@ class OpenAIHandler(BaseAiHandler):
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
@ -54,8 +55,8 @@ class OpenAIHandler(BaseAiHandler):
finish_reason = chat_completion["choices"][0]["finish_reason"]
usage = chat_completion.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
model=model, usage=usage)
return resp, finish_reason
except (APIError, Timeout, TryAgain) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
@ -64,4 +65,4 @@ class OpenAIHandler(BaseAiHandler):
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
raise TryAgain from e

View File

@ -3,6 +3,7 @@ import re
from pr_agent.config_loader import get_settings
def filter_ignored(files):
"""
Filter out files that match the ignore patterns.
@ -14,7 +15,7 @@ def filter_ignored(files):
if isinstance(patterns, str):
patterns = [patterns]
glob_setting = get_settings().ignore.glob
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
glob_setting = glob_setting.strip('[]').split(",")
patterns += [fnmatch.translate(glob) for glob in glob_setting]

View File

@ -409,7 +409,7 @@ def update_settings_from_args(args: List[str]) -> List[str]:
arg = arg.strip('-').strip()
vals = arg.split('=', 1)
if len(vals) != 2:
if len(vals) > 2: # --extended is a valid argument
if len(vals) > 2: # --extended is a valid argument
get_logger().error(f'Invalid argument format: {arg}')
other_args.append(arg)
continue

View File

@ -9,6 +9,7 @@ from pr_agent.log import setup_logger
log_level = os.environ.get("LOG_LEVEL", "INFO")
setup_logger(log_level)
def set_parser():
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\
@ -50,6 +51,7 @@ def set_parser():
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
return parser
def run_command(pr_url, command):
# Preparing the command
run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}"
@ -58,6 +60,7 @@ def run_command(pr_url, command):
# Run the command. Feedback will appear in GitHub PR comments
run(args=args)
def run(inargs=None, args=None):
parser = set_parser()
if not args:

View File

@ -34,6 +34,15 @@ global_settings = Dynaconf(
def get_settings():
"""
Retrieves the current settings.
This function attempts to fetch the settings from the starlette_context's context object. If it fails,
it defaults to the global settings defined outside of this function.
Returns:
Dynaconf: The current settings object, either from the context or the global default.
"""
try:
return context["settings"]
except Exception:
@ -41,7 +50,7 @@ def get_settings():
# Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Path:
def _find_repository_root() -> Optional[Path]:
"""
Identify project root directory by recursively searching for the .git directory in the parent directories.
"""
@ -61,7 +70,7 @@ def _find_pyproject() -> Optional[Path]:
"""
repo_root = _find_repository_root()
if repo_root:
pyproject = _find_repository_root() / "pyproject.toml"
pyproject = repo_root / "pyproject.toml"
return pyproject if pyproject.is_file() else None
return None

View File

@ -21,6 +21,7 @@ class PRReviewer:
"""
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
@ -34,7 +35,7 @@ class PRReviewer:
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
"""
self.args = args
self.parse_args(args) # -i command
self.parse_args(args) # -i command
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
self.main_language = get_main_pr_language(
@ -222,7 +223,6 @@ class PRReviewer:
else:
pass
incremental_review_markdown_text = None
# Add incremental review section
if self.incremental.is_incremental:
@ -278,7 +278,7 @@ class PRReviewer:
self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)
if comments:
self.git_provider.publish_inline_comments(comments)
self.git_provider.publish_inline_comments(comments)
def _get_user_answers(self) -> Tuple[str, str]:
"""
@ -373,10 +373,10 @@ class PRReviewer:
if get_settings().pr_reviewer.enable_review_labels_effort:
estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
estimated_effort_number = int(estimated_effort.split(',')[0])
if 1 <= estimated_effort_number <= 5: # 1, because ...
if 1 <= estimated_effort_number <= 5: # 1, because ...
review_labels.append(f'Review effort [1-5]: {estimated_effort_number}')
if get_settings().pr_reviewer.enable_review_labels_security:
security_concerns = data['review']['security_concerns'] # yes, because ...
security_concerns = data['review']['security_concerns'] # yes, because ...
security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower()
if security_concerns_bool:
review_labels.append('Possible security concern')
@ -426,4 +426,4 @@ class PRReviewer:
else:
get_logger().info("Auto-approval option is disabled")
self.git_provider.publish_comment("Auto-approval option for PR-Agent is disabled. "
"You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")
"You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")