Compare commits

..

10 Commits

Author SHA1 Message Date
d23daf880f Change gitlab API to use oauth_token instead of PAT (PAT shuold work as well) 2023-07-25 13:58:48 +03:00
adb3f17258 Merge pull request #131 from Codium-ai/ok/gitlab_webook
GitLab Webhook Integration and Provider Enhancements
2023-07-24 16:01:17 +03:00
55eb741965 Merge pull request #125 from Codium-ai/tr/code_enhancment
Code Enhancement in PR Agent
2023-07-24 15:37:53 +03:00
cca809e91c run_action 2023-07-24 12:45:24 +03:00
57ff46ecc1 stable 2023-07-24 12:41:00 +03:00
3819d52eb0 Merge remote-tracking branch 'origin/tr/code_enhancment' into tr/code_enhancment 2023-07-24 12:15:17 +03:00
3072325d2c PRDescription 2023-07-24 12:14:53 +03:00
abca2fdcb7 Merge remote-tracking branch 'origin/main' into tr/code_enhancment 2023-07-24 12:04:54 +03:00
4d84f76948 _get_prediction 2023-07-24 11:31:35 +03:00
1bf27c38a7 _prepare_pr_answer 2023-07-24 09:15:45 +03:00
3 changed files with 120 additions and 34 deletions

View File

@ -14,7 +14,6 @@ from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url:
@ -23,8 +22,8 @@ class GitLabProvider(GitProvider):
if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab(
gitlab_url,
gitlab_access_token
url=gitlab_url,
oauth_token=gitlab_access_token
)
self.id_project = None
self.id_mr = None

View File

@ -8,46 +8,57 @@ from pr_agent.tools.pr_reviewer import PRReviewer
async def run_action():
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME', None)
# Get environment variables
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
OPENAI_KEY = os.environ.get('OPENAI_KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
# Check if required environment variables are set
if not GITHUB_EVENT_NAME:
print("GITHUB_EVENT_NAME not set")
return
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH', None)
if not GITHUB_EVENT_PATH:
print("GITHUB_EVENT_PATH not set")
return
try:
event_payload = json.load(open(GITHUB_EVENT_PATH, 'r'))
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
OPENAI_KEY = os.environ.get('OPENAI_KEY', None)
if not OPENAI_KEY:
print("OPENAI_KEY not set")
return
OPENAI_ORG = os.environ.get('OPENAI_ORG', None)
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN', None)
if not GITHUB_TOKEN:
print("GITHUB_TOKEN not set")
return
# Set the environment variables in the settings
settings.set("OPENAI.KEY", OPENAI_KEY)
if OPENAI_ORG:
settings.set("OPENAI.ORG", OPENAI_ORG)
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
# Load the event payload
try:
with open(GITHUB_EVENT_PATH, 'r') as f:
event_payload = json.load(f)
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
# Handle pull request event
if GITHUB_EVENT_NAME == "pull_request":
action = event_payload.get("action", None)
action = event_payload.get("action")
if action in ["opened", "reopened"]:
pr_url = event_payload.get("pull_request", {}).get("url", None)
pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url:
await PRReviewer(pr_url).review()
# Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment":
action = event_payload.get("action", None)
action = event_payload.get("action")
if action in ["created", "edited"]:
comment_body = event_payload.get("comment", {}).get("body", None)
comment_body = event_payload.get("comment", {}).get("body")
if comment_body:
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url", None)
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
if pr_url:
body = comment_body.strip().lower()
await PRAgent().handle_request(pr_url, body)

View File

@ -1,6 +1,7 @@
import copy
import json
import logging
from typing import Tuple, List
from jinja2 import Environment, StrictUndefined
@ -14,11 +15,22 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
class PRDescription:
def __init__(self, pr_url: str):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model.
Args:
pr_url (str): The URL of the pull request.
"""
# Initialize the git provider and main PR language
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
# Initialize the AI handler
self.ai_handler = AiHandler()
# Initialize the variables dictionary
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
@ -26,20 +38,32 @@ class PRDescription:
"language": self.main_pr_language,
"diff": "", # empty diff for initial calculation
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_description_prompt.system,
settings.pr_description_prompt.user)
# Initialize the token handler
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
settings.pr_description_prompt.system,
settings.pr_description_prompt.user,
)
# Initialize patches_diff and prediction attributes
self.patches_diff = None
self.prediction = None
async def describe(self):
"""
Generates a PR description using an AI model and publishes it to the PR.
"""
logging.info('Generating a PR description...')
if settings.config.publish_output:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing answer...')
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
if settings.config.publish_output:
logging.info('Pushing answer...')
if settings.pr_description.publish_description_as_comment:
@ -52,45 +76,97 @@ class PRDescription:
current_labels = []
self.git_provider.publish_labels(pr_types + current_labels)
self.git_provider.remove_initial_comment()
return ""
async def _prepare_prediction(self, model: str):
async def _prepare_prediction(self, model: str) -> None:
"""
Prepare the AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
None
Raises:
Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions.
"""
logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str):
async def _get_prediction(self, model: str) -> str:
"""
Generate an AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
str: The generated AI prediction.
"""
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=0.2,
system=system_prompt,
user=user_prompt
)
return response
def _prepare_pr_answer(self):
def _prepare_pr_answer(self) -> Tuple[str, str, List[str], str]:
"""
Prepare the PR description based on the AI prediction data.
Returns:
- title: a string containing the PR title.
- pr_body: a string containing the PR body in a markdown format.
- pr_types: a list of strings containing the PR types.
- markdown_text: a string containing the AI prediction data in a markdown format.
"""
# Load the AI prediction data into a dictionary
data = json.loads(self.prediction)
markdown_text = ""
# Initialization
markdown_text = pr_body = ""
pr_types = []
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
for key, value in data.items():
markdown_text += f"## {key}\n\n"
markdown_text += f"{value}\n\n"
pr_body = ""
pr_types = []
# If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
if 'PR Type' in data:
pr_types = data['PR Type'].split(',')
title = data['PR Title']
del data['PR Title']
# Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary
title = data.pop('PR Title')
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
# except for the items containing the word 'walkthrough'
for key, value in data.items():
pr_body += f"{key}:\n"
if 'walkthrough' in key.lower():
pr_body += f"{value}\n"
else:
pr_body += f"**{value}**\n\n___\n"
if settings.config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}")
return title, pr_body, pr_types, markdown_text