Compare commits

..

25 Commits

Author SHA1 Message Date
8ae5faca53 Fix cyclic dependency 2023-07-25 16:52:18 +03:00
1229fba346 + settings.github.ratelimit_retries setup in configuration.toml 2023-07-25 16:37:13 +03:00
f6036e936e + settings.github.ratelimit_retries setup in configuration.toml 2023-07-25 15:23:40 +03:00
3b334805ee still need GithubException.RateLimitExceededException in pr_processing.py for correct exception catch 2023-07-25 15:14:56 +03:00
b6f6c903a0 moved @retry to github_provider.py and fetch number of retries from settings 2023-07-25 15:12:02 +03:00
55637a5620 added retry decorator similar to used in ai_handler following @okotek suggestion 2023-07-25 14:42:54 +03:00
404cc0a00e small change to show message and fail 2023-07-25 14:20:20 +03:00
d1a8a610e9 Revert "show how much time until rate limit reset"
This reverts commit 8f482cd41a.
2023-07-25 13:38:55 +03:00
8f482cd41a show how much time until rate limit reset
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2023-07-25 13:23:19 +03:00
34096059ff quick and dirty response for github API ratelimit, until some smart solution will be implemented 2023-07-25 13:05:56 +03:00
adb3f17258 Merge pull request #131 from Codium-ai/ok/gitlab_webook
GitLab Webhook Integration and Provider Enhancements
2023-07-24 16:01:17 +03:00
2c03a67312 Add labels 2023-07-24 16:00:51 +03:00
55eb741965 Merge pull request #125 from Codium-ai/tr/code_enhancment
Code Enhancement in PR Agent
2023-07-24 15:37:53 +03:00
c9c95d60d4 Implement gitlab webhook 2023-07-24 15:05:24 +03:00
cca809e91c run_action 2023-07-24 12:45:24 +03:00
57ff46ecc1 stable 2023-07-24 12:41:00 +03:00
3819d52eb0 Merge remote-tracking branch 'origin/tr/code_enhancment' into tr/code_enhancment 2023-07-24 12:15:17 +03:00
3072325d2c PRDescription 2023-07-24 12:14:53 +03:00
abca2fdcb7 Merge remote-tracking branch 'origin/main' into tr/code_enhancment 2023-07-24 12:04:54 +03:00
4d84f76948 _get_prediction 2023-07-24 11:31:35 +03:00
dd8f6eb923 Merge pull request #126 from Codium-ai/ok/preserve_labels
Add functionality to preserve existing labels in PRs
2023-07-24 10:22:51 +03:00
b9c25e487a On /describe, preserve the current labels 2023-07-24 10:17:26 +03:00
1bf27c38a7 _prepare_pr_answer 2023-07-24 09:15:45 +03:00
1f987380ed Merge pull request #124 from Xyand/bugfix/mising_model
Bugfix - missing function argument
2023-07-24 07:36:21 +03:00
cd8bbbf889 bugfix 2023-07-24 00:58:21 +03:00
12 changed files with 243 additions and 75 deletions

View File

@ -1,11 +0,0 @@
bot-review:
stage: test
variables:
MR_URL: ${CI_MERGE_REQUEST_PROJECT_URL}/-/merge_requests/${CI_MERGE_REQUEST_IID}
image: docker:latest
services:
- docker:19-dind
script:
- docker run --rm -e OPENAI.KEY=${OPEN_API_KEY} -e OPENAI.ORG=${OPEN_API_ORG} -e GITLAB.PERSONAL_ACCESS_TOKEN=${GITLAB_PAT} -e CONFIG.GIT_PROVIDER=gitlab codiumai/pr-agent --pr_url ${MR_URL} describe
rules:
- if: $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH

View File

@ -83,6 +83,7 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull
| | Reflect and Review | :white_check_mark: | | |
| | | | | |
| USAGE | CLI | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| | App / webhook | :white_check_mark: | :white_check_mark: | |
| | Tagging bot | :white_check_mark: | | |
| | Actions | :white_check_mark: | | |
| | | | | |

View File

@ -3,6 +3,8 @@ from __future__ import annotations
import logging
from typing import Tuple, Union, Callable, List
from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
@ -19,7 +21,6 @@ OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
"""
@ -40,7 +41,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
global PATCH_EXTRA_LINES
PATCH_EXTRA_LINES = 0
diff_files = list(git_provider.get_diff_files())
try:
diff_files = list(git_provider.get_diff_files())
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
@ -55,7 +60,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
# if we are over the limit, start pruning
patches_compressed, modified_file_names, deleted_file_names = \
pr_generate_compressed_diff(pr_languages, token_handler, add_line_numbers_to_hunks)
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
final_diff = "\n".join(patches_compressed)
if modified_file_names:

View File

@ -27,7 +27,7 @@ class BitbucketProvider:
self.set_pr(pr_url)
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments']:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels']:
return False
return True

View File

@ -60,6 +60,10 @@ class GitProvider(ABC):
def publish_labels(self, labels):
pass
@abstractmethod
def get_labels(self):
pass
@abstractmethod
def remove_initial_comment(self):
pass
@ -132,3 +136,4 @@ class IncrementalPR:
self.commits_range = None
self.first_new_commit_sha = None
self.last_seen_commit_sha = None

View File

@ -3,13 +3,15 @@ from datetime import datetime
from typing import Optional, Tuple
from urllib.parse import urlparse
from github import AppAuthentication, Github, Auth
from github import AppAuthentication, Auth, Github, GithubException
from retry import retry
from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..servers.utils import RateLimitExceeded
class GithubProvider(GitProvider):
@ -78,27 +80,34 @@ class GithubProvider(GitProvider):
return self.file_set.values()
return self.pr.get_files()
@retry(exceptions=RateLimitExceeded,
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]:
files = self.get_files()
diff_files = []
for file in files:
if is_valid_file(file.filename):
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
patch = file.patch
if self.incremental.is_incremental and self.file_set:
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
patch = load_large_diff(file,
new_file_content_str,
original_file_content_str,
None)
self.file_set[file.filename] = patch
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
try:
files = self.get_files()
diff_files = []
for file in files:
if is_valid_file(file.filename):
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
patch = file.patch
if self.incremental.is_incremental and self.file_set:
original_file_content_str = self._get_pr_file_content(file,
self.incremental.last_seen_commit_sha)
patch = load_large_diff(file,
new_file_content_str,
original_file_content_str,
None)
self.file_set[file.filename] = patch
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename))
self.diff_files = diff_files
return diff_files
diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename))
self.diff_files = diff_files
return diff_files
except GithubException.RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}")
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
def publish_description(self, pr_title: str, pr_body: str):
self.pr.edit(title=pr_title, body=pr_body)
@ -322,5 +331,12 @@ class GithubProvider(GitProvider):
headers, data = self.pr._requester.requestJsonAndCheck(
"PUT", f"{self.pr.issue_url}/labels", input=post_parameters
)
except:
logging.exception("Failed to publish labels")
except Exception as e:
logging.exception(f"Failed to publish labels, error: {e}")
def get_labels(self):
try:
return [label.name for label in self.pr.labels]
except Exception as e:
logging.exception(f"Failed to get labels, error: {e}")
return []

View File

@ -8,11 +8,13 @@ from gitlab import GitlabGetError
from pr_agent.config_loader import settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url:
@ -112,7 +114,7 @@ class GitLabProvider(GitProvider):
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
raise NotImplementedError("Gitlab provider does not support creating inline comments yet")
def create_inline_comment(self, comments: list[dict]):
def create_inline_comments(self, comments: list[dict]):
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet")
def send_inline_comment(self, body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
@ -258,8 +260,15 @@ class GitLabProvider(GitProvider):
def get_user_id(self):
return None
def publish_labels(self, labels):
pass
def publish_labels(self, pr_types):
try:
self.mr.labels = list(set(pr_types))
self.mr.save()
except Exception as e:
logging.exception(f"Failed to publish labels, error: {e}")
def publish_inline_comments(self, comments: list[dict]):
pass
pass
def get_labels(self):
return self.mr.labels

View File

@ -8,50 +8,61 @@ from pr_agent.tools.pr_reviewer import PRReviewer
async def run_action():
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME', None)
# Get environment variables
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
OPENAI_KEY = os.environ.get('OPENAI_KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
# Check if required environment variables are set
if not GITHUB_EVENT_NAME:
print("GITHUB_EVENT_NAME not set")
return
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH', None)
if not GITHUB_EVENT_PATH:
print("GITHUB_EVENT_PATH not set")
return
try:
event_payload = json.load(open(GITHUB_EVENT_PATH, 'r'))
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
OPENAI_KEY = os.environ.get('OPENAI_KEY', None)
if not OPENAI_KEY:
print("OPENAI_KEY not set")
return
OPENAI_ORG = os.environ.get('OPENAI_ORG', None)
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN', None)
if not GITHUB_TOKEN:
print("GITHUB_TOKEN not set")
return
# Set the environment variables in the settings
settings.set("OPENAI.KEY", OPENAI_KEY)
if OPENAI_ORG:
settings.set("OPENAI.ORG", OPENAI_ORG)
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
# Load the event payload
try:
with open(GITHUB_EVENT_PATH, 'r') as f:
event_payload = json.load(f)
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
# Handle pull request event
if GITHUB_EVENT_NAME == "pull_request":
action = event_payload.get("action", None)
action = event_payload.get("action")
if action in ["opened", "reopened"]:
pr_url = event_payload.get("pull_request", {}).get("url", None)
pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url:
await PRReviewer(pr_url).review()
# Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment":
action = event_payload.get("action", None)
action = event_payload.get("action")
if action in ["created", "edited"]:
comment_body = event_payload.get("comment", {}).get("body", None)
comment_body = event_payload.get("comment", {}).get("body")
if comment_body:
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url", None)
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
if pr_url:
body = comment_body.strip().lower()
await PRAgent().handle_request(pr_url, body)
if __name__ == '__main__':
asyncio.run(run_action())
asyncio.run(run_action())

View File

@ -0,0 +1,47 @@
import logging
import uvicorn
from fastapi import APIRouter, FastAPI, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from starlette.background import BackgroundTasks
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
app = FastAPI()
router = APIRouter()
@router.post("/webhook")
async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
data = await request.json()
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']:
logging.info(f"A merge request has been opened: {data['object_attributes'].get('title')}")
url = data['object_attributes'].get('url')
background_tasks.add_task(PRAgent().handle_request, url, "/review")
elif data.get('object_kind') == 'note' and data['event_type'] == 'note':
if 'merge_request' in data:
mr = data['merge_request']
url = mr.get('url')
body = data.get('object_attributes', {}).get('note')
background_tasks.add_task(PRAgent().handle_request, url, body)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
def start():
gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GITLAB.URL is not set")
gitlab_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token:
raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set")
settings.config.git_provider = "gitlab"
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=3000)
if __name__ == '__main__':
start()

View File

@ -21,3 +21,7 @@ def verify_signature(payload_body, secret_token, signature_header):
if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded."""
pass

View File

@ -27,6 +27,7 @@ num_code_suggestions=4
[github]
# The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user"
ratelimit_retries = 5
[gitlab]
# URL to the gitlab service

View File

@ -1,6 +1,7 @@
import copy
import json
import logging
from typing import Tuple, List
from jinja2 import Environment, StrictUndefined
@ -14,11 +15,22 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
class PRDescription:
def __init__(self, pr_url: str):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model.
Args:
pr_url (str): The URL of the pull request.
"""
# Initialize the git provider and main PR language
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
# Initialize the AI handler
self.ai_handler = AiHandler()
# Initialize the variables dictionary
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
@ -26,67 +38,135 @@ class PRDescription:
"language": self.main_pr_language,
"diff": "", # empty diff for initial calculation
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_description_prompt.system,
settings.pr_description_prompt.user)
# Initialize the token handler
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
settings.pr_description_prompt.system,
settings.pr_description_prompt.user,
)
# Initialize patches_diff and prediction attributes
self.patches_diff = None
self.prediction = None
async def describe(self):
"""
Generates a PR description using an AI model and publishes it to the PR.
"""
logging.info('Generating a PR description...')
if settings.config.publish_output:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing answer...')
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
if settings.config.publish_output:
logging.info('Pushing answer...')
if settings.pr_description.publish_description_as_comment:
self.git_provider.publish_comment(markdown_text)
else:
self.git_provider.publish_description(pr_title, pr_body)
self.git_provider.publish_labels(pr_types)
if self.git_provider.is_supported("get_labels"):
current_labels = self.git_provider.get_labels()
if current_labels is None:
current_labels = []
self.git_provider.publish_labels(pr_types + current_labels)
self.git_provider.remove_initial_comment()
return ""
async def _prepare_prediction(self, model: str):
async def _prepare_prediction(self, model: str) -> None:
"""
Prepare the AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
None
Raises:
Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions.
"""
logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str):
async def _get_prediction(self, model: str) -> str:
"""
Generate an AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
str: The generated AI prediction.
"""
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=0.2,
system=system_prompt,
user=user_prompt
)
return response
def _prepare_pr_answer(self):
def _prepare_pr_answer(self) -> Tuple[str, str, List[str], str]:
"""
Prepare the PR description based on the AI prediction data.
Returns:
- title: a string containing the PR title.
- pr_body: a string containing the PR body in a markdown format.
- pr_types: a list of strings containing the PR types.
- markdown_text: a string containing the AI prediction data in a markdown format.
"""
# Load the AI prediction data into a dictionary
data = json.loads(self.prediction)
markdown_text = ""
# Initialization
markdown_text = pr_body = ""
pr_types = []
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
for key, value in data.items():
markdown_text += f"## {key}\n\n"
markdown_text += f"{value}\n\n"
pr_body = ""
pr_types = []
# If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
if 'PR Type' in data:
pr_types = data['PR Type'].split(',')
title = data['PR Title']
del data['PR Title']
# Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary
title = data.pop('PR Title')
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
# except for the items containing the word 'walkthrough'
for key, value in data.items():
pr_body += f"{key}:\n"
if 'walkthrough' in key.lower():
pr_body += f"{value}\n"
else:
pr_body += f"**{value}**\n\n___\n"
if settings.config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}")
return title, pr_body, pr_types, markdown_text
return title, pr_body, pr_types, markdown_text