Merge pull request #719 from Codium-ai/ok/identity_provider

Ok/identity provider
This commit is contained in:
Tal
2024-02-26 09:34:06 -08:00
committed by GitHub
5 changed files with 104 additions and 26 deletions

View File

@ -0,0 +1,13 @@
from pr_agent.config_loader import get_settings
from pr_agent.identity_providers.default_identity_provider import DefaultIdentityProvider
_IDENTITY_PROVIDERS = {
'default': DefaultIdentityProvider
}
def get_identity_provider():
identity_provider_id = get_settings().get("CONFIG.IDENTITY_PROVIDER", "default")
if identity_provider_id not in _IDENTITY_PROVIDERS:
raise ValueError(f"Unknown identity provider: {identity_provider_id}")
return _IDENTITY_PROVIDERS[identity_provider_id]()

View File

@ -0,0 +1,9 @@
from pr_agent.identity_providers.identity_provider import Eligibility, IdentityProvider
class DefaultIdentityProvider(IdentityProvider):
def verify_eligibility(self, git_provider, git_provider_id, pr_url):
return Eligibility.ELIGIBLE
def inc_invocation_count(self, git_provider, git_provider_id):
pass

View File

@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from enum import Enum
class Eligibility(Enum):
NOT_ELIGIBLE = 0
ELIGIBLE = 1
TRIAL = 2
class IdentityProvider(ABC):
@abstractmethod
def verify_eligibility(self, git_provider, git_provier_id, pr_url):
pass
@abstractmethod
def inc_invocation_count(self, git_provider, git_provider_id):
pass

View File

@ -1,3 +1,4 @@
import base64
import copy
import hashlib
import json
@ -17,6 +18,8 @@ from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.identity_providers import get_identity_provider
from pr_agent.identity_providers.identity_provider import Eligibility
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.secret_providers import get_secret_provider
from pr_agent.servers.github_action_runner import get_setting_or_env, is_true
@ -80,11 +83,27 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
get_logger().debug(data)
async def inner():
try:
owner = data["data"]["repository"]["owner"]["username"]
try:
if data["data"]["actor"]["type"] != "user":
return "OK"
except KeyError:
get_logger().error("Failed to get actor type, check previous logs, this shouldn't happen.")
try:
owner = data["data"]["repository"]["owner"]["username"]
except Exception as e:
get_logger().error(f"Failed to get owner, will continue: {e}")
owner = "unknown"
sender_id = data["data"]["actor"]["account_id"]
log_context["sender"] = owner
secrets = json.loads(secret_provider.get_secret(owner))
log_context["sender_id"] = sender_id
jwt_parts = input_jwt.split(".")
claim_part = jwt_parts[1]
claim_part += "=" * (-len(claim_part) % 4)
decoded_claims = base64.urlsafe_b64decode(claim_part)
claims = json.loads(decoded_claims)
client_key = claims["iss"]
secrets = json.loads(secret_provider.get_secret(client_key))
shared_secret = secrets["shared_secret"]
client_key = secrets["client_key"]
jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"])
bearer_token = await get_bearer_token(shared_secret, client_key)
context['bitbucket_bearer_token'] = bearer_token
@ -98,15 +117,17 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
if pr_url:
with get_logger().contextualize(**log_context):
apply_repo_settings(pr_url)
auto_review = get_setting_or_env("BITBUCKET_APP.AUTO_REVIEW", None)
if auto_review is None or is_true(auto_review): # by default, auto review is enabled
await PRReviewer(pr_url).run()
auto_improve = get_setting_or_env("BITBUCKET_APP.AUTO_IMPROVE", None)
if is_true(auto_improve): # by default, auto improve is disabled
await PRCodeSuggestions(pr_url).run()
auto_describe = get_setting_or_env("BITBUCKET_APP.AUTO_DESCRIBE", None)
if is_true(auto_describe): # by default, auto describe is disabled
await PRDescription(pr_url).run()
if get_identity_provider().verify_eligibility("bitbucket",
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
auto_review = get_setting_or_env("BITBUCKET_APP.AUTO_REVIEW", None)
if auto_review is None or is_true(auto_review): # by default, auto review is enabled
await PRReviewer(pr_url).run()
auto_improve = get_setting_or_env("BITBUCKET_APP.AUTO_IMPROVE", None)
if is_true(auto_improve): # by default, auto improve is disabled
await PRCodeSuggestions(pr_url).run()
auto_describe = get_setting_or_env("BITBUCKET_APP.AUTO_DESCRIBE", None)
if is_true(auto_describe): # by default, auto describe is disabled
await PRDescription(pr_url).run()
# with get_logger().contextualize(**log_context):
# await agent.handle_request(pr_url, "review")
elif event == "pullrequest:comment_created":
@ -115,7 +136,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
log_context["event"] = "comment"
comment_body = data["data"]["comment"]["content"]["raw"]
with get_logger().contextualize(**log_context):
await agent.handle_request(pr_url, comment_body)
if get_identity_provider().verify_eligibility("bitbucket",
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
await agent.handle_request(pr_url, comment_body)
except Exception as e:
get_logger().error(f"Failed to handle webhook: {e}")
background_tasks.add_task(inner)

View File

@ -3,7 +3,7 @@ import copy
import os
import re
import uuid
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, Tuple
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
@ -17,11 +17,19 @@ from pr_agent.config_loader import get_settings, global_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import IncrementalPR
from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.identity_providers import get_identity_provider
from pr_agent.identity_providers.identity_provider import Eligibility
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.utils import DefaultDictWithTimeout, verify_signature
setup_logger(fmt=LoggingFormat.JSON)
base_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
build_number_path = os.path.join(base_path, "build_number.txt")
if os.path.exists(build_number_path):
with open(build_number_path) as f:
build_number = f.read().strip()
else:
build_number = "unknown"
router = APIRouter()
@ -70,6 +78,7 @@ _pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.C
async def handle_comments_on_pr(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
@ -98,13 +107,15 @@ async def handle_comments_on_pr(body: Dict[str, Any],
comment_id = body.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=api_url)
with get_logger().contextualize(**log_context):
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
await agent.handle_request(api_url, comment_body,
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
await agent.handle_request(api_url, comment_body,
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
async def handle_new_pr_opened(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
@ -123,11 +134,13 @@ async def handle_new_pr_opened(body: Dict[str, Any],
get_logger().info(f"Invalid PR event: {action=} {api_url=}")
return {}
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review', 'review_requested']
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
@ -180,8 +193,9 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
if get_settings().github_app.push_trigger_wait_for_initial_review and not get_git_provider()(api_url,
incremental=IncrementalPR(
True)).previous_review:
get_logger().info(f"Skipping incremental review because there was no initial review for {api_url=} yet")
return {}
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
get_logger().info(f"Skipping incremental review because there was no initial review for {api_url=} yet")
return {}
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
@ -205,25 +219,26 @@ async def handle_request(body: Dict[str, Any], event: str):
return {}
agent = PRAgent()
sender = body.get("sender", {}).get("login")
sender_id = body.get("sender", {}).get("id")
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app",
"request_id": uuid.uuid4().hex, "app_name": app_name}
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name}
# handle comments on PRs
if action == 'created':
get_logger().debug(f'Request body', artifact=body)
await handle_comments_on_pr(body, event, sender, action, log_context, agent)
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent)
# handle new PRs
elif event == 'pull_request' and action != 'synchronize':
get_logger().debug(f'Request body', artifact=body)
await handle_new_pr_opened(body, event, sender, action, log_context, agent)
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent)
# handle pull_request event with synchronize action - "push trigger" for new commits
elif event == 'pull_request' and action == 'synchronize':
get_logger().debug(f'Request body', artifact=body)
await handle_push_trigger_for_new_commits(body, event, sender, action, log_context, agent)
await handle_push_trigger_for_new_commits(body, event, sender, sender_id, action, log_context, agent)
else:
get_logger().info(f"event {event=} action {action=} does not require any handling")
return {}
return {}
def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str: