81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
from fastapi import HTTPException, Depends, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from jose import jwt, JWTError
|
|
import httpx
|
|
from typing import Optional
|
|
from config import settings
|
|
|
|
security = HTTPBearer()
|
|
|
|
class OIDCClient:
|
|
def __init__(self):
|
|
self.issuer = settings.oidc_issuer
|
|
self.client_id = settings.oidc_client_id
|
|
self.client_secret = settings.oidc_client_secret
|
|
self.redirect_uri = settings.oidc_redirect_uri
|
|
self._discovery_cache = None
|
|
|
|
async def get_discovery_document(self):
|
|
if self._discovery_cache:
|
|
return self._discovery_cache
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{self.issuer}/.well-known/openid-configuration")
|
|
self._discovery_cache = response.json()
|
|
return self._discovery_cache
|
|
|
|
async def exchange_code(self, code: str):
|
|
discovery = await self.get_discovery_document()
|
|
token_endpoint = discovery["token_endpoint"]
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
token_endpoint,
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": self.redirect_uri,
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
}
|
|
)
|
|
return response.json()
|
|
|
|
async def get_user_info(self, access_token: str):
|
|
discovery = await self.get_discovery_document()
|
|
userinfo_endpoint = discovery["userinfo_endpoint"]
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
userinfo_endpoint,
|
|
headers={"Authorization": f"Bearer {access_token}"}
|
|
)
|
|
return response.json()
|
|
|
|
oidc_client = OIDCClient()
|
|
|
|
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
token = credentials.credentials
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.secret_key,
|
|
algorithms=["HS256"]
|
|
)
|
|
user_id: str = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authentication credentials"
|
|
)
|
|
return payload
|
|
except JWTError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authentication credentials"
|
|
)
|
|
|
|
def create_access_token(data: dict) -> str:
|
|
return jwt.encode(data, settings.secret_key, algorithm="HS256")
|