Updating auth
This commit is contained in:
@@ -26,3 +26,6 @@ redis = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
argon2 = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
cookie = { workspace = true }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{AppState, error::Result};
|
||||
use axum::Router;
|
||||
use axum::{Router, http::HeaderValue};
|
||||
use std::{env, sync::Arc};
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::{
|
||||
@@ -19,17 +19,36 @@ impl App {
|
||||
pub async fn serve(self) -> Result<()> {
|
||||
log::debug!("Starting API...");
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
// Build CORS layer.
|
||||
//
|
||||
// In production both the UI and API are served from the same origin so
|
||||
// CORS is a non-issue. In development, Vite proxies all /api/* calls so
|
||||
// the browser also never makes cross-origin requests directly to this
|
||||
// server. We keep a permissive default for convenience, but restrict it
|
||||
// when CORS_ORIGIN is explicitly set.
|
||||
let cors = match env::var("CORS_ORIGIN") {
|
||||
Ok(origin) if origin != "*" => {
|
||||
let header_val = origin
|
||||
.parse::<HeaderValue>()
|
||||
.expect("CORS_ORIGIN is not a valid header value");
|
||||
CorsLayer::new()
|
||||
.allow_origin(header_val)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.allow_credentials(true)
|
||||
}
|
||||
_ => CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any),
|
||||
};
|
||||
|
||||
// Serve the built React frontend from frontend/dist (relative to the
|
||||
// working directory). Falls back gracefully if the directory does not
|
||||
// exist yet (e.g. during development when using `npm run dev`).
|
||||
// Serve the built React frontend from ui/dist (relative to the working
|
||||
// directory). Falls back gracefully if the directory does not exist yet
|
||||
// (e.g. during development when using `npm run dev`).
|
||||
let frontend_dir = env::current_dir()
|
||||
.unwrap_or_default()
|
||||
.join("frontend")
|
||||
.join("ui")
|
||||
.join("dist");
|
||||
|
||||
// For SPA routing: any path not matched by a real file (e.g. /map/<id>)
|
||||
|
||||
@@ -5,6 +5,16 @@ use serenity::{
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Data stored per-entry in the Discord OAuth state cache.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DiscordOAuthState {
|
||||
/// Where to send the browser after the OAuth dance completes.
|
||||
pub redirect_uri: String,
|
||||
/// Set when a logged-in user is connecting (not logging in) via Discord.
|
||||
pub connecting_user_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@@ -12,9 +22,9 @@ pub struct AppState {
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub base_url: String,
|
||||
/// Maps oauth_state → ui_redirect_uri.
|
||||
/// Populated on /authorize, consumed on /callback.
|
||||
pub discord_authorize_cache: Arc<Mutex<HashMap<String, String>>>,
|
||||
/// Maps oauth_state → DiscordOAuthState.
|
||||
/// Populated on /authorize or /connect, consumed on /callback.
|
||||
pub discord_authorize_cache: Arc<Mutex<HashMap<String, DiscordOAuthState>>>,
|
||||
pub http: Arc<Http>,
|
||||
pub cache: Arc<Cache>,
|
||||
/// Per-map WebSocket broadcast channels for real-time collaboration.
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{AuthorizationMiddleware, Session},
|
||||
auth::SessionAuthorization,
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Extension,
|
||||
Json,
|
||||
Router,
|
||||
extract::{Path, State},
|
||||
middleware::from_extractor,
|
||||
http::StatusCode,
|
||||
routing::post,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
@@ -22,15 +21,13 @@ use siren_bot::{
|
||||
handler::get_songbird,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/play", post(play_audio))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
.route("/pause", post(pause_audio))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
.route("/resume", post(resume_audio))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -38,19 +35,44 @@ struct PlayTrackRequest {
|
||||
url: String,
|
||||
}
|
||||
|
||||
/// Resolve the Discord snowflake for a local user from `user_connections`.
|
||||
/// Returns an error if the user has no linked Discord account.
|
||||
async fn get_discord_snowflake(local_user_id: Uuid) -> Result<u64> {
|
||||
let pool = siren_core::data::pool();
|
||||
let provider_id: Option<String> = sqlx::query_scalar(
|
||||
"SELECT provider_user_id FROM user_connections \
|
||||
WHERE user_id = $1 AND provider = 'discord'",
|
||||
)
|
||||
.bind(local_user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
provider_id
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.ok_or_else(|| Error::not_found("Discord account not connected".to_string()))
|
||||
}
|
||||
|
||||
async fn play_audio(
|
||||
Extension(session): Extension<Session>,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
Json(payload): Json<PlayTrackRequest>,
|
||||
) -> Result<()> {
|
||||
log::debug!("Playing audio in guild: {}", guild_id);
|
||||
|
||||
// Check if the user exists in the cache
|
||||
let user_id = session.user_id;
|
||||
let user_id = match state.cache.user(user_id) {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
// Resolve Discord snowflake from the local user_id
|
||||
let discord_snowflake = get_discord_snowflake(session.user_id).await?;
|
||||
|
||||
// Check if the user exists in the Discord cache
|
||||
let user_id = match state.cache.user(discord_snowflake) {
|
||||
Some(user) => user.id,
|
||||
None => return Err(Error::not_found("User not found".to_string())),
|
||||
None => {
|
||||
return Err(Error::not_found(
|
||||
"User not found in Discord cache".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
@@ -61,16 +83,17 @@ async fn play_audio(
|
||||
|
||||
// Play the track
|
||||
let manager = get_songbird();
|
||||
let _channel_id = join_voice_channel(&state.cache, &manager, &guild_id, &user_id).await?;
|
||||
let _channel_id = join_voice_channel(&state.cache, manager, &guild_id, &user_id).await?;
|
||||
enqueue_track(manager, guild_id.to_owned(), &payload.url).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pause_audio(
|
||||
Extension(_): Extension<Session>,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> Result<()> {
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
@@ -86,11 +109,12 @@ async fn pause_audio(
|
||||
}
|
||||
|
||||
async fn resume_audio(
|
||||
Extension(_): Extension<Session>,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> Result<()> {
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
log::debug!("Resuming audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
@@ -98,7 +122,7 @@ async fn resume_audio(
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
// Pause the track
|
||||
// Resume the track
|
||||
let manager = get_songbird();
|
||||
resume_track(manager, &guild_id).await?;
|
||||
Ok(())
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Claims encoded in the JWT stored in the `siren_session` cookie
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BearerTokenClaims {
|
||||
pub sub: u64,
|
||||
/// Local user UUID (as a string)
|
||||
pub sub: String,
|
||||
/// Display username
|
||||
pub name: String,
|
||||
/// Issued-at epoch seconds
|
||||
pub iat: i64,
|
||||
/// Expiry epoch seconds
|
||||
pub exp: i64,
|
||||
/// Redis session key (used to look up the full session)
|
||||
pub jti: String,
|
||||
}
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{bearer_token::BearerTokenClaims, csprng, session::Session},
|
||||
app_state::DiscordOAuthState,
|
||||
auth::{
|
||||
SessionAuthorization,
|
||||
local::{build_session_cookie, issue_jwt},
|
||||
middleware::{compute_fingerprint, extract_ip},
|
||||
session::Session,
|
||||
},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::get,
|
||||
};
|
||||
use axum_extra::extract::CookieJar;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, sync::Arc};
|
||||
use siren_core::utils::csprng;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
const DISCORD_REDIRECT_PATH: &str = "/api/auth/discord/callback";
|
||||
|
||||
@@ -18,14 +28,15 @@ pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/authorize", get(discord_authorize))
|
||||
.route("/callback", get(discord_callback))
|
||||
.route("/connect", get(discord_connect))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct AuthorizeQuery {
|
||||
redirect_uri: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct CallbackQuery {
|
||||
code: String,
|
||||
state: Option<String>,
|
||||
@@ -48,22 +59,66 @@ struct DiscordUser {
|
||||
avatar: Option<String>,
|
||||
}
|
||||
|
||||
/// Begin a Discord OAuth login flow (anonymous users)
|
||||
///
|
||||
/// Stores the caller's desired `redirect_uri` in the state cache so the
|
||||
/// callback can redirect to the right place after login
|
||||
async fn discord_authorize(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<AuthorizeQuery>,
|
||||
) -> impl IntoResponse {
|
||||
let oauth_state = csprng(16);
|
||||
log::trace!("Discord authorize: {:?}, state={}", query, oauth_state);
|
||||
|
||||
state
|
||||
.discord_authorize_cache
|
||||
.lock()
|
||||
.await
|
||||
.insert(oauth_state.clone(), query.redirect_uri);
|
||||
state.discord_authorize_cache.lock().await.insert(
|
||||
oauth_state.clone(),
|
||||
DiscordOAuthState {
|
||||
redirect_uri: query.redirect_uri,
|
||||
connecting_user_id: None,
|
||||
},
|
||||
);
|
||||
|
||||
build_discord_oauth_url(&state, &oauth_state)
|
||||
}
|
||||
|
||||
/// Begin a Discord OAuth connect flow (already-authenticated users).
|
||||
///
|
||||
/// The caller must have a valid session cookie. Their user ID is stored
|
||||
/// in the state cache so the callback can link the Discord account to the
|
||||
/// existing local account.
|
||||
async fn discord_connect(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<AuthorizeQuery>,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
let oauth_state = csprng(16);
|
||||
log::trace!(
|
||||
"Discord connect: {:?}, state={} (user_id={})",
|
||||
query,
|
||||
oauth_state,
|
||||
session.user_id
|
||||
);
|
||||
|
||||
state.discord_authorize_cache.lock().await.insert(
|
||||
oauth_state.clone(),
|
||||
DiscordOAuthState {
|
||||
redirect_uri: query.redirect_uri,
|
||||
connecting_user_id: Some(session.user_id),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(build_discord_oauth_url(&state, &oauth_state))
|
||||
}
|
||||
|
||||
fn build_discord_oauth_url(
|
||||
state: &AppState,
|
||||
oauth_state: &str,
|
||||
) -> std::result::Result<String, StatusCode> {
|
||||
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
|
||||
let encoded_callback = discord_callback_url.replace(':', "%3A").replace('/', "%2F");
|
||||
let encoded_callback = urlencoding_encode(&discord_callback_url);
|
||||
|
||||
let discord_auth_url = format!(
|
||||
let url = format!(
|
||||
"https://discord.com/api/oauth2/authorize\
|
||||
?client_id={}\
|
||||
&redirect_uri={}\
|
||||
@@ -73,8 +128,11 @@ async fn discord_authorize(
|
||||
state.client_id, encoded_callback, oauth_state,
|
||||
);
|
||||
|
||||
match serde_json::to_string(&discord_auth_url) {
|
||||
Ok(json) => Ok(json),
|
||||
match serde_json::to_string(&url) {
|
||||
Ok(json) => {
|
||||
log::trace!("Discord OAuth URL: {}", json);
|
||||
Ok(json)
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to serialize Discord OAuth URL: {e}");
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -82,14 +140,26 @@ async fn discord_authorize(
|
||||
}
|
||||
}
|
||||
|
||||
/// Very small percent-encoder for the callback URL (replaces `:` and `/`).
|
||||
fn urlencoding_encode(s: &str) -> String {
|
||||
s.replace(':', "%3A").replace('/', "%2F")
|
||||
}
|
||||
|
||||
/// Handle the Discord OAuth callback.
|
||||
///
|
||||
/// Two modes depending on what was stored in the state cache:
|
||||
/// - **Login** (`connecting_user_id = None`): look up (or create) the local
|
||||
/// user for this Discord account, then issue a session cookie and redirect.
|
||||
/// - **Connect** (`connecting_user_id = Some(id)`): link the Discord account
|
||||
/// to the existing local user, then redirect (no new session needed).
|
||||
async fn discord_callback(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<CallbackQuery>,
|
||||
headers: HeaderMap,
|
||||
jar: CookieJar,
|
||||
) -> impl IntoResponse {
|
||||
match do_oauth_callback(state, query).await {
|
||||
Ok((token, ui_redirect_uri)) => {
|
||||
Redirect::temporary(&format!("{}?token={}", ui_redirect_uri, token)).into_response()
|
||||
}
|
||||
match do_oauth_callback(state, query, headers, jar).await {
|
||||
Ok(response) => response,
|
||||
Err((e, ui_redirect_uri)) => {
|
||||
log::error!("OAuth callback error: {:?}", e);
|
||||
let fallback = ui_redirect_uri.unwrap_or_else(|| "/".to_string());
|
||||
@@ -98,33 +168,37 @@ async fn discord_callback(
|
||||
}
|
||||
}
|
||||
|
||||
type CallbackErr = (Error, Option<String>);
|
||||
|
||||
async fn do_oauth_callback(
|
||||
state: Arc<AppState>,
|
||||
query: CallbackQuery,
|
||||
) -> Result<(String, String), (crate::error::Error, Option<String>)> {
|
||||
// Validate the state and retrieve the associated UI redirect URI
|
||||
let ui_redirect_uri = {
|
||||
let mut oauth_states = state.discord_authorize_cache.lock().await;
|
||||
headers: HeaderMap,
|
||||
jar: CookieJar,
|
||||
) -> std::result::Result<axum::response::Response, CallbackErr> {
|
||||
// Validate state & retrieve stored data
|
||||
let stored = {
|
||||
let mut cache = state.discord_authorize_cache.lock().await;
|
||||
match query.state {
|
||||
Some(ref oauth_state) => match oauth_states.remove(oauth_state) {
|
||||
Some(uri) => uri,
|
||||
Some(ref s) => match cache.remove(s) {
|
||||
Some(v) => v,
|
||||
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
|
||||
},
|
||||
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
|
||||
}
|
||||
};
|
||||
log::trace!("Discord callback: query={:?} state={:?}", query, stored);
|
||||
|
||||
// Helper closure to tag errors with the redirect URI we already know
|
||||
let redirect = ui_redirect_uri.clone();
|
||||
let err = |s: StatusCode| -> Result<_, (crate::error::Error, Option<String>)> {
|
||||
Err((s.into(), Some(redirect.clone())))
|
||||
let ui_redirect_uri = stored.redirect_uri.clone();
|
||||
let err_redirect = |s: StatusCode| -> std::result::Result<_, CallbackErr> {
|
||||
Err((s.into(), Some(ui_redirect_uri.clone())))
|
||||
};
|
||||
|
||||
// The discord redirect_uri in the token exchange must match what was sent in /authorize
|
||||
// The redirect_uri sent to Discord must exactly match /authorize
|
||||
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
|
||||
|
||||
// Exchange code for an access token
|
||||
let token_response = state
|
||||
// Exchange code for Discord access token
|
||||
let token_resp = state
|
||||
.client
|
||||
.post("https://discord.com/api/oauth2/token")
|
||||
.form(&[
|
||||
@@ -136,90 +210,229 @@ async fn do_oauth_callback(
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
.map_err(|_| err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
if !token_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to exchange token: {:?}",
|
||||
token_response.text().await
|
||||
);
|
||||
return err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
if !token_resp.status().is_success() {
|
||||
log::error!("Token exchange failed: {:?}", token_resp.text().await);
|
||||
return err_redirect(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
let token_data: DiscordTokenResponse = token_response
|
||||
let token_data: DiscordTokenResponse = token_resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
.map_err(|_| err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
// Fetch user information from Discord
|
||||
let user_response = state
|
||||
// Fetch Discord user info
|
||||
let user_resp = state
|
||||
.client
|
||||
.get("https://discord.com/api/users/@me")
|
||||
.bearer_auth(token_data.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
.map_err(|_| err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
if !user_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to fetch user information: {:?}",
|
||||
user_response.text().await
|
||||
);
|
||||
return err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
if !user_resp.status().is_success() {
|
||||
log::error!("Discord user fetch failed: {:?}", user_resp.text().await);
|
||||
return err_redirect(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
let user_data: DiscordUser = user_response
|
||||
let discord_user: DiscordUser = user_resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
.map_err(|_| err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
log::debug!("User authenticated: {:?}", user_data);
|
||||
log::debug!("Discord OAuth user: {:?}", discord_user);
|
||||
|
||||
let user_id: i64 = user_data
|
||||
.id
|
||||
.parse::<i64>()
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
// Upsert the Discord user into the local users table
|
||||
let pool = siren_core::data::pool();
|
||||
sqlx::query(
|
||||
"INSERT INTO users (id, username, avatar, updated_at)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET username = EXCLUDED.username,
|
||||
avatar = EXCLUDED.avatar,
|
||||
updated_at = NOW()",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&user_data.username)
|
||||
.bind(&user_data.avatar)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("Failed to upsert user: {e}");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
// Create and insert the session
|
||||
let session = Session::new(user_id as u64, user_data.username.clone());
|
||||
session
|
||||
.insert()
|
||||
.await
|
||||
.map_err(|e| (e, Some(ui_redirect_uri.clone())))?;
|
||||
match stored.connecting_user_id {
|
||||
// Handle connecting an existing local user to a new Discord account
|
||||
Some(connecting_user_id) => {
|
||||
// Make sure this Discord account isn't already linked to a DIFFERENT user
|
||||
let existing_owner: Option<Uuid> = sqlx::query_scalar(
|
||||
"SELECT user_id FROM user_connections \
|
||||
WHERE provider = 'discord' AND provider_user_id = $1",
|
||||
)
|
||||
.bind(&discord_user.id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error checking connection: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
let issued_at = chrono::Utc::now();
|
||||
let claims = BearerTokenClaims {
|
||||
sub: session.user_id,
|
||||
name: session.user_name.clone(),
|
||||
iat: issued_at.timestamp(),
|
||||
exp: session.expires_at.timestamp(),
|
||||
jti: session.session_id.clone(),
|
||||
};
|
||||
if let Some(owner_id) = existing_owner {
|
||||
if owner_id != connecting_user_id {
|
||||
return err_redirect(StatusCode::CONFLICT);
|
||||
}
|
||||
}
|
||||
|
||||
let jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
|
||||
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
// Upsert the connection
|
||||
sqlx::query(
|
||||
"INSERT INTO user_connections \
|
||||
(user_id, provider, provider_user_id, provider_username, provider_avatar) \
|
||||
VALUES ($1, 'discord', $2, $3, $4) \
|
||||
ON CONFLICT (user_id, provider) DO UPDATE \
|
||||
SET provider_user_id = EXCLUDED.provider_user_id, \
|
||||
provider_username = EXCLUDED.provider_username, \
|
||||
provider_avatar = EXCLUDED.provider_avatar",
|
||||
)
|
||||
.bind(connecting_user_id)
|
||||
.bind(&discord_user.id)
|
||||
.bind(&discord_user.username)
|
||||
.bind(&discord_user.avatar)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error upserting connection: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
Ok((token, ui_redirect_uri))
|
||||
// No new session — redirect back to account page with existing cookie
|
||||
Ok(Redirect::temporary(&ui_redirect_uri).into_response())
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------ //
|
||||
// LOGIN MODE: look up (or create) the local user for this Discord account
|
||||
// ------------------------------------------------------------------ //
|
||||
None => {
|
||||
// Find existing connection → local user_id
|
||||
let local_user_id: Option<(Uuid, String)> = sqlx::query_as(
|
||||
"SELECT u.id, u.username \
|
||||
FROM user_connections uc \
|
||||
JOIN users u ON u.id = uc.user_id \
|
||||
WHERE uc.provider = 'discord' AND uc.provider_user_id = $1",
|
||||
)
|
||||
.bind(&discord_user.id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error looking up discord user: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
let (user_id, username) = match local_user_id {
|
||||
// Already linked — use the existing local user
|
||||
Some(row) => {
|
||||
// Keep provider fields up to date
|
||||
sqlx::query(
|
||||
"UPDATE user_connections \
|
||||
SET provider_username = $1, provider_avatar = $2 \
|
||||
WHERE user_id = $3 AND provider = 'discord'",
|
||||
)
|
||||
.bind(&discord_user.username)
|
||||
.bind(&discord_user.avatar)
|
||||
.bind(row.0)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error updating connection: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
row
|
||||
}
|
||||
|
||||
// First login — create a local user + connection
|
||||
None => {
|
||||
let base_username = &discord_user.username;
|
||||
let username = generate_unique_username(pool, base_username)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error generating username: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
// Create user (no password_hash — OAuth only)
|
||||
let new_id: Uuid =
|
||||
sqlx::query_scalar("INSERT INTO users (username) VALUES ($1) RETURNING id")
|
||||
.bind(&username)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error creating user: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO user_connections \
|
||||
(user_id, provider, provider_user_id, provider_username, provider_avatar) \
|
||||
VALUES ($1, 'discord', $2, $3, $4)",
|
||||
)
|
||||
.bind(new_id)
|
||||
.bind(&discord_user.id)
|
||||
.bind(&discord_user.username)
|
||||
.bind(&discord_user.avatar)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error inserting connection: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
(new_id, username)
|
||||
}
|
||||
};
|
||||
|
||||
// Build fingerprint from the callback request's headers
|
||||
let ip = extract_ip(&headers);
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let fingerprint = compute_fingerprint(&ip, &user_agent);
|
||||
|
||||
// Issue session
|
||||
let session = Session::new(user_id, username, fingerprint);
|
||||
session.insert().await.map_err(|e| {
|
||||
log::error!("Redis error inserting session: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
let token = issue_jwt(&session).map_err(|e| {
|
||||
log::error!("JWT error: {e}");
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
let cookie = build_session_cookie(token);
|
||||
let new_jar = jar.add(cookie);
|
||||
|
||||
Ok((new_jar, Redirect::temporary(&ui_redirect_uri)).into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return a username derived from `base` that does not yet exist in `users`.
|
||||
async fn generate_unique_username(pool: &sqlx::PgPool, base: &str) -> crate::error::Result<String> {
|
||||
// Truncate to 28 chars to leave room for the `_XXXX` suffix
|
||||
let base = if base.len() > 28 { &base[..28] } else { base };
|
||||
|
||||
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users WHERE username = $1")
|
||||
.bind(base)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if count == 0 {
|
||||
return Ok(base.to_string());
|
||||
}
|
||||
|
||||
for _ in 0..20 {
|
||||
let candidate = format!("{}_{}", base, csprng(4));
|
||||
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users WHERE username = $1")
|
||||
.bind(&candidate)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
if count == 0 {
|
||||
return Ok(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::internal_server_error(
|
||||
"Could not generate a unique username".into(),
|
||||
))
|
||||
}
|
||||
|
||||
474
crates/siren-api/src/auth/local.rs
Normal file
474
crates/siren-api/src/auth/local.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{
|
||||
SessionAuthorization,
|
||||
bearer_token::BearerTokenClaims,
|
||||
middleware::{compute_fingerprint, extract_ip},
|
||||
session::Session,
|
||||
},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use argon2::{
|
||||
Argon2,
|
||||
PasswordHash,
|
||||
PasswordHasher,
|
||||
PasswordVerifier,
|
||||
password_hash::{SaltString, rand_core::OsRng},
|
||||
};
|
||||
use axum::{
|
||||
Json,
|
||||
Router,
|
||||
extract::Path,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use axum_extra::extract::CookieJar;
|
||||
use cookie::{Cookie, SameSite};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siren_core::data;
|
||||
use std::{env, sync::Arc};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/register", post(register))
|
||||
.route("/login", post(login))
|
||||
.route("/logout", post(logout))
|
||||
.route("/me", get(me))
|
||||
.route("/profile", put(update_profile))
|
||||
.route("/change-password", post(change_password))
|
||||
.route("/connections/{provider}", delete(disconnect_provider))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Payloads
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RegisterPayload {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LoginPayload {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UpdateProfilePayload {
|
||||
first_name: Option<String>,
|
||||
last_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChangePasswordPayload {
|
||||
/// Required when the user already has a password set. Omit (null) when
|
||||
/// setting a password for the first time (OAuth-only account).
|
||||
current_password: Option<String>,
|
||||
new_password: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Response types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ConnectionInfo {
|
||||
pub provider: String,
|
||||
pub provider_username: Option<String>,
|
||||
pub provider_avatar: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct UserInfo {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub first_name: Option<String>,
|
||||
pub last_name: Option<String>,
|
||||
pub email: Option<String>,
|
||||
/// True when the account has a local password set (i.e. can log in without
|
||||
/// OAuth and can safely disconnect OAuth providers).
|
||||
pub has_password: bool,
|
||||
pub connections: Vec<ConnectionInfo>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DB row types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct DbUser {
|
||||
id: Uuid,
|
||||
username: String,
|
||||
first_name: Option<String>,
|
||||
last_name: Option<String>,
|
||||
email: Option<String>,
|
||||
password_hash: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct DbConnection {
|
||||
provider: String,
|
||||
provider_username: Option<String>,
|
||||
provider_avatar: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Password helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Hash and salt a plaintext password with Argon2.
|
||||
pub fn hash_password(password: &str) -> Result<String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
Argon2::default()
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map(|h| h.to_string())
|
||||
.map_err(|e| Error::internal_server_error(format!("Password hashing error: {e}")))
|
||||
}
|
||||
|
||||
/// Return `true` if `password` matches the stored Argon2id `hash`.
|
||||
pub fn verify_password(password: &str, hash: &str) -> bool {
|
||||
let Ok(parsed) = PasswordHash::new(hash) else {
|
||||
return false;
|
||||
};
|
||||
Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed)
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Cookie / session helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build the `siren_session` HttpOnly Secure cookie.
|
||||
pub fn build_session_cookie(token: String) -> Cookie<'static> {
|
||||
Cookie::build(("siren_session", token))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.same_site(SameSite::Lax)
|
||||
.path("/")
|
||||
.build()
|
||||
}
|
||||
|
||||
/// Issue a signed JWT for `session`.
|
||||
pub fn issue_jwt(session: &Session) -> Result<String> {
|
||||
let jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
|
||||
let claims = BearerTokenClaims {
|
||||
sub: session.user_id.to_string(),
|
||||
name: session.user_name.clone(),
|
||||
iat: chrono::Utc::now().timestamp(),
|
||||
exp: session.expires_at.timestamp(),
|
||||
jti: session.session_id.clone(),
|
||||
};
|
||||
jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
|
||||
.map_err(|e| Error::internal_server_error(format!("JWT error: {e}")))
|
||||
}
|
||||
|
||||
/// Create a session + JWT + Set-Cookie for `user_id` / `user_name`.
|
||||
#[allow(dead_code)]
|
||||
pub async fn create_session_and_cookie(
|
||||
user_id: Uuid,
|
||||
user_name: String,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<(CookieJar, ())> {
|
||||
let ip = extract_ip(headers);
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let fingerprint = compute_fingerprint(&ip, &user_agent);
|
||||
|
||||
let session = Session::new(user_id, user_name, fingerprint);
|
||||
session.insert().await?;
|
||||
|
||||
let token = issue_jwt(&session)?;
|
||||
let cookie = build_session_cookie(token);
|
||||
let jar = CookieJar::new().add(cookie);
|
||||
Ok((jar, ()))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper: load full UserInfo for a given user_id
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn load_user_info(user_id: Uuid) -> Result<UserInfo> {
|
||||
let pool = data::pool();
|
||||
|
||||
let user: DbUser = sqlx::query_as(
|
||||
"SELECT id, username, first_name, last_name, email, password_hash FROM users WHERE id = $1",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let connections: Vec<DbConnection> = sqlx::query_as(
|
||||
"SELECT provider, provider_username, provider_avatar \
|
||||
FROM user_connections WHERE user_id = $1",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(UserInfo {
|
||||
id: user.id.to_string(),
|
||||
username: user.username,
|
||||
first_name: user.first_name,
|
||||
last_name: user.last_name,
|
||||
email: user.email,
|
||||
has_password: user.password_hash.is_some(),
|
||||
connections: connections
|
||||
.into_iter()
|
||||
.map(|c| ConnectionInfo {
|
||||
provider: c.provider,
|
||||
provider_username: c.provider_username,
|
||||
provider_avatar: c.provider_avatar,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn register(
|
||||
headers: HeaderMap,
|
||||
jar: CookieJar,
|
||||
Json(payload): Json<RegisterPayload>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let username = payload.username.trim().to_string();
|
||||
if username.is_empty() || username.len() > 32 {
|
||||
return Err(Error::new(422, "Username must be 1–32 characters".into()));
|
||||
}
|
||||
if payload.password.len() < 8 {
|
||||
return Err(Error::new(
|
||||
422,
|
||||
"Password must be at least 8 characters".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = data::pool();
|
||||
|
||||
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE username = $1)")
|
||||
.bind(&username)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
if exists {
|
||||
return Err(Error::new(409, "Username already taken".into()));
|
||||
}
|
||||
|
||||
let password_hash = hash_password(&payload.password)?;
|
||||
let user_id: Uuid =
|
||||
sqlx::query_scalar("INSERT INTO users (username, password_hash) VALUES ($1, $2) RETURNING id")
|
||||
.bind(&username)
|
||||
.bind(&password_hash)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let ip = extract_ip(&headers);
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let fingerprint = compute_fingerprint(&ip, &user_agent);
|
||||
|
||||
let session = Session::new(user_id, username, fingerprint);
|
||||
session.insert().await?;
|
||||
|
||||
let token = issue_jwt(&session)?;
|
||||
let cookie = build_session_cookie(token);
|
||||
|
||||
Ok((jar.add(cookie), StatusCode::CREATED))
|
||||
}
|
||||
|
||||
async fn login(
|
||||
headers: HeaderMap,
|
||||
jar: CookieJar,
|
||||
Json(payload): Json<LoginPayload>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let pool = data::pool();
|
||||
|
||||
let row: Option<(Uuid, String, Option<String>)> =
|
||||
sqlx::query_as("SELECT id, username, password_hash FROM users WHERE username = $1")
|
||||
.bind(&payload.username)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let (user_id, username, password_hash) =
|
||||
row.ok_or_else(|| Error::new(401, "Invalid username or password".into()))?;
|
||||
|
||||
let hash =
|
||||
password_hash.ok_or_else(|| Error::new(401, "This account uses external login only".into()))?;
|
||||
|
||||
if !verify_password(&payload.password, &hash) {
|
||||
return Err(Error::new(401, "Invalid username or password".into()));
|
||||
}
|
||||
|
||||
let ip = extract_ip(&headers);
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let fingerprint = compute_fingerprint(&ip, &user_agent);
|
||||
|
||||
let session = Session::new(user_id, username, fingerprint);
|
||||
session.insert().await?;
|
||||
|
||||
let token = issue_jwt(&session)?;
|
||||
let cookie = build_session_cookie(token);
|
||||
|
||||
Ok((jar.add(cookie), StatusCode::OK))
|
||||
}
|
||||
|
||||
async fn logout(
|
||||
jar: CookieJar,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
) -> impl IntoResponse {
|
||||
if let Some(s) = session {
|
||||
let _ = Session::delete(&s.session_id).await;
|
||||
}
|
||||
let removal = Cookie::build(("siren_session", ""))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.same_site(SameSite::Lax)
|
||||
.path("/")
|
||||
.max_age(cookie::time::Duration::seconds(0))
|
||||
.build();
|
||||
(jar.add(removal), StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn me(SessionAuthorization(session): SessionAuthorization) -> Result<Json<UserInfo>> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
Ok(Json(load_user_info(session.user_id).await?))
|
||||
}
|
||||
|
||||
async fn update_profile(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Json(payload): Json<UpdateProfilePayload>,
|
||||
) -> Result<Json<UserInfo>> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
let pool = data::pool();
|
||||
|
||||
// Validate lengths if provided
|
||||
if let Some(ref f) = payload.first_name {
|
||||
if f.len() > 64 {
|
||||
return Err(Error::new(422, "First name must be ≤ 64 characters".into()));
|
||||
}
|
||||
}
|
||||
if let Some(ref l) = payload.last_name {
|
||||
if l.len() > 64 {
|
||||
return Err(Error::new(422, "Last name must be ≤ 64 characters".into()));
|
||||
}
|
||||
}
|
||||
|
||||
// COALESCE: only update fields that were sent (Some vs None)
|
||||
// We allow explicitly setting a field to an empty string to clear it,
|
||||
// so we map Some("") → SQL NULL.
|
||||
let first = payload
|
||||
.first_name
|
||||
.map(|s| if s.trim().is_empty() { None } else { Some(s) });
|
||||
let last = payload
|
||||
.last_name
|
||||
.map(|s| if s.trim().is_empty() { None } else { Some(s) });
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE users
|
||||
SET first_name = CASE WHEN $2 THEN $3 ELSE first_name END,
|
||||
last_name = CASE WHEN $4 THEN $5 ELSE last_name END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1",
|
||||
)
|
||||
.bind(session.user_id)
|
||||
.bind(first.is_some())
|
||||
.bind(first.flatten())
|
||||
.bind(last.is_some())
|
||||
.bind(last.flatten())
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(load_user_info(session.user_id).await?))
|
||||
}
|
||||
|
||||
async fn change_password(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Json(payload): Json<ChangePasswordPayload>,
|
||||
) -> Result<StatusCode> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
if payload.new_password.len() < 8 {
|
||||
return Err(Error::new(
|
||||
422,
|
||||
"New password must be at least 8 characters".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = data::pool();
|
||||
|
||||
let existing_hash: Option<String> =
|
||||
sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
|
||||
.bind(session.user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
match existing_hash {
|
||||
Some(hash) => {
|
||||
// User already has a password — require current password
|
||||
let current = payload
|
||||
.current_password
|
||||
.ok_or_else(|| Error::new(422, "Current password is required".into()))?;
|
||||
if !verify_password(¤t, &hash) {
|
||||
return Err(Error::new(401, "Current password is incorrect".into()));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// OAuth-only account — allow setting a password without current_password
|
||||
}
|
||||
}
|
||||
|
||||
let new_hash = hash_password(&payload.new_password)?;
|
||||
sqlx::query("UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2")
|
||||
.bind(&new_hash)
|
||||
.bind(session.user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn disconnect_provider(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(provider): Path<String>,
|
||||
) -> Result<StatusCode> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
let pool = data::pool();
|
||||
|
||||
// Safety check: ensure the user has a password before disconnecting OAuth.
|
||||
let has_password: bool =
|
||||
sqlx::query_scalar("SELECT password_hash IS NOT NULL FROM users WHERE id = $1")
|
||||
.bind(session.user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if !has_password {
|
||||
return Err(Error::new(
|
||||
422,
|
||||
"Set a password before disconnecting your OAuth provider".into(),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM user_connections WHERE user_id = $1 AND provider = $2")
|
||||
.bind(session.user_id)
|
||||
.bind(&provider)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
@@ -4,63 +4,23 @@ use crate::{
|
||||
};
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
http::{Method, StatusCode, request::Parts},
|
||||
};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
http::{HeaderMap, StatusCode, request::Parts},
|
||||
};
|
||||
use axum_extra::extract::CookieJar;
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, Validation, decode};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AuthorizationMiddleware — rejects unauthenticated requests
|
||||
// ---------------------------------------------------------------------------
|
||||
pub const COOKIE_NAME: &str = "siren_session";
|
||||
|
||||
pub struct AuthorizationMiddleware;
|
||||
/// Wraps an optional authenticated session.
|
||||
///
|
||||
/// Handlers using this extractor work for both authenticated and
|
||||
/// unauthenticated callers. A valid `siren_session` cookie grants a
|
||||
/// `Some(session)`.
|
||||
pub struct SessionAuthorization(pub Option<Session>);
|
||||
|
||||
impl<S> FromRequestParts<S> for AuthorizationMiddleware
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> std::result::Result<Self, Self::Rejection> {
|
||||
// For options requests browsers will not send the authorization header.
|
||||
if parts.method == Method::OPTIONS {
|
||||
return Ok(Self);
|
||||
}
|
||||
|
||||
// Check for a Bearer token in the `Authorization` header.
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
{
|
||||
return match check_bearer_auth(bearer.token()).await {
|
||||
Ok(session) => {
|
||||
parts.extensions.insert(session);
|
||||
Ok(Self)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
}
|
||||
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OptionalAuth — extracts a Session if present, otherwise None
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wraps an optional authenticated session.
|
||||
/// Handlers that use this extractor work for both authenticated and
|
||||
/// unauthenticated callers; callers with a valid Bearer token get a `Some(session)`.
|
||||
pub struct OptionalAuth(pub Option<Session>);
|
||||
|
||||
impl<S> FromRequestParts<S> for OptionalAuth
|
||||
impl<S> FromRequestParts<S> for SessionAuthorization
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
@@ -70,38 +30,103 @@ where
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> std::result::Result<Self, Self::Rejection> {
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
{
|
||||
if let Ok(session) = check_bearer_auth(bearer.token()).await {
|
||||
let jar = CookieJar::from_request_parts(parts, state).await.unwrap();
|
||||
|
||||
if let Some(cookie) = jar.get(COOKIE_NAME) {
|
||||
let ip = extract_ip(&parts.headers);
|
||||
let user_agent = parts
|
||||
.headers
|
||||
.get("user-agent")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
if let Ok(session) = check_cookie_auth(cookie.value(), &ip, &user_agent).await {
|
||||
parts.extensions.insert(session.clone());
|
||||
return Ok(Self(Some(session)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self(None))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared helper
|
||||
// ---------------------------------------------------------------------------
|
||||
/// Extract the client IP from common proxy headers, falling back to "unknown".
|
||||
pub fn extract_ip(headers: &HeaderMap) -> String {
|
||||
if let Some(forwarded) = headers.get("x-forwarded-for") {
|
||||
if let Ok(val) = forwarded.to_str() {
|
||||
if let Some(ip) = val.split(',').next() {
|
||||
return ip.trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(real_ip) = headers.get("x-real-ip") {
|
||||
if let Ok(val) = real_ip.to_str() {
|
||||
return val.trim().to_string();
|
||||
}
|
||||
}
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
pub async fn check_bearer_auth(bearer_token: &str) -> Result<Session> {
|
||||
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
|
||||
/// Compute a fingerprint from client IP and User-Agent
|
||||
///
|
||||
/// Stored in the Redis session at login time and re-checked on every
|
||||
/// authenticated request so that a stolen cookie is detected when used
|
||||
/// from a different device or IP.
|
||||
pub fn compute_fingerprint(ip: &str, user_agent: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(format!("{ip}:{user_agent}"));
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Validate a JWT cookie value, look up the Redis session, and verify the
|
||||
/// fingerprint against the current request's IP / User-Agent.
|
||||
pub async fn check_cookie_auth(token: &str, ip: &str, user_agent: &str) -> Result<Session> {
|
||||
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
|
||||
|
||||
let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
|
||||
let token_data = decode::<BearerTokenClaims>(token, &decoding_key, &Validation::default())
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let claims = token_data.claims;
|
||||
|
||||
let now = Utc::now().timestamp();
|
||||
if claims.exp < now {
|
||||
if claims.exp < Utc::now().timestamp() {
|
||||
return Err(StatusCode::UNAUTHORIZED.into());
|
||||
}
|
||||
|
||||
match Session::find(&claims.jti).await {
|
||||
Ok(Some(session)) => Ok(session),
|
||||
_ => Err(StatusCode::UNAUTHORIZED)?,
|
||||
let session = match Session::find(&claims.jti).await? {
|
||||
Some(s) => s,
|
||||
None => return Err(StatusCode::UNAUTHORIZED.into()),
|
||||
};
|
||||
|
||||
// Reject if the request comes from a different device / network
|
||||
let expected = compute_fingerprint(ip, user_agent);
|
||||
if session.fingerprint != expected {
|
||||
log::warn!(
|
||||
"Fingerprint mismatch for session {}: stored={} request={}",
|
||||
claims.jti,
|
||||
session.fingerprint,
|
||||
expected
|
||||
);
|
||||
return Err(StatusCode::UNAUTHORIZED.into());
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Parse the raw `Cookie:` header string and validate the siren_session
|
||||
/// value. Used by the WebSocket upgrade handler where we cannot use the
|
||||
/// normal `FromRequestParts` machinery.
|
||||
pub async fn check_cookie_from_header_str(
|
||||
cookie_header: &str,
|
||||
ip: &str,
|
||||
user_agent: &str,
|
||||
) -> Option<Session> {
|
||||
for pair in cookie_header.split(';') {
|
||||
let pair = pair.trim();
|
||||
if let Some(value) = pair.strip_prefix("siren_session=") {
|
||||
return check_cookie_auth(value, ip, user_agent).await.ok();
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
use crate::AppState;
|
||||
use axum::Router;
|
||||
use rand::RngExt;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod discord;
|
||||
mod session;
|
||||
pub use session::Session;
|
||||
mod bearer_token;
|
||||
mod discord;
|
||||
mod local;
|
||||
mod session;
|
||||
|
||||
pub use local::UserInfo;
|
||||
pub use session::Session;
|
||||
|
||||
pub mod middleware;
|
||||
pub use middleware::{AuthorizationMiddleware, OptionalAuth};
|
||||
pub use middleware::SessionAuthorization;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new().nest("/discord", discord::get_routes())
|
||||
}
|
||||
|
||||
pub fn csprng(take: usize) -> String {
|
||||
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
||||
rand::rng()
|
||||
.sample_iter(rand::distr::Alphanumeric)
|
||||
.take(take)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
Router::new()
|
||||
.merge(local::get_routes())
|
||||
.nest("/discord", discord::get_routes())
|
||||
}
|
||||
|
||||
@@ -1,38 +1,45 @@
|
||||
use crate::{auth::csprng, error::Result};
|
||||
use crate::error::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use redis::{AsyncCommands, RedisResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siren_core::data;
|
||||
use siren_core::{data, utils::csprng};
|
||||
use std::{env, sync::OnceLock};
|
||||
use uuid::Uuid;
|
||||
|
||||
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
|
||||
|
||||
fn get_session_ttl() -> i64 {
|
||||
// Initialize the SESSION_TTL value lazily
|
||||
pub fn get_session_ttl() -> i64 {
|
||||
*SESSION_TTL.get_or_init(|| {
|
||||
env::var("API_SESSION_TTL")
|
||||
.ok()
|
||||
.and_then(|val| val.parse::<i64>().ok())
|
||||
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
|
||||
.unwrap_or(86400) // 24 hours
|
||||
})
|
||||
}
|
||||
|
||||
/// A server-side session stored in Redis.
|
||||
///
|
||||
/// Contains the user's identity and a `fingerprint` (SHA-256 of
|
||||
/// `{client_ip}:{user_agent}`) so that stolen cookies can be detected.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Session {
|
||||
pub session_id: String,
|
||||
pub user_id: u64,
|
||||
pub user_id: Uuid,
|
||||
pub user_name: String,
|
||||
/// SHA-256 hex of `{client_ip}:{user_agent}` captured at login time.
|
||||
pub fingerprint: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(user_id: u64, user_name: String) -> Session {
|
||||
pub fn new(user_id: Uuid, user_name: String, fingerprint: String) -> Session {
|
||||
let now = Utc::now();
|
||||
let session_ttl = get_session_ttl();
|
||||
Session {
|
||||
session_id: csprng(32),
|
||||
user_id,
|
||||
user_name,
|
||||
fingerprint,
|
||||
expires_at: now + chrono::Duration::seconds(session_ttl),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{AuthorizationMiddleware, Session},
|
||||
auth::SessionAuthorization,
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Extension,
|
||||
Json,
|
||||
Router,
|
||||
extract::{Path, State},
|
||||
middleware::from_extractor,
|
||||
http::StatusCode,
|
||||
routing::post,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -18,9 +17,7 @@ use std::{fmt::Display, str::FromStr, sync::Arc};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/{guild_id}/track", post(add_track_dice))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
Router::new().route("/{guild_id}/track", post(add_track_dice))
|
||||
}
|
||||
|
||||
const TABLE_NAME: &str = "dice_track";
|
||||
@@ -156,16 +153,34 @@ impl InsertDiceTrack {
|
||||
}
|
||||
|
||||
pub async fn add_track_dice(
|
||||
Extension(session): Extension<Session>,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
Json(payload): Json<DiceTrackPayload>,
|
||||
) -> Result<Json<QueryDiceTrack>> {
|
||||
// Check if the user exists in the cache
|
||||
let owner_id = session.user_id;
|
||||
let owner_id = match state.cache.user(owner_id) {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
// Resolve Discord snowflake for this local user and verify they exist in cache
|
||||
let discord_snowflake: u64 = {
|
||||
let pool = siren_core::data::pool();
|
||||
let pid: Option<String> = sqlx::query_scalar(
|
||||
"SELECT provider_user_id FROM user_connections \
|
||||
WHERE user_id = $1 AND provider = 'discord'",
|
||||
)
|
||||
.bind(session.user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
pid
|
||||
.and_then(|s| s.parse().ok())
|
||||
.ok_or_else(|| Error::not_found("Discord account not connected".to_string()))?
|
||||
};
|
||||
let owner_id = match state.cache.user(discord_snowflake) {
|
||||
Some(user) => user.id,
|
||||
None => return Err(Error::not_found("User not found".to_string())),
|
||||
None => {
|
||||
return Err(Error::not_found(
|
||||
"User not found in Discord cache".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
@@ -182,10 +197,7 @@ pub async fn add_track_dice(
|
||||
dice: format_roll(dice.0, dice.1, dice.2),
|
||||
user_id: payload.user_id,
|
||||
value: payload.value,
|
||||
operator: match payload.operator {
|
||||
None => None,
|
||||
Some(s) => Some(s.to_string()),
|
||||
},
|
||||
operator: payload.operator.map(|s| s.to_string()),
|
||||
};
|
||||
|
||||
// Check for existing dice tracks
|
||||
|
||||
@@ -2,7 +2,7 @@ pub mod model;
|
||||
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{OptionalAuth, Session, csprng, middleware::check_bearer_auth},
|
||||
auth::{Session, SessionAuthorization, middleware::check_cookie_from_header_str},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
@@ -10,81 +10,98 @@ use axum::{
|
||||
Router,
|
||||
extract::{
|
||||
Path,
|
||||
Query,
|
||||
State,
|
||||
WebSocketUpgrade,
|
||||
ws::{Message, WebSocket},
|
||||
},
|
||||
http::StatusCode,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use model::{
|
||||
AccessRequestWithUser,
|
||||
ClientMessage,
|
||||
CreateAccessRequestPayload,
|
||||
CreateMapPayload,
|
||||
GridCell,
|
||||
GridMap,
|
||||
GridToken,
|
||||
MapPermission,
|
||||
ListedMap,
|
||||
MapRole,
|
||||
MapState,
|
||||
PermissionWithUser,
|
||||
ResolveAccessRequestPayload,
|
||||
ServerMessage,
|
||||
UpdateMapPayload,
|
||||
UpdatePermissionPayload,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use siren_core::utils::csprng;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/maps", get(list_maps))
|
||||
.route("/maps", post(create_map))
|
||||
.route("/maps/{id}", get(get_map))
|
||||
.route("/maps/{id}", put(update_map))
|
||||
.route("/maps/{id}", delete(delete_map))
|
||||
.route("/maps/{id}/permissions", get(list_permissions))
|
||||
.route("/maps/{id}/permissions", put(update_permission))
|
||||
.route("/maps/{id}/favorite", post(favorite_map))
|
||||
.route("/maps/{id}/favorite", delete(unfavorite_map))
|
||||
.route("/maps/{id}/access-requests", post(create_access_request))
|
||||
.route("/maps/{id}/access-requests", get(list_access_requests))
|
||||
.route(
|
||||
"/maps/{id}/access-requests/{request_id}",
|
||||
put(resolve_access_request),
|
||||
)
|
||||
.route("/maps/{id}/ws", get(ws_handler))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission helpers
|
||||
// Access helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fetch the role of `user_id` on `map_id`, or `None` if no record exists.
|
||||
async fn get_user_role(map_id: &str, user_id: i64) -> crate::error::Result<Option<MapRole>> {
|
||||
async fn get_user_role(map_id: &str, user_id: Uuid) -> Result<Option<MapRole>> {
|
||||
let pool = siren_core::data::pool();
|
||||
let perm: Option<MapPermission> = sqlx::query_as(
|
||||
"SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1 AND user_id = $2",
|
||||
)
|
||||
.bind(map_id)
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(perm.map(|p| p.role))
|
||||
let role: Option<String> =
|
||||
sqlx::query_scalar("SELECT role FROM map_permissions WHERE map_id = $1 AND user_id = $2")
|
||||
.bind(map_id)
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(role.and_then(|r| match r.as_str() {
|
||||
"owner" => Some(MapRole::Owner),
|
||||
"editor" => Some(MapRole::Editor),
|
||||
"viewer" => Some(MapRole::Viewer),
|
||||
_ => None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Returns whether the caller can view the map:
|
||||
/// - Public maps: always true.
|
||||
/// - Private maps: true only if the user has any role.
|
||||
/// Returns whether the caller can view the map.
|
||||
async fn can_view(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
if map.is_public {
|
||||
if map.public_access == "public_view" || map.public_access == "public_edit" {
|
||||
return true;
|
||||
}
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
get_user_role(&map.id, s.user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
/// Returns whether the caller can edit the map (editor or owner role).
|
||||
/// Returns whether the caller can edit the map (editor or owner role, or public_edit).
|
||||
async fn can_edit(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
if map.public_access == "public_edit" {
|
||||
return true;
|
||||
}
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
get_user_role(&map.id, s.user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
@@ -95,8 +112,7 @@ async fn can_edit(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
/// Returns whether the caller is the owner.
|
||||
async fn is_owner(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
get_user_role(&map.id, s.user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
@@ -105,60 +121,68 @@ async fn is_owner(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// REST handlers
|
||||
// Map CRUD
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_maps(OptionalAuth(session): OptionalAuth) -> Result<Json<Vec<GridMap>>> {
|
||||
pub async fn list_maps(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
) -> Result<Json<Vec<ListedMap>>> {
|
||||
let pool = siren_core::data::pool();
|
||||
let maps: Vec<GridMap> = match &session {
|
||||
let maps: Vec<ListedMap> = match &session {
|
||||
Some(s) => {
|
||||
let user_id = s.user_id as i64;
|
||||
sqlx::query_as(
|
||||
"SELECT DISTINCT gm.*
|
||||
"SELECT
|
||||
gm.id, gm.name, gm.public_access, gm.owner_id,
|
||||
u.username AS owner_username,
|
||||
gm.colors, gm.created_at, gm.updated_at,
|
||||
mp.role AS user_role,
|
||||
(mf.user_id IS NOT NULL) AS is_favorited
|
||||
FROM grid_maps gm
|
||||
JOIN users u ON u.id = gm.owner_id
|
||||
LEFT JOIN map_permissions mp ON mp.map_id = gm.id AND mp.user_id = $1
|
||||
WHERE gm.is_public = TRUE OR mp.user_id IS NOT NULL
|
||||
ORDER BY gm.created_at DESC",
|
||||
LEFT JOIN map_favorites mf ON mf.map_id = gm.id AND mf.user_id = $1
|
||||
WHERE mp.user_id IS NOT NULL OR mf.user_id IS NOT NULL
|
||||
ORDER BY gm.updated_at DESC",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(s.user_id)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
sqlx::query_as("SELECT * FROM grid_maps WHERE is_public = TRUE ORDER BY created_at DESC")
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
Ok(Json(maps))
|
||||
}
|
||||
|
||||
pub async fn create_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Json(payload): Json<CreateMapPayload>,
|
||||
) -> Result<(StatusCode, Json<GridMap>)> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
let user_id = session.user_id as i64;
|
||||
let public_access = payload.public_access.as_str();
|
||||
if !matches!(public_access, "private" | "public_view" | "public_edit") {
|
||||
return Err(Error::new(422, "Invalid public_access value".into()));
|
||||
}
|
||||
|
||||
let map_id = csprng(32);
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: GridMap = sqlx::query_as(
|
||||
"INSERT INTO grid_maps (id, name, is_public, owner_id)
|
||||
"INSERT INTO grid_maps (id, name, public_access, owner_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *",
|
||||
)
|
||||
.bind(&map_id)
|
||||
.bind(&payload.name)
|
||||
.bind(payload.is_public)
|
||||
.bind(user_id)
|
||||
.bind(&payload.public_access)
|
||||
.bind(session.user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
// Auto-assign the creator as owner in map_permissions
|
||||
sqlx::query("INSERT INTO map_permissions (map_id, user_id, role) VALUES ($1, $2, 'owner')")
|
||||
.bind(&map_id)
|
||||
.bind(user_id)
|
||||
.bind(session.user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
@@ -166,7 +190,7 @@ pub async fn create_map(
|
||||
}
|
||||
|
||||
pub async fn get_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<MapState>> {
|
||||
let pool = siren_core::data::pool();
|
||||
@@ -195,8 +219,51 @@ pub async fn get_map(
|
||||
Ok(Json(MapState { map, cells, tokens }))
|
||||
}
|
||||
|
||||
pub async fn update_map(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateMapPayload>,
|
||||
) -> Result<Json<GridMap>> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
if let Some(ref pa) = payload.public_access {
|
||||
if !matches!(pa.as_str(), "private" | "public_view" | "public_edit") {
|
||||
return Err(Error::new(422, "Invalid public_access value".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let new_name = payload.name.as_deref().unwrap_or(&map.name);
|
||||
let new_pa = payload
|
||||
.public_access
|
||||
.as_deref()
|
||||
.unwrap_or(&map.public_access);
|
||||
|
||||
let updated: GridMap = sqlx::query_as(
|
||||
"UPDATE grid_maps SET name = $1, public_access = $2, updated_at = NOW()
|
||||
WHERE id = $3 RETURNING *",
|
||||
)
|
||||
.bind(new_name)
|
||||
.bind(new_pa)
|
||||
.bind(&id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(updated))
|
||||
}
|
||||
|
||||
pub async fn delete_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode> {
|
||||
let pool = siren_core::data::pool();
|
||||
@@ -225,9 +292,9 @@ pub async fn delete_map(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_permissions(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<Vec<MapPermission>>> {
|
||||
) -> Result<Json<Vec<PermissionWithUser>>> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
@@ -241,17 +308,22 @@ pub async fn list_permissions(
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
let perms: Vec<MapPermission> =
|
||||
sqlx::query_as("SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1")
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
let perms: Vec<PermissionWithUser> = sqlx::query_as(
|
||||
"SELECT mp.map_id, mp.user_id, u.username, mp.role
|
||||
FROM map_permissions mp
|
||||
JOIN users u ON u.id = mp.user_id
|
||||
WHERE mp.map_id = $1
|
||||
ORDER BY mp.role, u.username",
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(perms))
|
||||
}
|
||||
|
||||
pub async fn update_permission(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdatePermissionPayload>,
|
||||
) -> Result<StatusCode> {
|
||||
@@ -268,10 +340,23 @@ pub async fn update_permission(
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
// Prevent the owner from removing their own owner record
|
||||
let caller_id = session.as_ref().map(|s| s.user_id as i64).unwrap_or(0);
|
||||
if payload.user_id == caller_id && payload.role.as_ref().map(|r| r.is_owner()) == Some(false) {
|
||||
return Err(Error::from(StatusCode::UNPROCESSABLE_ENTITY));
|
||||
// Resolve username → user_id
|
||||
let target_id: Option<Uuid> = sqlx::query_scalar("SELECT id FROM users WHERE username = $1")
|
||||
.bind(&payload.username)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let target_id = target_id.ok_or_else(|| Error::not_found("User not found".into()))?;
|
||||
|
||||
// Prevent the owner from stripping their own owner record
|
||||
if let Some(ref s) = session {
|
||||
if target_id == s.user_id {
|
||||
if let Some(ref role) = payload.role {
|
||||
if !role.is_owner() {
|
||||
return Err(Error::from(StatusCode::UNPROCESSABLE_ENTITY));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match payload.role {
|
||||
@@ -282,7 +367,7 @@ pub async fn update_permission(
|
||||
ON CONFLICT (map_id, user_id) DO UPDATE SET role = EXCLUDED.role",
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(payload.user_id)
|
||||
.bind(target_id)
|
||||
.bind(role)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
@@ -290,7 +375,7 @@ pub async fn update_permission(
|
||||
None => {
|
||||
sqlx::query("DELETE FROM map_permissions WHERE map_id = $1 AND user_id = $2")
|
||||
.bind(&id)
|
||||
.bind(payload.user_id)
|
||||
.bind(target_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
@@ -299,26 +384,215 @@ pub async fn update_permission(
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
pub async fn favorite_map(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
// Verify the map exists and is viewable
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
if !can_view(&map, &Some(session.clone())).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO map_favorites (user_id, map_id) VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(session.user_id)
|
||||
.bind(&id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
pub async fn unfavorite_map(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
sqlx::query("DELETE FROM map_favorites WHERE user_id = $1 AND map_id = $2")
|
||||
.bind(session.user_id)
|
||||
.bind(&id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Access Requests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn create_access_request(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<CreateAccessRequestPayload>,
|
||||
) -> Result<StatusCode> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
// Only editor and viewer roles can be requested
|
||||
if matches!(payload.role, MapRole::Owner) {
|
||||
return Err(Error::new(422, "Cannot request owner role".into()));
|
||||
}
|
||||
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
// Check if user already has a direct permission
|
||||
let existing_role = get_user_role(&id, session.user_id).await?;
|
||||
if existing_role.is_some() {
|
||||
return Err(Error::new(
|
||||
409,
|
||||
"You already have access to this map".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Upsert the request (update role if they change their mind)
|
||||
sqlx::query(
|
||||
"INSERT INTO map_access_requests (map_id, user_id, requested_role, status, updated_at)
|
||||
VALUES ($1, $2, $3, 'pending', NOW())
|
||||
ON CONFLICT (map_id, user_id)
|
||||
DO UPDATE SET requested_role = EXCLUDED.requested_role,
|
||||
status = 'pending',
|
||||
updated_at = NOW()",
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(session.user_id)
|
||||
.bind(&payload.role)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::CREATED)
|
||||
}
|
||||
|
||||
pub async fn list_access_requests(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<Vec<AccessRequestWithUser>>> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
let requests: Vec<AccessRequestWithUser> = sqlx::query_as(
|
||||
"SELECT mar.id, mar.map_id, mar.user_id, u.username,
|
||||
mar.requested_role, mar.status, mar.created_at, mar.updated_at
|
||||
FROM map_access_requests mar
|
||||
JOIN users u ON u.id = mar.user_id
|
||||
WHERE mar.map_id = $1 AND mar.status = 'pending'
|
||||
ORDER BY mar.created_at ASC",
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(requests))
|
||||
}
|
||||
|
||||
pub async fn resolve_access_request(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
Path((map_id, request_id)): Path<(String, Uuid)>,
|
||||
Json(payload): Json<ResolveAccessRequestPayload>,
|
||||
) -> Result<StatusCode> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&map_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
if !matches!(payload.action.as_str(), "approve" | "deny") {
|
||||
return Err(Error::new(422, "action must be 'approve' or 'deny'".into()));
|
||||
}
|
||||
|
||||
// Fetch the request
|
||||
let req: Option<(Uuid, String)> = sqlx::query_as(
|
||||
"SELECT user_id, requested_role FROM map_access_requests WHERE id = $1 AND map_id = $2",
|
||||
)
|
||||
.bind(request_id)
|
||||
.bind(&map_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let (user_id, role) = req.ok_or_else(|| Error::not_found("Access request not found".into()))?;
|
||||
|
||||
if payload.action == "approve" {
|
||||
// Grant the requested role
|
||||
sqlx::query(
|
||||
"INSERT INTO map_permissions (map_id, user_id, role)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (map_id, user_id) DO UPDATE SET role = EXCLUDED.role",
|
||||
)
|
||||
.bind(&map_id)
|
||||
.bind(user_id)
|
||||
.bind(&role)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Mark request as resolved
|
||||
let new_status = if payload.action == "approve" {
|
||||
"approved"
|
||||
} else {
|
||||
"denied"
|
||||
};
|
||||
sqlx::query("UPDATE map_access_requests SET status = $1, updated_at = NOW() WHERE id = $2")
|
||||
.bind(new_status)
|
||||
.bind(request_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WsQuery {
|
||||
/// Optional Bearer token passed as a query parameter for WS auth.
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(map_id): Path<String>,
|
||||
Query(query): Query<WsQuery>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
// Resolve the session from query param (WS can't easily send headers)
|
||||
let session: Option<Session> = match query.token {
|
||||
Some(ref tok) => check_bearer_auth(tok).await.ok(),
|
||||
None => None,
|
||||
let session: Option<Session> = {
|
||||
let ip = crate::auth::middleware::extract_ip(&headers);
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("unknown");
|
||||
if let Some(cookie_header) = headers.get("cookie").and_then(|h| h.to_str().ok()) {
|
||||
check_cookie_from_header_str(cookie_header, &ip, user_agent).await
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, map_id, session))
|
||||
@@ -330,20 +604,17 @@ async fn handle_socket(
|
||||
map_id: String,
|
||||
session: Option<Session>,
|
||||
) {
|
||||
// Load the map and verify the caller can view it
|
||||
let map_state = match fetch_map_state(&map_id).await {
|
||||
Ok(ms) => ms,
|
||||
Err(_) => return, // map doesn't exist
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
if !can_view(&map_state.map, &session).await {
|
||||
// Refuse the connection silently (upgrade already happened; just close)
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = can_edit(&map_state.map, &session).await;
|
||||
|
||||
// Get or create a broadcast channel for this map
|
||||
let tx = state
|
||||
.map_rooms
|
||||
.entry(map_id.clone())
|
||||
@@ -356,7 +627,6 @@ async fn handle_socket(
|
||||
|
||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||
|
||||
// Send the current full map state to the newly connected client
|
||||
let init_msg = ServerMessage::State {
|
||||
cells: map_state.cells,
|
||||
tokens: map_state.tokens,
|
||||
@@ -366,7 +636,6 @@ async fn handle_socket(
|
||||
let _ = ws_tx.send(Message::Text(json.into())).await;
|
||||
}
|
||||
|
||||
// Task 1: forward broadcast messages to this socket
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(json) = rx.recv().await {
|
||||
if ws_tx.send(Message::Text(json.into())).await.is_err() {
|
||||
@@ -375,7 +644,6 @@ async fn handle_socket(
|
||||
}
|
||||
});
|
||||
|
||||
// Task 2: receive messages from this client, persist, and broadcast
|
||||
let tx_clone = tx.clone();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = ws_rx.next().await {
|
||||
@@ -430,7 +698,6 @@ async fn handle_client_message(
|
||||
}
|
||||
};
|
||||
|
||||
// All mutating messages require editor or owner role
|
||||
if !can_edit {
|
||||
let err = ServerMessage::Error {
|
||||
message: "You do not have permission to edit this map.".into(),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use chrono::NaiveDateTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Map Role / Permission
|
||||
@@ -26,10 +27,12 @@ impl MapRole {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct MapPermission {
|
||||
/// A permission record with the associated user's username included (for display).
|
||||
#[derive(Serialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct PermissionWithUser {
|
||||
pub map_id: String,
|
||||
pub user_id: i64,
|
||||
pub user_id: Uuid,
|
||||
pub username: String,
|
||||
pub role: MapRole,
|
||||
}
|
||||
|
||||
@@ -37,32 +40,92 @@ pub struct MapPermission {
|
||||
// Grid Map
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Core map record as stored/returned by create, get, and update endpoints.
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct GridMap {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub is_public: bool,
|
||||
pub owner_id: i64,
|
||||
/// One of: "private", "public_view", "public_edit"
|
||||
pub public_access: String,
|
||||
pub owner_id: Uuid,
|
||||
pub colors: Vec<String>,
|
||||
pub created_at: NaiveDateTime,
|
||||
pub updated_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct CreateMapPayload {
|
||||
/// Extended map record returned by the list endpoint.
|
||||
/// Includes the owner's username, the caller's role (if any), and a
|
||||
/// favorited flag.
|
||||
#[derive(Serialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct ListedMap {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub is_public: bool,
|
||||
pub public_access: String,
|
||||
pub owner_id: Uuid,
|
||||
pub owner_username: String,
|
||||
pub colors: Vec<String>,
|
||||
pub created_at: NaiveDateTime,
|
||||
pub updated_at: NaiveDateTime,
|
||||
/// The authenticated caller's role on this map, or NULL if they only have it
|
||||
/// via a favorite (no explicit permission).
|
||||
pub user_role: Option<MapRole>,
|
||||
pub is_favorited: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct CreateMapPayload {
|
||||
pub name: String,
|
||||
/// Defaults to "private" when omitted.
|
||||
#[serde(default = "default_private")]
|
||||
pub public_access: String,
|
||||
}
|
||||
|
||||
fn default_private() -> String {
|
||||
"private".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct UpdateMapPayload {
|
||||
pub name: Option<String>,
|
||||
pub public_access: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct UpdatePermissionPayload {
|
||||
/// Discord user ID of the target user.
|
||||
pub user_id: i64,
|
||||
/// New role to assign. Omit (null) to remove the permission entry.
|
||||
/// Username of the target user (looked up server-side).
|
||||
pub username: String,
|
||||
/// New role to assign. `null` removes the permission entry.
|
||||
pub role: Option<MapRole>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Map Access Requests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// An access-request row joined with the requesting user's username.
|
||||
#[derive(Serialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct AccessRequestWithUser {
|
||||
pub id: Uuid,
|
||||
pub map_id: String,
|
||||
pub user_id: Uuid,
|
||||
pub username: String,
|
||||
pub requested_role: MapRole,
|
||||
pub status: String,
|
||||
pub created_at: NaiveDateTime,
|
||||
pub updated_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct CreateAccessRequestPayload {
|
||||
pub role: MapRole,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct ResolveAccessRequestPayload {
|
||||
/// "approve" or "deny"
|
||||
pub action: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Grid Cell (no id column — composite PK in DB)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user