mod bound;
mod worker;
use std::net::IpAddr;
use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, CompatSession, Session};
use mas_storage::Clock;
use sqlx::PgPool;
use ulid::Ulid;
pub use self::bound::Bound;
use self::worker::Worker;
static MESSAGE_QUEUE_SIZE: usize = 1000;
#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Hash)]
enum SessionKind {
OAuth2,
Compat,
Browser,
}
impl SessionKind {
const fn as_str(self) -> &'static str {
match self {
SessionKind::OAuth2 => "oauth2",
SessionKind::Compat => "compat",
SessionKind::Browser => "browser",
}
}
}
enum Message {
Record {
kind: SessionKind,
id: Ulid,
date_time: DateTime<Utc>,
ip: Option<IpAddr>,
},
Flush(tokio::sync::oneshot::Sender<()>),
Shutdown(tokio::sync::oneshot::Sender<()>),
}
#[derive(Clone)]
pub struct ActivityTracker {
channel: tokio::sync::mpsc::Sender<Message>,
}
impl ActivityTracker {
#[must_use]
pub fn new(pool: PgPool, flush_interval: std::time::Duration) -> Self {
let worker = Worker::new(pool);
let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
let tracker = ActivityTracker { channel: sender };
tokio::spawn(tracker.clone().flush_loop(flush_interval));
tokio::spawn(worker.run(receiver));
tracker
}
#[must_use]
pub fn bind(self, ip: Option<IpAddr>) -> Bound {
Bound::new(self, ip)
}
pub async fn record_oauth2_session(
&self,
clock: &dyn Clock,
session: &Session,
ip: Option<IpAddr>,
) {
let res = self
.channel
.send(Message::Record {
kind: SessionKind::OAuth2,
id: session.id,
date_time: clock.now(),
ip,
})
.await;
if let Err(e) = res {
tracing::error!("Failed to record OAuth2 session: {}", e);
}
}
pub async fn record_compat_session(
&self,
clock: &dyn Clock,
compat_session: &CompatSession,
ip: Option<IpAddr>,
) {
let res = self
.channel
.send(Message::Record {
kind: SessionKind::Compat,
id: compat_session.id,
date_time: clock.now(),
ip,
})
.await;
if let Err(e) = res {
tracing::error!("Failed to record compat session: {}", e);
}
}
pub async fn record_browser_session(
&self,
clock: &dyn Clock,
browser_session: &BrowserSession,
ip: Option<IpAddr>,
) {
let res = self
.channel
.send(Message::Record {
kind: SessionKind::Browser,
id: browser_session.id,
date_time: clock.now(),
ip,
})
.await;
if let Err(e) = res {
tracing::error!("Failed to record browser session: {}", e);
}
}
pub async fn flush(&self) {
let (tx, rx) = tokio::sync::oneshot::channel();
let res = self.channel.send(Message::Flush(tx)).await;
match res {
Ok(()) => {
if let Err(e) = rx.await {
tracing::error!("Failed to flush activity tracker: {}", e);
}
}
Err(e) => {
tracing::error!("Failed to flush activity tracker: {}", e);
}
}
}
async fn flush_loop(self, interval: std::time::Duration) {
loop {
tokio::select! {
biased;
() = self.channel.closed() => {
break;
}
() = tokio::time::sleep(interval) => {
self.flush().await;
}
}
}
}
pub async fn shutdown(&self) {
let (tx, rx) = tokio::sync::oneshot::channel();
let res = self.channel.send(Message::Shutdown(tx)).await;
match res {
Ok(()) => {
if let Err(e) = rx.await {
tracing::error!("Failed to shutdown activity tracker: {}", e);
}
}
Err(e) => {
tracing::error!("Failed to shutdown activity tracker: {}", e);
}
}
}
}