diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 7f617937..d3d3645e 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from github import AppAuthentication, Auth, Github, GithubException from retry import retry +from starlette_context import context from pr_agent.config_loader import settings @@ -17,7 +18,7 @@ from ..servers.utils import RateLimitExceeded class GithubProvider(GitProvider): def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)): self.repo_obj = None - self.installation_id = settings.get("GITHUB.INSTALLATION_ID") + self.installation_id = context.get("installation_id", None) self.github_client = self._get_github_client() self.repo = None self.pr_num = None diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index 8050117f..d23cccd4 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -4,6 +4,9 @@ import sys import uvicorn from fastapi import APIRouter, FastAPI, HTTPException, Request, Response +from starlette.middleware import Middleware +from starlette_context import context +from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import settings @@ -36,7 +39,9 @@ async def handle_github_webhooks(request: Request, response: Response): verify_signature(body_bytes, webhook_secret, signature_header) logging.debug(f'Request body:\n{body}') - + installation_id = body.get("installation", {}).get("id") + context["installation_id"] = installation_id + return await handle_request(body) @@ -48,8 +53,6 @@ async def handle_request(body: Dict[str, Any]): body: The request body. """ action = body.get("action") - installation_id = body.get("installation", {}).get("id") - settings.set("GITHUB.INSTALLATION_ID", installation_id) agent = PRAgent() if action == 'created': @@ -85,7 +88,8 @@ async def root(): def start(): # Override the deployment type to app settings.set("GITHUB.DEPLOYMENT_TYPE", "app") - app = FastAPI() + middleware = [Middleware(RawContextMiddleware)] + app = FastAPI(middleware=middleware) app.include_router(router) uvicorn.run(app, host="0.0.0.0", port=3000) diff --git a/pyproject.toml b/pyproject.toml index 03df2480..ac9a4889 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "aiohttp~=3.8.4", "atlassian-python-api==3.39.0", "GitPython~=3.1.32", + "starlette-context==0.3.6" ] [project.urls]