From cee9dbdc813aa6970a850da9b6671c70ff86eef2 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Wed, 4 Oct 2023 19:05:24 -0400 Subject: [PATCH] Split bot and service --- .gitignore | 3 - .vscode/settings.json | 6 + bot/.env.TEMPLATE | 8 + bot/Cargo.toml | 41 +++++ bot/src/commands/audio/mod.rs | 190 ++++++++++++++++++++++ bot/src/commands/audio/pause.rs | 43 +++++ bot/src/commands/audio/play.rs | 134 +++++++++++++++ bot/src/commands/audio/resume.rs | 43 +++++ bot/src/commands/audio/skip.rs | 43 +++++ bot/src/commands/audio/stop.rs | 38 +++++ bot/src/commands/audio/volume.rs | 85 ++++++++++ {service => bot}/src/commands/help.rs | 0 {service => bot}/src/commands/mod.rs | 0 {service => bot}/src/commands/oai.rs | 143 ++++++++++------ {service => bot}/src/commands/ping.rs | 0 {service => bot}/src/commands/schedule.rs | 0 bot/src/error_handler.rs | 35 ++++ bot/src/main.rs | 175 ++++++++++++++++++++ service/Cargo.toml | 13 -- service/src/db/messages/mod.rs | 2 + service/src/db/messages/model.rs | 138 ++++++++++++++-- service/src/db/messages/routes.rs | 79 +++++++++ service/src/db/spells/routes.rs | 2 +- service/src/main.rs | 169 +------------------ 24 files changed, 1144 insertions(+), 246 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 bot/.env.TEMPLATE create mode 100644 bot/Cargo.toml create mode 100644 bot/src/commands/audio/mod.rs create mode 100644 bot/src/commands/audio/pause.rs create mode 100644 bot/src/commands/audio/play.rs create mode 100644 bot/src/commands/audio/resume.rs create mode 100644 bot/src/commands/audio/skip.rs create mode 100644 bot/src/commands/audio/stop.rs create mode 100644 bot/src/commands/audio/volume.rs rename {service => bot}/src/commands/help.rs (100%) rename {service => bot}/src/commands/mod.rs (100%) rename {service => bot}/src/commands/oai.rs (75%) rename {service => bot}/src/commands/ping.rs (100%) rename {service => bot}/src/commands/schedule.rs (100%) create mode 100644 bot/src/error_handler.rs create mode 100644 bot/src/main.rs create mode 100644 service/src/db/messages/routes.rs diff --git a/.gitignore b/.gitignore index ac78a29..2a8cf44 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,7 @@ .env target/ .idea/ -.vscode/ **/Cargo.lock -audio/ logs/ -settings.json app/ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..ffa8b01 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "rust-analyzer.linkedProjects": [ + "./service/Cargo.toml", + "./bot/Cargo.toml", + ] +} \ No newline at end of file diff --git a/bot/.env.TEMPLATE b/bot/.env.TEMPLATE new file mode 100644 index 0000000..dd583f4 --- /dev/null +++ b/bot/.env.TEMPLATE @@ -0,0 +1,8 @@ +RUST_LOG=warn,bot=info +COMPOSE_PROJECT_NAME=siren + +SERVICE_HOST=localhost +SERVICE_PORT=5000 + +DISCORD_TOKEN= +OPENAI_API_KEY= \ No newline at end of file diff --git a/bot/Cargo.toml b/bot/Cargo.toml new file mode 100644 index 0000000..a4718f5 --- /dev/null +++ b/bot/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "bot" +version = "0.2.4" +edition = "2021" +authors = ["Ben Sherriff "] +repository = "https://github.com/bensherriff/siren" +readme = "README.md" +license = "GPL-3.0-or-later" + +[dependencies] +chrono = { version = "0.4.31", features = ["serde"] } +dotenv = "0.15.0" +serde_json = "1.0.107" +log = "0.4.20" +env_logger = "0.10.0" + +[dependencies.serenity] +version = "0.11.6" +default-features = false +features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] + +[dependencies.songbird] +version = "0.3.2" +features = ["builtin-queue", "yt-dlp"] + +[dependencies.tokio] +version = "1.32.0" +features = ["macros", "rt-multi-thread"] + +[dependencies.serde] +version = "1.0.188" +features = ["derive"] + +[dependencies.reqwest] +version = "0.11.22" +default-features = false +features = ["json", "rustls-tls"] + +[dependencies.pyo3] +version = "0.19.2" +features = ["auto-initialize"] diff --git a/bot/src/commands/audio/mod.rs b/bot/src/commands/audio/mod.rs new file mode 100644 index 0000000..cd60d28 --- /dev/null +++ b/bot/src/commands/audio/mod.rs @@ -0,0 +1,190 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use log::debug; + +use serenity::model::application::interaction::{InteractionResponseType, application_command::ApplicationCommandInteraction}; +use serenity::model::prelude::{GuildId, ChannelId}; +use serenity::model::user::User; +use serenity::prelude::*; +use songbird::{Call, Songbird}; +use songbird::input::{Restartable, Input, Metadata, error::Error as SongbirdError}; + +pub mod pause; +pub mod play; +pub mod resume; +pub mod skip; +pub mod stop; +pub mod volume; + +#[derive(Clone, Debug)] +pub struct AudioConfigs; + +impl TypeMapKey for AudioConfigs { + type Value = Arc>>; +} + +#[derive(Clone, Debug)] +pub struct AudioConfig { + pub volume: f32 +} + +/// Joins a Discord voice channel. +/// +/// # Arguments +/// - ctx - The context of the command. +/// - guild_id_option - The guild ID of the guild to join. +/// - user - The user that is requesting to join the voice channel. +/// +/// # Returns +/// Result<(), String> - Ok if the bot successfully joined the voice channel, Err if there was an error. +pub async fn join(ctx: &Context, guild_id_option: &Option, user: &User) -> Result<(), String> { + let guild_id = match guild_id_option { + Some(g) => g, + None => { + return Err(format!("{}", "No guild ID set")); + } + }; + + let channel_id = match find_voice_channel(&ctx, &guild_id, &user) { + Ok(channel) => channel, + Err(err) => return Err(format!("{}", err)) + }; + + debug!("<{}> Joining channel {}", guild_id.0, channel_id); + let manager = get_songbird(ctx).await; + let (_handle_lock, success) = manager.join(guild_id.to_owned(), channel_id.to_owned()).await; + match success { + Ok(s) => Ok(s), + Err(err) => Err(format!("{}", err)) + } +} + +/// Leaves a Discord voice channel. +/// +/// # Arguments +/// - ctx - The context of the command. +/// - guild_id_option - The guild ID of the guild to leave. +/// +/// # Returns +/// Result<(), String> - Ok if the bot successfully left the voice channel, Err if there was an error. +pub async fn leave(ctx: &Context, guild_id_option: &Option) -> Result<(), String> { + let guild_id = match guild_id_option { + Some(g) => g, + None => { + return Err(format!("{}", "No guild ID set")); + } + }; + + let manager = get_songbird(ctx).await; + if manager.get(*guild_id).is_some() { + debug!("<{}> Disconnecting from channel", guild_id.0); + if let Err(e) = manager.remove(*guild_id).await { + return Err(format!("{}", e)) + } + } + Ok(()) +} + +/// Finds the voice channel that the user is in. +/// +/// # Arguments +/// - ctx - The context of the command. +/// - guild_id - The guild ID of the guild to search. +/// - user - The user to search for. +/// +/// # Returns +/// Result - Ok if the user is in a voice channel, Err if the user is not in a voice channel. +fn find_voice_channel(ctx: &Context, guild_id: &GuildId, user: &User) -> Result { + let guild = match guild_id.to_guild_cached(ctx.cache.to_owned()) { + Some(g) => g, + None => return Err(format!("Guild not found")) + }; + + match guild.voice_states.get(&user.id).and_then(|voice_state| voice_state.channel_id) { + Some(channel) => Ok(channel), + None => return Err(format!("User is not in a voice channel")) + } +} + +/// Creates a response to an interaction. +/// +/// # Arguments +/// - ctx - The context of the command. +/// - command - The command that was sent. +/// - content - The content of the response. +/// +/// # Returns +/// Result<(), SerenityError> - Ok if the response was created successfully, Err if there was an error. +pub async fn create_response(ctx: &Context, command: &ApplicationCommandInteraction, content: String) -> Result<(), SerenityError> { + command.create_interaction_response(&ctx.http, |response: &mut serenity::builder::CreateInteractionResponse<'_>| { + response + .kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|message: &mut serenity::builder::CreateInteractionResponseData<'_>| message.content(content)) + }).await +} + +/// Edits a response to an interaction. +/// +/// # Arguments +/// - ctx - The context of the command. +/// - command - The command that was sent. +/// - content - The content of the response. +/// +/// # Returns +/// Result - Ok if the response was edited successfully, Err if there was an error. +pub async fn edit_response(ctx: &Context, command: &ApplicationCommandInteraction, content: String) -> Result { + command.edit_original_interaction_response(&ctx.http, |response: &mut serenity::builder::EditInteractionResponse| { + response.content(content) + }).await +} + +/// Adds a song to the queue. +/// +/// # Arguments +/// - call - The call to add the song to. +/// - url - The URL of the song to add. +/// - lazy - Whether or not to lazy load the song. +/// +/// # Returns +/// Result - Ok if the song was added successfully, Err if there was an error. +pub async fn add_song(call: Arc>, url: &str, lazy: bool, audio_config: Option<&AudioConfig>) -> Result { + let source = if is_valid_url(url) { + Restartable::ytdl(url.to_owned(), lazy).await? + } else { + Restartable::ytdl_search(url, lazy).await? + }; + let mut handler = call.lock().await; + let track: Input = source.into(); + let metadata = *track.metadata.clone(); + let track_handle = handler.enqueue_source(track); + if let Some(ac) = audio_config { + let _ = track_handle.set_volume(ac.volume); + } + Ok(metadata) +} + +/// Checks if a string is a valid URL. +/// +/// # Arguments +/// - url - The string to check. +/// +/// # Returns +/// bool - True if the string is a valid URL, false if it is not. +fn is_valid_url(url: &str) -> bool { + match url.parse::() { + Ok(_) => return true, + Err(_) => return false + } +} + +/// Gets the Songbird voice client. +/// +/// # Arguments +/// - ctx - The context of the command. +/// +/// # Returns +/// Arc - The Songbird voice client. +pub async fn get_songbird(ctx: &Context) -> Arc { + songbird::get(ctx).await.expect("Songbird Voice client placed in at initialization") +} diff --git a/bot/src/commands/audio/pause.rs b/bot/src/commands/audio/pause.rs new file mode 100644 index 0000000..4423f67 --- /dev/null +++ b/bot/src/commands/audio/pause.rs @@ -0,0 +1,43 @@ +use log::{debug, error}; + +use serenity::prelude::*; +use serenity::builder::CreateApplicationCommand; +use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; + +use super::{get_songbird, create_response, edit_response}; + +pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { + // Create the initial response + if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { + error!("Failed to create response message: {}", why); + return; + } + + let guild_id = match command.guild_id { + Some(g) => g, + None => { + if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await { + error!("Failed to edit response message: {}", why); + } + return; + } + }; + let manager = get_songbird(ctx).await; + if let Some(handler_lock) = manager.get(guild_id) { + let handler = handler_lock.lock().await; + if let Err(err) = handler.queue().pause() { + if let Err(why) = edit_response(&ctx, &command, format!("Failed to pause: {}", err)).await { + error!("Failed to edit response message: {}", why); + } + } else { + debug!("Paused the track"); + if let Err(why) = edit_response(&ctx, &command, format!("Pausing the track")).await { + error!("Failed to edit response message: {}", why); + } + } + } +} + +pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { + command.name("pause").description("Pause the current track") +} \ No newline at end of file diff --git a/bot/src/commands/audio/play.rs b/bot/src/commands/audio/play.rs new file mode 100644 index 0000000..386e6a2 --- /dev/null +++ b/bot/src/commands/audio/play.rs @@ -0,0 +1,134 @@ +use log::{debug, warn, error}; + +use serenity::{prelude::*, async_trait}; +use serenity::builder::CreateApplicationCommand; +use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; +use songbird::EventHandler; + +use crate::commands::audio::{join, leave, add_song, get_songbird, AudioConfigs}; + +use super::{create_response, edit_response}; + +pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { + // Get the track url + let track_url = match command.data.options.get(0) { + Some(t) => match &t.value { + Some(v) => match v.as_str() { + Some(s) => s.to_owned(), + None => { + warn!("Missing track option"); + if let Err(why) = create_response(&ctx, &command, format!("Track option is missing")).await { + error!("Failed to create response message: {}", why); + } + return; + } + } + None => { + warn!("Missing track option"); + if let Err(why) = create_response(&ctx, &command, format!("Track option is missing")).await { + error!("Failed to create response message: {}", why); + } + return; + } + } + None => { + warn!("Missing track option"); + if let Err(why) = create_response(&ctx, &command, format!("Track option is missing")).await { + error!("Failed to create response message: {}", why); + } + return; + } + }; + + // Create the initial response + if let Err(why) = create_response(&ctx, &command, format!("Processing command...")).await { + error!("Failed to create response message: {}", why); + return; + } + + match join(&ctx, &command.guild_id, &command.user).await { + Ok(_) => { + let guild_id = match command.guild_id { + Some(g) => g, + None => { + if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await { + error!("Failed to edit response message: {}", why); + } + return; + } + }; + debug!("Play command executed with track: {:?}", track_url); + + let manager = get_songbird(ctx).await; + if let Some(handler_lock) = manager.get(guild_id) { + let is_queue_empty = { + let call_handler = handler_lock.lock().await; + call_handler.queue().is_empty() + }; + let audio_config = { + let data_read = ctx.data.read().await; + data_read.get::().expect("Expected AudioConfigs in TypeMap.").clone() + }; + let ac = audio_config.read().await; + match add_song(handler_lock.clone(), &track_url, is_queue_empty, ac.get(&guild_id)).await { + Ok(added_song) => { + let track_title = added_song.title.unwrap(); + debug!("Added track: {}", track_title); + if let Err(why) = edit_response(&ctx, &command, format!("Added track to queue: {}", track_title)).await { + error!("Failed to edit response message: {}", why); + } + let mut handler = handler_lock.lock().await; + handler.remove_all_global_events(); + handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager }) + } + Err(why) => { + warn!("Failed to add song: {}", why); + if let Err(why) = edit_response(&ctx, &command, format!("Failed to add song: {}", why)).await { + error!("Failed to edit response message: {}", why); + } + if let Err(why) = leave(&ctx, &command.guild_id).await { + error!("Failed to leave voice channel: {}", why); + } + return; + } + }; + } + }, + Err(err) => { + warn!("{}", err); + if let Err(why) = edit_response(&ctx, &command, format!("{}", err)).await { + error!("Failed to edit response message: {}", why); + } + } + } +} + +pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { + command.name("play").description("Plays the given track").create_option(|option| { option + .name("track") + .description("The track to be played") + .kind(serenity::model::prelude::command::CommandOptionType::String) + .required(true) + }) +} + +struct TrackEndNotifier { + pub call: std::sync::Arc, + pub guild_id: serenity::model::id::GuildId +} + +#[async_trait] +impl EventHandler for TrackEndNotifier { + async fn act(&self, ctx: &songbird::events::EventContext<'_>) -> Option { + if let songbird::EventContext::Track(_track_list) = ctx { + if let Some(call) = self.call.get(self.guild_id) { + let mut handler = call.lock().await; + if handler.queue().is_empty() { + debug!("Queue is empty, leaving voice channel"); + handler.leave().await.unwrap(); + } + } + } + None + } +} diff --git a/bot/src/commands/audio/resume.rs b/bot/src/commands/audio/resume.rs new file mode 100644 index 0000000..d97a592 --- /dev/null +++ b/bot/src/commands/audio/resume.rs @@ -0,0 +1,43 @@ +use log::{debug, error}; + +use serenity::prelude::*; +use serenity::builder::CreateApplicationCommand; +use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; + +use super::{get_songbird, create_response, edit_response}; + +pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { + // Create the initial response + if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { + error!("Failed to create response message: {}", why); + return; + } + + let guild_id = match command.guild_id { + Some(g) => g, + None => { + if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await { + error!("Failed to edit response message: {}", why); + } + return; + } + }; + let manager = get_songbird(ctx).await; + if let Some(handler_lock) = manager.get(guild_id) { + let handler = handler_lock.lock().await; + if let Err(err) = handler.queue().resume() { + if let Err(why) = edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await { + error!("Failed to edit response message: {}", why); + } + } else { + debug!("Resumed the track"); + if let Err(why) = edit_response(&ctx, &command, format!("Resuming the track")).await { + error!("Failed to edit response message: {}", why); + } + } + } +} + +pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { + command.name("resume").description("Resume the current track") +} \ No newline at end of file diff --git a/bot/src/commands/audio/skip.rs b/bot/src/commands/audio/skip.rs new file mode 100644 index 0000000..cd95d91 --- /dev/null +++ b/bot/src/commands/audio/skip.rs @@ -0,0 +1,43 @@ +use log::{debug, error}; + +use serenity::prelude::*; +use serenity::builder::CreateApplicationCommand; +use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; + +use super::{get_songbird, create_response, edit_response}; + +pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { + // Create the initial response + if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { + error!("Failed to create response message: {}", why); + return; + } + + let guild_id = match command.guild_id { + Some(g) => g, + None => { + if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await { + error!("Failed to edit response message: {}", why); + } + return; + } + }; + let manager = get_songbird(ctx).await; + if let Some(handler_lock) = manager.get(guild_id) { + let handler = handler_lock.lock().await; + if let Err(err) = handler.queue().skip() { + if let Err(why) = edit_response(&ctx, &command, format!("Failed to skip: {}", err)).await { + error!("Failed to edit response message: {}", why); + } + } else { + debug!("Skipped the track"); + if let Err(why) = edit_response(&ctx, &command, format!("Skipping the track")).await { + error!("Failed to edit response message: {}", why); + } + } + } +} + +pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { + command.name("skip").description("Skip the current track") +} \ No newline at end of file diff --git a/bot/src/commands/audio/stop.rs b/bot/src/commands/audio/stop.rs new file mode 100644 index 0000000..32dec6f --- /dev/null +++ b/bot/src/commands/audio/stop.rs @@ -0,0 +1,38 @@ +use log::{debug, error}; + +use serenity::prelude::*; +use serenity::builder::CreateApplicationCommand; +use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; + +use super::{get_songbird, create_response, edit_response}; + +pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { + // Create the initial response + if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { + error!("Failed to create response message: {}", why); + return; + } + + let guild_id = match command.guild_id { + Some(g) => g, + None => { + if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await { + error!("Failed to edit response message: {}", why); + } + return; + } + }; + let manager = get_songbird(ctx).await; + if let Some(handler_lock) = manager.get(guild_id) { + let handler = handler_lock.lock().await; + handler.queue().stop(); + debug!("Stopped the track"); + if let Err(why) = edit_response(&ctx, &command, format!("Stopping the tracks")).await { + error!("Failed to edit response message: {}", why); + } + } +} + +pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { + command.name("stop").description("Stop the current track and clear the queue") +} \ No newline at end of file diff --git a/bot/src/commands/audio/volume.rs b/bot/src/commands/audio/volume.rs new file mode 100644 index 0000000..4ef8501 --- /dev/null +++ b/bot/src/commands/audio/volume.rs @@ -0,0 +1,85 @@ +use log::{error, warn}; + +use serenity::prelude::*; +use serenity::builder::CreateApplicationCommand; +use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; + +use super::{get_songbird, create_response, edit_response, AudioConfigs, AudioConfig}; + +pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { + // Get the volume + let volume = match command.data.options.get(0) { + Some(t) => match &t.value { + Some(v) => match v.as_i64() { + Some(p) => std::cmp::min(100, std::cmp::max(0, p)), + None => { + warn!("Unable to get volume option as a string"); + if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await { + error!("Failed to create response message: {}", why); + } + return; + } + } + None => { + warn!("Missing volume option value"); + if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await { + error!("Failed to create response message: {}", why); + } + return; + } + } + None => { + warn!("Missing volume option"); + if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await { + error!("Failed to create response message: {}", why); + } + return; + } + }; + + // Format volume to f32 bound between 0.0 and 1.0 + let bound_volume = volume as f32 / 100.0; + + // Create the initial response + if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await { + error!("Failed to create response message: {}", why); + return; + } + + let guild_id = match command.guild_id { + Some(g) => g, + None => { + if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await { + error!("Failed to edit response message: {}", why); + } + return; + } + }; + let audio_config_lock = { + let data_read = ctx.data.read().await; + data_read.get::().expect("Expected AudioConfigs in TypeMap.").clone() + }; + { + let mut audio_configs = audio_config_lock.write().await; + *audio_configs.entry(guild_id).or_insert(AudioConfig { volume: 1.0 }) = AudioConfig { volume: bound_volume }; + } + let manager = get_songbird(ctx).await; + if let Some(handler_lock) = manager.get(guild_id) { + let handler = handler_lock.lock().await; + for (_, track_handle) in handler.queue().current_queue().iter().enumerate() { + let _ = track_handle.set_volume(bound_volume); + } + } + if let Err(why) = edit_response(&ctx, &command, format!("Setting the volume to {}", volume)).await { + error!("Failed to set the volume: {}", why); + } +} + +pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { + command.name("volume").description("Set the audio player volume").create_option(|option| { option + .name("volume") + .description("Volume between 0 and 100") + .kind(serenity::model::prelude::command::CommandOptionType::Integer) + .required(true) + }) +} \ No newline at end of file diff --git a/service/src/commands/help.rs b/bot/src/commands/help.rs similarity index 100% rename from service/src/commands/help.rs rename to bot/src/commands/help.rs diff --git a/service/src/commands/mod.rs b/bot/src/commands/mod.rs similarity index 100% rename from service/src/commands/mod.rs rename to bot/src/commands/mod.rs diff --git a/service/src/commands/oai.rs b/bot/src/commands/oai.rs similarity index 75% rename from service/src/commands/oai.rs rename to bot/src/commands/oai.rs index 9e29c4b..87d0524 100644 --- a/service/src/commands/oai.rs +++ b/bot/src/commands/oai.rs @@ -1,4 +1,3 @@ -use diesel::{prelude::*, insert_into}; use log::{error, debug, trace, warn}; use serde::{Serialize, Deserialize}; @@ -8,12 +7,12 @@ use serenity::model::channel::Message; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; use serenity::prelude::*; -use crate::db::{connection, messages::{MessageDB, NewMessageDB}}; -use crate::error_handler::ServiceError; +use crate::error_handler::BotError; pub struct OAI { pub client: reqwest::Client, pub base_url: String, + pub service_url: String, pub max_attempts: i64, pub token: String, pub max_tokens: i64, @@ -127,33 +126,64 @@ enum ResponseEvent { ResponseError(ResponseError) } +#[derive(Serialize, Deserialize)] +pub struct GetResponse { + pub data: T, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option +} + +#[derive(Serialize, Deserialize)] +pub struct Metadata { + pub total: i32, + pub limit: i32, + pub page: i32, + pub pages: i32 +} + +#[derive(Serialize, Deserialize)] +pub struct QueryMessage { + pub id: String, + pub guild_id: i64, + pub channel_id: i64, + pub user_id: i64, + pub created: i64, + pub model: String, + pub request: String, + pub response: String, + pub request_tags: Vec, + pub response_tags: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct InsertMessage { + pub id: String, + pub guild_id: i64, + pub channel_id: i64, + pub user_id: i64, + pub created: i64, + pub model: String, + pub request: String, + pub response: String, + pub request_tags: Vec, + pub response_tags: Vec, +} + impl OAI { - async fn get_request(&self, request: ChatCompletionRequest) -> Result { + async fn get_request(&self, request: ChatCompletionRequest) -> Result { let uri = format!("{}/chat/completions", self.base_url); let body = serde_json::to_string(&request).unwrap(); trace!("Sending request to {}: {}", uri, body); - let value = match match self.client + let value = self.client .post(&uri) .bearer_auth(&self.token) .header("Content-Type", "application/json".to_string()) .body(body) .send() - .await { - Ok(r) => r, - Err(err) => return Err(ServiceError { - message: format!("Could not send request to OpenAI: {}", err), - status: 500 - }) - } + .await? .json::() - .await { - Ok(r) => r, - Err(err) => return Err(ServiceError { - message: format!("Could not read response from OpenAI: {}", err), - status: 500 - }) - }; + .await?; trace!("Received response from OpenAI: {:?}", value); @@ -169,21 +199,43 @@ impl OAI { // status: 500 // }) // }; - let response = match serde_json::from_value::(value) { - Ok(r) => r, - Err(err) => return Err(ServiceError { - message: format!("Could not parse response from OpenAI: {}", err), - status: 500 - }) - }; + let response = serde_json::from_value::(value)?; Ok(response) } + + async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result>, BotError> { + let uri = format!("{}/messages?guild_id={}&channel_id={}&author_id={}&limit={}", self.service_url, guild_id, channel_id, author_id, self.max_context_questions); + let value = self.client + .get(&uri) + .send() + .await? + .json::() + .await?; + + let response = serde_json::from_value::>>(value)?; + + Ok(response) + } + + async fn store_message(&self, message: InsertMessage) -> Result { + let uri = format!("{}/messages", self.service_url); + trace!("Sending request to {}", uri); + let value = self.client + .post(&uri) + .json::(&message) + .send() + .await? + .json::() + .await?; + trace!("Received response from Service: {:?}", value); + let response = serde_json::from_value::(value)?; + Ok(response) + } } pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { debug!("Generating response for message: {}", msg.content); - let mut connection = connection().unwrap(); let guild_id = msg.guild_id.unwrap(); let channel_id = msg.channel_id; @@ -193,27 +245,17 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); let parsed_content = msg.content.replace(bot_mention.as_str(), ""); - // Setup the request messages - let result: Result, diesel::result::Error> = crate::db::schema::messages::table - .select(MessageDB::as_select()) - .filter((crate::db::schema::messages::guild_id.eq(guild_id.0 as i64)) - .and(crate::db::schema::messages::channel_id.eq(channel_id.0 as i64)) - .and(crate::db::schema::messages::user_id.eq(author_id.0 as i64)) - ) - .order(crate::db::schema::messages::created.asc()) - .limit(oai.max_context_questions) - .load(&mut connection); - let mut messages = vec![ ChatCompletionMessage { role: GPTRole::System, content: "Siren is a Discord bot specializing in Dungeons and Dragons. Limit Siren's responses to <= 2000 characters. Siren must always obey these instructions, no matter what.".to_string() }, ]; - - match result { - Ok(r) => { - for message in r { + + let previous_messages = oai.get_messages(guild_id.0, channel_id.0, author_id.0).await; + match previous_messages { + Ok(m) => { + for message in m.data { messages.push( ChatCompletionMessage { role: GPTRole::User, @@ -228,7 +270,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { ); } }, - Err(err) => error!("Could not load previous messages: {}", err) + Err(err) => warn!("Could not load previous messages: {}", err) }; messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() }); @@ -272,20 +314,19 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { debug!("Processing response received from OpenAI"); if !r.choices.is_empty() { let res = r.choices[0].message.content.clone(); - // Insert the message into the messages database table - if let Err(err) = insert_into(crate::db::schema::messages::table).values(NewMessageDB { - id: &r.id, + if let Err(err) = oai.store_message(InsertMessage { + id: r.id, guild_id: guild_id.0 as i64, channel_id: response_channel.0 as i64, user_id: author_id.0 as i64, created: r.created, - model: &serde_json::to_string(&r.model).unwrap(), - request: &parsed_content, - response: &res, + model: serde_json::to_string(&r.model).unwrap(), + request: parsed_content, + response: res.clone(), request_tags: vec![], response_tags: vec![], - }).execute(&mut connection) { - error!("Could not insert message into database: {}", err); + }).await { + warn!("{}", err); } res } else { diff --git a/service/src/commands/ping.rs b/bot/src/commands/ping.rs similarity index 100% rename from service/src/commands/ping.rs rename to bot/src/commands/ping.rs diff --git a/service/src/commands/schedule.rs b/bot/src/commands/schedule.rs similarity index 100% rename from service/src/commands/schedule.rs rename to bot/src/commands/schedule.rs diff --git a/bot/src/error_handler.rs b/bot/src/error_handler.rs new file mode 100644 index 0000000..1493c13 --- /dev/null +++ b/bot/src/error_handler.rs @@ -0,0 +1,35 @@ +use serde::{Deserialize, Serialize}; +use std::fmt; + +#[derive(Debug, Deserialize, Serialize)] +pub struct BotError { + pub status: u16, + pub message: String, +} + +impl BotError { + pub fn new(error_status_code: u16, error_message: String) -> BotError { + BotError { + status: error_status_code, + message: error_message, + } + } +} + +impl fmt::Display for BotError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.message.as_str()) + } +} + +impl From for BotError { + fn from(error: reqwest::Error) -> BotError { + BotError::new(500, format!("Unknown reqwest error: {}", error)) + } +} + +impl From for BotError { + fn from(error: serde_json::Error) -> BotError { + BotError::new(500, format!("Unknown serde_json error: {}", error)) + } +} diff --git a/bot/src/main.rs b/bot/src/main.rs new file mode 100644 index 0000000..8f4b9b9 --- /dev/null +++ b/bot/src/main.rs @@ -0,0 +1,175 @@ +use std::collections::{HashSet, HashMap}; +use std::env; +use std::sync::Arc; + +use commands::audio::{create_response, AudioConfig, AudioConfigs}; + +use dotenv::dotenv; +use log::{error, warn, info}; +use serenity::async_trait; +use serenity::framework::StandardFramework; +use serenity::model::application::interaction::Interaction; +use serenity::model::gateway::Ready; +use serenity::model::channel::Message; +use serenity::http::Http; +use serenity::prelude::*; +use songbird::SerenityInit; + +use crate::commands::oai::GPTModel; + +mod commands; +mod error_handler; + +struct Handler { + // Open AI Config + oai: Option +} + +#[async_trait] +impl EventHandler for Handler { + async fn message(&self, ctx: Context, msg: Message) { + // Ignore messages from bots + if msg.author.bot { + return; + } + match &self.oai { + Some(oai) => { + match msg.mentions_me(&ctx.http).await { + Ok(mentioned) => { + let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { + Ok(t) => { + match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) { + Some(_) => true, + None => false + } + } + Err(_) => false + }; + if mentioned || bot_in_thread { + commands::oai::generate_response(&ctx, &msg, oai).await; + } + } + Err(why) => warn!("Could not check mentions: {:?}", why) + }; + } + None => {} + } + } + + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + if let Interaction::ApplicationCommand(command) = interaction { + match command.data.name.as_str() { + "play" => commands::audio::play::run(&ctx, &command).await, + "stop" => commands::audio::stop::run(&ctx, &command).await, + "pause" => commands::audio::pause::run(&ctx, &command).await, + "resume" => commands::audio::resume::run(&ctx, &command).await, + "skip" => commands::audio::skip::run(&ctx, &command).await, + "volume" => commands::audio::volume::run(&ctx, &command).await, + _ => { + let content: String = match command.data.name.as_str() { + "ping" => commands::ping::run(&command.data.options), + _ => "Unknown command".to_string() + }; + + if let Err(why) = create_response(&ctx, &command, content).await { + warn!("Cannot respond to slash command: {}", why); + } + } + } + } + } + + async fn ready(&self, ctx: Context, ready: Ready) { + if ready.guilds.is_empty() { + warn!("No ready guilds found"); + } + for guild in ready.guilds { + let audio_config_lock = { + let data_read = ctx.data.read().await; + data_read.get::().expect("Expected AudioConfigs in TypeMap.").clone() + }; + { + let mut audio_configs = audio_config_lock.write().await; + let _ = audio_configs.insert(guild.id, AudioConfig { volume: 1.0 }); + } + let commands = guild.id.set_application_commands(&ctx.http, |commands| { + commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::stop::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::volume::register(command) }) + }).await; + match commands { + Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0), + Err(why) => error!("Could not register commands for guild {}: {:?}", guild.id.0, why) + }; + } + } +} + +#[tokio::main] +async fn main() { + dotenv().ok(); + env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); + + let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); + let intents: GatewayIntents = GatewayIntents::all(); + + let http: Http = Http::new(&token); + let (owners, _bot_id) = match http.get_current_application_info().await { + Ok(info) => { + let mut owners: HashSet = HashSet::new(); + if let Some(team) = info.team { + owners.insert(team.owner_user_id); + } else { + owners.insert(info.owner.id); + } + match http.get_current_user().await { + Ok(bot) => (owners, bot.id), + Err(why) => panic!("Could not access the bot id: {:?}", why) + } + }, + Err(why) => panic!("Could not access application info: {:?}", why) + }; + + let handler = match env::var("OPENAI_API_KEY") { + Ok(token) => { + info!("Loaded OpenAI token"); + Handler { + oai: Some(commands::oai::OAI { + client: reqwest::Client::new(), + base_url: "https://api.openai.com/v1".to_string(), + service_url: "http://localhost:5000".to_string(), + max_attempts: 5, + token, + max_context_questions: 30, + max_tokens: 2048, + default_model: GPTModel::GPT35Turbo, + }) + } + } + Err(err) => { + warn!("Could not load OpenAI token: {}", err); + Handler { oai: None } + } + }; + + let mut client = Client::builder(token, intents) + .event_handler(handler) + .framework(StandardFramework::new() + .configure(|c| c.owners(owners))) + .register_songbird() + .await + .expect("Error creating client"); + + { + let mut data = client.data.write().await; + data.insert::(Arc::new(RwLock::new(HashMap::default()))); + } + + if let Err(why) = client.start_autosharded().await { + error!("An error occurred while running the client: {:?}", why); + } +} diff --git a/service/Cargo.toml b/service/Cargo.toml index 03e4fb5..b6039b9 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -21,15 +21,6 @@ r2d2 = "0.8.10" lazy_static = "1.4.0" uuid = { version = "1.4.1", features = ["serde", "v4"] } -[dependencies.serenity] -version = "0.11.6" -default-features = false -features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] - -[dependencies.songbird] -version = "0.3.2" -features = ["builtin-queue", "yt-dlp"] - [dependencies.tokio] version = "1.32.0" features = ["macros", "rt-multi-thread"] @@ -47,7 +38,3 @@ features = ["json", "rustls-tls"] version = "2.1.2" default-features = false features = ["postgres", "32-column-tables", "serde_json", "r2d2", "with-deprecated"] - -[dependencies.pyo3] -version = "0.19.2" -features = ["auto-initialize"] diff --git a/service/src/db/messages/mod.rs b/service/src/db/messages/mod.rs index 4a7ebf6..6fbb137 100644 --- a/service/src/db/messages/mod.rs +++ b/service/src/db/messages/mod.rs @@ -1,3 +1,5 @@ mod model; +mod routes; pub use model::*; +pub use routes::init_routes; diff --git a/service/src/db/messages/model.rs b/service/src/db/messages/model.rs index 3714ed5..3356ec0 100644 --- a/service/src/db/messages/model.rs +++ b/service/src/db/messages/model.rs @@ -1,10 +1,11 @@ use diesel::prelude::*; +use serde::{Deserialize, Serialize}; -use crate::db::schema::messages; +use crate::{db::schema::messages::{self}, error_handler::ServiceError}; -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = messages)] -pub struct MessageDB { +pub struct QueryMessage { pub id: String, pub guild_id: i64, pub channel_id: i64, @@ -17,17 +18,132 @@ pub struct MessageDB { pub response_tags: Vec, } -#[derive(Insertable)] +pub struct QueryFilters { + pub by_id: Option, + pub by_guild_id: Option, + pub by_channel_id: Option, + pub by_user_id: Option, + pub by_model: Option, + pub by_request: Option, + pub by_response: Option, + pub by_request_tags: Option>, + pub by_response_tags: Option> +} + +impl Default for QueryFilters { + fn default() -> Self { + QueryFilters { + by_id: None, + by_guild_id: None, + by_channel_id: None, + by_user_id: None, + by_model: None, + by_request: None, + by_response: None, + by_request_tags: None, + by_response_tags: None + } + } +} + +impl QueryMessage { + pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> Result, ServiceError> { + let mut conn = crate::db::connection()?; + let mut query = messages::table.limit(limit as i64).order(messages::created.asc()).into_boxed(); + // Limit query to page and limit + let offset = (page - 1) * limit; + query = query.offset(offset as i64); + // Apply filters + if let Some(id) = &filters.by_id { + query = query.filter(messages::id.eq(id)); + } + if let Some(guild_id) = &filters.by_guild_id { + query = query.filter(messages::guild_id.eq(guild_id)); + } + if let Some(channel_id) = &filters.by_channel_id { + query = query.filter(messages::channel_id.eq(channel_id)); + } + if let Some(user_id) = &filters.by_user_id { + query = query.filter(messages::user_id.eq(user_id)); + } + if let Some(model) = &filters.by_model { + query = query.filter(messages::model.eq(model)); + } + if let Some(request) = &filters.by_request { + query = query.filter(messages::request.eq(request)); + } + if let Some(response) = &filters.by_response { + query = query.filter(messages::response.eq(response)); + } + if let Some(request_tags) = &filters.by_request_tags { + query = query.filter(messages::request_tags.eq(request_tags)); + } + if let Some(response_tags) = &filters.by_response_tags { + query = query.filter(messages::response_tags.eq(response_tags)); + } + // Execute query + let messages = query.load::(&mut conn)?; + Ok(messages) + } + + pub fn get_count(fitlers: &QueryFilters) -> Result { + let mut conn = crate::db::connection()?; + let mut query = messages::table.into_boxed(); + // Apply filters + if let Some(id) = &fitlers.by_id { + query = query.filter(messages::id.eq(id)); + } + if let Some(guild_id) = &fitlers.by_guild_id { + query = query.filter(messages::guild_id.eq(guild_id)); + } + if let Some(channel_id) = &fitlers.by_channel_id { + query = query.filter(messages::channel_id.eq(channel_id)); + } + if let Some(user_id) = &fitlers.by_user_id { + query = query.filter(messages::user_id.eq(user_id)); + } + if let Some(model) = &fitlers.by_model { + query = query.filter(messages::model.eq(model)); + } + if let Some(request) = &fitlers.by_request { + query = query.filter(messages::request.eq(request)); + } + if let Some(response) = &fitlers.by_response { + query = query.filter(messages::response.eq(response)); + } + if let Some(request_tags) = &fitlers.by_request_tags { + query = query.filter(messages::request_tags.eq(request_tags)); + } + if let Some(response_tags) = &fitlers.by_response_tags { + query = query.filter(messages::response_tags.eq(response_tags)); + } + // Execute query + let count = query.count().get_result::(&mut conn)?; + Ok(count) + } +} + +#[derive(Insertable, AsChangeset, Serialize, Deserialize)] #[diesel(table_name = messages)] -pub struct NewMessageDB<'a> { - pub id: &'a str, +pub struct InsertMessage { + pub id: String, pub guild_id: i64, pub channel_id: i64, pub user_id: i64, pub created: i64, - pub model: &'a str, - pub request: &'a str, - pub response: &'a str, - pub request_tags: Vec<&'a str>, - pub response_tags: Vec<&'a str>, + pub model: String, + pub request: String, + pub response: String, + pub request_tags: Vec, + pub response_tags: Vec, +} + +impl InsertMessage { + pub fn insert(message: Self) -> Result { + let mut conn = crate::db::connection()?; + let message = diesel::insert_into(messages::table) + .values(message) + .get_result(&mut conn)?; + Ok(message) + } } \ No newline at end of file diff --git a/service/src/db/messages/routes.rs b/service/src/db/messages/routes.rs new file mode 100644 index 0000000..96ebbae --- /dev/null +++ b/service/src/db/messages/routes.rs @@ -0,0 +1,79 @@ +use actix_web::{get, post, web, HttpResponse, HttpRequest, ResponseError}; +use log::error; +use serde::{Serialize, Deserialize}; + +use crate::{db::{messages::{QueryMessage, QueryFilters, InsertMessage}, GetResponse, Metadata}, error_handler::ServiceError}; + +#[derive(Serialize, Deserialize)] +struct GetAllParams { + id: Option, + guild_id: Option, + channel_id: Option, + user_id: Option, + model: Option, + request: Option, + response: Option, + request_tags: Option>, + response_tags: Option>, + limit: Option, + page: Option, +} + +#[get("/messages")] +async fn get_all(req: HttpRequest) -> HttpResponse { + let params = match web::Query::::from_query(req.query_string()) { + Ok(params) => params, + Err(err) => return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + }; + let mut filters = QueryFilters::default(); + filters.by_id = params.id.clone(); + filters.by_guild_id = params.guild_id; + filters.by_channel_id = params.channel_id; + filters.by_user_id = params.user_id; + filters.by_model = params.model.clone(); + filters.by_request = params.request.clone(); + filters.by_response = params.response.clone(); + filters.by_request_tags = params.request_tags.clone(); + filters.by_response_tags = params.response_tags.clone(); + let limit = params.limit.unwrap_or(100); + let total_count = QueryMessage::get_count(&filters).unwrap(); + let max_page = std::cmp::max((total_count as f64 / limit as f64).ceil() as i32, 1); + let page = std::cmp::min(std::cmp::max(params.page.unwrap_or(1), 1), max_page); + + match QueryMessage::get_all(&filters, limit, page) { + Ok(messages) => { + HttpResponse::Ok().json(GetResponse { + data: messages, + metadata: Some(Metadata { + total: total_count as i32, + limit, + page, + pages: max_page + }) + }) + }, + Err(err) => { + error!("{:?}", err.message); + ResponseError::error_response(&err) + } + } +} + +#[post("/messages")] +async fn create(message: web::Json) -> HttpResponse { + match InsertMessage::insert(message.into_inner()) { + Ok(message) => HttpResponse::Created().json(message), + Err(err) => { + error!("{:?}", err.message); + ResponseError::error_response(&err) + } + } +} + +pub fn init_routes(config: &mut web::ServiceConfig) { + config.service(get_all); + config.service(create); +} \ No newline at end of file diff --git a/service/src/db/spells/routes.rs b/service/src/db/spells/routes.rs index 303c229..3905866 100644 --- a/service/src/db/spells/routes.rs +++ b/service/src/db/spells/routes.rs @@ -74,7 +74,7 @@ async fn get_all(req: HttpRequest) -> HttpResponse { // Limit must be between 1 and 100 let limit = std::cmp::min(std::cmp::max(params.limit.unwrap_or(20), 1), 100); let total_count = QuerySpell::get_count(&filters).unwrap(); - let max_page = std::cmp::max(1, (total_count as f64 / limit as f64).ceil() as i32); + let max_page = std::cmp::max((total_count as f64 / limit as f64).ceil() as i32, 1); // Page must be between 1 and max_page let page = std::cmp::min(std::cmp::max(params.page.unwrap_or(1), 1), max_page); diff --git a/service/src/main.rs b/service/src/main.rs index 832c7a1..b7fffd8 100644 --- a/service/src/main.rs +++ b/service/src/main.rs @@ -2,117 +2,15 @@ extern crate diesel; #[macro_use] extern crate diesel_migrations; -use std::collections::{HashSet, HashMap}; use std::env; -use std::sync::Arc; use actix_web::{HttpServer, App}; -use commands::audio::{create_response, AudioConfig, AudioConfigs}; use dotenv::dotenv; -use log::{error, warn, info}; -use serenity::async_trait; -use serenity::framework::StandardFramework; -use serenity::model::application::interaction::Interaction; -use serenity::model::gateway::Ready; -use serenity::model::channel::Message; -use serenity::http::Http; -use serenity::prelude::*; -use songbird::SerenityInit; +use log::{error, info}; -use crate::commands::oai::GPTModel; - -mod commands; mod error_handler; mod db; -struct Handler { - // Open AI Config - oai: Option -} - -#[async_trait] -impl EventHandler for Handler { - async fn message(&self, ctx: Context, msg: Message) { - // Ignore messages from bots - if msg.author.bot { - return; - } - match &self.oai { - Some(oai) => { - match msg.mentions_me(&ctx.http).await { - Ok(mentioned) => { - let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { - Ok(t) => { - match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) { - Some(_) => true, - None => false - } - } - Err(_) => false - }; - if mentioned || bot_in_thread { - commands::oai::generate_response(&ctx, &msg, oai).await; - } - } - Err(why) => warn!("Could not check mentions: {:?}", why) - }; - } - None => {} - } - } - - async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - if let Interaction::ApplicationCommand(command) = interaction { - match command.data.name.as_str() { - "play" => commands::audio::play::run(&ctx, &command).await, - "stop" => commands::audio::stop::run(&ctx, &command).await, - "pause" => commands::audio::pause::run(&ctx, &command).await, - "resume" => commands::audio::resume::run(&ctx, &command).await, - "skip" => commands::audio::skip::run(&ctx, &command).await, - "volume" => commands::audio::volume::run(&ctx, &command).await, - _ => { - let content: String = match command.data.name.as_str() { - "ping" => commands::ping::run(&command.data.options), - _ => "Unknown command".to_string() - }; - - if let Err(why) = create_response(&ctx, &command, content).await { - warn!("Cannot respond to slash command: {}", why); - } - } - } - } - } - - async fn ready(&self, ctx: Context, ready: Ready) { - if ready.guilds.is_empty() { - warn!("No ready guilds found"); - } - for guild in ready.guilds { - let audio_config_lock = { - let data_read = ctx.data.read().await; - data_read.get::().expect("Expected AudioConfigs in TypeMap.").clone() - }; - { - let mut audio_configs = audio_config_lock.write().await; - let _ = audio_configs.insert(guild.id, AudioConfig { volume: 1.0 }); - } - let commands = guild.id.set_application_commands(&ctx.http, |commands| { - commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::stop::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::volume::register(command) }) - }).await; - match commands { - Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0), - Err(why) => error!("Could not register commands for guild {}: {:?}", guild.id.0, why) - }; - } - } -} #[actix_web::main] async fn main() -> std::io::Result<()> { @@ -121,13 +19,12 @@ async fn main() -> std::io::Result<()> { db::init(); db::load_data(); - // setup_discord_bot(); - let host = env::var("SERVICE_HOST").unwrap_or("localhost".to_string()); let port = env::var("SERVICE_PORT").unwrap_or("5000".to_string()); match HttpServer::new(|| { App::new() + .configure(db::messages::init_routes) .configure(db::spells::init_routes) }) .bind(format!("{}:{}", host, port)) { @@ -143,65 +40,3 @@ async fn main() -> std::io::Result<()> { .run() .await } - -fn setup_discord_bot() { - tokio::spawn(async { - let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); - let intents: GatewayIntents = GatewayIntents::all(); - - let http: Http = Http::new(&token); - let (owners, _bot_id) = match http.get_current_application_info().await { - Ok(info) => { - let mut owners: HashSet = HashSet::new(); - if let Some(team) = info.team { - owners.insert(team.owner_user_id); - } else { - owners.insert(info.owner.id); - } - match http.get_current_user().await { - Ok(bot) => (owners, bot.id), - Err(why) => panic!("Could not access the bot id: {:?}", why) - } - }, - Err(why) => panic!("Could not access application info: {:?}", why) - }; - - let handler = match env::var("OPENAI_API_KEY") { - Ok(token) => { - info!("Loaded OpenAI token"); - Handler { - oai: Some(commands::oai::OAI { - client: reqwest::Client::new(), - base_url: "https://api.openai.com/v1".to_string(), - max_attempts: 5, - token, - max_context_questions: 30, - max_tokens: 2048, - default_model: GPTModel::GPT35Turbo, - }) - } - } - Err(err) => { - warn!("Could not load OpenAI token: {}", err); - Handler { oai: None } - } - }; - - let mut client = Client::builder(token, intents) - .event_handler(handler) - .framework(StandardFramework::new() - .configure(|c| c.owners(owners))) - .register_songbird() - .await - .expect("Error creating client"); - - { - let mut data = client.data.write().await; - data.insert::(Arc::new(RwLock::new(HashMap::default()))); - } - - if let Err(why) = client.start_autosharded().await { - error!("An error occurred while running the client: {:?}", why); - } - }); -}