diff --git a/src/api/audio/mod.rs b/src/api/audio/mod.rs new file mode 100644 index 0000000..3f9991b --- /dev/null +++ b/src/api/audio/mod.rs @@ -0,0 +1,38 @@ +use std::sync::Arc; +use axum::extract::{Path, State}; +use axum::middleware::from_extractor; +use axum::{Extension, Json, Router}; +use axum::routing::post; +use serde::Deserialize; +use crate::api::auth::{AuthorizationMiddleware, Session}; +use crate::AppState; +use crate::bot::commands::audio::join_voice_channel; +use crate::bot::commands::audio::play::enqueue_track; +use crate::bot::handler::get_songbird; +use crate::error::SirenResult; + +pub fn get_routes() -> Router> { + Router::new() + .route("/play", post(play_audio)) + .route_layer(from_extractor::()) +} + +#[derive(Deserialize)] +struct TrackRequest { + url: String, + guild_id: u64, +} + +async fn play_audio( + Extension(session): Extension, + State(state): State>, + Json(payload): Json, +) -> SirenResult<()> { + log::debug!("Playing audio in guild: {}", payload.guild_id); + let manager = get_songbird(); + let user_id = state.cache.user(session.user_id).unwrap().id; + let guild_id = state.cache.guild(payload.guild_id).unwrap().id; + let _channel_id = join_voice_channel(&state.cache, &manager, &guild_id, &user_id).await?; + enqueue_track(manager, guild_id.to_owned(), &payload.url).await?; + Ok(()) +} diff --git a/src/api/auth/api_key.rs b/src/api/auth/api_key.rs index a8ec8df..0ee1d3b 100644 --- a/src/api/auth/api_key.rs +++ b/src/api/auth/api_key.rs @@ -2,35 +2,40 @@ use std::sync::Arc; use axum::{middleware, Extension, Router}; use axum::middleware::from_extractor; use axum::routing::post; -use crate::api::auth::{authenticate_middleware, csprng}; -use crate::api::auth::middleware::AuthorizationMiddleware; +use crate::api::auth::csprng; +use crate::api::auth::AuthorizationMiddleware; use crate::api::auth::session::Session; use crate::AppState; use crate::error::SirenResult; pub fn get_routes() -> Router> { - Router::new().route("/api-key", post(create_api_key)) + Router::new() + .route("/api-key", post(create_api_key)) .route_layer(from_extractor::()) } struct ApiKey { pub key: String, - pub user_id: String, + pub user_id: u64, pub access_mask: u32, } impl ApiKey { - fn new(user_id: String, access_mask: u32) -> Self { + fn new(user_id: u64, access_mask: u32) -> Self { ApiKey { key: csprng(64), user_id, - access_mask + access_mask, } } } async fn create_api_key(Extension(session): Extension) -> SirenResult { - log::debug!("Generating API key for {} ({})", &session.user_id, &session.user_name); + log::debug!( + "Generating API key for {} ({})", + &session.user_id, + &session.user_name + ); let api_key = ApiKey::new(session.user_id, 0); Ok(api_key.key) -} \ No newline at end of file +} diff --git a/src/api/auth/bearer_token.rs b/src/api/auth/bearer_token.rs index ff12a34..16947ac 100644 --- a/src/api/auth/bearer_token.rs +++ b/src/api/auth/bearer_token.rs @@ -2,9 +2,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] pub struct BearerTokenClaims { - pub sub: String, + pub sub: u64, pub name: String, pub iat: i64, pub exp: i64, pub jti: String, -} \ No newline at end of file +} diff --git a/src/api/auth/middleware.rs b/src/api/auth/middleware.rs index 7e54c08..5afb6cc 100644 --- a/src/api/auth/middleware.rs +++ b/src/api/auth/middleware.rs @@ -2,7 +2,11 @@ use axum::async_trait; use axum::extract::FromRequestParts; use axum::http::request::Parts; use axum::http::{Method, StatusCode}; -use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}}; +use axum::middleware::{from_extractor, FromExtractorLayer}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use chrono::Utc; use jsonwebtoken::{decode, DecodingKey, Validation}; use crate::api::auth::bearer_token::BearerTokenClaims; @@ -27,7 +31,6 @@ where let Ok(TypedHeader(Authorization(bearer))) = TypedHeader::>::from_request_parts(parts, state).await else { - log::error!("Could not get Authorization header from the request"); return Err(StatusCode::UNAUTHORIZED); }; @@ -35,7 +38,7 @@ where Ok(session) => { parts.extensions.insert(session); Ok(Self) - }, + } Err(err) => { log::error!("{:?}", err); Err(StatusCode::UNAUTHORIZED) @@ -49,11 +52,9 @@ async fn check_auth(bearer: Bearer) -> SirenResult { let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment"); let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes()); - let token_data = decode::( - bearer.token(), - &decoding_key, - &Validation::default() - ).map_err(|_| StatusCode::UNAUTHORIZED)?; + let token_data = + decode::(bearer.token(), &decoding_key, &Validation::default()) + .map_err(|_| StatusCode::UNAUTHORIZED)?; let claims = token_data.claims; @@ -68,4 +69,4 @@ async fn check_auth(bearer: Bearer) -> SirenResult { Ok(Some(session)) => Ok(session), _ => Err(StatusCode::UNAUTHORIZED)?, } -} \ No newline at end of file +} diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index 048dcc6..a5438f2 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -7,9 +7,11 @@ use crate::AppState; mod oauth; mod session; +pub use session::Session; mod api_key; mod bearer_token; mod middleware; +pub use middleware::AuthorizationMiddleware; pub fn get_routes() -> Router> { Router::new() diff --git a/src/api/auth/oauth.rs b/src/api/auth/oauth.rs index 354b2b4..bf49442 100644 --- a/src/api/auth/oauth.rs +++ b/src/api/auth/oauth.rs @@ -16,6 +16,7 @@ use crate::error::SirenResult; pub fn get_routes() -> Router> { Router::new() .route("/authorize", get(discord_authorize)) + .route("/authorize/redirect", get(discord_authorize_redirect)) .route("/callback", get(oauth_callback)) } @@ -42,14 +43,14 @@ struct DiscordUser { avatar: Option, } -// async fn discord_authorize_redirect(State(state): State>) -> Redirect { -// // Construct the Discord OAuth URL -// let discord_auth_url = format!( -// "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", -// state.client_id, state.redirect_uri -// ); -// Redirect::temporary(&discord_auth_url) -// } +async fn discord_authorize_redirect(State(state): State>) -> Redirect { + // Construct the Discord OAuth URL + let discord_auth_url = format!( + "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", + state.client_id, state.redirect_uri + ); + Redirect::temporary(&discord_auth_url) +} async fn discord_authorize(State(state): State>) -> SirenResult { // Construct the Discord OAuth URL @@ -124,13 +125,16 @@ async fn oauth_callback( log::debug!("User authenticated: {:?}", user_data); // Create and insert the session - let session = Session::new(user_data.id.clone(), user_data.username.clone()); + let session = Session::new( + user_data.id.parse::().unwrap(), + user_data.username.clone(), + ); session.insert().await?; let issued_at = chrono::Utc::now(); let claims = BearerTokenClaims { - sub: session.user_id.clone(), + sub: session.user_id, name: session.user_name.clone(), iat: issued_at.timestamp(), exp: session.expires_at.timestamp(), diff --git a/src/api/auth/session.rs b/src/api/auth/session.rs index b67a5f8..0bedbdf 100644 --- a/src/api/auth/session.rs +++ b/src/api/auth/session.rs @@ -22,13 +22,13 @@ fn get_session_ttl() -> i64 { #[derive(Serialize, Deserialize, Clone, Debug)] pub struct Session { pub session_id: String, - pub user_id: String, + pub user_id: u64, pub user_name: String, pub expires_at: DateTime, } impl Session { - pub fn new(user_id: String, user_name: String) -> Session { + pub fn new(user_id: u64, user_name: String) -> Session { let now = Utc::now(); let session_ttl = get_session_ttl(); Session { @@ -70,4 +70,4 @@ impl Session { Err(err) => Err(err.into()), } } -} \ No newline at end of file +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 9b2e9ed..e8ab5a7 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -5,8 +5,11 @@ use axum::Router; use crate::AppState; mod app; +mod audio; mod auth; pub fn get_routes() -> Router> { - Router::new().merge(auth::get_routes()) + Router::new() + .merge(auth::get_routes()) + .nest("/audio", audio::get_routes()) } diff --git a/src/bot/commands/audio/mod.rs b/src/bot/commands/audio/mod.rs index 095472e..f3275f2 100644 --- a/src/bot/commands/audio/mod.rs +++ b/src/bot/commands/audio/mod.rs @@ -1,6 +1,7 @@ -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use reqwest::Url; +use serenity::all::UserId; use serenity::client::Cache; use serenity::model::prelude::{GuildId, ChannelId}; use serenity::model::user::User; @@ -17,12 +18,6 @@ pub mod skip; pub mod stop; pub mod volume; -pub async fn get_songbird(ctx: &Context) -> Arc { - songbird::get(ctx) - .await - .expect("Songbird Voice client placed in at initialization") -} - /** * Finds a voice channel that the user is currently in, and attempts to join it. */ @@ -30,9 +25,9 @@ pub async fn join_voice_channel( cache: &Arc, manager: &Arc, guild_id: &GuildId, - user: &User, + user_id: &UserId, ) -> SirenResult { - let channel_id = find_voice_channel(cache, guild_id, user)?; + let channel_id = find_voice_channel(cache, guild_id, user_id)?; log::debug!("<{}> Joining channel {}", guild_id.get(), channel_id.get()); manager .join(guild_id.to_owned(), channel_id.to_owned()) @@ -66,7 +61,7 @@ fn is_valid_url(url: &str) -> bool { fn find_voice_channel( cache: &Arc, guild_id: &GuildId, - user: &User, + user_id: &UserId, ) -> SirenResult { let guild = match guild_id.to_guild_cached(cache) { Some(g) => g, @@ -75,7 +70,7 @@ fn find_voice_channel( match guild .voice_states - .get(&user.id) + .get(&user_id) .and_then(|voice_state| voice_state.channel_id) { Some(channel) => Ok(channel), diff --git a/src/bot/commands/audio/mute.rs b/src/bot/commands/audio/mute.rs index 07c721c..2990841 100644 --- a/src/bot/commands/audio/mute.rs +++ b/src/bot/commands/audio/mute.rs @@ -3,15 +3,14 @@ use serenity::{ prelude::*, }; use crate::bot::chat::{edit_response, process_message}; - -use super::get_songbird; +use crate::bot::handler::get_songbird; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response process_message(&ctx, &command, false).await; // Get the songbird manager - let manager = get_songbird(ctx).await; + let manager = get_songbird(); // Extract the guild ID let guild_id = match &command.guild_id { diff --git a/src/bot/commands/audio/pause.rs b/src/bot/commands/audio/pause.rs index 80210a4..866d270 100644 --- a/src/bot/commands/audio/pause.rs +++ b/src/bot/commands/audio/pause.rs @@ -4,15 +4,14 @@ use serenity::{ }; use crate::bot::chat::{edit_response, process_message}; - -use super::get_songbird; +use crate::bot::handler::get_songbird; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response process_message(&ctx, &command, false).await; // Get the songbird manager - let manager = get_songbird(ctx).await; + let manager = get_songbird(); // Extract the guild ID let guild_id = match &command.guild_id { diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index a342e0d..6a67729 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -1,6 +1,8 @@ use std::sync::Arc; -use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}; +use serenity::all::{ + Cache, CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption, Http, +}; use serenity::model::prelude::GuildId; use serenity::{prelude::*, async_trait}; use songbird::input::{Input, YoutubeDl}; @@ -12,9 +14,10 @@ use crate::bot::ytdlp::{YtDlp, YtDlpItem}; use crate::error::{SirenResult, Error as SirenError}; use crate::{signal_shutdown, HttpKey}; -use super::{get_songbird, is_valid_url, join_voice_channel}; +use super::{is_valid_url, join_voice_channel}; use crate::bot::chat::{create_message_response, edit_response, process_message}; +use crate::bot::handler::{get_client, get_songbird}; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Process the command options @@ -34,7 +37,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { process_message(&ctx, &command, false).await; // Get the songbird manager - let manager = get_songbird(ctx).await; + let manager = get_songbird(); // Extract the guild ID let guild_id = match &command.guild_id { @@ -51,13 +54,13 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { }; // Join the user's voice channel - match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user).await { + match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user.id).await { Ok(channel_id) => { log::debug!( "<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}" ); // Handle the track url - match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await { + match enqueue_track(manager, guild_id.to_owned(), track_url).await { Ok(items) => { let mut message = format!("Added {} tracks", items.len()); if items.len() == 0 { @@ -81,8 +84,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { } pub async fn enqueue_track( - ctx: &Context, - manager: Arc, + manager: &Arc, guild_id: GuildId, track_url: &str, ) -> SirenResult> { @@ -112,15 +114,9 @@ pub async fn enqueue_track( // Add each track to the queue for item in &playlist_items { let volume = guild.volume as f32 / 100.0; - let http_client = { - let data = ctx.data.read().await; - data - .get::() - .cloned() - .expect("Guaranteed to exist in the typemap.") - }; + let http_client = get_client(); - let source = YoutubeDl::new(http_client, item.get_url().to_owned()); + let source = YoutubeDl::new(http_client.to_owned(), item.get_url().to_owned()); let input: Input = source.into(); let track_title = item.get_title().to_owned(); diff --git a/src/bot/commands/audio/resume.rs b/src/bot/commands/audio/resume.rs index 6be9c85..b02472c 100644 --- a/src/bot/commands/audio/resume.rs +++ b/src/bot/commands/audio/resume.rs @@ -4,15 +4,14 @@ use serenity::{ }; use crate::bot::chat::{edit_response, process_message}; - -use super::get_songbird; +use crate::bot::handler::get_songbird; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response process_message(&ctx, &command, false).await; // Get the songbird manager - let manager = get_songbird(ctx).await; + let manager = get_songbird(); // Extract the guild ID let guild_id = match &command.guild_id { diff --git a/src/bot/commands/audio/skip.rs b/src/bot/commands/audio/skip.rs index 49790a9..a0839f8 100644 --- a/src/bot/commands/audio/skip.rs +++ b/src/bot/commands/audio/skip.rs @@ -4,15 +4,14 @@ use serenity::{ }; use crate::bot::chat::{edit_response, process_message}; - -use super::get_songbird; +use crate::bot::handler::get_songbird; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response process_message(&ctx, &command, false).await; // Get the songbird manager - let manager = get_songbird(ctx).await; + let manager = get_songbird(); // Extract the guild ID let guild_id = match &command.guild_id { diff --git a/src/bot/commands/audio/stop.rs b/src/bot/commands/audio/stop.rs index c8087d0..7933c12 100644 --- a/src/bot/commands/audio/stop.rs +++ b/src/bot/commands/audio/stop.rs @@ -4,15 +4,14 @@ use serenity::{ }; use crate::bot::chat::{edit_response, process_message}; - -use super::get_songbird; +use crate::bot::handler::get_songbird; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Create the initial response process_message(&ctx, &command, false).await; // Get the songbird manager - let manager = get_songbird(ctx).await; + let manager = get_songbird(); // Extract the guild ID let guild_id = match command.guild_id { diff --git a/src/bot/commands/audio/volume.rs b/src/bot/commands/audio/volume.rs index 213a6a9..5e0c008 100644 --- a/src/bot/commands/audio/volume.rs +++ b/src/bot/commands/audio/volume.rs @@ -10,8 +10,7 @@ use songbird::Songbird; use crate::data::guilds::GuildCache; use crate::bot::chat::{create_message_response, edit_response, process_message}; - -use super::get_songbird; +use crate::bot::handler::get_songbird; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Process the command options @@ -37,7 +36,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { process_message(&ctx, &command, false).await; // Get the songbird manager - let manager = get_songbird(ctx).await; + let manager = get_songbird(); // Extract the guild ID let guild_id = match &command.guild_id { diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 5048105..75a0969 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -1,11 +1,15 @@ +use std::env; +use std::sync::{Arc, OnceLock}; use serenity::all::{Interaction, ResumedEvent}; use serenity::async_trait; use serenity::model::gateway::Ready; use serenity::model::channel::Message; use serenity::prelude::*; +use songbird::Songbird; use crate::bot::commands::chat::generate_response; use crate::bot::oai::OAI; use crate::data::guilds::GuildCache; +use crate::HttpKey; use super::{commands}; use super::chat::{create_modal_response}; @@ -14,6 +18,43 @@ pub struct BotHandler { pub oai: Option, } +static SONGBIRD: OnceLock> = OnceLock::new(); +static CLIENT: OnceLock = OnceLock::new(); + +pub fn get_songbird() -> &'static Arc { + SONGBIRD.get().unwrap() +} + +pub fn get_client() -> &'static reqwest::Client { + CLIENT.get().unwrap() +} + +impl BotHandler { + pub fn new() -> Self { + match env::var("OPENAI_TOKEN") { + Ok(token) => { + log::debug!("OpenAI functionality enabled"); + let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()); + let base_url = env::var("OPENAI_BASE_URL").unwrap(); + Self { + oai: Some(OAI { + client: reqwest::Client::new(), + base_url, + token, + max_conversation_history: 30, + max_tokens: 8192, + default_model, + }), + } + } + Err(_) => { + log::warn!("OpenAI functionality disabled"); + Self { oai: None } + } + } + } +} + #[async_trait] impl EventHandler for BotHandler { async fn message(&self, ctx: Context, msg: Message) { @@ -40,6 +81,20 @@ impl EventHandler for BotHandler { if ready.guilds.is_empty() { log::warn!("No ready guilds found"); } + + let songbird = songbird::get(&ctx).await.unwrap(); + SONGBIRD + .set(songbird.clone()) + .expect("Songbird value could not be set"); + let http_client = { + let data = ctx.data.read().await; + data + .get::() + .cloned() + .expect("Guaranteed to exist in the typemap.") + }; + CLIENT.set(http_client).ok(); + log::trace!("Handling {} guilds", ready.guilds.len()); for guild in ready.guilds { // Check if guild exists in database diff --git a/src/bot/mod.rs b/src/bot/mod.rs index 993392b..3e64c22 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -1,4 +1,4 @@ -mod chat; +pub mod chat; pub mod commands; pub mod handler; pub mod oai; diff --git a/src/main.rs b/src/main.rs index c4d302f..5baa556 100644 --- a/src/main.rs +++ b/src/main.rs @@ -41,7 +41,7 @@ async fn main() -> Result<(), Box> { let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); // Set up handler with optional OpenAI integration - let handler = configure_handler(); + let handler = BotHandler::new(); // Set up Songbird for voice functionality let songbird = Songbird::serenity(); @@ -113,30 +113,6 @@ async fn get_bot_info(http: &Http) -> (Option, UserId) { } } -fn configure_handler() -> BotHandler { - match env::var("OPENAI_TOKEN") { - Ok(token) => { - log::debug!("OpenAI functionality enabled"); - let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()); - let base_url = env::var("OPENAI_BASE_URL").unwrap(); - BotHandler { - oai: Some(OAI { - client: reqwest::Client::new(), - base_url, - token, - max_conversation_history: 30, - max_tokens: 8192, - default_model, - }), - } - } - Err(_) => { - log::warn!("OpenAI functionality disabled"); - BotHandler { oai: None } - } - } -} - async fn signal_shutdown(shard_manager: Arc) { tokio::signal::ctrl_c() .await