from fastapi import HTTPException, Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jose import jwt, JWTError import httpx from typing import Optional from config import settings security = HTTPBearer() class OIDCClient: def __init__(self): self.issuer = settings.oidc_issuer self.client_id = settings.oidc_client_id self.client_secret = settings.oidc_client_secret self.redirect_uri = settings.oidc_redirect_uri self._discovery_cache = None async def get_discovery_document(self): if self._discovery_cache: return self._discovery_cache async with httpx.AsyncClient() as client: response = await client.get(f"{self.issuer}/.well-known/openid-configuration") self._discovery_cache = response.json() return self._discovery_cache async def exchange_code(self, code: str): discovery = await self.get_discovery_document() token_endpoint = discovery["token_endpoint"] async with httpx.AsyncClient() as client: response = await client.post( token_endpoint, data={ "grant_type": "authorization_code", "code": code, "redirect_uri": self.redirect_uri, "client_id": self.client_id, "client_secret": self.client_secret, } ) return response.json() async def get_user_info(self, access_token: str): discovery = await self.get_discovery_document() userinfo_endpoint = discovery["userinfo_endpoint"] async with httpx.AsyncClient() as client: response = await client.get( userinfo_endpoint, headers={"Authorization": f"Bearer {access_token}"} ) return response.json() oidc_client = OIDCClient() async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): token = credentials.credentials try: payload = jwt.decode( token, settings.secret_key, algorithms=["HS256"] ) user_id: str = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials" ) return payload except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials" ) def create_access_token(data: dict) -> str: return jwt.encode(data, settings.secret_key, algorithm="HS256")