mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 05:10:38 +08:00
Compare commits
10 Commits
ok/gitlab_
...
ok/gitlat_
Author | SHA1 | Date | |
---|---|---|---|
d23daf880f | |||
adb3f17258 | |||
55eb741965 | |||
cca809e91c | |||
57ff46ecc1 | |||
3819d52eb0 | |||
3072325d2c | |||
abca2fdcb7 | |||
4d84f76948 | |||
1bf27c38a7 |
@ -14,7 +14,6 @@ from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
|||||||
|
|
||||||
class GitLabProvider(GitProvider):
|
class GitLabProvider(GitProvider):
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
|
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||||
gitlab_url = settings.get("GITLAB.URL", None)
|
gitlab_url = settings.get("GITLAB.URL", None)
|
||||||
if not gitlab_url:
|
if not gitlab_url:
|
||||||
@ -23,8 +22,8 @@ class GitLabProvider(GitProvider):
|
|||||||
if not gitlab_access_token:
|
if not gitlab_access_token:
|
||||||
raise ValueError("GitLab personal access token is not set in the config file")
|
raise ValueError("GitLab personal access token is not set in the config file")
|
||||||
self.gl = gitlab.Gitlab(
|
self.gl = gitlab.Gitlab(
|
||||||
gitlab_url,
|
url=gitlab_url,
|
||||||
gitlab_access_token
|
oauth_token=gitlab_access_token
|
||||||
)
|
)
|
||||||
self.id_project = None
|
self.id_project = None
|
||||||
self.id_mr = None
|
self.id_mr = None
|
||||||
|
@ -8,46 +8,57 @@ from pr_agent.tools.pr_reviewer import PRReviewer
|
|||||||
|
|
||||||
|
|
||||||
async def run_action():
|
async def run_action():
|
||||||
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME', None)
|
# Get environment variables
|
||||||
|
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
|
||||||
|
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
|
||||||
|
OPENAI_KEY = os.environ.get('OPENAI_KEY')
|
||||||
|
OPENAI_ORG = os.environ.get('OPENAI_ORG')
|
||||||
|
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
|
||||||
|
|
||||||
|
# Check if required environment variables are set
|
||||||
if not GITHUB_EVENT_NAME:
|
if not GITHUB_EVENT_NAME:
|
||||||
print("GITHUB_EVENT_NAME not set")
|
print("GITHUB_EVENT_NAME not set")
|
||||||
return
|
return
|
||||||
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH', None)
|
|
||||||
if not GITHUB_EVENT_PATH:
|
if not GITHUB_EVENT_PATH:
|
||||||
print("GITHUB_EVENT_PATH not set")
|
print("GITHUB_EVENT_PATH not set")
|
||||||
return
|
return
|
||||||
try:
|
|
||||||
event_payload = json.load(open(GITHUB_EVENT_PATH, 'r'))
|
|
||||||
except json.decoder.JSONDecodeError as e:
|
|
||||||
print(f"Failed to parse JSON: {e}")
|
|
||||||
return
|
|
||||||
OPENAI_KEY = os.environ.get('OPENAI_KEY', None)
|
|
||||||
if not OPENAI_KEY:
|
if not OPENAI_KEY:
|
||||||
print("OPENAI_KEY not set")
|
print("OPENAI_KEY not set")
|
||||||
return
|
return
|
||||||
OPENAI_ORG = os.environ.get('OPENAI_ORG', None)
|
|
||||||
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN', None)
|
|
||||||
if not GITHUB_TOKEN:
|
if not GITHUB_TOKEN:
|
||||||
print("GITHUB_TOKEN not set")
|
print("GITHUB_TOKEN not set")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Set the environment variables in the settings
|
||||||
settings.set("OPENAI.KEY", OPENAI_KEY)
|
settings.set("OPENAI.KEY", OPENAI_KEY)
|
||||||
if OPENAI_ORG:
|
if OPENAI_ORG:
|
||||||
settings.set("OPENAI.ORG", OPENAI_ORG)
|
settings.set("OPENAI.ORG", OPENAI_ORG)
|
||||||
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
|
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
|
||||||
settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
|
settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
|
||||||
|
|
||||||
|
# Load the event payload
|
||||||
|
try:
|
||||||
|
with open(GITHUB_EVENT_PATH, 'r') as f:
|
||||||
|
event_payload = json.load(f)
|
||||||
|
except json.decoder.JSONDecodeError as e:
|
||||||
|
print(f"Failed to parse JSON: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle pull request event
|
||||||
if GITHUB_EVENT_NAME == "pull_request":
|
if GITHUB_EVENT_NAME == "pull_request":
|
||||||
action = event_payload.get("action", None)
|
action = event_payload.get("action")
|
||||||
if action in ["opened", "reopened"]:
|
if action in ["opened", "reopened"]:
|
||||||
pr_url = event_payload.get("pull_request", {}).get("url", None)
|
pr_url = event_payload.get("pull_request", {}).get("url")
|
||||||
if pr_url:
|
if pr_url:
|
||||||
await PRReviewer(pr_url).review()
|
await PRReviewer(pr_url).review()
|
||||||
|
|
||||||
|
# Handle issue comment event
|
||||||
elif GITHUB_EVENT_NAME == "issue_comment":
|
elif GITHUB_EVENT_NAME == "issue_comment":
|
||||||
action = event_payload.get("action", None)
|
action = event_payload.get("action")
|
||||||
if action in ["created", "edited"]:
|
if action in ["created", "edited"]:
|
||||||
comment_body = event_payload.get("comment", {}).get("body", None)
|
comment_body = event_payload.get("comment", {}).get("body")
|
||||||
if comment_body:
|
if comment_body:
|
||||||
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url", None)
|
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
|
||||||
if pr_url:
|
if pr_url:
|
||||||
body = comment_body.strip().lower()
|
body = comment_body.strip().lower()
|
||||||
await PRAgent().handle_request(pr_url, body)
|
await PRAgent().handle_request(pr_url, body)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
@ -14,11 +15,22 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
|
|||||||
|
|
||||||
class PRDescription:
|
class PRDescription:
|
||||||
def __init__(self, pr_url: str):
|
def __init__(self, pr_url: str):
|
||||||
|
"""
|
||||||
|
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model.
|
||||||
|
Args:
|
||||||
|
pr_url (str): The URL of the pull request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize the git provider and main PR language
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_pr_language = get_main_pr_language(
|
self.main_pr_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize the AI handler
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
|
|
||||||
|
# Initialize the variables dictionary
|
||||||
self.vars = {
|
self.vars = {
|
||||||
"title": self.git_provider.pr.title,
|
"title": self.git_provider.pr.title,
|
||||||
"branch": self.git_provider.get_pr_branch(),
|
"branch": self.git_provider.get_pr_branch(),
|
||||||
@ -26,20 +38,32 @@ class PRDescription:
|
|||||||
"language": self.main_pr_language,
|
"language": self.main_pr_language,
|
||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
|
||||||
|
# Initialize the token handler
|
||||||
|
self.token_handler = TokenHandler(
|
||||||
|
self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
settings.pr_description_prompt.system,
|
settings.pr_description_prompt.system,
|
||||||
settings.pr_description_prompt.user)
|
settings.pr_description_prompt.user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize patches_diff and prediction attributes
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
|
||||||
async def describe(self):
|
async def describe(self):
|
||||||
|
"""
|
||||||
|
Generates a PR description using an AI model and publishes it to the PR.
|
||||||
|
"""
|
||||||
logging.info('Generating a PR description...')
|
logging.info('Generating a PR description...')
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
|
||||||
|
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
|
|
||||||
logging.info('Preparing answer...')
|
logging.info('Preparing answer...')
|
||||||
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
|
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
|
||||||
|
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
logging.info('Pushing answer...')
|
logging.info('Pushing answer...')
|
||||||
if settings.pr_description.publish_description_as_comment:
|
if settings.pr_description.publish_description_as_comment:
|
||||||
@ -52,45 +76,97 @@ class PRDescription:
|
|||||||
current_labels = []
|
current_labels = []
|
||||||
self.git_provider.publish_labels(pr_types + current_labels)
|
self.git_provider.publish_labels(pr_types + current_labels)
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str):
|
async def _prepare_prediction(self, model: str) -> None:
|
||||||
|
"""
|
||||||
|
Prepare the AI prediction for the PR description based on the provided model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The name of the model to be used for generating the prediction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions.
|
||||||
|
|
||||||
|
"""
|
||||||
logging.info('Getting PR diff...')
|
logging.info('Getting PR diff...')
|
||||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||||
logging.info('Getting AI prediction...')
|
logging.info('Getting AI prediction...')
|
||||||
self.prediction = await self._get_prediction(model)
|
self.prediction = await self._get_prediction(model)
|
||||||
|
|
||||||
async def _get_prediction(self, model: str):
|
async def _get_prediction(self, model: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate an AI prediction for the PR description based on the provided model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The name of the model to be used for generating the prediction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated AI prediction.
|
||||||
|
"""
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
|
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
|
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
|
||||||
|
|
||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
|
||||||
system=system_prompt, user=user_prompt)
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
|
model=model,
|
||||||
|
temperature=0.2,
|
||||||
|
system=system_prompt,
|
||||||
|
user=user_prompt
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _prepare_pr_answer(self):
|
def _prepare_pr_answer(self) -> Tuple[str, str, List[str], str]:
|
||||||
|
"""
|
||||||
|
Prepare the PR description based on the AI prediction data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- title: a string containing the PR title.
|
||||||
|
- pr_body: a string containing the PR body in a markdown format.
|
||||||
|
- pr_types: a list of strings containing the PR types.
|
||||||
|
- markdown_text: a string containing the AI prediction data in a markdown format.
|
||||||
|
"""
|
||||||
|
# Load the AI prediction data into a dictionary
|
||||||
data = json.loads(self.prediction)
|
data = json.loads(self.prediction)
|
||||||
markdown_text = ""
|
|
||||||
|
# Initialization
|
||||||
|
markdown_text = pr_body = ""
|
||||||
|
pr_types = []
|
||||||
|
|
||||||
|
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
markdown_text += f"## {key}\n\n"
|
markdown_text += f"## {key}\n\n"
|
||||||
markdown_text += f"{value}\n\n"
|
markdown_text += f"{value}\n\n"
|
||||||
pr_body = ""
|
|
||||||
pr_types = []
|
# If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
|
||||||
if 'PR Type' in data:
|
if 'PR Type' in data:
|
||||||
pr_types = data['PR Type'].split(',')
|
pr_types = data['PR Type'].split(',')
|
||||||
title = data['PR Title']
|
|
||||||
del data['PR Title']
|
# Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary
|
||||||
|
title = data.pop('PR Title')
|
||||||
|
|
||||||
|
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
|
||||||
|
# except for the items containing the word 'walkthrough'
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
pr_body += f"{key}:\n"
|
pr_body += f"{key}:\n"
|
||||||
if 'walkthrough' in key.lower():
|
if 'walkthrough' in key.lower():
|
||||||
pr_body += f"{value}\n"
|
pr_body += f"{value}\n"
|
||||||
else:
|
else:
|
||||||
pr_body += f"**{value}**\n\n___\n"
|
pr_body += f"**{value}**\n\n___\n"
|
||||||
|
|
||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"title:\n{title}\n{pr_body}")
|
logging.info(f"title:\n{title}\n{pr_body}")
|
||||||
|
|
||||||
return title, pr_body, pr_types, markdown_text
|
return title, pr_body, pr_types, markdown_text
|
Reference in New Issue
Block a user