Updates to pages
This commit is contained in:
190
crates/siren-api/src/admin/mod.rs
Normal file
190
crates/siren-api/src/admin/mod.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::AdminAuthorization,
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Json,
|
||||
Router,
|
||||
extract::Path,
|
||||
http::StatusCode,
|
||||
routing::{delete, get, put},
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siren_core::data;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/users", get(list_users))
|
||||
.route("/users/{id}/role", put(set_user_role))
|
||||
.route("/users/{id}/ban", put(ban_user))
|
||||
.route("/users/{id}/unban", put(unban_user))
|
||||
.route("/users/{id}", delete(delete_user))
|
||||
}
|
||||
|
||||
/// Minimal user record returned by the admin list endpoint.
|
||||
#[derive(Serialize, sqlx::FromRow)]
|
||||
pub struct AdminUserRecord {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub email: Option<String>,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct DbAdminUser {
|
||||
id: Uuid,
|
||||
username: String,
|
||||
email: Option<String>,
|
||||
role: String,
|
||||
status: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SetRolePayload {
|
||||
role: String,
|
||||
}
|
||||
|
||||
/// `GET /admin/users` — list all user accounts (admin only).
|
||||
async fn list_users(
|
||||
AdminAuthorization(_session): AdminAuthorization,
|
||||
) -> Result<Json<Vec<AdminUserRecord>>> {
|
||||
let pool = data::pool();
|
||||
|
||||
let rows: Vec<DbAdminUser> = sqlx::query_as(
|
||||
"SELECT id, username, email, role, status, created_at \
|
||||
FROM users \
|
||||
ORDER BY created_at ASC",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let records = rows
|
||||
.into_iter()
|
||||
.map(|r| AdminUserRecord {
|
||||
id: r.id.to_string(),
|
||||
username: r.username,
|
||||
email: r.email,
|
||||
role: r.role,
|
||||
status: r.status,
|
||||
created_at: r.created_at,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(records))
|
||||
}
|
||||
|
||||
/// `PUT /admin/users/{id}/role` — promote or demote a user (admin only).
|
||||
async fn set_user_role(
|
||||
AdminAuthorization(session): AdminAuthorization,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(payload): Json<SetRolePayload>,
|
||||
) -> Result<StatusCode> {
|
||||
if payload.role != "admin" && payload.role != "user" {
|
||||
return Err(Error::new(422, "role must be 'admin' or 'user'".into()));
|
||||
}
|
||||
|
||||
// Prevent an admin from demoting themselves
|
||||
if id == session.user_id && payload.role != "admin" {
|
||||
return Err(Error::new(
|
||||
422,
|
||||
"You cannot remove your own admin role".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = data::pool();
|
||||
|
||||
let affected = sqlx::query("UPDATE users SET role = $1, updated_at = NOW() WHERE id = $2")
|
||||
.bind(&payload.role)
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
if affected == 0 {
|
||||
return Err(Error::not_found(format!("User {id} not found")));
|
||||
}
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
/// `PUT /admin/users/{id}/ban` — ban a user account (admin only).
|
||||
async fn ban_user(
|
||||
AdminAuthorization(session): AdminAuthorization,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<StatusCode> {
|
||||
// Admins cannot ban themselves
|
||||
if id == session.user_id {
|
||||
return Err(Error::new(422, "You cannot ban yourself".into()));
|
||||
}
|
||||
|
||||
let pool = data::pool();
|
||||
|
||||
let affected =
|
||||
sqlx::query("UPDATE users SET status = 'banned', updated_at = NOW() WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
if affected == 0 {
|
||||
return Err(Error::not_found(format!("User {id} not found")));
|
||||
}
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
/// `PUT /admin/users/{id}/unban` — reinstate a banned user (admin only).
|
||||
async fn unban_user(
|
||||
AdminAuthorization(_session): AdminAuthorization,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<StatusCode> {
|
||||
let pool = data::pool();
|
||||
|
||||
let affected =
|
||||
sqlx::query("UPDATE users SET status = 'active', updated_at = NOW() WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
if affected == 0 {
|
||||
return Err(Error::not_found(format!("User {id} not found")));
|
||||
}
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
/// `DELETE /admin/users/{id}` — permanently delete a user account (admin only).
|
||||
async fn delete_user(
|
||||
AdminAuthorization(session): AdminAuthorization,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<StatusCode> {
|
||||
// Admins cannot delete themselves
|
||||
if id == session.user_id {
|
||||
return Err(Error::new(
|
||||
422,
|
||||
"You cannot delete your own account via the admin panel".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = data::pool();
|
||||
|
||||
let affected = sqlx::query("DELETE FROM users WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
if affected == 0 {
|
||||
return Err(Error::not_found(format!("User {id} not found")));
|
||||
}
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
@@ -8,28 +8,42 @@ use axum::{
|
||||
Router,
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
routing::post,
|
||||
routing::{get, post},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siren_bot::{
|
||||
commands::audio::{
|
||||
join_voice_channel,
|
||||
pause::pause_track,
|
||||
play::enqueue_track,
|
||||
queue::{TrackInfo, get_is_paused, get_queue},
|
||||
resume::resume_track,
|
||||
skip::skip_track,
|
||||
stop::stop_track,
|
||||
},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Routes that don't require a guild_id (nested at /api/audio)
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new().route("/guilds", get(list_guilds))
|
||||
}
|
||||
|
||||
/// Routes that operate on a specific guild (nested at /api/audio/{guild_id})
|
||||
pub fn get_guild_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/play", post(play_audio))
|
||||
.route("/pause", post(pause_audio))
|
||||
.route("/resume", post(resume_audio))
|
||||
.route("/stop", post(stop_audio))
|
||||
.route("/skip", post(skip_audio))
|
||||
.route("/status", get(audio_status))
|
||||
}
|
||||
|
||||
// ── Shared helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct PlayTrackRequest {
|
||||
url: String,
|
||||
@@ -52,6 +66,38 @@ async fn get_discord_snowflake(local_user_id: Uuid) -> Result<u64> {
|
||||
.ok_or_else(|| Error::not_found("Discord account not connected".to_string()))
|
||||
}
|
||||
|
||||
// ── GET /api/audio/guilds ─────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct GuildInfo {
|
||||
id: String,
|
||||
name: String,
|
||||
}
|
||||
|
||||
/// Returns all guilds the bot is currently in (from its Discord cache).
|
||||
async fn list_guilds(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Result<Json<Vec<GuildInfo>>> {
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
let guilds: Vec<GuildInfo> = state
|
||||
.cache
|
||||
.guilds()
|
||||
.into_iter()
|
||||
.filter_map(|guild_id| {
|
||||
state.cache.guild(guild_id).map(|g| GuildInfo {
|
||||
id: guild_id.get().to_string(),
|
||||
name: g.name.clone(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(guilds))
|
||||
}
|
||||
|
||||
// ── POST /api/audio/{guild_id}/play ──────────────────────────────────────────
|
||||
|
||||
async fn play_audio(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
@@ -88,6 +134,8 @@ async fn play_audio(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── POST /api/audio/{guild_id}/pause ─────────────────────────────────────────
|
||||
|
||||
async fn pause_audio(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
@@ -96,18 +144,18 @@ async fn pause_audio(
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
// Pause the track
|
||||
let manager = get_songbird();
|
||||
pause_track(manager, &guild_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── POST /api/audio/{guild_id}/resume ────────────────────────────────────────
|
||||
|
||||
async fn resume_audio(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
@@ -116,14 +164,106 @@ async fn resume_audio(
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
log::debug!("Resuming audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
// Resume the track
|
||||
let manager = get_songbird();
|
||||
resume_track(manager, &guild_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── POST /api/audio/{guild_id}/stop ──────────────────────────────────────────
|
||||
|
||||
async fn stop_audio(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> Result<()> {
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
log::debug!("Stopping audio in guild: {}", guild_id);
|
||||
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
let manager = get_songbird();
|
||||
stop_track(manager, &guild_id)
|
||||
.await
|
||||
.map_err(|e| Error::new(500, e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn skip_audio(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> Result<()> {
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
log::debug!("<{}> Skipping audio", guild_id);
|
||||
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
let manager = get_songbird();
|
||||
skip_track(manager, &guild_id)
|
||||
.await
|
||||
.map_err(|e| Error::new(500, e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AudioStatus {
|
||||
voice_channel: Option<String>,
|
||||
is_paused: bool,
|
||||
current_track: Option<TrackInfo>,
|
||||
queue: Vec<TrackInfo>,
|
||||
}
|
||||
|
||||
async fn audio_status(
|
||||
SessionAuthorization(session): SessionAuthorization,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> Result<Json<AudioStatus>> {
|
||||
session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
let guild_id_snowflake = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
// ── Voice channel: look up the bot's own voice state + channel name from cache ──
|
||||
let bot_user_id = state.cache.current_user().id;
|
||||
let voice_channel = state
|
||||
.cache
|
||||
.guild(guild_id_snowflake)
|
||||
.and_then(|guild| {
|
||||
let ch_id = guild
|
||||
.voice_states
|
||||
.get(&bot_user_id)
|
||||
.and_then(|vs| vs.channel_id)?;
|
||||
guild.channels.get(&ch_id).map(|ch| ch.name.clone())
|
||||
});
|
||||
|
||||
// ── Playback paused state (delegated to siren-bot to keep songbird internal) ──
|
||||
let is_paused = get_is_paused(guild_id).await;
|
||||
|
||||
// ── Queue metadata from our store (index 0 = currently playing) ──
|
||||
let mut full_queue = get_queue(guild_id);
|
||||
let current_track = if !full_queue.is_empty() {
|
||||
Some(full_queue.remove(0))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Json(AudioStatus {
|
||||
voice_channel,
|
||||
is_paused,
|
||||
current_track,
|
||||
queue: full_queue,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -294,8 +294,8 @@ async fn do_oauth_callback(
|
||||
|
||||
None => {
|
||||
// Find existing connection → local user_id
|
||||
let local_user_id: Option<(Uuid, String)> = sqlx::query_as(
|
||||
"SELECT u.id, u.username \
|
||||
let local_user_id: Option<(Uuid, String, String)> = sqlx::query_as(
|
||||
"SELECT u.id, u.username, u.status \
|
||||
FROM user_connections uc \
|
||||
JOIN users u ON u.id = uc.user_id \
|
||||
WHERE uc.provider = 'discord' AND uc.provider_user_id = $1",
|
||||
@@ -311,6 +311,10 @@ async fn do_oauth_callback(
|
||||
let (user_id, username) = match local_user_id {
|
||||
// Already linked — use the existing local user
|
||||
Some(row) => {
|
||||
// Reject banned accounts
|
||||
if row.2 == "banned" {
|
||||
return err_redirect(StatusCode::FORBIDDEN);
|
||||
}
|
||||
// Keep provider fields up to date
|
||||
sqlx::query(
|
||||
"UPDATE user_connections \
|
||||
@@ -327,7 +331,7 @@ async fn do_oauth_callback(
|
||||
err_redirect(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
row
|
||||
(row.0, row.1)
|
||||
}
|
||||
|
||||
// First login — create a local user + connection
|
||||
|
||||
@@ -85,6 +85,10 @@ pub struct UserInfo {
|
||||
/// OAuth and can safely disconnect OAuth providers).
|
||||
pub has_password: bool,
|
||||
pub connections: Vec<ConnectionInfo>,
|
||||
/// Site-level role: `"admin"` or `"user"`.
|
||||
pub role: String,
|
||||
/// Account status: `"active"` or `"banned"`.
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
@@ -95,6 +99,8 @@ struct DbUser {
|
||||
last_name: Option<String>,
|
||||
email: Option<String>,
|
||||
password_hash: Option<String>,
|
||||
role: String,
|
||||
status: String,
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
@@ -176,7 +182,8 @@ async fn load_user_info(user_id: Uuid) -> Result<UserInfo> {
|
||||
let pool = data::pool();
|
||||
|
||||
let user: DbUser = sqlx::query_as(
|
||||
"SELECT id, username, first_name, last_name, email, password_hash FROM users WHERE id = $1",
|
||||
"SELECT id, username, first_name, last_name, email, password_hash, role, status \
|
||||
FROM users WHERE id = $1",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
@@ -197,6 +204,8 @@ async fn load_user_info(user_id: Uuid) -> Result<UserInfo> {
|
||||
last_name: user.last_name,
|
||||
email: user.email,
|
||||
has_password: user.password_hash.is_some(),
|
||||
role: user.role,
|
||||
status: user.status,
|
||||
connections: connections
|
||||
.into_iter()
|
||||
.map(|c| ConnectionInfo {
|
||||
@@ -260,13 +269,13 @@ async fn login(
|
||||
) -> Result<impl IntoResponse> {
|
||||
let pool = data::pool();
|
||||
|
||||
let row: Option<(Uuid, String, Option<String>)> =
|
||||
sqlx::query_as("SELECT id, username, password_hash FROM users WHERE username = $1")
|
||||
let row: Option<(Uuid, String, Option<String>, String)> =
|
||||
sqlx::query_as("SELECT id, username, password_hash, status FROM users WHERE username = $1")
|
||||
.bind(&payload.username)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let (user_id, username, password_hash) =
|
||||
let (user_id, username, password_hash, status) =
|
||||
row.ok_or_else(|| Error::new(401, "Invalid username or password".into()))?;
|
||||
|
||||
let hash =
|
||||
@@ -276,6 +285,10 @@ async fn login(
|
||||
return Err(Error::new(401, "Invalid username or password".into()));
|
||||
}
|
||||
|
||||
if status == "banned" {
|
||||
return Err(Error::new(403, "This account has been banned".into()));
|
||||
}
|
||||
|
||||
let ip = extract_ip(&headers);
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
auth::{bearer_token::BearerTokenClaims, session::Session},
|
||||
error::Result,
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
@@ -10,6 +10,7 @@ use axum_extra::extract::CookieJar;
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, Validation, decode};
|
||||
use sha2::{Digest, Sha256};
|
||||
use siren_core::data;
|
||||
|
||||
pub const COOKIE_NAME: &str = "siren_session";
|
||||
|
||||
@@ -130,3 +131,46 @@ pub async fn check_cookie_from_header_str(
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Extractor that requires the caller to be an authenticated site admin.
|
||||
///
|
||||
/// Returns `401 Unauthorized` if there is no valid session, or
|
||||
/// `403 Forbidden` if the user's role is not `"admin"`.
|
||||
/// On success, the inner `Session` is available to the handler.
|
||||
pub struct AdminAuthorization(pub Session);
|
||||
|
||||
impl<S> FromRequestParts<S> for AdminAuthorization
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Error;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> std::result::Result<Self, Self::Rejection> {
|
||||
let SessionAuthorization(maybe_session) =
|
||||
SessionAuthorization::from_request_parts(parts, state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session = maybe_session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
// Verify admin role in the database
|
||||
let pool = data::pool();
|
||||
let role: Option<String> = sqlx::query_scalar("SELECT role FROM users WHERE id = $1")
|
||||
.bind(session.user_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("DB error checking admin role: {e}");
|
||||
Error::from(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
})?;
|
||||
|
||||
match role.as_deref() {
|
||||
Some("admin") => Ok(AdminAuthorization(session)),
|
||||
Some(_) => Err(Error::from(StatusCode::FORBIDDEN)),
|
||||
None => Err(Error::from(StatusCode::UNAUTHORIZED)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ pub use local::UserInfo;
|
||||
pub use session::Session;
|
||||
|
||||
pub mod middleware;
|
||||
pub use middleware::SessionAuthorization;
|
||||
pub use middleware::{AdminAuthorization, SessionAuthorization};
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod admin;
|
||||
pub mod app;
|
||||
mod app_state;
|
||||
pub mod audio;
|
||||
@@ -13,8 +14,10 @@ use std::sync::Arc;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.nest("/admin", admin::get_routes())
|
||||
.nest("/auth", auth::get_routes())
|
||||
.nest("/audio/{guild_id}", audio::get_routes())
|
||||
.nest("/audio", audio::get_routes())
|
||||
.nest("/audio/{guild_id}", audio::get_guild_routes())
|
||||
.nest("/dice", dice::get_routes())
|
||||
.nest("/grid", grid::get_routes())
|
||||
}
|
||||
|
||||
@@ -20,3 +20,4 @@ chrono = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
|
||||
@@ -11,6 +11,7 @@ use std::sync::Arc;
|
||||
pub mod mute;
|
||||
pub mod pause;
|
||||
pub mod play;
|
||||
pub mod queue;
|
||||
pub mod resume;
|
||||
pub mod skip;
|
||||
pub mod stop;
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
use super::{is_valid_url, join_voice_channel, leave_voice_channel};
|
||||
use super::{
|
||||
is_valid_url,
|
||||
join_voice_channel,
|
||||
leave_voice_channel,
|
||||
queue::{TrackInfo, enqueue_tracks, pop_front},
|
||||
};
|
||||
use crate::{
|
||||
chat::{create_message_response, edit_response, process_message},
|
||||
error::{Error, Result},
|
||||
@@ -113,6 +118,15 @@ pub async fn enqueue_track(
|
||||
|
||||
playlist_items = get_ytdlp_items(track_url)?;
|
||||
|
||||
// Collect TrackInfo for the queue store before borrowing `item` in the loop
|
||||
let track_infos: Vec<TrackInfo> = playlist_items
|
||||
.iter()
|
||||
.map(|item| TrackInfo {
|
||||
title: item.get_title().to_owned(),
|
||||
url: item.get_url().to_owned(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add each track to the queue
|
||||
for item in &playlist_items {
|
||||
let volume = guild.volume as f32 / 100.0;
|
||||
@@ -137,6 +151,10 @@ pub async fn enqueue_track(
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Store track metadata so the REST API can expose queue info
|
||||
enqueue_tracks(guild_id.get(), track_infos);
|
||||
|
||||
if handler.queue().is_empty() {
|
||||
let _ = handler.queue().resume();
|
||||
}
|
||||
@@ -204,6 +222,9 @@ struct TrackEndNotifier {
|
||||
impl EventHandler for TrackEndNotifier {
|
||||
async fn act(&self, ctx: &songbird::events::EventContext<'_>) -> Option<songbird::events::Event> {
|
||||
if let songbird::EventContext::Track(_track_list) = ctx {
|
||||
// Remove the finished track from our metadata store
|
||||
pop_front(self.guild_id.get());
|
||||
|
||||
if let Some(call) = self.call.get(self.guild_id) {
|
||||
let mut handler = call.lock().await;
|
||||
if handler.queue().is_empty() {
|
||||
|
||||
88
crates/siren-bot/src/commands/audio/queue.rs
Normal file
88
crates/siren-bot/src/commands/audio/queue.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use crate::handler::get_songbird;
|
||||
use dashmap::DashMap;
|
||||
use serde::Serialize;
|
||||
use serenity::model::prelude::GuildId;
|
||||
use songbird::tracks::PlayMode;
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
sync::{Arc, OnceLock},
|
||||
};
|
||||
|
||||
/// Metadata for a single track stored in our queue.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct TrackInfo {
|
||||
pub title: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
/// Global map of guild_id → ordered queue of TrackInfo.
|
||||
/// Initialised once by the bot handler's `ready` event.
|
||||
static TRACK_QUEUES: OnceLock<Arc<DashMap<u64, VecDeque<TrackInfo>>>> = OnceLock::new();
|
||||
|
||||
/// Call once from the `ready` event handler to initialise the store.
|
||||
pub fn init_track_queues() {
|
||||
TRACK_QUEUES
|
||||
.set(Arc::new(DashMap::new()))
|
||||
.ok();
|
||||
}
|
||||
|
||||
/// Returns a reference to the global TRACK_QUEUES map.
|
||||
fn queues() -> &'static Arc<DashMap<u64, VecDeque<TrackInfo>>> {
|
||||
TRACK_QUEUES
|
||||
.get()
|
||||
.expect("TRACK_QUEUES not initialised – call init_track_queues() in the ready handler")
|
||||
}
|
||||
|
||||
/// Append one or more tracks to the end of a guild's queue.
|
||||
pub fn enqueue_tracks(guild_id: u64, tracks: Vec<TrackInfo>) {
|
||||
let mut entry = queues().entry(guild_id).or_default();
|
||||
for t in tracks {
|
||||
entry.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove and return the front track (called when a track finishes).
|
||||
pub fn pop_front(guild_id: u64) -> Option<TrackInfo> {
|
||||
queues()
|
||||
.get_mut(&guild_id)
|
||||
.and_then(|mut q: dashmap::mapref::one::RefMut<u64, VecDeque<TrackInfo>>| q.pop_front())
|
||||
}
|
||||
|
||||
/// Clear the entire queue for a guild (called on stop).
|
||||
pub fn clear_queue(guild_id: u64) {
|
||||
if let Some(mut q) = queues().get_mut(&guild_id) {
|
||||
let q: &mut VecDeque<TrackInfo> = q.value_mut();
|
||||
q.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a snapshot of the current queue for a guild.
|
||||
/// Index 0 is the currently-playing track, index 1+ are upcoming.
|
||||
pub fn get_queue(guild_id: u64) -> Vec<TrackInfo> {
|
||||
queues()
|
||||
.get(&guild_id)
|
||||
.map(|q: dashmap::mapref::one::Ref<u64, VecDeque<TrackInfo>>| {
|
||||
q.iter().cloned().collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Returns `true` if the bot is currently paused in the given guild.
|
||||
/// Encapsulates the songbird dependency so `siren-api` doesn't need it directly.
|
||||
pub async fn get_is_paused(guild_id: u64) -> bool {
|
||||
let manager = get_songbird();
|
||||
let serenity_guild_id = GuildId::from(guild_id);
|
||||
if let Some(handler_lock) = manager.get(serenity_guild_id) {
|
||||
let handler = handler_lock.lock().await;
|
||||
let current = handler.queue().current();
|
||||
drop(handler);
|
||||
if let Some(track) = current {
|
||||
return track
|
||||
.get_info()
|
||||
.await
|
||||
.map(|info| info.playing == PlayMode::Pause)
|
||||
.unwrap_or(false);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
use crate::{
|
||||
chat::{edit_response, process_message},
|
||||
commands::audio::queue::pop_front,
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CreateCommand},
|
||||
all::{CommandInteraction, CreateCommand, GuildId},
|
||||
prelude::*,
|
||||
};
|
||||
use songbird::Songbird;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
@@ -29,17 +32,27 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
};
|
||||
|
||||
// Skip the track
|
||||
match skip_track(manager, guild_id).await {
|
||||
Ok(_) => {
|
||||
log::debug!("<{guild_id}> Skipped the track");
|
||||
edit_response(ctx, command, "Skipping the track".to_string()).await;
|
||||
}
|
||||
Err(err) => edit_response(ctx, command, format!("Failed to skip: {}", err)).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn skip_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<(), String> {
|
||||
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
|
||||
let handler = handler_lock.lock().await;
|
||||
match handler.queue().skip() {
|
||||
Ok(_) => {
|
||||
log::debug!("<{guild_id}> Skipped the track");
|
||||
edit_response(ctx, command, "Skipping the track".to_string()).await;
|
||||
}
|
||||
Err(err) => {
|
||||
edit_response(ctx, command, format!("Failed to skip: {}", err)).await;
|
||||
}
|
||||
}
|
||||
handler
|
||||
.queue()
|
||||
.skip()
|
||||
.map_err(|e| e.to_string())?;
|
||||
// Pop the current track from our metadata store; the next track (if any) moves to front
|
||||
pop_front(guild_id.get());
|
||||
Ok(())
|
||||
} else {
|
||||
Err("No active audio session in this guild".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use crate::{
|
||||
chat::{edit_response, process_message},
|
||||
commands::audio::queue::clear_queue,
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CreateCommand},
|
||||
all::{CommandInteraction, CreateCommand, GuildId},
|
||||
prelude::*,
|
||||
};
|
||||
use songbird::Songbird;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
@@ -29,11 +32,23 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
};
|
||||
|
||||
// Stop the track and clear the queue
|
||||
if let Some(handler_lock) = manager.get(guild_id) {
|
||||
match stop_track(manager, &guild_id).await {
|
||||
Ok(_) => {
|
||||
log::debug!("<{guild_id}> Stopped the track");
|
||||
edit_response(ctx, command, "Stopping the tracks".to_string()).await;
|
||||
}
|
||||
Err(err) => edit_response(ctx, command, format!("Failed to stop: {}", err)).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stop_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<(), String> {
|
||||
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
|
||||
let handler = handler_lock.lock().await;
|
||||
handler.queue().stop();
|
||||
log::debug!("<{guild_id}> Stopped the track");
|
||||
edit_response(ctx, command, "Stopping the tracks".to_string()).await;
|
||||
clear_queue(guild_id.get());
|
||||
Ok(())
|
||||
} else {
|
||||
Err("No active audio session in this guild".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::{chat::create_modal_response, commands};
|
||||
use crate::{
|
||||
HttpKey,
|
||||
commands::fun::roll::{format_roll, roll_dice, send_roll_message},
|
||||
commands::{audio::queue::init_track_queues, fun::roll::{format_roll, roll_dice, send_roll_message}},
|
||||
};
|
||||
use serenity::{
|
||||
all::{
|
||||
@@ -64,6 +64,9 @@ impl EventHandler for BotHandler {
|
||||
log::warn!("No ready guilds found");
|
||||
}
|
||||
|
||||
// Initialise the track-queue metadata store (idempotent)
|
||||
init_track_queues();
|
||||
|
||||
if SONGBIRD.get().is_none() {
|
||||
let songbird = songbird::get(&ctx).await.unwrap();
|
||||
SONGBIRD
|
||||
|
||||
Reference in New Issue
Block a user