use std::collections::HashMap;
use async_trait::async_trait;
use axum::{
extract::{
rejection::{FailedToDeserializeForm, FormRejection},
Form, FromRequest, FromRequestParts,
},
response::IntoResponse,
BoxError, Json,
};
use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
use headers::{authorization::Basic, Authorization};
use http::{Request, StatusCode};
use mas_data_model::{Client, JwksOrJwksUri};
use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
use mas_keystore::Encrypter;
use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess};
use oauth2_types::errors::{ClientError, ClientErrorCode};
use serde::{de::DeserializeOwned, Deserialize};
use serde_json::Value;
use thiserror::Error;
use tower::{Service, ServiceExt};
use crate::http_client_factory::HttpClientFactory;
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
#[derive(Deserialize)]
struct AuthorizedForm<F = ()> {
client_id: Option<String>,
client_secret: Option<String>,
client_assertion_type: Option<String>,
client_assertion: Option<String>,
#[serde(flatten)]
inner: F,
}
#[derive(Debug, PartialEq, Eq)]
pub enum Credentials {
None {
client_id: String,
},
ClientSecretBasic {
client_id: String,
client_secret: String,
},
ClientSecretPost {
client_id: String,
client_secret: String,
},
ClientAssertionJwtBearer {
client_id: String,
jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
},
}
impl Credentials {
#[must_use]
pub fn client_id(&self) -> &str {
match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
}
}
pub async fn fetch<E>(
&self,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<Client>, E> {
let client_id = match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
};
repo.oauth2_client().find_by_client_id(client_id).await
}
#[tracing::instrument(skip_all, err)]
pub async fn verify(
&self,
http_client_factory: &HttpClientFactory,
encrypter: &Encrypter,
method: &OAuthClientAuthenticationMethod,
client: &Client,
) -> Result<(), CredentialsVerificationError> {
match (self, method) {
(Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
(
Credentials::ClientSecretPost { client_secret, .. },
OAuthClientAuthenticationMethod::ClientSecretPost,
)
| (
Credentials::ClientSecretBasic { client_secret, .. },
OAuthClientAuthenticationMethod::ClientSecretBasic,
) => {
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let decrypted_client_secret = encrypter
.decrypt_string(encrypted_client_secret)
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
if client_secret.as_bytes() != decrypted_client_secret {
return Err(CredentialsVerificationError::ClientSecretMismatch);
}
}
(
Credentials::ClientAssertionJwtBearer { jwt, .. },
OAuthClientAuthenticationMethod::PrivateKeyJwt,
) => {
let jwks = client
.jwks
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let jwks = fetch_jwks(http_client_factory, jwks)
.await
.map_err(|_| CredentialsVerificationError::JwksFetchFailed)?;
jwt.verify_with_jwks(&jwks)
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
(
Credentials::ClientAssertionJwtBearer { jwt, .. },
OAuthClientAuthenticationMethod::ClientSecretJwt,
) => {
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let decrypted_client_secret = encrypter
.decrypt_string(encrypted_client_secret)
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
jwt.verify_with_shared_secret(decrypted_client_secret)
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
(_, _) => {
return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
}
};
Ok(())
}
}
async fn fetch_jwks(
http_client_factory: &HttpClientFactory,
jwks: &JwksOrJwksUri,
) -> Result<PublicJsonWebKeySet, BoxError> {
let uri = match jwks {
JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
JwksOrJwksUri::JwksUri(u) => u,
};
let request = http::Request::builder()
.uri(uri.as_str())
.body(mas_http::EmptyBody::new())
.unwrap();
let mut client = http_client_factory
.client("client.fetch_jwks")
.response_body_to_bytes()
.json_response::<PublicJsonWebKeySet>();
let response = client.ready().await?.call(request).await?;
Ok(response.into_body())
}
#[derive(Debug, Error)]
pub enum CredentialsVerificationError {
#[error("failed to decrypt client credentials")]
DecryptionError,
#[error("invalid client configuration")]
InvalidClientConfig,
#[error("client secret did not match")]
ClientSecretMismatch,
#[error("authentication method mismatch")]
AuthenticationMethodMismatch,
#[error("invalid assertion signature")]
InvalidAssertionSignature,
#[error("failed to fetch jwks")]
JwksFetchFailed,
}
#[derive(Debug, PartialEq, Eq)]
pub struct ClientAuthorization<F = ()> {
pub credentials: Credentials,
pub form: Option<F>,
}
impl<F> ClientAuthorization<F> {
#[must_use]
pub fn client_id(&self) -> &str {
self.credentials.client_id()
}
}
#[derive(Debug)]
pub enum ClientAuthorizationError {
InvalidHeader,
BadForm(FailedToDeserializeForm),
ClientIdMismatch { credential: String, form: String },
UnsupportedClientAssertion { client_assertion_type: String },
MissingCredentials,
InvalidRequest,
InvalidAssertion,
Internal(Box<dyn std::error::Error>),
}
impl IntoResponse for ClientAuthorizationError {
fn into_response(self) -> axum::response::Response {
match self {
ClientAuthorizationError::InvalidHeader => (
StatusCode::BAD_REQUEST,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"Invalid Authorization header",
)),
),
ClientAuthorizationError::BadForm(err) => (
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidRequest)
.with_description(format!("{err}")),
),
),
ClientAuthorizationError::ClientIdMismatch { form, credential } => {
let description = format!(
"client_id in form ({form:?}) does not match credential ({credential:?})"
);
(
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidGrant)
.with_description(description),
),
)
}
ClientAuthorizationError::UnsupportedClientAssertion {
client_assertion_type,
} => (
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidRequest).with_description(format!(
"Unsupported client_assertion_type: {client_assertion_type}",
)),
),
),
ClientAuthorizationError::MissingCredentials => (
StatusCode::BAD_REQUEST,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"No credentials were presented",
)),
),
ClientAuthorizationError::InvalidRequest => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
),
ClientAuthorizationError::InvalidAssertion => (
StatusCode::BAD_REQUEST,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"Invalid client_assertion",
)),
),
ClientAuthorizationError::Internal(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(
ClientError::from(ClientErrorCode::ServerError)
.with_description(format!("{e}")),
),
),
}
.into_response()
}
}
#[async_trait]
impl<S, F> FromRequest<S> for ClientAuthorization<F>
where
F: DeserializeOwned,
S: Send + Sync,
{
type Rejection = ClientAuthorizationError;
#[allow(clippy::too_many_lines)]
async fn from_request(
req: Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
let header =
TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
let credentials_from_header = match header {
Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
Err(err) => match err.reason() {
TypedHeaderRejectionReason::Missing => None,
_ => return Err(ClientAuthorizationError::InvalidHeader),
},
};
let req = Request::from_parts(parts, body);
let (
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
form,
) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
Ok(Form(form)) => (
form.client_id,
form.client_secret,
form.client_assertion_type,
form.client_assertion,
Some(form.inner),
),
Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
Err(FormRejection::FailedToDeserializeForm(err)) => {
return Err(ClientAuthorizationError::BadForm(err))
}
Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
};
let credentials = match (
credentials_from_header,
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
) {
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
if let Some(client_id_from_form) = client_id_from_form {
if client_id != client_id_from_form {
return Err(ClientAuthorizationError::ClientIdMismatch {
credential: client_id,
form: client_id_from_form,
});
}
}
Credentials::ClientSecretBasic {
client_id,
client_secret,
}
}
(None, Some(client_id), Some(client_secret), None, None) => {
Credentials::ClientSecretPost {
client_id,
client_secret,
}
}
(None, Some(client_id), None, None, None) => {
Credentials::None { client_id }
}
(
None,
client_id_from_form,
None,
Some(client_assertion_type),
Some(client_assertion),
) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
client_id.clone()
} else {
return Err(ClientAuthorizationError::InvalidAssertion);
};
if let Some(client_id_from_form) = client_id_from_form {
if client_id != client_id_from_form {
return Err(ClientAuthorizationError::ClientIdMismatch {
credential: client_id,
form: client_id_from_form,
});
}
}
Credentials::ClientAssertionJwtBearer {
client_id,
jwt: Box::new(jwt),
}
}
(None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
return Err(ClientAuthorizationError::UnsupportedClientAssertion {
client_assertion_type,
});
}
(None, None, None, None, None) => {
return Err(ClientAuthorizationError::MissingCredentials);
}
_ => {
return Err(ClientAuthorizationError::InvalidRequest);
}
};
Ok(ClientAuthorization { credentials, form })
}
}
#[cfg(test)]
mod tests {
use axum::body::Body;
use http::{Method, Request};
use super::*;
#[tokio::test]
async fn none_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(Body::new("client_id=client-id&foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::None {
client_id: "client-id".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
#[tokio::test]
async fn client_secret_basic_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Body::new("foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretBasic {
client_id: "client-id".to_owned(),
client_secret: "client-secret".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Body::new("client_id=client-id&foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretBasic {
client_id: "client-id".to_owned(),
client_secret: "client-secret".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
.unwrap();
assert!(matches!(
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
));
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(http::header::AUTHORIZATION, "Basic invalid")
.body(Body::new("foo=bar".to_owned()))
.unwrap();
assert!(matches!(
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
Err(ClientAuthorizationError::InvalidHeader),
));
}
#[tokio::test]
async fn client_secret_post_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(Body::new(
"client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretPost {
client_id: "client-id".to_owned(),
client_secret: "client-secret".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
#[tokio::test]
async fn client_assertion_test() {
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
let body = Body::new(format!(
"client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
));
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(body)
.unwrap();
let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap();
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
panic!("expected a JWT client_assertion");
};
assert_eq!(client_id, "client-id");
jwt.verify_with_shared_secret(b"client-secret".to_vec())
.unwrap();
}
}