diff --git a/.env b/.env index 62346c4..7bdbf80 100644 --- a/.env +++ b/.env @@ -7,7 +7,7 @@ JWT_SECRET=CHANGEME # Change this to a secure secret DATABASE_USER=siren DATABASE_PASSWORD=CHANGEME # Change this to a secure password -DATABASE_NAME=siren +DATABASE_NAME=siren_db DATABASE_HOST=localhost DATABASE_PORT=5432 diff --git a/Cargo.toml b/Cargo.toml index 15d4041..6bab4a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ rand = "0.8.5" rand_chacha = "0.3.1" tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] } regex = "1.11.0" -axum = "0.7.7" +axum = { version = "0.7.7", features = ["json"] } axum-extra = { version = "0.9.6", features = ["typed-header"] } lazy_static = "1.5.0" jsonwebtoken = "9.3.0" diff --git a/README.md b/README.md index fa06b60..166919b 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ Siren utilizes Discord slash commands. To view the commands, run `/help` in a se | --- | --- | | `/coinflip` | Flip a coin | | `/roll ` | Roll a dice | +| `/requestroll ` | Request a dice roll from a user | **Utility Commands** | Command | Description | diff --git a/bruno/audio/Pause Track.bru b/bruno/audio/Pause Track.bru new file mode 100644 index 0000000..ea1f83e --- /dev/null +++ b/bruno/audio/Pause Track.bru @@ -0,0 +1,17 @@ +meta { + name: Pause Track + type: http + seq: 2 +} + +post { + url: {{baseUrl}}/audio/pause + body: json + auth: inherit +} + +body:json { + { + "guild_id": 1061092965579235398 + } +} diff --git a/bruno/audio/Play Track.bru b/bruno/audio/Play Track.bru new file mode 100644 index 0000000..67800e7 --- /dev/null +++ b/bruno/audio/Play Track.bru @@ -0,0 +1,18 @@ +meta { + name: Play Track + type: http + seq: 1 +} + +post { + url: {{baseUrl}}/audio/play + body: json + auth: inherit +} + +body:json { + { + "url": "https://www.youtube.com/watch?v=V-QDxuknK-Q", + "guild_id": 1061092965579235398 + } +} diff --git a/bruno/audio/Resume Track.bru b/bruno/audio/Resume Track.bru new file mode 100644 index 0000000..5c633a9 --- /dev/null +++ b/bruno/audio/Resume Track.bru @@ -0,0 +1,17 @@ +meta { + name: Resume Track + type: http + seq: 3 +} + +post { + url: {{baseUrl}}/audio/resume + body: json + auth: inherit +} + +body:json { + { + "guild_id": 1061092965579235398 + } +} diff --git a/bruno/bruno.json b/bruno/bruno.json new file mode 100644 index 0000000..3a72085 --- /dev/null +++ b/bruno/bruno.json @@ -0,0 +1,9 @@ +{ + "version": "1", + "name": "Siren", + "type": "collection", + "ignore": [ + "node_modules", + ".git" + ] +} \ No newline at end of file diff --git a/bruno/collection.bru b/bruno/collection.bru new file mode 100644 index 0000000..2a7569a --- /dev/null +++ b/bruno/collection.bru @@ -0,0 +1,11 @@ +auth { + mode: bearer +} + +auth:bearer { + token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDcwNDI3NSwiZXhwIjoxNzM0NzkwNjc1LCJqdGkiOiJMSnc1Vnk3azZjc1BiYlJRWGlNcVFFVUZlQ29JS2JqcCJ9.sdgb93DmX9_augMdktYr58m5eTIJPuY13d87pckZOns +} + +vars:pre-request { + baseUrl: http://localhost:3000/api +} diff --git a/bruno/oauth/Authorize.bru b/bruno/oauth/Authorize.bru new file mode 100644 index 0000000..433ba3a --- /dev/null +++ b/bruno/oauth/Authorize.bru @@ -0,0 +1,11 @@ +meta { + name: Authorize + type: http + seq: 1 +} + +get { + url: {{baseUrl}}/oauth/authorize + body: none + auth: inherit +} diff --git a/migrations/000_initial.sql b/migrations/000_initial.sql index 3559238..59d5246 100644 --- a/migrations/000_initial.sql +++ b/migrations/000_initial.sql @@ -16,8 +16,13 @@ CREATE TABLE IF NOT EXISTS messages ( request_tags TEXT[] NOT NULL, response_tags TEXT[] NOT NULL ); -CREATE TABLE IF NOT EXISTS dice_rolls ( - id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid() +CREATE TABLE IF NOT EXISTS dice_thresholds ( + id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(), + owner_id BIGINT NOT NULL, + dice TEXT NOT NULL, + user_id BIGINT, + value INT, + operator TEXT ); CREATE TABLE IF NOT EXISTS events ( id UUID PRIMARY KEY NOT NULL, diff --git a/src/api/audio/mod.rs b/src/api/audio/mod.rs index b9031e6..ae71db2 100644 --- a/src/api/audio/mod.rs +++ b/src/api/audio/mod.rs @@ -2,23 +2,31 @@ use std::sync::Arc; use axum::extract::State; use axum::middleware::from_extractor; use axum::{Extension, Json, Router}; +use axum::response::IntoResponse; use axum::routing::post; +use reqwest::StatusCode; use serde::Deserialize; use crate::api::auth::{AuthorizationMiddleware, Session}; use crate::AppState; use crate::bot::commands::audio::join_voice_channel; +use crate::bot::commands::audio::pause::pause_track; use crate::bot::commands::audio::play::enqueue_track; +use crate::bot::commands::audio::resume::resume_track; use crate::bot::handler::get_songbird; -use crate::error::SirenResult; +use crate::error::{Error, SirenResult}; pub fn get_routes() -> Router> { Router::new() .route("/play", post(play_audio)) .route_layer(from_extractor::()) + .route("/pause", post(pause_audio)) + .route_layer(from_extractor::()) + .route("/resume", post(resume_audio)) + .route_layer(from_extractor::()) } #[derive(Deserialize)] -struct TrackRequest { +struct PlayTrackRequest { url: String, guild_id: u64, } @@ -26,13 +34,66 @@ struct TrackRequest { async fn play_audio( Extension(session): Extension, State(state): State>, - Json(payload): Json, + Json(payload): Json, ) -> SirenResult<()> { log::debug!("Playing audio in guild: {}", payload.guild_id); + + // Check if the user exists in the cache + let user_id = match state.cache.user(session.user_id) { + Some(user) => user.id, + None => return Err(Error::not_found("User not found".to_string())), + }; + + // Validate if the guild exists in the cache + let guild_id = match state.cache.guild(payload.guild_id) { + Some(guild) => guild.id, + None => return Err(Error::not_found("Guild not found".to_string())), + }; + + // Play the track let manager = get_songbird(); - let user_id = state.cache.user(session.user_id).unwrap().id; - let guild_id = state.cache.guild(payload.guild_id).unwrap().id; let _channel_id = join_voice_channel(&state.cache, &manager, &guild_id, &user_id).await?; enqueue_track(manager, guild_id.to_owned(), &payload.url).await?; Ok(()) } + +#[derive(Deserialize)] +struct GuildTrackRequest { + guild_id: u64, +} + +async fn pause_audio( + Extension(_): Extension, + State(state): State>, + Json(payload): Json, +) -> SirenResult<()> { + log::debug!("Pausing audio in guild: {}", payload.guild_id); + + // Validate if the guild exists in the cache + let guild_id = match state.cache.guild(payload.guild_id) { + Some(guild) => guild.id, + None => return Err(Error::not_found("Guild not found".to_string())), + }; + + // Pause the track + let manager = get_songbird(); + pause_track(manager, &guild_id).await +} + +async fn resume_audio( + Extension(_): Extension, + State(state): State>, + Json(payload): Json, +) -> SirenResult<()> { + log::debug!("Pausing audio in guild: {}", payload.guild_id); + + // Validate if the guild exists in the cache + let guild_id = match state.cache.guild(payload.guild_id) { + Some(guild) => guild.id, + None => return Err(Error::not_found("Guild not found".to_string())), + }; + + // Pause the track + let manager = get_songbird(); + resume_track(manager, &guild_id).await +} diff --git a/src/api/auth/oauth.rs b/src/api/auth/oauth.rs index 84f4d7b..b93a435 100644 --- a/src/api/auth/oauth.rs +++ b/src/api/auth/oauth.rs @@ -7,6 +7,7 @@ use axum::response::Redirect; use axum::routing::get; use serde::{Deserialize, Serialize}; use crate::api::auth::bearer_token::BearerTokenClaims; +use crate::api::auth::csprng; use crate::AppState; use crate::api::auth::session::Session; use crate::error::SirenResult; @@ -42,19 +43,27 @@ struct DiscordUser { } async fn discord_authorize_redirect(State(state): State>) -> Redirect { + // Store the state + let oauth_state = csprng(16); + state.oauth_states.lock().await.insert(oauth_state.clone()); + // Construct the Discord OAuth URL let discord_auth_url = format!( - "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", - state.client_id, state.redirect_uri + "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify&state={}", + state.client_id, state.redirect_uri, oauth_state ); Redirect::temporary(&discord_auth_url) } async fn discord_authorize(State(state): State>) -> SirenResult { + // Store the state + let oauth_state = csprng(16); + state.oauth_states.lock().await.insert(oauth_state.clone()); + // Construct the Discord OAuth URL let discord_auth_url = format!( - "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", - state.client_id, state.redirect_uri + "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify&state={}", + state.client_id, state.redirect_uri, oauth_state ); Ok(discord_auth_url) } @@ -70,6 +79,18 @@ async fn oauth_callback( State(state): State>, Query(query): Query, ) -> SirenResult> { + // Validate the state + let mut oauth_states = state.oauth_states.lock().await; + match query.state { + Some(oauth_state) => { + match oauth_states.get(&oauth_state) { + Some(_) => oauth_states.remove(&oauth_state), + None => return Err(StatusCode::UNAUTHORIZED.into()), + } + } + None => return Err(StatusCode::UNAUTHORIZED)?, + }; + // Exchange code for an access token let token_response = state .client diff --git a/src/api/auth/session.rs b/src/api/auth/session.rs index 0bedbdf..9edb92c 100644 --- a/src/api/auth/session.rs +++ b/src/api/auth/session.rs @@ -42,11 +42,12 @@ impl Session { pub async fn insert(&self) -> SirenResult<()> { let mut redis = data::redis_async_connection().await?; let session_id = self.session_id.clone(); + let session_ttl = get_session_ttl(); redis .set_ex( session_id, serde_json::to_string(self)?, - self.expires_at.timestamp() as u64, + session_ttl as u64, ) .await?; Ok(()) diff --git a/src/bot/chat/mod.rs b/src/bot/chat/mod.rs index dfde377..295f3de 100644 --- a/src/bot/chat/mod.rs +++ b/src/bot/chat/mod.rs @@ -8,7 +8,7 @@ pub async fn process_message(ctx: &Context, command: &CommandInteraction, privat create_message_response(&ctx, &command, "Processing...".to_string(), private).await; } -pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option { +pub async fn user_dm(ctx: &Context, user_id: &UserId, content: String) -> Option { let data = CreateMessage::new().content(content.to_owned()); match user_id.dm(ctx, data).await { Ok(message) => Some(message), @@ -19,17 +19,6 @@ pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Opt } } -pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option { - let data = CreateMessage::new().content(content.to_owned()); - match user.direct_message(ctx, data).await { - Ok(message) => Some(message), - Err(err) => { - log::error!("Failed to create direct message for {content}\n{err}"); - None - } - } -} - pub async fn create_message_response( ctx: &Context, command: &CommandInteraction, diff --git a/src/bot/commands/audio/mod.rs b/src/bot/commands/audio/mod.rs index eca77ab..33f4c1c 100644 --- a/src/bot/commands/audio/mod.rs +++ b/src/bot/commands/audio/mod.rs @@ -73,8 +73,8 @@ fn find_voice_channel( { Some(channel) => Ok(channel), None => { - return Err(SirenError::new( - 401, + Err(SirenError::new( + 400, "User is not in a voice channel".to_string(), )) } diff --git a/src/bot/commands/audio/pause.rs b/src/bot/commands/audio/pause.rs index 866d270..c2054b1 100644 --- a/src/bot/commands/audio/pause.rs +++ b/src/bot/commands/audio/pause.rs @@ -1,10 +1,12 @@ +use std::sync::Arc; use serenity::{ - all::{CommandInteraction, CreateCommand}, + all::{CommandInteraction, CreateCommand, GuildId}, prelude::*, }; - +use songbird::Songbird; use crate::bot::chat::{edit_response, process_message}; use crate::bot::handler::get_songbird; +use crate::error::{Error, SirenResult}; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response @@ -28,23 +30,25 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { }; // Pause the track + match pause_track(manager, guild_id).await { + Ok(_) => { + log::debug!("<{guild_id}> Paused the track"); + edit_response(&ctx, &command, "Pausing the track".to_string()).await; + } + Err(err) => edit_response(&ctx, &command, format!("Failed to pause: {}", err)).await + } +} + +pub async fn pause_track(manager: &Arc, guild_id: &GuildId) -> SirenResult<()> { if let Some(handler_lock) = manager.get(guild_id.to_owned()) { let handler = handler_lock.lock().await; match handler.queue().current() { - Some(track) => match track.pause() { - Ok(_) => { - log::debug!("<{guild_id}> Paused the track"); - edit_response(&ctx, &command, "Pausing the track".to_string()).await; - } - Err(err) => { - edit_response(&ctx, &command, format!("Failed to pause: {}", err)).await; - } - }, - None => { - edit_response(ctx, command, "No track currently being played".to_string()).await; - } + Some(track) => track.pause()?, + None => return Err(Error { status: 404, details: "No track is currently playing".to_string() }) } - } + }; + + Ok(()) } pub fn register() -> CreateCommand { diff --git a/src/bot/commands/audio/resume.rs b/src/bot/commands/audio/resume.rs index b02472c..032b7e9 100644 --- a/src/bot/commands/audio/resume.rs +++ b/src/bot/commands/audio/resume.rs @@ -1,10 +1,14 @@ +use std::sync::Arc; use serenity::{ all::{CommandInteraction, CreateCommand}, prelude::*, }; - +use serenity::all::GuildId; +use songbird::Songbird; use crate::bot::chat::{edit_response, process_message}; +use crate::bot::commands::audio::pause::pause_track; use crate::bot::handler::get_songbird; +use crate::error::{Error, SirenResult}; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response @@ -28,24 +32,25 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { }; // Resume the track + match resume_track(manager, guild_id).await { + Ok(_) => { + log::debug!("<{guild_id}> Resumed the track"); + edit_response(&ctx, &command, "resuming the track".to_string()).await; + } + Err(err) => edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await + } +} + +pub async fn resume_track(manager: &Arc, guild_id: &GuildId) -> SirenResult<()> { if let Some(handler_lock) = manager.get(guild_id.to_owned()) { let handler = handler_lock.lock().await; match handler.queue().current() { - Some(track) => match track.play() { - Ok(_) => { - log::debug!("<{guild_id}> Resumed the track"); - edit_response(&ctx, &command, "Resuming the track".to_string()).await; - } - Err(err) => { - edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await; - } - }, - None => { - edit_response(&ctx, &command, "No track is currently playing".to_string()).await; - return; - } + Some(track) => track.play()?, + None => return Err(Error { status: 404, details: "No track is currently playing".to_string() }) } - } + }; + + Ok(()) } pub fn register() -> CreateCommand { diff --git a/src/bot/commands/fun/mod.rs b/src/bot/commands/fun/mod.rs index 5f492dc..21cdb3e 100644 --- a/src/bot/commands/fun/mod.rs +++ b/src/bot/commands/fun/mod.rs @@ -1 +1,2 @@ pub mod roll; +pub mod request_roll; diff --git a/src/bot/commands/fun/request_roll.rs b/src/bot/commands/fun/request_roll.rs new file mode 100644 index 0000000..1735d3f --- /dev/null +++ b/src/bot/commands/fun/request_roll.rs @@ -0,0 +1,95 @@ +use serenity::all::{ButtonStyle, CommandInteraction, CommandOptionType, Context, CreateActionRow, CreateButton, CreateCommand, CreateCommandOption, CreateMessage, Mentionable, UserId}; +use serenity::builder::CreateEmbed; +use crate::bot::chat::{create_message_response, edit_response}; +use crate::bot::commands::fun::roll::parse_dice; + +pub async fn run(ctx: &Context, command: &CommandInteraction) { + // Check if the roll result is hidden + let hidden = command + .data + .options + .iter() + .find(|opt| opt.name == "hidden") + .and_then(|o| o.value.as_bool()) + .unwrap_or(false); + + // Retrieve the user + let user = command + .data + .options + .iter() + .find(|opt| opt.name == "user") + .and_then(|o| o.value.as_mentionable()).unwrap(); + + let user_id = UserId::new(user.get()); + + create_message_response(ctx, &command, format!("Sending request to {}", user_id.mention()), true).await; + + let dice_string = command + .data + .options + .get(0) + .and_then(|o| o.value.as_str()) + .map(|s| s.split_whitespace().collect::()).unwrap(); + + let dice_result = parse_dice(dice_string.as_str()); + match dice_result { + Ok(dice) => { + // let roll_button = CreateButton::new(format!("request_dice_roll|{}|{}|{}|{}|{}", dice.0, dice.1, dice.2, command.user.id.get(), hidden)) + // .label("Roll") + // .style(ButtonStyle::Primary); + // let action_row = CreateActionRow::Buttons(vec![roll_button]); + // + // let embed = CreateEmbed::new() + // .title("🎲 Dice roll request! 🎲".to_string()) + // .color(0x00FF00) + // .description(format!("{} Requested a dice roll of {}", command.user.mention(), dice_string)); + // + // let message = CreateMessage::new() + // .embed(embed) + // .components(vec![action_row]); + + let roll_button = CreateButton::new(format!("request_dice_roll|{}|{}|{}|{}|{}", dice.0, dice.1, dice.2, command.user.id.get(), hidden)) + .label(format!("🎲 Roll {} 🎲", dice_string)) // The label you want on the button + .style(ButtonStyle::Primary); + + let action_row = CreateActionRow::Buttons(vec![roll_button]); + + let message = CreateMessage::new() + .content(format!("-# Roll requested from {}", command.user.mention())) + .components(vec![action_row]); + + if let Err(why) = user_id.dm(ctx, message).await { + log::error!("failed to send request due to {}", why); + edit_response(ctx, command, "Unable to send dice request".to_string()).await; + }; + } + Err(why) => { + edit_response(ctx, &command, why.to_string()).await; + } + } +} + +pub fn register() -> CreateCommand { + CreateCommand::new("requestroll") + .description("Request a dice roll from a user") + .add_option( + CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll") + .required(true), + ) + .add_option( + CreateCommandOption::new( + CommandOptionType::Mentionable, + "user", + "User to receive the dice roll request" + ) + .required(true), + ) + .add_option( + CreateCommandOption::new( + CommandOptionType::Boolean, + "hidden", + "Hide the dice roll from the user (Default: False") + .required(false), + ) +} \ No newline at end of file diff --git a/src/bot/commands/fun/roll.rs b/src/bot/commands/fun/roll.rs index 37d2c86..bfbc671 100644 --- a/src/bot/commands/fun/roll.rs +++ b/src/bot/commands/fun/roll.rs @@ -9,19 +9,6 @@ use serenity::all::{ use crate::bot::chat::{create_message_response, edit_response}; use crate::utils::{a_or_an, number_to_words}; -lazy_static::lazy_static! { - static ref SAVED_ROLLS: Mutex>> = Mutex::new(HashMap::new()); -} - -pub fn temp() { - // // Add to the HashMap after processing the modal - // let mut saved_rolls = SAVED_ROLLS.lock().unwrap(); - // saved_rolls - // .entry(user_id) - // .or_default() - // .push((dice_roll, description.clone())); -} - pub async fn run(ctx: &Context, command: &CommandInteraction) { // Check if the roll result is private let private = command @@ -32,7 +19,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { .and_then(|o| o.value.as_bool()) .unwrap_or(true); - // Retrieve the DM's name or ID from the options (optional) + // Retrieve the user if present let user = command .data .options @@ -40,7 +27,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { .find(|opt| opt.name == "user") .and_then(|o| o.value.as_mentionable()); - create_message_response(&ctx, &command, "Rolling...".to_string(), private).await; + create_message_response(ctx, &command, "Rolling...".to_string(), private).await; let dice_string = match command .data @@ -60,61 +47,14 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let dice = parse_dice(dice_string.as_str()); match dice { Ok((count, sides, modifier)) => { - let mut rolls = Vec::new(); - let mut total = 0; - for _ in 0..count { - let roll = rand::thread_rng().gen_range(1..=sides); - total += roll; - rolls.push(roll); - } - - let response = ( - total + (modifier as u32), - format!( - "(Rolled {}d{}{})", - count, - sides, - if modifier > 0 { - format!("+{}", modifier) - } else if modifier < 0 { - format!("-{}", modifier) - } else { - "".to_string() - } - ), - ); + let total = roll_dice(count, sides, modifier); + let response = format!("(Rolled {})", format_roll(count, sides, modifier)); match user { Some(id) => { let user_id = UserId::new(id.get()); - - // Create the dice roll embed - let a = a_or_an(&number_to_words(response.0 as i32)); - let embed = CreateEmbed::new() - .title("🎲 Received a dice roll! 🎲".to_string()) - .color(0x00FF00) - .description(format!( - "{} rolled {} **{}**\n-# *{}*", - &command.user.mention(), - a, - response.0, - response.1 - )); - - // Create a button with a custom ID - let save_button = CreateButton::new("save_dice_roll") - .label("💾") - .style(ButtonStyle::Primary); - - // Action row to hold the button - let action_row = CreateActionRow::Buttons(vec![save_button]); - - let message = CreateMessage::new() - .embed(embed) - .components(vec![action_row]); - if let Err(err) = user_id.dm(&ctx, message).await { - log::error!("Could not send message: {}", err); - } + let roller_id = command.user.id; + send_roll_message(ctx, total, user_id, roller_id, &response).await; edit_response( &ctx, command, @@ -126,7 +66,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { edit_response( &ctx, &command, - format!("🎲 {}\n-# {}", response.0, response.1), + format!("🎲 {}\n-# {}", total, response), ) .await } @@ -138,7 +78,52 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { } } -fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { +pub async fn send_roll_message(ctx: &Context, total: i32, user_id: UserId, roller_id: UserId, dice_string: &str) { + // Create the dice roll embed + let a = a_or_an(&number_to_words(total)); + let embed = CreateEmbed::new() + .title("🎲 Received a dice roll! 🎲".to_string()) + .color(0x00FF00) + .description(format!( + "{} rolled {} **{}**\n-# *{}*", + &roller_id.mention(), + a, + total, + dice_string + )); + + let message = CreateMessage::new().embed(embed); + if let Err(err) = user_id.dm(ctx, message).await { + log::error!("Could not send message: {}", err); + } +} + +pub fn format_roll(count: u32, sides: u32, modifier: i32) -> String { + format!( + "{}d{}{}", + count, + sides, + if modifier > 0 { + format!("+{}", modifier) + } else if modifier < 0 { + format!("-{}", modifier) + } else { + "".to_string() + }) +} + +pub fn roll_dice(count: u32, sides: u32, modifier: i32) -> i32 { + let mut rolls = Vec::new(); + let mut total = modifier; + for _ in 0..count { + let roll = rand::thread_rng().gen_range(1..=sides as i32); + total += roll; + rolls.push(roll); + } + total +} + +pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { // If the input is just a number (e.g., "20" or "6"), assume it's the number of sides if let Ok(n) = dice.parse::() { return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers @@ -214,7 +199,7 @@ fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { pub fn register() -> CreateCommand { CreateCommand::new("roll") - .description("Rolls D&D dice") + .description("Roll dice") .add_option( CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true), ) diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 75a0969..0080ca3 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -1,17 +1,19 @@ use std::env; use std::sync::{Arc, OnceLock}; -use serenity::all::{Interaction, ResumedEvent}; +use serenity::all::{CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage, EditInteractionResponse, Interaction, ResumedEvent, UserId}; use serenity::async_trait; use serenity::model::gateway::Ready; use serenity::model::channel::Message; use serenity::prelude::*; use songbird::Songbird; use crate::bot::commands::chat::generate_response; +use crate::bot::commands::fun::roll::{format_roll, roll_dice, send_roll_message}; use crate::bot::oai::OAI; use crate::data::guilds::GuildCache; use crate::HttpKey; +use crate::utils::{a_or_an, number_to_words}; use super::{commands}; -use super::chat::{create_modal_response}; +use super::chat::{create_modal_response, user_dm}; pub struct BotHandler { // Open AI Config @@ -82,18 +84,22 @@ impl EventHandler for BotHandler { log::warn!("No ready guilds found"); } - let songbird = songbird::get(&ctx).await.unwrap(); - SONGBIRD - .set(songbird.clone()) - .expect("Songbird value could not be set"); - let http_client = { - let data = ctx.data.read().await; - data - .get::() - .cloned() - .expect("Guaranteed to exist in the typemap.") - }; - CLIENT.set(http_client).ok(); + if SONGBIRD.get().is_none() { + let songbird = songbird::get(&ctx).await.unwrap(); + SONGBIRD + .set(songbird.clone()) + .expect("Songbird value could not be set"); + } + if CLIENT.get().is_none() { + let http_client = { + let data = ctx.data.read().await; + data + .get::() + .cloned() + .expect("Guaranteed to exist in the typemap.") + }; + CLIENT.set(http_client).ok(); + } log::trace!("Handling {} guilds", ready.guilds.len()); for guild in ready.guilds { @@ -122,6 +128,7 @@ impl EventHandler for BotHandler { commands::audio::volume::register(), commands::event::schedule::register(), commands::fun::roll::register(), + commands::fun::request_roll::register(), commands::utility::ping::register(), ], ) @@ -146,10 +153,8 @@ impl EventHandler for BotHandler { } async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - if let Interaction::Ping(ping) = interaction { - log::trace!("Received interaction ping: {:?}", ping); - } else if let Interaction::Command(command) = interaction { - log::trace!("Received command interaction: {command:#?}"); + if let Interaction::Command(command) = interaction { + log::trace!("Received COMMAND"); match command.data.name.as_str() { // Match commands without returns "play" => commands::audio::play::run(&ctx, &command).await, @@ -161,11 +166,47 @@ impl EventHandler for BotHandler { "volume" => commands::audio::volume::run(&ctx, &command).await, "schedule" => commands::event::schedule::run(&ctx, &command).await, "roll" => commands::fun::roll::run(&ctx, &command).await, + "requestroll" => commands::fun::request_roll::run(&ctx, &command).await, "ping" => commands::utility::ping::run(&ctx, &command).await, _ => {} } + } else if let Interaction::Component(component) = interaction { + log::trace!("Received COMPONENT"); + let custom_id = &component.data.custom_id; + if custom_id.starts_with("request_dice_roll") { + // Acknowledge the interaction + if let Err(err) = component.create_response(ctx.http.clone(), CreateInteractionResponse::Acknowledge).await { + log::error!("Could not create dice response: {}", err); + }; + let parts = custom_id.split('|').collect::>(); + if parts.len() == 6 { + let count = parts[1].parse().unwrap(); + let sides = parts[2].parse().unwrap(); + let modifier = parts[3].parse().unwrap(); + let result = roll_dice(count, sides, modifier); + let response = format!("(Rolled {})", format_roll(count, sides, modifier)); + let user_id = UserId::from(parts[4].parse::().unwrap()); + let roller_id = component.user.id; + let hidden: bool = parts[5].parse().unwrap(); + send_roll_message(&ctx, result, user_id, roller_id, &response).await; + component.delete_response(ctx.http.clone()).await.ok(); + let message; + if hidden { + message = format!("Results sent to {}", user_id.mention()); + } else { + message = format!("🎲 You rolled {} {}\n-# {}", a_or_an(&number_to_words(result)), result, response); + } + user_dm(&ctx, &component.user.id, message).await; + } else { + log::error!("Could not handle dice click: {}", custom_id); + } + } + } else if let Interaction::Ping(_ping) = interaction { + log::trace!("Received PING"); + } else if let Interaction::Autocomplete(_autocomplete) = interaction { + log::trace!("Received AUTOCOMPLETE"); } else if let Interaction::Modal(modal) = interaction { - log::trace!("Received interaction modal: {:?}", modal); + log::trace!("Received MODAL"); create_modal_response(&ctx, &modal).await; } } diff --git a/src/error.rs b/src/error.rs index 60ff4d2..89c2033 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use std::fmt; use axum::http::StatusCode; -use axum::Json; +use axum::{http, Json}; use axum::response::{IntoResponse, Response}; use serde::{Deserialize, Serialize}; @@ -13,12 +13,20 @@ pub struct Error { } impl Error { - pub fn new(error_status_code: u16, error_message: String) -> Self { + pub fn new(status: u16, details: String) -> Self { Self { - status: error_status_code, - details: error_message, + status, + details, } } + + pub fn not_found(details: String) -> Self { + Self::new(404, details) + } + + pub fn internal_server_error(details: String) -> Self { + Self::new(500, details) + } } impl fmt::Display for Error { @@ -56,6 +64,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: songbird::tracks::ControlError) -> Self { + Self::new(500, format!("Unknown control error: {}", error)) + } +} + impl From for Error { fn from(status: StatusCode) -> Self { Error { diff --git a/src/main.rs b/src/main.rs index b39a621..583eac2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::env; use std::sync::Arc; use dotenv::{dotenv, from_filename}; @@ -27,6 +28,7 @@ struct AppState { client_id: String, client_secret: String, redirect_uri: String, + oauth_states: Arc>>, http: Arc, cache: Arc, } @@ -71,6 +73,7 @@ async fn main() -> Result<(), Box> { client_id: bot_id.to_string(), client_secret, redirect_uri, + oauth_states: Arc::new(Mutex::new(HashSet::new())), http: Arc::clone(&client.http), cache: Arc::clone(&client.cache), };