Implemented API Key creation/usage and changed layout of audio requests

This commit is contained in:
2024-12-21 12:03:02 -05:00
parent 4d9ee81ecf
commit ceea975836
16 changed files with 186 additions and 76 deletions

View File

@@ -5,13 +5,7 @@ meta {
} }
post { post {
url: {{baseUrl}}/audio/pause url: {{baseUrl}}/audio/1061092965579235398/pause
body: json body: json
auth: inherit auth: inherit
} }
body:json {
{
"guild_id": 1061092965579235398
}
}

View File

@@ -5,14 +5,13 @@ meta {
} }
post { post {
url: {{baseUrl}}/audio/play url: {{baseUrl}}/audio/1061092965579235398/play
body: json body: json
auth: inherit auth: inherit
} }
body:json { body:json {
{ {
"url": "https://www.youtube.com/watch?v=V-QDxuknK-Q", "url": "https://www.youtube.com/watch?v=V-QDxuknK-Q"
"guild_id": 1061092965579235398
} }
} }

View File

@@ -5,13 +5,7 @@ meta {
} }
post { post {
url: {{baseUrl}}/audio/resume url: {{baseUrl}}/audio/1061092965579235398/resume
body: json body: json
auth: inherit auth: inherit
} }
body:json {
{
"guild_id": 1061092965579235398
}
}

View File

@@ -1,11 +1,12 @@
auth { auth {
mode: bearer mode: apikey
} }
auth:bearer { auth:apikey {
token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDcwNDI3NSwiZXhwIjoxNzM0NzkwNjc1LCJqdGkiOiJMSnc1Vnk3azZjc1BiYlJRWGlNcVFFVUZlQ29JS2JqcCJ9.sdgb93DmX9_augMdktYr58m5eTIJPuY13d87pckZOns key: X-API-Key
value: rwOS4yMmNpQvL0vLHc1jWQoefJB1bvKvOvBSswiYh0mkhZDc1lsgFZmpXaSUXAa5ZjpRWR117hLQ1l0VPPSGkRXZl7dPRVCc
placement: header
} }
vars:pre-request { vars:pre-request {
baseUrl: http://localhost:3000/api baseUrl: http://localhost:3000/api
} }

View File

@@ -0,0 +1,11 @@
meta {
name: Create API Key
type: http
seq: 2
}
post {
url: {{baseUrl}}/api-key
body: none
auth: inherit
}

View File

@@ -16,6 +16,12 @@ CREATE TABLE IF NOT EXISTS messages (
request_tags TEXT[] NOT NULL, request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL response_tags TEXT[] NOT NULL
); );
CREATE TABLE IF NOT EXISTS api_keys (
key TEXT PRIMARY KEY NOT NULL,
user_id BIGINT NOT NULL,
user_name TEXT NOT NULL,
access_mask INT
);
CREATE TABLE IF NOT EXISTS dice_thresholds ( CREATE TABLE IF NOT EXISTS dice_thresholds (
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(), id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(),
owner_id BIGINT NOT NULL, owner_id BIGINT NOT NULL,

View File

@@ -1,12 +1,12 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State; use axum::extract::{Path, State};
use axum::middleware::from_extractor; use axum::middleware::from_extractor;
use axum::{Extension, Json, Router}; use axum::{Extension, Json, Router};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::post; use axum::routing::post;
use reqwest::StatusCode; use reqwest::StatusCode;
use serde::Deserialize; use serde::Deserialize;
use crate::api::auth::{AuthorizationMiddleware, Session}; use crate::api::auth::{AuthCredential, AuthorizationMiddleware, Session};
use crate::AppState; use crate::AppState;
use crate::bot::commands::audio::join_voice_channel; use crate::bot::commands::audio::join_voice_channel;
use crate::bot::commands::audio::pause::pause_track; use crate::bot::commands::audio::pause::pause_track;
@@ -28,24 +28,25 @@ pub fn get_routes() -> Router<Arc<AppState>> {
#[derive(Deserialize)] #[derive(Deserialize)]
struct PlayTrackRequest { struct PlayTrackRequest {
url: String, url: String,
guild_id: u64,
} }
async fn play_audio( async fn play_audio(
Extension(session): Extension<Session>, Extension(credential): Extension<AuthCredential>,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
Json(payload): Json<PlayTrackRequest>, Json(payload): Json<PlayTrackRequest>,
) -> SirenResult<()> { ) -> SirenResult<()> {
log::debug!("Playing audio in guild: {}", payload.guild_id); log::debug!("Playing audio in guild: {}", guild_id);
// Check if the user exists in the cache // Check if the user exists in the cache
let user_id = match state.cache.user(session.user_id) { let user_id = credential.user_id();
let user_id = match state.cache.user(user_id) {
Some(user) => user.id, Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())), None => return Err(Error::not_found("User not found".to_string())),
}; };
// Validate if the guild exists in the cache // Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) { let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id, Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())), None => return Err(Error::not_found("Guild not found".to_string())),
}; };
@@ -57,20 +58,15 @@ async fn play_audio(
Ok(()) Ok(())
} }
#[derive(Deserialize)]
struct GuildTrackRequest {
guild_id: u64,
}
async fn pause_audio( async fn pause_audio(
Extension(_): Extension<Session>, Extension(_): Extension<AuthCredential>,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(payload): Json<GuildTrackRequest>, Path(guild_id): Path<u64>,
) -> SirenResult<()> { ) -> SirenResult<()> {
log::debug!("Pausing audio in guild: {}", payload.guild_id); log::debug!("Pausing audio in guild: {}", guild_id);
// Validate if the guild exists in the cache // Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) { let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id, Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())), None => return Err(Error::not_found("Guild not found".to_string())),
}; };
@@ -81,14 +77,14 @@ async fn pause_audio(
} }
async fn resume_audio( async fn resume_audio(
Extension(_): Extension<Session>, Extension(_): Extension<AuthCredential>,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(payload): Json<GuildTrackRequest>, Path(guild_id): Path<u64>,
) -> SirenResult<()> { ) -> SirenResult<()> {
log::debug!("Pausing audio in guild: {}", payload.guild_id); log::debug!("Pausing audio in guild: {}", guild_id);
// Validate if the guild exists in the cache // Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) { let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id, Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())), None => return Err(Error::not_found("Guild not found".to_string())),
}; };

View File

@@ -2,11 +2,14 @@ use std::sync::Arc;
use axum::{Extension, Router}; use axum::{Extension, Router};
use axum::middleware::from_extractor; use axum::middleware::from_extractor;
use axum::routing::post; use axum::routing::post;
use crate::api::auth::csprng; use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use crate::api::auth::{csprng, AuthCredential};
use crate::api::auth::AuthorizationMiddleware; use crate::api::auth::AuthorizationMiddleware;
use crate::api::auth::session::Session; use crate::api::auth::session::Session;
use crate::AppState; use crate::AppState;
use crate::error::SirenResult; use crate::data::query::{Condition, QueryBuilder};
use crate::error::{Error, SirenResult};
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()
@@ -14,28 +17,82 @@ pub fn get_routes() -> Router<Arc<AppState>> {
.route_layer(from_extractor::<AuthorizationMiddleware>()) .route_layer(from_extractor::<AuthorizationMiddleware>())
} }
struct ApiKey { const TABLE_NAME: &str = "api_keys";
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct ApiKey {
pub key: String, pub key: String,
pub user_id: u64, pub user_id: i64,
pub access_mask: u32, pub user_name: String,
pub access_mask: i32,
} }
impl ApiKey { impl ApiKey {
fn new(user_id: u64, access_mask: u32) -> Self { fn new(user_id: u64, user_name: String, access_mask: i32) -> Self {
ApiKey { ApiKey {
key: csprng(64), key: csprng(96),
user_id, user_id: user_id as i64,
user_name,
access_mask, access_mask,
} }
} }
pub async fn insert(&self) -> SirenResult<()> {
let pool = crate::data::pool();
sqlx::query(&format!(
"INSERT INTO {} (
key,
user_id,
user_name,
access_mask
) VALUES (
$1, $2, $3, $4
)",
TABLE_NAME
))
.bind(&self.key)
.bind(self.user_id)
.bind(&self.user_name)
.bind(self.access_mask)
.execute(pool)
.await?;
Ok(())
}
pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> {
let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("key", "$1"))
.build();
let item = sqlx::query_as(&query)
.bind(key)
.fetch_optional(pool)
.await?;
Ok(item)
}
pub async fn delete_by_id(key: &str) -> SirenResult<()> {
let pool = crate::data::pool();
sqlx::query(&format!("DELETE FROM {} WHERE key = $1", TABLE_NAME))
.bind(key)
.execute(pool)
.await?;
Ok(())
}
} }
async fn create_api_key(Extension(session): Extension<Session>) -> SirenResult<String> { async fn create_api_key(Extension(credential): Extension<AuthCredential>) -> SirenResult<String> {
let session = match credential {
AuthCredential::ApiKey(_) => return Err(Error::new(400, "API keys cannot be generated with an API key".to_string())),
AuthCredential::Session(session) => session
};
log::debug!( log::debug!(
"Generating API key for {} ({})", "Generating API key for {} ({})",
&session.user_id, &session.user_id,
&session.user_name &session.user_name
); );
let api_key = ApiKey::new(session.user_id, 0); let api_key = ApiKey::new(session.user_id, session.user_name, 0);
api_key.insert().await?;
Ok(api_key.key) Ok(api_key.key)
} }

View File

@@ -8,6 +8,8 @@ use axum_extra::{
}; };
use chrono::Utc; use chrono::Utc;
use jsonwebtoken::{decode, DecodingKey, Validation}; use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::api::auth::api_key::ApiKey;
use crate::api::auth::AuthCredential;
use crate::api::auth::bearer_token::BearerTokenClaims; use crate::api::auth::bearer_token::BearerTokenClaims;
use crate::api::auth::session::Session; use crate::api::auth::session::Session;
use crate::error::SirenResult; use crate::error::SirenResult;
@@ -27,32 +29,46 @@ where
return Ok(Self); return Ok(Self);
} }
let Ok(TypedHeader(Authorization(bearer))) = // Check for a Bearer token in the `Authorization` header.
if let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
else { {
return Err(StatusCode::UNAUTHORIZED); return match check_bearer_auth(bearer.token()).await {
};
match check_auth(bearer).await {
Ok(session) => { Ok(session) => {
parts.extensions.insert(session); parts.extensions.insert(AuthCredential::Session(session));
Ok(Self) Ok(Self)
} }
Err(err) => { Err(_) => Err(StatusCode::UNAUTHORIZED),
log::error!("{:?}", err); };
}
// Check for an API key in the custom `X-API-Key` header.
if let Some(api_key_header) = parts.headers.get("X-API-Key") {
return if let Ok(api_key) = api_key_header.to_str() {
match check_api_key_auth(api_key).await {
Ok(api_key) => {
parts.extensions.insert(AuthCredential::ApiKey(api_key));
Ok(Self)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
} else {
// Invalid header value
Err(StatusCode::BAD_REQUEST)
};
}
// If neither the Bearer token nor API key is present or valid, return `UNAUTHORIZED`
Err(StatusCode::UNAUTHORIZED) Err(StatusCode::UNAUTHORIZED)
} }
}
}
} }
async fn check_auth(bearer: Bearer) -> SirenResult<Session> { async fn check_bearer_auth(bearer_token: &str) -> SirenResult<Session> {
// Decode and validate the JWT // Decode and validate the JWT
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment"); let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes()); let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
let token_data = let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
decode::<BearerTokenClaims>(bearer.token(), &decoding_key, &Validation::default())
.map_err(|_| StatusCode::UNAUTHORIZED)?; .map_err(|_| StatusCode::UNAUTHORIZED)?;
let claims = token_data.claims; let claims = token_data.claims;
@@ -69,3 +85,13 @@ async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
_ => Err(StatusCode::UNAUTHORIZED)?, _ => Err(StatusCode::UNAUTHORIZED)?,
} }
} }
async fn check_api_key_auth(key: &str) -> SirenResult<ApiKey> {
let api_key = match ApiKey::find_by_key(key).await? {
Some(api_key) => api_key,
None => return Err(StatusCode::UNAUTHORIZED.into()),
};
Ok(api_key)
}

View File

@@ -3,6 +3,7 @@ use axum::Router;
use rand::Rng; use rand::Rng;
use rand_chacha::ChaCha20Rng; use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::SeedableRng; use rand_chacha::rand_core::SeedableRng;
use serde::{Deserialize, Serialize};
use crate::AppState; use crate::AppState;
mod oauth; mod oauth;
@@ -12,6 +13,29 @@ mod api_key;
mod bearer_token; mod bearer_token;
mod middleware; mod middleware;
pub use middleware::AuthorizationMiddleware; pub use middleware::AuthorizationMiddleware;
use crate::api::auth::api_key::ApiKey;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum AuthCredential {
Session(Session),
ApiKey(ApiKey),
}
impl AuthCredential {
pub fn user_id(&self) -> u64 {
match self {
AuthCredential::Session(session) => session.user_id,
AuthCredential::ApiKey(api_key) => api_key.user_id as u64,
}
}
pub fn user_name(&self) -> String {
match self {
AuthCredential::Session(session) => session.user_name.clone(),
AuthCredential::ApiKey(api_key) => api_key.user_name.clone(),
}
}
}
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()

View File

@@ -11,5 +11,5 @@ mod auth;
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()
.merge(auth::get_routes()) .merge(auth::get_routes())
.nest("/audio", audio::get_routes()) .nest("/audio/:guild_id", audio::get_routes())
} }

View File

@@ -89,7 +89,9 @@ pub async fn enqueue_track(
let mut playlist_items: Vec<YtDlpItem> = Vec::new(); let mut playlist_items: Vec<YtDlpItem> = Vec::new();
if let Some(handler_lock) = manager.get(guild_id) { if let Some(handler_lock) = manager.get(guild_id) {
let mut handler = handler_lock.lock().await; let mut handler = handler_lock.lock().await;
let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap(); let guild = GuildCache::find_by_id(guild_id.get() as i64)
.await?
.unwrap();
let valid = is_valid_url(&track_url); let valid = is_valid_url(&track_url);
// Check if the URL is valid // Check if the URL is valid

View File

@@ -64,7 +64,7 @@ pub async fn set_volume(manager: &Arc<Songbird>, guild_id: &GuildId, volume: i32
let bound_volume = volume as f32 / 100.0; let bound_volume = volume as f32 / 100.0;
// Update the guild cache // Update the guild cache
let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64) let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();

View File

@@ -108,7 +108,7 @@ impl EventHandler for BotHandler {
for guild in ready.guilds { for guild in ready.guilds {
// Check if guild exists in database // Check if guild exists in database
let guild_id = guild.id.get() as i64; let guild_id = guild.id.get() as i64;
if let None = GuildCache::get_by_id(guild_id).await.unwrap() { if let None = GuildCache::find_by_id(guild_id).await.unwrap() {
let guild_cache = GuildCache { let guild_cache = GuildCache {
id: guild_id, id: guild_id,
name: guild.id.name(&ctx.cache), name: guild.id.name(&ctx.cache),

View File

@@ -4,7 +4,7 @@ use crate::error::SirenResult;
const TABLE_NAME: &str = "guilds"; const TABLE_NAME: &str = "guilds";
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)] #[derive(Serialize, Deserialize, sqlx::FromRow, Debug)]
pub struct GuildCache { pub struct GuildCache {
pub id: i64, pub id: i64,
pub name: Option<String>, pub name: Option<String>,
@@ -35,7 +35,7 @@ impl GuildCache {
Ok(()) Ok(())
} }
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> { pub async fn find_by_id(id: i64) -> SirenResult<Option<Self>> {
let pool = crate::data::pool(); let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME) let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder .where_condition(Condition::is_equal("id", "$1")) // Use a placeholder

View File

@@ -7,7 +7,7 @@ use crate::error::SirenResult;
pub mod events; pub mod events;
pub mod guilds; pub mod guilds;
pub mod messages; pub mod messages;
mod query; pub mod query;
static POOL: OnceLock<Pool<Postgres>> = OnceLock::new(); static POOL: OnceLock<Pool<Postgres>> = OnceLock::new();
static REDIS: OnceLock<RedisClient> = OnceLock::new(); static REDIS: OnceLock<RedisClient> = OnceLock::new();