diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 7f617937..dac92e89 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from github import AppAuthentication, Auth, Github, GithubException from retry import retry +from starlette_context import context from pr_agent.config_loader import settings @@ -17,7 +18,10 @@ from ..servers.utils import RateLimitExceeded class GithubProvider(GitProvider): def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)): self.repo_obj = None - self.installation_id = settings.get("GITHUB.INSTALLATION_ID") + try: + self.installation_id = context.get("installation_id", None) + except Exception: + self.installation_id = None self.github_client = self._get_github_client() self.repo = None self.pr_num = None diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index 8050117f..a9ba1de5 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -4,6 +4,9 @@ import sys import uvicorn from fastapi import APIRouter, FastAPI, HTTPException, Request, Response +from starlette.middleware import Middleware +from starlette_context import context +from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import settings @@ -20,24 +23,35 @@ async def handle_github_webhooks(request: Request, response: Response): Verifies the request signature, parses the request body, and passes it to the handle_request function for further processing. """ logging.debug("Received a GitHub webhook") - + + body = await get_body(request) + + logging.debug(f'Request body:\n{body}') + installation_id = body.get("installation", {}).get("id") + context["installation_id"] = installation_id + + return await handle_request(body) + + +@router.post("/api/v1/marketplace_webhooks") +async def handle_marketplace_webhooks(request: Request, response: Response): + body = await get_body(request) + logging.info(f'Request body:\n{body}') + +async def get_body(request): try: body = await request.json() except Exception as e: logging.error("Error parsing request body", e) raise HTTPException(status_code=400, detail="Error parsing request body") from e - body_bytes = await request.body() signature_header = request.headers.get('x-hub-signature-256', None) - webhook_secret = getattr(settings.github, 'webhook_secret', None) - if webhook_secret: verify_signature(body_bytes, webhook_secret, signature_header) - - logging.debug(f'Request body:\n{body}') - - return await handle_request(body) + return body + + async def handle_request(body: Dict[str, Any]): @@ -48,8 +62,6 @@ async def handle_request(body: Dict[str, Any]): body: The request body. """ action = body.get("action") - installation_id = body.get("installation", {}).get("id") - settings.set("GITHUB.INSTALLATION_ID", installation_id) agent = PRAgent() if action == 'created': @@ -85,7 +97,8 @@ async def root(): def start(): # Override the deployment type to app settings.set("GITHUB.DEPLOYMENT_TYPE", "app") - app = FastAPI() + middleware = [Middleware(RawContextMiddleware)] + app = FastAPI(middleware=middleware) app.include_router(router) uvicorn.run(app, host="0.0.0.0", port=3000) diff --git a/pyproject.toml b/pyproject.toml index 03df2480..ac9a4889 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "aiohttp~=3.8.4", "atlassian-python-api==3.39.0", "GitPython~=3.1.32", + "starlette-context==0.3.6" ] [project.urls]