1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

//! Requests for obtaining [Claims] about an end-user.
//!
//! [Claims]: https://openid.net/specs/openid-connect-core-1_0.html#Claims

use std::collections::HashMap;

use bytes::Bytes;
use headers::{Authorization, ContentType, HeaderMapExt, HeaderValue};
use http::header::ACCEPT;
use mas_http::CatchHttpCodesLayer;
use mas_jose::claims;
use mime::Mime;
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use url::Url;

use super::jose::JwtVerificationData;
use crate::{
    error::{IdTokenError, UserInfoError},
    http_service::HttpService,
    requests::jose::verify_signed_jwt,
    types::IdToken,
    utils::{http_all_error_status_codes, http_error_mapper},
};

/// Obtain information about an authenticated end-user.
///
/// Returns a map of claims with their value, that should be extracted with
/// one of the [`Claim`] methods.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `userinfo_endpoint` - The URL of the issuer's User Info endpoint.
///
/// * `access_token` - The access token of the end-user.
///
/// * `jwt_verification_data` - The data required to verify the response if a
///   signed response was requested during client registration.
///
///   The signing algorithm corresponds to the `userinfo_signed_response_alg`
///   field in the client metadata.
///
/// * `auth_id_token` - The ID token that was returned from the latest
///   authorization request.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// validation of the signed response fails.
///
/// [`Claim`]: mas_jose::claims::Claim
#[tracing::instrument(skip_all, fields(userinfo_endpoint))]
pub async fn fetch_userinfo(
    http_service: &HttpService,
    userinfo_endpoint: &Url,
    access_token: &str,
    jwt_verification_data: Option<JwtVerificationData<'_>>,
    auth_id_token: &IdToken<'_>,
) -> Result<HashMap<String, Value>, UserInfoError> {
    tracing::debug!("Obtaining user info…");

    let mut userinfo_request = http::Request::get(userinfo_endpoint.as_str());

    let expected_content_type = if jwt_verification_data.is_some() {
        "application/jwt"
    } else {
        mime::APPLICATION_JSON.as_ref()
    };

    if let Some(headers) = userinfo_request.headers_mut() {
        headers.typed_insert(Authorization::bearer(access_token)?);
        headers.insert(ACCEPT, HeaderValue::from_static(expected_content_type));
    }

    let userinfo_request = userinfo_request.body(Bytes::new())?;

    let service = CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper)
        .layer(http_service.clone());

    let userinfo_response = service
        .ready_oneshot()
        .await?
        .call(userinfo_request)
        .await?;

    let content_type: Mime = userinfo_response
        .headers()
        .typed_try_get::<ContentType>()
        .map_err(|_| UserInfoError::InvalidResponseContentTypeValue)?
        .ok_or(UserInfoError::MissingResponseContentType)?
        .into();

    if content_type.essence_str() != expected_content_type {
        return Err(UserInfoError::UnexpectedResponseContentType {
            expected: expected_content_type.to_owned(),
            got: content_type.to_string(),
        });
    }

    let response_body = std::str::from_utf8(userinfo_response.body())?;

    let mut claims = if let Some(verification_data) = jwt_verification_data {
        verify_signed_jwt(response_body, verification_data)
            .map_err(IdTokenError::from)?
            .into_parts()
            .1
    } else {
        serde_json::from_str(response_body)?
    };

    let mut auth_claims = auth_id_token.payload().clone();

    // Subject identifier must always be the same.
    let sub = claims::SUB
        .extract_required(&mut claims)
        .map_err(IdTokenError::from)?;
    let auth_sub = claims::SUB
        .extract_required(&mut auth_claims)
        .map_err(IdTokenError::from)?;
    if sub != auth_sub {
        return Err(IdTokenError::WrongSubjectIdentifier.into());
    }

    Ok(claims)
}