Major refactor
This commit is contained in:
28
crates/siren-api/Cargo.toml
Normal file
28
crates/siren-api/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "siren-api"
|
||||
edition.workspace = true
|
||||
version.workspace = true
|
||||
rust-version.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
siren-core = { workspace = true }
|
||||
siren-bot = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
serenity = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_chacha = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
53
crates/siren-api/src/app.rs
Normal file
53
crates/siren-api/src/app.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use crate::{AppState, error::Result};
|
||||
use axum::Router;
|
||||
use std::{env, sync::Arc};
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::{
|
||||
cors::{Any, CorsLayer},
|
||||
services::{ServeDir, ServeFile},
|
||||
};
|
||||
|
||||
pub struct App {
|
||||
app_state: AppState,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new(app_state: AppState) -> Self {
|
||||
Self { app_state }
|
||||
}
|
||||
|
||||
pub async fn serve(self) -> Result<()> {
|
||||
log::debug!("Starting API...");
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
// Serve the built React frontend from frontend/dist (relative to the
|
||||
// working directory). Falls back gracefully if the directory does not
|
||||
// exist yet (e.g. during development when using `npm run dev`).
|
||||
let frontend_dir = env::current_dir()
|
||||
.unwrap_or_default()
|
||||
.join("frontend")
|
||||
.join("dist");
|
||||
|
||||
// For SPA routing: any path not matched by a real file (e.g. /map/<id>)
|
||||
// falls back to index.html so React can handle client-side routing.
|
||||
let index_html = frontend_dir.join("index.html");
|
||||
let serve_dir = ServeDir::new(&frontend_dir).not_found_service(ServeFile::new(index_html));
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api", crate::get_routes())
|
||||
.fallback_service(serve_dir)
|
||||
.layer(cors)
|
||||
.with_state(Arc::new(self.app_state));
|
||||
|
||||
let api_port: String = env::var("API_PORT").expect("Expected a port in the environment");
|
||||
let addr = format!("0.0.0.0:{}", api_port);
|
||||
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
log::info!("API is listening on {}", &addr);
|
||||
Ok(axum::serve(listener, app).await?)
|
||||
}
|
||||
}
|
||||
23
crates/siren-api/src/app_state.rs
Normal file
23
crates/siren-api/src/app_state.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use dashmap::DashMap;
|
||||
use serenity::{
|
||||
all::{Cache, Http},
|
||||
prelude::Mutex,
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub client: reqwest::Client,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub base_url: String,
|
||||
/// Maps oauth_state → ui_redirect_uri.
|
||||
/// Populated on /authorize, consumed on /callback.
|
||||
pub discord_authorize_cache: Arc<Mutex<HashMap<String, String>>>,
|
||||
pub http: Arc<Http>,
|
||||
pub cache: Arc<Cache>,
|
||||
/// Per-map WebSocket broadcast channels for real-time collaboration.
|
||||
/// Key is the CSPRNG map ID (TEXT).
|
||||
pub map_rooms: Arc<DashMap<String, broadcast::Sender<String>>>,
|
||||
}
|
||||
105
crates/siren-api/src/audio/mod.rs
Normal file
105
crates/siren-api/src/audio/mod.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{AuthorizationMiddleware, Session},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Extension,
|
||||
Json,
|
||||
Router,
|
||||
extract::{Path, State},
|
||||
middleware::from_extractor,
|
||||
routing::post,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use siren_bot::{
|
||||
commands::audio::{
|
||||
join_voice_channel,
|
||||
pause::pause_track,
|
||||
play::enqueue_track,
|
||||
resume::resume_track,
|
||||
},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/play", post(play_audio))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
.route("/pause", post(pause_audio))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
.route("/resume", post(resume_audio))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct PlayTrackRequest {
|
||||
url: String,
|
||||
}
|
||||
|
||||
async fn play_audio(
|
||||
Extension(session): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
Json(payload): Json<PlayTrackRequest>,
|
||||
) -> Result<()> {
|
||||
log::debug!("Playing audio in guild: {}", guild_id);
|
||||
|
||||
// Check if the user exists in the cache
|
||||
let user_id = session.user_id;
|
||||
let user_id = match state.cache.user(user_id) {
|
||||
Some(user) => user.id,
|
||||
None => return Err(Error::not_found("User not found".to_string())),
|
||||
};
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
// Play the track
|
||||
let manager = get_songbird();
|
||||
let _channel_id = join_voice_channel(&state.cache, &manager, &guild_id, &user_id).await?;
|
||||
enqueue_track(manager, guild_id.to_owned(), &payload.url).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pause_audio(
|
||||
Extension(_): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> Result<()> {
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
// Pause the track
|
||||
let manager = get_songbird();
|
||||
pause_track(manager, &guild_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn resume_audio(
|
||||
Extension(_): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> Result<()> {
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
// Pause the track
|
||||
let manager = get_songbird();
|
||||
resume_track(manager, &guild_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
10
crates/siren-api/src/auth/bearer_token.rs
Normal file
10
crates/siren-api/src/auth/bearer_token.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BearerTokenClaims {
|
||||
pub sub: u64,
|
||||
pub name: String,
|
||||
pub iat: i64,
|
||||
pub exp: i64,
|
||||
pub jti: String,
|
||||
}
|
||||
225
crates/siren-api/src/auth/discord.rs
Normal file
225
crates/siren-api/src/auth/discord.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{bearer_token::BearerTokenClaims, csprng, session::Session},
|
||||
};
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::get,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, sync::Arc};
|
||||
|
||||
const DISCORD_REDIRECT_PATH: &str = "/api/auth/discord/callback";
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/authorize", get(discord_authorize))
|
||||
.route("/callback", get(discord_callback))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeQuery {
|
||||
redirect_uri: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CallbackQuery {
|
||||
code: String,
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct DiscordTokenResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: u64,
|
||||
refresh_token: String,
|
||||
scope: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct DiscordUser {
|
||||
id: String,
|
||||
username: String,
|
||||
discriminator: String,
|
||||
avatar: Option<String>,
|
||||
}
|
||||
|
||||
async fn discord_authorize(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<AuthorizeQuery>,
|
||||
) -> impl IntoResponse {
|
||||
let oauth_state = csprng(16);
|
||||
|
||||
state
|
||||
.discord_authorize_cache
|
||||
.lock()
|
||||
.await
|
||||
.insert(oauth_state.clone(), query.redirect_uri);
|
||||
|
||||
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
|
||||
let encoded_callback = discord_callback_url.replace(':', "%3A").replace('/', "%2F");
|
||||
|
||||
let discord_auth_url = format!(
|
||||
"https://discord.com/api/oauth2/authorize\
|
||||
?client_id={}\
|
||||
&redirect_uri={}\
|
||||
&response_type=code\
|
||||
&scope=identify\
|
||||
&state={}",
|
||||
state.client_id, encoded_callback, oauth_state,
|
||||
);
|
||||
|
||||
match serde_json::to_string(&discord_auth_url) {
|
||||
Ok(json) => Ok(json),
|
||||
Err(e) => {
|
||||
log::error!("Failed to serialize Discord OAuth URL: {e}");
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn discord_callback(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<CallbackQuery>,
|
||||
) -> impl IntoResponse {
|
||||
match do_oauth_callback(state, query).await {
|
||||
Ok((token, ui_redirect_uri)) => {
|
||||
Redirect::temporary(&format!("{}?token={}", ui_redirect_uri, token)).into_response()
|
||||
}
|
||||
Err((e, ui_redirect_uri)) => {
|
||||
log::error!("OAuth callback error: {:?}", e);
|
||||
let fallback = ui_redirect_uri.unwrap_or_else(|| "/".to_string());
|
||||
Redirect::temporary(&format!("{}?error=auth_failed", fallback)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_oauth_callback(
|
||||
state: Arc<AppState>,
|
||||
query: CallbackQuery,
|
||||
) -> Result<(String, String), (crate::error::Error, Option<String>)> {
|
||||
// Validate the state and retrieve the associated UI redirect URI
|
||||
let ui_redirect_uri = {
|
||||
let mut oauth_states = state.discord_authorize_cache.lock().await;
|
||||
match query.state {
|
||||
Some(ref oauth_state) => match oauth_states.remove(oauth_state) {
|
||||
Some(uri) => uri,
|
||||
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
|
||||
},
|
||||
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
|
||||
}
|
||||
};
|
||||
|
||||
// Helper closure to tag errors with the redirect URI we already know
|
||||
let redirect = ui_redirect_uri.clone();
|
||||
let err = |s: StatusCode| -> Result<_, (crate::error::Error, Option<String>)> {
|
||||
Err((s.into(), Some(redirect.clone())))
|
||||
};
|
||||
|
||||
// The discord redirect_uri in the token exchange must match what was sent in /authorize
|
||||
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
|
||||
|
||||
// Exchange code for an access token
|
||||
let token_response = state
|
||||
.client
|
||||
.post("https://discord.com/api/oauth2/token")
|
||||
.form(&[
|
||||
("client_id", state.client_id.as_str()),
|
||||
("client_secret", state.client_secret.as_str()),
|
||||
("grant_type", "authorization_code"),
|
||||
("code", query.code.as_str()),
|
||||
("redirect_uri", discord_callback_url.as_str()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
if !token_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to exchange token: {:?}",
|
||||
token_response.text().await
|
||||
);
|
||||
return err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
let token_data: DiscordTokenResponse = token_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
// Fetch user information from Discord
|
||||
let user_response = state
|
||||
.client
|
||||
.get("https://discord.com/api/users/@me")
|
||||
.bearer_auth(token_data.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
if !user_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to fetch user information: {:?}",
|
||||
user_response.text().await
|
||||
);
|
||||
return err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
let user_data: DiscordUser = user_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
log::debug!("User authenticated: {:?}", user_data);
|
||||
|
||||
let user_id: i64 = user_data
|
||||
.id
|
||||
.parse::<i64>()
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
// Upsert the Discord user into the local users table
|
||||
let pool = siren_core::data::pool();
|
||||
sqlx::query(
|
||||
"INSERT INTO users (id, username, avatar, updated_at)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET username = EXCLUDED.username,
|
||||
avatar = EXCLUDED.avatar,
|
||||
updated_at = NOW()",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&user_data.username)
|
||||
.bind(&user_data.avatar)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("Failed to upsert user: {e}");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
// Create and insert the session
|
||||
let session = Session::new(user_id as u64, user_data.username.clone());
|
||||
session
|
||||
.insert()
|
||||
.await
|
||||
.map_err(|e| (e, Some(ui_redirect_uri.clone())))?;
|
||||
|
||||
let issued_at = chrono::Utc::now();
|
||||
let claims = BearerTokenClaims {
|
||||
sub: session.user_id,
|
||||
name: session.user_name.clone(),
|
||||
iat: issued_at.timestamp(),
|
||||
exp: session.expires_at.timestamp(),
|
||||
jti: session.session_id.clone(),
|
||||
};
|
||||
|
||||
let jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
|
||||
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
Ok((token, ui_redirect_uri))
|
||||
}
|
||||
107
crates/siren-api/src/auth/middleware.rs
Normal file
107
crates/siren-api/src/auth/middleware.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::{
|
||||
auth::{bearer_token::BearerTokenClaims, session::Session},
|
||||
error::Result,
|
||||
};
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
http::{Method, StatusCode, request::Parts},
|
||||
};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, Validation, decode};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AuthorizationMiddleware — rejects unauthenticated requests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct AuthorizationMiddleware;
|
||||
|
||||
impl<S> FromRequestParts<S> for AuthorizationMiddleware
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> std::result::Result<Self, Self::Rejection> {
|
||||
// For options requests browsers will not send the authorization header.
|
||||
if parts.method == Method::OPTIONS {
|
||||
return Ok(Self);
|
||||
}
|
||||
|
||||
// Check for a Bearer token in the `Authorization` header.
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
{
|
||||
return match check_bearer_auth(bearer.token()).await {
|
||||
Ok(session) => {
|
||||
parts.extensions.insert(session);
|
||||
Ok(Self)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
}
|
||||
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OptionalAuth — extracts a Session if present, otherwise None
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wraps an optional authenticated session.
|
||||
/// Handlers that use this extractor work for both authenticated and
|
||||
/// unauthenticated callers; callers with a valid Bearer token get a `Some(session)`.
|
||||
pub struct OptionalAuth(pub Option<Session>);
|
||||
|
||||
impl<S> FromRequestParts<S> for OptionalAuth
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = std::convert::Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> std::result::Result<Self, Self::Rejection> {
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
{
|
||||
if let Ok(session) = check_bearer_auth(bearer.token()).await {
|
||||
parts.extensions.insert(session.clone());
|
||||
return Ok(Self(Some(session)));
|
||||
}
|
||||
}
|
||||
Ok(Self(None))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn check_bearer_auth(bearer_token: &str) -> Result<Session> {
|
||||
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
|
||||
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
|
||||
|
||||
let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let claims = token_data.claims;
|
||||
|
||||
let now = Utc::now().timestamp();
|
||||
if claims.exp < now {
|
||||
return Err(StatusCode::UNAUTHORIZED.into());
|
||||
}
|
||||
|
||||
match Session::find(&claims.jti).await {
|
||||
Ok(Some(session)) => Ok(session),
|
||||
_ => Err(StatusCode::UNAUTHORIZED)?,
|
||||
}
|
||||
}
|
||||
24
crates/siren-api/src/auth/mod.rs
Normal file
24
crates/siren-api/src/auth/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use crate::AppState;
|
||||
use axum::Router;
|
||||
use rand::RngExt;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod discord;
|
||||
mod session;
|
||||
pub use session::Session;
|
||||
mod bearer_token;
|
||||
pub mod middleware;
|
||||
pub use middleware::{AuthorizationMiddleware, OptionalAuth};
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new().nest("/discord", discord::get_routes())
|
||||
}
|
||||
|
||||
pub fn csprng(take: usize) -> String {
|
||||
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
||||
rand::rng()
|
||||
.sample_iter(rand::distr::Alphanumeric)
|
||||
.take(take)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
68
crates/siren-api/src/auth/session.rs
Normal file
68
crates/siren-api/src/auth/session.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use crate::{auth::csprng, error::Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use redis::{AsyncCommands, RedisResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siren_core::data;
|
||||
use std::{env, sync::OnceLock};
|
||||
|
||||
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
|
||||
|
||||
fn get_session_ttl() -> i64 {
|
||||
// Initialize the SESSION_TTL value lazily
|
||||
*SESSION_TTL.get_or_init(|| {
|
||||
env::var("API_SESSION_TTL")
|
||||
.ok()
|
||||
.and_then(|val| val.parse::<i64>().ok())
|
||||
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Session {
|
||||
pub session_id: String,
|
||||
pub user_id: u64,
|
||||
pub user_name: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(user_id: u64, user_name: String) -> Session {
|
||||
let now = Utc::now();
|
||||
let session_ttl = get_session_ttl();
|
||||
Session {
|
||||
session_id: csprng(32),
|
||||
user_id,
|
||||
user_name,
|
||||
expires_at: now + chrono::Duration::seconds(session_ttl),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert(&self) -> Result<()> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let session_id = self.session_id.clone();
|
||||
let session_ttl = get_session_ttl();
|
||||
redis
|
||||
.set_ex::<_, _, ()>(session_id, serde_json::to_string(self)?, session_ttl as u64)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn find(session_id: &str) -> Result<Option<Session>> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let result: RedisResult<Option<String>> = redis.get(session_id).await;
|
||||
match result {
|
||||
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
|
||||
Ok(None) => Ok(None),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete(session_id: &str) -> Result<()> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let result: RedisResult<()> = redis.del(session_id).await;
|
||||
match result {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
201
crates/siren-api/src/dice/mod.rs
Normal file
201
crates/siren-api/src/dice/mod.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{AuthorizationMiddleware, Session},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Extension,
|
||||
Json,
|
||||
Router,
|
||||
extract::{Path, State},
|
||||
middleware::from_extractor,
|
||||
routing::post,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siren_bot::commands::fun::roll::{format_roll, parse_dice};
|
||||
use siren_core::data::{ExecutableQuery, Value, condition::Condition, query::QueryBuilder};
|
||||
use std::{fmt::Display, str::FromStr, sync::Arc};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/{guild_id}/track", post(add_track_dice))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
}
|
||||
|
||||
const TABLE_NAME: &str = "dice_track";
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
enum TrackDiceOperator {
|
||||
#[serde(rename = "eq")]
|
||||
Equal,
|
||||
#[serde(rename = "lt")]
|
||||
LessThan,
|
||||
#[serde(rename = "lte")]
|
||||
LessThanEqual,
|
||||
#[serde(rename = "gt")]
|
||||
GreaterThan,
|
||||
#[serde(rename = "gte")]
|
||||
GreaterThanEqual,
|
||||
}
|
||||
|
||||
// Implementing the ToString trait for converting the enum to a string
|
||||
impl Display for TrackDiceOperator {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let str = match self {
|
||||
TrackDiceOperator::Equal => "eq".to_string(),
|
||||
TrackDiceOperator::LessThan => "lt".to_string(),
|
||||
TrackDiceOperator::LessThanEqual => "lte".to_string(),
|
||||
TrackDiceOperator::GreaterThan => "gt".to_string(),
|
||||
TrackDiceOperator::GreaterThanEqual => "gte".to_string(),
|
||||
};
|
||||
write!(f, "{}", str)
|
||||
}
|
||||
}
|
||||
|
||||
// Implementing the FromStr trait for parsing a string into the enum
|
||||
impl FromStr for TrackDiceOperator {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"eq" => Ok(TrackDiceOperator::Equal),
|
||||
"lt" => Ok(TrackDiceOperator::LessThan),
|
||||
"lte" => Ok(TrackDiceOperator::LessThanEqual),
|
||||
"gt" => Ok(TrackDiceOperator::GreaterThan),
|
||||
"gte" => Ok(TrackDiceOperator::GreaterThanEqual),
|
||||
_ => Err(format!("Unknown value for TrackDiceOperator: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct DiceTrackPayload {
|
||||
dice: String,
|
||||
user_id: Option<i64>,
|
||||
value: Option<i32>,
|
||||
operator: Option<TrackDiceOperator>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct InsertDiceTrack {
|
||||
guild_id: i64,
|
||||
owner_id: i64,
|
||||
dice: String,
|
||||
user_id: Option<i64>,
|
||||
value: Option<i32>,
|
||||
operator: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct QueryDiceTrack {
|
||||
id: Uuid,
|
||||
guild_id: i64,
|
||||
owner_id: i64,
|
||||
dice: String,
|
||||
user_id: Option<i64>,
|
||||
value: Option<i32>,
|
||||
operator: Option<String>,
|
||||
}
|
||||
|
||||
impl QueryDiceTrack {
|
||||
pub async fn find(dice: &InsertDiceTrack) -> Option<Self> {
|
||||
QueryBuilder::new(TABLE_NAME)
|
||||
.where_condition(Condition::and(
|
||||
Condition::is_equal("guild_id", Value::BigInt(dice.guild_id)),
|
||||
Condition::and(
|
||||
Condition::is_equal("owner_id", Value::BigInt(dice.owner_id)),
|
||||
Condition::and(
|
||||
Condition::is_equal("dice", Value::Text(dice.dice.clone())),
|
||||
Condition::and(
|
||||
Condition::is_equal("user_id", Value::OptionalBigInt(dice.user_id)),
|
||||
Condition::and(
|
||||
Condition::is_equal("value", Value::OptionalInt(dice.value)),
|
||||
Condition::is_equal("operator", Value::OptionalText(dice.operator.clone())),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
))
|
||||
.fetch_optional()
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl InsertDiceTrack {
|
||||
pub async fn insert(&self) -> Result<QueryDiceTrack> {
|
||||
let pool = siren_core::data::pool();
|
||||
let query = format!(
|
||||
"INSERT INTO {} (
|
||||
guild_id,
|
||||
owner_id,
|
||||
dice,
|
||||
user_id,
|
||||
value,
|
||||
operator
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6
|
||||
) RETURNING *",
|
||||
TABLE_NAME
|
||||
);
|
||||
let item: QueryDiceTrack = match sqlx::query_as(&query)
|
||||
.bind(self.guild_id)
|
||||
.bind(self.owner_id)
|
||||
.bind(&self.dice)
|
||||
.bind(self.user_id)
|
||||
.bind(self.value)
|
||||
.bind(&self.operator)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
{
|
||||
Some(result) => result,
|
||||
None => return Err(Error::new(500, "Error storing".to_string())),
|
||||
};
|
||||
Ok(item)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_track_dice(
|
||||
Extension(session): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
Json(payload): Json<DiceTrackPayload>,
|
||||
) -> Result<Json<QueryDiceTrack>> {
|
||||
// Check if the user exists in the cache
|
||||
let owner_id = session.user_id;
|
||||
let owner_id = match state.cache.user(owner_id) {
|
||||
Some(user) => user.id,
|
||||
None => return Err(Error::not_found("User not found".to_string())),
|
||||
};
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
let dice = parse_dice(&payload.dice)?;
|
||||
|
||||
let insert_dice = InsertDiceTrack {
|
||||
guild_id: guild_id.get() as i64,
|
||||
owner_id: owner_id.get() as i64,
|
||||
dice: format_roll(dice.0, dice.1, dice.2),
|
||||
user_id: payload.user_id,
|
||||
value: payload.value,
|
||||
operator: match payload.operator {
|
||||
None => None,
|
||||
Some(s) => Some(s.to_string()),
|
||||
},
|
||||
};
|
||||
|
||||
// Check for existing dice tracks
|
||||
let results = QueryDiceTrack::find(&insert_dice).await;
|
||||
|
||||
match results {
|
||||
Some(dice_track) => Ok(Json(dice_track)),
|
||||
None => {
|
||||
let dice_track = insert_dice.insert().await?;
|
||||
Ok(Json(dice_track))
|
||||
}
|
||||
}
|
||||
}
|
||||
128
crates/siren-api/src/error.rs
Normal file
128
crates/siren-api/src/error.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use axum::{
|
||||
Json,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Error {
|
||||
pub status: u16,
|
||||
pub details: String,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn new(status: u16, details: String) -> Self {
|
||||
Self { status, details }
|
||||
}
|
||||
|
||||
pub fn not_found(details: String) -> Self {
|
||||
Self::new(404, details)
|
||||
}
|
||||
|
||||
pub fn internal_server_error(details: String) -> Self {
|
||||
Self::new(500, details)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.details.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn description(&self) -> &str {
|
||||
&self.details
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> Response {
|
||||
let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = Json(serde_json::json!({
|
||||
"error": {
|
||||
"status": self.status,
|
||||
"details": self.details,
|
||||
}
|
||||
}));
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// --- Conversions from upstream crate errors ---
|
||||
|
||||
impl From<siren_core::error::Error> for Error {
|
||||
fn from(error: siren_core::error::Error) -> Self {
|
||||
Self::new(error.status, error.details)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<siren_bot::error::Error> for Error {
|
||||
fn from(error: siren_bot::error::Error) -> Self {
|
||||
Self::new(error.status, error.details)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Conversions from external crate errors ---
|
||||
|
||||
impl From<StatusCode> for Error {
|
||||
fn from(status: StatusCode) -> Self {
|
||||
Error {
|
||||
status: status.as_u16(),
|
||||
details: status
|
||||
.canonical_reason()
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(error: reqwest::Error) -> Self {
|
||||
Self::new(500, format!("HTTP client error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(error: serde_json::Error) -> Self {
|
||||
Self::new(500, format!("JSON error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<jsonwebtoken::errors::Error> for Error {
|
||||
fn from(error: jsonwebtoken::errors::Error) -> Self {
|
||||
match error.kind() {
|
||||
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
|
||||
Self::new(401, "Token expired".to_string())
|
||||
}
|
||||
jsonwebtoken::errors::ErrorKind::InvalidToken => Self::new(401, "Invalid token".to_string()),
|
||||
_ => Self::new(500, format!("JWT error: {}", error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Direct conversions for types used in API handlers that bypass the data abstraction layer
|
||||
|
||||
impl From<sqlx::Error> for Error {
|
||||
fn from(error: sqlx::Error) -> Self {
|
||||
let core_err: siren_core::error::Error = error.into();
|
||||
core_err.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redis::RedisError> for Error {
|
||||
fn from(error: redis::RedisError) -> Self {
|
||||
let core_err: siren_core::error::Error = error.into();
|
||||
core_err.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Self::new(500, format!("IO error: {}", error))
|
||||
}
|
||||
}
|
||||
619
crates/siren-api/src/grid/mod.rs
Normal file
619
crates/siren-api/src/grid/mod.rs
Normal file
@@ -0,0 +1,619 @@
|
||||
pub mod model;
|
||||
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{OptionalAuth, Session, csprng, middleware::check_bearer_auth},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Json,
|
||||
Router,
|
||||
extract::{
|
||||
Path,
|
||||
Query,
|
||||
State,
|
||||
WebSocketUpgrade,
|
||||
ws::{Message, WebSocket},
|
||||
},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use model::{
|
||||
ClientMessage,
|
||||
CreateMapPayload,
|
||||
GridCell,
|
||||
GridMap,
|
||||
GridToken,
|
||||
MapPermission,
|
||||
MapRole,
|
||||
MapState,
|
||||
ServerMessage,
|
||||
UpdatePermissionPayload,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/maps", get(list_maps))
|
||||
.route("/maps", post(create_map))
|
||||
.route("/maps/{id}", get(get_map))
|
||||
.route("/maps/{id}", delete(delete_map))
|
||||
.route("/maps/{id}/permissions", get(list_permissions))
|
||||
.route("/maps/{id}/permissions", put(update_permission))
|
||||
.route("/maps/{id}/ws", get(ws_handler))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fetch the role of `user_id` on `map_id`, or `None` if no record exists.
|
||||
async fn get_user_role(map_id: &str, user_id: i64) -> crate::error::Result<Option<MapRole>> {
|
||||
let pool = siren_core::data::pool();
|
||||
let perm: Option<MapPermission> = sqlx::query_as(
|
||||
"SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1 AND user_id = $2",
|
||||
)
|
||||
.bind(map_id)
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(perm.map(|p| p.role))
|
||||
}
|
||||
|
||||
/// Returns whether the caller can view the map:
|
||||
/// - Public maps: always true.
|
||||
/// - Private maps: true only if the user has any role.
|
||||
async fn can_view(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
if map.is_public {
|
||||
return true;
|
||||
}
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
/// Returns whether the caller can edit the map (editor or owner role).
|
||||
async fn can_edit(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|r| r.can_edit())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Returns whether the caller is the owner.
|
||||
async fn is_owner(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|r| r.is_owner())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// REST handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_maps(OptionalAuth(session): OptionalAuth) -> Result<Json<Vec<GridMap>>> {
|
||||
let pool = siren_core::data::pool();
|
||||
let maps: Vec<GridMap> = match &session {
|
||||
Some(s) => {
|
||||
let user_id = s.user_id as i64;
|
||||
sqlx::query_as(
|
||||
"SELECT DISTINCT gm.*
|
||||
FROM grid_maps gm
|
||||
LEFT JOIN map_permissions mp ON mp.map_id = gm.id AND mp.user_id = $1
|
||||
WHERE gm.is_public = TRUE OR mp.user_id IS NOT NULL
|
||||
ORDER BY gm.created_at DESC",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
sqlx::query_as("SELECT * FROM grid_maps WHERE is_public = TRUE ORDER BY created_at DESC")
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
Ok(Json(maps))
|
||||
}
|
||||
|
||||
pub async fn create_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Json(payload): Json<CreateMapPayload>,
|
||||
) -> Result<(StatusCode, Json<GridMap>)> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
let user_id = session.user_id as i64;
|
||||
let map_id = csprng(32);
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: GridMap = sqlx::query_as(
|
||||
"INSERT INTO grid_maps (id, name, is_public, owner_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *",
|
||||
)
|
||||
.bind(&map_id)
|
||||
.bind(&payload.name)
|
||||
.bind(payload.is_public)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
// Auto-assign the creator as owner in map_permissions
|
||||
sqlx::query("INSERT INTO map_permissions (map_id, user_id, role) VALUES ($1, $2, 'owner')")
|
||||
.bind(&map_id)
|
||||
.bind(user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok((StatusCode::CREATED, Json(map)))
|
||||
}
|
||||
|
||||
pub async fn get_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<MapState>> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !can_view(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
let cells: Vec<GridCell> = sqlx::query_as("SELECT * FROM grid_cells WHERE map_id = $1")
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let tokens: Vec<GridToken> = sqlx::query_as("SELECT * FROM grid_tokens WHERE map_id = $1")
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(MapState { map, cells, tokens }))
|
||||
}
|
||||
|
||||
pub async fn delete_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission management
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_permissions(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<Vec<MapPermission>>> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
let perms: Vec<MapPermission> =
|
||||
sqlx::query_as("SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1")
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(perms))
|
||||
}
|
||||
|
||||
pub async fn update_permission(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdatePermissionPayload>,
|
||||
) -> Result<StatusCode> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
// Prevent the owner from removing their own owner record
|
||||
let caller_id = session.as_ref().map(|s| s.user_id as i64).unwrap_or(0);
|
||||
if payload.user_id == caller_id && payload.role.as_ref().map(|r| r.is_owner()) == Some(false) {
|
||||
return Err(Error::from(StatusCode::UNPROCESSABLE_ENTITY));
|
||||
}
|
||||
|
||||
match payload.role {
|
||||
Some(role) => {
|
||||
sqlx::query(
|
||||
"INSERT INTO map_permissions (map_id, user_id, role)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (map_id, user_id) DO UPDATE SET role = EXCLUDED.role",
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(payload.user_id)
|
||||
.bind(role)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
None => {
|
||||
sqlx::query("DELETE FROM map_permissions WHERE map_id = $1 AND user_id = $2")
|
||||
.bind(&id)
|
||||
.bind(payload.user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WsQuery {
|
||||
/// Optional Bearer token passed as a query parameter for WS auth.
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(map_id): Path<String>,
|
||||
Query(query): Query<WsQuery>,
|
||||
) -> impl IntoResponse {
|
||||
// Resolve the session from query param (WS can't easily send headers)
|
||||
let session: Option<Session> = match query.token {
|
||||
Some(ref tok) => check_bearer_auth(tok).await.ok(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, map_id, session))
|
||||
}
|
||||
|
||||
async fn handle_socket(
|
||||
socket: WebSocket,
|
||||
state: Arc<AppState>,
|
||||
map_id: String,
|
||||
session: Option<Session>,
|
||||
) {
|
||||
// Load the map and verify the caller can view it
|
||||
let map_state = match fetch_map_state(&map_id).await {
|
||||
Ok(ms) => ms,
|
||||
Err(_) => return, // map doesn't exist
|
||||
};
|
||||
|
||||
if !can_view(&map_state.map, &session).await {
|
||||
// Refuse the connection silently (upgrade already happened; just close)
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = can_edit(&map_state.map, &session).await;
|
||||
|
||||
// Get or create a broadcast channel for this map
|
||||
let tx = state
|
||||
.map_rooms
|
||||
.entry(map_id.clone())
|
||||
.or_insert_with(|| {
|
||||
let (tx, _) = broadcast::channel(256);
|
||||
tx
|
||||
})
|
||||
.clone();
|
||||
let mut rx = tx.subscribe();
|
||||
|
||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||
|
||||
// Send the current full map state to the newly connected client
|
||||
let init_msg = ServerMessage::State {
|
||||
cells: map_state.cells,
|
||||
tokens: map_state.tokens,
|
||||
colors: map_state.map.colors,
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&init_msg) {
|
||||
let _ = ws_tx.send(Message::Text(json.into())).await;
|
||||
}
|
||||
|
||||
// Task 1: forward broadcast messages to this socket
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(json) = rx.recv().await {
|
||||
if ws_tx.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Task 2: receive messages from this client, persist, and broadcast
|
||||
let tx_clone = tx.clone();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = ws_rx.next().await {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
handle_client_message(&text, &map_id, editor, &tx_clone).await;
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => recv_task.abort(),
|
||||
_ = &mut recv_task => send_task.abort(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_map_state(map_id: &str) -> crate::error::Result<MapState> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: GridMap = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(map_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let cells: Vec<GridCell> = sqlx::query_as("SELECT * FROM grid_cells WHERE map_id = $1")
|
||||
.bind(map_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let tokens: Vec<GridToken> = sqlx::query_as("SELECT * FROM grid_tokens WHERE map_id = $1")
|
||||
.bind(map_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(MapState { map, cells, tokens })
|
||||
}
|
||||
|
||||
async fn handle_client_message(
|
||||
raw: &str,
|
||||
map_id: &str,
|
||||
can_edit: bool,
|
||||
tx: &broadcast::Sender<String>,
|
||||
) {
|
||||
let client_msg: ClientMessage = match serde_json::from_str(raw) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
log::warn!("Invalid WS message: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// All mutating messages require editor or owner role
|
||||
if !can_edit {
|
||||
let err = ServerMessage::Error {
|
||||
message: "You do not have permission to edit this map.".into(),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&err) {
|
||||
let _ = tx.send(json);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let server_msg: Option<ServerMessage> = match client_msg {
|
||||
ClientMessage::PaintCell { x, y, color } => {
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO grid_cells (map_id, x, y, color)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (map_id, x, y) DO UPDATE SET color = EXCLUDED.color",
|
||||
)
|
||||
.bind(map_id)
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.bind(&color)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Some(ServerMessage::CellPainted { x, y, color }),
|
||||
Err(e) => {
|
||||
log::error!("DB error painting cell: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::PaintCells { cells } => {
|
||||
let mut tx_db = match pool.begin().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
log::error!("DB error starting transaction for batch paint: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut ok = true;
|
||||
for cell in &cells {
|
||||
let res = sqlx::query(
|
||||
"INSERT INTO grid_cells (map_id, x, y, color)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (map_id, x, y) DO UPDATE SET color = EXCLUDED.color",
|
||||
)
|
||||
.bind(map_id)
|
||||
.bind(cell.x)
|
||||
.bind(cell.y)
|
||||
.bind(&cell.color)
|
||||
.execute(&mut *tx_db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = res {
|
||||
log::error!("DB error in batch paint cell ({},{}): {e}", cell.x, cell.y);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if ok {
|
||||
if let Err(e) = tx_db.commit().await {
|
||||
log::error!("DB error committing batch paint: {e}");
|
||||
None
|
||||
} else {
|
||||
Some(ServerMessage::CellsBatchPainted { cells })
|
||||
}
|
||||
} else {
|
||||
let _ = tx_db.rollback().await;
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::EraseCell { x, y } => {
|
||||
let result = sqlx::query("DELETE FROM grid_cells WHERE map_id = $1 AND x = $2 AND y = $3")
|
||||
.bind(map_id)
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Some(ServerMessage::CellErased { x, y }),
|
||||
Err(e) => {
|
||||
log::error!("DB error erasing cell: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::AddToken { x, y, label, color } => {
|
||||
let token_id = csprng(16);
|
||||
let result: sqlx::Result<GridToken> = sqlx::query_as(
|
||||
"INSERT INTO grid_tokens (id, map_id, x, y, label, color)
|
||||
VALUES ($1, $2, $3, $4, $5, $6) RETURNING *",
|
||||
)
|
||||
.bind(&token_id)
|
||||
.bind(map_id)
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.bind(&label)
|
||||
.bind(&color)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(token) => Some(ServerMessage::TokenAdded {
|
||||
id: token.id,
|
||||
x: token.x,
|
||||
y: token.y,
|
||||
label: token.label,
|
||||
color: token.color,
|
||||
}),
|
||||
Err(e) => {
|
||||
log::error!("DB error adding token: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::MoveToken { id, x, y } => {
|
||||
let result =
|
||||
sqlx::query("UPDATE grid_tokens SET x = $1, y = $2 WHERE id = $3 AND map_id = $4")
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.bind(&id)
|
||||
.bind(map_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => Some(ServerMessage::TokenMoved { id, x, y }),
|
||||
Ok(_) => None,
|
||||
Err(e) => {
|
||||
log::error!("DB error moving token: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::DeleteToken { id } => {
|
||||
let result = sqlx::query("DELETE FROM grid_tokens WHERE id = $1 AND map_id = $2")
|
||||
.bind(&id)
|
||||
.bind(map_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => Some(ServerMessage::TokenDeleted { id }),
|
||||
Ok(_) => None,
|
||||
Err(e) => {
|
||||
log::error!("DB error deleting token: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::UpdateColors { colors } => {
|
||||
let result =
|
||||
sqlx::query("UPDATE grid_maps SET colors = $1, updated_at = NOW() WHERE id = $2")
|
||||
.bind(&colors)
|
||||
.bind(map_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Some(ServerMessage::ColorsUpdated { colors }),
|
||||
Err(e) => {
|
||||
log::error!("DB error updating colors: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(msg) = server_msg {
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
let _ = tx.send(json);
|
||||
}
|
||||
}
|
||||
}
|
||||
190
crates/siren-api/src/grid/model.rs
Normal file
190
crates/siren-api/src/grid/model.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use chrono::NaiveDateTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Map Role / Permission
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::Type, Clone, Debug, PartialEq, Eq)]
|
||||
#[sqlx(type_name = "text", rename_all = "lowercase")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MapRole {
|
||||
Owner,
|
||||
Editor,
|
||||
Viewer,
|
||||
}
|
||||
|
||||
impl MapRole {
|
||||
/// Returns true if this role can mutate map content (paint, tokens, colors).
|
||||
pub fn can_edit(&self) -> bool {
|
||||
matches!(self, MapRole::Owner | MapRole::Editor)
|
||||
}
|
||||
|
||||
/// Returns true if this role can manage permissions and delete the map.
|
||||
pub fn is_owner(&self) -> bool {
|
||||
matches!(self, MapRole::Owner)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct MapPermission {
|
||||
pub map_id: String,
|
||||
pub user_id: i64,
|
||||
pub role: MapRole,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Grid Map
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct GridMap {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub is_public: bool,
|
||||
pub owner_id: i64,
|
||||
pub colors: Vec<String>,
|
||||
pub created_at: NaiveDateTime,
|
||||
pub updated_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct CreateMapPayload {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub is_public: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct UpdatePermissionPayload {
|
||||
/// Discord user ID of the target user.
|
||||
pub user_id: i64,
|
||||
/// New role to assign. Omit (null) to remove the permission entry.
|
||||
pub role: Option<MapRole>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Grid Cell (no id column — composite PK in DB)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct GridCell {
|
||||
pub map_id: String,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub color: String,
|
||||
}
|
||||
|
||||
/// Lightweight cell used for batch operations.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct CellPatch {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub color: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Grid Token
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct GridToken {
|
||||
pub id: String,
|
||||
pub map_id: String,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub label: String,
|
||||
pub color: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Full map state (used on initial WS connect and REST GET)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct MapState {
|
||||
pub map: GridMap,
|
||||
pub cells: Vec<GridCell>,
|
||||
pub tokens: Vec<GridToken>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket message types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ClientMessage {
|
||||
PaintCell {
|
||||
x: i32,
|
||||
y: i32,
|
||||
color: String,
|
||||
},
|
||||
PaintCells {
|
||||
cells: Vec<CellPatch>,
|
||||
},
|
||||
EraseCell {
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
AddToken {
|
||||
x: i32,
|
||||
y: i32,
|
||||
label: String,
|
||||
color: String,
|
||||
},
|
||||
MoveToken {
|
||||
id: String,
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
DeleteToken {
|
||||
id: String,
|
||||
},
|
||||
UpdateColors {
|
||||
colors: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ServerMessage {
|
||||
State {
|
||||
cells: Vec<GridCell>,
|
||||
tokens: Vec<GridToken>,
|
||||
colors: Vec<String>,
|
||||
},
|
||||
CellPainted {
|
||||
x: i32,
|
||||
y: i32,
|
||||
color: String,
|
||||
},
|
||||
CellsBatchPainted {
|
||||
cells: Vec<CellPatch>,
|
||||
},
|
||||
CellErased {
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
TokenAdded {
|
||||
id: String,
|
||||
x: i32,
|
||||
y: i32,
|
||||
label: String,
|
||||
color: String,
|
||||
},
|
||||
TokenMoved {
|
||||
id: String,
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
TokenDeleted {
|
||||
id: String,
|
||||
},
|
||||
ColorsUpdated {
|
||||
colors: Vec<String>,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
20
crates/siren-api/src/lib.rs
Normal file
20
crates/siren-api/src/lib.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
pub mod app;
|
||||
mod app_state;
|
||||
pub mod audio;
|
||||
pub mod auth;
|
||||
pub mod dice;
|
||||
pub mod error;
|
||||
pub mod grid;
|
||||
|
||||
pub use app::App;
|
||||
pub use app_state::AppState;
|
||||
use axum::Router;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.nest("/auth", auth::get_routes())
|
||||
.nest("/audio/{guild_id}", audio::get_routes())
|
||||
.nest("/dice", dice::get_routes())
|
||||
.nest("/grid", grid::get_routes())
|
||||
}
|
||||
Reference in New Issue
Block a user