Compare commits

..

19 Commits

Author SHA1 Message Date
8ae5faca53 Fix cyclic dependency 2023-07-25 16:52:18 +03:00
1229fba346 + settings.github.ratelimit_retries setup in configuration.toml 2023-07-25 16:37:13 +03:00
f6036e936e + settings.github.ratelimit_retries setup in configuration.toml 2023-07-25 15:23:40 +03:00
3b334805ee still need GithubException.RateLimitExceededException in pr_processing.py for correct exception catch 2023-07-25 15:14:56 +03:00
b6f6c903a0 moved @retry to github_provider.py and fetch number of retries from settings 2023-07-25 15:12:02 +03:00
55637a5620 added retry decorator similar to used in ai_handler following @okotek suggestion 2023-07-25 14:42:54 +03:00
404cc0a00e small change to show message and fail 2023-07-25 14:20:20 +03:00
d1a8a610e9 Revert "show how much time until rate limit reset"
This reverts commit 8f482cd41a.
2023-07-25 13:38:55 +03:00
8f482cd41a show how much time until rate limit reset
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2023-07-25 13:23:19 +03:00
34096059ff quick and dirty response for github API ratelimit, until some smart solution will be implemented 2023-07-25 13:05:56 +03:00
adb3f17258 Merge pull request #131 from Codium-ai/ok/gitlab_webook
GitLab Webhook Integration and Provider Enhancements
2023-07-24 16:01:17 +03:00
55eb741965 Merge pull request #125 from Codium-ai/tr/code_enhancment
Code Enhancement in PR Agent
2023-07-24 15:37:53 +03:00
cca809e91c run_action 2023-07-24 12:45:24 +03:00
57ff46ecc1 stable 2023-07-24 12:41:00 +03:00
3819d52eb0 Merge remote-tracking branch 'origin/tr/code_enhancment' into tr/code_enhancment 2023-07-24 12:15:17 +03:00
3072325d2c PRDescription 2023-07-24 12:14:53 +03:00
abca2fdcb7 Merge remote-tracking branch 'origin/main' into tr/code_enhancment 2023-07-24 12:04:54 +03:00
4d84f76948 _get_prediction 2023-07-24 11:31:35 +03:00
1bf27c38a7 _prepare_pr_answer 2023-07-24 09:15:45 +03:00
7 changed files with 161 additions and 54 deletions

View File

@ -3,6 +3,8 @@ from __future__ import annotations
import logging import logging
from typing import Tuple, Union, Callable, List from typing import Tuple, Union, Callable, List
from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
@ -19,7 +21,6 @@ OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3 PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
""" """
@ -40,7 +41,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
global PATCH_EXTRA_LINES global PATCH_EXTRA_LINES
PATCH_EXTRA_LINES = 0 PATCH_EXTRA_LINES = 0
diff_files = list(git_provider.get_diff_files()) try:
diff_files = list(git_provider.get_diff_files())
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise
# get pr languages # get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)

View File

@ -136,3 +136,4 @@ class IncrementalPR:
self.commits_range = None self.commits_range = None
self.first_new_commit_sha = None self.first_new_commit_sha = None
self.last_seen_commit_sha = None self.last_seen_commit_sha = None

View File

@ -3,13 +3,15 @@ from datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from github import AppAuthentication, Github, Auth from github import AppAuthentication, Auth, Github, GithubException
from retry import retry
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..servers.utils import RateLimitExceeded
class GithubProvider(GitProvider): class GithubProvider(GitProvider):
@ -78,27 +80,34 @@ class GithubProvider(GitProvider):
return self.file_set.values() return self.file_set.values()
return self.pr.get_files() return self.pr.get_files()
@retry(exceptions=RateLimitExceeded,
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
files = self.get_files() try:
diff_files = [] files = self.get_files()
for file in files: diff_files = []
if is_valid_file(file.filename): for file in files:
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) if is_valid_file(file.filename):
patch = file.patch new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
if self.incremental.is_incremental and self.file_set: patch = file.patch
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) if self.incremental.is_incremental and self.file_set:
patch = load_large_diff(file, original_file_content_str = self._get_pr_file_content(file,
new_file_content_str, self.incremental.last_seen_commit_sha)
original_file_content_str, patch = load_large_diff(file,
None) new_file_content_str,
self.file_set[file.filename] = patch original_file_content_str,
else: None)
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) self.file_set[file.filename] = patch
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
diff_files.append( diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename))
self.diff_files = diff_files self.diff_files = diff_files
return diff_files return diff_files
except GithubException.RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}")
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
self.pr.edit(title=pr_title, body=pr_body) self.pr.edit(title=pr_title, body=pr_body)

View File

@ -8,50 +8,61 @@ from pr_agent.tools.pr_reviewer import PRReviewer
async def run_action(): async def run_action():
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME', None) # Get environment variables
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
OPENAI_KEY = os.environ.get('OPENAI_KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
# Check if required environment variables are set
if not GITHUB_EVENT_NAME: if not GITHUB_EVENT_NAME:
print("GITHUB_EVENT_NAME not set") print("GITHUB_EVENT_NAME not set")
return return
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH', None)
if not GITHUB_EVENT_PATH: if not GITHUB_EVENT_PATH:
print("GITHUB_EVENT_PATH not set") print("GITHUB_EVENT_PATH not set")
return return
try:
event_payload = json.load(open(GITHUB_EVENT_PATH, 'r'))
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
OPENAI_KEY = os.environ.get('OPENAI_KEY', None)
if not OPENAI_KEY: if not OPENAI_KEY:
print("OPENAI_KEY not set") print("OPENAI_KEY not set")
return return
OPENAI_ORG = os.environ.get('OPENAI_ORG', None)
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN', None)
if not GITHUB_TOKEN: if not GITHUB_TOKEN:
print("GITHUB_TOKEN not set") print("GITHUB_TOKEN not set")
return return
# Set the environment variables in the settings
settings.set("OPENAI.KEY", OPENAI_KEY) settings.set("OPENAI.KEY", OPENAI_KEY)
if OPENAI_ORG: if OPENAI_ORG:
settings.set("OPENAI.ORG", OPENAI_ORG) settings.set("OPENAI.ORG", OPENAI_ORG)
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN) settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
settings.set("GITHUB.DEPLOYMENT_TYPE", "user") settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
# Load the event payload
try:
with open(GITHUB_EVENT_PATH, 'r') as f:
event_payload = json.load(f)
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
# Handle pull request event
if GITHUB_EVENT_NAME == "pull_request": if GITHUB_EVENT_NAME == "pull_request":
action = event_payload.get("action", None) action = event_payload.get("action")
if action in ["opened", "reopened"]: if action in ["opened", "reopened"]:
pr_url = event_payload.get("pull_request", {}).get("url", None) pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url: if pr_url:
await PRReviewer(pr_url).review() await PRReviewer(pr_url).review()
# Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment": elif GITHUB_EVENT_NAME == "issue_comment":
action = event_payload.get("action", None) action = event_payload.get("action")
if action in ["created", "edited"]: if action in ["created", "edited"]:
comment_body = event_payload.get("comment", {}).get("body", None) comment_body = event_payload.get("comment", {}).get("body")
if comment_body: if comment_body:
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url", None) pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
if pr_url: if pr_url:
body = comment_body.strip().lower() body = comment_body.strip().lower()
await PRAgent().handle_request(pr_url, body) await PRAgent().handle_request(pr_url, body)
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(run_action()) asyncio.run(run_action())

View File

@ -21,3 +21,7 @@ def verify_signature(payload_body, secret_token, signature_header):
if not hmac.compare_digest(expected_signature, signature_header): if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!") raise HTTPException(status_code=403, detail="Request signatures didn't match!")
class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded."""
pass

View File

@ -27,6 +27,7 @@ num_code_suggestions=4
[github] [github]
# The type of deployment to create. Valid values are 'app' or 'user'. # The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user" deployment_type = "user"
ratelimit_retries = 5
[gitlab] [gitlab]
# URL to the gitlab service # URL to the gitlab service

View File

@ -1,6 +1,7 @@
import copy import copy
import json import json
import logging import logging
from typing import Tuple, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -14,11 +15,22 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
class PRDescription: class PRDescription:
def __init__(self, pr_url: str): def __init__(self, pr_url: str):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model.
Args:
pr_url (str): The URL of the pull request.
"""
# Initialize the git provider and main PR language
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
# Initialize the AI handler
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
# Initialize the variables dictionary
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),
@ -26,20 +38,32 @@ class PRDescription:
"language": self.main_pr_language, "language": self.main_pr_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
} }
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, # Initialize the token handler
settings.pr_description_prompt.system, self.token_handler = TokenHandler(
settings.pr_description_prompt.user) self.git_provider.pr,
self.vars,
settings.pr_description_prompt.system,
settings.pr_description_prompt.user,
)
# Initialize patches_diff and prediction attributes
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
async def describe(self): async def describe(self):
"""
Generates a PR description using an AI model and publishes it to the PR.
"""
logging.info('Generating a PR description...') logging.info('Generating a PR description...')
if settings.config.publish_output: if settings.config.publish_output:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing answer...') logging.info('Preparing answer...')
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer() pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
if settings.config.publish_output: if settings.config.publish_output:
logging.info('Pushing answer...') logging.info('Pushing answer...')
if settings.pr_description.publish_description_as_comment: if settings.pr_description.publish_description_as_comment:
@ -52,45 +76,97 @@ class PRDescription:
current_labels = [] current_labels = []
self.git_provider.publish_labels(pr_types + current_labels) self.git_provider.publish_labels(pr_types + current_labels)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
return "" return ""
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str) -> None:
"""
Prepare the AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
None
Raises:
Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions.
"""
logging.info('Getting PR diff...') logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
logging.info('Getting AI prediction...') logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model) self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str): async def _get_prediction(self, model: str) -> str:
"""
Generate an AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
str: The generated AI prediction.
"""
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables) system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables) user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt) response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=0.2,
system=system_prompt,
user=user_prompt
)
return response return response
def _prepare_pr_answer(self): def _prepare_pr_answer(self) -> Tuple[str, str, List[str], str]:
"""
Prepare the PR description based on the AI prediction data.
Returns:
- title: a string containing the PR title.
- pr_body: a string containing the PR body in a markdown format.
- pr_types: a list of strings containing the PR types.
- markdown_text: a string containing the AI prediction data in a markdown format.
"""
# Load the AI prediction data into a dictionary
data = json.loads(self.prediction) data = json.loads(self.prediction)
markdown_text = ""
# Initialization
markdown_text = pr_body = ""
pr_types = []
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
for key, value in data.items(): for key, value in data.items():
markdown_text += f"## {key}\n\n" markdown_text += f"## {key}\n\n"
markdown_text += f"{value}\n\n" markdown_text += f"{value}\n\n"
pr_body = ""
pr_types = [] # If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
if 'PR Type' in data: if 'PR Type' in data:
pr_types = data['PR Type'].split(',') pr_types = data['PR Type'].split(',')
title = data['PR Title']
del data['PR Title'] # Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary
title = data.pop('PR Title')
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
# except for the items containing the word 'walkthrough'
for key, value in data.items(): for key, value in data.items():
pr_body += f"{key}:\n" pr_body += f"{key}:\n"
if 'walkthrough' in key.lower(): if 'walkthrough' in key.lower():
pr_body += f"{value}\n" pr_body += f"{value}\n"
else: else:
pr_body += f"**{value}**\n\n___\n" pr_body += f"**{value}**\n\n___\n"
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}") logging.info(f"title:\n{title}\n{pr_body}")
return title, pr_body, pr_types, markdown_text
return title, pr_body, pr_types, markdown_text