diff --git a/.vscode/settings.json b/.vscode/settings.json index 92e36aa..fbc8bad 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "rust-analyzer.linkedProjects": [ "./service/Cargo.toml" - ] + ], + "rust-analyzer.showUnlinkedFileNotification": false } \ No newline at end of file diff --git a/service/Makefile b/service/Makefile index 895c6dc..b7fecae 100644 --- a/service/Makefile +++ b/service/Makefile @@ -13,8 +13,9 @@ help: ## Help command build: ## Build the docker image docker compose build -db: ## Start the docker database +utils: ## Start the utils docker compose up -d db + docker compose up -d rbmq up: ## Start the app docker compose up -d diff --git a/service/src/bot/api/mod.rs b/service/src/bot/api/mod.rs new file mode 100644 index 0000000..6666fdc --- /dev/null +++ b/service/src/bot/api/mod.rs @@ -0,0 +1,5 @@ +mod model; +mod routes; + +pub use model::*; +pub use routes::init_routes; \ No newline at end of file diff --git a/service/src/bot/api/model.rs b/service/src/bot/api/model.rs new file mode 100644 index 0000000..e69de29 diff --git a/service/src/bot/api/routes.rs b/service/src/bot/api/routes.rs new file mode 100644 index 0000000..2f83df0 --- /dev/null +++ b/service/src/bot/api/routes.rs @@ -0,0 +1,118 @@ +use std::sync::Arc; + +use actix_web::{get, post, put, delete, web, HttpResponse, HttpRequest, ResponseError}; +use log::warn; +use serde::{Serialize, Deserialize}; +use serenity::{http::Http, model::prelude::{GuildChannel, ChannelType}}; +use siren::ServiceError; + +#[get("/guilds")] +async fn get_guilds(data: web::Data>) -> HttpResponse { + let guild_results = &data.get_guilds(None, None).await; + let guilds = match guild_results { + Ok(guilds) => guilds, + Err(err) => return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + }; + HttpResponse::Ok().json(guilds) +} + +#[get("/{id}/text")] +async fn get_text_channels(id: web::Path, data: web::Data>) -> HttpResponse { + let channel_results = &data.get_channels(id.parse::().unwrap()).await; + let channels = match channel_results { + Ok(channels) => channels.iter().filter(|c| c.kind == ChannelType::Text).collect::>(), + Err(err) => return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + }; + HttpResponse::Ok().json(channels) +} + +#[get("/{id}/voice")] +async fn get_voice_channels(id: web::Path, data: web::Data>) -> HttpResponse { + let channel_results = &data.get_channels(id.parse::().unwrap()).await; + let channels = match channel_results { + Ok(channels) => channels.iter().filter(|c| c.kind == ChannelType::Voice).collect::>(), + Err(err) => return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + }; + HttpResponse::Ok().json(channels) +} + +#[derive(Serialize, Deserialize)] +struct ChannelMessage { + message: String +} + +#[post("/{guild_id}/text/{channel_id}/message")] +async fn send_message(path: web::Path<(String, String)>, text: web::Json, data: web::Data>) -> HttpResponse { + let (guild_id, channel_id) = path.into_inner(); + let guild_id = match guild_id.parse::() { + Ok(id) => id, + Err(err) => { + warn!("Could not parse guild id: {:?}", err); + return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + } + }; + let channel_id = match channel_id.parse::() { + Ok(id) => id, + Err(err) => { + warn!("Could not parse channel id: {:?}", err); + return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + } + }; + let channel_results = &data.get_channels(guild_id).await; + let channels = match channel_results { + Ok(channels) => channels, + Err(err) => { + warn!("Could not get channels: {:?}", err); + return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + } + }; + + let channel = match channels.iter().find(|c| c.id.0 == channel_id) { + Some(channel) => channel, + None => { + warn!("Could not find channel with id {}", channel_id); + return ResponseError::error_response(&ServiceError { + status: 422, + message: format!("Could not find channel with id {}", channel_id) + }) + } + }; + + if let Err(err) = channel.say(&data.get_ref(), &text.message).await { + warn!("Could not send message: {:?}", err); + return ResponseError::error_response(&ServiceError { + status: 422, + message: err.to_string() + }) + }; + + HttpResponse::Ok().finish() +} + +pub fn init_routes(config: &mut web::ServiceConfig) { + config + .service(get_guilds) + .service(web::scope("guilds") + .service(get_text_channels) + .service(get_voice_channels) + .service(send_message) + ); +} \ No newline at end of file diff --git a/service/src/bot/commands/message.rs b/service/src/bot/commands/message.rs new file mode 100644 index 0000000..e69de29 diff --git a/service/src/bot/commands/mod.rs b/service/src/bot/commands/mod.rs index 6f92333..a38dda8 100644 --- a/service/src/bot/commands/mod.rs +++ b/service/src/bot/commands/mod.rs @@ -1,5 +1,6 @@ pub mod audio; pub mod help; +pub mod message; pub mod oai; pub mod ping; pub mod schedule; diff --git a/service/src/bot/mod.rs b/service/src/bot/mod.rs index 43dcce0..eb1e203 100644 --- a/service/src/bot/mod.rs +++ b/service/src/bot/mod.rs @@ -1,169 +1,2 @@ -use std::collections::{HashSet, HashMap}; -use std::env; -use std::sync::Arc; - -use commands::audio::{create_response, AudioConfig, AudioConfigs}; - -use log::{error, warn, info}; -use serenity::async_trait; -use serenity::framework::StandardFramework; -use serenity::model::application::interaction::Interaction; -use serenity::model::gateway::Ready; -use serenity::model::channel::Message; -use serenity::http::Http; -use serenity::prelude::*; -use songbird::SerenityInit; - -use crate::bot::commands::oai::GPTModel; - +pub mod api; pub mod commands; - -struct Handler { - // Open AI Config - oai: Option -} - -#[async_trait] -impl EventHandler for Handler { - async fn message(&self, ctx: Context, msg: Message) { - // Ignore messages from bots - if msg.author.bot { - return; - } - match &self.oai { - Some(oai) => { - match msg.mentions_me(&ctx.http).await { - Ok(mentioned) => { - let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { - Ok(t) => { - match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) { - Some(_) => true, - None => false - } - } - Err(_) => false - }; - if mentioned || bot_in_thread { - commands::oai::generate_response(&ctx, &msg, oai).await; - } - } - Err(why) => warn!("Could not check mentions: {:?}", why) - }; - } - None => {} - } - } - - async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - if let Interaction::ApplicationCommand(command) = interaction { - match command.data.name.as_str() { - "play" => commands::audio::play::run(&ctx, &command).await, - "stop" => commands::audio::stop::run(&ctx, &command).await, - "pause" => commands::audio::pause::run(&ctx, &command).await, - "resume" => commands::audio::resume::run(&ctx, &command).await, - "skip" => commands::audio::skip::run(&ctx, &command).await, - "volume" => commands::audio::volume::run(&ctx, &command).await, - _ => { - let content: String = match command.data.name.as_str() { - "ping" => commands::ping::run(&command.data.options), - _ => "Unknown command".to_string() - }; - - if let Err(why) = create_response(&ctx, &command, content).await { - warn!("Cannot respond to slash command: {}", why); - } - } - } - } - } - - async fn ready(&self, ctx: Context, ready: Ready) { - if ready.guilds.is_empty() { - warn!("No ready guilds found"); - } - for guild in ready.guilds { - let audio_config_lock = { - let data_read = ctx.data.read().await; - data_read.get::().expect("Expected AudioConfigs in TypeMap.").clone() - }; - { - let mut audio_configs = audio_config_lock.write().await; - let _ = audio_configs.insert(guild.id, AudioConfig { volume: 1.0 }); - } - let commands = guild.id.set_application_commands(&ctx.http, |commands| { - commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::stop::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) }) - .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::volume::register(command) }) - }).await; - match commands { - Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0), - Err(why) => error!("Could not register commands for guild {}: {:?}", guild.id.0, why) - }; - } - } -} - -pub async fn run() { - let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); - let intents: GatewayIntents = GatewayIntents::all(); - - let http: Http = Http::new(&token); - let (owners, _bot_id) = match http.get_current_application_info().await { - Ok(info) => { - let mut owners: HashSet = HashSet::new(); - if let Some(team) = info.team { - owners.insert(team.owner_user_id); - } else { - owners.insert(info.owner.id); - } - match http.get_current_user().await { - Ok(bot) => (owners, bot.id), - Err(why) => panic!("Could not access the bot id: {:?}", why) - } - }, - Err(why) => panic!("Could not access application info: {:?}", why) - }; - - let handler = match env::var("OPENAI_API_KEY") { - Ok(token) => { - info!("Loaded OpenAI token"); - Handler { - oai: Some(commands::oai::OAI { - client: reqwest::Client::new(), - base_url: "https://api.openai.com/v1".to_string(), - service_url: "http://localhost:5000".to_string(), - max_attempts: 5, - token, - max_context_questions: 30, - max_tokens: 2048, - default_model: GPTModel::GPT35Turbo, - }) - } - } - Err(err) => { - warn!("Could not load OpenAI token: {}", err); - Handler { oai: None } - } - }; - - let mut client = Client::builder(token, intents) - .event_handler(handler) - .framework(StandardFramework::new() - .configure(|c| c.owners(owners))) - .register_songbird() - .await - .expect("Error creating client"); - - { - let mut data = client.data.write().await; - data.insert::(Arc::new(RwLock::new(HashMap::default()))); - } - - if let Err(why) = client.start_autosharded().await { - error!("An error occurred while running the client: {:?}", why); - } -} \ No newline at end of file diff --git a/service/src/db/spells/routes.rs b/service/src/db/spells/routes.rs index cf61605..5069476 100644 --- a/service/src/db/spells/routes.rs +++ b/service/src/db/spells/routes.rs @@ -181,8 +181,10 @@ async fn delete(id: web::Path) -> HttpResponse { } pub fn init_routes(config: &mut web::ServiceConfig) { - config.service(get_all); - config.service(get_by_id); - config.service(create); - config.service(delete); + config.service(web::scope("dnd") + .service(get_all) + .service(get_by_id) + .service(create) + .service(update) + ); } \ No newline at end of file diff --git a/service/src/lib.rs b/service/src/lib.rs index f1d97fa..b8e64bd 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -81,6 +81,12 @@ impl From for ServiceError { } } +impl From for ServiceError { + fn from(error: serenity::Error) -> ServiceError { + ServiceError::new(500, format!("Unknown serenity error: {}", error)) + } +} + impl ResponseError for ServiceError { fn error_response(&self) -> HttpResponse { let status_code = match StatusCode::from_u16(self.status) { diff --git a/service/src/main.rs b/service/src/main.rs index bd2a483..1969cf3 100644 --- a/service/src/main.rs +++ b/service/src/main.rs @@ -3,16 +3,119 @@ extern crate diesel; extern crate diesel_migrations; use std::env; +use std::collections::{HashSet, HashMap}; +use std::sync::Arc; +use bot::commands::audio::{create_response, AudioConfig, AudioConfigs}; + +use log::{error, warn, info}; +use serenity::async_trait; +use serenity::framework::StandardFramework; +use serenity::model::application::interaction::Interaction; +use serenity::model::gateway::Ready; +use serenity::model::channel::Message; +use serenity::http::Http; +use serenity::prelude::*; +use songbird::SerenityInit; + +use crate::bot::commands::oai::GPTModel; use actix_cors::Cors; -use actix_web::{HttpServer, App}; +use actix_web::{HttpServer, App, web}; use dotenv::dotenv; -use log::{error, info, warn}; mod bot; mod db; +struct Handler { + // Open AI Config + oai: Option +} + +#[async_trait] +impl EventHandler for Handler { + async fn message(&self, ctx: Context, msg: Message) { + // Ignore messages from bots + if msg.author.bot { + return; + } + match &self.oai { + Some(oai) => { + match msg.mentions_me(&ctx.http).await { + Ok(mentioned) => { + let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { + Ok(t) => { + match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) { + Some(_) => true, + None => false + } + } + Err(_) => false + }; + if mentioned || bot_in_thread { + bot::commands::oai::generate_response(&ctx, &msg, oai).await; + } + } + Err(why) => warn!("Could not check mentions: {:?}", why) + }; + } + None => {} + } + } + + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + if let Interaction::ApplicationCommand(command) = interaction { + match command.data.name.as_str() { + "play" => bot::commands::audio::play::run(&ctx, &command).await, + "stop" => bot::commands::audio::stop::run(&ctx, &command).await, + "pause" => bot::commands::audio::pause::run(&ctx, &command).await, + "resume" => bot::commands::audio::resume::run(&ctx, &command).await, + "skip" => bot::commands::audio::skip::run(&ctx, &command).await, + "volume" => bot::commands::audio::volume::run(&ctx, &command).await, + _ => { + let content: String = match command.data.name.as_str() { + "ping" => bot::commands::ping::run(&command.data.options), + _ => "Unknown command".to_string() + }; + + if let Err(why) = create_response(&ctx, &command, content).await { + warn!("Cannot respond to slash command: {}", why); + } + } + } + } + } + + async fn ready(&self, ctx: Context, ready: Ready) { + if ready.guilds.is_empty() { + warn!("No ready guilds found"); + } + for guild in ready.guilds { + let audio_config_lock = { + let data_read = ctx.data.read().await; + data_read.get::().expect("Expected AudioConfigs in TypeMap.").clone() + }; + { + let mut audio_configs = audio_config_lock.write().await; + let _ = audio_configs.insert(guild.id, AudioConfig { volume: 1.0 }); + } + let commands = guild.id.set_application_commands(&ctx.http, |commands| { + commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { bot::commands::ping::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { bot::commands::audio::play::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { bot::commands::audio::stop::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { bot::commands::audio::pause::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { bot::commands::audio::resume::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { bot::commands::audio::skip::register(command) }) + .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { bot::commands::audio::volume::register(command) }) + }).await; + match commands { + Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0), + Err(why) => error!("Could not register commands for guild {}: {:?}", guild.id.0, why) + }; + } + } +} + #[actix_web::main] async fn main() -> std::io::Result<()> { dotenv().ok(); @@ -23,20 +126,83 @@ async fn main() -> std::io::Result<()> { Err(err) => warn!("Unable to load initial database data: {}", err) }; + let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); + let intents: GatewayIntents = GatewayIntents::all(); + + let http: Http = Http::new(&token); + let (owners, _bot_id) = match http.get_current_application_info().await { + Ok(info) => { + let mut owners: HashSet = HashSet::new(); + if let Some(team) = info.team { + owners.insert(team.owner_user_id); + } else { + owners.insert(info.owner.id); + } + match http.get_current_user().await { + Ok(bot) => (owners, bot.id), + Err(why) => panic!("Could not access the bot id: {:?}", why) + } + }, + Err(why) => panic!("Could not access application info: {:?}", why) + }; + + let handler = match env::var("OPENAI_API_KEY") { + Ok(token) => { + info!("Loaded OpenAI token"); + Handler { + oai: Some(bot::commands::oai::OAI { + client: reqwest::Client::new(), + base_url: "https://api.openai.com/v1".to_string(), + service_url: "http://localhost:5000".to_string(), + max_attempts: 5, + token, + max_context_questions: 30, + max_tokens: 2048, + default_model: GPTModel::GPT35Turbo, + }) + } + } + Err(err) => { + warn!("Could not load OpenAI token: {}", err); + Handler { oai: None } + } + }; + + let mut client = Client::builder(token, intents) + .event_handler(handler) + .framework(StandardFramework::new() + .configure(|c| c.owners(owners))) + .register_songbird() + .await + .expect("Error creating client"); + + { + let mut data = client.data.write().await; + data.insert::(Arc::new(RwLock::new(HashMap::default()))); + } + + let bot_http = Arc::clone(&client.cache_and_http.http); + + tokio::spawn(async move { + if let Err(why) = client.start_autosharded().await { + error!("An error occurred while running the client: {:?}", why); + } + }); + let host = env::var("SERVICE_HOST").unwrap_or("localhost".to_string()); let port = env::var("SERVICE_PORT").unwrap_or("5000".to_string()); - tokio::spawn(bot::run()); - - match HttpServer::new(|| { + let server = match HttpServer::new(move || { let cors = Cors::default() .allow_any_origin() .allow_any_method() .allow_any_header() .max_age(3600); App::new() - .configure(db::messages::init_routes) - .configure(db::spells::init_routes) + .app_data(web::Data::new(Arc::clone(&bot_http))) + .configure(crate::db::messages::init_routes) + .configure(crate::db::spells::init_routes) + .configure(crate::bot::api::init_routes) .wrap(cors) }) .bind(format!("{}:{}", host, port)) { @@ -48,7 +214,8 @@ async fn main() -> std::io::Result<()> { error!("Could not bind server: {}", err); return Err(err); } - } - .run() + }; + + server.run() .await }