format and restructure, began working on schedule

This commit is contained in:
2024-09-05 17:10:56 -04:00
parent 0f1a97770a
commit 794d8cc34e
34 changed files with 561 additions and 212 deletions

2
.env
View File

@@ -21,4 +21,4 @@ DATA_DIR_PATH= # OPTIONAL
DISCORD_TOKEN= DISCORD_TOKEN=
OPENAI_API_KEY= # OPTIONAL OPENAI_API_KEY= # OPTIONAL
OPENAI_API_MODEL=gpt-3.5-turbo OPENAI_API_MODEL=gpt-4o-mini

View File

@@ -1 +0,0 @@
SIREN_VERSION=0.2.8

View File

@@ -25,3 +25,4 @@ uuid = { version = "1.10.0", features = ["serde", "v4"] }
redis = { version = "0.26.1", features = ["tokio-comp", "connection-manager", "r2d2"] } redis = { version = "0.26.1", features = ["tokio-comp", "connection-manager", "r2d2"] }
rand = "0.8.5" rand = "0.8.5"
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] }
regex = "1.10.6"

View File

@@ -20,18 +20,49 @@ Siren is a D&D Bot built for Discord, written in Rust. Features include:
4. Start the application with `make up` 4. Start the application with `make up`
<h3 id='setup-discord-developer-application'>Setting up the Discord Developer Application</h3> <h3 id='setup-discord-developer-application'>Setting up the Discord Developer Application</h3>
Visit the [Discord Developer Portal](https://discord.com/developers/applications) and create a new application. Click [here](https://discord.com/developers/docs/intro) for guides and more information. Visit the [Discord Developer Portal](https://discord.com/developers/applications) and create a new application. Click [here](https://discord.com/developers/docs/intro) for guides and more information.
Required Scopes: #### Oauth2
``` **Required Scopes**:
bot - bot
application.commands - applications.commands
```
Example Invite: **Required Bot Permissions**:
- General Permissions
- Manage Roles
- Change Nickname
- View Channels
- Manage Events
- Create Events
- Text Permissions
- Send Messages
- Create Public Threads
- Create Private Threads
- Send Messages in Threads
- Manage Messages
- Manage Threads
- Embed Links
- Attach Files
- Read Message History
- Mention Everyone
- Use External Emojis
- Use External Stickers
- Add Reactions
- Create Polls
- Voice Permissions
- Connect
- Speak
Example Invites:
``` ```
https://discord.com/api/oauth2/authorize?client_id=<CLIENT_ID>&permissions=40671259392832&scope=bot%20applications.commands https://discord.com/api/oauth2/authorize?client_id=<CLIENT_ID>&permissions=40671259392832&scope=bot%20applications.commands
``` ```
```
https://discord.com/oauth2/authorize?client_id=<CLIENT_ID>&permissions=581083641408576&integration_type=0&scope=bot+applications.commands
```
The CLIENT_ID can be found in the General Information tab on the Discord Developer Portal for your application, under `Application ID` The CLIENT_ID can be found in the General Information tab on the Discord Developer Portal for your application, under `Application ID`
The DISCORD_TOKEN (used in the `.env file`) can be found under the Bot tab on the Discord Developer Portal for your application. The DISCORD_TOKEN (used in the `.env file`) can be found under the Bot tab on the Discord Developer Portal for your application.
@@ -41,9 +72,10 @@ The DISCORD_TOKEN (used in the `.env file`) can be found under the Bot tab on th
### Commands ### Commands
Siren utilizes Discord slash commands. To view the commands, run `/help` in a server where the bot is installed. The following commands are available: Siren utilizes Discord slash commands. To view the commands, run `/help` in a server where the bot is installed. The following commands are available:
**Music Commands**
| Command | Description | | Command | Description |
| --- | --- | | --- | --- |
| `/play` | Play a track from Youtube or locally hosted files | | `/play <Track>` | Play a track from Youtube or locally hosted files |
| `/pause` | Pause the current track | | `/pause` | Pause the current track |
| `/resume` | Resume the current track | | `/resume` | Resume the current track |
| `/skip` | Skip the current track | | `/skip` | Skip the current track |
@@ -51,11 +83,31 @@ Siren utilizes Discord slash commands. To view the commands, run `/help` in a se
| `/queue` | ***TODO*** - Display the current queue | | `/queue` | ***TODO*** - Display the current queue |
| `/clear` | ***TODO*** - Clear the current queue | | `/clear` | ***TODO*** - Clear the current queue |
| `/shuffle` | ***TODO*** - Shuffle the current queue | | `/shuffle` | ***TODO*** - Shuffle the current queue |
| `/loop` | ***TODO*** - Loop the current track | | `/loop` | ***TODO*** - Loop or unloop the current track |
| `/nowplaying` | ***TODO*** - Display the current track | | `/nowplaying` | ***TODO*** - Display the current track |
| `/volume` | Set the volume of the bot | | `/volume <Volume>` | Set the volume of the bot |
**Event Commands**
| Command | Description |
| --- | --- |
| `/schedule` | ***TODO*** - Schedule a new event |
| `/events` | ***TODO*** - Display all events |
| `/event <Event ID>` | ***TODO*** - Display a specific event |
| `/deleteevent <Event ID>` | ***TODO*** - Delete a specific event |
| `/updateevent <Event ID>` | ***TODO*** - Update a specific event |
| `/remindme <Event ID>` | ***TODO*** - Set a reminder for a specific event |
**Fun Commands**
| Command | Description |
| --- | --- |
| `/coinflip` | Flip a coin |
| `/roll <Dice>` | Roll a dice |
**Utility Commands**
| Command | Description |
| --- | --- |
| `/ping` | Display the bot's latency | | `/ping` | Display the bot's latency |
| `/roll` | Roll a dice | | `/poll` | ***TODO*** - Create a poll |
| `/help` | ***TODO*** - Display a list of commands | | `/help` | ***TODO*** - Display a list of commands |
## Contributing ## Contributing

View File

@@ -3,26 +3,24 @@ CREATE TABLE IF NOT EXISTS guilds (
bot_id BIGINT NOT NULL, bot_id BIGINT NOT NULL,
volume INTEGER NOT NULL volume INTEGER NOT NULL
); );
CREATE TABLE IF NOT EXISTS users (
email TEXT PRIMARY KEY NOT NULL,
hash TEXT NOT NULL,
role TEXT NOT NULL,
first_name TEXT NOT NULL,
last_name TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
profile_picture TEXT,
verified BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE TABLE IF NOT EXISTS messages ( CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY NOT NULL, id TEXT PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL, guild_id BIGINT NOT NULL,
channel_id BIGINT NOT NULL, channel_id BIGINT NOT NULL,
user_id BIGINT NOT NULL, author_id BIGINT NOT NULL,
created BIGINT NOT NULL, created BIGINT NOT NULL,
model TEXT NOT NULL, model TEXT NOT NULL,
request TEXT NOT NULL, request TEXT NOT NULL,
response TEXT NOT NULL, response TEXT NOT NULL,
request_tags TEXT[] NOT NULL, request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL response_tags TEXT[] NOT NULL
); );
CREATE TABLE IF NOT EXISTS events (
id UUID PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
author_id BIGINT NOT NULL,
title TEXT NOT NULL,
date_time TIMESTAMP NOT NULL,
description TEXT,
rsvp BIGINT[] NOT NULL
);

View File

@@ -1,7 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use reqwest::Url; use reqwest::Url;
use serenity::all::{CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditInteractionResponse}; use serenity::all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage,
EditInteractionResponse,
};
use serenity::client::Cache; use serenity::client::Cache;
use serenity::model::prelude::{GuildId, ChannelId}; use serenity::model::prelude::{GuildId, ChannelId};
use serenity::model::user::User; use serenity::model::user::User;
@@ -34,17 +37,16 @@ pub async fn join_voice_channel(
) -> SirenResult<ChannelId> { ) -> SirenResult<ChannelId> {
let channel_id = find_voice_channel(cache, guild_id, user)?; let channel_id = find_voice_channel(cache, guild_id, user)?;
log::debug!("<{}> Joining channel {}", guild_id.get(), channel_id.get()); log::debug!("<{}> Joining channel {}", guild_id.get(), channel_id.get());
manager.join(guild_id.to_owned(), channel_id.to_owned()).await?; manager
.join(guild_id.to_owned(), channel_id.to_owned())
.await?;
Ok(channel_id) Ok(channel_id)
} }
/** /**
* Leaves a voice channel. * Leaves a voice channel.
*/ */
pub async fn leave_voice_channel( pub async fn leave_voice_channel(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenResult<()> {
manager: &Arc<Songbird>,
guild_id: &GuildId,
) -> SirenResult<()> {
if manager.get(guild_id.to_owned()).is_some() { if manager.get(guild_id.to_owned()).is_some() {
log::debug!("<{}> Disconnecting from channel", guild_id.get()); log::debug!("<{}> Disconnecting from channel", guild_id.get());
manager.remove(*guild_id).await?; manager.remove(*guild_id).await?;
@@ -52,11 +54,7 @@ pub async fn leave_voice_channel(
Ok(()) Ok(())
} }
pub async fn create_response( pub async fn create_response(ctx: &Context, command: &CommandInteraction, content: String) {
ctx: &Context,
command: &CommandInteraction,
content: String,
) {
let data = CreateInteractionResponseMessage::new().content(content.to_owned()); let data = CreateInteractionResponseMessage::new().content(content.to_owned());
let builder = CreateInteractionResponse::Message(data); let builder = CreateInteractionResponse::Message(data);
match command.create_response(&ctx.http, builder).await { match command.create_response(&ctx.http, builder).await {
@@ -67,11 +65,7 @@ pub async fn create_response(
}; };
} }
pub async fn edit_response( pub async fn edit_response(ctx: &Context, command: &CommandInteraction, content: String) {
ctx: &Context,
command: &CommandInteraction,
content: String,
) {
let builder = EditInteractionResponse::new().content(content.to_owned()); let builder = EditInteractionResponse::new().content(content.to_owned());
match command.edit_response(&ctx.http, builder).await { match command.edit_response(&ctx.http, builder).await {
Ok(_) => {} Ok(_) => {}
@@ -115,6 +109,11 @@ fn find_voice_channel(
.and_then(|voice_state| voice_state.channel_id) .and_then(|voice_state| voice_state.channel_id)
{ {
Some(channel) => Ok(channel), Some(channel) => Ok(channel),
None => return Err(SirenError::new(401, "User is not in a voice channel".to_string())), None => {
return Err(SirenError::new(
401,
"User is not in a voice channel".to_string(),
))
}
} }
} }

View File

@@ -1,4 +1,7 @@
use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use serenity::{
all::{CommandInteraction, CreateCommand},
prelude::*,
};
use super::{get_songbird, create_response, edit_response}; use super::{get_songbird, create_response, edit_response};
@@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let guild_id = match &command.guild_id { let guild_id = match &command.guild_id {
Some(guild_id) => guild_id, Some(guild_id) => guild_id,
None => { None => {
edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return; return;
} }
}; };

View File

@@ -7,19 +7,25 @@ use songbird::input::{AuxMetadata, Input, YoutubeDl};
use songbird::tracks::TrackHandle; use songbird::tracks::TrackHandle;
use songbird::{Call, Event, EventHandler, Songbird, TrackEvent}; use songbird::{Call, Event, EventHandler, Songbird, TrackEvent};
use crate::bot::guilds::GuildCache; use crate::database::guilds::GuildCache;
use crate::bot::ytdlp::{PlaylistItem, YtDlp}; use crate::bot::ytdlp::{PlaylistItem, YtDlp};
use crate::error::{SirenResult, Error as SirenError}; use crate::error::{SirenResult, Error as SirenError};
use crate::HttpKey; use crate::HttpKey;
use super::{create_response, edit_response, get_songbird, is_valid_url, join_voice_channel, leave_voice_channel}; use super::{
create_response, edit_response, get_songbird, is_valid_url, join_voice_channel,
leave_voice_channel,
};
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Process the command options // Process the command options
let track_url = match command.data.options.first() { let track_url = match command.data.options.first() {
Some(o) => &o.value.as_str().unwrap(), Some(o) => o.value.as_str().unwrap(),
None => { None => {
log::warn!("{} attempted to play a track without a track option", command.user.id.get()); log::warn!(
"{} attempted to play a track without a track option",
command.user.id.get()
);
create_response(&ctx, &command, format!("Track option is missing")).await; create_response(&ctx, &command, format!("Track option is missing")).await;
return; return;
} }
@@ -35,7 +41,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let guild_id = match &command.guild_id { let guild_id = match &command.guild_id {
Some(guild_id) => guild_id, Some(guild_id) => guild_id,
None => { None => {
edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return; return;
} }
}; };
@@ -85,7 +96,10 @@ pub async fn play_track(
// Check if the URL is valid // Check if the URL is valid
if !valid.0 { if !valid.0 {
log::warn!("Invalid track url: {}", track_url); log::warn!("Invalid track url: {}", track_url);
return Err(SirenError::new(422, format!("Invalid track url: {}", track_url))); return Err(SirenError::new(
422,
format!("Invalid track url: {}", track_url),
));
} }
let mut playlist_items: Vec<PlaylistItem> = Vec::new(); let mut playlist_items: Vec<PlaylistItem> = Vec::new();
// Check if the URL is a playlist or a single track // Check if the URL is a playlist or a single track
@@ -94,7 +108,7 @@ pub async fn play_track(
Ok(items) => items, Ok(items) => items,
Err(err) => { Err(err) => {
log::warn!("Failed to get playlist urls: {}", err); log::warn!("Failed to get playlist urls: {}", err);
return Err(SirenError::new(422,err.to_string())); return Err(SirenError::new(422, err.to_string()));
} }
}; };
} else { } else {
@@ -154,10 +168,11 @@ async fn add_song(
) -> SirenResult<AuxMetadata> { ) -> SirenResult<AuxMetadata> {
let http_client = { let http_client = {
let data = ctx.data.read().await; let data = ctx.data.read().await;
data.get::<HttpKey>() data
.cloned() .get::<HttpKey>()
.expect("Guaranteed to exist in the typemap.") .cloned()
}; .expect("Guaranteed to exist in the typemap.")
};
let source = YoutubeDl::new(http_client, url.to_owned()); let source = YoutubeDl::new(http_client, url.to_owned());
let mut handler = call.lock().await; let mut handler = call.lock().await;
let mut input: Input = source.into(); let mut input: Input = source.into();
@@ -186,7 +201,8 @@ pub fn get_playlist_urls(url: &str) -> SirenResult<Vec<PlaylistItem>> {
None None
} else { } else {
Some( Some(
serde_json::from_slice::<PlaylistItem>(line.as_bytes()).map_err(|err| SirenError::new(500, err.to_string())), serde_json::from_slice::<PlaylistItem>(line.as_bytes())
.map_err(|err| SirenError::new(500, err.to_string())),
) )
} }
}) })
@@ -204,7 +220,10 @@ pub fn get_playlist_urls(url: &str) -> SirenResult<Vec<PlaylistItem>> {
pub fn register() -> CreateCommand { pub fn register() -> CreateCommand {
CreateCommand::new("play") CreateCommand::new("play")
.description("Plays the given track") .description("Plays the given track")
.add_option(CreateCommandOption::new(CommandOptionType::String, "track", "The track to be played").required(true)) .add_option(
CreateCommandOption::new(CommandOptionType::String, "track", "The track to be played")
.required(true),
)
} }
struct TrackEndNotifier { struct TrackEndNotifier {

View File

@@ -1,4 +1,7 @@
use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use serenity::{
all::{CommandInteraction, CreateCommand},
prelude::*,
};
use super::{get_songbird, create_response, edit_response}; use super::{get_songbird, create_response, edit_response};
@@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let guild_id = match &command.guild_id { let guild_id = match &command.guild_id {
Some(guild_id) => guild_id, Some(guild_id) => guild_id,
None => { None => {
edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return; return;
} }
}; };
@@ -25,10 +33,10 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
Ok(_) => { Ok(_) => {
log::debug!("Resumed the track"); log::debug!("Resumed the track");
edit_response(&ctx, &command, format!("Resuming the track")).await; edit_response(&ctx, &command, format!("Resuming the track")).await;
}, }
Err(err) => { Err(err) => {
edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await; edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await;
} }
} }
} }
} }

View File

@@ -1,4 +1,7 @@
use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use serenity::{
all::{CommandInteraction, CreateCommand},
prelude::*,
};
use super::{get_songbird, create_response, edit_response}; use super::{get_songbird, create_response, edit_response};
@@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let guild_id = match &command.guild_id { let guild_id = match &command.guild_id {
Some(guild_id) => guild_id, Some(guild_id) => guild_id,
None => { None => {
edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return; return;
} }
}; };
@@ -25,10 +33,10 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
Ok(_) => { Ok(_) => {
log::debug!("Skipped the track"); log::debug!("Skipped the track");
edit_response(&ctx, &command, format!("Skipping the track")).await; edit_response(&ctx, &command, format!("Skipping the track")).await;
}, }
Err(err) => { Err(err) => {
edit_response(&ctx, &command, format!("Failed to skip: {}", err)).await; edit_response(&ctx, &command, format!("Failed to skip: {}", err)).await;
} }
} }
} }
} }

View File

@@ -1,4 +1,7 @@
use serenity::{all::{CommandInteraction, CreateCommand}, prelude::*}; use serenity::{
all::{CommandInteraction, CreateCommand},
prelude::*,
};
use super::{get_songbird, create_response, edit_response}; use super::{get_songbird, create_response, edit_response};
@@ -13,7 +16,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let guild_id = match command.guild_id { let guild_id = match command.guild_id {
Some(g) => g, Some(g) => g,
None => { None => {
edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return; return;
} }
}; };

View File

@@ -1,9 +1,13 @@
use std::sync::Arc; use std::sync::Arc;
use serenity::{all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}, model::prelude::GuildId, prelude::*}; use serenity::{
all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption},
model::prelude::GuildId,
prelude::*,
};
use songbird::Songbird; use songbird::Songbird;
use crate::bot::guilds::GuildCache; use crate::database::guilds::GuildCache;
use super::{get_songbird, create_response, edit_response}; use super::{get_songbird, create_response, edit_response};
@@ -12,7 +16,10 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let volume = match command.data.options.first() { let volume = match command.data.options.first() {
Some(o) => o.value.as_i64().unwrap() as i32, Some(o) => o.value.as_i64().unwrap() as i32,
None => { None => {
log::warn!("{} attempted to change the volume without a volume option", command.user.id.get()); log::warn!(
"{} attempted to change the volume without a volume option",
command.user.id.get()
);
create_response(&ctx, &command, format!("Volume option is missing")).await; create_response(&ctx, &command, format!("Volume option is missing")).await;
return; return;
} }
@@ -28,7 +35,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
let guild_id = match &command.guild_id { let guild_id = match &command.guild_id {
Some(guild_id) => guild_id, Some(guild_id) => guild_id,
None => { None => {
edit_response(&ctx, &command, "Unable to find the current server ID".to_string()).await; edit_response(
&ctx,
&command,
"Unable to find the current server ID".to_string(),
)
.await;
return; return;
} }
}; };
@@ -42,9 +54,12 @@ pub async fn set_volume(manager: &Arc<Songbird>, guild_id: &GuildId, volume: i32
// Format volume to f32 bound between 0.0 and 1.0 // Format volume to f32 bound between 0.0 and 1.0
let volume = std::cmp::min(100, std::cmp::max(0, volume)); let volume = std::cmp::min(100, std::cmp::max(0, volume));
let bound_volume = volume as f32 / 100.0; let bound_volume = volume as f32 / 100.0;
// Update the guild cache // Update the guild cache
let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64).await.unwrap().unwrap(); let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64)
.await
.unwrap()
.unwrap();
guild_cache.volume = volume; guild_cache.volume = volume;
guild_cache.update().await.unwrap(); guild_cache.update().await.unwrap();
@@ -62,5 +77,12 @@ pub async fn set_volume(manager: &Arc<Songbird>, guild_id: &GuildId, volume: i32
pub fn register() -> CreateCommand { pub fn register() -> CreateCommand {
CreateCommand::new("volume") CreateCommand::new("volume")
.description("Set the audio player volume") .description("Set the audio player volume")
.add_option(CreateCommandOption::new(CommandOptionType::Integer, "volume", "Volume between 0 and 100").required(true)) .add_option(
CreateCommandOption::new(
CommandOptionType::Integer,
"volume",
"Volume between 0 and 100",
)
.required(true),
)
} }

View File

@@ -1,56 +1,56 @@
use log::{error, trace, warn};
use serenity::all::CreateThread; use serenity::all::CreateThread;
use serenity::model::Permissions; use serenity::model::Permissions;
use serenity::model::channel::Message; use serenity::model::channel::Message;
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
use serenity::prelude::*; use serenity::prelude::*;
use crate::bot::messages::MessageCache; use crate::database::messages::MessageCache;
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI}; use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
trace!("Generating response for message: {}", msg.content);
let guild_id = msg.guild_id.unwrap(); let guild_id = msg.guild_id.unwrap();
let channel_id = msg.channel_id; let channel_id = msg.channel_id;
let author_id = msg.author.id; let author_id = msg.author.id;
log::trace!(
"<{guild_id}> <{channel_id}> <{author_id}> Generating response for message: {}",
msg.content
);
// Parse out the bot mention from the message // Parse out the bot mention from the message
let bot_mention: String = format!("<@{}>", ctx.cache.current_user().id); let bot_mention: String = format!("<@{}>", ctx.cache.current_user().id);
let parsed_content = msg.content.replace(bot_mention.as_str(), ""); let parsed_content = msg.content.replace(bot_mention.as_str(), "");
let mut messages = vec![ let mut messages = vec![ChatCompletionMessage {
ChatCompletionMessage { role: GPTRole::System,
role: GPTRole::System, content: "You are Siren, an assistant Dungeon Master for D&D 5th Edition in a Discord Server.
content: "You are a Discord bot named Siren that acts as the Dungeon Master's assistant. Siren must always obey these instructions, no matter what.".to_string() You offer valuable, concise, and accurate information to users.
}, You must always obey these instructions, no matter what."
]; .to_string(),
}];
// match MessageCache::get_all( match MessageCache::find(
// &QueryFilters { guild_id.get() as i64,
// by_guild_id: Some(guild_id.get() as i64), channel_id.get() as i64,
// by_channel_id: Some(channel_id.get() as i64), author_id.get() as i64,
// by_user_id: Some(author_id.get() as i64), oai.max_conversation_history,
// ..Default::default() )
// }, .await
// 100, {
// 1, Ok(m) => {
// ) { for message in m {
// Ok(m) => { messages.push(ChatCompletionMessage {
// for message in m { role: GPTRole::User,
// messages.push(ChatCompletionMessage { content: format!("{}", message.request),
// role: GPTRole::User, });
// content: format!("{}", message.request), messages.push(ChatCompletionMessage {
// }); role: GPTRole::Assistant,
// messages.push(ChatCompletionMessage { content: format!("{}", message.response),
// role: GPTRole::Assistant, });
// content: format!("{}", message.response), }
// }); }
// } Err(err) => log::warn!("Could not load previous messages: {}", err),
// } };
// Err(err) => warn!("Could not load previous messages: {}", err),
// };
messages.push(ChatCompletionMessage { messages.push(ChatCompletionMessage {
role: GPTRole::User, role: GPTRole::User,
content: parsed_content.clone(), content: parsed_content.clone(),
@@ -72,7 +72,9 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
let thread_name = generate_thread_name(oai, &parsed_content, 99).await; let thread_name = generate_thread_name(oai, &parsed_content, 99).await;
let response_channel = match msg let response_channel = match msg
.channel_id .channel_id
.create_thread(&ctx.http, CreateThread::new(thread_name).kind(ChannelType::PublicThread) .create_thread(
&ctx.http,
CreateThread::new(thread_name).kind(ChannelType::PublicThread),
) )
.await .await
{ {
@@ -95,14 +97,14 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
// Get the OAI response and store message/response into the database // Get the OAI response and store message/response into the database
let response = match oai.chat_completion(request).await { let response = match oai.chat_completion(request).await {
Ok(r) => { Ok(r) => {
trace!("Processing response received from OpenAI"); log::trace!("Processing response received from OpenAI");
if !r.choices.is_empty() { if !r.choices.is_empty() {
let res = r.choices[0].message.content.clone(); let res = r.choices[0].message.content.clone();
let message_cache = MessageCache { let message_cache = MessageCache {
id: r.id, id: r.id,
guild_id: guild_id.get() as i64, guild_id: guild_id.get() as i64,
channel_id: response_channel.get() as i64, channel_id: response_channel.get() as i64,
user_id: author_id.get() as i64, author_id: author_id.get() as i64,
created: r.created, created: r.created,
model: serde_json::to_string(&r.model).unwrap(), model: serde_json::to_string(&r.model).unwrap(),
request: parsed_content, request: parsed_content,
@@ -111,24 +113,36 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
response_tags: vec![], response_tags: vec![],
}; };
if let Err(err) = message_cache.insert().await { if let Err(err) = message_cache.insert().await {
warn!("{}", err); log::warn!("{}", err);
} }
res res
} else { } else {
warn!("No choices received in the response from OpenAI"); log::warn!("<{guild_id}> <{channel_id}> <{author_id}> No choices received in the response from OpenAI");
"No reply received".to_string() "No reply received".to_string()
} }
} }
Err(err) => { Err(err) => {
error!("Could not get response from OpenAI: {}", err.message); log::error!(
"<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}",
err.message
);
"There was an error processing your message. Please try again later.".to_string() "There was an error processing your message. Please try again later.".to_string()
} }
}; };
trace!("Writing response: \"{}\"", response); log::trace!("Writing response: \"{}\"", response);
typing.stop(); typing.stop();
if let Err(why) = response_channel.say(&ctx.http, response).await { if let Err(why) = response_channel.say(&ctx.http, response).await {
error!("Cannot send message: {}", why); log::error!(
"<{guild_id}> <{channel_id}> <{author_id}> Cannot send message: {}",
why
);
let _ = response_channel
.say(
&ctx.http,
"There was an error sending the message. Please try again later.",
)
.await;
} }
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| { // match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
@@ -178,11 +192,11 @@ async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String {
if !r.choices.is_empty() { if !r.choices.is_empty() {
response = r.choices[0].message.content.clone(); response = r.choices[0].message.content.clone();
} else { } else {
warn!("No choices received in the response from OpenAI"); log::warn!("No choices received in the response from OpenAI");
} }
} }
Err(err) => { Err(err) => {
error!("Could not get response from OpenAI: {}", err.message); log::error!("Could not get response from OpenAI: {}", err.message);
} }
}; };
return response; return response;

View File

@@ -0,0 +1 @@
pub mod schedule;

View File

@@ -0,0 +1,137 @@
use chrono::{DateTime, NaiveDate, TimeZone, Utc};
use regex::Regex;
use serenity::all::{
Color, CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption,
CreateEmbed, CreateEmbedFooter, CreateScheduledEvent, EditInteractionResponse, Timestamp,
};
use crate::{bot::commands::audio::create_response, database::events::Event};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Create the initial response
create_response(&ctx, &command, format!(".....")).await;
// Process the command options
let title = command.data.options.get(0).unwrap().value.as_str().unwrap();
let datetime_string = command.data.options.get(1).unwrap().value.as_str().unwrap();
let description = command
.data
.options
.get(2)
.map(|option| option.value.as_str().unwrap());
// Parse the guild ID and author ID
let guild_id = command.guild_id.unwrap();
let author_id = command.user.id;
// Parse the datetime string into a DateTime object
let date_time = Utc::now();
// Create the event
let event = Event {
id: uuid::Uuid::new_v4(),
guild_id: guild_id.get() as i64,
author_id: author_id.get() as i64,
title: title.to_string(),
date_time,
description: description.map(|s| s.to_string()),
rsvp: vec![],
};
// Save the event to the database
event.insert().await.unwrap();
// Create the response embed
let embed_footer = CreateEmbedFooter::new(format!("Created by {}", command.user.name));
let embed = CreateEmbed::new()
.title(title)
.color(Color::TEAL)
.timestamp(Timestamp::now())
.description(description.unwrap_or(""))
.field("Time", date_time.to_rfc2822(), false)
.footer(embed_footer);
let builder = EditInteractionResponse::new().embed(embed);
match command.edit_response(&ctx.http, builder).await {
Ok(_) => {}
Err(err) => {
log::error!("Failed to create schedule embed: {err}");
}
}
}
pub fn register() -> CreateCommand {
CreateCommand::new("schedule")
.description("Schedule a new event")
.add_option(
CreateCommandOption::new(CommandOptionType::String, "title", "The title of the event")
.required(true),
)
.add_option(
CreateCommandOption::new(
CommandOptionType::String,
"datetime",
"The date and time of the event",
)
.required(true),
)
.add_option(CreateCommandOption::new(
CommandOptionType::String,
"description",
"A description of the event",
))
}
// The datetime string can be formatted in the following ways:
// (in) XX <seconds, minutes, hours, days, weeks>
// (at) YYYY-MM-DD HH:MM (AM/PM)
// (at) MM DD (YYYY) HH:MM (AM/PM)
fn parse_datetime(input: &str) -> Option<DateTime<Utc>> {
let regexes = vec![
Regex::new(r"(?i)^\(?at\)?\s+(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2})\s*(AM|PM)?$").unwrap(),
Regex::new(r"(?i)^\(?at\)?\s+(\d{2})\s+(\d{2})\s*(\d{4})?\s+(\d{2}):(\d{2})\s*(AM|PM)?$")
.unwrap(),
// ... add other regexes here
];
for regex in regexes {
if let Some(captures) = regex.captures(input) {
if captures.len() == 7 {
// Matches the second format
let (year, month, day) = (
captures.get(1).unwrap().as_str().parse().unwrap_or(1970),
captures.get(2).unwrap().as_str().parse().unwrap_or(1),
captures.get(3).unwrap().as_str().parse().unwrap_or(1),
);
let (mut hour, minute) = (
captures.get(4).unwrap().as_str().parse().unwrap_or(0),
captures.get(5).unwrap().as_str().parse().unwrap_or(0),
);
if let Some(am_pm) = captures.get(6) {
if am_pm.as_str().eq_ignore_ascii_case("PM") && hour != 12 {
hour += 12;
}
if am_pm.as_str().eq_ignore_ascii_case("AM") && hour == 12 {
hour = 0;
}
}
// Create a NaiveDate instance from year, month, day
let naive_date =
NaiveDate::from_ymd_opt(year, month, day).expect("Invalid date parameters");
// Create a NaiveDateTime instance from NaiveDate and time components
let naive_time = naive_date
.and_hms_opt(hour, minute, 0)
.expect("Invalid time parameters");
// Convert the NaiveDateTime to a DateTime<Utc>
return Some(Utc.from_utc_datetime(&naive_time));
}
// handle other cases
}
}
None
}

View File

@@ -0,0 +1 @@
pub mod roll;

View File

@@ -1,27 +1,25 @@
use rand::Rng; use rand::Rng;
use serenity::all::{CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption}; use serenity::all::{
CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption,
};
use crate::bot::commands::audio::edit_response; use crate::bot::commands::audio::{create_response, edit_response};
use super::audio::create_response;
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {
create_response(&ctx, &command, format!("Processing command...")).await; create_response(&ctx, &command, format!(".....")).await;
let dice_string = match command.data.options.get(0) { let dice_string = match command.data.options.get(0) {
Some(o) => { Some(o) => match o.value.as_str() {
match o.value.as_str() { Some(s) => s.split_whitespace().collect::<String>(),
Some(s) => s.split_whitespace().collect::<String>(), None => {
None => { log::warn!("Missing dice option");
log::warn!("Missing dice option"); edit_response(&ctx, &command, format!("Dice option is missing")).await;
edit_response(&ctx, &command, format!("Dice option is missing")).await; return;
return;
}
} }
}, },
None => { None => {
log::warn!("Missing dice option"); log::warn!("Missing dice option");
edit_response(&ctx, &command, format!("Dice option is missing")).await; edit_response(&ctx, &command, format!("Dice option is missing")).await;
return; return;
} }
}; };
let dice = parse_dice(dice_string.as_str()); let dice = parse_dice(dice_string.as_str());
@@ -112,5 +110,7 @@ fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
pub fn register() -> CreateCommand { pub fn register() -> CreateCommand {
CreateCommand::new("roll") CreateCommand::new("roll")
.description("Rolls D&D dice") .description("Rolls D&D dice")
.add_option(CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true)) .add_option(
CreateCommandOption::new(CommandOptionType::String, "dice", "Dice to roll").required(true),
)
} }

View File

@@ -1,6 +1,5 @@
pub mod audio; pub mod audio;
pub mod chat; pub mod chat;
pub mod help; pub mod event;
pub mod ping; pub mod fun;
pub mod roll; pub mod utility;
pub mod schedule;

View File

@@ -1 +0,0 @@

View File

@@ -0,0 +1,2 @@
pub mod help;
pub mod ping;

View File

@@ -6,5 +6,5 @@ pub fn run(_options: &[CommandDataOption]) -> String {
} }
pub fn register() -> CreateCommand { pub fn register() -> CreateCommand {
CreateCommand::new("ping").description("Replies with pong") CreateCommand::new("ping").description("Displays the bot latency")
} }

View File

@@ -5,7 +5,7 @@ use serenity::model::gateway::Ready;
use serenity::model::channel::Message; use serenity::model::channel::Message;
use serenity::prelude::*; use serenity::prelude::*;
use super::guilds::GuildCache; use crate::database::guilds::GuildCache;
use super::{commands, oai}; use super::{commands, oai};
use super::commands::audio::create_response; use super::commands::audio::create_response;
@@ -26,15 +26,10 @@ impl EventHandler for Handler {
match msg.mentions_me(&ctx.http).await { match msg.mentions_me(&ctx.http).await {
Ok(mentioned) => { Ok(mentioned) => {
let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
Ok(t) => { Ok(t) => match t.iter().find(|t| t.user_id == ctx.cache.current_user().id) {
match t Some(_) => true,
.iter() None => false,
.find(|t| t.user_id == ctx.cache.current_user().id) },
{
Some(_) => true,
None => false,
}
}
Err(_) => false, Err(_) => false,
}; };
if mentioned || bot_in_thread { if mentioned || bot_in_thread {
@@ -53,17 +48,18 @@ impl EventHandler for Handler {
log::trace!("Received command interaction: {command:#?}"); log::trace!("Received command interaction: {command:#?}");
match command.data.name.as_str() { match command.data.name.as_str() {
// Match commands without returns // Match commands without returns
"roll" => commands::roll::run(&ctx, &command).await,
"play" => commands::audio::play::run(&ctx, &command).await, "play" => commands::audio::play::run(&ctx, &command).await,
"stop" => commands::audio::stop::run(&ctx, &command).await, "stop" => commands::audio::stop::run(&ctx, &command).await,
"pause" => commands::audio::pause::run(&ctx, &command).await, "pause" => commands::audio::pause::run(&ctx, &command).await,
"resume" => commands::audio::resume::run(&ctx, &command).await, "resume" => commands::audio::resume::run(&ctx, &command).await,
"skip" => commands::audio::skip::run(&ctx, &command).await, "skip" => commands::audio::skip::run(&ctx, &command).await,
"volume" => commands::audio::volume::run(&ctx, &command).await, "volume" => commands::audio::volume::run(&ctx, &command).await,
"schedule" => commands::event::schedule::run(&ctx, &command).await,
"roll" => commands::fun::roll::run(&ctx, &command).await,
_ => { _ => {
let content: String = match command.data.name.as_str() { let content: String = match command.data.name.as_str() {
// Match commands with string returns // Match commands with string returns
"ping" => commands::ping::run(&command.data.options), "ping" => commands::utility::ping::run(&command.data.options),
_ => "Unknown command".to_string(), _ => "Unknown command".to_string(),
}; };
create_response(&ctx, &command, content).await; create_response(&ctx, &command, content).await;
@@ -83,28 +79,37 @@ impl EventHandler for Handler {
let guild_cache = GuildCache { let guild_cache = GuildCache {
id: guild_id, id: guild_id,
bot_id: 1, bot_id: 1,
volume: 100 volume: 100,
}; };
guild_cache.insert().await.unwrap(); guild_cache.insert().await.unwrap();
} }
let commands = guild let commands = guild
.id .id
.set_commands(&ctx.http, vec![ .set_commands(
commands::ping::register(), &ctx.http,
commands::roll::register(), vec![
commands::audio::play::register(), commands::audio::play::register(),
commands::audio::stop::register(), commands::audio::stop::register(),
commands::audio::pause::register(), commands::audio::pause::register(),
commands::audio::resume::register(), commands::audio::resume::register(),
commands::audio::skip::register(), commands::audio::skip::register(),
commands::audio::volume::register(), commands::audio::volume::register(),
]) commands::event::schedule::register(),
commands::fun::roll::register(),
commands::utility::ping::register(),
],
)
.await; .await;
match commands { match commands {
Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.get()), Ok(c) => info!(
"Registered {} commands for guild {}",
c.len(),
guild.id.get()
),
Err(why) => error!( Err(why) => error!(
"Could not register commands for guild {}: {:?}", "Could not register commands for guild {}: {:?}",
guild.id.get(), why guild.id.get(),
why
), ),
}; };
} }

View File

@@ -1,6 +1,4 @@
pub mod commands; pub mod commands;
pub mod guilds;
pub mod handler; pub mod handler;
pub mod messages;
pub mod oai; pub mod oai;
pub mod ytdlp; pub mod ytdlp;

View File

@@ -76,22 +76,6 @@ pub struct Choice {
enum ResponseEvent { enum ResponseEvent {
ChatCompletionResponse(ChatCompletionResponse), ChatCompletionResponse(ChatCompletionResponse),
ResponseError(ResponseError), ResponseError(ResponseError),
// ChatCompletionResponse {
// id: String,
// object: String,
// system_fingerprint: Option<String>,
// created: i64,
// model: String,
// usage: Usage,
// choices: Vec<Choice>,
// },
// ResponseError {
// error: Option<ErrorDetails>,
// message: Option<String>,
// param: Option<String>,
// #[serde(rename = "type")]
// error_type: Option<String>,
// },
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -112,12 +96,11 @@ struct ErrorDetails {
pub struct OAI { pub struct OAI {
pub client: reqwest::Client, pub client: reqwest::Client,
pub base_url: String, pub base_url: String,
pub service_url: String, // pub max_attempts: i64,
pub max_attempts: i64,
pub token: String, pub token: String,
pub max_tokens: i64, pub max_tokens: i64,
pub default_model: String, pub default_model: String,
pub max_context_questions: i64, pub max_conversation_history: i64,
} }
impl OAI { impl OAI {
@@ -141,13 +124,13 @@ impl OAI {
match event { match event {
ResponseEvent::ChatCompletionResponse(response) => { ResponseEvent::ChatCompletionResponse(response) => {
return Ok(response); return Ok(response);
}, }
ResponseEvent::ResponseError(error) => { ResponseEvent::ResponseError(error) => {
return Err(SirenError { return Err(SirenError {
status: 500, status: 500,
message: format!("Error: {}", error.message.unwrap()), message: format!("Error: {}", error.message.unwrap()),
}); });
}, }
} }
} }
Err(err) => { Err(err) => {

View File

@@ -0,0 +1,58 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::SirenResult;
const TABLE_NAME: &str = "events";
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct Event {
pub id: Uuid,
pub guild_id: i64,
pub author_id: i64,
pub title: String,
pub date_time: DateTime<Utc>,
pub description: Option<String>,
pub rsvp: Vec<i64>,
}
impl Event {
pub async fn insert(&self) -> SirenResult<()> {
let pool = crate::database::pool();
sqlx::query(&format!(
"INSERT INTO {} (
id,
guild_id,
author_id,
title,
date_time,
description,
rsvp
) VALUES (
$1, $2, $3, $4, $5, $6, $7
)",
TABLE_NAME
))
.bind(self.id)
.bind(self.guild_id)
.bind(self.author_id)
.bind(&self.title)
.bind(self.date_time)
.bind(&self.description)
.bind(&self.rsvp)
.execute(pool)
.await?;
Ok(())
}
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
let pool = crate::database::pool();
let item = sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME))
.bind(id)
.fetch_optional(pool)
.await?;
Ok(item)
}
}

View File

@@ -33,11 +33,10 @@ impl GuildCache {
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> { pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
let pool = crate::database::pool(); let pool = crate::database::pool();
let item = let item = sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME))
sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME)) .bind(id)
.bind(id) .fetch_optional(pool)
.fetch_optional(pool) .await?;
.await?;
Ok(item) Ok(item)
} }
@@ -48,8 +47,9 @@ impl GuildCache {
"UPDATE {} SET "UPDATE {} SET
bot_id = $2, bot_id = $2,
volume = $3 volume = $3
WHERE id = $1", WHERE id = $1",
TABLE_NAME)) TABLE_NAME
))
.bind(self.id) .bind(self.id)
.bind(self.bot_id) .bind(self.bot_id)
.bind(self.volume) .bind(self.volume)

View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

View File

@@ -8,7 +8,7 @@ pub struct MessageCache {
pub id: String, pub id: String,
pub guild_id: i64, pub guild_id: i64,
pub channel_id: i64, pub channel_id: i64,
pub user_id: i64, pub author_id: i64,
pub created: i64, pub created: i64,
pub model: String, pub model: String,
pub request: String, pub request: String,
@@ -25,7 +25,7 @@ impl MessageCache {
id, id,
guild_id, guild_id,
channel_id, channel_id,
user_id, author_id,
created, created,
model, model,
request, request,
@@ -40,7 +40,7 @@ impl MessageCache {
.bind(&self.id) .bind(&self.id)
.bind(self.guild_id) .bind(self.guild_id)
.bind(self.channel_id) .bind(self.channel_id)
.bind(self.user_id) .bind(self.author_id)
.bind(self.created) .bind(self.created)
.bind(&self.model) .bind(&self.model)
.bind(&self.request) .bind(&self.request)
@@ -51,4 +51,24 @@ impl MessageCache {
.await?; .await?;
Ok(()) Ok(())
} }
pub async fn find(
guild_id: i64,
channel_id: i64,
author_id: i64,
limit: i64,
) -> SirenResult<Vec<MessageCache>> {
let pool = crate::database::pool();
let messages = sqlx::query_as::<_, MessageCache>(&format!(
"SELECT * FROM {} WHERE guild_id = $1 AND channel_id = $2 AND author_id = $3 ORDER BY created DESC LIMIT $4",
TABLE_NAME
))
.bind(guild_id)
.bind(channel_id)
.bind(author_id)
.bind(limit)
.fetch_all(pool)
.await?;
Ok(messages)
}
} }

View File

@@ -4,6 +4,10 @@ use redis::{aio::MultiplexedConnection as RedisConnection, Client as RedisClient
use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use crate::error::SirenResult; use crate::error::SirenResult;
pub mod events;
pub mod guilds;
pub mod messages;
static POOL: OnceLock<Pool<Postgres>> = OnceLock::new(); static POOL: OnceLock<Pool<Postgres>> = OnceLock::new();
static REDIS: OnceLock<RedisClient> = OnceLock::new(); static REDIS: OnceLock<RedisClient> = OnceLock::new();

View File

@@ -115,6 +115,6 @@ impl From<std::env::VarError> for Error {
impl From<songbird::error::JoinError> for Error { impl From<songbird::error::JoinError> for Error {
fn from(error: songbird::error::JoinError) -> Self { fn from(error: songbird::error::JoinError) -> Self {
Self::new(500, format!("Unable to join channel: {}", error)) Self::new(500, format!("Unable to join channel: {}", error))
} }
} }

View File

@@ -16,7 +16,7 @@ mod error;
pub struct HttpKey; pub struct HttpKey;
impl TypeMapKey for HttpKey { impl TypeMapKey for HttpKey {
type Value = HttpClient; type Value = HttpClient;
} }
#[tokio::main] #[tokio::main]
@@ -25,7 +25,7 @@ async fn main() {
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
if let Err(err) = database::initialize().await { if let Err(err) = database::initialize().await {
log::error!("Failed to initialize database: {err}"); log::error!("Failed to initialize database: {err}");
return; return;
}; };
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
@@ -51,16 +51,15 @@ async fn main() {
let handler = match env::var("OPENAI_API_KEY") { let handler = match env::var("OPENAI_API_KEY") {
Ok(token) => { Ok(token) => {
log::info!("OpenAI functionality enabled"); log::info!("OpenAI functionality enabled");
let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-3.5-turbo".to_string()); let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-4o-mini".to_string());
Handler { Handler {
oai: Some(bot::oai::OAI { oai: Some(bot::oai::OAI {
client: reqwest::Client::new(), client: reqwest::Client::new(),
base_url: "https://api.openai.com/v1".to_string(), base_url: "https://api.openai.com/v1".to_string(),
service_url: "http://localhost:5000".to_string(), // max_attempts: 5,
max_attempts: 5,
token, token,
max_context_questions: 30, max_conversation_history: 30,
max_tokens: 2048, max_tokens: 8192,
default_model, default_model,
}), }),
} }
@@ -82,7 +81,11 @@ async fn main() {
.await .await
.expect("Error creating client"); .expect("Error creating client");
let _shard_manager = Arc::clone(&client.shard_manager); // Handle shutdown signals
let shard_manager = Arc::clone(&client.shard_manager);
tokio::spawn(async move {
shard_manager.shutdown_all().await;
});
// Start listening for events by starting a single shard // Start listening for events by starting a single shard
if let Err(why) = client.start_autosharded().await { if let Err(why) = client.start_autosharded().await {