mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 04:40:38 +08:00
Merge pull request #942 from barnett-yuxiang/main
Update Python code formatting, configuration loading, and local model additions
This commit is contained in:
@ -1,17 +1,17 @@
|
|||||||
|
|
||||||
After [installation](https://codium-ai.github.io/Docs-PR-Agent/installation/), there are three basic ways to invoke CodiumAI PR-Agent:
|
After [installation](https://pr-agent-docs.codium.ai/installation/), there are three basic ways to invoke CodiumAI PR-Agent:
|
||||||
|
|
||||||
1. Locally running a CLI command
|
1. Locally running a CLI command
|
||||||
2. Online usage - by [commenting](https://github.com/Codium-ai/pr-agent/pull/229#issuecomment-1695021901) on a PR
|
2. Online usage - by [commenting](https://github.com/Codium-ai/pr-agent/pull/229#issuecomment-1695021901) on a PR
|
||||||
3. Enabling PR-Agent tools to run automatically when a new PR is opened
|
3. Enabling PR-Agent tools to run automatically when a new PR is opened
|
||||||
|
|
||||||
|
|
||||||
Specifically, CLI commands can be issued by invoking a pre-built [docker image](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-from-source), or by invoking a [locally cloned repo](https://codium-ai.github.io/Docs-PR-Agent/installation/#locally).
|
Specifically, CLI commands can be issued by invoking a pre-built [docker image](https://pr-agent-docs.codium.ai/installation/locally/#using-docker-image), or by invoking a [locally cloned repo](https://pr-agent-docs.codium.ai/installation/locally/#run-from-source).
|
||||||
For online usage, you will need to setup either a [GitHub App](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-as-a-github-app), or a [GitHub Action](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-as-a-github-action).
|
For online usage, you will need to setup either a [GitHub App](https://pr-agent-docs.codium.ai/installation/github/#run-as-a-github-app), or a [GitHub Action](https://pr-agent-docs.codium.ai/installation/github/#run-as-a-github-action).
|
||||||
GitHub App and GitHub Action also enable to run PR-Agent specific tool automatically when a new PR is opened.
|
GitHub App and GitHub Action also enable to run PR-Agent specific tool automatically when a new PR is opened.
|
||||||
|
|
||||||
|
|
||||||
**git provider**: The [git_provider](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L4) field in the configuration file determines the GIT provider that will be used by PR-Agent. Currently, the following providers are supported:
|
**git provider**: The [git_provider](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L5) field in the configuration file determines the GIT provider that will be used by PR-Agent. Currently, the following providers are supported:
|
||||||
`
|
`
|
||||||
"github", "gitlab", "bitbucket", "azure", "codecommit", "local", "gerrit"
|
"github", "gitlab", "bitbucket", "azure", "codecommit", "local", "gerrit"
|
||||||
`
|
`
|
||||||
|
@ -46,9 +46,10 @@ command2class = {
|
|||||||
|
|
||||||
commands = list(command2class.keys())
|
commands = list(command2class.keys())
|
||||||
|
|
||||||
|
|
||||||
class PRAgent:
|
class PRAgent:
|
||||||
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
self.ai_handler = ai_handler # will be initialized in run_action
|
self.ai_handler = ai_handler # will be initialized in run_action
|
||||||
self.forbidden_cli_args = ['enable_auto_approval']
|
self.forbidden_cli_args = ['enable_auto_approval']
|
||||||
|
|
||||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||||
@ -68,7 +69,9 @@ class PRAgent:
|
|||||||
for forbidden_arg in self.forbidden_cli_args:
|
for forbidden_arg in self.forbidden_cli_args:
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if forbidden_arg in arg:
|
if forbidden_arg in arg:
|
||||||
get_logger().error(f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file.")
|
get_logger().error(
|
||||||
|
f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file."
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
args = update_settings_from_args(args)
|
args = update_settings_from_args(args)
|
||||||
|
|
||||||
@ -94,4 +97,3 @@ class PRAgent:
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ MAX_TOKENS = {
|
|||||||
'gpt-4': 8000,
|
'gpt-4': 8000,
|
||||||
'gpt-4-0613': 8000,
|
'gpt-4-0613': 8000,
|
||||||
'gpt-4-32k': 32000,
|
'gpt-4-32k': 32000,
|
||||||
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||||
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||||
'gpt-4o': 128000, # 128K, but may be limited by config.max_model_tokens
|
'gpt-4o': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||||
'gpt-4o-2024-05-13': 128000, # 128K, but may be limited by config.max_model_tokens
|
'gpt-4o-2024-05-13': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||||
@ -36,4 +36,5 @@ MAX_TOKENS = {
|
|||||||
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
|
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
|
||||||
'groq/llama3-8b-8192': 8192,
|
'groq/llama3-8b-8192': 8192,
|
||||||
'groq/llama3-70b-8192': 8192,
|
'groq/llama3-70b-8192': 8192,
|
||||||
|
'ollama/llama3': 4096,
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class BaseAiHandler(ABC):
|
class BaseAiHandler(ABC):
|
||||||
"""
|
"""
|
||||||
This class defines the interface for an AI handler to be used by the PR Agents.
|
This class defines the interface for an AI handler to be used by the PR Agents.
|
||||||
@ -14,7 +15,7 @@ class BaseAiHandler(ABC):
|
|||||||
def deployment_id(self):
|
def deployment_id(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
||||||
"""
|
"""
|
||||||
This method should be implemented to return a chat completion from the AI model.
|
This method should be implemented to return a chat completion from the AI model.
|
||||||
@ -25,4 +26,3 @@ class BaseAiHandler(ABC):
|
|||||||
temperature (float): the temperature to use for the chat completion
|
temperature (float): the temperature to use for the chat completion
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
try:
|
try:
|
||||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
||||||
from langchain.schema import SystemMessage, HumanMessage
|
from langchain.schema import SystemMessage, HumanMessage
|
||||||
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
|
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
@ -14,6 +14,7 @@ import functools
|
|||||||
|
|
||||||
OPENAI_RETRIES = 5
|
OPENAI_RETRIES = 5
|
||||||
|
|
||||||
|
|
||||||
class LangChainOpenAIHandler(BaseAiHandler):
|
class LangChainOpenAIHandler(BaseAiHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Initialize OpenAIHandler specific attributes here
|
# Initialize OpenAIHandler specific attributes here
|
||||||
@ -36,7 +37,7 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
|||||||
raise ValueError(f"OpenAI {e.name} is required") from e
|
raise ValueError(f"OpenAI {e.name} is required") from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat(self):
|
def chat(self):
|
||||||
if self.azure:
|
if self.azure:
|
||||||
@ -51,17 +52,18 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
|||||||
Returns the deployment ID for the OpenAI API.
|
Returns the deployment ID for the OpenAI API.
|
||||||
"""
|
"""
|
||||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
|
||||||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
||||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||||
try:
|
try:
|
||||||
messages=[SystemMessage(content=system), HumanMessage(content=user)]
|
messages = [SystemMessage(content=system), HumanMessage(content=user)]
|
||||||
|
|
||||||
# get a chat completion from the formatted messages
|
# get a chat completion from the formatted messages
|
||||||
resp = self.chat(messages, model=model, temperature=temperature)
|
resp = self.chat(messages, model=model, temperature=temperature)
|
||||||
finish_reason="completed"
|
finish_reason = "completed"
|
||||||
return resp.content, finish_reason
|
return resp.content, finish_reason
|
||||||
|
|
||||||
except (Exception) as e:
|
except (Exception) as e:
|
||||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -61,7 +61,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
|
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
|
||||||
litellm.api_base = get_settings().huggingface.api_base
|
litellm.api_base = get_settings().huggingface.api_base
|
||||||
self.api_base = get_settings().huggingface.api_base
|
self.api_base = get_settings().huggingface.api_base
|
||||||
if get_settings().get("OLLAMA.API_BASE", None) :
|
if get_settings().get("OLLAMA.API_BASE", None):
|
||||||
litellm.api_base = get_settings().ollama.api_base
|
litellm.api_base = get_settings().ollama.api_base
|
||||||
self.api_base = get_settings().ollama.api_base
|
self.api_base = get_settings().ollama.api_base
|
||||||
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None):
|
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None):
|
||||||
@ -129,7 +129,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"force_timeout": get_settings().config.ai_timeout,
|
"force_timeout": get_settings().config.ai_timeout,
|
||||||
"api_base" : self.api_base,
|
"api_base": self.api_base,
|
||||||
}
|
}
|
||||||
if self.aws_bedrock_client:
|
if self.aws_bedrock_client:
|
||||||
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
|
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
|
||||||
|
@ -28,13 +28,14 @@ class OpenAIHandler(BaseAiHandler):
|
|||||||
|
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("OpenAI key is required") from e
|
raise ValueError("OpenAI key is required") from e
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def deployment_id(self):
|
def deployment_id(self):
|
||||||
"""
|
"""
|
||||||
Returns the deployment ID for the OpenAI API.
|
Returns the deployment ID for the OpenAI API.
|
||||||
"""
|
"""
|
||||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
|
||||||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
||||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||||
@ -54,8 +55,8 @@ class OpenAIHandler(BaseAiHandler):
|
|||||||
finish_reason = chat_completion["choices"][0]["finish_reason"]
|
finish_reason = chat_completion["choices"][0]["finish_reason"]
|
||||||
usage = chat_completion.get("usage")
|
usage = chat_completion.get("usage")
|
||||||
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
|
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
|
||||||
model=model, usage=usage)
|
model=model, usage=usage)
|
||||||
return resp, finish_reason
|
return resp, finish_reason
|
||||||
except (APIError, Timeout, TryAgain) as e:
|
except (APIError, Timeout, TryAgain) as e:
|
||||||
get_logger().error("Error during OpenAI inference: ", e)
|
get_logger().error("Error during OpenAI inference: ", e)
|
||||||
raise
|
raise
|
||||||
@ -64,4 +65,4 @@ class OpenAIHandler(BaseAiHandler):
|
|||||||
raise
|
raise
|
||||||
except (Exception) as e:
|
except (Exception) as e:
|
||||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||||
raise TryAgain from e
|
raise TryAgain from e
|
||||||
|
@ -3,6 +3,7 @@ import re
|
|||||||
|
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
|
|
||||||
def filter_ignored(files):
|
def filter_ignored(files):
|
||||||
"""
|
"""
|
||||||
Filter out files that match the ignore patterns.
|
Filter out files that match the ignore patterns.
|
||||||
@ -14,7 +15,7 @@ def filter_ignored(files):
|
|||||||
if isinstance(patterns, str):
|
if isinstance(patterns, str):
|
||||||
patterns = [patterns]
|
patterns = [patterns]
|
||||||
glob_setting = get_settings().ignore.glob
|
glob_setting = get_settings().ignore.glob
|
||||||
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
|
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
|
||||||
glob_setting = glob_setting.strip('[]').split(",")
|
glob_setting = glob_setting.strip('[]').split(",")
|
||||||
patterns += [fnmatch.translate(glob) for glob in glob_setting]
|
patterns += [fnmatch.translate(glob) for glob in glob_setting]
|
||||||
|
|
||||||
|
@ -409,7 +409,7 @@ def update_settings_from_args(args: List[str]) -> List[str]:
|
|||||||
arg = arg.strip('-').strip()
|
arg = arg.strip('-').strip()
|
||||||
vals = arg.split('=', 1)
|
vals = arg.split('=', 1)
|
||||||
if len(vals) != 2:
|
if len(vals) != 2:
|
||||||
if len(vals) > 2: # --extended is a valid argument
|
if len(vals) > 2: # --extended is a valid argument
|
||||||
get_logger().error(f'Invalid argument format: {arg}')
|
get_logger().error(f'Invalid argument format: {arg}')
|
||||||
other_args.append(arg)
|
other_args.append(arg)
|
||||||
continue
|
continue
|
||||||
|
@ -9,6 +9,7 @@ from pr_agent.log import setup_logger
|
|||||||
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
||||||
setup_logger(log_level)
|
setup_logger(log_level)
|
||||||
|
|
||||||
|
|
||||||
def set_parser():
|
def set_parser():
|
||||||
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
|
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
|
||||||
"""\
|
"""\
|
||||||
@ -50,6 +51,7 @@ def set_parser():
|
|||||||
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def run_command(pr_url, command):
|
def run_command(pr_url, command):
|
||||||
# Preparing the command
|
# Preparing the command
|
||||||
run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}"
|
run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}"
|
||||||
@ -58,6 +60,7 @@ def run_command(pr_url, command):
|
|||||||
# Run the command. Feedback will appear in GitHub PR comments
|
# Run the command. Feedback will appear in GitHub PR comments
|
||||||
run(args=args)
|
run(args=args)
|
||||||
|
|
||||||
|
|
||||||
def run(inargs=None, args=None):
|
def run(inargs=None, args=None):
|
||||||
parser = set_parser()
|
parser = set_parser()
|
||||||
if not args:
|
if not args:
|
||||||
|
@ -34,6 +34,15 @@ global_settings = Dynaconf(
|
|||||||
|
|
||||||
|
|
||||||
def get_settings():
|
def get_settings():
|
||||||
|
"""
|
||||||
|
Retrieves the current settings.
|
||||||
|
|
||||||
|
This function attempts to fetch the settings from the starlette_context's context object. If it fails,
|
||||||
|
it defaults to the global settings defined outside of this function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dynaconf: The current settings object, either from the context or the global default.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return context["settings"]
|
return context["settings"]
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -41,7 +50,7 @@ def get_settings():
|
|||||||
|
|
||||||
|
|
||||||
# Add local configuration from pyproject.toml of the project being reviewed
|
# Add local configuration from pyproject.toml of the project being reviewed
|
||||||
def _find_repository_root() -> Path:
|
def _find_repository_root() -> Optional[Path]:
|
||||||
"""
|
"""
|
||||||
Identify project root directory by recursively searching for the .git directory in the parent directories.
|
Identify project root directory by recursively searching for the .git directory in the parent directories.
|
||||||
"""
|
"""
|
||||||
@ -61,7 +70,7 @@ def _find_pyproject() -> Optional[Path]:
|
|||||||
"""
|
"""
|
||||||
repo_root = _find_repository_root()
|
repo_root = _find_repository_root()
|
||||||
if repo_root:
|
if repo_root:
|
||||||
pyproject = _find_repository_root() / "pyproject.toml"
|
pyproject = repo_root / "pyproject.toml"
|
||||||
return pyproject if pyproject.is_file() else None
|
return pyproject if pyproject.is_file() else None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ from pr_agent.git_providers.local_git_provider import LocalGitProvider
|
|||||||
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
||||||
from pr_agent.git_providers.gerrit_provider import GerritProvider
|
from pr_agent.git_providers.gerrit_provider import GerritProvider
|
||||||
|
|
||||||
|
|
||||||
_GIT_PROVIDERS = {
|
_GIT_PROVIDERS = {
|
||||||
'github': GithubProvider,
|
'github': GithubProvider,
|
||||||
'gitlab': GitLabProvider,
|
'gitlab': GitLabProvider,
|
||||||
@ -16,10 +15,11 @@ _GIT_PROVIDERS = {
|
|||||||
'bitbucket_server': BitbucketServerProvider,
|
'bitbucket_server': BitbucketServerProvider,
|
||||||
'azure': AzureDevopsProvider,
|
'azure': AzureDevopsProvider,
|
||||||
'codecommit': CodeCommitProvider,
|
'codecommit': CodeCommitProvider,
|
||||||
'local' : LocalGitProvider,
|
'local': LocalGitProvider,
|
||||||
'gerrit': GerritProvider,
|
'gerrit': GerritProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_git_provider():
|
def get_git_provider():
|
||||||
try:
|
try:
|
||||||
provider_id = get_settings().config.git_provider
|
provider_id = get_settings().config.git_provider
|
||||||
|
@ -127,7 +127,7 @@ async def run_action():
|
|||||||
if event_payload.get("issue", {}).get("pull_request"):
|
if event_payload.get("issue", {}).get("pull_request"):
|
||||||
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
|
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
|
||||||
is_pr = True
|
is_pr = True
|
||||||
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
|
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
|
||||||
url = event_payload.get("comment", {}).get("pull_request_url")
|
url = event_payload.get("comment", {}).get("pull_request_url")
|
||||||
is_pr = True
|
is_pr = True
|
||||||
disable_eyes = True
|
disable_eyes = True
|
||||||
@ -139,8 +139,11 @@ async def run_action():
|
|||||||
comment_id = event_payload.get("comment", {}).get("id")
|
comment_id = event_payload.get("comment", {}).get("id")
|
||||||
provider = get_git_provider()(pr_url=url)
|
provider = get_git_provider()(pr_url=url)
|
||||||
if is_pr:
|
if is_pr:
|
||||||
await PRAgent().handle_request(url, body,
|
await PRAgent().handle_request(
|
||||||
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
|
url, body, notify=lambda: provider.add_eyes_reaction(
|
||||||
|
comment_id, disable_eyes=disable_eyes
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await PRAgent().handle_request(url, body)
|
await PRAgent().handle_request(url, body)
|
||||||
|
|
||||||
|
@ -43,7 +43,6 @@ class PRDescription:
|
|||||||
self.ai_handler = ai_handler()
|
self.ai_handler = ai_handler()
|
||||||
self.ai_handler.main_pr_language = self.main_pr_language
|
self.ai_handler.main_pr_language = self.main_pr_language
|
||||||
|
|
||||||
|
|
||||||
# Initialize the variables dictionary
|
# Initialize the variables dictionary
|
||||||
self.vars = {
|
self.vars = {
|
||||||
"title": self.git_provider.pr.title,
|
"title": self.git_provider.pr.title,
|
||||||
@ -157,7 +156,7 @@ class PRDescription:
|
|||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error generating PR description {self.pr_id}: {e}")
|
get_logger().error(f"Error generating PR description {self.pr_id}: {e}")
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str) -> None:
|
async def _prepare_prediction(self, model: str) -> None:
|
||||||
@ -221,9 +220,6 @@ class PRDescription:
|
|||||||
if 'pr_files' in self.data:
|
if 'pr_files' in self.data:
|
||||||
self.data['pr_files'] = self.data.pop('pr_files')
|
self.data['pr_files'] = self.data.pop('pr_files')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_labels(self) -> List[str]:
|
def _prepare_labels(self) -> List[str]:
|
||||||
pr_types = []
|
pr_types = []
|
||||||
|
|
||||||
@ -321,7 +317,7 @@ class PRDescription:
|
|||||||
value = self.file_label_dict
|
value = self.file_label_dict
|
||||||
else:
|
else:
|
||||||
key_publish = key.rstrip(':').replace("_", " ").capitalize()
|
key_publish = key.rstrip(':').replace("_", " ").capitalize()
|
||||||
if key_publish== "Type":
|
if key_publish == "Type":
|
||||||
key_publish = "PR Type"
|
key_publish = "PR Type"
|
||||||
# elif key_publish == "Description":
|
# elif key_publish == "Description":
|
||||||
# key_publish = "PR Description"
|
# key_publish = "PR Description"
|
||||||
@ -512,6 +508,7 @@ def insert_br_after_x_chars(text, x=70):
|
|||||||
is_inside_code = False
|
is_inside_code = False
|
||||||
return ''.join(new_text).strip()
|
return ''.join(new_text).strip()
|
||||||
|
|
||||||
|
|
||||||
def replace_code_tags(text):
|
def replace_code_tags(text):
|
||||||
"""
|
"""
|
||||||
Replace odd instances of ` with <code> and even instances of ` with </code>
|
Replace odd instances of ` with <code> and even instances of ` with </code>
|
||||||
@ -519,4 +516,4 @@ def replace_code_tags(text):
|
|||||||
parts = text.split('`')
|
parts = text.split('`')
|
||||||
for i in range(1, len(parts), 2):
|
for i in range(1, len(parts), 2):
|
||||||
parts[i] = '<code>' + parts[i] + '</code>'
|
parts[i] = '<code>' + parts[i] + '</code>'
|
||||||
return ''.join(parts)
|
return ''.join(parts)
|
||||||
|
@ -21,6 +21,7 @@ class PRReviewer:
|
|||||||
"""
|
"""
|
||||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
|
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
|
||||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
"""
|
"""
|
||||||
@ -34,7 +35,7 @@ class PRReviewer:
|
|||||||
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = args
|
self.args = args
|
||||||
self.parse_args(args) # -i command
|
self.parse_args(args) # -i command
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
|
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
@ -222,7 +223,6 @@ class PRReviewer:
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
incremental_review_markdown_text = None
|
incremental_review_markdown_text = None
|
||||||
# Add incremental review section
|
# Add incremental review section
|
||||||
if self.incremental.is_incremental:
|
if self.incremental.is_incremental:
|
||||||
@ -278,7 +278,7 @@ class PRReviewer:
|
|||||||
self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)
|
self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)
|
||||||
|
|
||||||
if comments:
|
if comments:
|
||||||
self.git_provider.publish_inline_comments(comments)
|
self.git_provider.publish_inline_comments(comments)
|
||||||
|
|
||||||
def _get_user_answers(self) -> Tuple[str, str]:
|
def _get_user_answers(self) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
@ -373,10 +373,10 @@ class PRReviewer:
|
|||||||
if get_settings().pr_reviewer.enable_review_labels_effort:
|
if get_settings().pr_reviewer.enable_review_labels_effort:
|
||||||
estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
|
estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
|
||||||
estimated_effort_number = int(estimated_effort.split(',')[0])
|
estimated_effort_number = int(estimated_effort.split(',')[0])
|
||||||
if 1 <= estimated_effort_number <= 5: # 1, because ...
|
if 1 <= estimated_effort_number <= 5: # 1, because ...
|
||||||
review_labels.append(f'Review effort [1-5]: {estimated_effort_number}')
|
review_labels.append(f'Review effort [1-5]: {estimated_effort_number}')
|
||||||
if get_settings().pr_reviewer.enable_review_labels_security:
|
if get_settings().pr_reviewer.enable_review_labels_security:
|
||||||
security_concerns = data['review']['security_concerns'] # yes, because ...
|
security_concerns = data['review']['security_concerns'] # yes, because ...
|
||||||
security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower()
|
security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower()
|
||||||
if security_concerns_bool:
|
if security_concerns_bool:
|
||||||
review_labels.append('Possible security concern')
|
review_labels.append('Possible security concern')
|
||||||
@ -426,4 +426,4 @@ class PRReviewer:
|
|||||||
else:
|
else:
|
||||||
get_logger().info("Auto-approval option is disabled")
|
get_logger().info("Auto-approval option is disabled")
|
||||||
self.git_provider.publish_comment("Auto-approval option for PR-Agent is disabled. "
|
self.git_provider.publish_comment("Auto-approval option for PR-Agent is disabled. "
|
||||||
"You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")
|
"You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")
|
||||||
|
Reference in New Issue
Block a user