fix basic auth

This commit is contained in:
Yochai Lehman
2024-02-11 17:42:06 -05:00
parent bc38fad4db
commit 95344c7083

View File

@ -5,8 +5,9 @@
import json import json
import os import os
import re import re
import secrets
import uvicorn import uvicorn
from fastapi import APIRouter, Depends, FastAPI from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from starlette import status from starlette import status
@ -19,7 +20,11 @@ from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent, command2class from pr_agent.agent.pr_agent import PRAgent, command2class
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger from pr_agent.log import get_logger
from fastapi import Request, Depends
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pr_agent.log import get_logger
security = HTTPBasic()
router = APIRouter() router = APIRouter()
available_commands_rgx = re.compile(r"^\/(" + "|".join(command2class.keys()) + r")\s*") available_commands_rgx = re.compile(r"^\/(" + "|".join(command2class.keys()) + r")\s*")
azuredevops_server = get_settings().get("azure_devops_server") azuredevops_server = get_settings().get("azure_devops_server")
@ -35,18 +40,24 @@ def handle_request(
background_tasks.add_task(PRAgent().handle_request, url, body) background_tasks.add_task(PRAgent().handle_request, url, body)
@router.post("/") # currently only basic auth is supported with azure webhooks
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
if not (is_user_ok and is_pass_ok):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Incorrect username or password.',
headers={'WWW-Authenticate': 'Basic'},
)
@router.post("/", dependencies=[Depends(authorize)])
async def handle_webhook(background_tasks: BackgroundTasks, request: Request): async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "azuredevops_server"} log_context = {"server_type": "azuredevops_server"}
data = await request.json() data = await request.json()
get_logger().info(json.dumps(data)) get_logger().info(json.dumps(data))
if not validate_basic_auth(request):
get_logger().error("Unauthorized webhook request")
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=json.dumps({"message": "unauthorized"}),
)
actions = [] actions = []
if data["eventType"] == "git.pullrequest.created": if data["eventType"] == "git.pullrequest.created":
# API V1 (latest) # API V1 (latest)
@ -96,23 +107,6 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "webhook triggerd successfully"}) status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "webhook triggerd successfully"})
) )
# currently only basic auth is supported with azure webhooks
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
def validate_basic_auth(request: Request):
try:
auth = request.headers.get("Authorization")
if not auth:
return False
if not auth.startswith("Basic "):
return False
security = HTTPBasic()
credentials: HTTPBasicCredentials = Depends(security)
username = credentials.username
password = credentials.password
return username == WEBHOOK_USERNAME and password == WEBHOOK_PASSWORD
except:
get_logger().error("Failed to validate basic auth")
return False
def start(): def start():
app = FastAPI(middleware=[Middleware(RawContextMiddleware)]) app = FastAPI(middleware=[Middleware(RawContextMiddleware)])