"""AAD OAuth handlers."""
import base64
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from authlib import __version__ as authlib_version
from authlib.jose import errors as jwt_errors, JsonWebKey, jwk, jwt
from authlib.jose.util import extract_header
from cachetools import TTLCache
from cryptography.hazmat.primitives import serialization
import msal
from packaging import version
from pkg_resources import resource_string
from pydantic import BaseSettings as _BaseSettings, Field, HttpUrl, PrivateAttr, SecretStr, validator
import requests
from starlette.middleware.authentication import AuthenticationError
from starlette.requests import Request
from fastapi_aad_auth._base.authenticators import SessionAuthenticator
from fastapi_aad_auth._base.provider import Provider, ProviderConfig
from fastapi_aad_auth._base.state import User
from fastapi_aad_auth._base.validators import OAuthFlowType, SessionValidator, TokenValidator
from fastapi_aad_auth.errors import ConfigurationError
from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, expand_doc, is_deprecated, list_from_env, urls
[docs]class TokenType(Enum):
"""Type of token to use."""
access = 'access_token'
id = 'id_token'
[docs]class BaseSettings(DeprecatableFieldsMixin, _BaseSettings):
"""Base Settings with Deprecatable Fields."""
pass
[docs]class AADSessionAuthenticator(SessionAuthenticator):
"""AAD Authenticator for interactive (UI) sessions."""
def __init__(
self,
session_validator,
token_validator,
client_id,
tenant_id,
redirect_path='/oauth/aad/redirect',
prompt=None,
client_secret=None,
scopes=None,
redirect_uri=None,
domain_hint=None,
token_type=TokenType.access):
"""Initialise AAD Authenticator for interactive (UI) sessions."""
super().__init__(session_validator, token_validator)
self._redirect_path = redirect_path
self._redirect_uri = redirect_uri
self._domain_hint = domain_hint
self._prompt = prompt
self.client_id = client_id
if scopes is None:
scopes = [f'{self.client_id}/openid']
self.logger.info(f'Scopes {scopes}')
self._scopes = scopes
if isinstance(token_type, Enum):
token_type = token_type.value
self._token_type = token_type
self._authority = f'https://login.microsoftonline.com/{tenant_id}'
if client_secret is not None:
self.logger.info('Client secret provided, using Confidential Client')
self.msal_application = msal.ConfidentialClientApplication(
client_id,
authority=self._authority,
client_credential=client_secret)
else:
self.logger.info('Client secret not provided, using Public Client')
self.msal_application = msal.PublicClientApplication(
client_id,
authority=self._authority)
def _build_redirect_uri(self, request: Request):
if self._redirect_uri:
redirect_uri = self._redirect_uri
else:
if request.url.port is None or (request.url.port == 80 and request.url.scheme == 'http') or (request.url.port == 443 and request.url.scheme == 'https'):
port = ''
else:
port = f':{request.url.port}'
if 'X-Forwarded-Proto' in request.headers and request.headers['X-Forwarded-Proto'] == 'https':
url_scheme = 'https'
else:
url_scheme = 'http'
redirect_uri = f'{url_scheme}://{request.url.hostname}{port}{self._redirect_path}'
self.logger.info(f'Created redirect uri: {redirect_uri} from {request.url}')
return redirect_uri
def _process_code(self, request: Request, auth_state, code):
# Let's build up the redirect_uri
result = self.msal_application.acquire_token_by_authorization_code(code, scopes=self._scopes,
redirect_uri=self._build_redirect_uri(request))
self.logger.debug(f'Result {result}')
if 'error' in result and result['error']:
error_args = [result['error']]
if 'error_description' in result:
error_args.append(result['error_description'])
raise ConfigurationError(*error_args)
return result[self._token_type]
def _get_user_from_token(self, token, options=None):
if options is None:
options = self._token_validator._claims_options
options.pop('azp', None)
options.pop('appid', None)
return super()._get_user_from_token(token, options=options)
def _get_authorization_url(self, request, session_state):
return self.msal_application.get_authorization_request_url(self._scopes,
state=session_state,
claims_challenge='{"id_token": {"roles": {"essential": true} } }',
redirect_uri=self._build_redirect_uri(request),
prompt=self._prompt,
domain_hint=self._domain_hint)
[docs] def get_access_token(self, user, scopes=None, app_scopes=True):
"""Get the access token for the user."""
result = None
account = None
if scopes is None:
scopes = self._scopes
elif app_scopes:
scopes = self.as_app_scopes(scopes)
if user.username:
account = self.msal_application.get_accounts(user.username)
if account:
account = account[0]
self.logger.info(account)
# This needs you to register the scopes in the app registration
result = self.msal_application.acquire_token_silent_with_error(scopes=scopes, account=account)
self.logger.info(f'Acquired Token: {result}')
if result is None:
raise ValueError('Token not found')
elif 'error' in result:
raise ConfigurationError(result['error'], result['error_description'])
else:
return {'token_type': result['token_type'],
'expires_in': result['expires_in'],
'access_token': result['access_token']}
[docs] def as_app_scopes(self, scopes):
"""Add the application client id to the scopes so that the tokens are valid for this app."""
if self.client_id not in scopes[0]:
scopes[0] = f'{self.client_id}/{scopes[0]}'
return scopes
[docs]class AADTokenValidator(TokenValidator):
"""Validator for AAD token based authentication."""
def __init__(self,
client_id: str,
tenant_id: str,
api_audience: str = None,
scheme_name: str = None,
scopes: dict = None,
auto_error: bool = False,
enabled: bool = True,
use_pkce: bool = True,
strict: bool = True,
client_app_ids: Optional[List[str]] = None,
user_klass: type = User,
flow_type: OAuthFlowType = OAuthFlowType.authorizationCode, # type: ignore
roles: Optional[List[str]] = None,
jwks_cache_ttl: int = 0):
"""Initialise validator for AAD token based authentication."""
authorization_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/authorize"
token_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
self.key_url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys"
self.tenant_id = tenant_id
self._roles = roles
super().__init__(client_id=client_id, authorizationUrl=authorization_url, tokenUrl=token_url, api_audience=api_audience, scheme_name=scheme_name,
scopes=scopes, auto_error=auto_error, enabled=enabled, use_pkce=use_pkce, user_klass=user_klass, flow_type=flow_type)
self.strict = strict
if client_app_ids is None:
client_app_ids = []
self.client_app_ids = client_app_ids
self._jwks_cache: TTLCache = TTLCache(maxsize=1, ttl=jwks_cache_ttl)
def _request_ms_jwks(self) -> Dict[str, Any]:
if len(self._jwks_cache):
jwks = self._jwks_cache[None]
else:
jwks = requests.get(self.key_url).json()
self._jwks_cache[None] = jwks
return jwks
def _get_ms_jwk(self, token: str) -> JsonWebKey:
try:
self.logger.info(f'Getting signing keys from {self.key_url}')
token_header = token.split(".")[0].encode()
jwks = self._request_ms_jwks()
unverified_header = extract_header(token_header, jwt_errors.DecodeError)
# Authlib 1.0.0 changed behaviour of loads in 1.0.0
if version.parse(authlib_version) >= version.parse('1.0.0'):
# This can now raise a ValueError: Invalid JSON Web Key Set if
# the key is not found
return jwk.loads(jwks, unverified_header['kid'])
else:
for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
self.logger.info(f'Identified key {key["kid"]}')
return jwk.loads(key)
except jwt_errors.DecodeError:
self.logger.exception('Error parsing signing keys')
except ValueError:
self.logger.exception('Error finding key')
raise AuthenticationError("Unable to parse signing keys")
def _decode_token(self, token):
jwk_ = self._get_ms_jwk(token)
claims = None
self.logger.debug(f'Key is {jwk_}')
try:
if hasattr(jwk_, 'as_pem'):
# Authlib 1.0.0
key = jwk_.as_pem()
elif hasattr(jwk_, 'public_key'):
# Authlib 1.0.0
key = jwk_.public_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1)
elif hasattr(jwk_, 'public_bytes'):
key = jwk_.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1)
else:
key = jwk_.raw_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1)
self.logger.debug(f'Processed Key: {key}')
claims = jwt.decode(
token,
key,
)
except Exception:
self.logger.exception('Unable to parse error')
raise AuthenticationError("Unable to parse authentication token")
return claims
def _validate_claims(self, claims, options=None):
if options is None:
options = self._claims_options
# We need to do some 1.0/2.0 handling because it doesn't seem to work properly
# TODO: validate whether we want this claim here?
if 'appid' in options and 'azp' in options:
if 'appid' not in claims:
options.pop('appid')
elif 'azp'not in claims:
options.pop('azp')
if not ('appid' in claims or 'azp' in claims):
if self.strict:
self.logger.error('No appid/azp claims found in token')
raise AuthenticationError('No appid/azp claims found in token')
else:
self.logger.warning('No appid/azp claims found in token - we are ignoring for now')
return super()._validate_claims(claims, options)
@property
def _claims_options(self):
options = super()._claims_options
if self._roles:
roles = self._roles
if not isinstance(roles, str):
# This is a space delimited string of roles in the claims
roles = ' '.join(self._roles)
options['roles'] = roles
options["iss"] = {"essential": True, "values": [f"https://sts.windows.net/{self.tenant_id}/", f"https://login.microsoftonline.com/{self.tenant_id}/v2.0"]}
options["aud"] = {"essential": True, "values": [self.api_audience] + [self.client_id] + self.client_app_ids}
options["azp"] = {"essential": True, "values": [self.client_id] + self.client_app_ids}
options["appid"] = {"essential": True, "values": [self.client_id] + self.client_app_ids}
self.logger.debug(f'Claims options {options}')
return options
def _get_user_from_claims(self, claims):
self.logger.debug(f'Processing claims: {claims}')
username_key = 'preferred_username'
if username_key not in claims:
username_key = 'unique_name'
if 'name' not in claims and 'appid' in claims:
# This is an application/service principal
user = self._user_klass(name=claims['appid'], email='', username=claims['appid'], groups=claims.get('groups', None), roles=claims.get('roles', None), scopes=claims.get('scp', None))
else:
user = self._user_klass(name=claims['name'], email=claims[username_key], username=claims[username_key], groups=claims.get('groups', None), roles=claims.get('roles', None), scopes=claims.get('scp', None))
return user
[docs]class AADProvider(Provider):
"""fastapi auth backend for Azure Active Directory."""
name: str = 'aad'
def __init__(
self,
session_validator: SessionValidator,
client_id: str,
tenant_id: str,
prompt: Optional[str] = None,
client_secret: Optional[str] = None,
scopes: Optional[List[str]] = None,
enabled: bool = True,
client_app_ids: Optional[List[str]] = None,
strict_token: bool = True,
api_audience: Optional[str] = None,
redirect_uri: Optional[str] = None,
domain_hint: Optional[str] = None,
user_klass: type = User,
oauth_base_route: str = '/oauth',
token_type: Union[str, TokenType] = TokenType.access,
token_scopes: Optional[Dict[str, str]] = None,
flow_type: OAuthFlowType = OAuthFlowType.authorizationCode, # type: ignore
roles: Optional[List[str]] = None,
jwks_cache_ttl: int = 0):
"""Initialise the auth backend.
Args:
* session_serializer: Session serializer object
* client_id: Client ID from Azure App Registration
* tenant_id: Tenant ID to connect to for Azure App Registration
Keyword Args:
* prompt: Prompt options for Azure AD
* client_secret: Client secret value
* scopes: Additional scopes requested
* enabled: Boolean flag to enable this backend
* client_app_ids: List of client apps to accept tokens from
* strict_token: Strictly evaluate token
* api_audience: Api Audience declared in Azure AD App registration
* redirect_uri: Full URI for post authentication callbacks
* domain_hint: Hint for the domain
* user_klass: Class to use as a user.
* flows: Customise the OAuth2 Flows in the OpenAPI docs
"""
redirect_path = self._build_oauth_url(oauth_base_route, 'redirect')
if token_scopes is None:
token_scopes = {}
token_validator = AADTokenValidator(client_id=client_id, tenant_id=tenant_id, api_audience=api_audience,
client_app_ids=client_app_ids, scopes=token_scopes, enabled=enabled, strict=strict_token,
user_klass=user_klass, flow_type=flow_type, roles=roles, jwks_cache_ttl=jwks_cache_ttl)
session_authenticator = AADSessionAuthenticator(session_validator=session_validator, token_validator=token_validator,
client_id=client_id, tenant_id=tenant_id, redirect_path=redirect_path,
prompt=prompt, client_secret=client_secret, scopes=scopes,
redirect_uri=redirect_uri, domain_hint=domain_hint, token_type=token_type)
super().__init__(validators=[token_validator], authenticator=session_authenticator, enabled=enabled, oauth_base_route=oauth_base_route)
[docs] @classmethod
def from_config(cls, session_validator, config, provider_config, user_klass: Optional[type] = None):
"""Load the auth backend from a config.
Args:
session_validator (SessionValidator): the session validator to use
config: Loaded configuration
Keyword Args:
user_klass: The class to use as a user
"""
client_secret = provider_config.client_secret
if client_secret is not None:
client_secret = client_secret.get_secret_value() # type: ignore
if user_klass is None:
user_klass = config.user_klass
obj = cls(session_validator=session_validator, client_id=provider_config.client_id.get_secret_value(),
tenant_id=provider_config.tenant_id.get_secret_value(),
client_secret=client_secret, enabled=config.enabled, # type: ignore
scopes=provider_config.scopes, client_app_ids=provider_config.client_app_ids,
strict_token=provider_config.strict, api_audience=provider_config.api_audience,
prompt=provider_config.prompt, domain_hint=provider_config.domain_hint,
redirect_uri=provider_config.redirect_uri, user_klass=user_klass,
oauth_base_route=config.routing.oauth_base_route, token_type=provider_config.token_type,
token_scopes=provider_config.token_scopes, flow_type=provider_config.flow_type,
roles=provider_config.roles,
jwks_cache_ttl=provider_config.jwks_cache_ttl)
# We need to override the login and redirect etc until it is deprecated
if hasattr(config.routing, 'login_path') and config.routing.login_path and not is_deprecated(config.routing.__fields__['login_path']):
obj._login_url = config.routing.login_path
if hasattr(config.routing, 'login_redirect_path') and config.routing.login_redirect_path and not is_deprecated(config.routing.__fields__['login_redirect_path']):
obj._redirect_url = config.routing.login_redirect_path
obj.authenticator._redirect_path = obj.redirect_url # type: ignore
return obj
[docs]@expand_doc
class AADConfig(ProviderConfig):
"""Configuration for the AAD application.
Includes expected claims, application registration, etc.
Can also provide additional client application ids to accept.
A list of roles can be provided to accept (requires configuring the
roles in the AAD application registration manifest)
"""
client_id: SecretStr = Field(..., description="Application Registration Client ID", env='AAD_CLIENT_ID')
tenant_id: SecretStr = Field(..., description="Application Registration Tenant ID", env='AAD_TENANT_ID')
client_secret: Optional[SecretStr] = Field(None, description="Application Registration Client Secret (if required)", env='AAD_CLIENT_SECRET')
scopes: Optional[List[str]] = Field(None, description="Additional scopes requested - if the scope is not configured to the application this will throw an error when validating the token")
client_app_ids: Optional[List[str]] = Field(None, description="Additional Client App IDs to accept tokens from (when running as a backend service)",
env='AAD_CLIENT_APP_IDS')
strict: bool = Field(True, description="Check that all claims are provided", env='AAD_STRICT_CLAIM_CHECK')
api_audience: Optional[str] = Field(None, description="Corresponds to the Application ID URI - used for token validation, defaults to api://{client_id}",
env='AAD_API_AUDIENCE')
redirect_uri: Optional[HttpUrl] = Field(None, description="The redirect URI to use - overwrites the default path handling etc",
env='AAD_REDIRECT_URI')
prompt: Optional[str] = Field(None, description="AAD prompt to request", env='AAD_PROMPT')
domain_hint: Optional[str] = Field(None, description="AAD domain hint", env='AAD_DOMAIN_HINT')
roles: Optional[List[str]] = Field(None, description="AAD roles required in claims", env='AAD_ROLES')
token_type: TokenType = Field(TokenType.access, description='The AAD token type to use to validate (we should use the access token if it is configured, unless we are acting as a pure UI component')
jwks_cache_ttl: int = Field(0, description='Cache Time to Live value for MS JWKs token', env='AAD_JWKS_CACHE_TTL')
_provider_klass: type = PrivateAttr(AADProvider)
class Config: # noqa D106
env_file = '.env'
_validate_strict = validator('strict', allow_reuse=True)(bool_from_env)
_validate_client_app_ids = validator('client_app_ids', allow_reuse=True)(list_from_env)
_validate_roles = validator('roles', allow_reuse=True)(list_from_env)