mas_data_model/
tokens.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Utc};
9use crc::{CRC_32_ISO_HDLC, Crc};
10use mas_iana::oauth::OAuthTokenTypeHint;
11use rand::{Rng, RngCore, distributions::Alphanumeric};
12use thiserror::Error;
13use ulid::Ulid;
14
15use crate::InvalidTransitionError;
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
18pub enum AccessTokenState {
19    #[default]
20    Valid,
21    Revoked {
22        revoked_at: DateTime<Utc>,
23    },
24}
25
26impl AccessTokenState {
27    fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
28        match self {
29            Self::Valid => Ok(Self::Revoked { revoked_at }),
30            Self::Revoked { .. } => Err(InvalidTransitionError),
31        }
32    }
33
34    /// Returns `true` if the refresh token state is [`Valid`].
35    ///
36    /// [`Valid`]: AccessTokenState::Valid
37    #[must_use]
38    pub fn is_valid(&self) -> bool {
39        matches!(self, Self::Valid)
40    }
41
42    /// Returns `true` if the refresh token state is [`Revoked`].
43    ///
44    /// [`Revoked`]: AccessTokenState::Revoked
45    #[must_use]
46    pub fn is_revoked(&self) -> bool {
47        matches!(self, Self::Revoked { .. })
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct AccessToken {
53    pub id: Ulid,
54    pub state: AccessTokenState,
55    pub session_id: Ulid,
56    pub access_token: String,
57    pub created_at: DateTime<Utc>,
58    pub expires_at: Option<DateTime<Utc>>,
59    pub first_used_at: Option<DateTime<Utc>>,
60}
61
62impl AccessToken {
63    #[must_use]
64    pub fn jti(&self) -> String {
65        self.id.to_string()
66    }
67
68    /// Whether the access token is valid, i.e. not revoked and not expired
69    ///
70    /// # Parameters
71    ///
72    /// * `now` - The current time
73    #[must_use]
74    pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
75        self.state.is_valid() && !self.is_expired(now)
76    }
77
78    /// Whether the access token is expired
79    ///
80    /// Always returns `false` if the access token does not have an expiry time.
81    ///
82    /// # Parameters
83    ///
84    /// * `now` - The current time
85    #[must_use]
86    pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
87        match self.expires_at {
88            Some(expires_at) => expires_at < now,
89            None => false,
90        }
91    }
92
93    /// Whether the access token was used at least once
94    #[must_use]
95    pub fn is_used(&self) -> bool {
96        self.first_used_at.is_some()
97    }
98
99    /// Mark the access token as revoked
100    ///
101    /// # Parameters
102    ///
103    /// * `revoked_at` - The time at which the access token was revoked
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the access token is already revoked
108    pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
109        self.state = self.state.revoke(revoked_at)?;
110        Ok(self)
111    }
112}
113
114#[derive(Debug, Clone, Default, PartialEq, Eq)]
115pub enum RefreshTokenState {
116    #[default]
117    Valid,
118    Consumed {
119        consumed_at: DateTime<Utc>,
120        next_refresh_token_id: Option<Ulid>,
121    },
122    Revoked {
123        revoked_at: DateTime<Utc>,
124    },
125}
126
127impl RefreshTokenState {
128    /// Consume the refresh token, returning a new state.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the refresh token is revoked.
133    fn consume(
134        self,
135        consumed_at: DateTime<Utc>,
136        replaced_by: &RefreshToken,
137    ) -> Result<Self, InvalidTransitionError> {
138        match self {
139            Self::Valid | Self::Consumed { .. } => Ok(Self::Consumed {
140                consumed_at,
141                next_refresh_token_id: Some(replaced_by.id),
142            }),
143            Self::Revoked { .. } => Err(InvalidTransitionError),
144        }
145    }
146
147    /// Revoke the refresh token, returning a new state.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if the refresh token is already consumed or revoked.
152    pub fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
153        match self {
154            Self::Valid => Ok(Self::Revoked { revoked_at }),
155            Self::Consumed { .. } | Self::Revoked { .. } => Err(InvalidTransitionError),
156        }
157    }
158
159    /// Returns `true` if the refresh token state is [`Valid`].
160    ///
161    /// [`Valid`]: RefreshTokenState::Valid
162    #[must_use]
163    pub fn is_valid(&self) -> bool {
164        matches!(self, Self::Valid)
165    }
166
167    /// Returns the next refresh token ID, if any.
168    #[must_use]
169    pub fn next_refresh_token_id(&self) -> Option<Ulid> {
170        match self {
171            Self::Valid | Self::Revoked { .. } => None,
172            Self::Consumed {
173                next_refresh_token_id,
174                ..
175            } => *next_refresh_token_id,
176        }
177    }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct RefreshToken {
182    pub id: Ulid,
183    pub state: RefreshTokenState,
184    pub refresh_token: String,
185    pub session_id: Ulid,
186    pub created_at: DateTime<Utc>,
187    pub access_token_id: Option<Ulid>,
188}
189
190impl std::ops::Deref for RefreshToken {
191    type Target = RefreshTokenState;
192
193    fn deref(&self) -> &Self::Target {
194        &self.state
195    }
196}
197
198impl RefreshToken {
199    #[must_use]
200    pub fn jti(&self) -> String {
201        self.id.to_string()
202    }
203
204    /// Consumes the refresh token and returns the consumed token.
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if the refresh token is revoked.
209    pub fn consume(
210        mut self,
211        consumed_at: DateTime<Utc>,
212        replaced_by: &Self,
213    ) -> Result<Self, InvalidTransitionError> {
214        self.state = self.state.consume(consumed_at, replaced_by)?;
215        Ok(self)
216    }
217
218    /// Revokes the refresh token and returns a new revoked token
219    ///
220    /// # Errors
221    ///
222    /// Returns an error if the refresh token is already revoked.
223    pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
224        self.state = self.state.revoke(revoked_at)?;
225        Ok(self)
226    }
227}
228
229/// Type of token to generate or validate
230#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum TokenType {
232    /// An access token, used by Relying Parties to authenticate requests
233    AccessToken,
234
235    /// A refresh token, used by the refresh token grant
236    RefreshToken,
237
238    /// A legacy access token
239    CompatAccessToken,
240
241    /// A legacy refresh token
242    CompatRefreshToken,
243
244    /// A personal access token.
245    PersonalAccessToken,
246}
247
248impl std::fmt::Display for TokenType {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        match self {
251            TokenType::AccessToken => write!(f, "access token"),
252            TokenType::RefreshToken => write!(f, "refresh token"),
253            TokenType::CompatAccessToken => write!(f, "compat access token"),
254            TokenType::CompatRefreshToken => write!(f, "compat refresh token"),
255            TokenType::PersonalAccessToken => write!(f, "personal access token"),
256        }
257    }
258}
259
260impl TokenType {
261    fn prefix(self) -> &'static str {
262        match self {
263            TokenType::AccessToken => "mat",
264            TokenType::RefreshToken => "mar",
265            TokenType::CompatAccessToken => "mct",
266            TokenType::CompatRefreshToken => "mcr",
267            TokenType::PersonalAccessToken => "mpt",
268        }
269    }
270
271    fn match_prefix(prefix: &str) -> Option<Self> {
272        match prefix {
273            "mat" => Some(TokenType::AccessToken),
274            "mar" => Some(TokenType::RefreshToken),
275            "mct" | "syt" => Some(TokenType::CompatAccessToken),
276            "mcr" | "syr" => Some(TokenType::CompatRefreshToken),
277            "mpt" => Some(TokenType::PersonalAccessToken),
278            _ => None,
279        }
280    }
281
282    /// Generate a token for the given type
283    pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String {
284        let random_part: String = rng
285            .sample_iter(&Alphanumeric)
286            .take(30)
287            .map(char::from)
288            .collect();
289
290        let base = format!("{prefix}_{random_part}", prefix = self.prefix());
291        let crc = CRC.checksum(base.as_bytes());
292        let crc = base62_encode(crc);
293        format!("{base}_{crc}")
294    }
295
296    /// Check the format of a token and determine its type
297    ///
298    /// # Errors
299    ///
300    /// Returns an error if the token is not valid
301    pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
302        // these are legacy tokens imported from Synapse
303        // we don't do any validation on them and continue as is
304        if token.starts_with("syt_") || is_likely_synapse_macaroon(token) {
305            return Ok(TokenType::CompatAccessToken);
306        }
307        if token.starts_with("syr_") {
308            return Ok(TokenType::CompatRefreshToken);
309        }
310
311        let split: Vec<&str> = token.split('_').collect();
312        let [prefix, random_part, crc]: [&str; 3] = split
313            .try_into()
314            .map_err(|_| TokenFormatError::InvalidFormat)?;
315
316        if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
317            return Err(TokenFormatError::InvalidFormat);
318        }
319
320        let token_type =
321            TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
322                prefix: prefix.to_owned(),
323            })?;
324
325        let base = format!("{prefix}_{random_part}", prefix = token_type.prefix());
326        let expected_crc = CRC.checksum(base.as_bytes());
327        let expected_crc = base62_encode(expected_crc);
328        if crc != expected_crc {
329            return Err(TokenFormatError::InvalidCrc {
330                expected: expected_crc,
331                got: crc.to_owned(),
332            });
333        }
334
335        Ok(token_type)
336    }
337}
338
339impl PartialEq<OAuthTokenTypeHint> for TokenType {
340    fn eq(&self, other: &OAuthTokenTypeHint) -> bool {
341        matches!(
342            (self, other),
343            (
344                TokenType::AccessToken
345                    | TokenType::CompatAccessToken
346                    | TokenType::PersonalAccessToken,
347                OAuthTokenTypeHint::AccessToken
348            ) | (
349                TokenType::RefreshToken | TokenType::CompatRefreshToken,
350                OAuthTokenTypeHint::RefreshToken
351            )
352        )
353    }
354}
355
356/// Returns true if and only if a token looks like it may be a macaroon.
357///
358/// Macaroons are a standard for tokens that support attenuation.
359/// Synapse used them for old sessions and for guest sessions.
360///
361/// We won't bother to decode them fully, but we can check to see if the first
362/// constraint is the `location` constraint.
363fn is_likely_synapse_macaroon(token: &str) -> bool {
364    let Ok(decoded) = Base64UrlUnpadded::decode_vec(token) else {
365        return false;
366    };
367    decoded.get(4..13) == Some(b"location ")
368}
369
370const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
371
372fn base62_encode(mut num: u32) -> String {
373    let mut res = String::with_capacity(6);
374    while num > 0 {
375        res.push(NUM[(num % 62) as usize] as char);
376        num /= 62;
377    }
378
379    format!("{res:0>6}")
380}
381
382const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
383
384/// Invalid token
385#[derive(Debug, Error, PartialEq, Eq)]
386pub enum TokenFormatError {
387    /// Overall token format is invalid
388    #[error("invalid token format")]
389    InvalidFormat,
390
391    /// Token used an unknown prefix
392    #[error("unknown token prefix {prefix:?}")]
393    UnknownPrefix {
394        /// The prefix found in the token
395        prefix: String,
396    },
397
398    /// The CRC checksum in the token is invalid
399    #[error("invalid crc {got:?}, expected {expected:?}")]
400    InvalidCrc {
401        /// The CRC hash expected to be found in the token
402        expected: String,
403        /// The CRC found in the token
404        got: String,
405    },
406}
407
408#[cfg(test)]
409mod tests {
410    use std::collections::HashSet;
411
412    use rand::thread_rng;
413
414    use super::*;
415
416    #[test]
417    fn test_prefix_match() {
418        use TokenType::{AccessToken, CompatAccessToken, CompatRefreshToken, RefreshToken};
419        assert_eq!(TokenType::match_prefix("syt"), Some(CompatAccessToken));
420        assert_eq!(TokenType::match_prefix("syr"), Some(CompatRefreshToken));
421        assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken));
422        assert_eq!(TokenType::match_prefix("mcr"), Some(CompatRefreshToken));
423        assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
424        assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
425        assert_eq!(TokenType::match_prefix("matt"), None);
426        assert_eq!(TokenType::match_prefix("marr"), None);
427        assert_eq!(TokenType::match_prefix("ma"), None);
428        assert_eq!(
429            TokenType::match_prefix(TokenType::CompatAccessToken.prefix()),
430            Some(TokenType::CompatAccessToken)
431        );
432        assert_eq!(
433            TokenType::match_prefix(TokenType::CompatRefreshToken.prefix()),
434            Some(TokenType::CompatRefreshToken)
435        );
436        assert_eq!(
437            TokenType::match_prefix(TokenType::AccessToken.prefix()),
438            Some(TokenType::AccessToken)
439        );
440        assert_eq!(
441            TokenType::match_prefix(TokenType::RefreshToken.prefix()),
442            Some(TokenType::RefreshToken)
443        );
444    }
445
446    #[test]
447    fn test_is_likely_synapse_macaroon() {
448        // This is just the prefix of a Synapse macaroon, but it's enough to make the
449        // sniffing work
450        assert!(is_likely_synapse_macaroon(
451            "MDAxYmxvY2F0aW9uIGxpYnJlcHVzaC5uZXQKMDAx"
452        ));
453
454        // This is a valid macaroon (even though Synapse did not generate this one)
455        assert!(is_likely_synapse_macaroon(
456            "MDAxY2xvY2F0aW9uIGh0dHA6Ly9teWJhbmsvCjAwMjZpZGVudGlmaWVyIHdlIHVzZWQgb3VyIHNlY3JldCBrZXkKMDAyZnNpZ25hdHVyZSDj2eApCFJsTAA5rhURQRXZf91ovyujebNCqvD2F9BVLwo"
457        ));
458
459        // None of these are macaroons
460        assert!(!is_likely_synapse_macaroon(
461            "eyJARTOhearotnaeisahtoarsnhiasra.arsohenaor.oarnsteao"
462        ));
463        assert!(!is_likely_synapse_macaroon("...."));
464        assert!(!is_likely_synapse_macaroon("aaa"));
465    }
466
467    #[test]
468    fn test_generate_and_check() {
469        const COUNT: usize = 500; // Generate 500 of each token type
470
471        #[allow(clippy::disallowed_methods)]
472        let mut rng = thread_rng();
473
474        for t in [
475            TokenType::CompatAccessToken,
476            TokenType::CompatRefreshToken,
477            TokenType::AccessToken,
478            TokenType::RefreshToken,
479        ] {
480            // Generate many tokens
481            let tokens: HashSet<String> = (0..COUNT).map(|_| t.generate(&mut rng)).collect();
482
483            // Check that they are all different
484            assert_eq!(tokens.len(), COUNT, "All tokens are unique");
485
486            // Check that they are all valid and detected as the right token type
487            for token in tokens {
488                assert_eq!(TokenType::check(&token).unwrap(), t);
489            }
490        }
491    }
492}