mas_tasks/
new_queue.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{collections::HashMap, sync::Arc};
8
9use async_trait::async_trait;
10use chrono::{DateTime, Duration, Utc};
11use cron::Schedule;
12use mas_context::LogContext;
13use mas_data_model::Clock;
14use mas_storage::{
15    RepositoryAccess, RepositoryError,
16    queue::{InsertableJob, Job, JobMetadata, Worker},
17};
18use mas_storage_pg::{DatabaseError, PgRepository};
19use opentelemetry::{
20    KeyValue,
21    metrics::{Counter, Histogram, UpDownCounter},
22};
23use rand::{Rng, RngCore, distributions::Uniform};
24use serde::de::DeserializeOwned;
25use sqlx::{
26    Acquire, Either,
27    postgres::{PgAdvisoryLock, PgListener},
28};
29use thiserror::Error;
30use tokio::{task::JoinSet, time::Instant};
31use tokio_util::sync::CancellationToken;
32use tracing::{Instrument as _, Span};
33use tracing_opentelemetry::OpenTelemetrySpanExt as _;
34use ulid::Ulid;
35
36use crate::{METER, State};
37
38type JobPayload = serde_json::Value;
39
40#[derive(Clone)]
41pub struct JobContext {
42    pub id: Ulid,
43    pub metadata: JobMetadata,
44    pub queue_name: String,
45    pub attempt: usize,
46    pub start: Instant,
47    pub cancellation_token: CancellationToken,
48}
49
50impl JobContext {
51    pub fn span(&self) -> Span {
52        let span = tracing::info_span!(
53            parent: Span::none(),
54            "job.run",
55            job.id = %self.id,
56            job.queue.name = self.queue_name,
57            job.attempt = self.attempt,
58        );
59
60        span.add_link(self.metadata.span_context());
61
62        span
63    }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
67pub enum JobErrorDecision {
68    Retry,
69
70    #[default]
71    Fail,
72}
73
74impl std::fmt::Display for JobErrorDecision {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            Self::Retry => f.write_str("retry"),
78            Self::Fail => f.write_str("fail"),
79        }
80    }
81}
82
83#[derive(Debug, Error)]
84#[error("Job failed to run, will {decision}")]
85pub struct JobError {
86    decision: JobErrorDecision,
87    #[source]
88    error: anyhow::Error,
89}
90
91impl JobError {
92    pub fn retry<T: Into<anyhow::Error>>(error: T) -> Self {
93        Self {
94            decision: JobErrorDecision::Retry,
95            error: error.into(),
96        }
97    }
98
99    pub fn fail<T: Into<anyhow::Error>>(error: T) -> Self {
100        Self {
101            decision: JobErrorDecision::Fail,
102            error: error.into(),
103        }
104    }
105}
106
107pub trait FromJob {
108    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
109    where
110        Self: Sized;
111}
112
113impl<T> FromJob for T
114where
115    T: DeserializeOwned,
116{
117    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
118        serde_json::from_value(payload).map_err(Into::into)
119    }
120}
121
122#[async_trait]
123pub trait RunnableJob: FromJob + Send + 'static {
124    async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>;
125
126    /// Allows the job to set a timeout for its execution. Jobs should then look
127    /// at the cancellation token passed in the [`JobContext`] to handle
128    /// graceful shutdowns.
129    fn timeout(&self) -> Option<std::time::Duration> {
130        None
131    }
132}
133
134fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
135    Box::new(job)
136}
137
138#[derive(Debug, Error)]
139pub enum QueueRunnerError {
140    #[error("Failed to setup listener")]
141    SetupListener(#[source] sqlx::Error),
142
143    #[error("Failed to start transaction")]
144    StartTransaction(#[source] sqlx::Error),
145
146    #[error("Failed to commit transaction")]
147    CommitTransaction(#[source] sqlx::Error),
148
149    #[error("Failed to acquire leader lock")]
150    LeaderLock(#[source] sqlx::Error),
151
152    #[error(transparent)]
153    Repository(#[from] RepositoryError),
154
155    #[error(transparent)]
156    Database(#[from] DatabaseError),
157
158    #[error("Invalid schedule expression")]
159    InvalidSchedule(#[from] cron::error::Error),
160
161    #[error("Worker is not the leader")]
162    NotLeader,
163}
164
165// When the worker waits for a notification, we still want to wake it up every
166// second. Because we don't want all the workers to wake up at the same time, we
167// add a random jitter to the sleep duration, so they effectively sleep between
168// 0.9 and 1.1 seconds.
169const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
170const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
171
172// How many jobs can we run concurrently
173const MAX_CONCURRENT_JOBS: usize = 10;
174
175// How many jobs can we fetch at once
176const MAX_JOBS_TO_FETCH: usize = 5;
177
178// How many attempts a job should be retried
179const MAX_ATTEMPTS: usize = 10;
180
181/// Returns the delay to wait before retrying a job
182///
183/// Uses an exponential backoff: 5s, 10s, 20s, 40s, 1m20s, 2m40s, 5m20s, 10m50s,
184/// 21m40s, 43m20s
185fn retry_delay(attempt: usize) -> Duration {
186    let attempt = u32::try_from(attempt).unwrap_or(u32::MAX);
187    Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
188}
189
190type JobResult = (std::time::Duration, Result<(), JobError>);
191type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
192
193struct ScheduleDefinition {
194    schedule_name: &'static str,
195    expression: Schedule,
196    queue_name: &'static str,
197    payload: serde_json::Value,
198}
199
200pub struct QueueWorker {
201    listener: PgListener,
202    registration: Worker,
203    am_i_leader: bool,
204    last_heartbeat: DateTime<Utc>,
205    cancellation_token: CancellationToken,
206    #[expect(dead_code, reason = "This is used on Drop")]
207    cancellation_guard: tokio_util::sync::DropGuard,
208    state: State,
209    schedules: Vec<ScheduleDefinition>,
210    tracker: JobTracker,
211    wakeup_reason: Counter<u64>,
212    tick_time: Histogram<u64>,
213}
214
215impl QueueWorker {
216    #[tracing::instrument(
217        name = "worker.init",
218        skip_all,
219        fields(worker.id)
220    )]
221    pub(crate) async fn new(
222        state: State,
223        cancellation_token: CancellationToken,
224    ) -> Result<Self, QueueRunnerError> {
225        let mut rng = state.rng();
226        let clock = state.clock();
227
228        let mut listener = PgListener::connect_with(&state.pool())
229            .await
230            .map_err(QueueRunnerError::SetupListener)?;
231
232        // We get notifications of leader stepping down on this channel
233        listener
234            .listen("queue_leader_stepdown")
235            .await
236            .map_err(QueueRunnerError::SetupListener)?;
237
238        // We get notifications when a job is available on this channel
239        listener
240            .listen("queue_available")
241            .await
242            .map_err(QueueRunnerError::SetupListener)?;
243
244        let txn = listener
245            .begin()
246            .await
247            .map_err(QueueRunnerError::StartTransaction)?;
248        let mut repo = PgRepository::from_conn(txn);
249
250        let registration = repo.queue_worker().register(&mut rng, clock).await?;
251        tracing::Span::current().record("worker.id", tracing::field::display(registration.id));
252        repo.into_inner()
253            .commit()
254            .await
255            .map_err(QueueRunnerError::CommitTransaction)?;
256
257        tracing::info!(worker.id = %registration.id, "Registered worker");
258        let now = clock.now();
259
260        let wakeup_reason = METER
261            .u64_counter("job.worker.wakeups")
262            .with_description("Counts how many time the worker has been woken up, for which reason")
263            .build();
264
265        // Pre-create the reasons on the counter
266        wakeup_reason.add(0, &[KeyValue::new("reason", "sleep")]);
267        wakeup_reason.add(0, &[KeyValue::new("reason", "task")]);
268        wakeup_reason.add(0, &[KeyValue::new("reason", "notification")]);
269
270        let tick_time = METER
271            .u64_histogram("job.worker.tick_duration")
272            .with_description(
273                "How much time the worker took to tick, including performing leader duties",
274            )
275            .build();
276
277        // We put a cancellation drop guard in the structure, so that when it gets
278        // dropped, we're sure to cancel the token
279        let cancellation_guard = cancellation_token.clone().drop_guard();
280
281        Ok(Self {
282            listener,
283            registration,
284            am_i_leader: false,
285            last_heartbeat: now,
286            cancellation_token,
287            cancellation_guard,
288            state,
289            schedules: Vec::new(),
290            tracker: JobTracker::new(),
291            wakeup_reason,
292            tick_time,
293        })
294    }
295
296    pub(crate) fn register_handler<T: RunnableJob + InsertableJob>(&mut self) -> &mut Self {
297        // There is a potential panic here, which is fine as it's going to be caught
298        // within the job task
299        let factory = |payload: JobPayload| {
300            box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
301        };
302
303        self.tracker
304            .factories
305            .insert(T::QUEUE_NAME, Arc::new(factory));
306        self
307    }
308
309    pub(crate) fn add_schedule<T: InsertableJob>(
310        &mut self,
311        schedule_name: &'static str,
312        expression: Schedule,
313        job: T,
314    ) -> &mut Self {
315        let payload = serde_json::to_value(job).expect("failed to serialize job payload");
316
317        self.schedules.push(ScheduleDefinition {
318            schedule_name,
319            expression,
320            queue_name: T::QUEUE_NAME,
321            payload,
322        });
323
324        self
325    }
326
327    pub(crate) async fn run(mut self) {
328        if let Err(e) = self.run_inner().await {
329            tracing::error!(
330                error = &e as &dyn std::error::Error,
331                "Failed to run new queue"
332            );
333        }
334    }
335
336    async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
337        self.setup_schedules().await?;
338
339        while !self.cancellation_token.is_cancelled() {
340            LogContext::new("worker-run-loop")
341                .run(|| self.run_loop())
342                .await?;
343        }
344
345        self.shutdown().await?;
346
347        Ok(())
348    }
349
350    #[tracing::instrument(name = "worker.setup_schedules", skip_all)]
351    pub(crate) async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
352        let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
353
354        // Start a transaction on the existing PgListener connection
355        let txn = self
356            .listener
357            .begin()
358            .await
359            .map_err(QueueRunnerError::StartTransaction)?;
360
361        let mut repo = PgRepository::from_conn(txn);
362
363        // Setup the entries in the queue_schedules table
364        repo.queue_schedule().setup(&schedules).await?;
365
366        repo.into_inner()
367            .commit()
368            .await
369            .map_err(QueueRunnerError::CommitTransaction)?;
370
371        Ok(())
372    }
373
374    #[tracing::instrument(name = "worker.run_loop", skip_all)]
375    async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
376        self.wait_until_wakeup().await?;
377
378        if self.cancellation_token.is_cancelled() {
379            return Ok(());
380        }
381
382        let start = Instant::now();
383        self.tick().await?;
384
385        if self.am_i_leader {
386            self.perform_leader_duties().await?;
387        }
388
389        let elapsed = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
390        self.tick_time.record(elapsed, &[]);
391
392        Ok(())
393    }
394
395    #[tracing::instrument(name = "worker.shutdown", skip_all)]
396    async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
397        tracing::info!("Shutting down worker");
398
399        let clock = self.state.clock();
400        let mut rng = self.state.rng();
401
402        // Start a transaction on the existing PgListener connection
403        let txn = self
404            .listener
405            .begin()
406            .await
407            .map_err(QueueRunnerError::StartTransaction)?;
408
409        let mut repo = PgRepository::from_conn(txn);
410
411        // Log about any job still running
412        match self.tracker.running_jobs() {
413            0 => {}
414            1 => tracing::warn!("There is one job still running, waiting for it to finish"),
415            n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"),
416        }
417
418        // TODO: we may want to introduce a timeout here, and abort the tasks if they
419        // take too long. It's fine for now, as we don't have long-running
420        // tasks, most of them are idempotent, and the only effect might be that
421        // the worker would 'dirtily' shutdown, meaning that its tasks would be
422        // considered, later retried by another worker
423
424        // Wait for all the jobs to finish
425        self.tracker
426            .process_jobs(&mut rng, clock, &mut repo, true)
427            .await?;
428
429        // Tell the other workers we're shutting down
430        // This also releases the leader election lease
431        repo.queue_worker()
432            .shutdown(clock, &self.registration)
433            .await?;
434
435        repo.into_inner()
436            .commit()
437            .await
438            .map_err(QueueRunnerError::CommitTransaction)?;
439
440        Ok(())
441    }
442
443    #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all)]
444    async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
445        let mut rng = self.state.rng();
446
447        // This is to make sure we wake up every second to do the maintenance tasks
448        // We add a little bit of random jitter to the duration, so that we don't get
449        // fully synced workers waking up at the same time after each notification
450        let sleep_duration = rng.sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
451        let wakeup_sleep = tokio::time::sleep(sleep_duration);
452
453        tokio::select! {
454            () = self.cancellation_token.cancelled() => {
455                tracing::debug!("Woke up from cancellation");
456            },
457
458            () = wakeup_sleep => {
459                tracing::debug!("Woke up from sleep");
460                self.wakeup_reason.add(1, &[KeyValue::new("reason", "sleep")]);
461            },
462
463            () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => {
464                tracing::debug!("Joined job task");
465                self.wakeup_reason.add(1, &[KeyValue::new("reason", "task")]);
466            },
467
468            notification = self.listener.recv() => {
469                self.wakeup_reason.add(1, &[KeyValue::new("reason", "notification")]);
470                match notification {
471                    Ok(notification) => {
472                        tracing::debug!(
473                            notification.channel = notification.channel(),
474                            notification.payload = notification.payload(),
475                            "Woke up from notification"
476                        );
477                    },
478                    Err(e) => {
479                        tracing::error!(error = &e as &dyn std::error::Error, "Failed to receive notification");
480                    },
481                }
482            },
483        }
484
485        Ok(())
486    }
487
488    #[tracing::instrument(
489        name = "worker.tick",
490        skip_all,
491        fields(worker.id = %self.registration.id),
492    )]
493    async fn tick(&mut self) -> Result<(), QueueRunnerError> {
494        tracing::debug!("Tick");
495        let clock = self.state.clock();
496        let mut rng = self.state.rng();
497        let now = clock.now();
498
499        // Start a transaction on the existing PgListener connection
500        let txn = self
501            .listener
502            .begin()
503            .await
504            .map_err(QueueRunnerError::StartTransaction)?;
505        let mut repo = PgRepository::from_conn(txn);
506
507        // We send a heartbeat every minute, to avoid writing to the database too often
508        // on a logged table
509        if now - self.last_heartbeat >= chrono::Duration::minutes(1) {
510            tracing::info!("Sending heartbeat");
511            repo.queue_worker()
512                .heartbeat(clock, &self.registration)
513                .await?;
514            self.last_heartbeat = now;
515        }
516
517        // Remove any dead worker leader leases
518        repo.queue_worker()
519            .remove_leader_lease_if_expired(clock)
520            .await?;
521
522        // Try to become (or stay) the leader
523        let leader = repo
524            .queue_worker()
525            .try_get_leader_lease(clock, &self.registration)
526            .await?;
527
528        // Process any job task which finished
529        self.tracker
530            .process_jobs(&mut rng, clock, &mut repo, false)
531            .await?;
532
533        // Compute how many jobs we should fetch at most
534        let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
535            .saturating_sub(self.tracker.running_jobs())
536            .max(MAX_JOBS_TO_FETCH);
537
538        if max_jobs_to_fetch == 0 {
539            tracing::warn!("Internal job queue is full, not fetching any new jobs");
540        } else {
541            // Grab a few jobs in the queue
542            let queues = self.tracker.queues();
543            let jobs = repo
544                .queue_job()
545                .reserve(clock, &self.registration, &queues, max_jobs_to_fetch)
546                .await?;
547
548            for Job {
549                id,
550                queue_name,
551                payload,
552                metadata,
553                attempt,
554            } in jobs
555            {
556                let cancellation_token = self.cancellation_token.child_token();
557                let start = Instant::now();
558                let context = JobContext {
559                    id,
560                    metadata,
561                    queue_name,
562                    attempt,
563                    start,
564                    cancellation_token,
565                };
566
567                self.tracker.spawn_job(self.state.clone(), context, payload);
568            }
569        }
570
571        // After this point, we are locking the leader table, so it's important that we
572        // commit as soon as possible to not block the other workers for too long
573        repo.into_inner()
574            .commit()
575            .await
576            .map_err(QueueRunnerError::CommitTransaction)?;
577
578        // Save the new leader state to log any change
579        if leader != self.am_i_leader {
580            // If we flipped state, log it
581            self.am_i_leader = leader;
582            if self.am_i_leader {
583                tracing::info!("I'm the leader now");
584            } else {
585                tracing::warn!("I am no longer the leader");
586            }
587        }
588
589        Ok(())
590    }
591
592    #[tracing::instrument(name = "worker.perform_leader_duties", skip_all)]
593    async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
594        // This should have been checked by the caller, but better safe than sorry
595        if !self.am_i_leader {
596            return Err(QueueRunnerError::NotLeader);
597        }
598
599        let clock = self.state.clock();
600        let mut rng = self.state.rng();
601
602        // Start a transaction on the existing PgListener connection
603        let txn = self
604            .listener
605            .begin()
606            .await
607            .map_err(QueueRunnerError::StartTransaction)?;
608
609        // The thing with the leader election is that it locks the table during the
610        // election, preventing other workers from going through the loop.
611        //
612        // Ideally, we would do the leader duties in the same transaction so that we
613        // make sure only one worker is doing the leader duties, but that
614        // would mean we would lock all the workers for the duration of the
615        // duties, which is not ideal.
616        //
617        // So we do the duties in a separate transaction, in which we take an advisory
618        // lock, so that in the very rare case where two workers think they are the
619        // leader, we still don't have two workers doing the duties at the same time.
620        let lock = PgAdvisoryLock::new("leader-duties");
621
622        let locked = lock
623            .try_acquire(txn)
624            .await
625            .map_err(QueueRunnerError::LeaderLock)?;
626
627        let locked = match locked {
628            Either::Left(locked) => locked,
629            Either::Right(txn) => {
630                tracing::error!("Another worker has the leader lock, aborting");
631                txn.rollback()
632                    .await
633                    .map_err(QueueRunnerError::CommitTransaction)?;
634                return Ok(());
635            }
636        };
637
638        let mut repo = PgRepository::from_conn(locked);
639
640        // Look at the state of schedules in the database
641        let schedules_status = repo.queue_schedule().list().await?;
642
643        let now = clock.now();
644        for schedule in &self.schedules {
645            // Find the schedule status from the database
646            let Some(status) = schedules_status
647                .iter()
648                .find(|s| s.schedule_name == schedule.schedule_name)
649            else {
650                tracing::error!(
651                    "Schedule {} was not found in the database",
652                    schedule.schedule_name
653                );
654                continue;
655            };
656
657            // Figure out if we should schedule a new job
658            if let Some(next_time) = status.last_scheduled_at {
659                if next_time > now {
660                    // We already have a job scheduled in the future, skip
661                    continue;
662                }
663
664                if status.last_scheduled_job_completed == Some(false) {
665                    // The last scheduled job has not completed yet, skip
666                    continue;
667                }
668            }
669
670            let next_tick = schedule.expression.after(&now).next().unwrap();
671
672            tracing::info!(
673                "Scheduling job for {}, next run at {}",
674                schedule.schedule_name,
675                next_tick
676            );
677
678            repo.queue_job()
679                .schedule_later(
680                    &mut rng,
681                    clock,
682                    schedule.queue_name,
683                    schedule.payload.clone(),
684                    serde_json::json!({}),
685                    next_tick,
686                    Some(schedule.schedule_name),
687                )
688                .await?;
689        }
690
691        // We also check if the worker is dead, and if so, we shutdown all the dead
692        // workers that haven't checked in the last two minutes
693        repo.queue_worker()
694            .shutdown_dead_workers(clock, Duration::minutes(2))
695            .await?;
696
697        // TODO: mark tasks those workers had as lost
698
699        // Mark all the scheduled jobs as available
700        let scheduled = repo.queue_job().schedule_available_jobs(clock).await?;
701        match scheduled {
702            0 => {}
703            1 => tracing::info!("One scheduled job marked as available"),
704            n => tracing::info!("{n} scheduled jobs marked as available"),
705        }
706
707        // Release the leader lock
708        let txn = repo
709            .into_inner()
710            .release_now()
711            .await
712            .map_err(QueueRunnerError::LeaderLock)?;
713
714        txn.commit()
715            .await
716            .map_err(QueueRunnerError::CommitTransaction)?;
717
718        Ok(())
719    }
720
721    /// Process all the pending jobs in the queue.
722    /// This should only be called in tests!
723    ///
724    /// # Errors
725    ///
726    /// This function can fail if the database connection fails.
727    pub async fn process_all_jobs_in_tests(&mut self) -> Result<(), QueueRunnerError> {
728        // I swear, I'm the leader!
729        self.am_i_leader = true;
730
731        // First, perform the leader duties. This will make sure that we schedule
732        // recurring jobs.
733        self.perform_leader_duties().await?;
734
735        let clock = self.state.clock();
736        let mut rng = self.state.rng();
737
738        // Grab the connection from the PgListener
739        let txn = self
740            .listener
741            .begin()
742            .await
743            .map_err(QueueRunnerError::StartTransaction)?;
744        let mut repo = PgRepository::from_conn(txn);
745
746        // Spawn all the jobs in the database
747        let queues = self.tracker.queues();
748        let jobs = repo
749            .queue_job()
750            // I really hope that we don't spawn more than 10k jobs in tests
751            .reserve(clock, &self.registration, &queues, 10_000)
752            .await?;
753
754        for Job {
755            id,
756            queue_name,
757            payload,
758            metadata,
759            attempt,
760        } in jobs
761        {
762            let cancellation_token = self.cancellation_token.child_token();
763            let start = Instant::now();
764            let context = JobContext {
765                id,
766                metadata,
767                queue_name,
768                attempt,
769                start,
770                cancellation_token,
771            };
772
773            self.tracker.spawn_job(self.state.clone(), context, payload);
774        }
775
776        self.tracker
777            .process_jobs(&mut rng, clock, &mut repo, true)
778            .await?;
779
780        repo.into_inner()
781            .commit()
782            .await
783            .map_err(QueueRunnerError::CommitTransaction)?;
784
785        Ok(())
786    }
787}
788
789/// Tracks running jobs
790///
791/// This is a separate structure to be able to borrow it mutably at the same
792/// time as the connection to the database is borrowed
793struct JobTracker {
794    /// Stores a mapping from the job queue name to the job factory
795    factories: HashMap<&'static str, JobFactory>,
796
797    /// A join set of all the currently running jobs
798    running_jobs: JoinSet<JobResult>,
799
800    /// Stores a mapping from the Tokio task ID to the job context
801    job_contexts: HashMap<tokio::task::Id, JobContext>,
802
803    /// Stores the last `join_next_with_id` result for processing, in case we
804    /// got woken up in `collect_next_job`
805    last_join_result: Option<Result<(tokio::task::Id, JobResult), tokio::task::JoinError>>,
806
807    /// An histogram which records the time it takes to process a job
808    job_processing_time: Histogram<u64>,
809
810    /// A counter which records the number of jobs currently in flight
811    in_flight_jobs: UpDownCounter<i64>,
812}
813
814impl JobTracker {
815    fn new() -> Self {
816        let job_processing_time = METER
817            .u64_histogram("job.process.duration")
818            .with_description("The time it takes to process a job in milliseconds")
819            .with_unit("ms")
820            .build();
821
822        let in_flight_jobs = METER
823            .i64_up_down_counter("job.active_tasks")
824            .with_description("The number of jobs currently in flight")
825            .with_unit("{job}")
826            .build();
827
828        Self {
829            factories: HashMap::new(),
830            running_jobs: JoinSet::new(),
831            job_contexts: HashMap::new(),
832            last_join_result: None,
833            job_processing_time,
834            in_flight_jobs,
835        }
836    }
837
838    /// Returns the queue names that are currently being tracked
839    fn queues(&self) -> Vec<&'static str> {
840        self.factories.keys().copied().collect()
841    }
842
843    /// Spawn a job on the job tracker
844    fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
845        let factory = self.factories.get(context.queue_name.as_str()).cloned();
846        let task = {
847            let log_context = LogContext::new(format!("job-{}", context.queue_name));
848            let context = context.clone();
849            let span = context.span();
850            log_context
851                .run(async move || {
852                    // We should never crash, but in case we do, we do that in the task and
853                    // don't crash the worker
854                    let job = factory.expect("unknown job factory")(payload);
855
856                    let timeout = job.timeout();
857                    // If there is a timeout set on the job, spawn a task which will cancel the
858                    // CancellationToken once the timeout is reached
859                    if let Some(timeout) = timeout {
860                        let context = context.clone();
861
862                        // It's fine to spawn this task without tracking it, as it is quite
863                        // lightweight and has no reason to crash.
864                        tokio::spawn(
865                            context
866                                .cancellation_token
867                                .clone()
868                                // This makes sure the task gets cancelled as soon as the job
869                                // finishes
870                                .run_until_cancelled_owned(async move {
871                                    tokio::time::sleep(timeout).await;
872                                    tracing::warn!(
873                                        job.id = %context.id,
874                                        job.queue.name = %context.queue_name,
875                                        "Job reached timeout, asking for cancellation"
876                                    );
877                                    context.cancellation_token.cancel();
878                                }),
879                        );
880                    }
881
882                    tracing::info!(
883                        job.id = %context.id,
884                        job.queue.name = %context.queue_name,
885                        job.attempt = %context.attempt,
886                        job.timeout = timeout.map(tracing::field::debug),
887                        "Running job"
888                    );
889                    let result = job.run(&state, context.clone()).await;
890
891                    // Cancel the cancellation token to stop any timeout task
892                    // that may be running
893                    context.cancellation_token.cancel();
894
895                    let Some(context_stats) =
896                        LogContext::maybe_with(mas_context::LogContext::stats)
897                    else {
898                        // This should never happen, but if it does it's fine: we're recovering fine
899                        // from panics in those tasks
900                        panic!("Missing log context, this should never happen");
901                    };
902
903                    // We log the result here so that it's attached to the right span & log context
904                    match &result {
905                        Ok(()) => {
906                            tracing::info!(
907                                job.id = %context.id,
908                                job.queue.name = %context.queue_name,
909                                job.attempt = %context.attempt,
910                                "Job completed [{context_stats}]"
911                            );
912                        }
913
914                        Err(JobError {
915                            decision: JobErrorDecision::Fail,
916                            error,
917                        }) => {
918                            tracing::error!(
919                                error = &**error as &dyn std::error::Error,
920                                job.id = %context.id,
921                                job.queue.name = %context.queue_name,
922                                job.attempt = %context.attempt,
923                                "Job failed, not retrying [{context_stats}]"
924                            );
925                        }
926
927                        Err(JobError {
928                            decision: JobErrorDecision::Retry,
929                            error,
930                        }) if context.attempt < MAX_ATTEMPTS => {
931                            let delay = retry_delay(context.attempt);
932                            tracing::warn!(
933                                error = &**error as &dyn std::error::Error,
934                                job.id = %context.id,
935                                job.queue.name = %context.queue_name,
936                                job.attempt = %context.attempt,
937                                "Job failed, will retry in {}s [{context_stats}]",
938                                delay.num_seconds()
939                            );
940                        }
941
942                        Err(JobError {
943                            decision: JobErrorDecision::Retry,
944                            error,
945                        }) => {
946                            tracing::error!(
947                                error = &**error as &dyn std::error::Error,
948                                job.id = %context.id,
949                                job.queue.name = %context.queue_name,
950                                job.attempt = %context.attempt,
951                                "Job failed too many times, abandonning [{context_stats}]"
952                            );
953                        }
954                    }
955
956                    (context_stats.elapsed, result)
957                })
958                .instrument(span)
959        };
960
961        self.in_flight_jobs.add(
962            1,
963            &[KeyValue::new("job.queue.name", context.queue_name.clone())],
964        );
965
966        let handle = self.running_jobs.spawn(task);
967        self.job_contexts.insert(handle.id(), context);
968    }
969
970    /// Returns `true` if there are currently running jobs
971    fn has_jobs(&self) -> bool {
972        !self.running_jobs.is_empty()
973    }
974
975    /// Returns the number of currently running jobs
976    ///
977    /// This also includes the job result which may be stored for processing
978    fn running_jobs(&self) -> usize {
979        self.running_jobs.len() + usize::from(self.last_join_result.is_some())
980    }
981
982    async fn collect_next_job(&mut self) {
983        // Double-check that we don't have a job result stored
984        if self.last_join_result.is_some() {
985            tracing::error!(
986                "Job tracker already had a job result stored, this should never happen!"
987            );
988            return;
989        }
990
991        self.last_join_result = self.running_jobs.join_next_with_id().await;
992    }
993
994    /// Process all the jobs which are currently running
995    ///
996    /// If `blocking` is `true`, this function will block until all the jobs
997    /// are finished. Otherwise, it will return as soon as it processed the
998    /// already finished jobs.
999    async fn process_jobs<E: std::error::Error + Send + Sync + 'static>(
1000        &mut self,
1001        rng: &mut (dyn RngCore + Send),
1002        clock: &dyn Clock,
1003        repo: &mut dyn RepositoryAccess<Error = E>,
1004        blocking: bool,
1005    ) -> Result<(), E> {
1006        if self.last_join_result.is_none() {
1007            if blocking {
1008                self.last_join_result = self.running_jobs.join_next_with_id().await;
1009            } else {
1010                self.last_join_result = self.running_jobs.try_join_next_with_id();
1011            }
1012        }
1013
1014        while let Some(result) = self.last_join_result.take() {
1015            match result {
1016                // The job succeeded. The logging and time measurement is already done in the task
1017                Ok((id, (elapsed, Ok(())))) => {
1018                    let context = self
1019                        .job_contexts
1020                        .remove(&id)
1021                        .expect("Job context not found");
1022
1023                    self.in_flight_jobs.add(
1024                        -1,
1025                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1026                    );
1027
1028                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
1029                    self.job_processing_time.record(
1030                        elapsed_ms,
1031                        &[
1032                            KeyValue::new("job.queue.name", context.queue_name),
1033                            KeyValue::new("job.result", "success"),
1034                        ],
1035                    );
1036
1037                    repo.queue_job()
1038                        .mark_as_completed(clock, context.id)
1039                        .await?;
1040                }
1041
1042                // The job failed. The logging and time measurement is already done in the task
1043                Ok((id, (elapsed, Err(e)))) => {
1044                    let context = self
1045                        .job_contexts
1046                        .remove(&id)
1047                        .expect("Job context not found");
1048
1049                    self.in_flight_jobs.add(
1050                        -1,
1051                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1052                    );
1053
1054                    let reason = format!("{:?}", e.error);
1055                    repo.queue_job()
1056                        .mark_as_failed(clock, context.id, &reason)
1057                        .await?;
1058
1059                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
1060                    match e.decision {
1061                        JobErrorDecision::Fail => {
1062                            self.job_processing_time.record(
1063                                elapsed_ms,
1064                                &[
1065                                    KeyValue::new("job.queue.name", context.queue_name),
1066                                    KeyValue::new("job.result", "failed"),
1067                                    KeyValue::new("job.decision", "fail"),
1068                                ],
1069                            );
1070                        }
1071
1072                        JobErrorDecision::Retry if context.attempt < MAX_ATTEMPTS => {
1073                            self.job_processing_time.record(
1074                                elapsed_ms,
1075                                &[
1076                                    KeyValue::new("job.queue.name", context.queue_name),
1077                                    KeyValue::new("job.result", "failed"),
1078                                    KeyValue::new("job.decision", "retry"),
1079                                ],
1080                            );
1081
1082                            let delay = retry_delay(context.attempt);
1083                            repo.queue_job()
1084                                .retry(&mut *rng, clock, context.id, delay)
1085                                .await?;
1086                        }
1087
1088                        JobErrorDecision::Retry => {
1089                            self.job_processing_time.record(
1090                                elapsed_ms,
1091                                &[
1092                                    KeyValue::new("job.queue.name", context.queue_name),
1093                                    KeyValue::new("job.result", "failed"),
1094                                    KeyValue::new("job.decision", "abandon"),
1095                                ],
1096                            );
1097                        }
1098                    }
1099                }
1100
1101                // The job crashed (or was aborted)
1102                Err(e) => {
1103                    let id = e.id();
1104                    let context = self
1105                        .job_contexts
1106                        .remove(&id)
1107                        .expect("Job context not found");
1108
1109                    self.in_flight_jobs.add(
1110                        -1,
1111                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1112                    );
1113
1114                    // This measurement is not accurate as it includes the time processing the jobs,
1115                    // but it's fine, it's only for panicked tasks
1116                    let elapsed = context
1117                        .start
1118                        .elapsed()
1119                        .as_millis()
1120                        .try_into()
1121                        .unwrap_or(u64::MAX);
1122
1123                    let reason = e.to_string();
1124                    repo.queue_job()
1125                        .mark_as_failed(clock, context.id, &reason)
1126                        .await?;
1127
1128                    if context.attempt < MAX_ATTEMPTS {
1129                        let delay = retry_delay(context.attempt);
1130                        tracing::error!(
1131                            error = &e as &dyn std::error::Error,
1132                            job.id = %context.id,
1133                            job.queue.name = %context.queue_name,
1134                            job.attempt = %context.attempt,
1135                            job.elapsed = format!("{elapsed}ms"),
1136                            "Job crashed, will retry in {}s",
1137                            delay.num_seconds()
1138                        );
1139
1140                        self.job_processing_time.record(
1141                            elapsed,
1142                            &[
1143                                KeyValue::new("job.queue.name", context.queue_name),
1144                                KeyValue::new("job.result", "crashed"),
1145                                KeyValue::new("job.decision", "retry"),
1146                            ],
1147                        );
1148
1149                        repo.queue_job()
1150                            .retry(&mut *rng, clock, context.id, delay)
1151                            .await?;
1152                    } else {
1153                        tracing::error!(
1154                            error = &e as &dyn std::error::Error,
1155                            job.id = %context.id,
1156                            job.queue.name = %context.queue_name,
1157                            job.attempt = %context.attempt,
1158                            job.elapsed = format!("{elapsed}ms"),
1159                            "Job crashed too many times, abandonning"
1160                        );
1161
1162                        self.job_processing_time.record(
1163                            elapsed,
1164                            &[
1165                                KeyValue::new("job.queue.name", context.queue_name),
1166                                KeyValue::new("job.result", "crashed"),
1167                                KeyValue::new("job.decision", "abandon"),
1168                            ],
1169                        );
1170                    }
1171                }
1172            }
1173
1174            if blocking {
1175                self.last_join_result = self.running_jobs.join_next_with_id().await;
1176            } else {
1177                self.last_join_result = self.running_jobs.try_join_next_with_id();
1178            }
1179        }
1180
1181        Ok(())
1182    }
1183}