Major refactor

This commit is contained in:
2026-04-03 23:04:51 -04:00
parent e7f337c735
commit 35d07e8df1
124 changed files with 4929 additions and 2429 deletions

View File

@@ -0,0 +1,28 @@
[package]
name = "siren-api"
edition.workspace = true
version.workspace = true
rust-version.workspace = true
authors.workspace = true
[dependencies]
siren-core = { workspace = true }
siren-bot = { workspace = true }
tokio = { workspace = true }
log = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
axum = { workspace = true }
axum-extra = { workspace = true }
serenity = { workspace = true }
reqwest = { workspace = true }
jsonwebtoken = { workspace = true }
chrono = { workspace = true }
uuid = { workspace = true }
rand = { workspace = true }
rand_chacha = { workspace = true }
sqlx = { workspace = true }
redis = { workspace = true }
tower-http = { workspace = true }
dashmap = { workspace = true }
futures-util = { workspace = true }

View File

@@ -0,0 +1,53 @@
use crate::{AppState, error::Result};
use axum::Router;
use std::{env, sync::Arc};
use tokio::net::TcpListener;
use tower_http::{
cors::{Any, CorsLayer},
services::{ServeDir, ServeFile},
};
pub struct App {
app_state: AppState,
}
impl App {
pub fn new(app_state: AppState) -> Self {
Self { app_state }
}
pub async fn serve(self) -> Result<()> {
log::debug!("Starting API...");
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
// Serve the built React frontend from frontend/dist (relative to the
// working directory). Falls back gracefully if the directory does not
// exist yet (e.g. during development when using `npm run dev`).
let frontend_dir = env::current_dir()
.unwrap_or_default()
.join("frontend")
.join("dist");
// For SPA routing: any path not matched by a real file (e.g. /map/<id>)
// falls back to index.html so React can handle client-side routing.
let index_html = frontend_dir.join("index.html");
let serve_dir = ServeDir::new(&frontend_dir).not_found_service(ServeFile::new(index_html));
let app = Router::new()
.nest("/api", crate::get_routes())
.fallback_service(serve_dir)
.layer(cors)
.with_state(Arc::new(self.app_state));
let api_port: String = env::var("API_PORT").expect("Expected a port in the environment");
let addr = format!("0.0.0.0:{}", api_port);
let listener = TcpListener::bind(&addr).await?;
log::info!("API is listening on {}", &addr);
Ok(axum::serve(listener, app).await?)
}
}

View File

@@ -0,0 +1,23 @@
use dashmap::DashMap;
use serenity::{
all::{Cache, Http},
prelude::Mutex,
};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::broadcast;
#[derive(Clone)]
pub struct AppState {
pub client: reqwest::Client,
pub client_id: String,
pub client_secret: String,
pub base_url: String,
/// Maps oauth_state → ui_redirect_uri.
/// Populated on /authorize, consumed on /callback.
pub discord_authorize_cache: Arc<Mutex<HashMap<String, String>>>,
pub http: Arc<Http>,
pub cache: Arc<Cache>,
/// Per-map WebSocket broadcast channels for real-time collaboration.
/// Key is the CSPRNG map ID (TEXT).
pub map_rooms: Arc<DashMap<String, broadcast::Sender<String>>>,
}

View File

@@ -0,0 +1,105 @@
use crate::{
AppState,
auth::{AuthorizationMiddleware, Session},
error::{Error, Result},
};
use axum::{
Extension,
Json,
Router,
extract::{Path, State},
middleware::from_extractor,
routing::post,
};
use serde::Deserialize;
use siren_bot::{
commands::audio::{
join_voice_channel,
pause::pause_track,
play::enqueue_track,
resume::resume_track,
},
handler::get_songbird,
};
use std::sync::Arc;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/play", post(play_audio))
.route_layer(from_extractor::<AuthorizationMiddleware>())
.route("/pause", post(pause_audio))
.route_layer(from_extractor::<AuthorizationMiddleware>())
.route("/resume", post(resume_audio))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
#[derive(Deserialize)]
struct PlayTrackRequest {
url: String,
}
async fn play_audio(
Extension(session): Extension<Session>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
Json(payload): Json<PlayTrackRequest>,
) -> Result<()> {
log::debug!("Playing audio in guild: {}", guild_id);
// Check if the user exists in the cache
let user_id = session.user_id;
let user_id = match state.cache.user(user_id) {
Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())),
};
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Play the track
let manager = get_songbird();
let _channel_id = join_voice_channel(&state.cache, &manager, &guild_id, &user_id).await?;
enqueue_track(manager, guild_id.to_owned(), &payload.url).await?;
Ok(())
}
async fn pause_audio(
Extension(_): Extension<Session>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
) -> Result<()> {
log::debug!("Pausing audio in guild: {}", guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Pause the track
let manager = get_songbird();
pause_track(manager, &guild_id).await?;
Ok(())
}
async fn resume_audio(
Extension(_): Extension<Session>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
) -> Result<()> {
log::debug!("Pausing audio in guild: {}", guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Pause the track
let manager = get_songbird();
resume_track(manager, &guild_id).await?;
Ok(())
}

View File

@@ -0,0 +1,10 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct BearerTokenClaims {
pub sub: u64,
pub name: String,
pub iat: i64,
pub exp: i64,
pub jti: String,
}

View File

@@ -0,0 +1,225 @@
use crate::{
AppState,
auth::{bearer_token::BearerTokenClaims, csprng, session::Session},
};
use axum::{
Router,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Redirect},
routing::get,
};
use serde::{Deserialize, Serialize};
use std::{env, sync::Arc};
const DISCORD_REDIRECT_PATH: &str = "/api/auth/discord/callback";
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/authorize", get(discord_authorize))
.route("/callback", get(discord_callback))
}
#[derive(Deserialize)]
struct AuthorizeQuery {
redirect_uri: String,
}
#[derive(Deserialize)]
struct CallbackQuery {
code: String,
state: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct DiscordTokenResponse {
access_token: String,
token_type: String,
expires_in: u64,
refresh_token: String,
scope: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct DiscordUser {
id: String,
username: String,
discriminator: String,
avatar: Option<String>,
}
async fn discord_authorize(
State(state): State<Arc<AppState>>,
Query(query): Query<AuthorizeQuery>,
) -> impl IntoResponse {
let oauth_state = csprng(16);
state
.discord_authorize_cache
.lock()
.await
.insert(oauth_state.clone(), query.redirect_uri);
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
let encoded_callback = discord_callback_url.replace(':', "%3A").replace('/', "%2F");
let discord_auth_url = format!(
"https://discord.com/api/oauth2/authorize\
?client_id={}\
&redirect_uri={}\
&response_type=code\
&scope=identify\
&state={}",
state.client_id, encoded_callback, oauth_state,
);
match serde_json::to_string(&discord_auth_url) {
Ok(json) => Ok(json),
Err(e) => {
log::error!("Failed to serialize Discord OAuth URL: {e}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
async fn discord_callback(
State(state): State<Arc<AppState>>,
Query(query): Query<CallbackQuery>,
) -> impl IntoResponse {
match do_oauth_callback(state, query).await {
Ok((token, ui_redirect_uri)) => {
Redirect::temporary(&format!("{}?token={}", ui_redirect_uri, token)).into_response()
}
Err((e, ui_redirect_uri)) => {
log::error!("OAuth callback error: {:?}", e);
let fallback = ui_redirect_uri.unwrap_or_else(|| "/".to_string());
Redirect::temporary(&format!("{}?error=auth_failed", fallback)).into_response()
}
}
}
async fn do_oauth_callback(
state: Arc<AppState>,
query: CallbackQuery,
) -> Result<(String, String), (crate::error::Error, Option<String>)> {
// Validate the state and retrieve the associated UI redirect URI
let ui_redirect_uri = {
let mut oauth_states = state.discord_authorize_cache.lock().await;
match query.state {
Some(ref oauth_state) => match oauth_states.remove(oauth_state) {
Some(uri) => uri,
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
},
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
}
};
// Helper closure to tag errors with the redirect URI we already know
let redirect = ui_redirect_uri.clone();
let err = |s: StatusCode| -> Result<_, (crate::error::Error, Option<String>)> {
Err((s.into(), Some(redirect.clone())))
};
// The discord redirect_uri in the token exchange must match what was sent in /authorize
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
// Exchange code for an access token
let token_response = state
.client
.post("https://discord.com/api/oauth2/token")
.form(&[
("client_id", state.client_id.as_str()),
("client_secret", state.client_secret.as_str()),
("grant_type", "authorization_code"),
("code", query.code.as_str()),
("redirect_uri", discord_callback_url.as_str()),
])
.send()
.await
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
if !token_response.status().is_success() {
log::error!(
"Failed to exchange token: {:?}",
token_response.text().await
);
return err(StatusCode::INTERNAL_SERVER_ERROR);
}
let token_data: DiscordTokenResponse = token_response
.json()
.await
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
// Fetch user information from Discord
let user_response = state
.client
.get("https://discord.com/api/users/@me")
.bearer_auth(token_data.access_token)
.send()
.await
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
if !user_response.status().is_success() {
log::error!(
"Failed to fetch user information: {:?}",
user_response.text().await
);
return err(StatusCode::INTERNAL_SERVER_ERROR);
}
let user_data: DiscordUser = user_response
.json()
.await
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
log::debug!("User authenticated: {:?}", user_data);
let user_id: i64 = user_data
.id
.parse::<i64>()
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
// Upsert the Discord user into the local users table
let pool = siren_core::data::pool();
sqlx::query(
"INSERT INTO users (id, username, avatar, updated_at)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (id) DO UPDATE
SET username = EXCLUDED.username,
avatar = EXCLUDED.avatar,
updated_at = NOW()",
)
.bind(user_id)
.bind(&user_data.username)
.bind(&user_data.avatar)
.execute(pool)
.await
.map_err(|e| {
log::error!("Failed to upsert user: {e}");
err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
})?;
// Create and insert the session
let session = Session::new(user_id as u64, user_data.username.clone());
session
.insert()
.await
.map_err(|e| (e, Some(ui_redirect_uri.clone())))?;
let issued_at = chrono::Utc::now();
let claims = BearerTokenClaims {
sub: session.user_id,
name: session.user_name.clone(),
iat: issued_at.timestamp(),
exp: session.expires_at.timestamp(),
jti: session.session_id.clone(),
};
let jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
Ok((token, ui_redirect_uri))
}

View File

@@ -0,0 +1,107 @@
use crate::{
auth::{bearer_token::BearerTokenClaims, session::Session},
error::Result,
};
use axum::{
extract::FromRequestParts,
http::{Method, StatusCode, request::Parts},
};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use chrono::Utc;
use jsonwebtoken::{DecodingKey, Validation, decode};
// ---------------------------------------------------------------------------
// AuthorizationMiddleware — rejects unauthenticated requests
// ---------------------------------------------------------------------------
pub struct AuthorizationMiddleware;
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> std::result::Result<Self, Self::Rejection> {
// For options requests browsers will not send the authorization header.
if parts.method == Method::OPTIONS {
return Ok(Self);
}
// Check for a Bearer token in the `Authorization` header.
if let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
{
return match check_bearer_auth(bearer.token()).await {
Ok(session) => {
parts.extensions.insert(session);
Ok(Self)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
};
}
Err(StatusCode::UNAUTHORIZED)
}
}
// ---------------------------------------------------------------------------
// OptionalAuth — extracts a Session if present, otherwise None
// ---------------------------------------------------------------------------
/// Wraps an optional authenticated session.
/// Handlers that use this extractor work for both authenticated and
/// unauthenticated callers; callers with a valid Bearer token get a `Some(session)`.
pub struct OptionalAuth(pub Option<Session>);
impl<S> FromRequestParts<S> for OptionalAuth
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> std::result::Result<Self, Self::Rejection> {
if let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
{
if let Ok(session) = check_bearer_auth(bearer.token()).await {
parts.extensions.insert(session.clone());
return Ok(Self(Some(session)));
}
}
Ok(Self(None))
}
}
// ---------------------------------------------------------------------------
// Shared helper
// ---------------------------------------------------------------------------
pub async fn check_bearer_auth(bearer_token: &str) -> Result<Session> {
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
.map_err(|_| StatusCode::UNAUTHORIZED)?;
let claims = token_data.claims;
let now = Utc::now().timestamp();
if claims.exp < now {
return Err(StatusCode::UNAUTHORIZED.into());
}
match Session::find(&claims.jti).await {
Ok(Some(session)) => Ok(session),
_ => Err(StatusCode::UNAUTHORIZED)?,
}
}

View File

@@ -0,0 +1,24 @@
use crate::AppState;
use axum::Router;
use rand::RngExt;
use std::sync::Arc;
mod discord;
mod session;
pub use session::Session;
mod bearer_token;
pub mod middleware;
pub use middleware::{AuthorizationMiddleware, OptionalAuth};
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new().nest("/discord", discord::get_routes())
}
pub fn csprng(take: usize) -> String {
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
rand::rng()
.sample_iter(rand::distr::Alphanumeric)
.take(take)
.map(char::from)
.collect()
}

View File

@@ -0,0 +1,68 @@
use crate::{auth::csprng, error::Result};
use chrono::{DateTime, Utc};
use redis::{AsyncCommands, RedisResult};
use serde::{Deserialize, Serialize};
use siren_core::data;
use std::{env, sync::OnceLock};
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
fn get_session_ttl() -> i64 {
// Initialize the SESSION_TTL value lazily
*SESSION_TTL.get_or_init(|| {
env::var("API_SESSION_TTL")
.ok()
.and_then(|val| val.parse::<i64>().ok())
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
})
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Session {
pub session_id: String,
pub user_id: u64,
pub user_name: String,
pub expires_at: DateTime<Utc>,
}
impl Session {
pub fn new(user_id: u64, user_name: String) -> Session {
let now = Utc::now();
let session_ttl = get_session_ttl();
Session {
session_id: csprng(32),
user_id,
user_name,
expires_at: now + chrono::Duration::seconds(session_ttl),
}
}
pub async fn insert(&self) -> Result<()> {
let mut redis = data::redis_async_connection().await?;
let session_id = self.session_id.clone();
let session_ttl = get_session_ttl();
redis
.set_ex::<_, _, ()>(session_id, serde_json::to_string(self)?, session_ttl as u64)
.await?;
Ok(())
}
pub async fn find(session_id: &str) -> Result<Option<Session>> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<Option<String>> = redis.get(session_id).await;
match result {
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
}
}
pub async fn delete(session_id: &str) -> Result<()> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<()> = redis.del(session_id).await;
match result {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
}
}

View File

@@ -0,0 +1,201 @@
use crate::{
AppState,
auth::{AuthorizationMiddleware, Session},
error::{Error, Result},
};
use axum::{
Extension,
Json,
Router,
extract::{Path, State},
middleware::from_extractor,
routing::post,
};
use serde::{Deserialize, Serialize};
use siren_bot::commands::fun::roll::{format_roll, parse_dice};
use siren_core::data::{ExecutableQuery, Value, condition::Condition, query::QueryBuilder};
use std::{fmt::Display, str::FromStr, sync::Arc};
use uuid::Uuid;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/{guild_id}/track", post(add_track_dice))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
const TABLE_NAME: &str = "dice_track";
#[derive(Serialize, Deserialize, Clone, Debug)]
enum TrackDiceOperator {
#[serde(rename = "eq")]
Equal,
#[serde(rename = "lt")]
LessThan,
#[serde(rename = "lte")]
LessThanEqual,
#[serde(rename = "gt")]
GreaterThan,
#[serde(rename = "gte")]
GreaterThanEqual,
}
// Implementing the ToString trait for converting the enum to a string
impl Display for TrackDiceOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
TrackDiceOperator::Equal => "eq".to_string(),
TrackDiceOperator::LessThan => "lt".to_string(),
TrackDiceOperator::LessThanEqual => "lte".to_string(),
TrackDiceOperator::GreaterThan => "gt".to_string(),
TrackDiceOperator::GreaterThanEqual => "gte".to_string(),
};
write!(f, "{}", str)
}
}
// Implementing the FromStr trait for parsing a string into the enum
impl FromStr for TrackDiceOperator {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"eq" => Ok(TrackDiceOperator::Equal),
"lt" => Ok(TrackDiceOperator::LessThan),
"lte" => Ok(TrackDiceOperator::LessThanEqual),
"gt" => Ok(TrackDiceOperator::GreaterThan),
"gte" => Ok(TrackDiceOperator::GreaterThanEqual),
_ => Err(format!("Unknown value for TrackDiceOperator: {}", s)),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct DiceTrackPayload {
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<TrackDiceOperator>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct InsertDiceTrack {
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct QueryDiceTrack {
id: Uuid,
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
impl QueryDiceTrack {
pub async fn find(dice: &InsertDiceTrack) -> Option<Self> {
QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::and(
Condition::is_equal("guild_id", Value::BigInt(dice.guild_id)),
Condition::and(
Condition::is_equal("owner_id", Value::BigInt(dice.owner_id)),
Condition::and(
Condition::is_equal("dice", Value::Text(dice.dice.clone())),
Condition::and(
Condition::is_equal("user_id", Value::OptionalBigInt(dice.user_id)),
Condition::and(
Condition::is_equal("value", Value::OptionalInt(dice.value)),
Condition::is_equal("operator", Value::OptionalText(dice.operator.clone())),
),
),
),
),
))
.fetch_optional()
.await
}
}
impl InsertDiceTrack {
pub async fn insert(&self) -> Result<QueryDiceTrack> {
let pool = siren_core::data::pool();
let query = format!(
"INSERT INTO {} (
guild_id,
owner_id,
dice,
user_id,
value,
operator
) VALUES (
$1, $2, $3, $4, $5, $6
) RETURNING *",
TABLE_NAME
);
let item: QueryDiceTrack = match sqlx::query_as(&query)
.bind(self.guild_id)
.bind(self.owner_id)
.bind(&self.dice)
.bind(self.user_id)
.bind(self.value)
.bind(&self.operator)
.fetch_optional(pool)
.await?
{
Some(result) => result,
None => return Err(Error::new(500, "Error storing".to_string())),
};
Ok(item)
}
}
pub async fn add_track_dice(
Extension(session): Extension<Session>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
Json(payload): Json<DiceTrackPayload>,
) -> Result<Json<QueryDiceTrack>> {
// Check if the user exists in the cache
let owner_id = session.user_id;
let owner_id = match state.cache.user(owner_id) {
Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())),
};
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
let dice = parse_dice(&payload.dice)?;
let insert_dice = InsertDiceTrack {
guild_id: guild_id.get() as i64,
owner_id: owner_id.get() as i64,
dice: format_roll(dice.0, dice.1, dice.2),
user_id: payload.user_id,
value: payload.value,
operator: match payload.operator {
None => None,
Some(s) => Some(s.to_string()),
},
};
// Check for existing dice tracks
let results = QueryDiceTrack::find(&insert_dice).await;
match results {
Some(dice_track) => Ok(Json(dice_track)),
None => {
let dice_track = insert_dice.insert().await?;
Ok(Json(dice_track))
}
}
}

View File

@@ -0,0 +1,128 @@
use axum::{
Json,
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use std::fmt;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Deserialize, Serialize)]
pub struct Error {
pub status: u16,
pub details: String,
}
impl Error {
pub fn new(status: u16, details: String) -> Self {
Self { status, details }
}
pub fn not_found(details: String) -> Self {
Self::new(404, details)
}
pub fn internal_server_error(details: String) -> Self {
Self::new(500, details)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.details.as_str())
}
}
impl std::error::Error for Error {
fn description(&self) -> &str {
&self.details
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = Json(serde_json::json!({
"error": {
"status": self.status,
"details": self.details,
}
}));
(status, body).into_response()
}
}
// --- Conversions from upstream crate errors ---
impl From<siren_core::error::Error> for Error {
fn from(error: siren_core::error::Error) -> Self {
Self::new(error.status, error.details)
}
}
impl From<siren_bot::error::Error> for Error {
fn from(error: siren_bot::error::Error) -> Self {
Self::new(error.status, error.details)
}
}
// --- Conversions from external crate errors ---
impl From<StatusCode> for Error {
fn from(status: StatusCode) -> Self {
Error {
status: status.as_u16(),
details: status
.canonical_reason()
.unwrap_or("Unknown error")
.to_string(),
}
}
}
impl From<reqwest::Error> for Error {
fn from(error: reqwest::Error) -> Self {
Self::new(500, format!("HTTP client error: {}", error))
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Self::new(500, format!("JSON error: {}", error))
}
}
impl From<jsonwebtoken::errors::Error> for Error {
fn from(error: jsonwebtoken::errors::Error) -> Self {
match error.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
Self::new(401, "Token expired".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidToken => Self::new(401, "Invalid token".to_string()),
_ => Self::new(500, format!("JWT error: {}", error)),
}
}
}
// Direct conversions for types used in API handlers that bypass the data abstraction layer
impl From<sqlx::Error> for Error {
fn from(error: sqlx::Error) -> Self {
let core_err: siren_core::error::Error = error.into();
core_err.into()
}
}
impl From<redis::RedisError> for Error {
fn from(error: redis::RedisError) -> Self {
let core_err: siren_core::error::Error = error.into();
core_err.into()
}
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Self {
Self::new(500, format!("IO error: {}", error))
}
}

View File

@@ -0,0 +1,619 @@
pub mod model;
use crate::{
AppState,
auth::{OptionalAuth, Session, csprng, middleware::check_bearer_auth},
error::{Error, Result},
};
use axum::{
Json,
Router,
extract::{
Path,
Query,
State,
WebSocketUpgrade,
ws::{Message, WebSocket},
},
http::StatusCode,
response::IntoResponse,
routing::{delete, get, post, put},
};
use futures_util::{SinkExt, StreamExt};
use model::{
ClientMessage,
CreateMapPayload,
GridCell,
GridMap,
GridToken,
MapPermission,
MapRole,
MapState,
ServerMessage,
UpdatePermissionPayload,
};
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::broadcast;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/maps", get(list_maps))
.route("/maps", post(create_map))
.route("/maps/{id}", get(get_map))
.route("/maps/{id}", delete(delete_map))
.route("/maps/{id}/permissions", get(list_permissions))
.route("/maps/{id}/permissions", put(update_permission))
.route("/maps/{id}/ws", get(ws_handler))
}
// ---------------------------------------------------------------------------
// Permission helpers
// ---------------------------------------------------------------------------
/// Fetch the role of `user_id` on `map_id`, or `None` if no record exists.
async fn get_user_role(map_id: &str, user_id: i64) -> crate::error::Result<Option<MapRole>> {
let pool = siren_core::data::pool();
let perm: Option<MapPermission> = sqlx::query_as(
"SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1 AND user_id = $2",
)
.bind(map_id)
.bind(user_id)
.fetch_optional(pool)
.await?;
Ok(perm.map(|p| p.role))
}
/// Returns whether the caller can view the map:
/// - Public maps: always true.
/// - Private maps: true only if the user has any role.
async fn can_view(map: &GridMap, session: &Option<Session>) -> bool {
if map.is_public {
return true;
}
let Some(s) = session else { return false };
let user_id = s.user_id as i64;
get_user_role(&map.id, user_id)
.await
.ok()
.flatten()
.is_some()
}
/// Returns whether the caller can edit the map (editor or owner role).
async fn can_edit(map: &GridMap, session: &Option<Session>) -> bool {
let Some(s) = session else { return false };
let user_id = s.user_id as i64;
get_user_role(&map.id, user_id)
.await
.ok()
.flatten()
.map(|r| r.can_edit())
.unwrap_or(false)
}
/// Returns whether the caller is the owner.
async fn is_owner(map: &GridMap, session: &Option<Session>) -> bool {
let Some(s) = session else { return false };
let user_id = s.user_id as i64;
get_user_role(&map.id, user_id)
.await
.ok()
.flatten()
.map(|r| r.is_owner())
.unwrap_or(false)
}
// ---------------------------------------------------------------------------
// REST handlers
// ---------------------------------------------------------------------------
pub async fn list_maps(OptionalAuth(session): OptionalAuth) -> Result<Json<Vec<GridMap>>> {
let pool = siren_core::data::pool();
let maps: Vec<GridMap> = match &session {
Some(s) => {
let user_id = s.user_id as i64;
sqlx::query_as(
"SELECT DISTINCT gm.*
FROM grid_maps gm
LEFT JOIN map_permissions mp ON mp.map_id = gm.id AND mp.user_id = $1
WHERE gm.is_public = TRUE OR mp.user_id IS NOT NULL
ORDER BY gm.created_at DESC",
)
.bind(user_id)
.fetch_all(pool)
.await?
}
None => {
sqlx::query_as("SELECT * FROM grid_maps WHERE is_public = TRUE ORDER BY created_at DESC")
.fetch_all(pool)
.await?
}
};
Ok(Json(maps))
}
pub async fn create_map(
OptionalAuth(session): OptionalAuth,
Json(payload): Json<CreateMapPayload>,
) -> Result<(StatusCode, Json<GridMap>)> {
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
let user_id = session.user_id as i64;
let map_id = csprng(32);
let pool = siren_core::data::pool();
let map: GridMap = sqlx::query_as(
"INSERT INTO grid_maps (id, name, is_public, owner_id)
VALUES ($1, $2, $3, $4)
RETURNING *",
)
.bind(&map_id)
.bind(&payload.name)
.bind(payload.is_public)
.bind(user_id)
.fetch_one(pool)
.await?;
// Auto-assign the creator as owner in map_permissions
sqlx::query("INSERT INTO map_permissions (map_id, user_id, role) VALUES ($1, $2, 'owner')")
.bind(&map_id)
.bind(user_id)
.execute(pool)
.await?;
Ok((StatusCode::CREATED, Json(map)))
}
pub async fn get_map(
OptionalAuth(session): OptionalAuth,
Path(id): Path<String>,
) -> Result<Json<MapState>> {
let pool = siren_core::data::pool();
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
.bind(&id)
.fetch_optional(pool)
.await?;
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
if !can_view(&map, &session).await {
return Err(StatusCode::FORBIDDEN.into());
}
let cells: Vec<GridCell> = sqlx::query_as("SELECT * FROM grid_cells WHERE map_id = $1")
.bind(&id)
.fetch_all(pool)
.await?;
let tokens: Vec<GridToken> = sqlx::query_as("SELECT * FROM grid_tokens WHERE map_id = $1")
.bind(&id)
.fetch_all(pool)
.await?;
Ok(Json(MapState { map, cells, tokens }))
}
pub async fn delete_map(
OptionalAuth(session): OptionalAuth,
Path(id): Path<String>,
) -> Result<StatusCode> {
let pool = siren_core::data::pool();
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
.bind(&id)
.fetch_optional(pool)
.await?;
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
if !is_owner(&map, &session).await {
return Err(StatusCode::FORBIDDEN.into());
}
sqlx::query("DELETE FROM grid_maps WHERE id = $1")
.bind(&id)
.execute(pool)
.await?;
Ok(StatusCode::NO_CONTENT)
}
// ---------------------------------------------------------------------------
// Permission management
// ---------------------------------------------------------------------------
pub async fn list_permissions(
OptionalAuth(session): OptionalAuth,
Path(id): Path<String>,
) -> Result<Json<Vec<MapPermission>>> {
let pool = siren_core::data::pool();
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
.bind(&id)
.fetch_optional(pool)
.await?;
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
if !is_owner(&map, &session).await {
return Err(StatusCode::FORBIDDEN.into());
}
let perms: Vec<MapPermission> =
sqlx::query_as("SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1")
.bind(&id)
.fetch_all(pool)
.await?;
Ok(Json(perms))
}
pub async fn update_permission(
OptionalAuth(session): OptionalAuth,
Path(id): Path<String>,
Json(payload): Json<UpdatePermissionPayload>,
) -> Result<StatusCode> {
let pool = siren_core::data::pool();
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
.bind(&id)
.fetch_optional(pool)
.await?;
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
if !is_owner(&map, &session).await {
return Err(StatusCode::FORBIDDEN.into());
}
// Prevent the owner from removing their own owner record
let caller_id = session.as_ref().map(|s| s.user_id as i64).unwrap_or(0);
if payload.user_id == caller_id && payload.role.as_ref().map(|r| r.is_owner()) == Some(false) {
return Err(Error::from(StatusCode::UNPROCESSABLE_ENTITY));
}
match payload.role {
Some(role) => {
sqlx::query(
"INSERT INTO map_permissions (map_id, user_id, role)
VALUES ($1, $2, $3)
ON CONFLICT (map_id, user_id) DO UPDATE SET role = EXCLUDED.role",
)
.bind(&id)
.bind(payload.user_id)
.bind(role)
.execute(pool)
.await?;
}
None => {
sqlx::query("DELETE FROM map_permissions WHERE map_id = $1 AND user_id = $2")
.bind(&id)
.bind(payload.user_id)
.execute(pool)
.await?;
}
}
Ok(StatusCode::NO_CONTENT)
}
// ---------------------------------------------------------------------------
// WebSocket handler
// ---------------------------------------------------------------------------
#[derive(Deserialize)]
pub struct WsQuery {
/// Optional Bearer token passed as a query parameter for WS auth.
token: Option<String>,
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
Path(map_id): Path<String>,
Query(query): Query<WsQuery>,
) -> impl IntoResponse {
// Resolve the session from query param (WS can't easily send headers)
let session: Option<Session> = match query.token {
Some(ref tok) => check_bearer_auth(tok).await.ok(),
None => None,
};
ws.on_upgrade(move |socket| handle_socket(socket, state, map_id, session))
}
async fn handle_socket(
socket: WebSocket,
state: Arc<AppState>,
map_id: String,
session: Option<Session>,
) {
// Load the map and verify the caller can view it
let map_state = match fetch_map_state(&map_id).await {
Ok(ms) => ms,
Err(_) => return, // map doesn't exist
};
if !can_view(&map_state.map, &session).await {
// Refuse the connection silently (upgrade already happened; just close)
return;
}
let editor = can_edit(&map_state.map, &session).await;
// Get or create a broadcast channel for this map
let tx = state
.map_rooms
.entry(map_id.clone())
.or_insert_with(|| {
let (tx, _) = broadcast::channel(256);
tx
})
.clone();
let mut rx = tx.subscribe();
let (mut ws_tx, mut ws_rx) = socket.split();
// Send the current full map state to the newly connected client
let init_msg = ServerMessage::State {
cells: map_state.cells,
tokens: map_state.tokens,
colors: map_state.map.colors,
};
if let Ok(json) = serde_json::to_string(&init_msg) {
let _ = ws_tx.send(Message::Text(json.into())).await;
}
// Task 1: forward broadcast messages to this socket
let mut send_task = tokio::spawn(async move {
while let Ok(json) = rx.recv().await {
if ws_tx.send(Message::Text(json.into())).await.is_err() {
break;
}
}
});
// Task 2: receive messages from this client, persist, and broadcast
let tx_clone = tx.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = ws_rx.next().await {
match msg {
Message::Text(text) => {
handle_client_message(&text, &map_id, editor, &tx_clone).await;
}
Message::Close(_) => break,
_ => {}
}
}
});
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
}
}
async fn fetch_map_state(map_id: &str) -> crate::error::Result<MapState> {
let pool = siren_core::data::pool();
let map: GridMap = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
.bind(map_id)
.fetch_one(pool)
.await?;
let cells: Vec<GridCell> = sqlx::query_as("SELECT * FROM grid_cells WHERE map_id = $1")
.bind(map_id)
.fetch_all(pool)
.await?;
let tokens: Vec<GridToken> = sqlx::query_as("SELECT * FROM grid_tokens WHERE map_id = $1")
.bind(map_id)
.fetch_all(pool)
.await?;
Ok(MapState { map, cells, tokens })
}
async fn handle_client_message(
raw: &str,
map_id: &str,
can_edit: bool,
tx: &broadcast::Sender<String>,
) {
let client_msg: ClientMessage = match serde_json::from_str(raw) {
Ok(m) => m,
Err(e) => {
log::warn!("Invalid WS message: {e}");
return;
}
};
// All mutating messages require editor or owner role
if !can_edit {
let err = ServerMessage::Error {
message: "You do not have permission to edit this map.".into(),
};
if let Ok(json) = serde_json::to_string(&err) {
let _ = tx.send(json);
}
return;
}
let pool = siren_core::data::pool();
let server_msg: Option<ServerMessage> = match client_msg {
ClientMessage::PaintCell { x, y, color } => {
let result = sqlx::query(
"INSERT INTO grid_cells (map_id, x, y, color)
VALUES ($1, $2, $3, $4)
ON CONFLICT (map_id, x, y) DO UPDATE SET color = EXCLUDED.color",
)
.bind(map_id)
.bind(x)
.bind(y)
.bind(&color)
.execute(pool)
.await;
match result {
Ok(_) => Some(ServerMessage::CellPainted { x, y, color }),
Err(e) => {
log::error!("DB error painting cell: {e}");
None
}
}
}
ClientMessage::PaintCells { cells } => {
let mut tx_db = match pool.begin().await {
Ok(t) => t,
Err(e) => {
log::error!("DB error starting transaction for batch paint: {e}");
return;
}
};
let mut ok = true;
for cell in &cells {
let res = sqlx::query(
"INSERT INTO grid_cells (map_id, x, y, color)
VALUES ($1, $2, $3, $4)
ON CONFLICT (map_id, x, y) DO UPDATE SET color = EXCLUDED.color",
)
.bind(map_id)
.bind(cell.x)
.bind(cell.y)
.bind(&cell.color)
.execute(&mut *tx_db)
.await;
if let Err(e) = res {
log::error!("DB error in batch paint cell ({},{}): {e}", cell.x, cell.y);
ok = false;
break;
}
}
if ok {
if let Err(e) = tx_db.commit().await {
log::error!("DB error committing batch paint: {e}");
None
} else {
Some(ServerMessage::CellsBatchPainted { cells })
}
} else {
let _ = tx_db.rollback().await;
None
}
}
ClientMessage::EraseCell { x, y } => {
let result = sqlx::query("DELETE FROM grid_cells WHERE map_id = $1 AND x = $2 AND y = $3")
.bind(map_id)
.bind(x)
.bind(y)
.execute(pool)
.await;
match result {
Ok(_) => Some(ServerMessage::CellErased { x, y }),
Err(e) => {
log::error!("DB error erasing cell: {e}");
None
}
}
}
ClientMessage::AddToken { x, y, label, color } => {
let token_id = csprng(16);
let result: sqlx::Result<GridToken> = sqlx::query_as(
"INSERT INTO grid_tokens (id, map_id, x, y, label, color)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING *",
)
.bind(&token_id)
.bind(map_id)
.bind(x)
.bind(y)
.bind(&label)
.bind(&color)
.fetch_one(pool)
.await;
match result {
Ok(token) => Some(ServerMessage::TokenAdded {
id: token.id,
x: token.x,
y: token.y,
label: token.label,
color: token.color,
}),
Err(e) => {
log::error!("DB error adding token: {e}");
None
}
}
}
ClientMessage::MoveToken { id, x, y } => {
let result =
sqlx::query("UPDATE grid_tokens SET x = $1, y = $2 WHERE id = $3 AND map_id = $4")
.bind(x)
.bind(y)
.bind(&id)
.bind(map_id)
.execute(pool)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => Some(ServerMessage::TokenMoved { id, x, y }),
Ok(_) => None,
Err(e) => {
log::error!("DB error moving token: {e}");
None
}
}
}
ClientMessage::DeleteToken { id } => {
let result = sqlx::query("DELETE FROM grid_tokens WHERE id = $1 AND map_id = $2")
.bind(&id)
.bind(map_id)
.execute(pool)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => Some(ServerMessage::TokenDeleted { id }),
Ok(_) => None,
Err(e) => {
log::error!("DB error deleting token: {e}");
None
}
}
}
ClientMessage::UpdateColors { colors } => {
let result =
sqlx::query("UPDATE grid_maps SET colors = $1, updated_at = NOW() WHERE id = $2")
.bind(&colors)
.bind(map_id)
.execute(pool)
.await;
match result {
Ok(_) => Some(ServerMessage::ColorsUpdated { colors }),
Err(e) => {
log::error!("DB error updating colors: {e}");
None
}
}
}
};
if let Some(msg) = server_msg {
if let Ok(json) = serde_json::to_string(&msg) {
let _ = tx.send(json);
}
}
}

View File

@@ -0,0 +1,190 @@
use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize};
// ---------------------------------------------------------------------------
// Map Role / Permission
// ---------------------------------------------------------------------------
#[derive(Serialize, Deserialize, sqlx::Type, Clone, Debug, PartialEq, Eq)]
#[sqlx(type_name = "text", rename_all = "lowercase")]
#[serde(rename_all = "lowercase")]
pub enum MapRole {
Owner,
Editor,
Viewer,
}
impl MapRole {
/// Returns true if this role can mutate map content (paint, tokens, colors).
pub fn can_edit(&self) -> bool {
matches!(self, MapRole::Owner | MapRole::Editor)
}
/// Returns true if this role can manage permissions and delete the map.
pub fn is_owner(&self) -> bool {
matches!(self, MapRole::Owner)
}
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct MapPermission {
pub map_id: String,
pub user_id: i64,
pub role: MapRole,
}
// ---------------------------------------------------------------------------
// Grid Map
// ---------------------------------------------------------------------------
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct GridMap {
pub id: String,
pub name: String,
pub is_public: bool,
pub owner_id: i64,
pub colors: Vec<String>,
pub created_at: NaiveDateTime,
pub updated_at: NaiveDateTime,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct CreateMapPayload {
pub name: String,
#[serde(default)]
pub is_public: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct UpdatePermissionPayload {
/// Discord user ID of the target user.
pub user_id: i64,
/// New role to assign. Omit (null) to remove the permission entry.
pub role: Option<MapRole>,
}
// ---------------------------------------------------------------------------
// Grid Cell (no id column — composite PK in DB)
// ---------------------------------------------------------------------------
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct GridCell {
pub map_id: String,
pub x: i32,
pub y: i32,
pub color: String,
}
/// Lightweight cell used for batch operations.
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct CellPatch {
pub x: i32,
pub y: i32,
pub color: String,
}
// ---------------------------------------------------------------------------
// Grid Token
// ---------------------------------------------------------------------------
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct GridToken {
pub id: String,
pub map_id: String,
pub x: i32,
pub y: i32,
pub label: String,
pub color: String,
}
// ---------------------------------------------------------------------------
// Full map state (used on initial WS connect and REST GET)
// ---------------------------------------------------------------------------
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct MapState {
pub map: GridMap,
pub cells: Vec<GridCell>,
pub tokens: Vec<GridToken>,
}
// ---------------------------------------------------------------------------
// WebSocket message types
// ---------------------------------------------------------------------------
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientMessage {
PaintCell {
x: i32,
y: i32,
color: String,
},
PaintCells {
cells: Vec<CellPatch>,
},
EraseCell {
x: i32,
y: i32,
},
AddToken {
x: i32,
y: i32,
label: String,
color: String,
},
MoveToken {
id: String,
x: i32,
y: i32,
},
DeleteToken {
id: String,
},
UpdateColors {
colors: Vec<String>,
},
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerMessage {
State {
cells: Vec<GridCell>,
tokens: Vec<GridToken>,
colors: Vec<String>,
},
CellPainted {
x: i32,
y: i32,
color: String,
},
CellsBatchPainted {
cells: Vec<CellPatch>,
},
CellErased {
x: i32,
y: i32,
},
TokenAdded {
id: String,
x: i32,
y: i32,
label: String,
color: String,
},
TokenMoved {
id: String,
x: i32,
y: i32,
},
TokenDeleted {
id: String,
},
ColorsUpdated {
colors: Vec<String>,
},
Error {
message: String,
},
}

View File

@@ -0,0 +1,20 @@
pub mod app;
mod app_state;
pub mod audio;
pub mod auth;
pub mod dice;
pub mod error;
pub mod grid;
pub use app::App;
pub use app_state::AppState;
use axum::Router;
use std::sync::Arc;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.nest("/auth", auth::get_routes())
.nest("/audio/{guild_id}", audio::get_routes())
.nest("/dice", dice::get_routes())
.nest("/grid", grid::get_routes())
}

View File

@@ -0,0 +1,22 @@
[package]
name = "siren-bot"
edition.workspace = true
version.workspace = true
rust-version.workspace = true
authors.workspace = true
[dependencies]
siren-core = { workspace = true }
tokio = { workspace = true }
log = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serenity = { workspace = true }
songbird = { workspace = true }
symphonia = { workspace = true }
reqwest = { workspace = true }
rand = { workspace = true }
chrono = { workspace = true }
regex = { workspace = true }
uuid = { workspace = true }
lazy_static = { workspace = true }

View File

@@ -0,0 +1,67 @@
use serenity::all::{
CommandInteraction,
Context,
CreateInteractionResponse,
CreateInteractionResponseMessage,
CreateMessage,
EditInteractionResponse,
InteractionResponseFlags,
Message,
ModalInteraction,
UserId,
};
pub async fn process_message(ctx: &Context, command: &CommandInteraction, private: bool) {
create_message_response(&ctx, &command, "Processing...".to_string(), private).await;
}
pub async fn user_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> {
let data = CreateMessage::new().content(content.to_owned());
match user_id.dm(ctx, data).await {
Ok(message) => Some(message),
Err(err) => {
log::error!("Failed to create direct message for {content}\n{err}");
None
}
}
}
pub async fn create_message_response(
ctx: &Context,
command: &CommandInteraction,
content: String,
private: bool,
) {
let mut data = CreateInteractionResponseMessage::new().content(content.to_owned());
if private {
data = data.flags(InteractionResponseFlags::EPHEMERAL);
}
let builder = CreateInteractionResponse::Message(data);
match command.create_response(&ctx.http, builder).await {
Ok(_) => {}
Err(err) => {
log::error!("Failed to create message response for {content}\n{err}");
}
};
}
pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) {
let data = CreateInteractionResponseMessage::new();
let builder = CreateInteractionResponse::Message(data);
match modal.create_response(&ctx.http, builder).await {
Ok(_) => {}
Err(err) => {
log::error!("Failed to create modal response\n{err}");
}
}
}
pub async fn edit_response(ctx: &Context, command: &CommandInteraction, content: String) {
let builder = EditInteractionResponse::new().content(content.to_owned());
match command.edit_response(&ctx.http, builder).await {
Ok(_) => {}
Err(err) => {
log::error!("Failed to create response for {content}\n{err}");
}
}
}

View File

@@ -0,0 +1,89 @@
use crate::error::{Error, Result};
use reqwest::Url;
use serenity::{
all::UserId,
client::Cache,
model::prelude::{ChannelId, GuildId},
};
use songbird::Songbird;
use std::sync::Arc;
pub mod mute;
pub mod pause;
pub mod play;
pub mod resume;
pub mod skip;
pub mod stop;
pub mod volume;
/**
* Finds a voice channel that the user is currently in, and attempts to join it.
*/
pub async fn join_voice_channel(
cache: &Arc<Cache>,
manager: &Arc<Songbird>,
guild_id: &GuildId,
user_id: &UserId,
) -> Result<ChannelId> {
let channel_id = find_voice_channel(cache, guild_id, user_id)?;
log::debug!("<{}> Joining channel {}", guild_id.get(), channel_id.get());
match manager
.join(guild_id.to_owned(), channel_id.to_owned())
.await
{
Ok(_) => Ok(channel_id),
Err(e) => {
if e.should_leave_server() || e.should_reconnect_driver() {
log::debug!("<{}> Cleaning up failed voice connection", guild_id.get());
let _ = manager.remove(*guild_id).await;
}
Err(e.into())
}
}
}
/**
* Leaves a voice channel.
*/
pub async fn leave_voice_channel(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<()> {
if manager.get(guild_id.to_owned()).is_some() {
log::debug!("<{}> Disconnecting from channel", guild_id.get());
manager.remove(*guild_id).await?;
}
Ok(())
}
/**
* Validates whether the given string is a properly formatted URL.
*
* Returns `true` if the input string is a valid URL, otherwise `false`.
*/
fn is_valid_url(url: &str) -> bool {
Url::parse(url).is_ok()
}
/**
* Finds a voice channel that the user is currently in.
*/
fn find_voice_channel(
cache: &Arc<Cache>,
guild_id: &GuildId,
user_id: &UserId,
) -> Result<ChannelId> {
let guild = match guild_id.to_guild_cached(cache) {
Some(g) => g,
None => return Err(Error::new(404, "Guild not found".to_string())),
};
match guild
.voice_states
.get(&user_id)
.and_then(|voice_state| voice_state.channel_id)
{
Some(channel) => Ok(channel),
None => Err(Error::new(
400,
"User is not in a voice channel".to_string(),
)),
}
}

View File

@@ -0,0 +1,54 @@
use crate::{
chat::{edit_response, process_message},
handler::get_songbird,
};
use serenity::{
all::{CommandInteraction, CreateCommand},
prelude::*,
};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response
process_message(&ctx, &command, false).await;
// Get the songbird manager
let manager = get_songbird();
// Extract the guild ID
let guild_id = match &command.guild_id {
Some(guild_id) => guild_id,
None => {
edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return;
}
};
// Mute the track
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
let mut handler = handler_lock.lock().await;
let is_muted = handler.is_mute();
match handler.mute(!is_muted).await {
Ok(_) => {
if is_muted {
log::debug!("<{guild_id}> Unmuted");
edit_response(&ctx, &command, "Unmuted".to_string()).await;
} else {
log::debug!("<{guild_id}> Muted");
edit_response(&ctx, &command, "Muted".to_string()).await;
}
}
Err(err) => {
edit_response(&ctx, &command, format!("Failed to mute: {}", err)).await;
}
}
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("mute").description("Mute/unmute Siren")
}

View File

@@ -0,0 +1,63 @@
use crate::{
chat::{edit_response, process_message},
error::{Error, Result},
handler::get_songbird,
};
use serenity::{
all::{CommandInteraction, CreateCommand, GuildId},
prelude::*,
};
use songbird::Songbird;
use std::sync::Arc;
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response
process_message(&ctx, &command, false).await;
// Get the songbird manager
let manager = get_songbird();
// Extract the guild ID
let guild_id = match &command.guild_id {
Some(guild_id) => guild_id,
None => {
edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return;
}
};
// Pause the track
match pause_track(manager, guild_id).await {
Ok(_) => {
log::debug!("<{guild_id}> Paused the track");
edit_response(&ctx, &command, "Pausing the track".to_string()).await;
}
Err(err) => edit_response(&ctx, &command, format!("Failed to pause: {}", err)).await,
}
}
pub async fn pause_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<()> {
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
let handler = handler_lock.lock().await;
match handler.queue().current() {
Some(track) => track.pause()?,
None => {
return Err(Error {
status: 404,
details: "No track is currently playing".to_string(),
});
}
}
};
Ok(())
}
pub fn register() -> CreateCommand {
CreateCommand::new("pause").description("Pause the current track")
}

View File

@@ -0,0 +1,218 @@
use super::{is_valid_url, join_voice_channel, leave_voice_channel};
use crate::{
chat::{create_message_response, edit_response, process_message},
error::{Error, Result},
handler::{get_client, get_songbird},
ytdlp::{YtDlp, YtDlpItem},
};
use serenity::{
all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption},
async_trait,
model::prelude::GuildId,
prelude::*,
};
use siren_core::data::guilds::GuildCache;
use songbird::{
Event,
EventHandler,
Songbird,
TrackEvent,
input::{Input, YoutubeDl},
tracks::TrackHandle,
};
use std::sync::Arc;
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Process the command options
let track_url = match command.data.options.first() {
Some(o) => o.value.as_str().unwrap(),
None => {
log::warn!(
"<{}> {} attempted to play a track without a track option",
command.guild_id.unwrap(),
command.user.id.get()
);
create_message_response(&ctx, &command, "Track option is missing".to_string(), false).await;
return;
}
};
// Create the initial response
process_message(&ctx, &command, false).await;
// Get the songbird manager
let manager = get_songbird();
// Extract the guild ID
let guild_id = match &command.guild_id {
Some(guild_id) => guild_id,
None => {
edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return;
}
};
// Join the user's voice channel
match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user.id).await {
Ok(channel_id) => {
log::debug!(
"<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}"
);
// Handle the track url
match enqueue_track(manager, guild_id.to_owned(), track_url).await {
Ok(items) => {
let mut message = format!("Added {} tracks", items.len());
if items.len() == 0 {
message = "No tracks were played".to_string();
log::warn!("<{guild_id}> No tracks were played");
if let Err(err) = leave_voice_channel(&manager, guild_id).await {
log::error!("Failed to leave voice channel: {}", err);
};
} else if items.len() == 1 {
message = format!("Added **{}**", items[0].get_title());
}
edit_response(&ctx, &command, message).await;
}
Err(err) => {
log::error!("Failed to play track: {}", err);
if let Err(err) = leave_voice_channel(&manager, guild_id).await {
log::error!("Failed to leave voice channel: {}", err);
}
edit_response(&ctx, &command, format!("Failed to play track: {}", err)).await;
}
};
}
Err(err) => {
log::warn!("<{guild_id}> Failed to join voice channel: {}", err);
edit_response(&ctx, &command, format!("{}", err)).await;
}
}
}
pub async fn enqueue_track(
manager: &Arc<Songbird>,
guild_id: GuildId,
track_url: &str,
) -> Result<Vec<YtDlpItem>> {
let mut playlist_items: Vec<YtDlpItem> = Vec::new();
if let Some(handler_lock) = manager.get(guild_id) {
let mut handler = handler_lock.lock().await;
let guild = GuildCache::find_by_id(guild_id.get() as i64).await.unwrap();
let valid = is_valid_url(&track_url);
// Check if the URL is valid
if !valid {
log::warn!("<{guild_id}> Invalid track url: {}", track_url);
return Err(Error::new(422, format!("Invalid track url: {}", track_url)));
}
playlist_items = get_ytdlp_items(&track_url)?;
// Add each track to the queue
for item in &playlist_items {
let volume = guild.volume as f32 / 100.0;
let http_client = get_client();
let source = YoutubeDl::new(http_client.to_owned(), item.get_url().to_owned());
let input: Input = source.into();
let track_title = item.get_title().to_owned();
let track_handle: TrackHandle;
track_handle = handler.enqueue_input(input).await;
// Set the volume
let _ = track_handle.set_volume(volume);
log::debug!("<{guild_id}> Added track: {}", track_title);
handler.remove_all_global_events();
handler.add_global_event(
Event::Track(TrackEvent::End),
TrackEndNotifier {
guild_id,
call: manager.clone(),
},
);
}
if handler.queue().is_empty() {
let _ = handler.queue().resume();
}
}
Ok(playlist_items)
}
pub fn get_ytdlp_items(url: &str) -> Result<Vec<YtDlpItem>> {
let output = YtDlp::new()
.arg("--flat-playlist")
.arg("--dump-json")
.arg("--no-check-formats")
.arg(url)
.execute()?;
// Check if yt-dlp exited successfully; log stderr if not
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(Error::new(
500,
format!("yt-dlp failed ({}): {}", output.status, stderr.trim()),
));
}
let stdout = String::from_utf8(output.stdout)?;
let items: Vec<YtDlpItem> = stdout
.split('\n')
.filter_map(|line| {
if line.is_empty() {
None
} else {
Some(
serde_json::from_slice::<YtDlpItem>(line.as_bytes())
.map_err(|err| Error::new(500, err.to_string())),
)
}
})
.filter_map(|parsed| match parsed {
Ok(item) => Some(item),
Err(err) => {
log::warn!("Failed to parse yt-dlp item: {}", err);
None
}
})
.collect();
Ok(items)
}
pub fn register() -> CreateCommand {
CreateCommand::new("play")
.description("Plays the given track")
.add_option(
CreateCommandOption::new(CommandOptionType::String, "track", "The track to be played")
.required(true),
)
}
struct TrackEndNotifier {
pub call: Arc<Songbird>,
pub guild_id: GuildId,
}
#[async_trait]
impl EventHandler for TrackEndNotifier {
async fn act(&self, ctx: &songbird::events::EventContext<'_>) -> Option<songbird::events::Event> {
if let songbird::EventContext::Track(_track_list) = ctx {
if let Some(call) = self.call.get(self.guild_id) {
let mut handler = call.lock().await;
if handler.queue().is_empty() {
log::debug!("Queue is empty, leaving voice channel");
handler.leave().await.unwrap();
}
}
}
None
}
}

View File

@@ -0,0 +1,63 @@
use crate::{
chat::{edit_response, process_message},
error::{Error, Result},
handler::get_songbird,
};
use serenity::{
all::{CommandInteraction, CreateCommand, GuildId},
prelude::*,
};
use songbird::Songbird;
use std::sync::Arc;
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response
process_message(&ctx, &command, false).await;
// Get the songbird manager
let manager = get_songbird();
// Extract the guild ID
let guild_id = match &command.guild_id {
Some(guild_id) => guild_id,
None => {
edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return;
}
};
// Resume the track
match resume_track(manager, guild_id).await {
Ok(_) => {
log::debug!("<{guild_id}> Resumed the track");
edit_response(&ctx, &command, "resuming the track".to_string()).await;
}
Err(err) => edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await,
}
}
pub async fn resume_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<()> {
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
let handler = handler_lock.lock().await;
match handler.queue().current() {
Some(track) => track.play()?,
None => {
return Err(Error {
status: 404,
details: "No track is currently playing".to_string(),
});
}
}
};
Ok(())
}
pub fn register() -> CreateCommand {
CreateCommand::new("resume").description("Resume the current track")
}

View File

@@ -0,0 +1,48 @@
use crate::{
chat::{edit_response, process_message},
handler::get_songbird,
};
use serenity::{
all::{CommandInteraction, CreateCommand},
prelude::*,
};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response
process_message(&ctx, &command, false).await;
// Get the songbird manager
let manager = get_songbird();
// Extract the guild ID
let guild_id = match &command.guild_id {
Some(guild_id) => guild_id,
None => {
edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return;
}
};
// Skip the track
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
let handler = handler_lock.lock().await;
match handler.queue().skip() {
Ok(_) => {
log::debug!("<{guild_id}> Skipped the track");
edit_response(&ctx, &command, "Skipping the track".to_string()).await;
}
Err(err) => {
edit_response(&ctx, &command, format!("Failed to skip: {}", err)).await;
}
}
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("skip").description("Skip the current track")
}

View File

@@ -0,0 +1,42 @@
use crate::{
chat::{edit_response, process_message},
handler::get_songbird,
};
use serenity::{
all::{CommandInteraction, CreateCommand},
prelude::*,
};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response
process_message(&ctx, &command, false).await;
// Get the songbird manager
let manager = get_songbird();
// Extract the guild ID
let guild_id = match command.guild_id {
Some(g) => g,
None => {
edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return;
}
};
// Stop the track and clear the queue
if let Some(handler_lock) = manager.get(guild_id) {
let handler = handler_lock.lock().await;
handler.queue().stop();
log::debug!("<{guild_id}> Stopped the track");
edit_response(&ctx, &command, "Stopping the tracks".to_string()).await;
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("stop").description("Stop the current track and clear the queue")
}

View File

@@ -0,0 +1,92 @@
use crate::{
chat::{create_message_response, edit_response, process_message},
handler::get_songbird,
};
use serenity::{
all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption},
model::prelude::GuildId,
prelude::*,
};
use siren_core::data::guilds::GuildCache;
use songbird::Songbird;
use std::sync::Arc;
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Process the command options
let volume = match command.data.options.first() {
Some(o) => o.value.as_i64().unwrap() as i32,
None => {
log::warn!(
"{} attempted to change the volume without a volume option",
command.user.id.get()
);
create_message_response(
&ctx,
&command,
"Volume option is missing".to_string(),
false,
)
.await;
return;
}
};
// Create the initial response
process_message(&ctx, &command, false).await;
// Get the songbird manager
let manager = get_songbird();
// Extract the guild ID
let guild_id = match &command.guild_id {
Some(guild_id) => guild_id,
None => {
edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return;
}
};
// Set the volume
set_volume(&manager, guild_id, volume).await;
log::debug!("<{guild_id}> Setting the volume to {}", volume);
edit_response(&ctx, &command, format!("Setting the volume to {}", volume)).await;
}
pub async fn set_volume(manager: &Arc<Songbird>, guild_id: &GuildId, volume: i32) {
// Format volume to f32 bound between 0.0 and 1.0
let volume = std::cmp::min(100, std::cmp::max(0, volume));
let bound_volume = volume as f32 / 100.0;
// Update the guild cache
let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64).await.unwrap();
guild_cache.volume = volume;
guild_cache.update().await.unwrap();
// Update the volume of the songbird handler
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
let handler = handler_lock.lock().await;
for (_, track_handle) in handler.queue().current_queue().iter().enumerate() {
if let Err(err) = track_handle.set_volume(bound_volume) {
log::error!("Unable to set volume: {err}");
}
}
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("volume")
.description("Set the audio player volume")
.add_option(
CreateCommandOption::new(
CommandOptionType::Integer,
"volume",
"Volume between 0 and 100",
)
.required(true),
)
}

View File

@@ -0,0 +1 @@
pub mod schedule;

View File

@@ -0,0 +1,146 @@
use crate::chat::process_message;
use chrono::{DateTime, NaiveDate, TimeZone, Utc};
use regex::Regex;
use serenity::all::{
Color,
CommandInteraction,
CommandOptionType,
Context,
CreateCommand,
CreateCommandOption,
CreateEmbed,
CreateEmbedFooter,
EditInteractionResponse,
Timestamp,
};
use siren_core::data::events::Event;
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response
process_message(&ctx, &command, true).await;
// Process the command options
let title = command.data.options.get(0).unwrap().value.as_str().unwrap();
// let datetime_string = command.data.options.get(1).unwrap().value.as_str().unwrap();
let description = command
.data
.options
.get(2)
.map(|option| option.value.as_str().unwrap());
// Parse the guild ID and author ID
let guild_id = command.guild_id.unwrap();
let author_id = command.user.id;
// Parse the datetime string into a DateTime object
let date_time = Utc::now();
// Create the event
let event = Event {
id: uuid::Uuid::new_v4(),
guild_id: guild_id.get() as i64,
author_id: author_id.get() as i64,
title: title.to_string(),
date_time,
description: description.map(|s| s.to_string()),
rsvp: vec![],
};
// Save the event to the database
event.insert().await.unwrap();
// Create the response embed
let embed_footer = CreateEmbedFooter::new(format!("Created by {}", command.user.name));
let embed = CreateEmbed::new()
.title(title)
.color(Color::TEAL)
.timestamp(Timestamp::now())
.description(description.unwrap_or(""))
.field("Time", date_time.to_rfc2822(), false)
.footer(embed_footer);
let builder = EditInteractionResponse::new().embed(embed);
match command.edit_response(&ctx.http, builder).await {
Ok(_) => {}
Err(err) => {
log::error!("Failed to create schedule embed: {err}");
}
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("schedule")
.description("Schedule a new event")
.add_option(
CreateCommandOption::new(CommandOptionType::String, "title", "The title of the event")
.required(true),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::String,
"datetime",
"The date and time of the event",
)
.required(true),
)
.add_option(CreateCommandOption::new(
CommandOptionType::String,
"description",
"A description of the event",
))
}
// The datetime string can be formatted in the following ways:
// (in) XX <seconds, minutes, hours, days, weeks>
// (at) YYYY-MM-DD HH:MM (AM/PM)
// (at) MM DD (YYYY) HH:MM (AM/PM)
#[allow(dead_code)]
fn parse_datetime(input: &str) -> Option<DateTime<Utc>> {
let regexes = vec![
Regex::new(r"(?i)^\(?at\)?\s+(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2})\s*(AM|PM)?$").unwrap(),
Regex::new(r"(?i)^\(?at\)?\s+(\d{2})\s+(\d{2})\s*(\d{4})?\s+(\d{2}):(\d{2})\s*(AM|PM)?$")
.unwrap(),
// ... add other regexes here
];
for regex in regexes {
if let Some(captures) = regex.captures(input) {
if captures.len() == 7 {
// Matches the second format
let (year, month, day) = (
captures.get(1).unwrap().as_str().parse().unwrap_or(1970),
captures.get(2).unwrap().as_str().parse().unwrap_or(1),
captures.get(3).unwrap().as_str().parse().unwrap_or(1),
);
let (mut hour, minute) = (
captures.get(4).unwrap().as_str().parse().unwrap_or(0),
captures.get(5).unwrap().as_str().parse().unwrap_or(0),
);
if let Some(am_pm) = captures.get(6) {
if am_pm.as_str().eq_ignore_ascii_case("PM") && hour != 12 {
hour += 12;
}
if am_pm.as_str().eq_ignore_ascii_case("AM") && hour == 12 {
hour = 0;
}
}
// Create a NaiveDate instance from year, month, day
let naive_date =
NaiveDate::from_ymd_opt(year, month, day).expect("Invalid date parameters");
// Create a NaiveDateTime instance from NaiveDate and time components
let naive_time = naive_date
.and_hms_opt(hour, minute, 0)
.expect("Invalid time parameters");
// Convert the NaiveDateTime to a DateTime<Utc>
return Some(Utc.from_utc_datetime(&naive_time));
}
// handle other cases
}
}
None
}

View File

@@ -0,0 +1,2 @@
pub mod request_roll;
pub mod roll;

View File

@@ -0,0 +1,109 @@
use crate::{
chat::{create_message_response, edit_response},
commands::fun::roll::parse_dice,
};
use serenity::all::{
ButtonStyle,
CommandInteraction,
CommandOptionType,
Context,
CreateActionRow,
CreateButton,
CreateCommand,
CreateCommandOption,
CreateMessage,
Mentionable,
UserId,
};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Check if the roll result is hidden
let hidden = command
.data
.options
.iter()
.find(|opt| opt.name == "hidden")
.and_then(|o| o.value.as_bool())
.unwrap_or(false);
// Retrieve the user
let user = command
.data
.options
.iter()
.find(|opt| opt.name == "user")
.and_then(|o| o.value.as_mentionable())
.unwrap();
let user_id = UserId::new(user.get());
create_message_response(
ctx,
&command,
format!("Sending request to {}", user_id.mention()),
true,
)
.await;
let dice_string = command
.data
.options
.get(0)
.and_then(|o| o.value.as_str())
.map(|s| s.split_whitespace().collect::<String>())
.unwrap();
let dice_result = parse_dice(dice_string.as_str());
match dice_result {
Ok(dice) => {
let roll_button = CreateButton::new(format!(
"request_dice_roll|{}|{}|{}|{}|{}",
dice.0,
dice.1,
dice.2,
command.user.id.get(),
hidden
))
.label(format!("🎲 Roll {} 🎲", dice_string)) // The label you want on the button
.style(ButtonStyle::Primary);
let action_row = CreateActionRow::Buttons(vec![roll_button]);
let message = CreateMessage::new()
.content(format!("-# Roll requested from {}", command.user.mention()))
.components(vec![action_row]);
if let Err(why) = user_id.dm(ctx, message).await {
log::error!("failed to send request due to {}", why);
edit_response(ctx, command, "Unable to send dice request".to_string()).await;
};
}
Err(why) => {
edit_response(ctx, &command, why.to_string()).await;
}
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("requestroll")
.description("Request a dice roll from a user")
.add_option(
CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::Mentionable,
"user",
"User to receive the dice roll request",
)
.required(true),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::Boolean,
"hidden",
"Hide the dice roll from the user (Default: False",
)
.required(false),
)
}

View File

@@ -0,0 +1,236 @@
use crate::{
chat::{create_message_response, edit_response},
error::{Error, Result},
};
use rand::RngExt;
use serenity::all::{
CommandInteraction,
CommandOptionType,
Context,
CreateCommand,
CreateCommandOption,
CreateEmbed,
CreateMessage,
Mentionable,
UserId,
};
use siren_core::utils::{a_or_an, number_to_words};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Check if the roll result is private
let private = command
.data
.options
.iter()
.find(|opt| opt.name == "private")
.and_then(|o| o.value.as_bool())
.unwrap_or(true);
// Retrieve the user if present
let user = command
.data
.options
.iter()
.find(|opt| opt.name == "user")
.and_then(|o| o.value.as_mentionable());
create_message_response(ctx, &command, "Rolling...".to_string(), private).await;
let dice_string = match command
.data
.options
.get(0)
.and_then(|o| o.value.as_str())
.map(|s| s.split_whitespace().collect::<String>())
{
Some(dice_value) => dice_value,
None => {
log::warn!("Missing or invalid dice option");
let _ = edit_response(&ctx, &command, "Dice option is missing".to_string()).await;
return;
}
};
let dice = parse_dice(dice_string.as_str());
match dice {
Ok((count, sides, modifier)) => {
let total = roll_dice(count, sides, modifier);
let response = format!("(Rolled {})", format_roll(count, sides, modifier));
match user {
Some(id) => {
let user_id = UserId::new(id.get());
let roller_id = command.user.id;
send_roll_message(ctx, total, user_id, roller_id, &response).await;
edit_response(
&ctx,
command,
format!("Sending dice roll results to {}", &user_id.mention()),
)
.await;
}
None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await,
};
// Check for dice tracks
}
Err(why) => {
edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await;
}
}
}
pub async fn send_roll_message(
ctx: &Context,
total: i32,
user_id: UserId,
roller_id: UserId,
dice_string: &str,
) {
// Create the dice roll embed
let a = a_or_an(&number_to_words(total));
let embed = CreateEmbed::new()
.title("🎲 Received a dice roll! 🎲".to_string())
.color(0x00FF00)
.description(format!(
"{} rolled {} **{}**\n-# *{}*",
&roller_id.mention(),
a,
total,
dice_string
));
let message = CreateMessage::new().embed(embed);
if let Err(err) = user_id.dm(ctx, message).await {
log::error!("Could not send message: {}", err);
}
}
pub fn format_roll(count: u32, sides: u32, modifier: i32) -> String {
format!(
"{}d{}{}",
count,
sides,
if modifier > 0 {
format!("+{}", modifier)
} else if modifier < 0 {
format!("-{}", modifier)
} else {
"".to_string()
}
)
}
pub fn roll_dice(count: u32, sides: u32, modifier: i32) -> i32 {
let mut rolls = Vec::new();
let mut total = modifier;
for _ in 0..count {
let roll = rand::rng().random_range(1..=sides as i32);
total += roll;
rolls.push(roll);
}
total
}
pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32)> {
// If the input is just a number (e.g., "20" or "6"), assume it's the number of sides
if let Ok(n) = dice.parse::<u32>() {
return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers
}
// If the input starts with "d", assume it's shorthand for "1dX"
let dice = if dice.starts_with("d") {
format!("1{}", dice) // Prepend "1"
} else {
dice.to_string()
};
let mut parts = dice.split(['d', '+', '-'].as_ref());
let mut positive_modifier = true;
// Parse the dice count
let count = match parts.next() {
Some("") => 1, // Handle cases like "d6", assume 1 dice
Some(c) => match c.parse::<u32>() {
Ok(n) => n,
Err(_) => return Err(Error::new(400, format!("Invalid dice count: {}", c))),
},
None => return Err(Error::new(400, format!("Invalid dice string: {}", dice))),
};
// Parse the number of sides
let sides_part = parts
.next()
.ok_or_else(|| Error::new(400, format!("Invalid dice string: {}", dice)))?;
let sides = match sides_part.parse::<u32>() {
Ok(n) => {
if [4, 6, 8, 10, 12, 20, 100].contains(&n) {
n
} else {
return Err(Error::new(
400,
format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
n
),
));
}
}
Err(_) => {
return Err(Error::new(
400,
format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
sides_part
),
));
}
};
// Determine if there's a modifier (+ or -)
if dice.contains('+') {
positive_modifier = true;
} else if dice.contains('-') {
positive_modifier = false;
}
// Parse the modifier, if present
let modifier = match parts.next() {
Some(m) => match m.parse::<i32>() {
Ok(n) => {
if positive_modifier {
n
} else {
-n
}
}
Err(_) => return Err(Error::new(400, format!("Invalid dice modifier: {}", m))),
},
None => 0, // No modifier found
};
Ok((count, sides, modifier))
}
pub fn register() -> CreateCommand {
CreateCommand::new("roll")
.description("Roll dice")
.add_option(
CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::Boolean,
"private",
"Make the roll private (Default: True)",
)
.required(false),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::Mentionable,
"user",
"User to receive the roll results",
)
.required(false),
)
}

View File

@@ -0,0 +1,4 @@
pub mod audio;
pub mod event;
pub mod fun;
pub mod utility;

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,2 @@
pub mod help;
pub mod ping;

View File

@@ -0,0 +1,19 @@
use crate::chat::create_message_response;
use serenity::all::{CommandInteraction, Context, CreateCommand};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
log::debug!("Ping command executed");
if let Some(guild_id) = command.guild_id {
if let Some(guild) = guild_id.to_guild_cached(&ctx.cache) {
let owner_id = guild.owner_id;
if command.user.id == owner_id {}
}
}
create_message_response(&ctx, &command, "pong".to_string(), true).await;
}
pub fn register() -> CreateCommand {
CreateCommand::new("ping").description("Displays the bot latency")
}

View File

@@ -0,0 +1,89 @@
use serde::{Deserialize, Serialize};
use std::fmt;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Deserialize, Serialize)]
pub struct Error {
pub status: u16,
pub details: String,
}
impl Error {
pub fn new(status: u16, details: String) -> Self {
Self { status, details }
}
pub fn not_found(details: String) -> Self {
Self::new(404, details)
}
pub fn internal_server_error(details: String) -> Self {
Self::new(500, details)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.details.as_str())
}
}
impl std::error::Error for Error {
fn description(&self) -> &str {
&self.details
}
}
impl From<siren_core::error::Error> for Error {
fn from(error: siren_core::error::Error) -> Self {
Self::new(error.status, error.details)
}
}
impl From<serenity::Error> for Error {
fn from(error: serenity::Error) -> Self {
Self::new(500, format!("Discord error: {}", error))
}
}
impl From<songbird::error::JoinError> for Error {
fn from(error: songbird::error::JoinError) -> Self {
use std::error::Error as StdError;
let details = match error.source() {
Some(source) => format!("Unable to join channel: {} ({})", error, source),
None => format!("Unable to join channel: {}", error),
};
Self::new(500, details)
}
}
impl From<songbird::tracks::ControlError> for Error {
fn from(error: songbird::tracks::ControlError) -> Self {
Self::new(500, format!("Track control error: {}", error))
}
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Self {
Self::new(500, format!("IO error: {}", error))
}
}
impl From<std::string::FromUtf8Error> for Error {
fn from(error: std::string::FromUtf8Error) -> Self {
Self::new(500, format!("UTF-8 error: {}", error))
}
}
impl From<reqwest::Error> for Error {
fn from(error: reqwest::Error) -> Self {
Self::new(500, format!("HTTP client error: {}", error))
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Self::new(500, format!("JSON error: {}", error))
}
}

View File

@@ -0,0 +1,244 @@
use super::{chat::create_modal_response, commands};
use crate::{
HttpKey,
commands::fun::roll::{format_roll, roll_dice, send_roll_message},
};
use serenity::{
all::{
CreateInteractionResponse,
EditInteractionResponse,
Interaction,
ResumedEvent,
UnavailableGuild,
UserId,
},
async_trait,
model::{channel::Message, gateway::Ready},
prelude::*,
};
use siren_core::{
data::guilds::GuildCache,
utils::{a_or_an, number_to_words},
};
use songbird::Songbird;
use std::sync::{Arc, OnceLock};
pub struct BotHandler {
pub force_register: bool,
}
static REGISTERED: OnceLock<bool> = OnceLock::new();
static SONGBIRD: OnceLock<Arc<Songbird>> = OnceLock::new();
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
pub fn get_songbird() -> &'static Arc<Songbird> {
SONGBIRD.get().unwrap()
}
pub fn get_client() -> &'static reqwest::Client {
CLIENT.get().unwrap()
}
impl BotHandler {
pub fn new(force_register: bool) -> Self {
Self { force_register }
}
}
#[async_trait]
impl EventHandler for BotHandler {
async fn message(&self, _ctx: Context, msg: Message) {
// Ignore bot messages
if msg.author.bot {
return;
}
// Handle direct messages
if let None = msg.guild_id {
log::trace!("Received DM from {}: {}", msg.author, msg.content);
}
}
async fn ready(&self, ctx: Context, ready: Ready) {
if ready.guilds.is_empty() {
log::warn!("No ready guilds found");
}
if SONGBIRD.get().is_none() {
let songbird = songbird::get(&ctx).await.unwrap();
SONGBIRD
.set(songbird.clone())
.expect("Songbird value could not be set");
}
if CLIENT.get().is_none() {
let http_client = {
let data = ctx.data.read().await;
data
.get::<HttpKey>()
.cloned()
.expect("Guaranteed to exist in the typemap.")
};
CLIENT.set(http_client).ok();
}
// Update registered to prevent reloading the commands
if REGISTERED.get().is_some() {
return;
} else {
REGISTERED.set(true).ok();
}
log::debug!("Registering in {} guild(s)", ready.guilds.len());
for guild in ready.guilds {
update_guild_commands(&ctx, &guild, self.force_register).await;
}
}
async fn resume(&self, _: Context, _: ResumedEvent) {
log::trace!("Resumed");
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::Command(command) = interaction {
log::trace!(
"<{}> Received command: {}",
command.guild_id.unwrap(),
command.data.name
);
match command.data.name.as_str() {
// Match commands without returns
"play" => commands::audio::play::run(&ctx, &command).await,
"stop" => commands::audio::stop::run(&ctx, &command).await,
"pause" => commands::audio::pause::run(&ctx, &command).await,
"resume" => commands::audio::resume::run(&ctx, &command).await,
"mute" => commands::audio::mute::run(&ctx, &command).await,
"skip" => commands::audio::skip::run(&ctx, &command).await,
"volume" => commands::audio::volume::run(&ctx, &command).await,
"schedule" => commands::event::schedule::run(&ctx, &command).await,
"roll" => commands::fun::roll::run(&ctx, &command).await,
"requestroll" => commands::fun::request_roll::run(&ctx, &command).await,
"ping" => commands::utility::ping::run(&ctx, &command).await,
_ => {}
}
} else if let Interaction::Component(component) = interaction {
log::trace!("Received COMPONENT");
let custom_id = &component.data.custom_id;
if custom_id.starts_with("request_dice_roll") {
// Acknowledge the interaction
if let Err(err) = component
.create_response(ctx.http.clone(), CreateInteractionResponse::Acknowledge)
.await
{
log::error!("Could not create dice response: {}", err);
};
let parts = custom_id.split('|').collect::<Vec<&str>>();
if parts.len() == 6 {
let count = parts[1].parse().unwrap();
let sides = parts[2].parse().unwrap();
let modifier = parts[3].parse().unwrap();
let result = roll_dice(count, sides, modifier);
let response = format!("(Rolled {})", format_roll(count, sides, modifier));
let user_id = UserId::from(parts[4].parse::<u64>().unwrap());
let roller_id = component.user.id;
let hidden: bool = parts[5].parse().unwrap();
// Prepare the message based on visibility
let new_message = if hidden {
// For hidden rolls, only reveal "results sent" to the requester
format!("🎲 Results sent to {}\n-# {}", user_id.mention(), response)
} else {
// For public rolls, show the roll result
format!(
"🎲 You rolled {} {}\n-# {}",
a_or_an(&number_to_words(result)),
result,
response
)
};
// Edit the message to update the text and remove buttons
if let Err(err) = component
.edit_response(
ctx.http.clone(),
EditInteractionResponse::new()
.content(new_message)
.components(Vec::new()),
)
.await
{
log::error!("Could not update dice roll message: {}", err);
}
// Send message to the requester
send_roll_message(&ctx, result, user_id, roller_id, &response).await;
} else {
log::error!("Could not handle dice click: {}", custom_id);
}
}
} else if let Interaction::Ping(_ping) = interaction {
log::trace!("Received PING");
} else if let Interaction::Autocomplete(_autocomplete) = interaction {
log::trace!("Received AUTOCOMPLETE");
} else if let Interaction::Modal(modal) = interaction {
log::trace!("Received MODAL");
create_modal_response(&ctx, &modal).await;
}
}
}
async fn update_guild_commands(ctx: &Context, guild: &UnavailableGuild, force_register: bool) {
// List of commands to register for the guild
let guild_commands = vec![
commands::audio::play::register(),
commands::audio::stop::register(),
commands::audio::pause::register(),
commands::audio::resume::register(),
commands::audio::mute::register(),
commands::audio::skip::register(),
commands::audio::volume::register(),
commands::event::schedule::register(),
commands::fun::roll::register(),
commands::fun::request_roll::register(),
commands::utility::ping::register(),
];
let guild_id = guild.id.get() as i64;
let register_commands = match GuildCache::find_by_id(guild_id).await {
Some(_) => force_register,
None => {
// If no guild cache is found, create a new one.
let guild_cache = GuildCache {
id: guild_id,
name: guild.id.name(&ctx.cache),
owner_id: None,
volume: 100,
};
if let Err(err) = guild_cache.insert().await {
log::error!("Could not insert guild cache: {err}");
};
true
}
};
if register_commands {
// Register the commands in the guild
match guild.id.set_commands(&ctx.http, guild_commands).await {
Ok(registered_commands) => {
log::info!(
"Registered {} commands for guild {}",
registered_commands.len(),
guild.id.get()
);
}
Err(why) => {
log::error!(
"Could not register commands for guild {}: {:?}",
guild.id.get(),
why
);
}
};
} else {
log::debug!("Guild {guild_id} is already registered");
}
}

View File

@@ -0,0 +1,14 @@
pub mod chat;
pub mod commands;
pub mod error;
pub mod handler;
pub mod ytdlp;
use reqwest::Client as HttpClient;
use serenity::prelude::TypeMapKey;
pub struct HttpKey;
impl TypeMapKey for HttpKey {
type Value = HttpClient;
}

View File

@@ -0,0 +1,39 @@
mod model;
pub use model::*;
use std::process::{Child, Command, Output, Stdio};
const YOUTUBE_DL_COMMAND: &str = "yt-dlp";
pub struct YtDlp {
command: Command,
args: Vec<String>,
}
impl YtDlp {
pub fn new() -> Self {
let mut cmd = Command::new(YOUTUBE_DL_COMMAND);
cmd
.env("LC_ALL", "en_US.UTF-8")
.stdout(Stdio::piped())
.stdin(Stdio::piped())
.stderr(Stdio::piped());
Self {
command: cmd,
args: Vec::new(),
}
}
pub fn arg(&mut self, arg: &str) -> &mut Self {
self.args.push(arg.to_owned());
self
}
pub fn execute(&mut self) -> std::io::Result<Output> {
self
.command
.args(self.args.clone())
.spawn()
.and_then(Child::wait_with_output)
}
}

View File

@@ -0,0 +1,35 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum YtDlpItem {
PlaylistItem {
id: String,
url: String,
title: String,
duration: Option<f64>,
playlist_index: Option<i32>,
},
VideoItem {
id: String,
webpage_url: String,
title: String,
duration: Option<f64>,
},
}
impl YtDlpItem {
pub fn get_title(&self) -> &str {
match self {
YtDlpItem::PlaylistItem { title, .. } => title,
YtDlpItem::VideoItem { title, .. } => title,
}
}
pub fn get_url(&self) -> &str {
match self {
YtDlpItem::PlaylistItem { url, .. } => url,
YtDlpItem::VideoItem { webpage_url, .. } => webpage_url,
}
}
}

View File

@@ -0,0 +1,21 @@
[package]
name = "siren-core"
edition.workspace = true
version.workspace = true
rust-version.workspace = true
authors.workspace = true
[dependencies]
tokio = { workspace = true }
log = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
chrono = { workspace = true }
reqwest = { workspace = true }
uuid = { workspace = true }
redis = { workspace = true }
rand = { workspace = true }
rand_chacha = { workspace = true }
regex = { workspace = true }
lazy_static = { workspace = true }

View File

@@ -0,0 +1,81 @@
use crate::error::Result;
use std::env;
pub struct EnvironmentConfiguration {
pub rust_log: String,
pub discord_token: String,
pub discord_secret: String,
pub jwt_secret: String,
pub postgres_user: String,
pub postgres_password: String,
pub postgres_database: String,
pub postgres_host: String,
pub postgres_port: u16,
pub api_base_url: String,
pub api_port: u16,
pub api_session_ttl: u64,
pub valkey_host: String,
pub valkey_port: u16,
pub minio_root_user: String,
pub minio_root_password: String,
pub minio_host: String,
pub minio_port: u16,
pub minio_port_internal: u16,
pub data_dir_path: Option<String>,
pub force_register: bool,
pub default_api_key: String,
pub default_server: Option<String>,
pub default_user: Option<String>,
}
impl EnvironmentConfiguration {
pub fn load() -> Result<Self> {
Ok(Self {
rust_log: env::var("RUST_LOG").unwrap_or_else(|_| "warn,siren=info".to_string()),
discord_token: env::var("DISCORD_BOT_TOKEN")?,
discord_secret: env::var("DISCORD_CLIENT_SECRET")?,
jwt_secret: env::var("JWT_SECRET")?,
postgres_user: env::var("POSTGRES_USER")?,
postgres_password: env::var("POSTGRES_PASSWORD")?,
postgres_database: env::var("POSTGRES_DB")?,
postgres_host: env::var("POSTGRES_HOST")?,
postgres_port: env::var("POSTGRES_PORT")
.unwrap_or_else(|_| "5432".to_string())
.parse()
.unwrap_or(5432),
api_base_url: env::var("API_BASE_URL")?,
api_port: env::var("API_PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse()
.unwrap_or(3000),
api_session_ttl: env::var("API_SESSION_TTL")
.unwrap_or_else(|_| "86400".to_string())
.parse()
.unwrap_or(86400),
valkey_host: env::var("VALKEY_HOST").unwrap_or_else(|_| "localhost".to_string()),
valkey_port: env::var("VALKEY_PORT")
.unwrap_or_else(|_| "6379".to_string())
.parse()
.unwrap_or(6379),
minio_root_user: env::var("MINIO_ROOT_USER")?,
minio_root_password: env::var("MINIO_ROOT_PASSWORD")?,
minio_host: env::var("MINIO_HOST").unwrap_or_else(|_| "localhost".to_string()),
minio_port: env::var("MINIO_PORT")
.unwrap_or_else(|_| "9000".to_string())
.parse()
.unwrap_or(9000),
minio_port_internal: env::var("MINIO_PORT_INTERNAL")
.unwrap_or_else(|_| "9001".to_string())
.parse()
.unwrap_or(9001),
data_dir_path: env::var("DATA_DIR_PATH").ok().filter(|s| !s.is_empty()),
force_register: env::var("FORCE_REGISTER")
.ok()
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false),
default_api_key: env::var("DEFAULT_API_KEY").unwrap_or_default(),
default_server: env::var("DEFAULT_SERVER").ok().filter(|s| !s.is_empty()),
default_user: env::var("DEFAULT_USER").ok().filter(|s| !s.is_empty()),
})
}
}

View File

@@ -0,0 +1,243 @@
use crate::data::Value;
pub enum Condition {
Simple(String, Vec<Value>),
And(Box<Condition>, Box<Condition>),
Or(Box<Condition>, Box<Condition>),
Group(Box<Condition>),
}
impl Condition {
pub fn new(condition: &str) -> Self {
Condition::Simple(condition.to_string(), vec![])
}
pub fn and(self, other: Self) -> Self {
Condition::And(Box::new(self), Box::new(other))
}
pub fn or(self, other: Self) -> Self {
Condition::Or(Box::new(self), Box::new(other))
}
pub fn group(self) -> Self {
Condition::Group(Box::new(self))
}
pub fn is_equal(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} = ?", left), vec![right.into()]),
}
}
pub fn not_equal(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} != ?", left), vec![right.into()]),
}
}
pub fn is_null(value: &str) -> Self {
Condition::Simple(format!("{} IS NULL", value), vec![])
}
pub fn not_null(value: &str) -> Self {
Condition::Simple(format!("{} IS NOT NULL", value), vec![])
}
pub fn is_in(left: &str, right: Vec<Value>) -> Self {
// Use helper function to handle special cases
if let Some(condition) = Self::handle_empty_or_all_none(left, &right, true) {
return condition;
}
let right_list = right
.iter()
.map(|_v| "'?'".to_string())
.collect::<Vec<_>>()
.join(", ");
Condition::Simple(format!("{} IN ({})", left, right_list), right)
}
pub fn not_in(left: &str, right: Vec<Value>) -> Self {
// Use helper function to handle special cases
if let Some(condition) = Self::handle_empty_or_all_none(left, &right, true) {
return condition;
}
let right_list = right
.iter()
.map(|_v| "'?'".to_string())
.collect::<Vec<_>>()
.join(", ");
Condition::Simple(format!("{} NOT IN ({})", left, right_list), right)
}
pub fn like(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} LIKE '?'", left), vec![right.into()]),
}
}
pub fn not_like(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} NOT LIKE '?'", left), vec![right.into()]),
}
}
pub fn i_like(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} ILIKE '?'", left), vec![right.into()]),
}
}
pub fn not_i_like(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} NOT ILIKE '?'", left), vec![right.into()]),
}
}
pub fn gt(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} > ?", left), vec![right.into()]),
}
}
pub fn gte(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} >= ?", left), vec![right.into()]),
}
}
pub fn lt(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} < ?", left), vec![right.into()]),
}
}
pub fn lte(left: &str, right: impl Into<Value> + Clone) -> Self {
let value = right.clone().into();
match Self::from_optional_value(left, value) {
Some(condition) => condition,
None => Condition::Simple(format!("{} <= ?", left), vec![right.into()]),
}
}
// Private helper function to handle optional values
fn from_optional_value(left: &str, value: Value) -> Option<Self> {
match value {
Value::OptionalInt(None) => Some(Condition::is_null(left)),
Value::OptionalBigInt(None) => Some(Condition::is_null(left)),
Value::OptionalFloat(None) => Some(Condition::is_null(left)),
Value::OptionalDouble(None) => Some(Condition::is_null(left)),
Value::OptionalBool(None) => Some(Condition::is_null(left)),
Value::OptionalText(None) => Some(Condition::is_null(left)),
Value::OptionalDateTime(None) => Some(Condition::is_null(left)),
_ => None, // For non-optional or Some(value), let the primary method handle it
}
}
// Private helper to handle `empty` or `all-None` lists
fn handle_empty_or_all_none(left: &str, right: &[Value], negate: bool) -> Option<Self> {
if right.is_empty() {
// For an empty list, return an always-false condition
// NOT IN with empty list is always TRUE, but we're defaulting to SQL-SAFE result (FALSE)
return Some(Condition::Simple(
if negate {
"TRUE".to_string()
} else {
"FALSE".to_string()
},
vec![],
));
}
// Check if all elements in the `right` vector are `None` (Optional*)
if right.iter().all(|v| {
matches!(
v,
Value::OptionalInt(None)
| Value::OptionalBigInt(None)
| Value::OptionalFloat(None)
| Value::OptionalDouble(None)
| Value::OptionalBool(None)
| Value::OptionalText(None)
| Value::OptionalDateTime(None)
)
}) {
// If all values are None, handle as NULL or NOT NULL
return Some(if negate {
Condition::not_null(left)
} else {
Condition::is_null(left)
});
}
// Otherwise, this is not an empty or all-none case
None
}
pub fn to_sql(&self, counter: &mut usize) -> (String, Vec<Value>) {
let mut sql = String::new();
let mut binds = Vec::new();
match self {
Condition::Simple(condition, values) => {
// Replace each instance of '?' with increasing numbered binds
let mut numbered_condition = String::new();
let mut chars = condition.chars().peekable();
while let Some(c) = chars.next() {
if c == '?' {
// Increment the counter and replace `?` with a numbered bind
*counter += 1;
numbered_condition.push_str(&format!("${}", *counter));
} else {
numbered_condition.push(c);
}
}
sql.push_str(&numbered_condition);
binds.extend(values.clone());
}
Condition::And(left, right) => {
let (left_sql, left_binds) = left.to_sql(counter);
let (right_sql, right_binds) = right.to_sql(counter);
sql.push_str(&format!("{} AND {}", left_sql, right_sql));
binds.extend(left_binds);
binds.extend(right_binds);
}
Condition::Or(left, right) => {
let (left_sql, left_binds) = left.to_sql(counter);
let (right_sql, right_binds) = right.to_sql(counter);
sql.push_str(&format!("{} OR {}", left_sql, right_sql));
binds.extend(left_binds);
binds.extend(right_binds);
}
Condition::Group(inner) => {
let (inner_sql, inner_binds) = inner.to_sql(counter);
sql.push_str(&format!("({})", inner_sql));
binds.extend(inner_binds);
}
};
(sql, binds)
}
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

View File

@@ -0,0 +1,46 @@
use std::str::FromStr;
use serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)]
pub enum AbilityType {
#[serde(rename = "strength")]
Strength,
#[serde(rename = "dexterity")]
Dexterity,
#[serde(rename = "constitution")]
Constitution,
#[serde(rename = "intelligence")]
Intelligence,
#[serde(rename = "wisdom")]
Wisdom,
#[serde(rename = "charisma")]
Charisma,
}
impl AbilityType {
pub fn to_string(&self) -> String {
match self {
AbilityType::Strength => "Strength".to_string(),
AbilityType::Dexterity => "Dexterity".to_string(),
AbilityType::Constitution => "Constitution".to_string(),
AbilityType::Intelligence => "Intelligence".to_string(),
AbilityType::Wisdom => "Wisdom".to_string(),
AbilityType::Charisma => "Charisma".to_string(),
}
}
}
impl FromStr for AbilityType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Strength" => Ok(AbilityType::Strength),
"Dexterity" => Ok(AbilityType::Dexterity),
"Constitution" => Ok(AbilityType::Constitution),
"Intelligence" => Ok(AbilityType::Intelligence),
"Wisdom" => Ok(AbilityType::Wisdom),
"Charisma" => Ok(AbilityType::Charisma),
_ => Err(()),
}
}
}

View File

@@ -0,0 +1,82 @@
use std::str::FromStr;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub enum ConditionType {
#[serde(rename = "blinded")]
Blinded,
#[serde(rename = "charmed")]
Charmed,
#[serde(rename = "deafened")]
Deafened,
#[serde(rename = "exhaustion")]
Exhaustion,
#[serde(rename = "frightened")]
Frightened,
#[serde(rename = "grappled")]
Grappled,
#[serde(rename = "incapacitated")]
Incapacitated,
#[serde(rename = "invisible")]
Invisible,
#[serde(rename = "paralyzed")]
Paralyzed,
#[serde(rename = "petrified")]
Petrified,
#[serde(rename = "poisoned")]
Poisoned,
#[serde(rename = "prone")]
Prone,
#[serde(rename = "restrained")]
Restrained,
#[serde(rename = "stunned")]
Stunned,
#[serde(rename = "unconscious")]
Unconscious,
}
impl ConditionType {
pub fn to_string(&self) -> String {
match self {
ConditionType::Blinded => "Blinded".to_string(),
ConditionType::Charmed => "Charmed".to_string(),
ConditionType::Deafened => "Deafened".to_string(),
ConditionType::Exhaustion => "Exhaustion".to_string(),
ConditionType::Frightened => "Frightened".to_string(),
ConditionType::Grappled => "Grappled".to_string(),
ConditionType::Incapacitated => "Incapacitated".to_string(),
ConditionType::Invisible => "Invisible".to_string(),
ConditionType::Paralyzed => "Paralyzed".to_string(),
ConditionType::Petrified => "Petrified".to_string(),
ConditionType::Poisoned => "Poisoned".to_string(),
ConditionType::Prone => "Prone".to_string(),
ConditionType::Restrained => "Restrained".to_string(),
ConditionType::Stunned => "Stunned".to_string(),
ConditionType::Unconscious => "Unconscious".to_string(),
}
}
}
impl FromStr for ConditionType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Blinded" => Ok(ConditionType::Blinded),
"Charmed" => Ok(ConditionType::Charmed),
"Deafened" => Ok(ConditionType::Deafened),
"Exhaustion" => Ok(ConditionType::Exhaustion),
"Frightened" => Ok(ConditionType::Frightened),
"Grappled" => Ok(ConditionType::Grappled),
"Incapacitated" => Ok(ConditionType::Incapacitated),
"Invisible" => Ok(ConditionType::Invisible),
"Paralyzed" => Ok(ConditionType::Paralyzed),
"Petrified" => Ok(ConditionType::Petrified),
"Poisoned" => Ok(ConditionType::Poisoned),
"Prone" => Ok(ConditionType::Prone),
"Restrained" => Ok(ConditionType::Restrained),
"Stunned" => Ok(ConditionType::Stunned),
"Unconscious" => Ok(ConditionType::Unconscious),
_ => Err(()),
}
}
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,13 @@
pub mod backgrounds;
pub mod bestiary;
pub mod classes;
pub mod conditions;
pub mod feats;
pub mod items;
pub mod options;
pub mod races;
pub mod spells;
pub fn load_data(data_dir_path: &str) {
spells::load_data(data_dir_path);
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,62 @@
mod model;
mod types;
use std::{
fs::{metadata, File, read_dir},
path::Path,
io::BufReader,
};
pub use model::*;
pub use types::*;
pub fn load_data(data_dir_path: &str) {
if Path::new(data_dir_path).exists() {
let meta = metadata(data_dir_path).unwrap();
if meta.is_dir() {
let spells_dir_path = format!("{}/spells", data_dir_path);
if Path::new(&spells_dir_path).exists() {
let meta = metadata(&spells_dir_path).unwrap();
if meta.is_dir() {
for entry in read_dir(&spells_dir_path).unwrap() {
let entry = entry.unwrap();
let path = entry.path();
if path.is_file() {
let file = File::open(path).unwrap();
let reader = BufReader::new(file);
let result: Result<Vec<Spell>, serde_json::Error> = serde_json::from_reader(reader);
match result {
Ok(spells) => {
for spell in spells {
// let mut filters = QueryFilters::default();
// filters.by_name = Some(spell.name.clone());
// match QuerySpell::get_all(&filters, 100, 1) {
// Ok(spells) => {
// if spells.len() > 0 {
// trace!("Spell '{}' already exists", spell.name);
// continue;
// }
// }
// Err(err) => {
// warn!("Error checking if spell '{}' exists: {}", spell.name, err);
// continue;
// }
// };
// let spell = InsertSpell::insert(spell.into()).unwrap();
// trace!("Inserted spell: {}", spell.name);
}
}
Err(err) => log::warn!("Error reading spells from file: {}", err),
};
}
}
}
}
}
} else {
log::warn!(
"Data path '{}' does not exist, no data imported",
data_dir_path
);
}
}

View File

@@ -0,0 +1,184 @@
use serde::{Deserialize, Serialize};
use crate::dnd::{classes::AbilityType, conditions::ConditionType};
use super::{
SchoolType, CastingTime, SpellAttackType, SpellDamageType, Range, Area, Components, Duration,
Source, Description, DurationType, Effect,
};
#[derive(Debug, Serialize, Deserialize)]
pub struct QuerySpell {
pub id: i32,
pub name: String,
pub school: String,
pub level: i32,
pub ritual: bool,
pub concentration: bool,
pub classes: Vec<String>,
pub damage_inflict: Vec<String>,
pub damage_resist: Vec<String>,
pub conditions: Vec<String>,
pub saving_throw: Vec<String>,
pub attack_type: Option<String>,
pub data: serde_json::Value,
}
#[derive(Debug)]
pub struct InsertSpell {
pub name: String,
pub school: String,
pub level: i32,
pub ritual: bool,
pub concentration: bool,
pub classes: Vec<String>,
pub damage_inflict: Vec<String>,
pub damage_resist: Vec<String>,
pub conditions: Vec<String>,
pub saving_throw: Vec<String>,
pub attack_type: Option<String>,
pub data: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Spell {
pub id: Option<i32>,
pub name: String,
pub school: SchoolType,
pub level: i32,
pub ritual: bool,
pub casting_time: CastingTime,
#[serde(skip_serializing_if = "Option::is_none")]
pub effect: Option<Effect>,
#[serde(skip_serializing_if = "Option::is_none")]
pub saving_throw: Option<Vec<AbilityType>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub attack_type: Option<SpellAttackType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub damage_inflict: Option<Vec<SpellDamageType>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub damage_resist: Option<Vec<SpellDamageType>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conditions: Option<Vec<ConditionType>>,
pub range: Range,
#[serde(skip_serializing_if = "Option::is_none")]
pub area: Option<Area>,
pub components: Components,
pub durations: Vec<Duration>,
pub classes: Vec<String>,
pub sources: Vec<Source>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<Description>,
}
impl From<QuerySpell> for Spell {
fn from(query: QuerySpell) -> Self {
return match serde_json::from_value(query.data) {
Ok(data) => data,
Err(err) => {
log::error!("Failed to parse spell: {}", err);
Self {
id: None,
name: "".to_string(),
school: SchoolType::Abjuration,
level: 0,
ritual: false,
casting_time: CastingTime {
value: 0,
casting_type: "".to_string(),
note: None,
},
effect: None,
saving_throw: None,
attack_type: None,
damage_inflict: None,
damage_resist: None,
conditions: None,
range: Range {
range_type: "".to_string(),
value: None,
unit: None,
},
area: None,
components: Components {
verbal: false,
somatic: false,
material: false,
materials_needed: None,
materials_cost: None,
materials_consumed: None,
},
durations: vec![],
classes: vec![],
sources: vec![],
tags: None,
description: None,
}
}
};
}
}
impl Into<InsertSpell> for Spell {
fn into(self) -> InsertSpell {
return InsertSpell {
name: self.name.to_string(),
school: self.school.to_string(),
level: self.level,
ritual: self.ritual,
concentration: self
.durations
.iter()
.any(|duration| match duration.duration_type {
DurationType::Concentration => true,
_ => false,
}),
classes: self
.classes
.iter()
.map(|class| class.to_string())
.collect::<Vec<String>>(),
damage_inflict: match &self.damage_inflict {
Some(damage_inflict) => damage_inflict
.iter()
.map(|damage_inflict| damage_inflict.to_string())
.collect(),
None => vec![],
},
damage_resist: match &self.damage_resist {
Some(damage_resist) => damage_resist
.iter()
.map(|damage_resist| damage_resist.to_string())
.collect(),
None => vec![],
},
conditions: match &self.conditions {
Some(conditions) => conditions
.iter()
.map(|condition| condition.to_string())
.collect(),
None => vec![],
},
saving_throw: match &self.saving_throw {
Some(saving_throw) => saving_throw
.iter()
.map(|saving_throw| saving_throw.to_string())
.collect(),
None => vec![],
},
attack_type: self
.attack_type
.as_ref()
.map(|attack_type| attack_type.to_string()),
data: match serde_json::to_value(&self) {
Ok(data) => data,
Err(err) => {
log::error!("Failed to serialize spell: {}", err);
serde_json::Value::Null
}
},
};
}
}

View File

@@ -0,0 +1,366 @@
use std::str::FromStr;
use serde::{Deserialize, Serialize, ser::SerializeMap};
#[derive(Debug, Serialize, Deserialize)]
pub enum SchoolType {
#[serde(rename = "abjuration")]
Abjuration,
#[serde(rename = "conjuration")]
Conjuration,
#[serde(rename = "divination")]
Divination,
#[serde(rename = "enchantment")]
Enchantment,
#[serde(rename = "evocation")]
Evocation,
#[serde(rename = "illusion")]
Illusion,
#[serde(rename = "necromancy")]
Necromancy,
#[serde(rename = "transmutation")]
Transmutation,
}
impl SchoolType {
pub fn to_string(&self) -> String {
match self {
SchoolType::Abjuration => "abjuration".to_string(),
SchoolType::Conjuration => "conjuration".to_string(),
SchoolType::Divination => "divination".to_string(),
SchoolType::Enchantment => "enchantment".to_string(),
SchoolType::Evocation => "evocation".to_string(),
SchoolType::Illusion => "illusion".to_string(),
SchoolType::Necromancy => "necromancy".to_string(),
SchoolType::Transmutation => "transmutation".to_string(),
}
}
}
impl FromStr for SchoolType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"abjuration" => Ok(SchoolType::Abjuration),
"conjuration" => Ok(SchoolType::Conjuration),
"divination" => Ok(SchoolType::Divination),
"enchantment" => Ok(SchoolType::Enchantment),
"evocation" => Ok(SchoolType::Evocation),
"illusion" => Ok(SchoolType::Illusion),
"necromancy" => Ok(SchoolType::Necromancy),
"transmutation" => Ok(SchoolType::Transmutation),
_ => Err(()),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CastingTime {
pub value: i32,
#[serde(rename = "unit")]
pub casting_type: String,
pub note: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum SpellAttackType {
#[serde(rename = "melee")]
Melee,
#[serde(rename = "ranged")]
Ranged,
}
impl SpellAttackType {
pub fn to_string(&self) -> String {
match self {
SpellAttackType::Melee => "melee".to_string(),
SpellAttackType::Ranged => "ranged".to_string(),
}
}
}
impl FromStr for SpellAttackType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"melee" => Ok(SpellAttackType::Melee),
"ranged" => Ok(SpellAttackType::Ranged),
_ => Err(()),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum SpellDamageType {
#[serde(rename = "acid")]
Acid,
#[serde(rename = "bludgeoning")]
Bludgeoning,
#[serde(rename = "cold")]
Cold,
#[serde(rename = "fire")]
Fire,
#[serde(rename = "force")]
Force,
#[serde(rename = "lightning")]
Lightning,
#[serde(rename = "necrotic")]
Necrotic,
#[serde(rename = "piercing")]
Piercing,
#[serde(rename = "poison")]
Poison,
#[serde(rename = "psychic")]
Psychic,
#[serde(rename = "radiant")]
Radiant,
#[serde(rename = "slashing")]
Slashing,
#[serde(rename = "thunder")]
Thunder,
}
impl SpellDamageType {
pub fn to_string(&self) -> String {
match self {
SpellDamageType::Acid => "acid".to_string(),
SpellDamageType::Bludgeoning => "bludgeoning".to_string(),
SpellDamageType::Cold => "cold".to_string(),
SpellDamageType::Fire => "fire".to_string(),
SpellDamageType::Force => "force".to_string(),
SpellDamageType::Lightning => "lightning".to_string(),
SpellDamageType::Necrotic => "necrotic".to_string(),
SpellDamageType::Piercing => "piercing".to_string(),
SpellDamageType::Poison => "poison".to_string(),
SpellDamageType::Psychic => "psychic".to_string(),
SpellDamageType::Radiant => "radiant".to_string(),
SpellDamageType::Slashing => "slashing".to_string(),
SpellDamageType::Thunder => "thunder".to_string(),
}
}
}
impl FromStr for SpellDamageType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"acid" => Ok(SpellDamageType::Acid),
"bludgeoning" => Ok(SpellDamageType::Bludgeoning),
"cold" => Ok(SpellDamageType::Cold),
"fire" => Ok(SpellDamageType::Fire),
"force" => Ok(SpellDamageType::Force),
"lightning" => Ok(SpellDamageType::Lightning),
"necrotic" => Ok(SpellDamageType::Necrotic),
"piercing" => Ok(SpellDamageType::Piercing),
"poison" => Ok(SpellDamageType::Poison),
"psychic" => Ok(SpellDamageType::Psychic),
"radiant" => Ok(SpellDamageType::Radiant),
"slashing" => Ok(SpellDamageType::Slashing),
"thunder" => Ok(SpellDamageType::Thunder),
_ => Err(()),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Range {
#[serde(rename = "type")]
pub range_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub value: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub unit: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Area {
#[serde(rename = "type")]
pub area_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub value: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub unit: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Duration {
#[serde(rename = "type")]
pub duration_type: DurationType,
#[serde(skip_serializing_if = "Option::is_none")]
pub value: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub unit: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum DurationType {
#[serde(rename = "concentration")]
Concentration,
#[serde(rename = "instantaneous")]
Instantaneous,
#[serde(rename = "timed")]
Timed,
#[serde(rename = "dispelled")]
UntilDispelled,
#[serde(rename = "special")]
Special,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Source {
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<i32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Description {
pub entries: Vec<Entry>,
}
#[derive(Debug)]
pub struct Entry {
pub text: Option<String>,
pub list: Option<Vec<String>>,
pub table: Option<EntryTable>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EntryTable {
pub headers: Vec<String>,
pub rows: Vec<Vec<String>>,
}
impl<'de> Deserialize<'de> for Entry {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::String(s) => Ok(Entry {
text: Some(s),
list: None,
table: None,
}),
serde_json::Value::Object(o) => {
let text = match o.get("text") {
Some(t) => match t.as_str() {
Some(s) => Some(s.to_string()),
None => return Err(serde::de::Error::custom("Invalid entry text")),
},
None => None,
};
let list = match o.get("list") {
Some(i) => match i.as_array() {
Some(a) => {
let mut list = Vec::new();
for item in a {
match item.as_str() {
Some(s) => list.push(s.to_string()),
None => return Err(serde::de::Error::custom("Invalid entry list item")),
}
}
Some(list)
}
None => return Err(serde::de::Error::custom("Invalid entry list items")),
},
None => None,
};
let table = match o.get("table") {
Some(t) => match t.as_object() {
Some(o) => {
let mut headers = Vec::new();
let mut rows = Vec::new();
match o.get("headers") {
Some(c) => match c.as_array() {
Some(a) => {
for item in a {
match item.as_str() {
Some(s) => headers.push(s.to_string()),
None => return Err(serde::de::Error::custom("Invalid entry table header")),
}
}
}
None => return Err(serde::de::Error::custom("Invalid entry table headers")),
},
None => return Err(serde::de::Error::custom("Missing entry table headers")),
};
match o.get("rows") {
Some(r) => match r.as_array() {
Some(a) => {
for row in a {
match row.as_array() {
Some(a) => {
let mut row = Vec::new();
for item in a {
match item.as_str() {
Some(s) => row.push(s.to_string()),
None => {
return Err(serde::de::Error::custom(
"Invalid entry table row item",
))
}
}
}
rows.push(row);
}
None => return Err(serde::de::Error::custom("Invalid entry table row")),
}
}
}
None => return Err(serde::de::Error::custom("Invalid entry table rows")),
},
None => return Err(serde::de::Error::custom("Missing entry table rows")),
};
Some(EntryTable { headers, rows })
}
None => return Err(serde::de::Error::custom("Invalid entry table")),
},
None => None,
};
Ok(Entry { text, list, table })
}
_ => Err(serde::de::Error::custom("Invalid entry")),
}
}
}
impl Serialize for Entry {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(1))?;
if let Some(text) = &self.text {
map.serialize_entry("text", text)?;
}
if let Some(list) = &self.list {
map.serialize_entry("list", list)?;
}
if let Some(table) = &self.table {
map.serialize_entry("table", table)?;
}
map.end()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Components {
pub verbal: bool,
pub somatic: bool,
pub material: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub materials_needed: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub materials_cost: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub materials_consumed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Effect {
pub effect_type: Option<String>,
}

View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

View File

@@ -0,0 +1,57 @@
use crate::error::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
const TABLE_NAME: &str = "events";
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct Event {
pub id: Uuid,
pub guild_id: i64,
pub author_id: i64,
pub title: String,
pub date_time: DateTime<Utc>,
pub description: Option<String>,
pub rsvp: Vec<i64>,
}
impl Event {
pub async fn insert(&self) -> Result<()> {
let pool = crate::data::pool();
sqlx::query(&format!(
"INSERT INTO {} (
id,
guild_id,
author_id,
title,
date_time,
description,
rsvp
) VALUES (
$1, $2, $3, $4, $5, $6, $7
)",
TABLE_NAME
))
.bind(self.id)
.bind(self.guild_id)
.bind(self.author_id)
.bind(&self.title)
.bind(self.date_time)
.bind(&self.description)
.bind(&self.rsvp)
.execute(pool)
.await?;
Ok(())
}
pub async fn get_by_id(id: i64) -> Result<Option<Self>> {
let pool = crate::data::pool();
let item = sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME))
.bind(id)
.fetch_optional(pool)
.await?;
Ok(item)
}
}

View File

@@ -0,0 +1,106 @@
use crate::data::Value;
use sqlx::{FromRow, Postgres};
#[allow(async_fn_in_trait)]
pub trait ExecutableQuery {
fn build(&self) -> (String, Vec<Value>);
async fn execute(&self) -> Result<sqlx::postgres::PgQueryResult, sqlx::Error> {
// Build the SQL query and its values
let (query_string, values) = self.build();
// Start constructing the query
let mut query = sqlx::query(&query_string);
// Bind each value to its respective placeholder
for value in values {
match value {
Value::Int(n) => query = query.bind(n),
Value::OptionalInt(n) => query = query.bind(n),
Value::BigInt(n) => query = query.bind(n),
Value::OptionalBigInt(n) => query = query.bind(n),
Value::Float(n) => query = query.bind(n),
Value::OptionalFloat(n) => query = query.bind(n),
Value::Double(n) => query = query.bind(n),
Value::OptionalDouble(n) => query = query.bind(n),
Value::Bool(n) => query = query.bind(n),
Value::OptionalBool(n) => query = query.bind(n),
Value::Text(n) => query = query.bind(n),
Value::OptionalText(n) => query = query.bind(n),
Value::DateTime(n) => query = query.bind(n),
Value::OptionalDateTime(n) => query = query.bind(n),
}
}
let pool = crate::data::pool();
query.execute(pool).await
}
async fn fetch_optional<
T: Send + Unpin + for<'r> FromRow<'r, <Postgres as sqlx::Database>::Row>,
>(
&self,
) -> Option<T> {
let (query_string, values) = self.build();
let mut query_as = sqlx::query_as(&query_string);
for value in values {
match value {
Value::Int(n) => query_as = query_as.bind(n),
Value::OptionalInt(n) => query_as = query_as.bind(n),
Value::BigInt(n) => query_as = query_as.bind(n),
Value::OptionalBigInt(n) => query_as = query_as.bind(n),
Value::Float(n) => query_as = query_as.bind(n),
Value::OptionalFloat(n) => query_as = query_as.bind(n),
Value::Double(n) => query_as = query_as.bind(n),
Value::OptionalDouble(n) => query_as = query_as.bind(n),
Value::Bool(n) => query_as = query_as.bind(n),
Value::OptionalBool(n) => query_as = query_as.bind(n),
Value::Text(n) => query_as = query_as.bind(n),
Value::OptionalText(n) => query_as = query_as.bind(n),
Value::DateTime(n) => query_as = query_as.bind(n),
Value::OptionalDateTime(n) => query_as = query_as.bind(n),
}
}
let pool = crate::data::pool();
query_as.fetch_optional(pool).await.unwrap_or_else(|err| {
log::error!(
"Unable to fetch optional on query '{}': {}",
query_string,
err
);
None
})
}
async fn fetch_all<T: Send + Unpin + for<'r> FromRow<'r, <Postgres as sqlx::Database>::Row>>(
&self,
) -> Vec<T> {
let (query_string, values) = self.build();
let mut query_as = sqlx::query_as(&query_string);
for value in values {
match value {
Value::Int(n) => query_as = query_as.bind(n),
Value::OptionalInt(n) => query_as = query_as.bind(n),
Value::BigInt(n) => query_as = query_as.bind(n),
Value::OptionalBigInt(n) => query_as = query_as.bind(n),
Value::Float(n) => query_as = query_as.bind(n),
Value::OptionalFloat(n) => query_as = query_as.bind(n),
Value::Double(n) => query_as = query_as.bind(n),
Value::OptionalDouble(n) => query_as = query_as.bind(n),
Value::Bool(n) => query_as = query_as.bind(n),
Value::OptionalBool(n) => query_as = query_as.bind(n),
Value::Text(n) => query_as = query_as.bind(n),
Value::OptionalText(n) => query_as = query_as.bind(n),
Value::DateTime(n) => query_as = query_as.bind(n),
Value::OptionalDateTime(n) => query_as = query_as.bind(n),
}
}
let pool = crate::data::pool();
query_as.fetch_all(pool).await.unwrap_or_else(|err| {
log::error!("Unable to fetch all on query '{}': {}", query_string, err);
vec![]
})
}
}

View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

View File

@@ -0,0 +1,53 @@
use crate::{
data::{
Value,
condition::Condition,
executable_query::ExecutableQuery,
insert::InsertBuilder,
query::QueryBuilder,
update::UpdateBuilder,
},
error::Result,
};
use serde::{Deserialize, Serialize};
const TABLE_NAME: &str = "guilds";
#[derive(Serialize, Deserialize, sqlx::FromRow, Debug)]
pub struct GuildCache {
pub id: i64,
pub name: Option<String>,
pub owner_id: Option<i64>,
pub volume: i32,
}
impl GuildCache {
pub async fn insert(&self) -> Result<()> {
InsertBuilder::new(TABLE_NAME)
.column("id", Value::BigInt(self.id))
.column("name", Value::OptionalText(self.name.clone()))
.column("owner_id", Value::OptionalBigInt(self.owner_id))
.column("volume", Value::Int(self.volume))
.execute()
.await?;
Ok(())
}
pub async fn find_by_id(id: i64) -> Option<Self> {
QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("id", Value::BigInt(id)))
.fetch_optional()
.await
}
pub async fn update(&self) -> Result<()> {
UpdateBuilder::new(TABLE_NAME)
.column("name", Value::OptionalText(self.name.clone()))
.column("owner_id", Value::OptionalBigInt(self.owner_id))
.column("volume", Value::Int(self.volume))
.where_condition(Condition::is_equal("id", Value::BigInt(self.id)))
.execute()
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,60 @@
use crate::data::{Value, executable_query::ExecutableQuery};
pub struct InsertBuilder {
table: String,
columns: Vec<String>,
returning: Vec<String>,
values: Vec<Value>,
}
impl InsertBuilder {
pub fn new(table: &str) -> Self {
Self {
table: table.to_string(),
columns: Vec::new(),
returning: Vec::new(),
values: Vec::new(),
}
}
pub fn column(mut self, column: &str, value: Value) -> Self {
self.columns.push(column.to_string());
self.values.push(value);
self
}
pub fn returning(mut self, columns: &[&str]) -> Self {
self.returning = columns.iter().map(|s| s.to_string()).collect();
self
}
}
impl ExecutableQuery for InsertBuilder {
fn build(&self) -> (String, Vec<Value>) {
if self.columns.is_empty() || self.values.is_empty() {
panic!("Cannot build insert query without columns and values");
}
// Create the list of column names
let columns = self.columns.join(", ");
// Generate placeholders for values ($1, $2, etc.)
let placeholders = (1..=self.values.len())
.map(|i| format!("${}", i))
.collect::<Vec<_>>()
.join(", ");
// Create the basic INSERT statement
let mut query = format!(
"INSERT INTO {} ({}) VALUES ({})",
self.table, columns, placeholders
);
// Add RETURNING clause if specified
if !self.returning.is_empty() {
query.push_str(&format!(" RETURNING {}", self.returning.join(", ")));
}
(query, self.values.clone())
}
}

View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

View File

@@ -0,0 +1,74 @@
use crate::error::Result;
use serde::{Deserialize, Serialize};
const TABLE_NAME: &str = "messages";
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageCache {
pub id: String,
pub guild_id: i64,
pub channel_id: i64,
pub author_id: i64,
pub created: i64,
pub model: String,
pub request: String,
pub response: String,
pub request_tags: Vec<String>,
pub response_tags: Vec<String>,
}
impl MessageCache {
pub async fn insert(&self) -> Result<()> {
let pool = crate::data::pool();
sqlx::query(&format!(
"INSERT INTO {} (
id,
guild_id,
channel_id,
author_id,
created,
model,
request,
response,
request_tags,
response_tags
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10
)",
TABLE_NAME
))
.bind(&self.id)
.bind(self.guild_id)
.bind(self.channel_id)
.bind(self.author_id)
.bind(self.created)
.bind(&self.model)
.bind(&self.request)
.bind(&self.response)
.bind(&self.request_tags)
.bind(&self.response_tags)
.execute(pool)
.await?;
Ok(())
}
pub async fn find(
guild_id: i64,
channel_id: i64,
author_id: i64,
limit: i64,
) -> Result<Vec<MessageCache>> {
let pool = crate::data::pool();
let messages = sqlx::query_as::<_, MessageCache>(&format!(
"SELECT * FROM {} WHERE guild_id = $1 AND channel_id = $2 AND author_id = $3 ORDER BY created ASC LIMIT $4",
TABLE_NAME
))
.bind(guild_id)
.bind(channel_id)
.bind(author_id)
.bind(limit)
.fetch_all(pool)
.await?;
Ok(messages)
}
}

View File

@@ -0,0 +1,137 @@
use crate::error::Result;
use chrono::{DateTime, Utc};
use redis::{Client as RedisClient, RedisResult, aio::MultiplexedConnection as RedisConnection};
use sqlx::{Pool, Postgres, postgres::PgPoolOptions};
use std::{fmt, fmt::Display, sync::OnceLock, time::Duration};
pub mod condition;
pub mod events;
mod executable_query;
pub mod guilds;
pub mod insert;
pub mod messages;
pub mod query;
pub mod update;
use crate::config::EnvironmentConfiguration;
pub use executable_query::ExecutableQuery;
static POOL: OnceLock<Pool<Postgres>> = OnceLock::new();
static REDIS: OnceLock<RedisClient> = OnceLock::new();
pub async fn initialize(config: &EnvironmentConfiguration) -> Result<()> {
log::info!("Initializing database...");
// Setup Postgres pool connection
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(30))
.connect(&format!(
"postgres://{}:{}@{}:{}/{}",
config.postgres_user,
config.postgres_password,
config.postgres_host,
config.postgres_port,
config.postgres_database
))
.await?;
match POOL.set(pool) {
Ok(_) => {}
Err(_) => {
log::warn!("Database pool already initialized");
}
}
// Setup Redis connection
let redis = {
let host = std::env::var("VALKEY_HOST").unwrap_or("localhost".to_string());
let port = std::env::var("VALKEY_PORT").unwrap_or("6379".to_string());
let url = format!("redis://{}:{}", host, port);
RedisClient::open(url).expect("Failed to create valkey client")
};
match REDIS.set(redis) {
Ok(_) => {}
Err(_) => {
log::warn!("Valkey client already initialized");
}
}
// Run migrations
match run_migrations().await {
Ok(_) => log::debug!("Successfully ran migrations"),
Err(e) => log::error!("Failed to run migrations: {}", e),
}
log::info!("Database initialized");
Ok(())
}
pub fn pool() -> &'static Pool<Postgres> {
POOL.get().unwrap()
}
fn redis() -> &'static RedisClient {
REDIS.get().unwrap()
}
pub fn redis_connection() -> RedisResult<redis::Connection> {
let conn = redis().get_connection()?;
Ok(conn)
}
pub async fn redis_async_connection() -> RedisResult<RedisConnection> {
let conn = redis().get_multiplexed_async_connection().await?;
Ok(conn)
}
async fn run_migrations() -> Result<()> {
log::debug!("Running migrations");
let pool = pool();
sqlx::migrate!("../../migrations").run(pool).await?;
Ok(())
}
#[derive(Debug, Clone)]
pub enum Value {
Int(i32),
OptionalInt(Option<i32>),
BigInt(i64),
OptionalBigInt(Option<i64>),
Float(f32),
OptionalFloat(Option<f32>),
Double(f64),
OptionalDouble(Option<f64>),
Bool(bool),
OptionalBool(Option<bool>),
Text(String),
OptionalText(Option<String>),
DateTime(DateTime<Utc>),
OptionalDateTime(Option<DateTime<Utc>>),
}
impl Display for Value {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Value::Int(n) => write!(f, "{}", n),
Value::OptionalInt(Some(n)) => write!(f, "{}", n),
Value::OptionalInt(None) => write!(f, "NULL"),
Value::BigInt(n) => write!(f, "{}", n),
Value::OptionalBigInt(Some(n)) => write!(f, "{}", n),
Value::OptionalBigInt(None) => write!(f, "NULL"),
Value::Float(n) => write!(f, "{}", n),
Value::OptionalFloat(Some(n)) => write!(f, "{}", n),
Value::OptionalFloat(None) => write!(f, "NULL"),
Value::Double(n) => write!(f, "{}", n),
Value::OptionalDouble(Some(n)) => write!(f, "{}", n),
Value::OptionalDouble(None) => write!(f, "NULL"),
Value::Bool(n) => write!(f, "{}", n),
Value::OptionalBool(Some(n)) => write!(f, "{}", n),
Value::OptionalBool(None) => write!(f, "NULL"),
Value::Text(s) => write!(f, "'{}'", s.replace("'", "''")),
Value::OptionalText(Some(s)) => write!(f, "'{}'", s.replace("'", "''")),
Value::OptionalText(None) => write!(f, "NULL"),
Value::DateTime(n) => write!(f, "{}", n),
Value::OptionalDateTime(Some(n)) => write!(f, "{}", n),
Value::OptionalDateTime(None) => write!(f, "NULL"),
}
}
}

View File

@@ -0,0 +1,110 @@
use crate::data::{Value, condition::Condition, executable_query::ExecutableQuery};
pub struct QueryBuilder<'a> {
table: &'a str,
columns: Vec<&'a str>,
distinct_on: Option<Vec<String>>,
condition: Option<Condition>,
order_by: Vec<String>,
limit: Option<usize>,
offset: Option<usize>,
}
impl<'a> QueryBuilder<'a> {
pub fn new(table: &'a str) -> Self {
QueryBuilder {
table,
columns: Vec::new(),
distinct_on: None,
condition: None,
order_by: Vec::new(),
limit: None,
offset: None,
}
}
pub fn select(mut self, columns: &[&'a str]) -> Self {
self.columns.extend(columns);
self
}
pub fn distinct_on(mut self, columns: &[&str]) -> Self {
self.distinct_on = Some(columns.iter().map(|s| s.to_string()).collect());
self
}
pub fn where_condition(mut self, condition: Condition) -> Self {
self.condition = Some(condition);
self
}
pub fn order_by(mut self, column: &str, direction: Option<OrderDirection>) -> Self {
match direction {
Some(order) => self
.order_by
.push(format!("{} {}", column, order.to_string())),
None => self.order_by.push(column.to_string()),
}
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
}
pub enum OrderDirection {
Asc,
Desc,
}
impl std::fmt::Display for OrderDirection {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let direction_str = match self {
OrderDirection::Asc => "ASC",
OrderDirection::Desc => "DESC",
};
write!(f, "{}", direction_str)
}
}
impl<'a> ExecutableQuery for QueryBuilder<'a> {
fn build(&self) -> (String, Vec<Value>) {
let columns = if self.columns.is_empty() {
"*".to_string()
} else {
self.columns.join(",")
};
let mut query = if let Some(distinct_columns) = &self.distinct_on {
let distinct_on_clause = distinct_columns.join(",");
format!("SELECT DISTINCT ON ({}) {}", distinct_on_clause, columns)
} else {
format!("SELECT {}", columns)
};
query.push_str(format!(" FROM {}", self.table).as_str());
let mut values: Vec<Value> = Vec::new();
if let Some(condition) = &self.condition {
let where_condition = condition.to_sql(&mut 0);
query.push_str(&format!(" WHERE {}", where_condition.0));
values = where_condition.1;
}
if !self.order_by.is_empty() {
query.push_str(&format!(" ORDER BY {}", self.order_by.join(" ORDER BY")));
}
if let Some(limit) = self.limit {
query.push_str(&format!(" LIMIT {}", limit));
}
if let Some(offset) = self.offset {
query.push_str(&format!(" OFFSET {}", offset));
}
(query, values)
}
}

View File

@@ -0,0 +1,60 @@
use crate::data::{Value, condition::Condition, executable_query::ExecutableQuery};
pub struct UpdateBuilder {
table: String,
columns: Vec<String>,
values: Vec<Value>,
condition: Option<Condition>,
}
impl UpdateBuilder {
pub fn new(table: &str) -> Self {
Self {
table: table.to_string(),
columns: Vec::new(),
values: Vec::new(),
condition: None,
}
}
pub fn column(mut self, column: &str, value: Value) -> Self {
self.columns.push(column.to_string());
self.values.push(value);
self
}
pub fn where_condition(mut self, condition: Condition) -> Self {
self.condition = Some(condition);
self
}
}
impl ExecutableQuery for UpdateBuilder {
fn build(&self) -> (String, Vec<Value>) {
if self.columns.is_empty() {
panic!("Cannot build update query without columns to set");
}
// Generate the SET clause
let set_clause = self
.columns
.iter()
.enumerate()
.map(|(i, col)| format!("{} = ${}", col, i + 1))
.collect::<Vec<_>>()
.join(", ");
let mut query = format!("UPDATE {} SET {}", self.table, set_clause);
let mut counter = self.values.len();
let mut values: Vec<Value> = self.values.clone();
// Build where clause
if let Some(condition) = &self.condition {
let where_condition = condition.to_sql(&mut counter);
query.push_str(&format!(" WHERE {}", where_condition.0));
values.extend(where_condition.1);
}
(query, values)
}
}

View File

@@ -0,0 +1,108 @@
use serde::{Deserialize, Serialize};
use std::fmt;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Deserialize, Serialize)]
pub struct Error {
pub status: u16,
pub details: String,
}
impl Error {
pub fn new(status: u16, details: String) -> Self {
Self { status, details }
}
pub fn not_found(details: String) -> Self {
Self::new(404, details)
}
pub fn internal_server_error(details: String) -> Self {
Self::new(500, details)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.details.as_str())
}
}
impl std::error::Error for Error {
fn description(&self) -> &str {
&self.details
}
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Self {
Self::new(500, format!("IO error: {}", error))
}
}
impl From<std::string::FromUtf8Error> for Error {
fn from(error: std::string::FromUtf8Error) -> Self {
Self::new(500, format!("UTF-8 error: {}", error))
}
}
impl From<std::env::VarError> for Error {
fn from(error: std::env::VarError) -> Self {
Self::new(500, format!("Environment variable error: {}", error))
}
}
impl From<sqlx::Error> for Error {
fn from(error: sqlx::Error) -> Self {
match error {
sqlx::Error::RowNotFound => Error::new(404, "Not found".to_string()),
sqlx::Error::ColumnIndexOutOfBounds { .. } => Error::new(422, error.to_string()),
sqlx::Error::ColumnNotFound { .. } => Error::new(422, error.to_string()),
sqlx::Error::ColumnDecode { .. } => Error::new(422, error.to_string()),
sqlx::Error::Decode(_) => Error::new(422, error.to_string()),
sqlx::Error::PoolTimedOut => Error::new(503, error.to_string()),
sqlx::Error::PoolClosed => Error::new(503, error.to_string()),
sqlx::Error::Database(err) => {
if let Some(code) = err.code() {
match code.trim() {
"23505" => return Error::new(409, err.to_string()),
_ => (),
}
}
Error::new(500, err.to_string())
}
_ => Error::new(500, error.to_string()),
}
}
}
impl From<sqlx::migrate::MigrateError> for Error {
fn from(error: sqlx::migrate::MigrateError) -> Self {
Error::new(500, error.to_string())
}
}
impl From<redis::RedisError> for Error {
fn from(error: redis::RedisError) -> Self {
Self::new(500, format!("Redis error: {}", error))
}
}
impl From<reqwest::Error> for Error {
fn from(error: reqwest::Error) -> Self {
Self::new(500, format!("HTTP client error: {}", error))
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Self::new(500, format!("JSON error: {}", error))
}
}
impl From<uuid::Error> for Error {
fn from(error: uuid::Error) -> Self {
Self::new(500, format!("UUID error: {}", error))
}
}

View File

@@ -0,0 +1,4 @@
pub mod config;
pub mod data;
pub mod error;
pub mod utils;

View File

@@ -0,0 +1,2 @@
pub mod text_utils;
pub use text_utils::*;

View File

@@ -0,0 +1,62 @@
pub fn a_or_an(word: &str) -> &'static str {
let vowels = ['a', 'e', 'i', 'o', 'u'];
let lowercase_word = word.to_lowercase();
// Special cases where the article should be "a"
let special_cases_a = vec!["one"];
if special_cases_a.contains(&lowercase_word.as_str()) {
return "a";
}
// Special cases where the article should be "an"
let special_cases_an = vec!["hour"];
if special_cases_an.contains(&lowercase_word.as_str()) {
return "an";
}
let first_char = lowercase_word.chars().next();
match first_char {
// If the first character is a vowel, return "an"
Some(c) if vowels.contains(&c) => "an",
// Otherwise, return "a"
_ => "a",
}
}
pub fn number_to_words(n: i32) -> String {
let ones = [
"", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
];
let teens = [
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
];
let tens = [
"", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety",
];
if n < 10 {
ones[n as usize].to_string()
} else if n < 20 {
teens[(n - 10) as usize].to_string()
} else if n < 100 {
let ten_part = tens[(n / 10) as usize];
let one_part = ones[(n % 10) as usize];
if n % 10 == 0 {
ten_part.to_string() // e.g., 20 → "twenty"
} else {
format!("{}-{}", ten_part, one_part) // e.g., 42 → "forty-two"
}
} else {
"Number out of range".to_string() // Handle numbers >= 100 (or extend the logic)
}
}

20
crates/siren/Cargo.toml Normal file
View File

@@ -0,0 +1,20 @@
[package]
name = "siren"
edition.workspace = true
version.workspace = true
rust-version.workspace = true
authors.workspace = true
[dependencies]
siren-core = { workspace = true }
siren-bot = { workspace = true }
siren-api = { workspace = true }
dotenv = { workspace = true }
log = { workspace = true }
env_logger = { workspace = true }
serenity = { workspace = true }
songbird = { workspace = true }
reqwest = { workspace = true }
# Add the `signal` feature on top of the workspace base for graceful shutdown
tokio = { workspace = true, features = ["signal"] }
dashmap = { workspace = true }

117
crates/siren/src/main.rs Normal file
View File

@@ -0,0 +1,117 @@
use dashmap::DashMap;
use dotenv::from_filename;
use reqwest::Client as HttpClient;
use serenity::{
all::{ShardManager, UserId},
http::Http,
prelude::*,
};
use siren_api::{App, AppState};
use siren_bot::{HttpKey, handler::BotHandler};
use siren_core::{
config::EnvironmentConfiguration,
error::{Error, Result},
};
use songbird::{SerenityInit, Songbird};
use std::{collections::HashMap, sync::Arc};
#[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
initialize_environment()?;
let config = EnvironmentConfiguration::load()?;
siren_core::data::initialize(&config).await?;
let handler = BotHandler::new(config.force_register);
let songbird = Songbird::serenity();
let intents: GatewayIntents = GatewayIntents::all();
let mut client = Client::builder(&config.discord_token, intents)
.event_handler(handler)
.register_songbird_with(Arc::clone(&songbird))
.type_map_insert::<HttpKey>(HttpClient::new())
.await
.expect("Error creating client");
let (bot_owner, bot_id) = get_bot_info(&client.http).await?;
let app_state = AppState {
client: HttpClient::new(),
client_id: bot_id.to_string(),
client_secret: config.discord_secret,
base_url: config.api_base_url,
discord_authorize_cache: Arc::new(Mutex::new(HashMap::new())),
http: Arc::clone(&client.http),
cache: Arc::clone(&client.cache),
map_rooms: Arc::new(DashMap::new()),
};
log::debug!(
"Starting Siren with ID: {bot_id} (Contact: {:?})",
bot_owner
);
let shard_manager = Arc::clone(&client.shard_manager);
tokio::spawn(async move {
signal_shutdown(shard_manager).await;
});
tokio::spawn(App::new(app_state).serve());
if let Err(why) = client.start_autosharded().await {
log::error!("Client error: {why:?}");
}
Ok(())
}
fn initialize_environment() -> std::io::Result<()> {
for entry in std::fs::read_dir(".")? {
let entry = entry?;
let path = entry.path();
if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
if file_name.starts_with(".env") && !file_name.ends_with(".example") && path.is_file() {
if let Err(err) = from_filename(&file_name) {
eprintln!("Failed to load {}: {}", file_name, err);
} else {
println!("Loaded: {}", file_name);
}
}
}
}
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
Ok(())
}
async fn get_bot_info(http: &Http) -> Result<(Option<UserId>, UserId)> {
match http.get_current_application_info().await {
Ok(info) => {
let bot_owner;
if let Some(team) = info.team {
bot_owner = Some(team.owner_user_id);
} else if let Some(owner) = info.owner {
bot_owner = Some(owner.id);
} else {
bot_owner = None;
}
match http.get_current_user().await {
Ok(bot) => Ok((bot_owner, bot.id)),
Err(why) => Err(Error::new(
500,
format!("Could not access the bot id: {why:?}"),
)),
}
}
Err(why) => Err(Error::new(
500,
format!("Could not access application info: {why:?}"),
)),
}
}
async fn signal_shutdown(shard_manager: Arc<ShardManager>) {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen for shutdown signal");
shard_manager.shutdown_all().await;
log::info!("Bot shutdown gracefully.");
}