Implemented API Key creation/usage and changed layout of audio requests

This commit is contained in:
2024-12-21 12:03:02 -05:00
parent 4d9ee81ecf
commit ceea975836
16 changed files with 186 additions and 76 deletions

View File

@@ -5,13 +5,7 @@ meta {
}
post {
url: {{baseUrl}}/audio/pause
url: {{baseUrl}}/audio/1061092965579235398/pause
body: json
auth: inherit
}
body:json {
{
"guild_id": 1061092965579235398
}
}

View File

@@ -5,14 +5,13 @@ meta {
}
post {
url: {{baseUrl}}/audio/play
url: {{baseUrl}}/audio/1061092965579235398/play
body: json
auth: inherit
}
body:json {
{
"url": "https://www.youtube.com/watch?v=V-QDxuknK-Q",
"guild_id": 1061092965579235398
"url": "https://www.youtube.com/watch?v=V-QDxuknK-Q"
}
}

View File

@@ -5,13 +5,7 @@ meta {
}
post {
url: {{baseUrl}}/audio/resume
url: {{baseUrl}}/audio/1061092965579235398/resume
body: json
auth: inherit
}
body:json {
{
"guild_id": 1061092965579235398
}
}

View File

@@ -1,11 +1,12 @@
auth {
mode: bearer
mode: apikey
}
auth:bearer {
token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDcwNDI3NSwiZXhwIjoxNzM0NzkwNjc1LCJqdGkiOiJMSnc1Vnk3azZjc1BiYlJRWGlNcVFFVUZlQ29JS2JqcCJ9.sdgb93DmX9_augMdktYr58m5eTIJPuY13d87pckZOns
auth:apikey {
key: X-API-Key
value: rwOS4yMmNpQvL0vLHc1jWQoefJB1bvKvOvBSswiYh0mkhZDc1lsgFZmpXaSUXAa5ZjpRWR117hLQ1l0VPPSGkRXZl7dPRVCc
placement: header
}
vars:pre-request {
baseUrl: http://localhost:3000/api
}

View File

@@ -0,0 +1,11 @@
meta {
name: Create API Key
type: http
seq: 2
}
post {
url: {{baseUrl}}/api-key
body: none
auth: inherit
}

View File

@@ -16,6 +16,12 @@ CREATE TABLE IF NOT EXISTS messages (
request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL
);
CREATE TABLE IF NOT EXISTS api_keys (
key TEXT PRIMARY KEY NOT NULL,
user_id BIGINT NOT NULL,
user_name TEXT NOT NULL,
access_mask INT
);
CREATE TABLE IF NOT EXISTS dice_thresholds (
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(),
owner_id BIGINT NOT NULL,

View File

@@ -1,12 +1,12 @@
use std::sync::Arc;
use axum::extract::State;
use axum::extract::{Path, State};
use axum::middleware::from_extractor;
use axum::{Extension, Json, Router};
use axum::response::IntoResponse;
use axum::routing::post;
use reqwest::StatusCode;
use serde::Deserialize;
use crate::api::auth::{AuthorizationMiddleware, Session};
use crate::api::auth::{AuthCredential, AuthorizationMiddleware, Session};
use crate::AppState;
use crate::bot::commands::audio::join_voice_channel;
use crate::bot::commands::audio::pause::pause_track;
@@ -28,24 +28,25 @@ pub fn get_routes() -> Router<Arc<AppState>> {
#[derive(Deserialize)]
struct PlayTrackRequest {
url: String,
guild_id: u64,
}
async fn play_audio(
Extension(session): Extension<Session>,
Extension(credential): Extension<AuthCredential>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
Json(payload): Json<PlayTrackRequest>,
) -> SirenResult<()> {
log::debug!("Playing audio in guild: {}", payload.guild_id);
log::debug!("Playing audio in guild: {}", guild_id);
// Check if the user exists in the cache
let user_id = match state.cache.user(session.user_id) {
let user_id = credential.user_id();
let user_id = match state.cache.user(user_id) {
Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())),
};
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) {
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
@@ -57,20 +58,15 @@ async fn play_audio(
Ok(())
}
#[derive(Deserialize)]
struct GuildTrackRequest {
guild_id: u64,
}
async fn pause_audio(
Extension(_): Extension<Session>,
Extension(_): Extension<AuthCredential>,
State(state): State<Arc<AppState>>,
Json(payload): Json<GuildTrackRequest>,
Path(guild_id): Path<u64>,
) -> SirenResult<()> {
log::debug!("Pausing audio in guild: {}", payload.guild_id);
log::debug!("Pausing audio in guild: {}", guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) {
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
@@ -81,14 +77,14 @@ async fn pause_audio(
}
async fn resume_audio(
Extension(_): Extension<Session>,
Extension(_): Extension<AuthCredential>,
State(state): State<Arc<AppState>>,
Json(payload): Json<GuildTrackRequest>,
Path(guild_id): Path<u64>,
) -> SirenResult<()> {
log::debug!("Pausing audio in guild: {}", payload.guild_id);
log::debug!("Pausing audio in guild: {}", guild_id);
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(payload.guild_id) {
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};

View File

@@ -2,11 +2,14 @@ use std::sync::Arc;
use axum::{Extension, Router};
use axum::middleware::from_extractor;
use axum::routing::post;
use crate::api::auth::csprng;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use crate::api::auth::{csprng, AuthCredential};
use crate::api::auth::AuthorizationMiddleware;
use crate::api::auth::session::Session;
use crate::AppState;
use crate::error::SirenResult;
use crate::data::query::{Condition, QueryBuilder};
use crate::error::{Error, SirenResult};
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
@@ -14,28 +17,82 @@ pub fn get_routes() -> Router<Arc<AppState>> {
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
struct ApiKey {
const TABLE_NAME: &str = "api_keys";
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
pub struct ApiKey {
pub key: String,
pub user_id: u64,
pub access_mask: u32,
pub user_id: i64,
pub user_name: String,
pub access_mask: i32,
}
impl ApiKey {
fn new(user_id: u64, access_mask: u32) -> Self {
fn new(user_id: u64, user_name: String, access_mask: i32) -> Self {
ApiKey {
key: csprng(64),
user_id,
key: csprng(96),
user_id: user_id as i64,
user_name,
access_mask,
}
}
pub async fn insert(&self) -> SirenResult<()> {
let pool = crate::data::pool();
sqlx::query(&format!(
"INSERT INTO {} (
key,
user_id,
user_name,
access_mask
) VALUES (
$1, $2, $3, $4
)",
TABLE_NAME
))
.bind(&self.key)
.bind(self.user_id)
.bind(&self.user_name)
.bind(self.access_mask)
.execute(pool)
.await?;
Ok(())
}
pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> {
let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("key", "$1"))
.build();
let item = sqlx::query_as(&query)
.bind(key)
.fetch_optional(pool)
.await?;
Ok(item)
}
pub async fn delete_by_id(key: &str) -> SirenResult<()> {
let pool = crate::data::pool();
sqlx::query(&format!("DELETE FROM {} WHERE key = $1", TABLE_NAME))
.bind(key)
.execute(pool)
.await?;
Ok(())
}
}
async fn create_api_key(Extension(session): Extension<Session>) -> SirenResult<String> {
async fn create_api_key(Extension(credential): Extension<AuthCredential>) -> SirenResult<String> {
let session = match credential {
AuthCredential::ApiKey(_) => return Err(Error::new(400, "API keys cannot be generated with an API key".to_string())),
AuthCredential::Session(session) => session
};
log::debug!(
"Generating API key for {} ({})",
&session.user_id,
&session.user_name
);
let api_key = ApiKey::new(session.user_id, 0);
let api_key = ApiKey::new(session.user_id, session.user_name, 0);
api_key.insert().await?;
Ok(api_key.key)
}

View File

@@ -8,6 +8,8 @@ use axum_extra::{
};
use chrono::Utc;
use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::api::auth::api_key::ApiKey;
use crate::api::auth::AuthCredential;
use crate::api::auth::bearer_token::BearerTokenClaims;
use crate::api::auth::session::Session;
use crate::error::SirenResult;
@@ -27,32 +29,46 @@ where
return Ok(Self);
}
let Ok(TypedHeader(Authorization(bearer))) =
// Check for a Bearer token in the `Authorization` header.
if let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
else {
return Err(StatusCode::UNAUTHORIZED);
};
match check_auth(bearer).await {
{
return match check_bearer_auth(bearer.token()).await {
Ok(session) => {
parts.extensions.insert(session);
parts.extensions.insert(AuthCredential::Session(session));
Ok(Self)
}
Err(err) => {
log::error!("{:?}", err);
Err(_) => Err(StatusCode::UNAUTHORIZED),
};
}
// Check for an API key in the custom `X-API-Key` header.
if let Some(api_key_header) = parts.headers.get("X-API-Key") {
return if let Ok(api_key) = api_key_header.to_str() {
match check_api_key_auth(api_key).await {
Ok(api_key) => {
parts.extensions.insert(AuthCredential::ApiKey(api_key));
Ok(Self)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
} else {
// Invalid header value
Err(StatusCode::BAD_REQUEST)
};
}
// If neither the Bearer token nor API key is present or valid, return `UNAUTHORIZED`
Err(StatusCode::UNAUTHORIZED)
}
}
}
}
async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
async fn check_bearer_auth(bearer_token: &str) -> SirenResult<Session> {
// Decode and validate the JWT
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
let token_data =
decode::<BearerTokenClaims>(bearer.token(), &decoding_key, &Validation::default())
let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
.map_err(|_| StatusCode::UNAUTHORIZED)?;
let claims = token_data.claims;
@@ -69,3 +85,13 @@ async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
_ => Err(StatusCode::UNAUTHORIZED)?,
}
}
async fn check_api_key_auth(key: &str) -> SirenResult<ApiKey> {
let api_key = match ApiKey::find_by_key(key).await? {
Some(api_key) => api_key,
None => return Err(StatusCode::UNAUTHORIZED.into()),
};
Ok(api_key)
}

View File

@@ -3,6 +3,7 @@ use axum::Router;
use rand::Rng;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::SeedableRng;
use serde::{Deserialize, Serialize};
use crate::AppState;
mod oauth;
@@ -12,6 +13,29 @@ mod api_key;
mod bearer_token;
mod middleware;
pub use middleware::AuthorizationMiddleware;
use crate::api::auth::api_key::ApiKey;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum AuthCredential {
Session(Session),
ApiKey(ApiKey),
}
impl AuthCredential {
pub fn user_id(&self) -> u64 {
match self {
AuthCredential::Session(session) => session.user_id,
AuthCredential::ApiKey(api_key) => api_key.user_id as u64,
}
}
pub fn user_name(&self) -> String {
match self {
AuthCredential::Session(session) => session.user_name.clone(),
AuthCredential::ApiKey(api_key) => api_key.user_name.clone(),
}
}
}
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()

View File

@@ -11,5 +11,5 @@ mod auth;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.merge(auth::get_routes())
.nest("/audio", audio::get_routes())
.nest("/audio/:guild_id", audio::get_routes())
}

View File

@@ -89,7 +89,9 @@ pub async fn enqueue_track(
let mut playlist_items: Vec<YtDlpItem> = Vec::new();
if let Some(handler_lock) = manager.get(guild_id) {
let mut handler = handler_lock.lock().await;
let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap();
let guild = GuildCache::find_by_id(guild_id.get() as i64)
.await?
.unwrap();
let valid = is_valid_url(&track_url);
// Check if the URL is valid

View File

@@ -64,7 +64,7 @@ pub async fn set_volume(manager: &Arc<Songbird>, guild_id: &GuildId, volume: i32
let bound_volume = volume as f32 / 100.0;
// Update the guild cache
let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64)
let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64)
.await
.unwrap()
.unwrap();

View File

@@ -108,7 +108,7 @@ impl EventHandler for BotHandler {
for guild in ready.guilds {
// Check if guild exists in database
let guild_id = guild.id.get() as i64;
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
if let None = GuildCache::find_by_id(guild_id).await.unwrap() {
let guild_cache = GuildCache {
id: guild_id,
name: guild.id.name(&ctx.cache),

View File

@@ -4,7 +4,7 @@ use crate::error::SirenResult;
const TABLE_NAME: &str = "guilds";
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
#[derive(Serialize, Deserialize, sqlx::FromRow, Debug)]
pub struct GuildCache {
pub id: i64,
pub name: Option<String>,
@@ -35,7 +35,7 @@ impl GuildCache {
Ok(())
}
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
pub async fn find_by_id(id: i64) -> SirenResult<Option<Self>> {
let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder

View File

@@ -7,7 +7,7 @@ use crate::error::SirenResult;
pub mod events;
pub mod guilds;
pub mod messages;
mod query;
pub mod query;
static POOL: OnceLock<Pool<Postgres>> = OnceLock::new();
static REDIS: OnceLock<RedisClient> = OnceLock::new();