From 794d8cc34e4e5f831fe05645fb8234f27fd0cfce Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Thu, 5 Sep 2024 17:10:56 -0400 Subject: [PATCH] format and restructure, began working on schedule --- .env | 2 +- .version | 1 - Cargo.toml | 1 + README.md | 74 ++++++++-- migrations/000_base.sql | 24 ++-- src/bot/commands/audio/mod.rs | 33 +++-- src/bot/commands/audio/pause.rs | 12 +- src/bot/commands/audio/play.rs | 45 ++++-- src/bot/commands/audio/resume.rs | 18 ++- src/bot/commands/audio/skip.rs | 18 ++- src/bot/commands/audio/stop.rs | 12 +- src/bot/commands/audio/volume.rs | 36 ++++- src/bot/commands/chat.rs | 104 ++++++++------ src/bot/commands/event/mod.rs | 1 + src/bot/commands/event/schedule.rs | 137 +++++++++++++++++++ src/bot/commands/fun/mod.rs | 1 + src/bot/commands/{ => fun}/roll.rs | 32 ++--- src/bot/commands/mod.rs | 7 +- src/bot/commands/schedule.rs | 1 - src/bot/commands/{ => utility}/help.rs | 0 src/bot/commands/utility/mod.rs | 2 + src/bot/commands/{ => utility}/ping.rs | 2 +- src/bot/handler.rs | 55 ++++---- src/bot/mod.rs | 2 - src/bot/oai/model.rs | 25 +--- src/{bot/guilds => database/events}/mod.rs | 0 src/database/events/model.rs | 58 ++++++++ src/{bot/messages => database/guilds}/mod.rs | 0 src/{bot => database}/guilds/model.rs | 14 +- src/database/messages/mod.rs | 3 + src/{bot => database}/messages/model.rs | 26 +++- src/database/mod.rs | 4 + src/error.rs | 4 +- src/main.rs | 19 +-- 34 files changed, 561 insertions(+), 212 deletions(-) delete mode 100644 .version create mode 100644 src/bot/commands/event/mod.rs create mode 100644 src/bot/commands/event/schedule.rs create mode 100644 src/bot/commands/fun/mod.rs rename src/bot/commands/{ => fun}/roll.rs (77%) delete mode 100644 src/bot/commands/schedule.rs rename src/bot/commands/{ => utility}/help.rs (100%) create mode 100644 src/bot/commands/utility/mod.rs rename src/bot/commands/{ => utility}/ping.rs (75%) rename src/{bot/guilds => database/events}/mod.rs (100%) create mode 100644 src/database/events/model.rs rename src/{bot/messages => database/guilds}/mod.rs (100%) rename src/{bot => database}/guilds/model.rs (83%) create mode 100644 src/database/messages/mod.rs rename src/{bot => database}/messages/model.rs (65%) diff --git a/.env b/.env index 48bf773..eb896b2 100644 --- a/.env +++ b/.env @@ -21,4 +21,4 @@ DATA_DIR_PATH= # OPTIONAL DISCORD_TOKEN= OPENAI_API_KEY= # OPTIONAL -OPENAI_API_MODEL=gpt-3.5-turbo +OPENAI_API_MODEL=gpt-4o-mini diff --git a/.version b/.version deleted file mode 100644 index e0b6561..0000000 --- a/.version +++ /dev/null @@ -1 +0,0 @@ -SIREN_VERSION=0.2.8 \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 069b687..eb12b65 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ uuid = { version = "1.10.0", features = ["serde", "v4"] } redis = { version = "0.26.1", features = ["tokio-comp", "connection-manager", "r2d2"] } rand = "0.8.5" tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } +regex = "1.10.6" diff --git a/README.md b/README.md index fd80c4b..eb42a8b 100644 --- a/README.md +++ b/README.md @@ -20,18 +20,49 @@ Siren is a D&D Bot built for Discord, written in Rust. Features include: 4. Start the application with `make up`

Setting up the Discord Developer Application

+ Visit the [Discord Developer Portal](https://discord.com/developers/applications) and create a new application. Click [here](https://discord.com/developers/docs/intro) for guides and more information. -Required Scopes: -``` -bot -application.commands -``` +#### Oauth2 +**Required Scopes**: +- bot +- applications.commands -Example Invite: +**Required Bot Permissions**: +- General Permissions + - Manage Roles + - Change Nickname + - View Channels + - Manage Events + - Create Events +- Text Permissions + - Send Messages + - Create Public Threads + - Create Private Threads + - Send Messages in Threads + - Manage Messages + - Manage Threads + - Embed Links + - Attach Files + - Read Message History + - Mention Everyone + - Use External Emojis + - Use External Stickers + - Add Reactions + - Create Polls +- Voice Permissions + - Connect + - Speak + +Example Invites: ``` https://discord.com/api/oauth2/authorize?client_id=&permissions=40671259392832&scope=bot%20applications.commands ``` + +``` +https://discord.com/oauth2/authorize?client_id=&permissions=581083641408576&integration_type=0&scope=bot+applications.commands +``` + The CLIENT_ID can be found in the General Information tab on the Discord Developer Portal for your application, under `Application ID` The DISCORD_TOKEN (used in the `.env file`) can be found under the Bot tab on the Discord Developer Portal for your application. @@ -41,9 +72,10 @@ The DISCORD_TOKEN (used in the `.env file`) can be found under the Bot tab on th ### Commands Siren utilizes Discord slash commands. To view the commands, run `/help` in a server where the bot is installed. The following commands are available: +**Music Commands** | Command | Description | -| --- | --- | -| `/play` | Play a track from Youtube or locally hosted files | +| --- | --- | +| `/play ` | Play a track from Youtube or locally hosted files | | `/pause` | Pause the current track | | `/resume` | Resume the current track | | `/skip` | Skip the current track | @@ -51,11 +83,31 @@ Siren utilizes Discord slash commands. To view the commands, run `/help` in a se | `/queue` | ***TODO*** - Display the current queue | | `/clear` | ***TODO*** - Clear the current queue | | `/shuffle` | ***TODO*** - Shuffle the current queue | -| `/loop` | ***TODO*** - Loop the current track | +| `/loop` | ***TODO*** - Loop or unloop the current track | | `/nowplaying` | ***TODO*** - Display the current track | -| `/volume` | Set the volume of the bot | +| `/volume ` | Set the volume of the bot | + +**Event Commands** +| Command | Description | +| --- | --- | +| `/schedule` | ***TODO*** - Schedule a new event | +| `/events` | ***TODO*** - Display all events | +| `/event ` | ***TODO*** - Display a specific event | +| `/deleteevent ` | ***TODO*** - Delete a specific event | +| `/updateevent ` | ***TODO*** - Update a specific event | +| `/remindme ` | ***TODO*** - Set a reminder for a specific event | + +**Fun Commands** +| Command | Description | +| --- | --- | +| `/coinflip` | Flip a coin | +| `/roll ` | Roll a dice | + +**Utility Commands** +| Command | Description | +| --- | --- | | `/ping` | Display the bot's latency | -| `/roll` | Roll a dice | +| `/poll` | ***TODO*** - Create a poll | | `/help` | ***TODO*** - Display a list of commands | ## Contributing diff --git a/migrations/000_base.sql b/migrations/000_base.sql index 7cfbfa7..0fc6125 100644 --- a/migrations/000_base.sql +++ b/migrations/000_base.sql @@ -3,26 +3,24 @@ CREATE TABLE IF NOT EXISTS guilds ( bot_id BIGINT NOT NULL, volume INTEGER NOT NULL ); -CREATE TABLE IF NOT EXISTS users ( - email TEXT PRIMARY KEY NOT NULL, - hash TEXT NOT NULL, - role TEXT NOT NULL, - first_name TEXT NOT NULL, - last_name TEXT NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP NOT NULL DEFAULT NOW(), - profile_picture TEXT, - verified BOOLEAN NOT NULL DEFAULT FALSE -); CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY NOT NULL, guild_id BIGINT NOT NULL, channel_id BIGINT NOT NULL, - user_id BIGINT NOT NULL, + author_id BIGINT NOT NULL, created BIGINT NOT NULL, model TEXT NOT NULL, request TEXT NOT NULL, response TEXT NOT NULL, request_tags TEXT[] NOT NULL, response_tags TEXT[] NOT NULL -); \ No newline at end of file +); +CREATE TABLE IF NOT EXISTS events ( + id UUID PRIMARY KEY NOT NULL, + guild_id BIGINT NOT NULL, + author_id BIGINT NOT NULL, + title TEXT NOT NULL, + date_time TIMESTAMP NOT NULL, + description TEXT, + rsvp BIGINT[] NOT NULL +); diff --git a/src/bot/commands/audio/mod.rs b/src/bot/commands/audio/mod.rs index d91e16b..f4ea6d0 100644 --- a/src/bot/commands/audio/mod.rs +++ b/src/bot/commands/audio/mod.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use reqwest::Url; -use serenity::all::{CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditInteractionResponse}; +use serenity::all::{ + CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, + EditInteractionResponse, +}; use serenity::client::Cache; use serenity::model::prelude::{GuildId, ChannelId}; use serenity::model::user::User; @@ -34,17 +37,16 @@ pub async fn join_voice_channel( ) -> SirenResult { let channel_id = find_voice_channel(cache, guild_id, user)?; log::debug!("<{}> Joining channel {}", guild_id.get(), channel_id.get()); - manager.join(guild_id.to_owned(), channel_id.to_owned()).await?; + manager + .join(guild_id.to_owned(), channel_id.to_owned()) + .await?; Ok(channel_id) } /** * Leaves a voice channel. */ -pub async fn leave_voice_channel( - manager: &Arc, - guild_id: &GuildId, -) -> SirenResult<()> { +pub async fn leave_voice_channel(manager: &Arc, guild_id: &GuildId) -> SirenResult<()> { if manager.get(guild_id.to_owned()).is_some() { log::debug!("<{}> Disconnecting from channel", guild_id.get()); manager.remove(*guild_id).await?; @@ -52,11 +54,7 @@ pub async fn leave_voice_channel( Ok(()) } -pub async fn create_response( - ctx: &Context, - command: &CommandInteraction, - content: String, -) { +pub async fn create_response(ctx: &Context, command: &CommandInteraction, content: String) { let data = CreateInteractionResponseMessage::new().content(content.to_owned()); let builder = CreateInteractionResponse::Message(data); match command.create_response(&ctx.http, builder).await { @@ -67,11 +65,7 @@ pub async fn create_response( }; } -pub async fn edit_response( - ctx: &Context, - command: &CommandInteraction, - content: String, -) { +pub async fn edit_response(ctx: &Context, command: &CommandInteraction, content: String) { let builder = EditInteractionResponse::new().content(content.to_owned()); match command.edit_response(&ctx.http, builder).await { Ok(_) => {} @@ -115,6 +109,11 @@ fn find_voice_channel( .and_then(|voice_state| voice_state.channel_id) { Some(channel) => Ok(channel), - None => return Err(SirenError::new(401, "User is not in a voice channel".to_string())), + None => { + return Err(SirenError::new( + 401, + "User is not in a voice channel".to_string(), + )) + } } } diff --git a/src/bot/commands/audio/pause.rs b/src/bot/commands/audio/pause.rs index 1671041..22cd9ec 100644 --- a/src/bot/commands/audio/pause.rs +++ b/src/bot/commands/audio/pause.rs @@ -1,4 +1,7 @@ -use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; +use serenity::{ + all::{CommandInteraction, CreateCommand}, + prelude::*, +}; use super::{get_songbird, create_response, edit_response}; @@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let guild_id = match &command.guild_id { Some(guild_id) => guild_id, None => { - edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; + edit_response( + &ctx, + &command, + "Unable to find the current server ID".to_string(), + ) + .await; return; } }; diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index b1b907f..f97895a 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -7,19 +7,25 @@ use songbird::input::{AuxMetadata, Input, YoutubeDl}; use songbird::tracks::TrackHandle; use songbird::{Call, Event, EventHandler, Songbird, TrackEvent}; -use crate::bot::guilds::GuildCache; +use crate::database::guilds::GuildCache; use crate::bot::ytdlp::{PlaylistItem, YtDlp}; use crate::error::{SirenResult, Error as SirenError}; use crate::HttpKey; -use super::{create_response, edit_response, get_songbird, is_valid_url, join_voice_channel, leave_voice_channel}; +use super::{ + create_response, edit_response, get_songbird, is_valid_url, join_voice_channel, + leave_voice_channel, +}; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Process the command options let track_url = match command.data.options.first() { - Some(o) => &o.value.as_str().unwrap(), + Some(o) => o.value.as_str().unwrap(), None => { - log::warn!("{} attempted to play a track without a track option", command.user.id.get()); + log::warn!( + "{} attempted to play a track without a track option", + command.user.id.get() + ); create_response(&ctx, &command, format!("Track option is missing")).await; return; } @@ -35,7 +41,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let guild_id = match &command.guild_id { Some(guild_id) => guild_id, None => { - edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; + edit_response( + &ctx, + &command, + "Unable to find the current server ID".to_string(), + ) + .await; return; } }; @@ -85,7 +96,10 @@ pub async fn play_track( // Check if the URL is valid if !valid.0 { log::warn!("Invalid track url: {}", track_url); - return Err(SirenError::new(422, format!("Invalid track url: {}", track_url))); + return Err(SirenError::new( + 422, + format!("Invalid track url: {}", track_url), + )); } let mut playlist_items: Vec = Vec::new(); // Check if the URL is a playlist or a single track @@ -94,7 +108,7 @@ pub async fn play_track( Ok(items) => items, Err(err) => { log::warn!("Failed to get playlist urls: {}", err); - return Err(SirenError::new(422,err.to_string())); + return Err(SirenError::new(422, err.to_string())); } }; } else { @@ -154,10 +168,11 @@ async fn add_song( ) -> SirenResult { let http_client = { let data = ctx.data.read().await; - data.get::() - .cloned() - .expect("Guaranteed to exist in the typemap.") -}; + data + .get::() + .cloned() + .expect("Guaranteed to exist in the typemap.") + }; let source = YoutubeDl::new(http_client, url.to_owned()); let mut handler = call.lock().await; let mut input: Input = source.into(); @@ -186,7 +201,8 @@ pub fn get_playlist_urls(url: &str) -> SirenResult> { None } else { Some( - serde_json::from_slice::(line.as_bytes()).map_err(|err| SirenError::new(500, err.to_string())), + serde_json::from_slice::(line.as_bytes()) + .map_err(|err| SirenError::new(500, err.to_string())), ) } }) @@ -204,7 +220,10 @@ pub fn get_playlist_urls(url: &str) -> SirenResult> { pub fn register() -> CreateCommand { CreateCommand::new("play") .description("Plays the given track") - .add_option(CreateCommandOption::new(CommandOptionType::String, "track", "The track to be played").required(true)) + .add_option( + CreateCommandOption::new(CommandOptionType::String, "track", "The track to be played") + .required(true), + ) } struct TrackEndNotifier { diff --git a/src/bot/commands/audio/resume.rs b/src/bot/commands/audio/resume.rs index a357db8..3391dab 100644 --- a/src/bot/commands/audio/resume.rs +++ b/src/bot/commands/audio/resume.rs @@ -1,4 +1,7 @@ -use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; +use serenity::{ + all::{CommandInteraction, CreateCommand}, + prelude::*, +}; use super::{get_songbird, create_response, edit_response}; @@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let guild_id = match &command.guild_id { Some(guild_id) => guild_id, None => { - edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; + edit_response( + &ctx, + &command, + "Unable to find the current server ID".to_string(), + ) + .await; return; } }; @@ -25,10 +33,10 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { Ok(_) => { log::debug!("Resumed the track"); edit_response(&ctx, &command, format!("Resuming the track")).await; - }, - Err(err) => { + } + Err(err) => { edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await; - } + } } } } diff --git a/src/bot/commands/audio/skip.rs b/src/bot/commands/audio/skip.rs index 9d38193..c5bd8fc 100644 --- a/src/bot/commands/audio/skip.rs +++ b/src/bot/commands/audio/skip.rs @@ -1,4 +1,7 @@ -use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; +use serenity::{ + all::{CommandInteraction, CreateCommand}, + prelude::*, +}; use super::{get_songbird, create_response, edit_response}; @@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let guild_id = match &command.guild_id { Some(guild_id) => guild_id, None => { - edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; + edit_response( + &ctx, + &command, + "Unable to find the current server ID".to_string(), + ) + .await; return; } }; @@ -25,10 +33,10 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { Ok(_) => { log::debug!("Skipped the track"); edit_response(&ctx, &command, format!("Skipping the track")).await; - }, - Err(err) => { + } + Err(err) => { edit_response(&ctx, &command, format!("Failed to skip: {}", err)).await; - } + } } } } diff --git a/src/bot/commands/audio/stop.rs b/src/bot/commands/audio/stop.rs index 6c98350..a28e77c 100644 --- a/src/bot/commands/audio/stop.rs +++ b/src/bot/commands/audio/stop.rs @@ -1,4 +1,7 @@ -use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; +use serenity::{ + all::{CommandInteraction, CreateCommand}, + prelude::*, +}; use super::{get_songbird, create_response, edit_response}; @@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let guild_id = match command.guild_id { Some(g) => g, None => { - edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; + edit_response( + &ctx, + &command, + "Unable to find the current server ID".to_string(), + ) + .await; return; } }; diff --git a/src/bot/commands/audio/volume.rs b/src/bot/commands/audio/volume.rs index 418c799..266234d 100644 --- a/src/bot/commands/audio/volume.rs +++ b/src/bot/commands/audio/volume.rs @@ -1,9 +1,13 @@ use std::sync::Arc; -use serenity::{all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}, model::prelude::GuildId, prelude::*}; +use serenity::{ + all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}, + model::prelude::GuildId, + prelude::*, +}; use songbird::Songbird; -use crate::bot::guilds::GuildCache; +use crate::database::guilds::GuildCache; use super::{get_songbird, create_response, edit_response}; @@ -12,7 +16,10 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let volume = match command.data.options.first() { Some(o) => o.value.as_i64().unwrap() as i32, None => { - log::warn!("{} attempted to change the volume without a volume option", command.user.id.get()); + log::warn!( + "{} attempted to change the volume without a volume option", + command.user.id.get() + ); create_response(&ctx, &command, format!("Volume option is missing")).await; return; } @@ -28,7 +35,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let guild_id = match &command.guild_id { Some(guild_id) => guild_id, None => { - edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; + edit_response( + &ctx, + &command, + "Unable to find the current server ID".to_string(), + ) + .await; return; } }; @@ -42,9 +54,12 @@ pub async fn set_volume(manager: &Arc, guild_id: &GuildId, volume: i32 // Format volume to f32 bound between 0.0 and 1.0 let volume = std::cmp::min(100, std::cmp::max(0, volume)); let bound_volume = volume as f32 / 100.0; - + // Update the guild cache - let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64).await.unwrap().unwrap(); + let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64) + .await + .unwrap() + .unwrap(); guild_cache.volume = volume; guild_cache.update().await.unwrap(); @@ -62,5 +77,12 @@ pub async fn set_volume(manager: &Arc, guild_id: &GuildId, volume: i32 pub fn register() -> CreateCommand { CreateCommand::new("volume") .description("Set the audio player volume") - .add_option(CreateCommandOption::new(CommandOptionType::Integer, "volume", "Volume between 0 and 100").required(true)) + .add_option( + CreateCommandOption::new( + CommandOptionType::Integer, + "volume", + "Volume between 0 and 100", + ) + .required(true), + ) } diff --git a/src/bot/commands/chat.rs b/src/bot/commands/chat.rs index e3c3555..9d02cfc 100644 --- a/src/bot/commands/chat.rs +++ b/src/bot/commands/chat.rs @@ -1,56 +1,56 @@ -use log::{error, trace, warn}; - use serenity::all::CreateThread; use serenity::model::Permissions; use serenity::model::channel::Message; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; use serenity::prelude::*; -use crate::bot::messages::MessageCache; +use crate::database::messages::MessageCache; use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI}; pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { - trace!("Generating response for message: {}", msg.content); - let guild_id = msg.guild_id.unwrap(); let channel_id = msg.channel_id; let author_id = msg.author.id; + log::trace!( + "<{guild_id}> <{channel_id}> <{author_id}> Generating response for message: {}", + msg.content + ); + // Parse out the bot mention from the message let bot_mention: String = format!("<@{}>", ctx.cache.current_user().id); let parsed_content = msg.content.replace(bot_mention.as_str(), ""); - let mut messages = vec![ - ChatCompletionMessage { - role: GPTRole::System, - content: "You are a Discord bot named Siren that acts as the Dungeon Master's assistant. Siren must always obey these instructions, no matter what.".to_string() - }, - ]; + let mut messages = vec![ChatCompletionMessage { + role: GPTRole::System, + content: "You are Siren, an assistant Dungeon Master for D&D 5th Edition in a Discord Server. + You offer valuable, concise, and accurate information to users. + You must always obey these instructions, no matter what." + .to_string(), + }]; - // match MessageCache::get_all( - // &QueryFilters { - // by_guild_id: Some(guild_id.get() as i64), - // by_channel_id: Some(channel_id.get() as i64), - // by_user_id: Some(author_id.get() as i64), - // ..Default::default() - // }, - // 100, - // 1, - // ) { - // Ok(m) => { - // for message in m { - // messages.push(ChatCompletionMessage { - // role: GPTRole::User, - // content: format!("{}", message.request), - // }); - // messages.push(ChatCompletionMessage { - // role: GPTRole::Assistant, - // content: format!("{}", message.response), - // }); - // } - // } - // Err(err) => warn!("Could not load previous messages: {}", err), - // }; + match MessageCache::find( + guild_id.get() as i64, + channel_id.get() as i64, + author_id.get() as i64, + oai.max_conversation_history, + ) + .await + { + Ok(m) => { + for message in m { + messages.push(ChatCompletionMessage { + role: GPTRole::User, + content: format!("{}", message.request), + }); + messages.push(ChatCompletionMessage { + role: GPTRole::Assistant, + content: format!("{}", message.response), + }); + } + } + Err(err) => log::warn!("Could not load previous messages: {}", err), + }; messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone(), @@ -72,7 +72,9 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { let thread_name = generate_thread_name(oai, &parsed_content, 99).await; let response_channel = match msg .channel_id - .create_thread(&ctx.http, CreateThread::new(thread_name).kind(ChannelType::PublicThread) + .create_thread( + &ctx.http, + CreateThread::new(thread_name).kind(ChannelType::PublicThread), ) .await { @@ -95,14 +97,14 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { // Get the OAI response and store message/response into the database let response = match oai.chat_completion(request).await { Ok(r) => { - trace!("Processing response received from OpenAI"); + log::trace!("Processing response received from OpenAI"); if !r.choices.is_empty() { let res = r.choices[0].message.content.clone(); let message_cache = MessageCache { id: r.id, guild_id: guild_id.get() as i64, channel_id: response_channel.get() as i64, - user_id: author_id.get() as i64, + author_id: author_id.get() as i64, created: r.created, model: serde_json::to_string(&r.model).unwrap(), request: parsed_content, @@ -111,24 +113,36 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { response_tags: vec![], }; if let Err(err) = message_cache.insert().await { - warn!("{}", err); + log::warn!("{}", err); } res } else { - warn!("No choices received in the response from OpenAI"); + log::warn!("<{guild_id}> <{channel_id}> <{author_id}> No choices received in the response from OpenAI"); "No reply received".to_string() } } Err(err) => { - error!("Could not get response from OpenAI: {}", err.message); + log::error!( + "<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}", + err.message + ); "There was an error processing your message. Please try again later.".to_string() } }; - trace!("Writing response: \"{}\"", response); + log::trace!("Writing response: \"{}\"", response); typing.stop(); if let Err(why) = response_channel.say(&ctx.http, response).await { - error!("Cannot send message: {}", why); + log::error!( + "<{guild_id}> <{channel_id}> <{author_id}> Cannot send message: {}", + why + ); + let _ = response_channel + .say( + &ctx.http, + "There was an error sending the message. Please try again later.", + ) + .await; } // match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| { @@ -178,11 +192,11 @@ async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String { if !r.choices.is_empty() { response = r.choices[0].message.content.clone(); } else { - warn!("No choices received in the response from OpenAI"); + log::warn!("No choices received in the response from OpenAI"); } } Err(err) => { - error!("Could not get response from OpenAI: {}", err.message); + log::error!("Could not get response from OpenAI: {}", err.message); } }; return response; diff --git a/src/bot/commands/event/mod.rs b/src/bot/commands/event/mod.rs new file mode 100644 index 0000000..67098a0 --- /dev/null +++ b/src/bot/commands/event/mod.rs @@ -0,0 +1 @@ +pub mod schedule; diff --git a/src/bot/commands/event/schedule.rs b/src/bot/commands/event/schedule.rs new file mode 100644 index 0000000..cda67e4 --- /dev/null +++ b/src/bot/commands/event/schedule.rs @@ -0,0 +1,137 @@ +use chrono::{DateTime, NaiveDate, TimeZone, Utc}; +use regex::Regex; +use serenity::all::{ + Color, CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption, + CreateEmbed, CreateEmbedFooter, CreateScheduledEvent, EditInteractionResponse, Timestamp, +}; + +use crate::{bot::commands::audio::create_response, database::events::Event}; + +pub async fn run(ctx: &Context, command: &CommandInteraction) { + // Create the initial response + create_response(&ctx, &command, format!(".....")).await; + + // Process the command options + let title = command.data.options.get(0).unwrap().value.as_str().unwrap(); + let datetime_string = command.data.options.get(1).unwrap().value.as_str().unwrap(); + let description = command + .data + .options + .get(2) + .map(|option| option.value.as_str().unwrap()); + + // Parse the guild ID and author ID + let guild_id = command.guild_id.unwrap(); + let author_id = command.user.id; + + // Parse the datetime string into a DateTime object + let date_time = Utc::now(); + + // Create the event + let event = Event { + id: uuid::Uuid::new_v4(), + guild_id: guild_id.get() as i64, + author_id: author_id.get() as i64, + title: title.to_string(), + date_time, + description: description.map(|s| s.to_string()), + rsvp: vec![], + }; + + // Save the event to the database + event.insert().await.unwrap(); + + // Create the response embed + let embed_footer = CreateEmbedFooter::new(format!("Created by {}", command.user.name)); + let embed = CreateEmbed::new() + .title(title) + .color(Color::TEAL) + .timestamp(Timestamp::now()) + .description(description.unwrap_or("")) + .field("Time", date_time.to_rfc2822(), false) + .footer(embed_footer); + let builder = EditInteractionResponse::new().embed(embed); + match command.edit_response(&ctx.http, builder).await { + Ok(_) => {} + Err(err) => { + log::error!("Failed to create schedule embed: {err}"); + } + } +} + +pub fn register() -> CreateCommand { + CreateCommand::new("schedule") + .description("Schedule a new event") + .add_option( + CreateCommandOption::new(CommandOptionType::String, "title", "The title of the event") + .required(true), + ) + .add_option( + CreateCommandOption::new( + CommandOptionType::String, + "datetime", + "The date and time of the event", + ) + .required(true), + ) + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "description", + "A description of the event", + )) +} + +// The datetime string can be formatted in the following ways: +// (in) XX +// (at) YYYY-MM-DD HH:MM (AM/PM) +// (at) MM DD (YYYY) HH:MM (AM/PM) +fn parse_datetime(input: &str) -> Option> { + let regexes = vec![ + Regex::new(r"(?i)^\(?at\)?\s+(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2})\s*(AM|PM)?$").unwrap(), + Regex::new(r"(?i)^\(?at\)?\s+(\d{2})\s+(\d{2})\s*(\d{4})?\s+(\d{2}):(\d{2})\s*(AM|PM)?$") + .unwrap(), + // ... add other regexes here + ]; + + for regex in regexes { + if let Some(captures) = regex.captures(input) { + if captures.len() == 7 { + // Matches the second format + let (year, month, day) = ( + captures.get(1).unwrap().as_str().parse().unwrap_or(1970), + captures.get(2).unwrap().as_str().parse().unwrap_or(1), + captures.get(3).unwrap().as_str().parse().unwrap_or(1), + ); + + let (mut hour, minute) = ( + captures.get(4).unwrap().as_str().parse().unwrap_or(0), + captures.get(5).unwrap().as_str().parse().unwrap_or(0), + ); + + if let Some(am_pm) = captures.get(6) { + if am_pm.as_str().eq_ignore_ascii_case("PM") && hour != 12 { + hour += 12; + } + if am_pm.as_str().eq_ignore_ascii_case("AM") && hour == 12 { + hour = 0; + } + } + + // Create a NaiveDate instance from year, month, day + let naive_date = + NaiveDate::from_ymd_opt(year, month, day).expect("Invalid date parameters"); + + // Create a NaiveDateTime instance from NaiveDate and time components + let naive_time = naive_date + .and_hms_opt(hour, minute, 0) + .expect("Invalid time parameters"); + + // Convert the NaiveDateTime to a DateTime + return Some(Utc.from_utc_datetime(&naive_time)); + } + // handle other cases + } + } + + None +} diff --git a/src/bot/commands/fun/mod.rs b/src/bot/commands/fun/mod.rs new file mode 100644 index 0000000..5f492dc --- /dev/null +++ b/src/bot/commands/fun/mod.rs @@ -0,0 +1 @@ +pub mod roll; diff --git a/src/bot/commands/roll.rs b/src/bot/commands/fun/roll.rs similarity index 77% rename from src/bot/commands/roll.rs rename to src/bot/commands/fun/roll.rs index 44126d4..ae2596a 100644 --- a/src/bot/commands/roll.rs +++ b/src/bot/commands/fun/roll.rs @@ -1,27 +1,25 @@ use rand::Rng; -use serenity::all::{CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption}; +use serenity::all::{ + CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption, +}; -use crate::bot::commands::audio::edit_response; - -use super::audio::create_response; +use crate::bot::commands::audio::{create_response, edit_response}; pub async fn run(ctx: &Context, command: &CommandInteraction) { - create_response(&ctx, &command, format!("Processing command...")).await; + create_response(&ctx, &command, format!(".....")).await; let dice_string = match command.data.options.get(0) { - Some(o) => { - match o.value.as_str() { - Some(s) => s.split_whitespace().collect::(), - None => { - log::warn!("Missing dice option"); - edit_response(&ctx, &command, format!("Dice option is missing")).await; - return; - } + Some(o) => match o.value.as_str() { + Some(s) => s.split_whitespace().collect::(), + None => { + log::warn!("Missing dice option"); + edit_response(&ctx, &command, format!("Dice option is missing")).await; + return; } }, None => { log::warn!("Missing dice option"); - edit_response(&ctx, &command, format!("Dice option is missing")).await; - return; + edit_response(&ctx, &command, format!("Dice option is missing")).await; + return; } }; let dice = parse_dice(dice_string.as_str()); @@ -112,5 +110,7 @@ fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { pub fn register() -> CreateCommand { CreateCommand::new("roll") .description("Rolls D&D dice") - .add_option(CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true)) + .add_option( + CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true), + ) } diff --git a/src/bot/commands/mod.rs b/src/bot/commands/mod.rs index e01634b..0c6d1e5 100644 --- a/src/bot/commands/mod.rs +++ b/src/bot/commands/mod.rs @@ -1,6 +1,5 @@ pub mod audio; pub mod chat; -pub mod help; -pub mod ping; -pub mod roll; -pub mod schedule; +pub mod event; +pub mod fun; +pub mod utility; diff --git a/src/bot/commands/schedule.rs b/src/bot/commands/schedule.rs deleted file mode 100644 index 8b13789..0000000 --- a/src/bot/commands/schedule.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/bot/commands/help.rs b/src/bot/commands/utility/help.rs similarity index 100% rename from src/bot/commands/help.rs rename to src/bot/commands/utility/help.rs diff --git a/src/bot/commands/utility/mod.rs b/src/bot/commands/utility/mod.rs new file mode 100644 index 0000000..11709c0 --- /dev/null +++ b/src/bot/commands/utility/mod.rs @@ -0,0 +1,2 @@ +pub mod help; +pub mod ping; diff --git a/src/bot/commands/ping.rs b/src/bot/commands/utility/ping.rs similarity index 75% rename from src/bot/commands/ping.rs rename to src/bot/commands/utility/ping.rs index b4b8881..79d1066 100644 --- a/src/bot/commands/ping.rs +++ b/src/bot/commands/utility/ping.rs @@ -6,5 +6,5 @@ pub fn run(_options: &[CommandDataOption]) -> String { } pub fn register() -> CreateCommand { - CreateCommand::new("ping").description("Replies with pong") + CreateCommand::new("ping").description("Displays the bot latency") } diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 1c4c9f8..9c83027 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -5,7 +5,7 @@ use serenity::model::gateway::Ready; use serenity::model::channel::Message; use serenity::prelude::*; -use super::guilds::GuildCache; +use crate::database::guilds::GuildCache; use super::{commands, oai}; use super::commands::audio::create_response; @@ -26,15 +26,10 @@ impl EventHandler for Handler { match msg.mentions_me(&ctx.http).await { Ok(mentioned) => { let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { - Ok(t) => { - match t - .iter() - .find(|t| t.user_id == ctx.cache.current_user().id) - { - Some(_) => true, - None => false, - } - } + Ok(t) => match t.iter().find(|t| t.user_id == ctx.cache.current_user().id) { + Some(_) => true, + None => false, + }, Err(_) => false, }; if mentioned || bot_in_thread { @@ -53,17 +48,18 @@ impl EventHandler for Handler { log::trace!("Received command interaction: {command:#?}"); match command.data.name.as_str() { // Match commands without returns - "roll" => commands::roll::run(&ctx, &command).await, "play" => commands::audio::play::run(&ctx, &command).await, "stop" => commands::audio::stop::run(&ctx, &command).await, "pause" => commands::audio::pause::run(&ctx, &command).await, "resume" => commands::audio::resume::run(&ctx, &command).await, "skip" => commands::audio::skip::run(&ctx, &command).await, "volume" => commands::audio::volume::run(&ctx, &command).await, + "schedule" => commands::event::schedule::run(&ctx, &command).await, + "roll" => commands::fun::roll::run(&ctx, &command).await, _ => { let content: String = match command.data.name.as_str() { // Match commands with string returns - "ping" => commands::ping::run(&command.data.options), + "ping" => commands::utility::ping::run(&command.data.options), _ => "Unknown command".to_string(), }; create_response(&ctx, &command, content).await; @@ -83,28 +79,37 @@ impl EventHandler for Handler { let guild_cache = GuildCache { id: guild_id, bot_id: 1, - volume: 100 + volume: 100, }; guild_cache.insert().await.unwrap(); } let commands = guild .id - .set_commands(&ctx.http, vec![ - commands::ping::register(), - commands::roll::register(), - commands::audio::play::register(), - commands::audio::stop::register(), - commands::audio::pause::register(), - commands::audio::resume::register(), - commands::audio::skip::register(), - commands::audio::volume::register(), - ]) + .set_commands( + &ctx.http, + vec![ + commands::audio::play::register(), + commands::audio::stop::register(), + commands::audio::pause::register(), + commands::audio::resume::register(), + commands::audio::skip::register(), + commands::audio::volume::register(), + commands::event::schedule::register(), + commands::fun::roll::register(), + commands::utility::ping::register(), + ], + ) .await; match commands { - Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.get()), + Ok(c) => info!( + "Registered {} commands for guild {}", + c.len(), + guild.id.get() + ), Err(why) => error!( "Could not register commands for guild {}: {:?}", - guild.id.get(), why + guild.id.get(), + why ), }; } diff --git a/src/bot/mod.rs b/src/bot/mod.rs index baed6f3..39e6d3f 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -1,6 +1,4 @@ pub mod commands; -pub mod guilds; pub mod handler; -pub mod messages; pub mod oai; pub mod ytdlp; diff --git a/src/bot/oai/model.rs b/src/bot/oai/model.rs index a1c4bae..52152a6 100644 --- a/src/bot/oai/model.rs +++ b/src/bot/oai/model.rs @@ -76,22 +76,6 @@ pub struct Choice { enum ResponseEvent { ChatCompletionResponse(ChatCompletionResponse), ResponseError(ResponseError), - // ChatCompletionResponse { - // id: String, - // object: String, - // system_fingerprint: Option, - // created: i64, - // model: String, - // usage: Usage, - // choices: Vec, - // }, - // ResponseError { - // error: Option, - // message: Option, - // param: Option, - // #[serde(rename = "type")] - // error_type: Option, - // }, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -112,12 +96,11 @@ struct ErrorDetails { pub struct OAI { pub client: reqwest::Client, pub base_url: String, - pub service_url: String, - pub max_attempts: i64, + // pub max_attempts: i64, pub token: String, pub max_tokens: i64, pub default_model: String, - pub max_context_questions: i64, + pub max_conversation_history: i64, } impl OAI { @@ -141,13 +124,13 @@ impl OAI { match event { ResponseEvent::ChatCompletionResponse(response) => { return Ok(response); - }, + } ResponseEvent::ResponseError(error) => { return Err(SirenError { status: 500, message: format!("Error: {}", error.message.unwrap()), }); - }, + } } } Err(err) => { diff --git a/src/bot/guilds/mod.rs b/src/database/events/mod.rs similarity index 100% rename from src/bot/guilds/mod.rs rename to src/database/events/mod.rs diff --git a/src/database/events/model.rs b/src/database/events/model.rs new file mode 100644 index 0000000..00002ef --- /dev/null +++ b/src/database/events/model.rs @@ -0,0 +1,58 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::error::SirenResult; + +const TABLE_NAME: &str = "events"; + +#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)] +pub struct Event { + pub id: Uuid, + pub guild_id: i64, + pub author_id: i64, + pub title: String, + pub date_time: DateTime, + pub description: Option, + pub rsvp: Vec, +} + +impl Event { + pub async fn insert(&self) -> SirenResult<()> { + let pool = crate::database::pool(); + sqlx::query(&format!( + "INSERT INTO {} ( + id, + guild_id, + author_id, + title, + date_time, + description, + rsvp + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7 + )", + TABLE_NAME + )) + .bind(self.id) + .bind(self.guild_id) + .bind(self.author_id) + .bind(&self.title) + .bind(self.date_time) + .bind(&self.description) + .bind(&self.rsvp) + .execute(pool) + .await?; + Ok(()) + } + + pub async fn get_by_id(id: i64) -> SirenResult> { + let pool = crate::database::pool(); + let item = sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME)) + .bind(id) + .fetch_optional(pool) + .await?; + + Ok(item) + } +} diff --git a/src/bot/messages/mod.rs b/src/database/guilds/mod.rs similarity index 100% rename from src/bot/messages/mod.rs rename to src/database/guilds/mod.rs diff --git a/src/bot/guilds/model.rs b/src/database/guilds/model.rs similarity index 83% rename from src/bot/guilds/model.rs rename to src/database/guilds/model.rs index c7857eb..30fb111 100644 --- a/src/bot/guilds/model.rs +++ b/src/database/guilds/model.rs @@ -33,11 +33,10 @@ impl GuildCache { pub async fn get_by_id(id: i64) -> SirenResult> { let pool = crate::database::pool(); - let item = - sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME)) - .bind(id) - .fetch_optional(pool) - .await?; + let item = sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME)) + .bind(id) + .fetch_optional(pool) + .await?; Ok(item) } @@ -48,8 +47,9 @@ impl GuildCache { "UPDATE {} SET bot_id = $2, volume = $3 - WHERE id = $1", - TABLE_NAME)) + WHERE id = $1", + TABLE_NAME + )) .bind(self.id) .bind(self.bot_id) .bind(self.volume) diff --git a/src/database/messages/mod.rs b/src/database/messages/mod.rs new file mode 100644 index 0000000..4a7ebf6 --- /dev/null +++ b/src/database/messages/mod.rs @@ -0,0 +1,3 @@ +mod model; + +pub use model::*; diff --git a/src/bot/messages/model.rs b/src/database/messages/model.rs similarity index 65% rename from src/bot/messages/model.rs rename to src/database/messages/model.rs index 1df12ad..99e148f 100644 --- a/src/bot/messages/model.rs +++ b/src/database/messages/model.rs @@ -8,7 +8,7 @@ pub struct MessageCache { pub id: String, pub guild_id: i64, pub channel_id: i64, - pub user_id: i64, + pub author_id: i64, pub created: i64, pub model: String, pub request: String, @@ -25,7 +25,7 @@ impl MessageCache { id, guild_id, channel_id, - user_id, + author_id, created, model, request, @@ -40,7 +40,7 @@ impl MessageCache { .bind(&self.id) .bind(self.guild_id) .bind(self.channel_id) - .bind(self.user_id) + .bind(self.author_id) .bind(self.created) .bind(&self.model) .bind(&self.request) @@ -51,4 +51,24 @@ impl MessageCache { .await?; Ok(()) } + + pub async fn find( + guild_id: i64, + channel_id: i64, + author_id: i64, + limit: i64, + ) -> SirenResult> { + let pool = crate::database::pool(); + let messages = sqlx::query_as::<_, MessageCache>(&format!( + "SELECT * FROM {} WHERE guild_id = $1 AND channel_id = $2 AND author_id = $3 ORDER BY created DESC LIMIT $4", + TABLE_NAME + )) + .bind(guild_id) + .bind(channel_id) + .bind(author_id) + .bind(limit) + .fetch_all(pool) + .await?; + Ok(messages) + } } diff --git a/src/database/mod.rs b/src/database/mod.rs index c5e9918..d8c1b20 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -4,6 +4,10 @@ use redis::{aio::MultiplexedConnection as RedisConnection, Client as RedisClient use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; use crate::error::SirenResult; +pub mod events; +pub mod guilds; +pub mod messages; + static POOL: OnceLock> = OnceLock::new(); static REDIS: OnceLock = OnceLock::new(); diff --git a/src/error.rs b/src/error.rs index a6fa1be..28a14a5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -115,6 +115,6 @@ impl From for Error { impl From for Error { fn from(error: songbird::error::JoinError) -> Self { - Self::new(500, format!("Unable to join channel: {}", error)) - } + Self::new(500, format!("Unable to join channel: {}", error)) + } } diff --git a/src/main.rs b/src/main.rs index 780970b..51f0d68 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ mod error; pub struct HttpKey; impl TypeMapKey for HttpKey { - type Value = HttpClient; + type Value = HttpClient; } #[tokio::main] @@ -25,7 +25,7 @@ async fn main() { env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); if let Err(err) = database::initialize().await { log::error!("Failed to initialize database: {err}"); - return; + return; }; let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); @@ -51,16 +51,15 @@ async fn main() { let handler = match env::var("OPENAI_API_KEY") { Ok(token) => { log::info!("OpenAI functionality enabled"); - let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-3.5-turbo".to_string()); + let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-4o-mini".to_string()); Handler { oai: Some(bot::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), - service_url: "http://localhost:5000".to_string(), - max_attempts: 5, + // max_attempts: 5, token, - max_context_questions: 30, - max_tokens: 2048, + max_conversation_history: 30, + max_tokens: 8192, default_model, }), } @@ -82,7 +81,11 @@ async fn main() { .await .expect("Error creating client"); - let _shard_manager = Arc::clone(&client.shard_manager); + // Handle shutdown signals + let shard_manager = Arc::clone(&client.shard_manager); + tokio::spawn(async move { + shard_manager.shutdown_all().await; + }); // Start listening for events by starting a single shard if let Err(why) = client.start_autosharded().await {