diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index d2542cf2..d0ac46ca 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -46,9 +46,10 @@ command2class = { commands = list(command2class.keys()) + class PRAgent: def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): - self.ai_handler = ai_handler # will be initialized in run_action + self.ai_handler = ai_handler # will be initialized in run_action self.forbidden_cli_args = ['enable_auto_approval'] async def handle_request(self, pr_url, request, notify=None) -> bool: @@ -68,7 +69,9 @@ class PRAgent: for forbidden_arg in self.forbidden_cli_args: for arg in args: if forbidden_arg in arg: - get_logger().error(f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file.") + get_logger().error( + f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file." + ) return False args = update_settings_from_args(args) @@ -94,4 +97,3 @@ class PRAgent: else: return False return True - diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index f22eb97c..43078f8d 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -9,7 +9,7 @@ MAX_TOKENS = { 'gpt-4': 8000, 'gpt-4-0613': 8000, 'gpt-4-32k': 32000, - 'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens + 'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens 'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens 'gpt-4o': 128000, # 128K, but may be limited by config.max_model_tokens 'gpt-4o-2024-05-13': 128000, # 128K, but may be limited by config.max_model_tokens @@ -36,4 +36,5 @@ MAX_TOKENS = { 'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000, 'groq/llama3-8b-8192': 8192, 'groq/llama3-70b-8192': 8192, + "ollama/llama3": 4096, } diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py index b5166b8e..e3274eac 100644 --- a/pr_agent/algo/ai_handlers/base_ai_handler.py +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class BaseAiHandler(ABC): """ This class defines the interface for an AI handler to be used by the PR Agents. @@ -14,7 +15,7 @@ class BaseAiHandler(ABC): def deployment_id(self): pass - @abstractmethod + @abstractmethod async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): """ This method should be implemented to return a chat completion from the AI model. @@ -25,4 +26,3 @@ class BaseAiHandler(ABC): temperature (float): the temperature to use for the chat completion """ pass - diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index a7c6d345..cbd3b2c9 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -1,7 +1,7 @@ try: from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langchain.schema import SystemMessage, HumanMessage -except: # we don't enforce langchain as a dependency, so if it's not installed, just move on +except: # we don't enforce langchain as a dependency, so if it's not installed, just move on pass from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler @@ -14,6 +14,7 @@ import functools OPENAI_RETRIES = 5 + class LangChainOpenAIHandler(BaseAiHandler): def __init__(self): # Initialize OpenAIHandler specific attributes here @@ -36,7 +37,7 @@ class LangChainOpenAIHandler(BaseAiHandler): raise ValueError(f"OpenAI {e.name} is required") from e else: raise e - + @property def chat(self): if self.azure: @@ -51,17 +52,18 @@ class LangChainOpenAIHandler(BaseAiHandler): Returns the deployment ID for the OpenAI API. """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): try: - messages=[SystemMessage(content=system), HumanMessage(content=user)] - + messages = [SystemMessage(content=system), HumanMessage(content=user)] + # get a chat completion from the formatted messages resp = self.chat(messages, model=model, temperature=temperature) - finish_reason="completed" + finish_reason = "completed" return resp.content, finish_reason - + except (Exception) as e: get_logger().error("Unknown error during OpenAI inference: ", e) - raise e \ No newline at end of file + raise e diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 969ddebe..a1dfc6c9 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -61,7 +61,7 @@ class LiteLLMAIHandler(BaseAiHandler): if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model: litellm.api_base = get_settings().huggingface.api_base self.api_base = get_settings().huggingface.api_base - if get_settings().get("OLLAMA.API_BASE", None) : + if get_settings().get("OLLAMA.API_BASE", None): litellm.api_base = get_settings().ollama.api_base self.api_base = get_settings().ollama.api_base if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None): @@ -129,7 +129,7 @@ class LiteLLMAIHandler(BaseAiHandler): "messages": messages, "temperature": temperature, "force_timeout": get_settings().config.ai_timeout, - "api_base" : self.api_base, + "api_base": self.api_base, } if self.aws_bedrock_client: kwargs["aws_bedrock_client"] = self.aws_bedrock_client diff --git a/pr_agent/algo/ai_handlers/openai_ai_handler.py b/pr_agent/algo/ai_handlers/openai_ai_handler.py index 3856f6f7..999f3d3f 100644 --- a/pr_agent/algo/ai_handlers/openai_ai_handler.py +++ b/pr_agent/algo/ai_handlers/openai_ai_handler.py @@ -28,13 +28,14 @@ class OpenAIHandler(BaseAiHandler): except AttributeError as e: raise ValueError("OpenAI key is required") from e + @property def deployment_id(self): """ Returns the deployment ID for the OpenAI API. """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): @@ -54,8 +55,8 @@ class OpenAIHandler(BaseAiHandler): finish_reason = chat_completion["choices"][0]["finish_reason"] usage = chat_completion.get("usage") get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, - model=model, usage=usage) - return resp, finish_reason + model=model, usage=usage) + return resp, finish_reason except (APIError, Timeout, TryAgain) as e: get_logger().error("Error during OpenAI inference: ", e) raise @@ -64,4 +65,4 @@ class OpenAIHandler(BaseAiHandler): raise except (Exception) as e: get_logger().error("Unknown error during OpenAI inference: ", e) - raise TryAgain from e \ No newline at end of file + raise TryAgain from e diff --git a/pr_agent/algo/file_filter.py b/pr_agent/algo/file_filter.py index aa457293..9f396549 100644 --- a/pr_agent/algo/file_filter.py +++ b/pr_agent/algo/file_filter.py @@ -3,6 +3,7 @@ import re from pr_agent.config_loader import get_settings + def filter_ignored(files): """ Filter out files that match the ignore patterns. @@ -14,7 +15,7 @@ def filter_ignored(files): if isinstance(patterns, str): patterns = [patterns] glob_setting = get_settings().ignore.glob - if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py + if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py glob_setting = glob_setting.strip('[]').split(",") patterns += [fnmatch.translate(glob) for glob in glob_setting] diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index f9798277..8c6bf673 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -409,7 +409,7 @@ def update_settings_from_args(args: List[str]) -> List[str]: arg = arg.strip('-').strip() vals = arg.split('=', 1) if len(vals) != 2: - if len(vals) > 2: # --extended is a valid argument + if len(vals) > 2: # --extended is a valid argument get_logger().error(f'Invalid argument format: {arg}') other_args.append(arg) continue diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 209d3641..7ab78a0e 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -9,6 +9,7 @@ from pr_agent.log import setup_logger log_level = os.environ.get("LOG_LEVEL", "INFO") setup_logger(log_level) + def set_parser(): parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage= """\ @@ -50,6 +51,7 @@ def set_parser(): parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) return parser + def run_command(pr_url, command): # Preparing the command run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}" @@ -58,6 +60,7 @@ def run_command(pr_url, command): # Run the command. Feedback will appear in GitHub PR comments run(args=args) + def run(inargs=None, args=None): parser = set_parser() if not args: diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 224a1c04..5facb3b0 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -34,6 +34,15 @@ global_settings = Dynaconf( def get_settings(): + """ + Retrieves the current settings. + + This function attempts to fetch the settings from the starlette_context's context object. If it fails, + it defaults to the global settings defined outside of this function. + + Returns: + Dynaconf: The current settings object, either from the context or the global default. + """ try: return context["settings"] except Exception: @@ -41,7 +50,7 @@ def get_settings(): # Add local configuration from pyproject.toml of the project being reviewed -def _find_repository_root() -> Path: +def _find_repository_root() -> Optional[Path]: """ Identify project root directory by recursively searching for the .git directory in the parent directories. """ @@ -61,7 +70,7 @@ def _find_pyproject() -> Optional[Path]: """ repo_root = _find_repository_root() if repo_root: - pyproject = _find_repository_root() / "pyproject.toml" + pyproject = repo_root / "pyproject.toml" return pyproject if pyproject.is_file() else None return None diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 3a127f4c..074023d2 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -21,6 +21,7 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): """ @@ -34,7 +35,7 @@ class PRReviewer: args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. """ self.args = args - self.parse_args(args) # -i command + self.parse_args(args) # -i command self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) self.main_language = get_main_pr_language( @@ -222,7 +223,6 @@ class PRReviewer: else: pass - incremental_review_markdown_text = None # Add incremental review section if self.incremental.is_incremental: @@ -278,7 +278,7 @@ class PRReviewer: self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file) if comments: - self.git_provider.publish_inline_comments(comments) + self.git_provider.publish_inline_comments(comments) def _get_user_answers(self) -> Tuple[str, str]: """ @@ -373,10 +373,10 @@ class PRReviewer: if get_settings().pr_reviewer.enable_review_labels_effort: estimated_effort = data['review']['estimated_effort_to_review_[1-5]'] estimated_effort_number = int(estimated_effort.split(',')[0]) - if 1 <= estimated_effort_number <= 5: # 1, because ... + if 1 <= estimated_effort_number <= 5: # 1, because ... review_labels.append(f'Review effort [1-5]: {estimated_effort_number}') if get_settings().pr_reviewer.enable_review_labels_security: - security_concerns = data['review']['security_concerns'] # yes, because ... + security_concerns = data['review']['security_concerns'] # yes, because ... security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower() if security_concerns_bool: review_labels.append('Possible security concern') @@ -426,4 +426,4 @@ class PRReviewer: else: get_logger().info("Auto-approval option is disabled") self.git_provider.publish_comment("Auto-approval option for PR-Agent is disabled. " - "You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)") \ No newline at end of file + "You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")