diff --git a/pr_agent/servers/gitlab_webhook.py b/pr_agent/servers/gitlab_webhook.py index cfc3a8a0..565405b4 100644 --- a/pr_agent/servers/gitlab_webhook.py +++ b/pr_agent/servers/gitlab_webhook.py @@ -1,3 +1,5 @@ +import copy +import json import logging import uvicorn @@ -5,23 +7,38 @@ from fastapi import APIRouter, FastAPI, Request, status from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from starlette.background import BackgroundTasks +from starlette.middleware import Middleware +from starlette_context import context +from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent -from pr_agent.config_loader import get_settings +from pr_agent.config_loader import get_settings, global_settings from pr_agent.secret_providers import get_secret_provider -app = FastAPI() router = APIRouter() -if get_settings().config.secret_provider: - secret_provider = get_secret_provider() +secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None + @router.post("/webhook") async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request): - if get_settings().get("GITLAB.SHARED_SECRET"): + if request.headers.get("X-Gitlab-Token") and secret_provider: + request_token = request.headers.get("X-Gitlab-Token") + secret = secret_provider.get_secret(request_token) + try: + secret_dict = json.loads(secret) + gitlab_token = secret_dict["gitlab_token"] + context["settings"] = copy.deepcopy(global_settings) + context["settings"].gitlab.personal_access_token = gitlab_token + except Exception as e: + logging.error(f"Failed to validate secret {request_token}: {e}") + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) + elif get_settings().get("GITLAB.SHARED_SECRET"): secret = get_settings().get("GITLAB.SHARED_SECRET") if not request.headers.get("X-Gitlab-Token") == secret: return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) + else: + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None) if not gitlab_token: return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) @@ -43,8 +60,8 @@ def start(): if not gitlab_url: raise ValueError("GITLAB.URL is not set") get_settings().config.git_provider = "gitlab" - - app = FastAPI() + middleware = [Middleware(RawContextMiddleware)] + app = FastAPI(middleware=middleware) app.include_router(router) uvicorn.run(app, host="0.0.0.0", port=3000)