diff --git a/bruno/audio/Pause Track.bru b/bruno/audio/Pause Track.bru index ea1f83e..36a8cf6 100644 --- a/bruno/audio/Pause Track.bru +++ b/bruno/audio/Pause Track.bru @@ -5,13 +5,7 @@ meta { } post { - url: {{baseUrl}}/audio/pause + url: {{baseUrl}}/audio/1061092965579235398/pause body: json auth: inherit } - -body:json { - { - "guild_id": 1061092965579235398 - } -} diff --git a/bruno/audio/Play Track.bru b/bruno/audio/Play Track.bru index 67800e7..f2c9821 100644 --- a/bruno/audio/Play Track.bru +++ b/bruno/audio/Play Track.bru @@ -5,14 +5,13 @@ meta { } post { - url: {{baseUrl}}/audio/play + url: {{baseUrl}}/audio/1061092965579235398/play body: json auth: inherit } body:json { { - "url": "https://www.youtube.com/watch?v=V-QDxuknK-Q", - "guild_id": 1061092965579235398 + "url": "https://www.youtube.com/watch?v=V-QDxuknK-Q" } } diff --git a/bruno/audio/Resume Track.bru b/bruno/audio/Resume Track.bru index 5c633a9..33e0186 100644 --- a/bruno/audio/Resume Track.bru +++ b/bruno/audio/Resume Track.bru @@ -5,13 +5,7 @@ meta { } post { - url: {{baseUrl}}/audio/resume + url: {{baseUrl}}/audio/1061092965579235398/resume body: json auth: inherit } - -body:json { - { - "guild_id": 1061092965579235398 - } -} diff --git a/bruno/collection.bru b/bruno/collection.bru index 2a7569a..0157b98 100644 --- a/bruno/collection.bru +++ b/bruno/collection.bru @@ -1,11 +1,12 @@ auth { - mode: bearer + mode: apikey } -auth:bearer { - token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDcwNDI3NSwiZXhwIjoxNzM0NzkwNjc1LCJqdGkiOiJMSnc1Vnk3azZjc1BiYlJRWGlNcVFFVUZlQ29JS2JqcCJ9.sdgb93DmX9_augMdktYr58m5eTIJPuY13d87pckZOns +auth:apikey { + key: X-API-Key + value: rwOS4yMmNpQvL0vLHc1jWQoefJB1bvKvOvBSswiYh0mkhZDc1lsgFZmpXaSUXAa5ZjpRWR117hLQ1l0VPPSGkRXZl7dPRVCc + placement: header } - vars:pre-request { baseUrl: http://localhost:3000/api } diff --git a/bruno/oauth/Create API Key.bru b/bruno/oauth/Create API Key.bru new file mode 100644 index 0000000..848acb1 --- /dev/null +++ b/bruno/oauth/Create API Key.bru @@ -0,0 +1,11 @@ +meta { + name: Create API Key + type: http + seq: 2 +} + +post { + url: {{baseUrl}}/api-key + body: none + auth: inherit +} diff --git a/migrations/000_initial.sql b/migrations/000_initial.sql index 59d5246..8532b5e 100644 --- a/migrations/000_initial.sql +++ b/migrations/000_initial.sql @@ -16,6 +16,12 @@ CREATE TABLE IF NOT EXISTS messages ( request_tags TEXT[] NOT NULL, response_tags TEXT[] NOT NULL ); +CREATE TABLE IF NOT EXISTS api_keys ( + key TEXT PRIMARY KEY NOT NULL, + user_id BIGINT NOT NULL, + user_name TEXT NOT NULL, + access_mask INT +); CREATE TABLE IF NOT EXISTS dice_thresholds ( id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(), owner_id BIGINT NOT NULL, diff --git a/src/api/audio/mod.rs b/src/api/audio/mod.rs index ae71db2..02fbb63 100644 --- a/src/api/audio/mod.rs +++ b/src/api/audio/mod.rs @@ -1,12 +1,12 @@ use std::sync::Arc; -use axum::extract::State; +use axum::extract::{Path, State}; use axum::middleware::from_extractor; use axum::{Extension, Json, Router}; use axum::response::IntoResponse; use axum::routing::post; use reqwest::StatusCode; use serde::Deserialize; -use crate::api::auth::{AuthorizationMiddleware, Session}; +use crate::api::auth::{AuthCredential, AuthorizationMiddleware, Session}; use crate::AppState; use crate::bot::commands::audio::join_voice_channel; use crate::bot::commands::audio::pause::pause_track; @@ -28,24 +28,25 @@ pub fn get_routes() -> Router> { #[derive(Deserialize)] struct PlayTrackRequest { url: String, - guild_id: u64, } async fn play_audio( - Extension(session): Extension, + Extension(credential): Extension, State(state): State>, + Path(guild_id): Path, Json(payload): Json, ) -> SirenResult<()> { - log::debug!("Playing audio in guild: {}", payload.guild_id); + log::debug!("Playing audio in guild: {}", guild_id); // Check if the user exists in the cache - let user_id = match state.cache.user(session.user_id) { + let user_id = credential.user_id(); + let user_id = match state.cache.user(user_id) { Some(user) => user.id, None => return Err(Error::not_found("User not found".to_string())), }; // Validate if the guild exists in the cache - let guild_id = match state.cache.guild(payload.guild_id) { + let guild_id = match state.cache.guild(guild_id) { Some(guild) => guild.id, None => return Err(Error::not_found("Guild not found".to_string())), }; @@ -57,20 +58,15 @@ async fn play_audio( Ok(()) } -#[derive(Deserialize)] -struct GuildTrackRequest { - guild_id: u64, -} - async fn pause_audio( - Extension(_): Extension, + Extension(_): Extension, State(state): State>, - Json(payload): Json, + Path(guild_id): Path, ) -> SirenResult<()> { - log::debug!("Pausing audio in guild: {}", payload.guild_id); + log::debug!("Pausing audio in guild: {}", guild_id); // Validate if the guild exists in the cache - let guild_id = match state.cache.guild(payload.guild_id) { + let guild_id = match state.cache.guild(guild_id) { Some(guild) => guild.id, None => return Err(Error::not_found("Guild not found".to_string())), }; @@ -81,14 +77,14 @@ async fn pause_audio( } async fn resume_audio( - Extension(_): Extension, + Extension(_): Extension, State(state): State>, - Json(payload): Json, + Path(guild_id): Path, ) -> SirenResult<()> { - log::debug!("Pausing audio in guild: {}", payload.guild_id); + log::debug!("Pausing audio in guild: {}", guild_id); // Validate if the guild exists in the cache - let guild_id = match state.cache.guild(payload.guild_id) { + let guild_id = match state.cache.guild(guild_id) { Some(guild) => guild.id, None => return Err(Error::not_found("Guild not found".to_string())), }; diff --git a/src/api/auth/api_key.rs b/src/api/auth/api_key.rs index 213cf19..ea8b31a 100644 --- a/src/api/auth/api_key.rs +++ b/src/api/auth/api_key.rs @@ -2,11 +2,14 @@ use std::sync::Arc; use axum::{Extension, Router}; use axum::middleware::from_extractor; use axum::routing::post; -use crate::api::auth::csprng; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use crate::api::auth::{csprng, AuthCredential}; use crate::api::auth::AuthorizationMiddleware; use crate::api::auth::session::Session; use crate::AppState; -use crate::error::SirenResult; +use crate::data::query::{Condition, QueryBuilder}; +use crate::error::{Error, SirenResult}; pub fn get_routes() -> Router> { Router::new() @@ -14,28 +17,82 @@ pub fn get_routes() -> Router> { .route_layer(from_extractor::()) } -struct ApiKey { +const TABLE_NAME: &str = "api_keys"; + +#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)] +pub struct ApiKey { pub key: String, - pub user_id: u64, - pub access_mask: u32, + pub user_id: i64, + pub user_name: String, + pub access_mask: i32, } impl ApiKey { - fn new(user_id: u64, access_mask: u32) -> Self { + fn new(user_id: u64, user_name: String, access_mask: i32) -> Self { ApiKey { - key: csprng(64), - user_id, + key: csprng(96), + user_id: user_id as i64, + user_name, access_mask, } } + + pub async fn insert(&self) -> SirenResult<()> { + let pool = crate::data::pool(); + sqlx::query(&format!( + "INSERT INTO {} ( + key, + user_id, + user_name, + access_mask + ) VALUES ( + $1, $2, $3, $4 + )", + TABLE_NAME + )) + .bind(&self.key) + .bind(self.user_id) + .bind(&self.user_name) + .bind(self.access_mask) + .execute(pool) + .await?; + Ok(()) + } + + pub async fn find_by_key(key: &str) -> SirenResult> { + let pool = crate::data::pool(); + let query = QueryBuilder::new(TABLE_NAME) + .where_condition(Condition::is_equal("key", "$1")) + .build(); + let item = sqlx::query_as(&query) + .bind(key) + .fetch_optional(pool) + .await?; + + Ok(item) + } + + pub async fn delete_by_id(key: &str) -> SirenResult<()> { + let pool = crate::data::pool(); + sqlx::query(&format!("DELETE FROM {} WHERE key = $1", TABLE_NAME)) + .bind(key) + .execute(pool) + .await?; + Ok(()) + } } -async fn create_api_key(Extension(session): Extension) -> SirenResult { +async fn create_api_key(Extension(credential): Extension) -> SirenResult { + let session = match credential { + AuthCredential::ApiKey(_) => return Err(Error::new(400, "API keys cannot be generated with an API key".to_string())), + AuthCredential::Session(session) => session + }; log::debug!( "Generating API key for {} ({})", &session.user_id, &session.user_name ); - let api_key = ApiKey::new(session.user_id, 0); + let api_key = ApiKey::new(session.user_id, session.user_name, 0); + api_key.insert().await?; Ok(api_key.key) } diff --git a/src/api/auth/middleware.rs b/src/api/auth/middleware.rs index 1defe2e..fcc7477 100644 --- a/src/api/auth/middleware.rs +++ b/src/api/auth/middleware.rs @@ -8,6 +8,8 @@ use axum_extra::{ }; use chrono::Utc; use jsonwebtoken::{decode, DecodingKey, Validation}; +use crate::api::auth::api_key::ApiKey; +use crate::api::auth::AuthCredential; use crate::api::auth::bearer_token::BearerTokenClaims; use crate::api::auth::session::Session; use crate::error::SirenResult; @@ -27,33 +29,47 @@ where return Ok(Self); } - let Ok(TypedHeader(Authorization(bearer))) = + // Check for a Bearer token in the `Authorization` header. + if let Ok(TypedHeader(Authorization(bearer))) = TypedHeader::>::from_request_parts(parts, state).await - else { - return Err(StatusCode::UNAUTHORIZED); - }; - - match check_auth(bearer).await { - Ok(session) => { - parts.extensions.insert(session); - Ok(Self) - } - Err(err) => { - log::error!("{:?}", err); - Err(StatusCode::UNAUTHORIZED) - } + { + return match check_bearer_auth(bearer.token()).await { + Ok(session) => { + parts.extensions.insert(AuthCredential::Session(session)); + Ok(Self) + } + Err(_) => Err(StatusCode::UNAUTHORIZED), + }; } + + // Check for an API key in the custom `X-API-Key` header. + if let Some(api_key_header) = parts.headers.get("X-API-Key") { + return if let Ok(api_key) = api_key_header.to_str() { + match check_api_key_auth(api_key).await { + Ok(api_key) => { + parts.extensions.insert(AuthCredential::ApiKey(api_key)); + Ok(Self) + } + Err(_) => Err(StatusCode::UNAUTHORIZED), + } + } else { + // Invalid header value + Err(StatusCode::BAD_REQUEST) + }; + } + + // If neither the Bearer token nor API key is present or valid, return `UNAUTHORIZED` + Err(StatusCode::UNAUTHORIZED) } } -async fn check_auth(bearer: Bearer) -> SirenResult { +async fn check_bearer_auth(bearer_token: &str) -> SirenResult { // Decode and validate the JWT let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment"); let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes()); - let token_data = - decode::(bearer.token(), &decoding_key, &Validation::default()) - .map_err(|_| StatusCode::UNAUTHORIZED)?; + let token_data = decode::(bearer_token, &decoding_key, &Validation::default()) + .map_err(|_| StatusCode::UNAUTHORIZED)?; let claims = token_data.claims; @@ -69,3 +85,13 @@ async fn check_auth(bearer: Bearer) -> SirenResult { _ => Err(StatusCode::UNAUTHORIZED)?, } } + +async fn check_api_key_auth(key: &str) -> SirenResult { + + let api_key = match ApiKey::find_by_key(key).await? { + Some(api_key) => api_key, + None => return Err(StatusCode::UNAUTHORIZED.into()), + }; + + Ok(api_key) +} diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index a5438f2..c5ff9b1 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -3,6 +3,7 @@ use axum::Router; use rand::Rng; use rand_chacha::ChaCha20Rng; use rand_chacha::rand_core::SeedableRng; +use serde::{Deserialize, Serialize}; use crate::AppState; mod oauth; @@ -12,6 +13,29 @@ mod api_key; mod bearer_token; mod middleware; pub use middleware::AuthorizationMiddleware; +use crate::api::auth::api_key::ApiKey; + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum AuthCredential { + Session(Session), + ApiKey(ApiKey), +} + +impl AuthCredential { + pub fn user_id(&self) -> u64 { + match self { + AuthCredential::Session(session) => session.user_id, + AuthCredential::ApiKey(api_key) => api_key.user_id as u64, + } + } + + pub fn user_name(&self) -> String { + match self { + AuthCredential::Session(session) => session.user_name.clone(), + AuthCredential::ApiKey(api_key) => api_key.user_name.clone(), + } + } +} pub fn get_routes() -> Router> { Router::new() diff --git a/src/api/mod.rs b/src/api/mod.rs index e8ab5a7..2992c01 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -11,5 +11,5 @@ mod auth; pub fn get_routes() -> Router> { Router::new() .merge(auth::get_routes()) - .nest("/audio", audio::get_routes()) + .nest("/audio/:guild_id", audio::get_routes()) } diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index 7f90348..ea38582 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -89,7 +89,9 @@ pub async fn enqueue_track( let mut playlist_items: Vec = Vec::new(); if let Some(handler_lock) = manager.get(guild_id) { let mut handler = handler_lock.lock().await; - let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap(); + let guild = GuildCache::find_by_id(guild_id.get() as i64) + .await? + .unwrap(); let valid = is_valid_url(&track_url); // Check if the URL is valid diff --git a/src/bot/commands/audio/volume.rs b/src/bot/commands/audio/volume.rs index 5e0c008..98869e8 100644 --- a/src/bot/commands/audio/volume.rs +++ b/src/bot/commands/audio/volume.rs @@ -64,7 +64,7 @@ pub async fn set_volume(manager: &Arc, guild_id: &GuildId, volume: i32 let bound_volume = volume as f32 / 100.0; // Update the guild cache - let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64) + let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64) .await .unwrap() .unwrap(); diff --git a/src/bot/handler.rs b/src/bot/handler.rs index e514b77..6832739 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -108,7 +108,7 @@ impl EventHandler for BotHandler { for guild in ready.guilds { // Check if guild exists in database let guild_id = guild.id.get() as i64; - if let None = GuildCache::get_by_id(guild_id).await.unwrap() { + if let None = GuildCache::find_by_id(guild_id).await.unwrap() { let guild_cache = GuildCache { id: guild_id, name: guild.id.name(&ctx.cache), diff --git a/src/data/guilds/model.rs b/src/data/guilds/model.rs index 145607c..84f94cb 100644 --- a/src/data/guilds/model.rs +++ b/src/data/guilds/model.rs @@ -4,7 +4,7 @@ use crate::error::SirenResult; const TABLE_NAME: &str = "guilds"; -#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)] +#[derive(Serialize, Deserialize, sqlx::FromRow, Debug)] pub struct GuildCache { pub id: i64, pub name: Option, @@ -35,7 +35,7 @@ impl GuildCache { Ok(()) } - pub async fn get_by_id(id: i64) -> SirenResult> { + pub async fn find_by_id(id: i64) -> SirenResult> { let pool = crate::data::pool(); let query = QueryBuilder::new(TABLE_NAME) .where_condition(Condition::is_equal("id", "$1")) // Use a placeholder diff --git a/src/data/mod.rs b/src/data/mod.rs index c90e036..15ae33a 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -7,7 +7,7 @@ use crate::error::SirenResult; pub mod events; pub mod guilds; pub mod messages; -mod query; +pub mod query; static POOL: OnceLock> = OnceLock::new(); static REDIS: OnceLock = OnceLock::new();