Merge pull request #942 from barnett-yuxiang/main

Update Python code formatting, configuration loading, and local model additions
This commit is contained in:
Tal
2024-06-04 15:38:20 +03:00
committed by GitHub
15 changed files with 64 additions and 45 deletions

View File

@ -1,17 +1,17 @@
After [installation](https://codium-ai.github.io/Docs-PR-Agent/installation/), there are three basic ways to invoke CodiumAI PR-Agent: After [installation](https://pr-agent-docs.codium.ai/installation/), there are three basic ways to invoke CodiumAI PR-Agent:
1. Locally running a CLI command 1. Locally running a CLI command
2. Online usage - by [commenting](https://github.com/Codium-ai/pr-agent/pull/229#issuecomment-1695021901) on a PR 2. Online usage - by [commenting](https://github.com/Codium-ai/pr-agent/pull/229#issuecomment-1695021901) on a PR
3. Enabling PR-Agent tools to run automatically when a new PR is opened 3. Enabling PR-Agent tools to run automatically when a new PR is opened
Specifically, CLI commands can be issued by invoking a pre-built [docker image](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-from-source), or by invoking a [locally cloned repo](https://codium-ai.github.io/Docs-PR-Agent/installation/#locally). Specifically, CLI commands can be issued by invoking a pre-built [docker image](https://pr-agent-docs.codium.ai/installation/locally/#using-docker-image), or by invoking a [locally cloned repo](https://pr-agent-docs.codium.ai/installation/locally/#run-from-source).
For online usage, you will need to setup either a [GitHub App](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-as-a-github-app), or a [GitHub Action](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-as-a-github-action). For online usage, you will need to setup either a [GitHub App](https://pr-agent-docs.codium.ai/installation/github/#run-as-a-github-app), or a [GitHub Action](https://pr-agent-docs.codium.ai/installation/github/#run-as-a-github-action).
GitHub App and GitHub Action also enable to run PR-Agent specific tool automatically when a new PR is opened. GitHub App and GitHub Action also enable to run PR-Agent specific tool automatically when a new PR is opened.
**git provider**: The [git_provider](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L4) field in the configuration file determines the GIT provider that will be used by PR-Agent. Currently, the following providers are supported: **git provider**: The [git_provider](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L5) field in the configuration file determines the GIT provider that will be used by PR-Agent. Currently, the following providers are supported:
` `
"github", "gitlab", "bitbucket", "azure", "codecommit", "local", "gerrit" "github", "gitlab", "bitbucket", "azure", "codecommit", "local", "gerrit"
` `

View File

@ -46,6 +46,7 @@ command2class = {
commands = list(command2class.keys()) commands = list(command2class.keys())
class PRAgent: class PRAgent:
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler # will be initialized in run_action self.ai_handler = ai_handler # will be initialized in run_action
@ -68,7 +69,9 @@ class PRAgent:
for forbidden_arg in self.forbidden_cli_args: for forbidden_arg in self.forbidden_cli_args:
for arg in args: for arg in args:
if forbidden_arg in arg: if forbidden_arg in arg:
get_logger().error(f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file.") get_logger().error(
f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file."
)
return False return False
args = update_settings_from_args(args) args = update_settings_from_args(args)
@ -94,4 +97,3 @@ class PRAgent:
else: else:
return False return False
return True return True

View File

@ -36,4 +36,5 @@ MAX_TOKENS = {
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000, 'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
'groq/llama3-8b-8192': 8192, 'groq/llama3-8b-8192': 8192,
'groq/llama3-70b-8192': 8192, 'groq/llama3-70b-8192': 8192,
'ollama/llama3': 4096,
} }

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class BaseAiHandler(ABC): class BaseAiHandler(ABC):
""" """
This class defines the interface for an AI handler to be used by the PR Agents. This class defines the interface for an AI handler to be used by the PR Agents.
@ -25,4 +26,3 @@ class BaseAiHandler(ABC):
temperature (float): the temperature to use for the chat completion temperature (float): the temperature to use for the chat completion
""" """
pass pass

View File

@ -14,6 +14,7 @@ import functools
OPENAI_RETRIES = 5 OPENAI_RETRIES = 5
class LangChainOpenAIHandler(BaseAiHandler): class LangChainOpenAIHandler(BaseAiHandler):
def __init__(self): def __init__(self):
# Initialize OpenAIHandler specific attributes here # Initialize OpenAIHandler specific attributes here
@ -51,15 +52,16 @@ class LangChainOpenAIHandler(BaseAiHandler):
Returns the deployment ID for the OpenAI API. Returns the deployment ID for the OpenAI API.
""" """
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try: try:
messages=[SystemMessage(content=system), HumanMessage(content=user)] messages = [SystemMessage(content=system), HumanMessage(content=user)]
# get a chat completion from the formatted messages # get a chat completion from the formatted messages
resp = self.chat(messages, model=model, temperature=temperature) resp = self.chat(messages, model=model, temperature=temperature)
finish_reason="completed" finish_reason = "completed"
return resp.content, finish_reason return resp.content, finish_reason
except (Exception) as e: except (Exception) as e:

View File

@ -61,7 +61,7 @@ class LiteLLMAIHandler(BaseAiHandler):
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model: if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None) : if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None): if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None):
@ -129,7 +129,7 @@ class LiteLLMAIHandler(BaseAiHandler):
"messages": messages, "messages": messages,
"temperature": temperature, "temperature": temperature,
"force_timeout": get_settings().config.ai_timeout, "force_timeout": get_settings().config.ai_timeout,
"api_base" : self.api_base, "api_base": self.api_base,
} }
if self.aws_bedrock_client: if self.aws_bedrock_client:
kwargs["aws_bedrock_client"] = self.aws_bedrock_client kwargs["aws_bedrock_client"] = self.aws_bedrock_client

View File

@ -28,6 +28,7 @@ class OpenAIHandler(BaseAiHandler):
except AttributeError as e: except AttributeError as e:
raise ValueError("OpenAI key is required") from e raise ValueError("OpenAI key is required") from e
@property @property
def deployment_id(self): def deployment_id(self):
""" """

View File

@ -3,6 +3,7 @@ import re
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
def filter_ignored(files): def filter_ignored(files):
""" """
Filter out files that match the ignore patterns. Filter out files that match the ignore patterns.

View File

@ -9,6 +9,7 @@ from pr_agent.log import setup_logger
log_level = os.environ.get("LOG_LEVEL", "INFO") log_level = os.environ.get("LOG_LEVEL", "INFO")
setup_logger(log_level) setup_logger(log_level)
def set_parser(): def set_parser():
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage= parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\ """\
@ -50,6 +51,7 @@ def set_parser():
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
return parser return parser
def run_command(pr_url, command): def run_command(pr_url, command):
# Preparing the command # Preparing the command
run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}" run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}"
@ -58,6 +60,7 @@ def run_command(pr_url, command):
# Run the command. Feedback will appear in GitHub PR comments # Run the command. Feedback will appear in GitHub PR comments
run(args=args) run(args=args)
def run(inargs=None, args=None): def run(inargs=None, args=None):
parser = set_parser() parser = set_parser()
if not args: if not args:

View File

@ -34,6 +34,15 @@ global_settings = Dynaconf(
def get_settings(): def get_settings():
"""
Retrieves the current settings.
This function attempts to fetch the settings from the starlette_context's context object. If it fails,
it defaults to the global settings defined outside of this function.
Returns:
Dynaconf: The current settings object, either from the context or the global default.
"""
try: try:
return context["settings"] return context["settings"]
except Exception: except Exception:
@ -41,7 +50,7 @@ def get_settings():
# Add local configuration from pyproject.toml of the project being reviewed # Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Path: def _find_repository_root() -> Optional[Path]:
""" """
Identify project root directory by recursively searching for the .git directory in the parent directories. Identify project root directory by recursively searching for the .git directory in the parent directories.
""" """
@ -61,7 +70,7 @@ def _find_pyproject() -> Optional[Path]:
""" """
repo_root = _find_repository_root() repo_root = _find_repository_root()
if repo_root: if repo_root:
pyproject = _find_repository_root() / "pyproject.toml" pyproject = repo_root / "pyproject.toml"
return pyproject if pyproject.is_file() else None return pyproject if pyproject.is_file() else None
return None return None

View File

@ -8,7 +8,6 @@ from pr_agent.git_providers.local_git_provider import LocalGitProvider
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
from pr_agent.git_providers.gerrit_provider import GerritProvider from pr_agent.git_providers.gerrit_provider import GerritProvider
_GIT_PROVIDERS = { _GIT_PROVIDERS = {
'github': GithubProvider, 'github': GithubProvider,
'gitlab': GitLabProvider, 'gitlab': GitLabProvider,
@ -16,10 +15,11 @@ _GIT_PROVIDERS = {
'bitbucket_server': BitbucketServerProvider, 'bitbucket_server': BitbucketServerProvider,
'azure': AzureDevopsProvider, 'azure': AzureDevopsProvider,
'codecommit': CodeCommitProvider, 'codecommit': CodeCommitProvider,
'local' : LocalGitProvider, 'local': LocalGitProvider,
'gerrit': GerritProvider, 'gerrit': GerritProvider,
} }
def get_git_provider(): def get_git_provider():
try: try:
provider_id = get_settings().config.git_provider provider_id = get_settings().config.git_provider

View File

@ -139,8 +139,11 @@ async def run_action():
comment_id = event_payload.get("comment", {}).get("id") comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=url) provider = get_git_provider()(pr_url=url)
if is_pr: if is_pr:
await PRAgent().handle_request(url, body, await PRAgent().handle_request(
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes)) url, body, notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes
)
)
else: else:
await PRAgent().handle_request(url, body) await PRAgent().handle_request(url, body)

View File

@ -43,7 +43,6 @@ class PRDescription:
self.ai_handler = ai_handler() self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_pr_language self.ai_handler.main_pr_language = self.main_pr_language
# Initialize the variables dictionary # Initialize the variables dictionary
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
@ -221,9 +220,6 @@ class PRDescription:
if 'pr_files' in self.data: if 'pr_files' in self.data:
self.data['pr_files'] = self.data.pop('pr_files') self.data['pr_files'] = self.data.pop('pr_files')
def _prepare_labels(self) -> List[str]: def _prepare_labels(self) -> List[str]:
pr_types = [] pr_types = []
@ -321,7 +317,7 @@ class PRDescription:
value = self.file_label_dict value = self.file_label_dict
else: else:
key_publish = key.rstrip(':').replace("_", " ").capitalize() key_publish = key.rstrip(':').replace("_", " ").capitalize()
if key_publish== "Type": if key_publish == "Type":
key_publish = "PR Type" key_publish = "PR Type"
# elif key_publish == "Description": # elif key_publish == "Description":
# key_publish = "PR Description" # key_publish = "PR Description"
@ -512,6 +508,7 @@ def insert_br_after_x_chars(text, x=70):
is_inside_code = False is_inside_code = False
return ''.join(new_text).strip() return ''.join(new_text).strip()
def replace_code_tags(text): def replace_code_tags(text):
""" """
Replace odd instances of ` with <code> and even instances of ` with </code> Replace odd instances of ` with <code> and even instances of ` with </code>

View File

@ -21,6 +21,7 @@ class PRReviewer:
""" """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
""" """
@ -222,7 +223,6 @@ class PRReviewer:
else: else:
pass pass
incremental_review_markdown_text = None incremental_review_markdown_text = None
# Add incremental review section # Add incremental review section
if self.incremental.is_incremental: if self.incremental.is_incremental: