mas_storage_pg/oauth2/
access_token.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Duration, Utc};
10use mas_data_model::{AccessToken, AccessTokenState, Clock, Session};
11use mas_storage::oauth2::OAuth2AccessTokenRepository;
12use rand::RngCore;
13use sqlx::PgConnection;
14use ulid::Ulid;
15use uuid::Uuid;
16
17use crate::{DatabaseError, tracing::ExecuteExt};
18
19/// An implementation of [`OAuth2AccessTokenRepository`] for a PostgreSQL
20/// connection
21pub struct PgOAuth2AccessTokenRepository<'c> {
22    conn: &'c mut PgConnection,
23}
24
25impl<'c> PgOAuth2AccessTokenRepository<'c> {
26    /// Create a new [`PgOAuth2AccessTokenRepository`] from an active PostgreSQL
27    /// connection
28    pub fn new(conn: &'c mut PgConnection) -> Self {
29        Self { conn }
30    }
31}
32
33struct OAuth2AccessTokenLookup {
34    oauth2_access_token_id: Uuid,
35    oauth2_session_id: Uuid,
36    access_token: String,
37    created_at: DateTime<Utc>,
38    expires_at: Option<DateTime<Utc>>,
39    revoked_at: Option<DateTime<Utc>>,
40    first_used_at: Option<DateTime<Utc>>,
41}
42
43impl From<OAuth2AccessTokenLookup> for AccessToken {
44    fn from(value: OAuth2AccessTokenLookup) -> Self {
45        let state = match value.revoked_at {
46            None => AccessTokenState::Valid,
47            Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
48        };
49
50        Self {
51            id: value.oauth2_access_token_id.into(),
52            state,
53            session_id: value.oauth2_session_id.into(),
54            access_token: value.access_token,
55            created_at: value.created_at,
56            expires_at: value.expires_at,
57            first_used_at: value.first_used_at,
58        }
59    }
60}
61
62#[async_trait]
63impl OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'_> {
64    type Error = DatabaseError;
65
66    async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
67        let res = sqlx::query_as!(
68            OAuth2AccessTokenLookup,
69            r#"
70                SELECT oauth2_access_token_id
71                     , access_token
72                     , created_at
73                     , expires_at
74                     , revoked_at
75                     , oauth2_session_id
76                     , first_used_at
77
78                FROM oauth2_access_tokens
79
80                WHERE oauth2_access_token_id = $1
81            "#,
82            Uuid::from(id),
83        )
84        .fetch_optional(&mut *self.conn)
85        .await?;
86
87        let Some(res) = res else { return Ok(None) };
88
89        Ok(Some(res.into()))
90    }
91
92    #[tracing::instrument(
93        name = "db.oauth2_access_token.find_by_token",
94        skip_all,
95        fields(
96            db.query.text,
97        ),
98        err,
99    )]
100    async fn find_by_token(
101        &mut self,
102        access_token: &str,
103    ) -> Result<Option<AccessToken>, Self::Error> {
104        let res = sqlx::query_as!(
105            OAuth2AccessTokenLookup,
106            r#"
107                SELECT oauth2_access_token_id
108                     , access_token
109                     , created_at
110                     , expires_at
111                     , revoked_at
112                     , oauth2_session_id
113                     , first_used_at
114
115                FROM oauth2_access_tokens
116
117                WHERE access_token = $1
118            "#,
119            access_token,
120        )
121        .fetch_optional(&mut *self.conn)
122        .await?;
123
124        let Some(res) = res else { return Ok(None) };
125
126        Ok(Some(res.into()))
127    }
128
129    #[tracing::instrument(
130        name = "db.oauth2_access_token.add",
131        skip_all,
132        fields(
133            db.query.text,
134            %session.id,
135            client.id = %session.client_id,
136            access_token.id,
137        ),
138        err,
139    )]
140    async fn add(
141        &mut self,
142        rng: &mut (dyn RngCore + Send),
143        clock: &dyn Clock,
144        session: &Session,
145        access_token: String,
146        expires_after: Option<Duration>,
147    ) -> Result<AccessToken, Self::Error> {
148        let created_at = clock.now();
149        let expires_at = expires_after.map(|d| created_at + d);
150        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
151
152        tracing::Span::current().record("access_token.id", tracing::field::display(id));
153
154        sqlx::query!(
155            r#"
156                INSERT INTO oauth2_access_tokens
157                    (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
158                VALUES
159                    ($1, $2, $3, $4, $5)
160            "#,
161            Uuid::from(id),
162            Uuid::from(session.id),
163            &access_token,
164            created_at,
165            expires_at,
166        )
167            .traced()
168        .execute(&mut *self.conn)
169        .await?;
170
171        Ok(AccessToken {
172            id,
173            state: AccessTokenState::default(),
174            access_token,
175            session_id: session.id,
176            created_at,
177            expires_at,
178            first_used_at: None,
179        })
180    }
181
182    #[tracing::instrument(
183        name = "db.oauth2_access_token.revoke",
184        skip_all,
185        fields(
186            db.query.text,
187            session.id = %access_token.session_id,
188            %access_token.id,
189        ),
190        err,
191    )]
192    async fn revoke(
193        &mut self,
194        clock: &dyn Clock,
195        access_token: AccessToken,
196    ) -> Result<AccessToken, Self::Error> {
197        let revoked_at = clock.now();
198        let res = sqlx::query!(
199            r#"
200                UPDATE oauth2_access_tokens
201                SET revoked_at = $2
202                WHERE oauth2_access_token_id = $1
203            "#,
204            Uuid::from(access_token.id),
205            revoked_at,
206        )
207        .traced()
208        .execute(&mut *self.conn)
209        .await?;
210
211        DatabaseError::ensure_affected_rows(&res, 1)?;
212
213        access_token
214            .revoke(revoked_at)
215            .map_err(DatabaseError::to_invalid_operation)
216    }
217
218    #[tracing::instrument(
219        name = "db.oauth2_access_token.mark_used",
220        skip_all,
221        fields(
222            db.query.text,
223            session.id = %access_token.session_id,
224            %access_token.id,
225        ),
226        err,
227    )]
228    async fn mark_used(
229        &mut self,
230        clock: &dyn Clock,
231        mut access_token: AccessToken,
232    ) -> Result<AccessToken, Self::Error> {
233        let now = clock.now();
234        let res = sqlx::query!(
235            r#"
236                UPDATE oauth2_access_tokens
237                SET first_used_at = $2
238                WHERE oauth2_access_token_id = $1
239            "#,
240            Uuid::from(access_token.id),
241            now,
242        )
243        .execute(&mut *self.conn)
244        .await?;
245
246        DatabaseError::ensure_affected_rows(&res, 1)?;
247
248        access_token.first_used_at = Some(now);
249
250        Ok(access_token)
251    }
252
253    #[tracing::instrument(
254        name = "db.oauth2_access_token.cleanup_revoked",
255        skip_all,
256        fields(
257            db.query.text,
258        ),
259        err,
260    )]
261    async fn cleanup_revoked(
262        &mut self,
263        since: Option<DateTime<Utc>>,
264        until: DateTime<Utc>,
265        limit: usize,
266    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
267        let res = sqlx::query!(
268            r#"
269                WITH
270                    to_delete AS (
271                        SELECT oauth2_access_token_id
272                        FROM oauth2_access_tokens
273                        WHERE revoked_at IS NOT NULL
274                          AND ($1::timestamptz IS NULL OR revoked_at >= $1::timestamptz)
275                          AND revoked_at < $2::timestamptz
276                        ORDER BY revoked_at ASC
277                        LIMIT $3
278                        FOR UPDATE
279                    ),
280
281                    deleted AS (
282                        DELETE FROM oauth2_access_tokens
283                        USING to_delete
284                        WHERE oauth2_access_tokens.oauth2_access_token_id = to_delete.oauth2_access_token_id
285                        RETURNING oauth2_access_tokens.revoked_at
286                    )
287
288                SELECT
289                    COUNT(*) as "count!",
290                    MAX(revoked_at) as last_revoked_at
291                FROM deleted
292            "#,
293            since,
294            until,
295            i64::try_from(limit).unwrap_or(i64::MAX),
296        )
297        .traced()
298        .fetch_one(&mut *self.conn)
299        .await?;
300
301        Ok((
302            res.count.try_into().unwrap_or(usize::MAX),
303            res.last_revoked_at,
304        ))
305    }
306}