Updates to pages

This commit is contained in:
2026-04-04 18:31:28 -04:00
parent 070337577c
commit ca95582d92
42 changed files with 2831 additions and 640 deletions

View File

@@ -0,0 +1,190 @@
use crate::{
AppState,
auth::AdminAuthorization,
error::{Error, Result},
};
use axum::{
Json,
Router,
extract::Path,
http::StatusCode,
routing::{delete, get, put},
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use siren_core::data;
use std::sync::Arc;
use uuid::Uuid;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/users", get(list_users))
.route("/users/{id}/role", put(set_user_role))
.route("/users/{id}/ban", put(ban_user))
.route("/users/{id}/unban", put(unban_user))
.route("/users/{id}", delete(delete_user))
}
/// Minimal user record returned by the admin list endpoint.
#[derive(Serialize, sqlx::FromRow)]
pub struct AdminUserRecord {
pub id: String,
pub username: String,
pub email: Option<String>,
pub role: String,
pub status: String,
pub created_at: DateTime<Utc>,
}
#[derive(sqlx::FromRow)]
struct DbAdminUser {
id: Uuid,
username: String,
email: Option<String>,
role: String,
status: String,
created_at: DateTime<Utc>,
}
#[derive(Deserialize)]
struct SetRolePayload {
role: String,
}
/// `GET /admin/users` — list all user accounts (admin only).
async fn list_users(
AdminAuthorization(_session): AdminAuthorization,
) -> Result<Json<Vec<AdminUserRecord>>> {
let pool = data::pool();
let rows: Vec<DbAdminUser> = sqlx::query_as(
"SELECT id, username, email, role, status, created_at \
FROM users \
ORDER BY created_at ASC",
)
.fetch_all(pool)
.await?;
let records = rows
.into_iter()
.map(|r| AdminUserRecord {
id: r.id.to_string(),
username: r.username,
email: r.email,
role: r.role,
status: r.status,
created_at: r.created_at,
})
.collect();
Ok(Json(records))
}
/// `PUT /admin/users/{id}/role` — promote or demote a user (admin only).
async fn set_user_role(
AdminAuthorization(session): AdminAuthorization,
Path(id): Path<Uuid>,
Json(payload): Json<SetRolePayload>,
) -> Result<StatusCode> {
if payload.role != "admin" && payload.role != "user" {
return Err(Error::new(422, "role must be 'admin' or 'user'".into()));
}
// Prevent an admin from demoting themselves
if id == session.user_id && payload.role != "admin" {
return Err(Error::new(
422,
"You cannot remove your own admin role".into(),
));
}
let pool = data::pool();
let affected = sqlx::query("UPDATE users SET role = $1, updated_at = NOW() WHERE id = $2")
.bind(&payload.role)
.bind(id)
.execute(pool)
.await?
.rows_affected();
if affected == 0 {
return Err(Error::not_found(format!("User {id} not found")));
}
Ok(StatusCode::NO_CONTENT)
}
/// `PUT /admin/users/{id}/ban` — ban a user account (admin only).
async fn ban_user(
AdminAuthorization(session): AdminAuthorization,
Path(id): Path<Uuid>,
) -> Result<StatusCode> {
// Admins cannot ban themselves
if id == session.user_id {
return Err(Error::new(422, "You cannot ban yourself".into()));
}
let pool = data::pool();
let affected =
sqlx::query("UPDATE users SET status = 'banned', updated_at = NOW() WHERE id = $1")
.bind(id)
.execute(pool)
.await?
.rows_affected();
if affected == 0 {
return Err(Error::not_found(format!("User {id} not found")));
}
Ok(StatusCode::NO_CONTENT)
}
/// `PUT /admin/users/{id}/unban` — reinstate a banned user (admin only).
async fn unban_user(
AdminAuthorization(_session): AdminAuthorization,
Path(id): Path<Uuid>,
) -> Result<StatusCode> {
let pool = data::pool();
let affected =
sqlx::query("UPDATE users SET status = 'active', updated_at = NOW() WHERE id = $1")
.bind(id)
.execute(pool)
.await?
.rows_affected();
if affected == 0 {
return Err(Error::not_found(format!("User {id} not found")));
}
Ok(StatusCode::NO_CONTENT)
}
/// `DELETE /admin/users/{id}` — permanently delete a user account (admin only).
async fn delete_user(
AdminAuthorization(session): AdminAuthorization,
Path(id): Path<Uuid>,
) -> Result<StatusCode> {
// Admins cannot delete themselves
if id == session.user_id {
return Err(Error::new(
422,
"You cannot delete your own account via the admin panel".into(),
));
}
let pool = data::pool();
let affected = sqlx::query("DELETE FROM users WHERE id = $1")
.bind(id)
.execute(pool)
.await?
.rows_affected();
if affected == 0 {
return Err(Error::not_found(format!("User {id} not found")));
}
Ok(StatusCode::NO_CONTENT)
}

View File

@@ -8,28 +8,42 @@ use axum::{
Router,
extract::{Path, State},
http::StatusCode,
routing::post,
routing::{get, post},
};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use siren_bot::{
commands::audio::{
join_voice_channel,
pause::pause_track,
play::enqueue_track,
queue::{TrackInfo, get_is_paused, get_queue},
resume::resume_track,
skip::skip_track,
stop::stop_track,
},
handler::get_songbird,
};
use std::sync::Arc;
use uuid::Uuid;
/// Routes that don't require a guild_id (nested at /api/audio)
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new().route("/guilds", get(list_guilds))
}
/// Routes that operate on a specific guild (nested at /api/audio/{guild_id})
pub fn get_guild_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/play", post(play_audio))
.route("/pause", post(pause_audio))
.route("/resume", post(resume_audio))
.route("/stop", post(stop_audio))
.route("/skip", post(skip_audio))
.route("/status", get(audio_status))
}
// ── Shared helpers ────────────────────────────────────────────────────────────
#[derive(Deserialize)]
struct PlayTrackRequest {
url: String,
@@ -52,6 +66,38 @@ async fn get_discord_snowflake(local_user_id: Uuid) -> Result<u64> {
.ok_or_else(|| Error::not_found("Discord account not connected".to_string()))
}
// ── GET /api/audio/guilds ─────────────────────────────────────────────────────
#[derive(Serialize)]
struct GuildInfo {
id: String,
name: String,
}
/// Returns all guilds the bot is currently in (from its Discord cache).
async fn list_guilds(
SessionAuthorization(session): SessionAuthorization,
State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<GuildInfo>>> {
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
let guilds: Vec<GuildInfo> = state
.cache
.guilds()
.into_iter()
.filter_map(|guild_id| {
state.cache.guild(guild_id).map(|g| GuildInfo {
id: guild_id.get().to_string(),
name: g.name.clone(),
})
})
.collect();
Ok(Json(guilds))
}
// ── POST /api/audio/{guild_id}/play ──────────────────────────────────────────
async fn play_audio(
SessionAuthorization(session): SessionAuthorization,
State(state): State<Arc<AppState>>,
@@ -88,6 +134,8 @@ async fn play_audio(
Ok(())
}
// ── POST /api/audio/{guild_id}/pause ─────────────────────────────────────────
async fn pause_audio(
SessionAuthorization(session): SessionAuthorization,
State(state): State<Arc<AppState>>,
@@ -96,18 +144,18 @@ async fn pause_audio(
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
log::debug!("Pausing audio in guild: {}", guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Pause the track
let manager = get_songbird();
pause_track(manager, &guild_id).await?;
Ok(())
}
// ── POST /api/audio/{guild_id}/resume ────────────────────────────────────────
async fn resume_audio(
SessionAuthorization(session): SessionAuthorization,
State(state): State<Arc<AppState>>,
@@ -116,14 +164,106 @@ async fn resume_audio(
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
log::debug!("Resuming audio in guild: {}", guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Resume the track
let manager = get_songbird();
resume_track(manager, &guild_id).await?;
Ok(())
}
// ── POST /api/audio/{guild_id}/stop ──────────────────────────────────────────
async fn stop_audio(
SessionAuthorization(session): SessionAuthorization,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
) -> Result<()> {
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
log::debug!("Stopping audio in guild: {}", guild_id);
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
let manager = get_songbird();
stop_track(manager, &guild_id)
.await
.map_err(|e| Error::new(500, e))?;
Ok(())
}
async fn skip_audio(
SessionAuthorization(session): SessionAuthorization,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
) -> Result<()> {
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
log::debug!("<{}> Skipping audio", guild_id);
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
let manager = get_songbird();
skip_track(manager, &guild_id)
.await
.map_err(|e| Error::new(500, e))?;
Ok(())
}
#[derive(Serialize)]
struct AudioStatus {
voice_channel: Option<String>,
is_paused: bool,
current_track: Option<TrackInfo>,
queue: Vec<TrackInfo>,
}
async fn audio_status(
SessionAuthorization(session): SessionAuthorization,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
) -> Result<Json<AudioStatus>> {
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
let guild_id_snowflake = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// ── Voice channel: look up the bot's own voice state + channel name from cache ──
let bot_user_id = state.cache.current_user().id;
let voice_channel = state
.cache
.guild(guild_id_snowflake)
.and_then(|guild| {
let ch_id = guild
.voice_states
.get(&bot_user_id)
.and_then(|vs| vs.channel_id)?;
guild.channels.get(&ch_id).map(|ch| ch.name.clone())
});
// ── Playback paused state (delegated to siren-bot to keep songbird internal) ──
let is_paused = get_is_paused(guild_id).await;
// ── Queue metadata from our store (index 0 = currently playing) ──
let mut full_queue = get_queue(guild_id);
let current_track = if !full_queue.is_empty() {
Some(full_queue.remove(0))
} else {
None
};
Ok(Json(AudioStatus {
voice_channel,
is_paused,
current_track,
queue: full_queue,
}))
}

View File

@@ -294,8 +294,8 @@ async fn do_oauth_callback(
None => {
// Find existing connection → local user_id
let local_user_id: Option<(Uuid, String)> = sqlx::query_as(
"SELECT u.id, u.username \
let local_user_id: Option<(Uuid, String, String)> = sqlx::query_as(
"SELECT u.id, u.username, u.status \
FROM user_connections uc \
JOIN users u ON u.id = uc.user_id \
WHERE uc.provider = 'discord' AND uc.provider_user_id = $1",
@@ -311,6 +311,10 @@ async fn do_oauth_callback(
let (user_id, username) = match local_user_id {
// Already linked — use the existing local user
Some(row) => {
// Reject banned accounts
if row.2 == "banned" {
return err_redirect(StatusCode::FORBIDDEN);
}
// Keep provider fields up to date
sqlx::query(
"UPDATE user_connections \
@@ -327,7 +331,7 @@ async fn do_oauth_callback(
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
})?;
row
(row.0, row.1)
}
// First login — create a local user + connection

View File

@@ -85,6 +85,10 @@ pub struct UserInfo {
/// OAuth and can safely disconnect OAuth providers).
pub has_password: bool,
pub connections: Vec<ConnectionInfo>,
/// Site-level role: `"admin"` or `"user"`.
pub role: String,
/// Account status: `"active"` or `"banned"`.
pub status: String,
}
#[derive(sqlx::FromRow)]
@@ -95,6 +99,8 @@ struct DbUser {
last_name: Option<String>,
email: Option<String>,
password_hash: Option<String>,
role: String,
status: String,
}
#[derive(sqlx::FromRow)]
@@ -176,7 +182,8 @@ async fn load_user_info(user_id: Uuid) -> Result<UserInfo> {
let pool = data::pool();
let user: DbUser = sqlx::query_as(
"SELECT id, username, first_name, last_name, email, password_hash FROM users WHERE id = $1",
"SELECT id, username, first_name, last_name, email, password_hash, role, status \
FROM users WHERE id = $1",
)
.bind(user_id)
.fetch_one(pool)
@@ -197,6 +204,8 @@ async fn load_user_info(user_id: Uuid) -> Result<UserInfo> {
last_name: user.last_name,
email: user.email,
has_password: user.password_hash.is_some(),
role: user.role,
status: user.status,
connections: connections
.into_iter()
.map(|c| ConnectionInfo {
@@ -260,13 +269,13 @@ async fn login(
) -> Result<impl IntoResponse> {
let pool = data::pool();
let row: Option<(Uuid, String, Option<String>)> =
sqlx::query_as("SELECT id, username, password_hash FROM users WHERE username = $1")
let row: Option<(Uuid, String, Option<String>, String)> =
sqlx::query_as("SELECT id, username, password_hash, status FROM users WHERE username = $1")
.bind(&payload.username)
.fetch_optional(pool)
.await?;
let (user_id, username, password_hash) =
let (user_id, username, password_hash, status) =
row.ok_or_else(|| Error::new(401, "Invalid username or password".into()))?;
let hash =
@@ -276,6 +285,10 @@ async fn login(
return Err(Error::new(401, "Invalid username or password".into()));
}
if status == "banned" {
return Err(Error::new(403, "This account has been banned".into()));
}
let ip = extract_ip(&headers);
let user_agent = headers
.get("user-agent")

View File

@@ -1,6 +1,6 @@
use crate::{
auth::{bearer_token::BearerTokenClaims, session::Session},
error::Result,
error::{Error, Result},
};
use axum::{
extract::FromRequestParts,
@@ -10,6 +10,7 @@ use axum_extra::extract::CookieJar;
use chrono::Utc;
use jsonwebtoken::{DecodingKey, Validation, decode};
use sha2::{Digest, Sha256};
use siren_core::data;
pub const COOKIE_NAME: &str = "siren_session";
@@ -130,3 +131,46 @@ pub async fn check_cookie_from_header_str(
}
None
}
/// Extractor that requires the caller to be an authenticated site admin.
///
/// Returns `401 Unauthorized` if there is no valid session, or
/// `403 Forbidden` if the user's role is not `"admin"`.
/// On success, the inner `Session` is available to the handler.
pub struct AdminAuthorization(pub Session);
impl<S> FromRequestParts<S> for AdminAuthorization
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> std::result::Result<Self, Self::Rejection> {
let SessionAuthorization(maybe_session) =
SessionAuthorization::from_request_parts(parts, state)
.await
.unwrap();
let session = maybe_session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
// Verify admin role in the database
let pool = data::pool();
let role: Option<String> = sqlx::query_scalar("SELECT role FROM users WHERE id = $1")
.bind(session.user_id)
.fetch_optional(pool)
.await
.map_err(|e| {
log::error!("DB error checking admin role: {e}");
Error::from(StatusCode::INTERNAL_SERVER_ERROR)
})?;
match role.as_deref() {
Some("admin") => Ok(AdminAuthorization(session)),
Some(_) => Err(Error::from(StatusCode::FORBIDDEN)),
None => Err(Error::from(StatusCode::UNAUTHORIZED)),
}
}
}

View File

@@ -11,7 +11,7 @@ pub use local::UserInfo;
pub use session::Session;
pub mod middleware;
pub use middleware::SessionAuthorization;
pub use middleware::{AdminAuthorization, SessionAuthorization};
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()

View File

@@ -1,3 +1,4 @@
pub mod admin;
pub mod app;
mod app_state;
pub mod audio;
@@ -13,8 +14,10 @@ use std::sync::Arc;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.nest("/admin", admin::get_routes())
.nest("/auth", auth::get_routes())
.nest("/audio/{guild_id}", audio::get_routes())
.nest("/audio", audio::get_routes())
.nest("/audio/{guild_id}", audio::get_guild_routes())
.nest("/dice", dice::get_routes())
.nest("/grid", grid::get_routes())
}