diff --git a/.env b/.env index e45f6a4..3d93168 100644 --- a/.env +++ b/.env @@ -1,4 +1,4 @@ -RUST_LOG=warn,service=debug +RUST_LOG=warn,siren=info DATABASE_USER=siren DATABASE_PASSWORD=CHANGEME # Change this to a secure password @@ -21,6 +21,6 @@ SERVICE_HOST=localhost SERVICE_PORT=5000 DATA_DIR_PATH= # OPTIONAL -DISCORD_TOKEN= # OPTIONAL +DISCORD_TOKEN= OPENAI_API_KEY= # OPTIONAL OPENAI_API_MODEL=gpt-3.5-turbo diff --git a/.version b/.version index 752f734..e0b6561 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -SIREN_VERSION=0.2.7 \ No newline at end of file +SIREN_VERSION=0.2.8 \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 1aa45a6..1506a54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,29 +1,27 @@ [package] -name = "service" +name = "siren" version = "0.2.8" edition = "2021" authors = ["Ben Sherriff "] +description = "A Discord bot for playing music" repository = "https://github.com/bensherriff/siren" readme = "README.md" license = "GPL-3.0-or-later" -[lib] -name = "siren" -path = "src/lib.rs" - [dependencies] dotenv = "0.15.0" log = "0.4.22" env_logger = "0.11.5" serde = { version = "1.0.209", features = ["derive"] } serde_json = "1.0.127" -serenity = { version = "0.11.6", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] } -songbird = { version = "0.3.2", features = ["builtin-queue", "yt-dlp"] } +serenity = { version = "0.12.2", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] } +songbird = { version = "0.4.3", features = ["builtin-queue"] } +symphonia = { version = "0.5.4", features = ["all"] } diesel = { version = "2.1.5", default-features = false, features = ["postgres", "chrono", "r2d2", "32-column-tables", "serde_json", "with-deprecated"] } diesel_migrations = { version = "2.1.0", features = ["postgres"] } r2d2 = "0.8.10" chrono = { version = "0.4.38", features = ["serde"] } -reqwest = { version = "0.12.7", default-features = false, features = ["json"] } +reqwest = { version = "0.11", default-features = false, features = ["json"] } lazy_static = "1.5.0" uuid = { version = "1.10.0", features = ["serde", "v4"] } redis = { version = "0.26.1", features = ["tokio-comp", "connection-manager", "r2d2"] } diff --git a/Makefile b/Makefile index 161e4ce..61a1dbc 100644 --- a/Makefile +++ b/Makefile @@ -47,4 +47,7 @@ docker-clean: ## Stop the docker containers and remove volumes @docker compose --profile backend --profile siren down -v @echo "Docker container stopped and volumes removed" -docker-refresh: docker-clean backend-up ## Refresh the docker containers \ No newline at end of file +docker-refresh: docker-clean backend-up ## Refresh the docker containers + +psql: ## Connect to the database + @docker exec -it siren-db psql -U ${DATABASE_USER} -P pager=off \ No newline at end of file diff --git a/src/bot/commands/audio/mod.rs b/src/bot/commands/audio/mod.rs index 5f94cbc..1de1f6a 100644 --- a/src/bot/commands/audio/mod.rs +++ b/src/bot/commands/audio/mod.rs @@ -1,20 +1,14 @@ use std::sync::Arc; -use log::{debug, warn}; - use reqwest::Url; +use serenity::all::{CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditInteractionResponse}; use serenity::client::Cache; -use serenity::model::application::interaction::{ - InteractionResponseType, application_command::ApplicationCommandInteraction, -}; use serenity::model::prelude::{GuildId, ChannelId}; use serenity::model::user::User; use serenity::prelude::*; -use siren::ServiceError; -use songbird::{Call, Songbird}; -use songbird::input::{Restartable, Input, Metadata, error::Error as SongbirdError}; +use songbird::Songbird; -use crate::bot::ytdlp::{PlaylistItem, YtDlp}; +use crate::error::{SirenResult, Error as SirenError}; pub mod pause; pub mod play; @@ -23,57 +17,41 @@ pub mod skip; pub mod stop; pub mod volume; +/** + * Finds a voice channel that the user is currently in, and attempts to join it. + */ pub async fn join_by_user( cache: &Arc, - manager: Arc, - guild_id_option: &Option, + manager: &Arc, + guild_id: &GuildId, user: &User, -) -> Result<(), ServiceError> { - let guild_id = match guild_id_option { - Some(g) => g, - None => { - return Err(ServiceError { - status: 422, - message: format!("{}", "No guild ID set"), - }) - } - }; - - let channel_id = match find_voice_channel(cache, &guild_id, &user) { +) -> SirenResult { + let channel_id = match find_voice_channel(cache, guild_id, user) { Ok(channel) => channel, - Err(err) => { - return Err(ServiceError { - status: 500, - message: err.to_string(), - }) - } + Err(err) => return Err(SirenError::new(500, err.to_string())) }; - join(manager, guild_id, &channel_id).await + join_voice_channel(manager, guild_id, &channel_id).await?; + Ok(channel_id) } -pub async fn join( - manager: Arc, +/** + * Joins a voice channel. + */ +async fn join_voice_channel( + manager: &Arc, guild_id: &GuildId, channel_id: &ChannelId, -) -> Result<(), ServiceError> { - debug!("<{}> Joining channel {}", guild_id.0, channel_id.0); - let (_handle_lock, success) = manager - .join(guild_id.to_owned(), channel_id.to_owned()) - .await; - match success { - Ok(s) => Ok(s), - Err(err) => { - warn!("Failed to join channel: {:?}", err); - Err(ServiceError { - status: 500, - message: err.to_string(), - }) - } - } +) -> SirenResult<()> { + log::debug!("<{}> Joining channel {}", guild_id.get(), channel_id.get()); + manager.join(guild_id.to_owned(), channel_id.to_owned()).await?; + Ok(()) } -pub async fn leave( +/** + * Leaves a voice channel. + */ +pub async fn leave_voice_channel( manager: Arc, guild_id_option: &Option, ) -> Result<(), String> { @@ -85,7 +63,7 @@ pub async fn leave( }; if manager.get(*guild_id).is_some() { - debug!("<{}> Disconnecting from channel", guild_id.0); + log::debug!("<{}> Disconnecting from channel", guild_id.get()); if let Err(e) = manager.remove(*guild_id).await { return Err(format!("{}", e)); } @@ -93,12 +71,15 @@ pub async fn leave( Ok(()) } +/** + * Finds a voice channel that the user is currently in. + */ fn find_voice_channel( cache: &Arc, guild_id: &GuildId, user: &User, ) -> Result { - let guild = match guild_id.to_guild_cached(cache.to_owned()) { + let guild = match guild_id.to_guild_cached(cache) { Some(g) => g, None => return Err(format!("Guild not found")), }; @@ -115,86 +96,29 @@ fn find_voice_channel( pub async fn create_response( ctx: &Context, - command: &ApplicationCommandInteraction, + command: &CommandInteraction, content: String, ) -> Result<(), SerenityError> { - command - .create_interaction_response( - &ctx.http, - |response: &mut serenity::builder::CreateInteractionResponse<'_>| { - response - .kind(InteractionResponseType::ChannelMessageWithSource) - .interaction_response_data( - |message: &mut serenity::builder::CreateInteractionResponseData<'_>| { - message.content(content) - }, - ) - }, - ) - .await + let data = CreateInteractionResponseMessage::new().content(content); + let builder = CreateInteractionResponse::Message(data); + command.create_response(&ctx.http, builder).await?; + Ok(()) } pub async fn edit_response( ctx: &Context, - command: &ApplicationCommandInteraction, + command: &CommandInteraction, content: String, ) -> Result { - command - .edit_original_interaction_response( - &ctx.http, - |response: &mut serenity::builder::EditInteractionResponse| response.content(content), - ) - .await -} - -pub async fn add_song( - call: Arc>, - url: &str, - lazy: bool, - volume: Option, -) -> Result { - let source = Restartable::ytdl(url.to_owned(), lazy).await?; - let mut handler = call.lock().await; - let track: Input = source.into(); - let metadata = *track.metadata.clone(); - let track_handle = handler.enqueue_source(track); - if let Some(volume) = volume { - let _ = track_handle.set_volume(volume); - } - Ok(metadata) -} - -pub fn get_playlist_urls(url: &str) -> Result, ServiceError> { - let output = YtDlp::new() - .arg("--flat-playlist") - .arg("--dump-json") - .arg(url) - .execute()?; - let items: Vec = String::from_utf8(output.stdout)? - .split('\n') - .filter_map(|line| { - if line.is_empty() { - None - } else { - Some( - serde_json::from_slice::(line.as_bytes()).map_err(|err| ServiceError { - status: 500, - message: err.to_string(), - }), - ) - } - }) - .filter_map(|parsed| match parsed { - Ok(item) => Some(item), - Err(err) => { - warn!("Failed to parse playlist item: {}", err); - None - } - }) - .collect(); - Ok(items) + let builder = EditInteractionResponse::new().content(content); + command.edit_response(&ctx.http, builder).await } +/** + * Checks if a URL is valid and if it is a playlist. + * 1st tuple value is if the URL is valid. + * 2nd tuple value is if the URL is a playlist. + */ fn is_valid_url(url: &str) -> (bool, bool) { Url::parse(url).ok().map_or((false, false), |valid_url| { let is_playlist: bool = valid_url diff --git a/src/bot/commands/audio/pause.rs b/src/bot/commands/audio/pause.rs index 213bfa9..d646c83 100644 --- a/src/bot/commands/audio/pause.rs +++ b/src/bot/commands/audio/pause.rs @@ -1,12 +1,10 @@ use log::{debug, error}; -use serenity::prelude::*; -use serenity::builder::CreateApplicationCommand; -use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; +use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use super::{get_songbird, create_response, edit_response}; -pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { +pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { error!("Failed to create response message: {}", why); @@ -40,6 +38,6 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { } } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command.name("pause").description("Pause the current track") +pub fn register() -> CreateCommand { + CreateCommand::new("pause").description("Pause the current track") } diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index a960eaa..38abc12 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -1,51 +1,39 @@ use std::sync::Arc; -use log::{debug, warn, error}; - +use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}; use serenity::model::prelude::GuildId; use serenity::{prelude::*, async_trait}; -use serenity::builder::CreateApplicationCommand; -use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; -use siren::ServiceError; -use songbird::{EventHandler, Songbird}; +use songbird::input::{AuxMetadata, Input, YoutubeDl}; +use songbird::tracks::TrackHandle; +use songbird::{Call, EventHandler, Songbird}; -use crate::bot::guilds::QueryGuild; -use crate::bot::ytdlp::PlaylistItem; -use crate::bot::{ - commands::audio::{leave, get_playlist_urls, add_song, get_songbird}, -}; +use crate::bot::guilds::GuildCache; +use crate::bot::ytdlp::{PlaylistItem, YtDlp}; +use crate::bot::commands::audio::{create_response, edit_response, leave_voice_channel}; +use crate::error::{SirenResult, Error as SirenError}; +use crate::HttpKey; -use super::{create_response, edit_response, is_valid_url, join_by_user}; +use super::{get_songbird, is_valid_url, join_by_user}; -pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { +pub async fn run(ctx: &Context, command: &CommandInteraction) { // Get the track url let track_url = match command.data.options.get(0) { - Some(t) => match &t.value { - Some(v) => match v.as_str() { + Some(o) => match &o.value.as_str() { Some(s) => s.to_owned(), None => { - warn!("Missing track option"); + log::warn!("Missing track option"); if let Err(why) = create_response(&ctx, &command, format!("Track option is missing")).await { - error!("Failed to create response message: {}", why); + log::error!("Failed to create response message: {}", why); } return; } }, - None => { - warn!("Missing track option"); - if let Err(why) = create_response(&ctx, &command, format!("Track option is missing")).await - { - error!("Failed to create response message: {}", why); - } - return; - } - }, None => { - warn!("Missing track option"); + log::warn!("Missing track option"); if let Err(why) = create_response(&ctx, &command, format!("Track option is missing")).await { - error!("Failed to create response message: {}", why); + log::error!("Failed to create response message: {}", why); } return; } @@ -53,27 +41,29 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { // Create the initial response if let Err(why) = create_response(&ctx, &command, format!("Processing command...")).await { - error!("Failed to create response message: {}", why); + log::error!("Failed to create response message: {}", why); return; } let manager = get_songbird(ctx).await; - match join_by_user(&ctx.cache, manager, &command.guild_id, &command.user).await { + // Extract the guild ID + let guild_id = match &command.guild_id { + Some(guild_id) => guild_id, + None => { + if let Err(why) = + edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await + { + log::error!("Failed to edit response message: {}", why); + } + return; + } + }; + // Join the user's voice channel + match join_by_user(&ctx.cache, &manager, guild_id, &command.user).await { Ok(_) => { - let guild_id = match command.guild_id { - Some(g) => g, - None => { - if let Err(why) = - edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await - { - error!("Failed to edit response message: {}", why); - } - return; - } - }; - debug!("Play command executed with track: {:?}", track_url); - let manager = get_songbird(ctx).await; - match play_track(manager, guild_id, track_url).await { + log::debug!("Play command executed with track: {:?}", track_url); + // Handle the track url + match play_track(ctx, manager, guild_id.to_owned(), track_url).await { Ok(count) => { let mut message = format!("Playing {} tracks", count); if count == 0 { @@ -82,72 +72,71 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { message = "Playing 1 track".to_string(); } if let Err(why) = edit_response(&ctx, &command, message).await { - error!("Failed to edit response message: {}", why); + log::error!("Failed to edit response message: {}", why); } } Err(err) => { - warn!("Failed to play track: {}", err); + log::warn!("Failed to play track: {}", err); if let Err(why) = edit_response(&ctx, &command, format!("Failed to play track: {}", err)).await { - error!("Failed to edit response message: {}", why); + log::error!("Failed to edit response message: {}", why); } } }; } Err(err) => { - warn!("{}", err); + log::warn!("{}", err); if let Err(why) = edit_response(&ctx, &command, format!("{}", err)).await { - error!("Failed to edit response message: {}", why); + log::error!("Failed to edit response message: {}", why); } } } } pub async fn play_track( + ctx: &Context, manager: Arc, guild_id: GuildId, - track_url: String, -) -> Result { + track_url: &str, +) -> SirenResult { let mut track_count = 0; if let Some(handler_lock) = manager.get(guild_id) { let is_queue_empty = { let call_handler = handler_lock.lock().await; call_handler.queue().is_empty() }; - let guild = QueryGuild::get(guild_id.0 as i64)?; + let guild = GuildCache::get(guild_id.get() as i64)?; let valid = is_valid_url(&track_url); + // Check if the URL is valid if !valid.0 { - warn!("Invalid track url: {}", track_url); - return Err(ServiceError { - status: 422, - message: format!("Invalid track url: {}", track_url), - }); + log::warn!("Invalid track url: {}", track_url); + return Err(SirenError::new(422, format!("Invalid track url: {}", track_url))); } let mut playlist_items: Vec = Vec::new(); + // Check if the URL is a playlist or a single track if valid.1 { playlist_items = match get_playlist_urls(&track_url) { Ok(items) => items, Err(err) => { - warn!("Failed to get playlist urls: {}", err); - return Err(ServiceError { - status: 422, - message: err.to_string(), - }); + log::warn!("Failed to get playlist urls: {}", err); + return Err(SirenError::new(422,err.to_string())); } }; } else { let playlist_item = PlaylistItem { id: "".to_string(), - url: track_url, + url: track_url.to_string(), title: "".to_string(), duration: 0, playlist_index: 0, }; playlist_items.push(playlist_item); } + // Add each track to the queue for item in playlist_items { match add_song( + ctx, handler_lock.clone(), &item.url, is_queue_empty, @@ -157,7 +146,7 @@ pub async fn play_track( { Ok(added_song) => { let track_title = added_song.title.unwrap(); - debug!("Added track: {}", track_title); + log::debug!("Added track: {}", track_title); let mut handler = handler_lock.lock().await; handler.remove_all_global_events(); handler.add_global_event( @@ -170,14 +159,11 @@ pub async fn play_track( track_count += 1; } Err(err) => { - warn!("Failed to add song: {}", err); - if let Err(why) = leave(manager, &Some(guild_id)).await { - error!("Failed to leave voice channel: {}", why); + log::warn!("Failed to add song: {}", err); + if let Err(why) = leave_voice_channel(manager, &Some(guild_id)).await { + log::error!("Failed to leave voice channel: {}", why); } - return Err(ServiceError { - status: 422, - message: err.to_string(), - }); + return Err(SirenError::new(422, err.to_string())); } } } @@ -185,17 +171,68 @@ pub async fn play_track( Ok(track_count) } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command - .name("play") - .description("Plays the given track") - .create_option(|option| { - option - .name("track") - .description("The track to be played") - .kind(serenity::model::prelude::command::CommandOptionType::String) - .required(true) +async fn add_song( + ctx: &Context, + call: Arc>, + url: &str, + lazy: bool, + volume: Option, +) -> SirenResult { + // let source = Restartable::ytdl(url.to_owned(), lazy).await?; + let http_client = { + let data = ctx.data.read().await; + data.get::() + .cloned() + .expect("Guaranteed to exist in the typemap.") +}; + let source = YoutubeDl::new(http_client, url.to_owned()); + let mut handler = call.lock().await; + let mut track: Input = source.into(); + let metadata = track.aux_metadata().await.unwrap(); + let track_handle: TrackHandle; + if lazy { + track_handle = handler.play_input(track); + } else { + track_handle = handler.enqueue_input(track).await; + } + if let Some(volume) = volume { + let _ = track_handle.set_volume(volume); + } + Ok(metadata) +} + +pub fn get_playlist_urls(url: &str) -> SirenResult> { + let output = YtDlp::new() + .arg("--flat-playlist") + .arg("--dump-json") + .arg(url) + .execute()?; + let items: Vec = String::from_utf8(output.stdout)? + .split('\n') + .filter_map(|line| { + if line.is_empty() { + None + } else { + Some( + serde_json::from_slice::(line.as_bytes()).map_err(|err| SirenError::new(500, err.to_string())), + ) + } }) + .filter_map(|parsed| match parsed { + Ok(item) => Some(item), + Err(err) => { + log::warn!("Failed to parse playlist item: {}", err); + None + } + }) + .collect(); + Ok(items) +} + +pub fn register() -> CreateCommand { + CreateCommand::new("play") + .description("Plays the given track") + .add_option(CreateCommandOption::new(CommandOptionType::String, "track", "The track to be played").required(true)) } struct TrackEndNotifier { @@ -210,7 +247,7 @@ impl EventHandler for TrackEndNotifier { if let Some(call) = self.call.get(self.guild_id) { let mut handler = call.lock().await; if handler.queue().is_empty() { - debug!("Queue is empty, leaving voice channel"); + log::debug!("Queue is empty, leaving voice channel"); handler.leave().await.unwrap(); } } diff --git a/src/bot/commands/audio/resume.rs b/src/bot/commands/audio/resume.rs index 0a10a20..722f2bc 100644 --- a/src/bot/commands/audio/resume.rs +++ b/src/bot/commands/audio/resume.rs @@ -1,12 +1,10 @@ use log::{debug, error}; -use serenity::prelude::*; -use serenity::builder::CreateApplicationCommand; -use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; +use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use super::{get_songbird, create_response, edit_response}; -pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { +pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { error!("Failed to create response message: {}", why); @@ -40,8 +38,6 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { } } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command - .name("resume") - .description("Resume the current track") +pub fn register() -> CreateCommand { + CreateCommand::new("resume").description("Resume the current track") } diff --git a/src/bot/commands/audio/skip.rs b/src/bot/commands/audio/skip.rs index c91ffdd..ef5ff8c 100644 --- a/src/bot/commands/audio/skip.rs +++ b/src/bot/commands/audio/skip.rs @@ -1,12 +1,10 @@ use log::{debug, error}; -use serenity::prelude::*; -use serenity::builder::CreateApplicationCommand; -use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; +use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use super::{get_songbird, create_response, edit_response}; -pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { +pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { error!("Failed to create response message: {}", why); @@ -40,6 +38,6 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { } } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command.name("skip").description("Skip the current track") +pub fn register() -> CreateCommand { + CreateCommand::new("skip").description("Skip the current track") } diff --git a/src/bot/commands/audio/stop.rs b/src/bot/commands/audio/stop.rs index b4943dd..1bee71c 100644 --- a/src/bot/commands/audio/stop.rs +++ b/src/bot/commands/audio/stop.rs @@ -1,12 +1,10 @@ use log::{debug, error}; -use serenity::prelude::*; -use serenity::builder::CreateApplicationCommand; -use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; +use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use super::{get_songbird, create_response, edit_response}; -pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { +pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { error!("Failed to create response message: {}", why); @@ -35,8 +33,6 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { } } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command - .name("stop") - .description("Stop the current track and clear the queue") +pub fn register() -> CreateCommand { + CreateCommand::new("stop").description("Stop the current track and clear the queue") } diff --git a/src/bot/commands/audio/volume.rs b/src/bot/commands/audio/volume.rs index 624cbdd..eaff139 100644 --- a/src/bot/commands/audio/volume.rs +++ b/src/bot/commands/audio/volume.rs @@ -2,34 +2,22 @@ use std::sync::Arc; use log::{error, warn}; -use serenity::{prelude::*, model::prelude::GuildId}; -use serenity::builder::CreateApplicationCommand; -use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; +use serenity::{all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}, model::prelude::GuildId, prelude::*}; use songbird::Songbird; -use crate::bot::guilds::InsertGuild; +use crate::bot::guilds::GuildCache; use super::{get_songbird, create_response, edit_response}; -pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { +pub async fn run(ctx: &Context, command: &CommandInteraction) { // Get the volume let volume = match command.data.options.get(0) { - Some(t) => match &t.value { - Some(v) => match v.as_i64() { - Some(p) => p as i32, - None => { - warn!("Unable to get volume option as a string"); - if let Err(why) = - create_response(&ctx, &command, format!("Volume option is missing")).await - { - error!("Failed to create response message: {}", why); - } - return; - } - }, + Some(o) => match o.value.as_i64() { + Some(p) => p as i32, None => { - warn!("Missing volume option value"); - if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await + warn!("Unable to get volume option as a string"); + if let Err(why) = + create_response(&ctx, &command, format!("Volume option is missing")).await { error!("Failed to create response message: {}", why); } @@ -74,7 +62,7 @@ pub async fn set_volume(manager: Arc, guild_id: GuildId, volume: i32) // Format volume to f32 bound between 0.0 and 1.0 let volume = std::cmp::min(100, std::cmp::max(0, volume)); let bound_volume = volume as f32 / 100.0; - let _ = InsertGuild::update_audio(guild_id.0 as i64, volume); + let _ = GuildCache::update_audio(guild_id.get() as i64, volume); if let Some(handler_lock) = manager.get(guild_id) { let handler = handler_lock.lock().await; @@ -84,15 +72,8 @@ pub async fn set_volume(manager: Arc, guild_id: GuildId, volume: i32) } } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command - .name("volume") +pub fn register() -> CreateCommand { + CreateCommand::new("volume") .description("Set the audio player volume") - .create_option(|option| { - option - .name("volume") - .description("Volume between 0 and 100") - .kind(serenity::model::prelude::command::CommandOptionType::Integer) - .required(true) - }) + .add_option(CreateCommandOption::new(CommandOptionType::Integer, "volume", "Volume between 0 and 100").required(true)) } diff --git a/src/bot/commands/chat.rs b/src/bot/commands/chat.rs index cb3f92b..3bf09bf 100644 --- a/src/bot/commands/chat.rs +++ b/src/bot/commands/chat.rs @@ -1,5 +1,6 @@ use log::{error, trace, warn}; +use serenity::all::CreateThread; use serenity::model::Permissions; use serenity::model::channel::Message; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; @@ -16,7 +17,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { let author_id = msg.author.id; // Parse out the bot mention from the message - let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); + let bot_mention: String = format!("<@{}>", ctx.cache.current_user().id); let parsed_content = msg.content.replace(bot_mention.as_str(), ""); let mut messages = vec![ @@ -28,9 +29,9 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { match QueryMessage::get_all( &QueryFilters { - by_guild_id: Some(guild_id.0 as i64), - by_channel_id: Some(channel_id.0 as i64), - by_user_id: Some(author_id.0 as i64), + by_guild_id: Some(guild_id.get() as i64), + by_channel_id: Some(channel_id.get() as i64), + by_user_id: Some(author_id.get() as i64), ..Default::default() }, 100, @@ -71,9 +72,8 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { let thread_name = generate_thread_name(oai, &parsed_content, 99).await; let response_channel = match msg .channel_id - .create_private_thread(&ctx.http, |thread| { - thread.name(thread_name).kind(ChannelType::PublicThread) - }) + .create_thread(&ctx.http, CreateThread::new(thread_name).kind(ChannelType::PublicThread) + ) .await { Ok(c) => { @@ -84,13 +84,13 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { deny, kind: PermissionOverwriteType::Member(msg.author.id), }; - let _ = c.create_permission(&ctx.http, &overwrite).await; + let _ = c.create_permission(&ctx.http, overwrite).await; c.id } Err(_) => channel_id, }; - let typing = response_channel.start_typing(&ctx.http).unwrap(); + let typing = response_channel.start_typing(&ctx.http); // Get the OAI response and store message/response into the database let response = match oai.chat_completion(request).await { @@ -100,9 +100,9 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { let res = r.choices[0].message.content.clone(); if let Err(err) = QueryMessage::insert(QueryMessage { id: r.id, - guild_id: guild_id.0 as i64, - channel_id: response_channel.0 as i64, - user_id: author_id.0 as i64, + guild_id: guild_id.get() as i64, + channel_id: response_channel.get() as i64, + user_id: author_id.get() as i64, created: r.created, model: serde_json::to_string(&r.model).unwrap(), request: parsed_content, diff --git a/src/bot/commands/ping.rs b/src/bot/commands/ping.rs index 951f712..b4b8881 100644 --- a/src/bot/commands/ping.rs +++ b/src/bot/commands/ping.rs @@ -1,14 +1,10 @@ -use log::debug; -use serenity::{ - model::prelude::interaction::application_command::CommandDataOption, - builder::CreateApplicationCommand, -}; +use serenity::all::{CommandDataOption, CreateCommand}; pub fn run(_options: &[CommandDataOption]) -> String { - debug!("Ping command executed"); + log::debug!("Ping command executed"); "pong".to_string() } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command.name("ping").description("Replies with pong") +pub fn register() -> CreateCommand { + CreateCommand::new("ping").description("Replies with pong") } diff --git a/src/bot/commands/roll.rs b/src/bot/commands/roll.rs index e929bf5..d5f10dd 100644 --- a/src/bot/commands/roll.rs +++ b/src/bot/commands/roll.rs @@ -1,25 +1,19 @@ use log::{error, warn}; use rand::Rng; -use serenity::{ - builder::CreateApplicationCommand, - client::Context, - model::application::{ - command::CommandOptionType, interaction::application_command::ApplicationCommandInteraction, - }, -}; +use serenity::all::{CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption}; use crate::bot::commands::audio::edit_response; use super::audio::create_response; -pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { +pub async fn run(ctx: &Context, command: &CommandInteraction) { if let Err(why) = create_response(&ctx, &command, format!("Processing command...")).await { error!("Failed to create response message: {}", why); return; } - let dice_string: String = match command.data.options.get(0) { - Some(o) => match &o.value { - Some(v) => match v.as_str() { + let dice_string = match command.data.options.get(0) { + Some(o) => { + match o.value.as_str() { Some(s) => s.split_whitespace().collect::(), None => { warn!("Missing dice option"); @@ -28,21 +22,14 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { } return; } - }, - None => { - warn!("Missing dice option"); - if let Err(why) = edit_response(&ctx, &command, format!("Dice option is missing")).await { - error!("Failed to create response message: {}", why); - } - return; } }, None => { warn!("Missing dice option"); - if let Err(why) = edit_response(&ctx, &command, format!("Dice option is missing")).await { - error!("Failed to create response message: {}", why); - } - return; + if let Err(why) = edit_response(&ctx, &command, format!("Dice option is missing")).await { + error!("Failed to create response message: {}", why); + } + return; } }; let dice = parse_dice(dice_string.as_str()); @@ -135,15 +122,8 @@ fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { Ok((count, sides, modifier)) } -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command - .name("roll") +pub fn register() -> CreateCommand { + CreateCommand::new("roll") .description("Rolls D&D dice") - .create_option(|option| { - option - .name("dice") - .description("Dice to roll") - .kind(CommandOptionType::String) - .required(true) - }) + .add_option(CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true)) } diff --git a/src/bot/guilds/model.rs b/src/bot/guilds/model.rs index d79bcb3..74d8f98 100644 --- a/src/bot/guilds/model.rs +++ b/src/bot/guilds/model.rs @@ -1,43 +1,33 @@ use diesel::prelude::*; use serde::{Serialize, Deserialize}; -use siren::ServiceError; +use crate::error::SirenResult; use crate::storage::{schema::guilds, connection}; -#[derive(Queryable, QueryableByName, Serialize, Deserialize)] +#[derive(Insertable, AsChangeset, Queryable, QueryableByName, Serialize, Deserialize)] #[diesel(table_name = guilds)] -pub struct QueryGuild { +pub struct GuildCache { pub id: i64, pub bot_id: i64, pub volume: i32, } -impl QueryGuild { - pub fn get(id: i64) -> Result { - let mut conn = connection()?; - let guild = guilds::table.filter(guilds::id.eq(id)).first(&mut conn)?; - Ok(guild) - } -} - -#[derive(Insertable, AsChangeset, Serialize, Deserialize)] -#[diesel(table_name = guilds)] -pub struct InsertGuild { - pub id: i64, - pub bot_id: i64, - pub volume: i32, -} - -impl InsertGuild { - pub fn insert(guild: Self) -> Result { +impl GuildCache { + pub fn insert(&self) -> SirenResult { let mut conn = connection()?; let guild = diesel::insert_into(guilds::table) - .values(guild) + .values(self) .get_result(&mut conn)?; Ok(guild) } - pub fn update_audio(id: i64, volume: i32) -> Result { + pub fn get(id: i64) -> SirenResult { + let mut conn = connection()?; + let guild = guilds::table.filter(guilds::id.eq(id)).first(&mut conn)?; + Ok(guild) + } + + pub fn update_audio(id: i64, volume: i32) -> SirenResult { let mut conn = connection()?; let guild = diesel::update(guilds::table.filter(guilds::id.eq(id))) .set(guilds::volume.eq(volume)) diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 9de3e32..9d01c85 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -1,10 +1,11 @@ use log::{warn, info, error}; +use serenity::all::Interaction; use serenity::async_trait; -use serenity::model::application::interaction::Interaction; use serenity::model::gateway::Ready; use serenity::model::channel::Message; use serenity::prelude::*; +use super::guilds::GuildCache; use super::{commands, oai}; use super::commands::audio::create_response; @@ -28,7 +29,7 @@ impl EventHandler for Handler { Ok(t) => { match t .iter() - .find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) + .find(|t| t.user_id == ctx.cache.current_user().id) { Some(_) => true, None => false, @@ -48,8 +49,10 @@ impl EventHandler for Handler { } async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - if let Interaction::ApplicationCommand(command) = interaction { + if let Interaction::Command(command) = interaction { + log::trace!("Received command interaction: {command:#?}"); match command.data.name.as_str() { + // Match commands without returns "roll" => commands::roll::run(&ctx, &command).await, "play" => commands::audio::play::run(&ctx, &command).await, "stop" => commands::audio::stop::run(&ctx, &command).await, @@ -59,6 +62,7 @@ impl EventHandler for Handler { "volume" => commands::audio::volume::run(&ctx, &command).await, _ => { let content: String = match command.data.name.as_str() { + // Match commands with string returns "ping" => commands::ping::run(&command.data.options), _ => "Unknown command".to_string(), }; @@ -76,57 +80,34 @@ impl EventHandler for Handler { warn!("No ready guilds found"); } for guild in ready.guilds { + // Check if guild exists in database + let guild_id = guild.id.get() as i64; + if let Err(why) = GuildCache::get(guild_id) { + let guild_cache = GuildCache { + id: guild_id, + bot_id: 1, + volume: 100 + }; + guild_cache.insert(); + } let commands = guild .id - .set_application_commands(&ctx.http, |commands| { - commands - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::ping::register(command) - }, - ) - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::roll::register(command) - }, - ) - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::audio::play::register(command) - }, - ) - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::audio::stop::register(command) - }, - ) - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::audio::pause::register(command) - }, - ) - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::audio::resume::register(command) - }, - ) - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::audio::skip::register(command) - }, - ) - .create_application_command( - |command: &mut serenity::builder::CreateApplicationCommand| { - commands::audio::volume::register(command) - }, - ) - }) + .set_commands(&ctx.http, vec![ + commands::ping::register(), + commands::roll::register(), + commands::audio::play::register(), + commands::audio::stop::register(), + commands::audio::pause::register(), + commands::audio::resume::register(), + commands::audio::skip::register(), + commands::audio::volume::register(), + ]) .await; match commands { - Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0), + Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.get()), Err(why) => error!( "Could not register commands for guild {}: {:?}", - guild.id.0, why + guild.id.get(), why ), }; } diff --git a/src/bot/messages/model.rs b/src/bot/messages/model.rs index 4cc4e7c..d2d840c 100644 --- a/src/bot/messages/model.rs +++ b/src/bot/messages/model.rs @@ -1,6 +1,6 @@ use diesel::prelude::*; use serde::{Deserialize, Serialize}; -use siren::ServiceError; +use crate::error::SirenResult; use crate::storage::{ schema::messages::{self}, @@ -51,7 +51,7 @@ impl Default for QueryFilters { } impl QueryMessage { - pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> Result, ServiceError> { + pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> SirenResult> { let mut conn = connection()?; let mut query = messages::table .limit(limit as i64) @@ -93,7 +93,7 @@ impl QueryMessage { Ok(messages) } - pub fn get_count(fitlers: &QueryFilters) -> Result { + pub fn get_count(fitlers: &QueryFilters) -> SirenResult { let mut conn = connection()?; let mut query = messages::table.into_boxed(); // Apply filters @@ -129,7 +129,7 @@ impl QueryMessage { Ok(count) } - pub fn insert(message: Self) -> Result { + pub fn insert(message: Self) -> SirenResult { let mut conn = connection()?; let message = diesel::insert_into(messages::table) .values(message) diff --git a/src/bot/oai/model.rs b/src/bot/oai/model.rs index 91fab69..a1c4bae 100644 --- a/src/bot/oai/model.rs +++ b/src/bot/oai/model.rs @@ -1,6 +1,7 @@ use serde::{Serialize, Deserialize}; use serde_json::Value; -use siren::ServiceError; + +use crate::error::{SirenResult, Error as SirenError}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum GPTRole { @@ -123,7 +124,7 @@ impl OAI { pub async fn chat_completion( &self, request: ChatCompletionRequest, - ) -> Result { + ) -> SirenResult { let url = format!("{}/chat/completions", self.base_url); let response = self .client @@ -142,7 +143,7 @@ impl OAI { return Ok(response); }, ResponseEvent::ResponseError(error) => { - return Err(ServiceError { + return Err(SirenError { status: 500, message: format!("Error: {}", error.message.unwrap()), }); @@ -150,7 +151,7 @@ impl OAI { } } Err(err) => { - return Err(ServiceError { + return Err(SirenError { status: 500, message: format!("Error: {}", err), }) diff --git a/src/dnd/spells/model.rs b/src/dnd/spells/model.rs index 611483e..65bca7d 100644 --- a/src/dnd/spells/model.rs +++ b/src/dnd/spells/model.rs @@ -1,7 +1,7 @@ use diesel::prelude::*; use serde::{Deserialize, Serialize}; -use siren::ServiceError; +use crate::error::SirenResult; use crate::storage::connection; use crate::storage::schema::spells::{self}; use crate::dnd::{classes::AbilityType, conditions::ConditionType}; @@ -65,7 +65,7 @@ impl Default for QueryFilters { } impl QuerySpell { - pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> Result, ServiceError> { + pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> SirenResult> { let mut conn = connection()?; let mut query = spells::table.limit(limit as i64).into_boxed(); // Limit query to page and limit @@ -147,7 +147,7 @@ impl QuerySpell { Ok(spells) } - pub fn get_count(filters: &QueryFilters) -> Result { + pub fn get_count(filters: &QueryFilters) -> SirenResult { let mut conn = connection()?; let mut query = spells::table.count().into_boxed(); if let Some(name) = &filters.by_name { @@ -223,7 +223,7 @@ impl QuerySpell { Ok(count) } - pub fn get_by_id(id: i32) -> Result { + pub fn get_by_id(id: i32) -> SirenResult { let mut conn = connection()?; let spell = spells::table .filter(spells::id.eq(id)) @@ -231,7 +231,7 @@ impl QuerySpell { Ok(spell) } - pub fn delete(id: i32) -> Result { + pub fn delete(id: i32) -> SirenResult { let mut conn = connection()?; let spell = diesel::delete(spells::table.filter(spells::id.eq(id))).get_result(&mut conn)?; Ok(spell) @@ -256,7 +256,7 @@ pub struct InsertSpell { } impl InsertSpell { - pub fn insert(spell: Self) -> Result { + pub fn insert(spell: Self) -> SirenResult { let mut conn = connection()?; let spell = diesel::insert_into(spells::table) .values(spell) @@ -264,7 +264,7 @@ impl InsertSpell { Ok(spell) } - pub fn update(id: i32, spell: Self) -> Result { + pub fn update(id: i32, spell: Self) -> SirenResult { let mut conn = connection()?; let spell = diesel::update(spells::table.filter(spells::id.eq(id))) .set(spell) diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..0f81a78 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,96 @@ +use std::fmt; +use diesel::result::Error as DieselError; +use serde::{Deserialize, Serialize}; + +pub type SirenResult = Result; + +#[derive(Debug, Deserialize, Serialize)] +pub struct Error { + pub status: u16, + pub message: String, +} + +impl Error { + pub fn new(error_status_code: u16, error_message: String) -> Self { + Self { + status: error_status_code, + message: error_message, + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.message.as_str()) + } +} + +impl From for Error { + fn from(error: std::io::Error) -> Self { + Self::new(500, format!("Unknown io error: {}", error)) + } +} + +impl From for Error { + fn from(error: std::string::FromUtf8Error) -> Self { + Self::new(500, format!("Unknown from utf8 error: {}", error)) + } +} + +impl From for Error { + fn from(error: DieselError) -> Self { + match error { + DieselError::DatabaseError(kind, err) => match kind { + diesel::result::DatabaseErrorKind::UniqueViolation => { + Self::new(409, err.message().to_string()) + } + _ => Self::new(500, err.message().to_string()), + }, + DieselError::NotFound => Self::new(404, "The record was not found".to_string()), + DieselError::SerializationError(err) => Self::new(422, err.to_string()), + err => Self::new(500, format!("Unknown database error: {}", err)), + } + } +} + +impl From for Error { + fn from(error: reqwest::Error) -> Self { + Self::new(500, format!("Unknown reqwest error: {}", error)) + } +} + +impl From for Error { + fn from(error: serde_json::Error) -> Self { + Self::new(500, format!("Unknown serde_json error: {}", error)) + } +} + +impl From for Error { + fn from(error: serenity::Error) -> Self { + Self::new(500, format!("Unknown serenity error: {}", error)) + } +} + +impl From for Error { + fn from(error: redis::RedisError) -> Self { + Self::new(500, format!("Unknown redis error: {}", error)) + } +} + +impl From for Error { + fn from(error: uuid::Error) -> Self { + Self::new(500, format!("Unknown uuid error: {}", error)) + } +} + +impl From for Error { + fn from(error: std::env::VarError) -> Self { + Self::new(500, format!("Unknown env error: {}", error)) + } +} + +impl From for Error { + fn from(error: songbird::error::JoinError) -> Self { + Self::new(500, format!("Unable to join channel: {}", error)) + } +} diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index dcf1d08..0000000 --- a/src/lib.rs +++ /dev/null @@ -1,117 +0,0 @@ -use diesel::result::Error as DieselError; -use serde::{Serialize, Deserialize}; -use std::fmt; - -#[derive(Serialize, Deserialize)] -pub struct Message { - pub id: String, - pub guild_id: i64, - pub channel_id: i64, - pub user_id: i64, - pub created: i64, - pub model: String, - pub request: String, - pub response: String, - pub request_tags: Vec, - pub response_tags: Vec, -} - -#[derive(Serialize, Deserialize)] -pub struct Response { - pub data: T, - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option, -} - -#[derive(Serialize, Deserialize)] -pub struct Metadata { - pub total: i32, - pub limit: i32, - pub page: i32, - pub pages: i32, -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct ServiceError { - pub status: u16, - pub message: String, -} - -impl ServiceError { - pub fn new(error_status_code: u16, error_message: String) -> ServiceError { - ServiceError { - status: error_status_code, - message: error_message, - } - } -} - -impl fmt::Display for ServiceError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(self.message.as_str()) - } -} - -impl From for ServiceError { - fn from(error: std::io::Error) -> ServiceError { - ServiceError::new(500, format!("Unknown io error: {}", error)) - } -} - -impl From for ServiceError { - fn from(error: std::string::FromUtf8Error) -> ServiceError { - ServiceError::new(500, format!("Unknown from utf8 error: {}", error)) - } -} - -impl From for ServiceError { - fn from(error: DieselError) -> ServiceError { - match error { - DieselError::DatabaseError(kind, err) => match kind { - diesel::result::DatabaseErrorKind::UniqueViolation => { - ServiceError::new(409, err.message().to_string()) - } - _ => ServiceError::new(500, err.message().to_string()), - }, - DieselError::NotFound => ServiceError::new(404, "The record was not found".to_string()), - DieselError::SerializationError(err) => ServiceError::new(422, err.to_string()), - err => ServiceError::new(500, format!("Unknown database error: {}", err)), - } - } -} - -impl From for ServiceError { - fn from(error: reqwest::Error) -> ServiceError { - ServiceError::new(500, format!("Unknown reqwest error: {}", error)) - } -} - -impl From for ServiceError { - fn from(error: serde_json::Error) -> ServiceError { - ServiceError::new(500, format!("Unknown serde_json error: {}", error)) - } -} - -impl From for ServiceError { - fn from(error: serenity::Error) -> ServiceError { - ServiceError::new(500, format!("Unknown serenity error: {}", error)) - } -} - -impl From for ServiceError { - fn from(error: redis::RedisError) -> ServiceError { - ServiceError::new(500, format!("Unknown redis error: {}", error)) - } -} - -impl From for ServiceError { - fn from(error: uuid::Error) -> ServiceError { - ServiceError::new(500, format!("Unknown uuid error: {}", error)) - } -} - -impl From for ServiceError { - fn from(error: std::env::VarError) -> ServiceError { - ServiceError::new(500, format!("Unknown env error: {}", error)) - } -} diff --git a/src/main.rs b/src/main.rs index 479314d..c491b3e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,18 +5,24 @@ extern crate diesel_migrations; use std::env; use std::collections::HashSet; use std::sync::Arc; -use serenity::client::Cache; -use serenity::framework::StandardFramework; use serenity::http::Http; use serenity::prelude::*; use songbird::{SerenityInit, Songbird}; +use reqwest::Client as HttpClient; use crate::bot::handler::Handler; mod bot; mod dnd; +mod error; mod storage; +pub struct HttpKey; + +impl TypeMapKey for HttpKey { + type Value = HttpClient; +} + #[tokio::main] async fn main() { dotenv::dotenv().ok(); @@ -27,25 +33,25 @@ async fn main() { let intents: GatewayIntents = GatewayIntents::all(); let http: Http = Http::new(&token); - let (owners, _bot_id) = match http.get_current_application_info().await { + let (_owners, _bot_id) = match http.get_current_application_info().await { Ok(info) => { let mut owners: HashSet = HashSet::new(); if let Some(team) = info.team { owners.insert(team.owner_user_id); } else { - owners.insert(info.owner.id); + owners.insert(info.owner.unwrap().id); } match http.get_current_user().await { Ok(bot) => (owners, bot.id), - Err(why) => panic!("Could not access the bot id: {:?}", why), + Err(why) => panic!("Could not access the bot id: {why:?}"), } } - Err(why) => panic!("Could not access application info: {:?}", why), + Err(why) => panic!("Could not access application info: {why:?}"), }; let handler = match env::var("OPENAI_API_KEY") { Ok(token) => { - log::info!("Loaded OpenAI token"); + log::info!("OpenAI functionality enabled"); let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-3.5-turbo".to_string()); Handler { oai: Some(bot::oai::OAI { @@ -61,7 +67,8 @@ async fn main() { } } Err(err) => { - log::warn!("Could not load OpenAI token: {}", err); + log::trace!("No OPENAI_API_KEY found: {err}"); + log::warn!("OpenAI functionality disabled"); Handler { oai: None } } }; @@ -70,8 +77,9 @@ async fn main() { let mut client = Client::builder(token, intents) .event_handler(handler) - .framework(StandardFramework::new().configure(|c| c.owners(owners))) + // .framework(StandardFramework::new().configure(|c| c.owners(owners))) .register_songbird_with(Arc::clone(&songbird)) + .type_map_insert::(HttpClient::new()) .await .expect("Error creating client"); @@ -81,11 +89,4 @@ async fn main() { if let Err(why) = client.start_autosharded().await { log::error!("Client error: {why:?}"); } - -} - -pub struct AppState { - pub http: Arc, - pub cache: Arc, - pub songbird: Arc, } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index b3bfd2e..0350a48 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,7 +1,6 @@ use diesel::{r2d2::ConnectionManager as DieselConnectionManager, PgConnection}; -use redis::{Client as RedisClient, aio::Connection as RedisConnection}; -use siren::ServiceError; -use crate::diesel_migrations::MigrationHarness; +use redis::{aio::MultiplexedConnection, Client as RedisClient}; +use crate::{diesel_migrations::MigrationHarness, error::{Error as SirenError, SirenResult}}; use lazy_static::lazy_static; use log::{error, info}; use r2d2; @@ -49,18 +48,18 @@ pub async fn init() { }; } -pub fn connection() -> Result { +pub fn connection() -> SirenResult { POOL .get() - .map_err(|e| ServiceError::new(500, format!("Failed getting db connection: {}", e))) + .map_err(|e| SirenError::new(500, format!("Failed getting db connection: {}", e))) } -pub fn redis_connection() -> Result { +pub fn redis_connection() -> SirenResult { let conn = REDIS.get_connection()?; Ok(conn) } -pub async fn redis_async_connection() -> Result { - let conn = REDIS.get_async_connection().await?; +pub async fn redis_async_connection() -> SirenResult { + let conn = REDIS.get_multiplexed_async_connection().await?; Ok(conn) }