Added auth route (temp api key route)

This commit is contained in:
2024-12-19 19:54:05 -05:00
parent 4840c7c001
commit a97505ea5e
15 changed files with 288 additions and 137 deletions

2
.env
View File

@@ -3,6 +3,8 @@ RUST_LOG=warn,siren=info
DISCORD_TOKEN= DISCORD_TOKEN=
DISCORD_SECRET= DISCORD_SECRET=
JWT_SECRET=CHANGEME
DATABASE_USER=siren DATABASE_USER=siren
DATABASE_PASSWORD=CHANGEME # Change this to a secure password DATABASE_PASSWORD=CHANGEME # Change this to a secure password
DATABASE_NAME=siren DATABASE_NAME=siren

View File

@@ -27,7 +27,6 @@ rand_chacha = "0.3.1"
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] } tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
regex = "1.11.0" regex = "1.11.0"
axum = "0.7.7" axum = "0.7.7"
axum-extra = { version = "0.9.6", features = ["typed-header"] }
lazy_static = "1.5.0" lazy_static = "1.5.0"
futures = "0.3.31" jsonwebtoken = "9.3.0"
axum-login = "0.16.0"
sqlx-postgres = "0.8.2"

36
src/api/auth/api_key.rs Normal file
View File

@@ -0,0 +1,36 @@
use std::sync::Arc;
use axum::{middleware, Extension, Router};
use axum::middleware::from_extractor;
use axum::routing::post;
use crate::api::auth::{authenticate_middleware, csprng};
use crate::api::auth::middleware::AuthorizationMiddleware;
use crate::api::auth::session::Session;
use crate::AppState;
use crate::error::SirenResult;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new().route("/api-key", post(create_api_key))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
struct ApiKey {
pub key: String,
pub user_id: String,
pub access_mask: u32,
}
impl ApiKey {
fn new(user_id: String, access_mask: u32) -> Self {
ApiKey {
key: csprng(64),
user_id,
access_mask
}
}
}
async fn create_api_key(Extension(session): Extension<Session>) -> SirenResult<String> {
log::debug!("Generating API key for {} ({})", &session.user_id, &session.user_name);
let api_key = ApiKey::new(session.user_id, 0);
Ok(api_key.key)
}

View File

@@ -0,0 +1,10 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct BearerTokenClaims {
pub sub: String,
pub name: String,
pub iat: i64,
pub exp: i64,
pub jti: String,
}

View File

@@ -0,0 +1,71 @@
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::{Method, StatusCode};
use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}};
use chrono::Utc;
use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::api::auth::bearer_token::BearerTokenClaims;
use crate::api::auth::session::Session;
use crate::error::SirenResult;
pub struct AuthorizationMiddleware;
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// For options requests browsers will not send the authorization header.
if parts.method == Method::OPTIONS {
return Ok(Self);
}
let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
else {
log::error!("Could not get Authorization header from the request");
return Err(StatusCode::UNAUTHORIZED);
};
match check_auth(bearer).await {
Ok(session) => {
parts.extensions.insert(session);
Ok(Self)
},
Err(err) => {
log::error!("{:?}", err);
Err(StatusCode::UNAUTHORIZED)
}
}
}
}
async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
// Decode and validate the JWT
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
let token_data = decode::<BearerTokenClaims>(
bearer.token(),
&decoding_key,
&Validation::default()
).map_err(|_| StatusCode::UNAUTHORIZED)?;
let claims = token_data.claims;
// Check if the token has expired
let now = Utc::now().timestamp();
if claims.exp < now {
return Err(StatusCode::UNAUTHORIZED.into());
}
// Confirm the session exists in the session store (based on `jti`)
match Session::find(&claims.jti).await {
Ok(Some(session)) => Ok(session),
_ => Err(StatusCode::UNAUTHORIZED)?,
}
}

28
src/api/auth/mod.rs Normal file
View File

@@ -0,0 +1,28 @@
use std::sync::Arc;
use axum::Router;
use rand::Rng;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::SeedableRng;
use crate::AppState;
mod oauth;
mod session;
mod api_key;
mod bearer_token;
mod middleware;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.nest("/oauth", oauth::get_routes())
.merge(api_key::get_routes())
}
pub fn csprng(take: usize) -> String {
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
let rng = ChaCha20Rng::from_entropy();
rng
.sample_iter(rand::distributions::Alphanumeric)
.take(take)
.map(char::from)
.collect()
}

View File

@@ -1,48 +1,24 @@
use std::env; use std::env;
use std::sync::{Arc, OnceLock}; use std::sync::Arc;
use axum::extract::{Query, State}; use axum::extract::{Query, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode}; use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::{Json, Router}; use axum::{Json, Router};
use axum::http::header::SET_COOKIE; use axum::http::header::SET_COOKIE;
use axum::response::Redirect; use axum::response::Redirect;
use axum::routing::get; use axum::routing::get;
use chrono::{DateTime, Utc};
use rand::Rng;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::SeedableRng;
use redis::{AsyncCommands, RedisResult};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{data, AppState}; use crate::api::auth::bearer_token::BearerTokenClaims;
use crate::AppState;
use crate::api::auth::csprng;
use crate::api::auth::session::Session;
use crate::error::SirenResult; use crate::error::SirenResult;
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()
.route("/authorize", get(discord_authorize)) .route("/authorize", get(discord_authorize))
.route("/callback", get(oauth_callback)) .route("/callback", get(oauth_callback))
} }
fn get_session_ttl() -> i64 {
// Initialize the SESSION_TTL value lazily
*SESSION_TTL.get_or_init(|| {
env::var("SESSION_TTL")
.ok()
.and_then(|val| val.parse::<i64>().ok())
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
})
}
pub fn csprng(take: usize) -> String {
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
let rng = ChaCha20Rng::from_entropy();
rng
.sample_iter(rand::distributions::Alphanumeric)
.take(take)
.map(char::from)
.collect()
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct AuthQuery { struct AuthQuery {
code: String, code: String,
@@ -66,59 +42,6 @@ struct DiscordUser {
avatar: Option<String>, avatar: Option<String>,
} }
#[derive(Serialize, Deserialize, Debug)]
struct Session {
session_id: String,
user_id: String,
user_name: String,
pub expires_at: DateTime<Utc>,
}
impl Session {
fn new(id: String, user_id: String, user_name: String) -> Session {
let now = Utc::now();
let session_ttl = get_session_ttl();
Session {
session_id: id,
user_id,
user_name,
expires_at: now + chrono::Duration::seconds(session_ttl),
}
}
async fn insert(&self) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let session_id = self.session_id.clone();
redis
.set_ex(
session_id,
serde_json::to_string(self)?,
self.expires_at.timestamp() as u64,
)
.await?;
Ok(())
}
async fn get(session_id: String) -> SirenResult<Option<Session>> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<Option<String>> = redis.get(session_id).await;
match result {
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
}
}
async fn delete(session_id: String) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<()> = redis.del(session_id).await;
match result {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
}
}
// async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect { // async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
// // Construct the Discord OAuth URL // // Construct the Discord OAuth URL
// let discord_auth_url = format!( // let discord_auth_url = format!(
@@ -137,10 +60,17 @@ async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<St
Ok(discord_auth_url) Ok(discord_auth_url)
} }
#[derive(Debug, Serialize)]
pub struct BearerTokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
}
async fn oauth_callback( async fn oauth_callback(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(query): Query<AuthQuery>, Query(query): Query<AuthQuery>,
) -> SirenResult<(HeaderMap, Json<DiscordUser>)> { ) -> SirenResult<Json<BearerTokenResponse>> {
// Exchange code for an access token // Exchange code for an access token
let token_response = state let token_response = state
.client .client
@@ -193,28 +123,32 @@ async fn oauth_callback(
log::debug!("User authenticated: {:?}", user_data); log::debug!("User authenticated: {:?}", user_data);
// Generate a session token
let session_token = csprng(16);
let expiration = env::var("API_SESSION_TTL")
.expect("Expected a session ttl in the environment")
.parse::<u64>()
.unwrap();
// Create and insert the session // Create and insert the session
let session = Session::new( let session = Session::new(user_data.id.clone(), user_data.username.clone());
session_token.clone(),
user_data.id.clone(),
user_data.username.clone(),
);
session.insert().await?; session.insert().await?;
let cookie_value = format!( let issued_at = chrono::Utc::now();
"session={}; HttpOnly; Path=/; Max-Age={}",
session_token, expiration
);
let mut headers = HeaderMap::new(); let claims = BearerTokenClaims {
headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap()); sub: session.user_id.clone(),
name: session.user_name.clone(),
iat: issued_at.timestamp(),
exp: session.expires_at.timestamp(),
jti: session.session_id.clone(),
};
Ok((headers, Json(user_data))) // Create the JWT
let jwt_secret = env::var("JWT_SECRET").expect("Expected a JWT secret in the environment");
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Return the bearer token and user information
let response = BearerTokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: (session.expires_at.timestamp() - issued_at.timestamp()) as u64,
};
Ok(Json(response))
} }

73
src/api/auth/session.rs Normal file
View File

@@ -0,0 +1,73 @@
use std::env;
use std::sync::OnceLock;
use chrono::{DateTime, Utc};
use redis::{AsyncCommands, RedisResult};
use serde::{Deserialize, Serialize};
use crate::api::auth::csprng;
use crate::data;
use crate::error::SirenResult;
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
fn get_session_ttl() -> i64 {
// Initialize the SESSION_TTL value lazily
*SESSION_TTL.get_or_init(|| {
env::var("API_SESSION_TTL")
.ok()
.and_then(|val| val.parse::<i64>().ok())
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
})
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Session {
pub session_id: String,
pub user_id: String,
pub user_name: String,
pub expires_at: DateTime<Utc>,
}
impl Session {
pub fn new(user_id: String, user_name: String) -> Session {
let now = Utc::now();
let session_ttl = get_session_ttl();
Session {
session_id: csprng(32),
user_id,
user_name,
expires_at: now + chrono::Duration::seconds(session_ttl),
}
}
pub async fn insert(&self) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let session_id = self.session_id.clone();
redis
.set_ex(
session_id,
serde_json::to_string(self)?,
self.expires_at.timestamp() as u64,
)
.await?;
Ok(())
}
pub async fn find(session_id: &str) -> SirenResult<Option<Session>> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<Option<String>> = redis.get(session_id).await;
match result {
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
}
}
pub async fn delete(session_id: &str) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<()> = redis.del(session_id).await;
match result {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
}
}

View File

@@ -1,13 +1,12 @@
mod app;
mod oauth;
pub use app::App; pub use app::App;
use std::sync::Arc; use std::sync::Arc;
use axum::Router; use axum::Router;
use serde::{Deserialize, Serialize};
use crate::AppState; use crate::AppState;
mod app;
mod auth;
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
Router::new().nest("/oauth", oauth::get_routes()) Router::new().merge(auth::get_routes())
} }

View File

@@ -10,24 +10,24 @@ pub async fn process_message(ctx: &Context, command: &CommandInteraction, privat
pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> { pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> {
let data = CreateMessage::new().content(content.to_owned()); let data = CreateMessage::new().content(content.to_owned());
return match user_id.dm(ctx, data).await { match user_id.dm(ctx, data).await {
Ok(message) => Some(message), Ok(message) => Some(message),
Err(err) => { Err(err) => {
log::error!("Failed to create direct message for {content}\n{err}"); log::error!("Failed to create direct message for {content}\n{err}");
None None
} }
}; }
} }
pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option<Message> { pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option<Message> {
let data = CreateMessage::new().content(content.to_owned()); let data = CreateMessage::new().content(content.to_owned());
return match user.direct_message(ctx, data).await { match user.direct_message(ctx, data).await {
Ok(message) => Some(message), Ok(message) => Some(message),
Err(err) => { Err(err) => {
log::error!("Failed to create direct message for {content}\n{err}"); log::error!("Failed to create direct message for {content}\n{err}");
None None
} }
}; }
} }
pub async fn create_message_response( pub async fn create_message_response(
@@ -50,7 +50,7 @@ pub async fn create_message_response(
} }
pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) { pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) {
let mut data = CreateInteractionResponseMessage::new(); let data = CreateInteractionResponseMessage::new();
let builder = CreateInteractionResponse::Message(data); let builder = CreateInteractionResponse::Message(data);
match modal.create_response(&ctx.http, builder).await { match modal.create_response(&ctx.http, builder).await {
Ok(_) => {} Ok(_) => {}

View File

@@ -1,4 +1,4 @@
use serenity::all::{CommandDataOption, CommandInteraction, Context, CreateCommand}; use serenity::all::{CommandInteraction, Context, CreateCommand};
use crate::bot::chat::create_message_response; use crate::bot::chat::create_message_response;
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {

View File

@@ -1,4 +1,4 @@
use serenity::all::{CreateInteractionResponse, Interaction, ResumedEvent}; use serenity::all::{Interaction, ResumedEvent};
use serenity::async_trait; use serenity::async_trait;
use serenity::model::gateway::Ready; use serenity::model::gateway::Ready;
use serenity::model::channel::Message; use serenity::model::channel::Message;
@@ -7,7 +7,7 @@ use crate::bot::commands::chat::generate_response;
use crate::bot::oai::OAI; use crate::bot::oai::OAI;
use crate::data::guilds::GuildCache; use crate::data::guilds::GuildCache;
use super::{commands}; use super::{commands};
use super::chat::{create_message_response, create_modal_response}; use super::chat::{create_modal_response};
pub struct BotHandler { pub struct BotHandler {
// Open AI Config // Open AI Config

View File

@@ -26,12 +26,12 @@ impl GuildCache {
)", )",
TABLE_NAME TABLE_NAME
)) ))
.bind(self.id) .bind(self.id)
.bind(&self.name) .bind(&self.name)
.bind(self.owner_id) .bind(self.owner_id)
.bind(self.volume) .bind(self.volume)
.execute(pool) .execute(pool)
.await?; .await?;
Ok(()) Ok(())
} }
@@ -40,10 +40,7 @@ impl GuildCache {
let query = QueryBuilder::new(TABLE_NAME) let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder .where_condition(Condition::is_equal("id", "$1")) // Use a placeholder
.build(); .build();
let item = sqlx::query_as(&query) let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?;
.bind(id)
.fetch_optional(pool)
.await?;
Ok(item) Ok(item)
} }
@@ -58,12 +55,12 @@ impl GuildCache {
WHERE id = $1", WHERE id = $1",
TABLE_NAME TABLE_NAME
)) ))
.bind(self.id) .bind(self.id)
.bind(&self.name) .bind(&self.name)
.bind(self.owner_id) .bind(self.owner_id)
.bind(self.volume) .bind(self.volume)
.execute(pool) .execute(pool)
.await?; .await?;
Ok(()) Ok(())
} }
} }

View File

@@ -160,4 +160,4 @@ impl Condition {
Condition::Group(a) => format!("({})", a.to_sql()), Condition::Group(a) => format!("({})", a.to_sql()),
} }
} }
} }

View File

@@ -1,5 +1,4 @@
use std::env; use std::env;
use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use serenity::http::Http; use serenity::http::Http;
use serenity::prelude::*; use serenity::prelude::*;
@@ -72,7 +71,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
cache: Arc::clone(&client.cache), cache: Arc::clone(&client.cache),
}; };
log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner); log::debug!(
"Starting Siren with ID: {bot_id} (Contact: {:?})",
bot_owner
);
// Spawn shutdown signal handling // Spawn shutdown signal handling
let shard_manager = Arc::clone(&client.shard_manager); let shard_manager = Arc::clone(&client.shard_manager);