191 lines
5.7 KiB
Rust
191 lines
5.7 KiB
Rust
use std::env;
|
|
use std::sync::{Arc, OnceLock};
|
|
use serenity::all::{Interaction, ResumedEvent};
|
|
use serenity::async_trait;
|
|
use serenity::model::gateway::Ready;
|
|
use serenity::model::channel::Message;
|
|
use serenity::prelude::*;
|
|
use songbird::Songbird;
|
|
use crate::bot::commands::chat::generate_response;
|
|
use crate::bot::oai::OAI;
|
|
use crate::data::guilds::GuildCache;
|
|
use crate::HttpKey;
|
|
use super::{commands};
|
|
use super::chat::{create_modal_response};
|
|
|
|
pub struct BotHandler {
|
|
// Open AI Config
|
|
pub oai: Option<OAI>,
|
|
}
|
|
|
|
static SONGBIRD: OnceLock<Arc<Songbird>> = OnceLock::new();
|
|
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
|
|
|
|
pub fn get_songbird() -> &'static Arc<Songbird> {
|
|
SONGBIRD.get().unwrap()
|
|
}
|
|
|
|
pub fn get_client() -> &'static reqwest::Client {
|
|
CLIENT.get().unwrap()
|
|
}
|
|
|
|
impl BotHandler {
|
|
pub fn new() -> Self {
|
|
match env::var("OPENAI_TOKEN") {
|
|
Ok(token) => {
|
|
log::debug!("OpenAI functionality enabled");
|
|
let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
|
|
let base_url = env::var("OPENAI_BASE_URL").unwrap();
|
|
Self {
|
|
oai: Some(OAI {
|
|
client: reqwest::Client::new(),
|
|
base_url,
|
|
token,
|
|
max_conversation_history: 30,
|
|
max_tokens: 8192,
|
|
default_model,
|
|
}),
|
|
}
|
|
}
|
|
Err(_) => {
|
|
log::warn!("OpenAI functionality disabled");
|
|
Self { oai: None }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl EventHandler for BotHandler {
|
|
async fn message(&self, ctx: Context, msg: Message) {
|
|
// Ignore bot messages
|
|
if msg.author.bot {
|
|
return;
|
|
}
|
|
|
|
// Handle direct messages
|
|
if let None = msg.guild_id {
|
|
log::trace!("Received DM from {}: {}", msg.author, msg.content);
|
|
}
|
|
|
|
// Handle OAI messages
|
|
match &self.oai {
|
|
Some(oai) => {
|
|
handle_oai_messages(oai, &ctx, &msg).await;
|
|
}
|
|
None => {}
|
|
}
|
|
}
|
|
|
|
async fn ready(&self, ctx: Context, ready: Ready) {
|
|
if ready.guilds.is_empty() {
|
|
log::warn!("No ready guilds found");
|
|
}
|
|
|
|
let songbird = songbird::get(&ctx).await.unwrap();
|
|
SONGBIRD
|
|
.set(songbird.clone())
|
|
.expect("Songbird value could not be set");
|
|
let http_client = {
|
|
let data = ctx.data.read().await;
|
|
data
|
|
.get::<HttpKey>()
|
|
.cloned()
|
|
.expect("Guaranteed to exist in the typemap.")
|
|
};
|
|
CLIENT.set(http_client).ok();
|
|
|
|
log::trace!("Handling {} guilds", ready.guilds.len());
|
|
for guild in ready.guilds {
|
|
// Check if guild exists in database
|
|
let guild_id = guild.id.get() as i64;
|
|
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
|
|
let guild_cache = GuildCache {
|
|
id: guild_id,
|
|
name: guild.id.name(&ctx.cache),
|
|
owner_id: None,
|
|
volume: 100,
|
|
};
|
|
guild_cache.insert().await.unwrap();
|
|
}
|
|
let commands = guild
|
|
.id
|
|
.set_commands(
|
|
&ctx.http,
|
|
vec![
|
|
commands::audio::play::register(),
|
|
commands::audio::stop::register(),
|
|
commands::audio::pause::register(),
|
|
commands::audio::resume::register(),
|
|
commands::audio::mute::register(),
|
|
commands::audio::skip::register(),
|
|
commands::audio::volume::register(),
|
|
commands::event::schedule::register(),
|
|
commands::fun::roll::register(),
|
|
commands::utility::ping::register(),
|
|
],
|
|
)
|
|
.await;
|
|
match commands {
|
|
Ok(c) => log::info!(
|
|
"Registered {} commands for guild {}",
|
|
c.len(),
|
|
guild.id.get(),
|
|
),
|
|
Err(why) => log::error!(
|
|
"Could not register commands for guild {}: {:?}",
|
|
guild.id.get(),
|
|
why
|
|
),
|
|
};
|
|
}
|
|
}
|
|
|
|
async fn resume(&self, _: Context, _: ResumedEvent) {
|
|
log::debug!("Resumed");
|
|
}
|
|
|
|
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
|
|
if let Interaction::Ping(ping) = interaction {
|
|
log::trace!("Received interaction ping: {:?}", ping);
|
|
} else if let Interaction::Command(command) = interaction {
|
|
log::trace!("Received command interaction: {command:#?}");
|
|
match command.data.name.as_str() {
|
|
// Match commands without returns
|
|
"play" => commands::audio::play::run(&ctx, &command).await,
|
|
"stop" => commands::audio::stop::run(&ctx, &command).await,
|
|
"pause" => commands::audio::pause::run(&ctx, &command).await,
|
|
"resume" => commands::audio::resume::run(&ctx, &command).await,
|
|
"mute" => commands::audio::mute::run(&ctx, &command).await,
|
|
"skip" => commands::audio::skip::run(&ctx, &command).await,
|
|
"volume" => commands::audio::volume::run(&ctx, &command).await,
|
|
"schedule" => commands::event::schedule::run(&ctx, &command).await,
|
|
"roll" => commands::fun::roll::run(&ctx, &command).await,
|
|
"ping" => commands::utility::ping::run(&ctx, &command).await,
|
|
_ => {}
|
|
}
|
|
} else if let Interaction::Modal(modal) = interaction {
|
|
log::trace!("Received interaction modal: {:?}", modal);
|
|
create_modal_response(&ctx, &modal).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_oai_messages(oai: &OAI, ctx: &Context, msg: &Message) {
|
|
match msg.mentions_me(&ctx.http).await {
|
|
Ok(mentioned) => {
|
|
let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
|
|
Ok(t) => match t.iter().find(|t| t.user_id == ctx.cache.current_user().id) {
|
|
Some(_) => true,
|
|
None => false,
|
|
},
|
|
Err(_) => false,
|
|
};
|
|
if mentioned || bot_in_thread {
|
|
generate_response(&ctx, &msg, oai).await;
|
|
}
|
|
}
|
|
Err(why) => log::warn!("Could not check mentions: {why}"),
|
|
};
|
|
}
|