Files
siren/src/bot/handler.rs

267 lines
8.6 KiB
Rust

use std::env;
use std::sync::{Arc, OnceLock};
use serenity::all::{
CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage,
EditInteractionResponse, Interaction, ResumedEvent, UserId,
};
use serenity::async_trait;
use serenity::model::gateway::Ready;
use serenity::model::channel::Message;
use serenity::prelude::*;
use songbird::Songbird;
use crate::bot::commands::chat::generate_response;
use crate::bot::commands::fun::roll::{format_roll, roll_dice, send_roll_message};
use crate::bot::oai::OAI;
use crate::data::guilds::GuildCache;
use crate::HttpKey;
use crate::utils::{a_or_an, number_to_words};
use super::{commands};
use super::chat::{create_modal_response, user_dm};
pub struct BotHandler {
// Open AI Config
pub oai: Option<OAI>,
}
static REGISTERED: OnceLock<bool> = OnceLock::new();
static SONGBIRD: OnceLock<Arc<Songbird>> = OnceLock::new();
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
pub fn get_songbird() -> &'static Arc<Songbird> {
SONGBIRD.get().unwrap()
}
pub fn get_client() -> &'static reqwest::Client {
CLIENT.get().unwrap()
}
impl BotHandler {
pub fn new() -> Self {
match env::var("OPENAI_TOKEN") {
Ok(token) => {
log::debug!("OpenAI functionality enabled");
let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
let base_url = env::var("OPENAI_BASE_URL").unwrap();
Self {
oai: Some(OAI {
client: reqwest::Client::new(),
base_url,
token,
max_conversation_history: 30,
max_tokens: 8192,
default_model,
}),
}
}
Err(_) => {
log::warn!("OpenAI functionality disabled");
Self { oai: None }
}
}
}
}
#[async_trait]
impl EventHandler for BotHandler {
async fn message(&self, ctx: Context, msg: Message) {
// Ignore bot messages
if msg.author.bot {
return;
}
// Handle direct messages
if let None = msg.guild_id {
log::trace!("Received DM from {}: {}", msg.author, msg.content);
}
// Handle OAI messages
match &self.oai {
Some(oai) => {
handle_oai_messages(oai, &ctx, &msg).await;
}
None => {}
}
}
async fn ready(&self, ctx: Context, ready: Ready) {
if ready.guilds.is_empty() {
log::warn!("No ready guilds found");
}
if SONGBIRD.get().is_none() {
let songbird = songbird::get(&ctx).await.unwrap();
SONGBIRD
.set(songbird.clone())
.expect("Songbird value could not be set");
}
if CLIENT.get().is_none() {
let http_client = {
let data = ctx.data.read().await;
data
.get::<HttpKey>()
.cloned()
.expect("Guaranteed to exist in the typemap.")
};
CLIENT.set(http_client).ok();
}
// Update registered to prevent reloading the commands
if REGISTERED.get().is_some() {
return;
} else {
REGISTERED.set(true).ok();
}
log::trace!("Handling {} guilds", ready.guilds.len());
for guild in ready.guilds {
// Check if guild exists in database
let guild_id = guild.id.get() as i64;
if let None = GuildCache::find_by_id(guild_id).await {
let guild_cache = GuildCache {
id: guild_id,
name: guild.id.name(&ctx.cache),
owner_id: None,
volume: 100,
};
guild_cache.insert().await.unwrap();
}
let commands = guild
.id
.set_commands(
&ctx.http,
vec![
commands::audio::play::register(),
commands::audio::stop::register(),
commands::audio::pause::register(),
commands::audio::resume::register(),
commands::audio::mute::register(),
commands::audio::skip::register(),
commands::audio::volume::register(),
commands::event::schedule::register(),
commands::fun::roll::register(),
commands::fun::request_roll::register(),
commands::utility::ping::register(),
],
)
.await;
match commands {
Ok(c) => log::info!(
"Registered {} commands for guild {}",
c.len(),
guild.id.get(),
),
Err(why) => log::error!(
"Could not register commands for guild {}: {:?}",
guild.id.get(),
why
),
};
}
}
async fn resume(&self, _: Context, _: ResumedEvent) {
log::trace!("Resumed");
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::Command(command) = interaction {
log::trace!("Received COMMAND");
match command.data.name.as_str() {
// Match commands without returns
"play" => commands::audio::play::run(&ctx, &command).await,
"stop" => commands::audio::stop::run(&ctx, &command).await,
"pause" => commands::audio::pause::run(&ctx, &command).await,
"resume" => commands::audio::resume::run(&ctx, &command).await,
"mute" => commands::audio::mute::run(&ctx, &command).await,
"skip" => commands::audio::skip::run(&ctx, &command).await,
"volume" => commands::audio::volume::run(&ctx, &command).await,
"schedule" => commands::event::schedule::run(&ctx, &command).await,
"roll" => commands::fun::roll::run(&ctx, &command).await,
"requestroll" => commands::fun::request_roll::run(&ctx, &command).await,
"ping" => commands::utility::ping::run(&ctx, &command).await,
_ => {}
}
} else if let Interaction::Component(component) = interaction {
log::trace!("Received COMPONENT");
let custom_id = &component.data.custom_id;
if custom_id.starts_with("request_dice_roll") {
// Acknowledge the interaction
if let Err(err) = component
.create_response(ctx.http.clone(), CreateInteractionResponse::Acknowledge)
.await
{
log::error!("Could not create dice response: {}", err);
};
let parts = custom_id.split('|').collect::<Vec<&str>>();
if parts.len() == 6 {
let count = parts[1].parse().unwrap();
let sides = parts[2].parse().unwrap();
let modifier = parts[3].parse().unwrap();
let result = roll_dice(count, sides, modifier);
let response = format!("(Rolled {})", format_roll(count, sides, modifier));
let user_id = UserId::from(parts[4].parse::<u64>().unwrap());
let roller_id = component.user.id;
let hidden: bool = parts[5].parse().unwrap();
// Prepare the message based on visibility
let new_message = if hidden {
// For hidden rolls, only reveal "results sent" to the requester
format!("🎲 Results sent to {}\n-# {}", user_id.mention(), response)
} else {
// For public rolls, show the roll result
format!(
"🎲 You rolled {} {}\n-# {}",
a_or_an(&number_to_words(result)),
result,
response
)
};
// Edit the message to update the text and remove buttons
if let Err(err) = component
.edit_response(
ctx.http.clone(),
EditInteractionResponse::new()
.content(new_message)
.components(Vec::new()),
)
.await
{
log::error!("Could not update dice roll message: {}", err);
}
// Send message to the requester
send_roll_message(&ctx, result, user_id, roller_id, &response).await;
} else {
log::error!("Could not handle dice click: {}", custom_id);
}
}
} else if let Interaction::Ping(_ping) = interaction {
log::trace!("Received PING");
} else if let Interaction::Autocomplete(_autocomplete) = interaction {
log::trace!("Received AUTOCOMPLETE");
} else if let Interaction::Modal(modal) = interaction {
log::trace!("Received MODAL");
create_modal_response(&ctx, &modal).await;
}
}
}
async fn handle_oai_messages(oai: &OAI, ctx: &Context, msg: &Message) {
match msg.mentions_me(&ctx.http).await {
Ok(mentioned) => {
let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
Ok(t) => match t.iter().find(|t| t.user_id == ctx.cache.current_user().id) {
Some(_) => true,
None => false,
},
Err(_) => false,
};
if mentioned || bot_in_thread {
generate_response(&ctx, &msg, oai).await;
}
}
Err(why) => log::warn!("Could not check mentions: {why}"),
};
}