Implemented API Key creation/usage and changed layout of audio requests
This commit is contained in:
@@ -5,13 +5,7 @@ meta {
|
||||
}
|
||||
|
||||
post {
|
||||
url: {{baseUrl}}/audio/pause
|
||||
url: {{baseUrl}}/audio/1061092965579235398/pause
|
||||
body: json
|
||||
auth: inherit
|
||||
}
|
||||
|
||||
body:json {
|
||||
{
|
||||
"guild_id": 1061092965579235398
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,14 +5,13 @@ meta {
|
||||
}
|
||||
|
||||
post {
|
||||
url: {{baseUrl}}/audio/play
|
||||
url: {{baseUrl}}/audio/1061092965579235398/play
|
||||
body: json
|
||||
auth: inherit
|
||||
}
|
||||
|
||||
body:json {
|
||||
{
|
||||
"url": "https://www.youtube.com/watch?v=V-QDxuknK-Q",
|
||||
"guild_id": 1061092965579235398
|
||||
"url": "https://www.youtube.com/watch?v=V-QDxuknK-Q"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,7 @@ meta {
|
||||
}
|
||||
|
||||
post {
|
||||
url: {{baseUrl}}/audio/resume
|
||||
url: {{baseUrl}}/audio/1061092965579235398/resume
|
||||
body: json
|
||||
auth: inherit
|
||||
}
|
||||
|
||||
body:json {
|
||||
{
|
||||
"guild_id": 1061092965579235398
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
auth {
|
||||
mode: bearer
|
||||
mode: apikey
|
||||
}
|
||||
|
||||
auth:bearer {
|
||||
token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDcwNDI3NSwiZXhwIjoxNzM0NzkwNjc1LCJqdGkiOiJMSnc1Vnk3azZjc1BiYlJRWGlNcVFFVUZlQ29JS2JqcCJ9.sdgb93DmX9_augMdktYr58m5eTIJPuY13d87pckZOns
|
||||
auth:apikey {
|
||||
key: X-API-Key
|
||||
value: rwOS4yMmNpQvL0vLHc1jWQoefJB1bvKvOvBSswiYh0mkhZDc1lsgFZmpXaSUXAa5ZjpRWR117hLQ1l0VPPSGkRXZl7dPRVCc
|
||||
placement: header
|
||||
}
|
||||
|
||||
vars:pre-request {
|
||||
baseUrl: http://localhost:3000/api
|
||||
}
|
||||
|
||||
11
bruno/oauth/Create API Key.bru
Normal file
11
bruno/oauth/Create API Key.bru
Normal file
@@ -0,0 +1,11 @@
|
||||
meta {
|
||||
name: Create API Key
|
||||
type: http
|
||||
seq: 2
|
||||
}
|
||||
|
||||
post {
|
||||
url: {{baseUrl}}/api-key
|
||||
body: none
|
||||
auth: inherit
|
||||
}
|
||||
@@ -16,6 +16,12 @@ CREATE TABLE IF NOT EXISTS messages (
|
||||
request_tags TEXT[] NOT NULL,
|
||||
response_tags TEXT[] NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
key TEXT PRIMARY KEY NOT NULL,
|
||||
user_id BIGINT NOT NULL,
|
||||
user_name TEXT NOT NULL,
|
||||
access_mask INT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS dice_thresholds (
|
||||
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(),
|
||||
owner_id BIGINT NOT NULL,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use std::sync::Arc;
|
||||
use axum::extract::State;
|
||||
use axum::extract::{Path, State};
|
||||
use axum::middleware::from_extractor;
|
||||
use axum::{Extension, Json, Router};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::post;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use crate::api::auth::{AuthorizationMiddleware, Session};
|
||||
use crate::api::auth::{AuthCredential, AuthorizationMiddleware, Session};
|
||||
use crate::AppState;
|
||||
use crate::bot::commands::audio::join_voice_channel;
|
||||
use crate::bot::commands::audio::pause::pause_track;
|
||||
@@ -28,24 +28,25 @@ pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
#[derive(Deserialize)]
|
||||
struct PlayTrackRequest {
|
||||
url: String,
|
||||
guild_id: u64,
|
||||
}
|
||||
|
||||
async fn play_audio(
|
||||
Extension(session): Extension<Session>,
|
||||
Extension(credential): Extension<AuthCredential>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
Json(payload): Json<PlayTrackRequest>,
|
||||
) -> SirenResult<()> {
|
||||
log::debug!("Playing audio in guild: {}", payload.guild_id);
|
||||
log::debug!("Playing audio in guild: {}", guild_id);
|
||||
|
||||
// Check if the user exists in the cache
|
||||
let user_id = match state.cache.user(session.user_id) {
|
||||
let user_id = credential.user_id();
|
||||
let user_id = match state.cache.user(user_id) {
|
||||
Some(user) => user.id,
|
||||
None => return Err(Error::not_found("User not found".to_string())),
|
||||
};
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(payload.guild_id) {
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
@@ -57,20 +58,15 @@ async fn play_audio(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GuildTrackRequest {
|
||||
guild_id: u64,
|
||||
}
|
||||
|
||||
async fn pause_audio(
|
||||
Extension(_): Extension<Session>,
|
||||
Extension(_): Extension<AuthCredential>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(payload): Json<GuildTrackRequest>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> SirenResult<()> {
|
||||
log::debug!("Pausing audio in guild: {}", payload.guild_id);
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(payload.guild_id) {
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
@@ -81,14 +77,14 @@ async fn pause_audio(
|
||||
}
|
||||
|
||||
async fn resume_audio(
|
||||
Extension(_): Extension<Session>,
|
||||
Extension(_): Extension<AuthCredential>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(payload): Json<GuildTrackRequest>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> SirenResult<()> {
|
||||
log::debug!("Pausing audio in guild: {}", payload.guild_id);
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
let guild_id = match state.cache.guild(payload.guild_id) {
|
||||
let guild_id = match state.cache.guild(guild_id) {
|
||||
Some(guild) => guild.id,
|
||||
None => return Err(Error::not_found("Guild not found".to_string())),
|
||||
};
|
||||
|
||||
@@ -2,11 +2,14 @@ use std::sync::Arc;
|
||||
use axum::{Extension, Router};
|
||||
use axum::middleware::from_extractor;
|
||||
use axum::routing::post;
|
||||
use crate::api::auth::csprng;
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::api::auth::{csprng, AuthCredential};
|
||||
use crate::api::auth::AuthorizationMiddleware;
|
||||
use crate::api::auth::session::Session;
|
||||
use crate::AppState;
|
||||
use crate::error::SirenResult;
|
||||
use crate::data::query::{Condition, QueryBuilder};
|
||||
use crate::error::{Error, SirenResult};
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
@@ -14,28 +17,82 @@ pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
}
|
||||
|
||||
struct ApiKey {
|
||||
const TABLE_NAME: &str = "api_keys";
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct ApiKey {
|
||||
pub key: String,
|
||||
pub user_id: u64,
|
||||
pub access_mask: u32,
|
||||
pub user_id: i64,
|
||||
pub user_name: String,
|
||||
pub access_mask: i32,
|
||||
}
|
||||
|
||||
impl ApiKey {
|
||||
fn new(user_id: u64, access_mask: u32) -> Self {
|
||||
fn new(user_id: u64, user_name: String, access_mask: i32) -> Self {
|
||||
ApiKey {
|
||||
key: csprng(64),
|
||||
user_id,
|
||||
key: csprng(96),
|
||||
user_id: user_id as i64,
|
||||
user_name,
|
||||
access_mask,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
let pool = crate::data::pool();
|
||||
sqlx::query(&format!(
|
||||
"INSERT INTO {} (
|
||||
key,
|
||||
user_id,
|
||||
user_name,
|
||||
access_mask
|
||||
) VALUES (
|
||||
$1, $2, $3, $4
|
||||
)",
|
||||
TABLE_NAME
|
||||
))
|
||||
.bind(&self.key)
|
||||
.bind(self.user_id)
|
||||
.bind(&self.user_name)
|
||||
.bind(self.access_mask)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> {
|
||||
let pool = crate::data::pool();
|
||||
let query = QueryBuilder::new(TABLE_NAME)
|
||||
.where_condition(Condition::is_equal("key", "$1"))
|
||||
.build();
|
||||
let item = sqlx::query_as(&query)
|
||||
.bind(key)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(item)
|
||||
}
|
||||
|
||||
pub async fn delete_by_id(key: &str) -> SirenResult<()> {
|
||||
let pool = crate::data::pool();
|
||||
sqlx::query(&format!("DELETE FROM {} WHERE key = $1", TABLE_NAME))
|
||||
.bind(key)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_api_key(Extension(session): Extension<Session>) -> SirenResult<String> {
|
||||
async fn create_api_key(Extension(credential): Extension<AuthCredential>) -> SirenResult<String> {
|
||||
let session = match credential {
|
||||
AuthCredential::ApiKey(_) => return Err(Error::new(400, "API keys cannot be generated with an API key".to_string())),
|
||||
AuthCredential::Session(session) => session
|
||||
};
|
||||
log::debug!(
|
||||
"Generating API key for {} ({})",
|
||||
&session.user_id,
|
||||
&session.user_name
|
||||
);
|
||||
let api_key = ApiKey::new(session.user_id, 0);
|
||||
let api_key = ApiKey::new(session.user_id, session.user_name, 0);
|
||||
api_key.insert().await?;
|
||||
Ok(api_key.key)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ use axum_extra::{
|
||||
};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use crate::api::auth::api_key::ApiKey;
|
||||
use crate::api::auth::AuthCredential;
|
||||
use crate::api::auth::bearer_token::BearerTokenClaims;
|
||||
use crate::api::auth::session::Session;
|
||||
use crate::error::SirenResult;
|
||||
@@ -27,33 +29,47 @@ where
|
||||
return Ok(Self);
|
||||
}
|
||||
|
||||
let Ok(TypedHeader(Authorization(bearer))) =
|
||||
// Check for a Bearer token in the `Authorization` header.
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
else {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
match check_auth(bearer).await {
|
||||
Ok(session) => {
|
||||
parts.extensions.insert(session);
|
||||
Ok(Self)
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("{:?}", err);
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
{
|
||||
return match check_bearer_auth(bearer.token()).await {
|
||||
Ok(session) => {
|
||||
parts.extensions.insert(AuthCredential::Session(session));
|
||||
Ok(Self)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
}
|
||||
|
||||
// Check for an API key in the custom `X-API-Key` header.
|
||||
if let Some(api_key_header) = parts.headers.get("X-API-Key") {
|
||||
return if let Ok(api_key) = api_key_header.to_str() {
|
||||
match check_api_key_auth(api_key).await {
|
||||
Ok(api_key) => {
|
||||
parts.extensions.insert(AuthCredential::ApiKey(api_key));
|
||||
Ok(Self)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
} else {
|
||||
// Invalid header value
|
||||
Err(StatusCode::BAD_REQUEST)
|
||||
};
|
||||
}
|
||||
|
||||
// If neither the Bearer token nor API key is present or valid, return `UNAUTHORIZED`
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
|
||||
async fn check_bearer_auth(bearer_token: &str) -> SirenResult<Session> {
|
||||
// Decode and validate the JWT
|
||||
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
|
||||
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
|
||||
|
||||
let token_data =
|
||||
decode::<BearerTokenClaims>(bearer.token(), &decoding_key, &Validation::default())
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let claims = token_data.claims;
|
||||
|
||||
@@ -69,3 +85,13 @@ async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
|
||||
_ => Err(StatusCode::UNAUTHORIZED)?,
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_api_key_auth(key: &str) -> SirenResult<ApiKey> {
|
||||
|
||||
let api_key = match ApiKey::find_by_key(key).await? {
|
||||
Some(api_key) => api_key,
|
||||
None => return Err(StatusCode::UNAUTHORIZED.into()),
|
||||
};
|
||||
|
||||
Ok(api_key)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use axum::Router;
|
||||
use rand::Rng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use rand_chacha::rand_core::SeedableRng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::AppState;
|
||||
|
||||
mod oauth;
|
||||
@@ -12,6 +13,29 @@ mod api_key;
|
||||
mod bearer_token;
|
||||
mod middleware;
|
||||
pub use middleware::AuthorizationMiddleware;
|
||||
use crate::api::auth::api_key::ApiKey;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub enum AuthCredential {
|
||||
Session(Session),
|
||||
ApiKey(ApiKey),
|
||||
}
|
||||
|
||||
impl AuthCredential {
|
||||
pub fn user_id(&self) -> u64 {
|
||||
match self {
|
||||
AuthCredential::Session(session) => session.user_id,
|
||||
AuthCredential::ApiKey(api_key) => api_key.user_id as u64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_name(&self) -> String {
|
||||
match self {
|
||||
AuthCredential::Session(session) => session.user_name.clone(),
|
||||
AuthCredential::ApiKey(api_key) => api_key.user_name.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
|
||||
@@ -11,5 +11,5 @@ mod auth;
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.merge(auth::get_routes())
|
||||
.nest("/audio", audio::get_routes())
|
||||
.nest("/audio/:guild_id", audio::get_routes())
|
||||
}
|
||||
|
||||
@@ -89,7 +89,9 @@ pub async fn enqueue_track(
|
||||
let mut playlist_items: Vec<YtDlpItem> = Vec::new();
|
||||
if let Some(handler_lock) = manager.get(guild_id) {
|
||||
let mut handler = handler_lock.lock().await;
|
||||
let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap();
|
||||
let guild = GuildCache::find_by_id(guild_id.get() as i64)
|
||||
.await?
|
||||
.unwrap();
|
||||
let valid = is_valid_url(&track_url);
|
||||
|
||||
// Check if the URL is valid
|
||||
|
||||
@@ -64,7 +64,7 @@ pub async fn set_volume(manager: &Arc<Songbird>, guild_id: &GuildId, volume: i32
|
||||
let bound_volume = volume as f32 / 100.0;
|
||||
|
||||
// Update the guild cache
|
||||
let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64)
|
||||
let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
@@ -108,7 +108,7 @@ impl EventHandler for BotHandler {
|
||||
for guild in ready.guilds {
|
||||
// Check if guild exists in database
|
||||
let guild_id = guild.id.get() as i64;
|
||||
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
|
||||
if let None = GuildCache::find_by_id(guild_id).await.unwrap() {
|
||||
let guild_cache = GuildCache {
|
||||
id: guild_id,
|
||||
name: guild.id.name(&ctx.cache),
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::error::SirenResult;
|
||||
|
||||
const TABLE_NAME: &str = "guilds";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Debug)]
|
||||
pub struct GuildCache {
|
||||
pub id: i64,
|
||||
pub name: Option<String>,
|
||||
@@ -35,7 +35,7 @@ impl GuildCache {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
|
||||
pub async fn find_by_id(id: i64) -> SirenResult<Option<Self>> {
|
||||
let pool = crate::data::pool();
|
||||
let query = QueryBuilder::new(TABLE_NAME)
|
||||
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::error::SirenResult;
|
||||
pub mod events;
|
||||
pub mod guilds;
|
||||
pub mod messages;
|
||||
mod query;
|
||||
pub mod query;
|
||||
|
||||
static POOL: OnceLock<Pool<Postgres>> = OnceLock::new();
|
||||
static REDIS: OnceLock<RedisClient> = OnceLock::new();
|
||||
|
||||
Reference in New Issue
Block a user