Implemented roll request, updated API requests

This commit is contained in:
2024-12-20 15:13:31 -05:00
parent bb03654d5f
commit 8ac0e59b8c
23 changed files with 459 additions and 150 deletions

2
.env
View File

@@ -7,7 +7,7 @@ JWT_SECRET=CHANGEME # Change this to a secure secret
DATABASE_USER=siren DATABASE_USER=siren
DATABASE_PASSWORD=CHANGEME # Change this to a secure password DATABASE_PASSWORD=CHANGEME # Change this to a secure password
DATABASE_NAME=siren DATABASE_NAME=siren_db
DATABASE_HOST=localhost DATABASE_HOST=localhost
DATABASE_PORT=5432 DATABASE_PORT=5432

View File

@@ -26,7 +26,7 @@ rand = "0.8.5"
rand_chacha = "0.3.1" rand_chacha = "0.3.1"
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] } tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
regex = "1.11.0" regex = "1.11.0"
axum = "0.7.7" axum = { version = "0.7.7", features = ["json"] }
axum-extra = { version = "0.9.6", features = ["typed-header"] } axum-extra = { version = "0.9.6", features = ["typed-header"] }
lazy_static = "1.5.0" lazy_static = "1.5.0"
jsonwebtoken = "9.3.0" jsonwebtoken = "9.3.0"

View File

@@ -105,6 +105,7 @@ Siren utilizes Discord slash commands. To view the commands, run `/help` in a se
| --- | --- | | --- | --- |
| `/coinflip` | Flip a coin | | `/coinflip` | Flip a coin |
| `/roll <Dice>` | Roll a dice | | `/roll <Dice>` | Roll a dice |
| `/requestroll <User> <Dice>` | Request a dice roll from a user |
**Utility Commands** **Utility Commands**
| Command | Description | | Command | Description |

View File

@@ -0,0 +1,17 @@
meta {
name: Pause Track
type: http
seq: 2
}
post {
url: {{baseUrl}}/audio/pause
body: json
auth: inherit
}
body:json {
{
"guild_id": 1061092965579235398
}
}

View File

@@ -0,0 +1,18 @@
meta {
name: Play Track
type: http
seq: 1
}
post {
url: {{baseUrl}}/audio/play
body: json
auth: inherit
}
body:json {
{
"url": "https://www.youtube.com/watch?v=V-QDxuknK-Q",
"guild_id": 1061092965579235398
}
}

View File

@@ -0,0 +1,17 @@
meta {
name: Resume Track
type: http
seq: 3
}
post {
url: {{baseUrl}}/audio/resume
body: json
auth: inherit
}
body:json {
{
"guild_id": 1061092965579235398
}
}

9
bruno/bruno.json Normal file
View File

@@ -0,0 +1,9 @@
{
"version": "1",
"name": "Siren",
"type": "collection",
"ignore": [
"node_modules",
".git"
]
}

11
bruno/collection.bru Normal file
View File

@@ -0,0 +1,11 @@
auth {
mode: bearer
}
auth:bearer {
token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDcwNDI3NSwiZXhwIjoxNzM0NzkwNjc1LCJqdGkiOiJMSnc1Vnk3azZjc1BiYlJRWGlNcVFFVUZlQ29JS2JqcCJ9.sdgb93DmX9_augMdktYr58m5eTIJPuY13d87pckZOns
}
vars:pre-request {
baseUrl: http://localhost:3000/api
}

11
bruno/oauth/Authorize.bru Normal file
View File

@@ -0,0 +1,11 @@
meta {
name: Authorize
type: http
seq: 1
}
get {
url: {{baseUrl}}/oauth/authorize
body: none
auth: inherit
}

View File

@@ -16,8 +16,13 @@ CREATE TABLE IF NOT EXISTS messages (
request_tags TEXT[] NOT NULL, request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL response_tags TEXT[] NOT NULL
); );
CREATE TABLE IF NOT EXISTS dice_rolls ( CREATE TABLE IF NOT EXISTS dice_thresholds (
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid() id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(),
owner_id BIGINT NOT NULL,
dice TEXT NOT NULL,
user_id BIGINT,
value INT,
operator TEXT
); );
CREATE TABLE IF NOT EXISTS events ( CREATE TABLE IF NOT EXISTS events (
id UUID PRIMARY KEY NOT NULL, id UUID PRIMARY KEY NOT NULL,

View File

@@ -2,23 +2,31 @@ use std::sync::Arc;
use axum::extract::State; use axum::extract::State;
use axum::middleware::from_extractor; use axum::middleware::from_extractor;
use axum::{Extension, Json, Router}; use axum::{Extension, Json, Router};
use axum::response::IntoResponse;
use axum::routing::post; use axum::routing::post;
use reqwest::StatusCode;
use serde::Deserialize; use serde::Deserialize;
use crate::api::auth::{AuthorizationMiddleware, Session}; use crate::api::auth::{AuthorizationMiddleware, Session};
use crate::AppState; use crate::AppState;
use crate::bot::commands::audio::join_voice_channel; use crate::bot::commands::audio::join_voice_channel;
use crate::bot::commands::audio::pause::pause_track;
use crate::bot::commands::audio::play::enqueue_track; use crate::bot::commands::audio::play::enqueue_track;
use crate::bot::commands::audio::resume::resume_track;
use crate::bot::handler::get_songbird; use crate::bot::handler::get_songbird;
use crate::error::SirenResult; use crate::error::{Error, SirenResult};
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()
.route("/play", post(play_audio)) .route("/play", post(play_audio))
.route_layer(from_extractor::<AuthorizationMiddleware>()) .route_layer(from_extractor::<AuthorizationMiddleware>())
.route("/pause", post(pause_audio))
.route_layer(from_extractor::<AuthorizationMiddleware>())
.route("/resume", post(resume_audio))
.route_layer(from_extractor::<AuthorizationMiddleware>())
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct TrackRequest { struct PlayTrackRequest {
url: String, url: String,
guild_id: u64, guild_id: u64,
} }
@@ -26,13 +34,66 @@ struct TrackRequest {
async fn play_audio( async fn play_audio(
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(payload): Json<TrackRequest>, Json(payload): Json<PlayTrackRequest>,
) -> SirenResult<()> { ) -> SirenResult<()> {
log::debug!("Playing audio in guild: {}", payload.guild_id); log::debug!("Playing audio in guild: {}", payload.guild_id);
// Check if the user exists in the cache
let user_id = match state.cache.user(session.user_id) {
Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())),
};
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Play the track
let manager = get_songbird(); let manager = get_songbird();
let user_id = state.cache.user(session.user_id).unwrap().id;
let guild_id = state.cache.guild(payload.guild_id).unwrap().id;
let _channel_id = join_voice_channel(&state.cache, &manager, &guild_id, &user_id).await?; let _channel_id = join_voice_channel(&state.cache, &manager, &guild_id, &user_id).await?;
enqueue_track(manager, guild_id.to_owned(), &payload.url).await?; enqueue_track(manager, guild_id.to_owned(), &payload.url).await?;
Ok(()) Ok(())
} }
#[derive(Deserialize)]
struct GuildTrackRequest {
guild_id: u64,
}
async fn pause_audio(
Extension(_): Extension<Session>,
State(state): State<Arc<AppState>>,
Json(payload): Json<GuildTrackRequest>,
) -> SirenResult<()> {
log::debug!("Pausing audio in guild: {}", payload.guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Pause the track
let manager = get_songbird();
pause_track(manager, &guild_id).await
}
async fn resume_audio(
Extension(_): Extension<Session>,
State(state): State<Arc<AppState>>,
Json(payload): Json<GuildTrackRequest>,
) -> SirenResult<()> {
log::debug!("Pausing audio in guild: {}", payload.guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
// Pause the track
let manager = get_songbird();
resume_track(manager, &guild_id).await
}

View File

@@ -7,6 +7,7 @@ use axum::response::Redirect;
use axum::routing::get; use axum::routing::get;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::auth::bearer_token::BearerTokenClaims; use crate::api::auth::bearer_token::BearerTokenClaims;
use crate::api::auth::csprng;
use crate::AppState; use crate::AppState;
use crate::api::auth::session::Session; use crate::api::auth::session::Session;
use crate::error::SirenResult; use crate::error::SirenResult;
@@ -42,19 +43,27 @@ struct DiscordUser {
} }
async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect { async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
// Store the state
let oauth_state = csprng(16);
state.oauth_states.lock().await.insert(oauth_state.clone());
// Construct the Discord OAuth URL // Construct the Discord OAuth URL
let discord_auth_url = format!( let discord_auth_url = format!(
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify&state={}",
state.client_id, state.redirect_uri state.client_id, state.redirect_uri, oauth_state
); );
Redirect::temporary(&discord_auth_url) Redirect::temporary(&discord_auth_url)
} }
async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<String> { async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<String> {
// Store the state
let oauth_state = csprng(16);
state.oauth_states.lock().await.insert(oauth_state.clone());
// Construct the Discord OAuth URL // Construct the Discord OAuth URL
let discord_auth_url = format!( let discord_auth_url = format!(
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify&state={}",
state.client_id, state.redirect_uri state.client_id, state.redirect_uri, oauth_state
); );
Ok(discord_auth_url) Ok(discord_auth_url)
} }
@@ -70,6 +79,18 @@ async fn oauth_callback(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(query): Query<AuthQuery>, Query(query): Query<AuthQuery>,
) -> SirenResult<Json<BearerTokenResponse>> { ) -> SirenResult<Json<BearerTokenResponse>> {
// Validate the state
let mut oauth_states = state.oauth_states.lock().await;
match query.state {
Some(oauth_state) => {
match oauth_states.get(&oauth_state) {
Some(_) => oauth_states.remove(&oauth_state),
None => return Err(StatusCode::UNAUTHORIZED.into()),
}
}
None => return Err(StatusCode::UNAUTHORIZED)?,
};
// Exchange code for an access token // Exchange code for an access token
let token_response = state let token_response = state
.client .client

View File

@@ -42,11 +42,12 @@ impl Session {
pub async fn insert(&self) -> SirenResult<()> { pub async fn insert(&self) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?; let mut redis = data::redis_async_connection().await?;
let session_id = self.session_id.clone(); let session_id = self.session_id.clone();
let session_ttl = get_session_ttl();
redis redis
.set_ex( .set_ex(
session_id, session_id,
serde_json::to_string(self)?, serde_json::to_string(self)?,
self.expires_at.timestamp() as u64, session_ttl as u64,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@@ -8,7 +8,7 @@ pub async fn process_message(ctx: &Context, command: &CommandInteraction, privat
create_message_response(&ctx, &command, "Processing...".to_string(), private).await; create_message_response(&ctx, &command, "Processing...".to_string(), private).await;
} }
pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> { pub async fn user_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> {
let data = CreateMessage::new().content(content.to_owned()); let data = CreateMessage::new().content(content.to_owned());
match user_id.dm(ctx, data).await { match user_id.dm(ctx, data).await {
Ok(message) => Some(message), Ok(message) => Some(message),
@@ -19,17 +19,6 @@ pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Opt
} }
} }
pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option<Message> {
let data = CreateMessage::new().content(content.to_owned());
match user.direct_message(ctx, data).await {
Ok(message) => Some(message),
Err(err) => {
log::error!("Failed to create direct message for {content}\n{err}");
None
}
}
}
pub async fn create_message_response( pub async fn create_message_response(
ctx: &Context, ctx: &Context,
command: &CommandInteraction, command: &CommandInteraction,

View File

@@ -73,8 +73,8 @@ fn find_voice_channel(
{ {
Some(channel) => Ok(channel), Some(channel) => Ok(channel),
None => { None => {
return Err(SirenError::new( Err(SirenError::new(
401, 400,
"User is not in a voice channel".to_string(), "User is not in a voice channel".to_string(),
)) ))
} }

View File

@@ -1,10 +1,12 @@
use std::sync::Arc;
use serenity::{ use serenity::{
all::{CommandInteraction, CreateCommand}, all::{CommandInteraction, CreateCommand, GuildId},
prelude::*, prelude::*,
}; };
use songbird::Songbird;
use crate::bot::chat::{edit_response, process_message}; use crate::bot::chat::{edit_response, process_message};
use crate::bot::handler::get_songbird; use crate::bot::handler::get_songbird;
use crate::error::{Error, SirenResult};
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response // Create the initial response
@@ -28,23 +30,25 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
}; };
// Pause the track // Pause the track
if let Some(handler_lock) = manager.get(guild_id.to_owned()) { match pause_track(manager, guild_id).await {
let handler = handler_lock.lock().await;
match handler.queue().current() {
Some(track) => match track.pause() {
Ok(_) => { Ok(_) => {
log::debug!("<{guild_id}> Paused the track"); log::debug!("<{guild_id}> Paused the track");
edit_response(&ctx, &command, "Pausing the track".to_string()).await; edit_response(&ctx, &command, "Pausing the track".to_string()).await;
} }
Err(err) => { Err(err) => edit_response(&ctx, &command, format!("Failed to pause: {}", err)).await
edit_response(&ctx, &command, format!("Failed to pause: {}", err)).await;
}
},
None => {
edit_response(ctx, command, "No track currently being played".to_string()).await;
} }
} }
pub async fn pause_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenResult<()> {
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
let handler = handler_lock.lock().await;
match handler.queue().current() {
Some(track) => track.pause()?,
None => return Err(Error { status: 404, details: "No track is currently playing".to_string() })
} }
};
Ok(())
} }
pub fn register() -> CreateCommand { pub fn register() -> CreateCommand {

View File

@@ -1,10 +1,14 @@
use std::sync::Arc;
use serenity::{ use serenity::{
all::{CommandInteraction, CreateCommand}, all::{CommandInteraction, CreateCommand},
prelude::*, prelude::*,
}; };
use serenity::all::GuildId;
use songbird::Songbird;
use crate::bot::chat::{edit_response, process_message}; use crate::bot::chat::{edit_response, process_message};
use crate::bot::commands::audio::pause::pause_track;
use crate::bot::handler::get_songbird; use crate::bot::handler::get_songbird;
use crate::error::{Error, SirenResult};
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response // Create the initial response
@@ -28,24 +32,25 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
}; };
// Resume the track // Resume the track
match resume_track(manager, guild_id).await {
Ok(_) => {
log::debug!("<{guild_id}> Resumed the track");
edit_response(&ctx, &command, "resuming the track".to_string()).await;
}
Err(err) => edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await
}
}
pub async fn resume_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenResult<()> {
if let Some(handler_lock) = manager.get(guild_id.to_owned()) { if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
let handler = handler_lock.lock().await; let handler = handler_lock.lock().await;
match handler.queue().current() { match handler.queue().current() {
Some(track) => match track.play() { Some(track) => track.play()?,
Ok(_) => { None => return Err(Error { status: 404, details: "No track is currently playing".to_string() })
log::debug!("<{guild_id}> Resumed the track");
edit_response(&ctx, &command, "Resuming the track".to_string()).await;
}
Err(err) => {
edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await;
}
},
None => {
edit_response(&ctx, &command, "No track is currently playing".to_string()).await;
return;
}
}
} }
};
Ok(())
} }
pub fn register() -> CreateCommand { pub fn register() -> CreateCommand {

View File

@@ -1 +1,2 @@
pub mod roll; pub mod roll;
pub mod request_roll;

View File

@@ -0,0 +1,95 @@
use serenity::all::{ButtonStyle, CommandInteraction, CommandOptionType, Context, CreateActionRow, CreateButton, CreateCommand, CreateCommandOption, CreateMessage, Mentionable, UserId};
use serenity::builder::CreateEmbed;
use crate::bot::chat::{create_message_response, edit_response};
use crate::bot::commands::fun::roll::parse_dice;
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Check if the roll result is hidden
let hidden = command
.data
.options
.iter()
.find(|opt| opt.name == "hidden")
.and_then(|o| o.value.as_bool())
.unwrap_or(false);
// Retrieve the user
let user = command
.data
.options
.iter()
.find(|opt| opt.name == "user")
.and_then(|o| o.value.as_mentionable()).unwrap();
let user_id = UserId::new(user.get());
create_message_response(ctx, &command, format!("Sending request to {}", user_id.mention()), true).await;
let dice_string = command
.data
.options
.get(0)
.and_then(|o| o.value.as_str())
.map(|s| s.split_whitespace().collect::<String>()).unwrap();
let dice_result = parse_dice(dice_string.as_str());
match dice_result {
Ok(dice) => {
// let roll_button = CreateButton::new(format!("request_dice_roll|{}|{}|{}|{}|{}", dice.0, dice.1, dice.2, command.user.id.get(), hidden))
// .label("Roll")
// .style(ButtonStyle::Primary);
// let action_row = CreateActionRow::Buttons(vec![roll_button]);
//
// let embed = CreateEmbed::new()
// .title("🎲 Dice roll request! 🎲".to_string())
// .color(0x00FF00)
// .description(format!("{} Requested a dice roll of {}", command.user.mention(), dice_string));
//
// let message = CreateMessage::new()
// .embed(embed)
// .components(vec![action_row]);
let roll_button = CreateButton::new(format!("request_dice_roll|{}|{}|{}|{}|{}", dice.0, dice.1, dice.2, command.user.id.get(), hidden))
.label(format!("🎲 Roll {} 🎲", dice_string)) // The label you want on the button
.style(ButtonStyle::Primary);
let action_row = CreateActionRow::Buttons(vec![roll_button]);
let message = CreateMessage::new()
.content(format!("-# Roll requested from {}", command.user.mention()))
.components(vec![action_row]);
if let Err(why) = user_id.dm(ctx, message).await {
log::error!("failed to send request due to {}", why);
edit_response(ctx, command, "Unable to send dice request".to_string()).await;
};
}
Err(why) => {
edit_response(ctx, &command, why.to_string()).await;
}
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("requestroll")
.description("Request a dice roll from a user")
.add_option(
CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll")
.required(true),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::Mentionable,
"user",
"User to receive the dice roll request"
)
.required(true),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::Boolean,
"hidden",
"Hide the dice roll from the user (Default: False")
.required(false),
)
}

View File

@@ -9,19 +9,6 @@ use serenity::all::{
use crate::bot::chat::{create_message_response, edit_response}; use crate::bot::chat::{create_message_response, edit_response};
use crate::utils::{a_or_an, number_to_words}; use crate::utils::{a_or_an, number_to_words};
lazy_static::lazy_static! {
static ref SAVED_ROLLS: Mutex<HashMap<UserId, Vec<(i32, String)>>> = Mutex::new(HashMap::new());
}
pub fn temp() {
// // Add to the HashMap after processing the modal
// let mut saved_rolls = SAVED_ROLLS.lock().unwrap();
// saved_rolls
// .entry(user_id)
// .or_default()
// .push((dice_roll, description.clone()));
}
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Check if the roll result is private // Check if the roll result is private
let private = command let private = command
@@ -32,7 +19,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
.and_then(|o| o.value.as_bool()) .and_then(|o| o.value.as_bool())
.unwrap_or(true); .unwrap_or(true);
// Retrieve the DM's name or ID from the options (optional) // Retrieve the user if present
let user = command let user = command
.data .data
.options .options
@@ -40,7 +27,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
.find(|opt| opt.name == "user") .find(|opt| opt.name == "user")
.and_then(|o| o.value.as_mentionable()); .and_then(|o| o.value.as_mentionable());
create_message_response(&ctx, &command, "Rolling...".to_string(), private).await; create_message_response(ctx, &command, "Rolling...".to_string(), private).await;
let dice_string = match command let dice_string = match command
.data .data
@@ -60,61 +47,14 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let dice = parse_dice(dice_string.as_str()); let dice = parse_dice(dice_string.as_str());
match dice { match dice {
Ok((count, sides, modifier)) => { Ok((count, sides, modifier)) => {
let mut rolls = Vec::new(); let total = roll_dice(count, sides, modifier);
let mut total = 0; let response = format!("(Rolled {})", format_roll(count, sides, modifier));
for _ in 0..count {
let roll = rand::thread_rng().gen_range(1..=sides);
total += roll;
rolls.push(roll);
}
let response = (
total + (modifier as u32),
format!(
"(Rolled {}d{}{})",
count,
sides,
if modifier > 0 {
format!("+{}", modifier)
} else if modifier < 0 {
format!("-{}", modifier)
} else {
"".to_string()
}
),
);
match user { match user {
Some(id) => { Some(id) => {
let user_id = UserId::new(id.get()); let user_id = UserId::new(id.get());
let roller_id = command.user.id;
// Create the dice roll embed send_roll_message(ctx, total, user_id, roller_id, &response).await;
let a = a_or_an(&number_to_words(response.0 as i32));
let embed = CreateEmbed::new()
.title("🎲 Received a dice roll! 🎲".to_string())
.color(0x00FF00)
.description(format!(
"{} rolled {} **{}**\n-# *{}*",
&command.user.mention(),
a,
response.0,
response.1
));
// Create a button with a custom ID
let save_button = CreateButton::new("save_dice_roll")
.label("💾")
.style(ButtonStyle::Primary);
// Action row to hold the button
let action_row = CreateActionRow::Buttons(vec![save_button]);
let message = CreateMessage::new()
.embed(embed)
.components(vec![action_row]);
if let Err(err) = user_id.dm(&ctx, message).await {
log::error!("Could not send message: {}", err);
}
edit_response( edit_response(
&ctx, &ctx,
command, command,
@@ -126,7 +66,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
edit_response( edit_response(
&ctx, &ctx,
&command, &command,
format!("🎲 {}\n-# {}", response.0, response.1), format!("🎲 {}\n-# {}", total, response),
) )
.await .await
} }
@@ -138,7 +78,52 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
} }
} }
fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { pub async fn send_roll_message(ctx: &Context, total: i32, user_id: UserId, roller_id: UserId, dice_string: &str) {
// Create the dice roll embed
let a = a_or_an(&number_to_words(total));
let embed = CreateEmbed::new()
.title("🎲 Received a dice roll! 🎲".to_string())
.color(0x00FF00)
.description(format!(
"{} rolled {} **{}**\n-# *{}*",
&roller_id.mention(),
a,
total,
dice_string
));
let message = CreateMessage::new().embed(embed);
if let Err(err) = user_id.dm(ctx, message).await {
log::error!("Could not send message: {}", err);
}
}
pub fn format_roll(count: u32, sides: u32, modifier: i32) -> String {
format!(
"{}d{}{}",
count,
sides,
if modifier > 0 {
format!("+{}", modifier)
} else if modifier < 0 {
format!("-{}", modifier)
} else {
"".to_string()
})
}
pub fn roll_dice(count: u32, sides: u32, modifier: i32) -> i32 {
let mut rolls = Vec::new();
let mut total = modifier;
for _ in 0..count {
let roll = rand::thread_rng().gen_range(1..=sides as i32);
total += roll;
rolls.push(roll);
}
total
}
pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
// If the input is just a number (e.g., "20" or "6"), assume it's the number of sides // If the input is just a number (e.g., "20" or "6"), assume it's the number of sides
if let Ok(n) = dice.parse::<u32>() { if let Ok(n) = dice.parse::<u32>() {
return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers
@@ -214,7 +199,7 @@ fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
pub fn register() -> CreateCommand { pub fn register() -> CreateCommand {
CreateCommand::new("roll") CreateCommand::new("roll")
.description("Rolls D&D dice") .description("Roll dice")
.add_option( .add_option(
CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true), CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true),
) )

View File

@@ -1,17 +1,19 @@
use std::env; use std::env;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use serenity::all::{Interaction, ResumedEvent}; use serenity::all::{CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage, EditInteractionResponse, Interaction, ResumedEvent, UserId};
use serenity::async_trait; use serenity::async_trait;
use serenity::model::gateway::Ready; use serenity::model::gateway::Ready;
use serenity::model::channel::Message; use serenity::model::channel::Message;
use serenity::prelude::*; use serenity::prelude::*;
use songbird::Songbird; use songbird::Songbird;
use crate::bot::commands::chat::generate_response; use crate::bot::commands::chat::generate_response;
use crate::bot::commands::fun::roll::{format_roll, roll_dice, send_roll_message};
use crate::bot::oai::OAI; use crate::bot::oai::OAI;
use crate::data::guilds::GuildCache; use crate::data::guilds::GuildCache;
use crate::HttpKey; use crate::HttpKey;
use crate::utils::{a_or_an, number_to_words};
use super::{commands}; use super::{commands};
use super::chat::{create_modal_response}; use super::chat::{create_modal_response, user_dm};
pub struct BotHandler { pub struct BotHandler {
// Open AI Config // Open AI Config
@@ -82,10 +84,13 @@ impl EventHandler for BotHandler {
log::warn!("No ready guilds found"); log::warn!("No ready guilds found");
} }
if SONGBIRD.get().is_none() {
let songbird = songbird::get(&ctx).await.unwrap(); let songbird = songbird::get(&ctx).await.unwrap();
SONGBIRD SONGBIRD
.set(songbird.clone()) .set(songbird.clone())
.expect("Songbird value could not be set"); .expect("Songbird value could not be set");
}
if CLIENT.get().is_none() {
let http_client = { let http_client = {
let data = ctx.data.read().await; let data = ctx.data.read().await;
data data
@@ -94,6 +99,7 @@ impl EventHandler for BotHandler {
.expect("Guaranteed to exist in the typemap.") .expect("Guaranteed to exist in the typemap.")
}; };
CLIENT.set(http_client).ok(); CLIENT.set(http_client).ok();
}
log::trace!("Handling {} guilds", ready.guilds.len()); log::trace!("Handling {} guilds", ready.guilds.len());
for guild in ready.guilds { for guild in ready.guilds {
@@ -122,6 +128,7 @@ impl EventHandler for BotHandler {
commands::audio::volume::register(), commands::audio::volume::register(),
commands::event::schedule::register(), commands::event::schedule::register(),
commands::fun::roll::register(), commands::fun::roll::register(),
commands::fun::request_roll::register(),
commands::utility::ping::register(), commands::utility::ping::register(),
], ],
) )
@@ -146,10 +153,8 @@ impl EventHandler for BotHandler {
} }
async fn interaction_create(&self, ctx: Context, interaction: Interaction) { async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::Ping(ping) = interaction { if let Interaction::Command(command) = interaction {
log::trace!("Received interaction ping: {:?}", ping); log::trace!("Received COMMAND");
} else if let Interaction::Command(command) = interaction {
log::trace!("Received command interaction: {command:#?}");
match command.data.name.as_str() { match command.data.name.as_str() {
// Match commands without returns // Match commands without returns
"play" => commands::audio::play::run(&ctx, &command).await, "play" => commands::audio::play::run(&ctx, &command).await,
@@ -161,11 +166,47 @@ impl EventHandler for BotHandler {
"volume" => commands::audio::volume::run(&ctx, &command).await, "volume" => commands::audio::volume::run(&ctx, &command).await,
"schedule" => commands::event::schedule::run(&ctx, &command).await, "schedule" => commands::event::schedule::run(&ctx, &command).await,
"roll" => commands::fun::roll::run(&ctx, &command).await, "roll" => commands::fun::roll::run(&ctx, &command).await,
"requestroll" => commands::fun::request_roll::run(&ctx, &command).await,
"ping" => commands::utility::ping::run(&ctx, &command).await, "ping" => commands::utility::ping::run(&ctx, &command).await,
_ => {} _ => {}
} }
} else if let Interaction::Component(component) = interaction {
log::trace!("Received COMPONENT");
let custom_id = &component.data.custom_id;
if custom_id.starts_with("request_dice_roll") {
// Acknowledge the interaction
if let Err(err) = component.create_response(ctx.http.clone(), CreateInteractionResponse::Acknowledge).await {
log::error!("Could not create dice response: {}", err);
};
let parts = custom_id.split('|').collect::<Vec<&str>>();
if parts.len() == 6 {
let count = parts[1].parse().unwrap();
let sides = parts[2].parse().unwrap();
let modifier = parts[3].parse().unwrap();
let result = roll_dice(count, sides, modifier);
let response = format!("(Rolled {})", format_roll(count, sides, modifier));
let user_id = UserId::from(parts[4].parse::<u64>().unwrap());
let roller_id = component.user.id;
let hidden: bool = parts[5].parse().unwrap();
send_roll_message(&ctx, result, user_id, roller_id, &response).await;
component.delete_response(ctx.http.clone()).await.ok();
let message;
if hidden {
message = format!("Results sent to {}", user_id.mention());
} else {
message = format!("🎲 You rolled {} {}\n-# {}", a_or_an(&number_to_words(result)), result, response);
}
user_dm(&ctx, &component.user.id, message).await;
} else {
log::error!("Could not handle dice click: {}", custom_id);
}
}
} else if let Interaction::Ping(_ping) = interaction {
log::trace!("Received PING");
} else if let Interaction::Autocomplete(_autocomplete) = interaction {
log::trace!("Received AUTOCOMPLETE");
} else if let Interaction::Modal(modal) = interaction { } else if let Interaction::Modal(modal) = interaction {
log::trace!("Received interaction modal: {:?}", modal); log::trace!("Received MODAL");
create_modal_response(&ctx, &modal).await; create_modal_response(&ctx, &modal).await;
} }
} }

View File

@@ -1,6 +1,6 @@
use std::fmt; use std::fmt;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::Json; use axum::{http, Json};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -13,12 +13,20 @@ pub struct Error {
} }
impl Error { impl Error {
pub fn new(error_status_code: u16, error_message: String) -> Self { pub fn new(status: u16, details: String) -> Self {
Self { Self {
status: error_status_code, status,
details: error_message, details,
} }
} }
pub fn not_found(details: String) -> Self {
Self::new(404, details)
}
pub fn internal_server_error(details: String) -> Self {
Self::new(500, details)
}
} }
impl fmt::Display for Error { impl fmt::Display for Error {
@@ -56,6 +64,12 @@ impl From<std::io::Error> for Error {
} }
} }
impl From<songbird::tracks::ControlError> for Error {
fn from(error: songbird::tracks::ControlError) -> Self {
Self::new(500, format!("Unknown control error: {}", error))
}
}
impl From<StatusCode> for Error { impl From<StatusCode> for Error {
fn from(status: StatusCode) -> Self { fn from(status: StatusCode) -> Self {
Error { Error {

View File

@@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::env; use std::env;
use std::sync::Arc; use std::sync::Arc;
use dotenv::{dotenv, from_filename}; use dotenv::{dotenv, from_filename};
@@ -27,6 +28,7 @@ struct AppState {
client_id: String, client_id: String,
client_secret: String, client_secret: String,
redirect_uri: String, redirect_uri: String,
oauth_states: Arc<Mutex<HashSet<String>>>,
http: Arc<Http>, http: Arc<Http>,
cache: Arc<Cache>, cache: Arc<Cache>,
} }
@@ -71,6 +73,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
client_id: bot_id.to_string(), client_id: bot_id.to_string(),
client_secret, client_secret,
redirect_uri, redirect_uri,
oauth_states: Arc::new(Mutex::new(HashSet::new())),
http: Arc::clone(&client.http), http: Arc::clone(&client.http),
cache: Arc::clone(&client.cache), cache: Arc::clone(&client.cache),
}; };