diff --git a/.env b/.env index 8a212e6..ed45b5c 100644 --- a/.env +++ b/.env @@ -3,6 +3,8 @@ RUST_LOG=warn,siren=info DISCORD_TOKEN= DISCORD_SECRET= +JWT_SECRET=CHANGEME + DATABASE_USER=siren DATABASE_PASSWORD=CHANGEME # Change this to a secure password DATABASE_NAME=siren diff --git a/Cargo.toml b/Cargo.toml index 1f50d84..15d4041 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ rand_chacha = "0.3.1" tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] } regex = "1.11.0" axum = "0.7.7" +axum-extra = { version = "0.9.6", features = ["typed-header"] } lazy_static = "1.5.0" -futures = "0.3.31" -axum-login = "0.16.0" -sqlx-postgres = "0.8.2" +jsonwebtoken = "9.3.0" diff --git a/src/api/auth/api_key.rs b/src/api/auth/api_key.rs new file mode 100644 index 0000000..a8ec8df --- /dev/null +++ b/src/api/auth/api_key.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; +use axum::{middleware, Extension, Router}; +use axum::middleware::from_extractor; +use axum::routing::post; +use crate::api::auth::{authenticate_middleware, csprng}; +use crate::api::auth::middleware::AuthorizationMiddleware; +use crate::api::auth::session::Session; +use crate::AppState; +use crate::error::SirenResult; + +pub fn get_routes() -> Router> { + Router::new().route("/api-key", post(create_api_key)) + .route_layer(from_extractor::()) +} + +struct ApiKey { + pub key: String, + pub user_id: String, + pub access_mask: u32, +} + +impl ApiKey { + fn new(user_id: String, access_mask: u32) -> Self { + ApiKey { + key: csprng(64), + user_id, + access_mask + } + } +} + +async fn create_api_key(Extension(session): Extension) -> SirenResult { + log::debug!("Generating API key for {} ({})", &session.user_id, &session.user_name); + let api_key = ApiKey::new(session.user_id, 0); + Ok(api_key.key) +} \ No newline at end of file diff --git a/src/api/auth/bearer_token.rs b/src/api/auth/bearer_token.rs new file mode 100644 index 0000000..ff12a34 --- /dev/null +++ b/src/api/auth/bearer_token.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct BearerTokenClaims { + pub sub: String, + pub name: String, + pub iat: i64, + pub exp: i64, + pub jti: String, +} \ No newline at end of file diff --git a/src/api/auth/middleware.rs b/src/api/auth/middleware.rs new file mode 100644 index 0000000..7e54c08 --- /dev/null +++ b/src/api/auth/middleware.rs @@ -0,0 +1,71 @@ +use axum::async_trait; +use axum::extract::FromRequestParts; +use axum::http::request::Parts; +use axum::http::{Method, StatusCode}; +use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}}; +use chrono::Utc; +use jsonwebtoken::{decode, DecodingKey, Validation}; +use crate::api::auth::bearer_token::BearerTokenClaims; +use crate::api::auth::session::Session; +use crate::error::SirenResult; + +pub struct AuthorizationMiddleware; + +#[async_trait] +impl FromRequestParts for AuthorizationMiddleware +where + S: Send + Sync, +{ + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + // For options requests browsers will not send the authorization header. + if parts.method == Method::OPTIONS { + return Ok(Self); + } + + let Ok(TypedHeader(Authorization(bearer))) = + TypedHeader::>::from_request_parts(parts, state).await + else { + log::error!("Could not get Authorization header from the request"); + return Err(StatusCode::UNAUTHORIZED); + }; + + match check_auth(bearer).await { + Ok(session) => { + parts.extensions.insert(session); + Ok(Self) + }, + Err(err) => { + log::error!("{:?}", err); + Err(StatusCode::UNAUTHORIZED) + } + } + } +} + +async fn check_auth(bearer: Bearer) -> SirenResult { + // Decode and validate the JWT + let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment"); + let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes()); + + let token_data = decode::( + bearer.token(), + &decoding_key, + &Validation::default() + ).map_err(|_| StatusCode::UNAUTHORIZED)?; + + let claims = token_data.claims; + + // Check if the token has expired + let now = Utc::now().timestamp(); + if claims.exp < now { + return Err(StatusCode::UNAUTHORIZED.into()); + } + + // Confirm the session exists in the session store (based on `jti`) + match Session::find(&claims.jti).await { + Ok(Some(session)) => Ok(session), + _ => Err(StatusCode::UNAUTHORIZED)?, + } +} \ No newline at end of file diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs new file mode 100644 index 0000000..048dcc6 --- /dev/null +++ b/src/api/auth/mod.rs @@ -0,0 +1,28 @@ +use std::sync::Arc; +use axum::Router; +use rand::Rng; +use rand_chacha::ChaCha20Rng; +use rand_chacha::rand_core::SeedableRng; +use crate::AppState; + +mod oauth; +mod session; +mod api_key; +mod bearer_token; +mod middleware; + +pub fn get_routes() -> Router> { + Router::new() + .nest("/oauth", oauth::get_routes()) + .merge(api_key::get_routes()) +} + +pub fn csprng(take: usize) -> String { + // Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9) + let rng = ChaCha20Rng::from_entropy(); + rng + .sample_iter(rand::distributions::Alphanumeric) + .take(take) + .map(char::from) + .collect() +} diff --git a/src/api/oauth.rs b/src/api/auth/oauth.rs similarity index 52% rename from src/api/oauth.rs rename to src/api/auth/oauth.rs index cd7dbaa..354b2b4 100644 --- a/src/api/oauth.rs +++ b/src/api/auth/oauth.rs @@ -1,48 +1,24 @@ use std::env; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use axum::extract::{Query, State}; use axum::http::{HeaderMap, HeaderValue, StatusCode}; use axum::{Json, Router}; use axum::http::header::SET_COOKIE; use axum::response::Redirect; use axum::routing::get; -use chrono::{DateTime, Utc}; -use rand::Rng; -use rand_chacha::ChaCha20Rng; -use rand_chacha::rand_core::SeedableRng; -use redis::{AsyncCommands, RedisResult}; use serde::{Deserialize, Serialize}; -use crate::{data, AppState}; +use crate::api::auth::bearer_token::BearerTokenClaims; +use crate::AppState; +use crate::api::auth::csprng; +use crate::api::auth::session::Session; use crate::error::SirenResult; -static SESSION_TTL: OnceLock = OnceLock::new(); - pub fn get_routes() -> Router> { Router::new() .route("/authorize", get(discord_authorize)) .route("/callback", get(oauth_callback)) } -fn get_session_ttl() -> i64 { - // Initialize the SESSION_TTL value lazily - *SESSION_TTL.get_or_init(|| { - env::var("SESSION_TTL") - .ok() - .and_then(|val| val.parse::().ok()) - .unwrap_or(3600) // Default to 3600 seconds (1 hour) - }) -} - -pub fn csprng(take: usize) -> String { - // Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9) - let rng = ChaCha20Rng::from_entropy(); - rng - .sample_iter(rand::distributions::Alphanumeric) - .take(take) - .map(char::from) - .collect() -} - #[derive(Deserialize)] struct AuthQuery { code: String, @@ -66,59 +42,6 @@ struct DiscordUser { avatar: Option, } -#[derive(Serialize, Deserialize, Debug)] -struct Session { - session_id: String, - user_id: String, - user_name: String, - pub expires_at: DateTime, -} - -impl Session { - fn new(id: String, user_id: String, user_name: String) -> Session { - let now = Utc::now(); - let session_ttl = get_session_ttl(); - Session { - session_id: id, - user_id, - user_name, - expires_at: now + chrono::Duration::seconds(session_ttl), - } - } - - async fn insert(&self) -> SirenResult<()> { - let mut redis = data::redis_async_connection().await?; - let session_id = self.session_id.clone(); - redis - .set_ex( - session_id, - serde_json::to_string(self)?, - self.expires_at.timestamp() as u64, - ) - .await?; - Ok(()) - } - - async fn get(session_id: String) -> SirenResult> { - let mut redis = data::redis_async_connection().await?; - let result: RedisResult> = redis.get(session_id).await; - match result { - Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)), - Ok(None) => Ok(None), - Err(err) => Err(err.into()), - } - } - - async fn delete(session_id: String) -> SirenResult<()> { - let mut redis = data::redis_async_connection().await?; - let result: RedisResult<()> = redis.del(session_id).await; - match result { - Ok(_) => Ok(()), - Err(err) => Err(err.into()), - } - } -} - // async fn discord_authorize_redirect(State(state): State>) -> Redirect { // // Construct the Discord OAuth URL // let discord_auth_url = format!( @@ -137,10 +60,17 @@ async fn discord_authorize(State(state): State>) -> SirenResult>, Query(query): Query, -) -> SirenResult<(HeaderMap, Json)> { +) -> SirenResult> { // Exchange code for an access token let token_response = state .client @@ -193,28 +123,32 @@ async fn oauth_callback( log::debug!("User authenticated: {:?}", user_data); - // Generate a session token - let session_token = csprng(16); - let expiration = env::var("API_SESSION_TTL") - .expect("Expected a session ttl in the environment") - .parse::() - .unwrap(); - // Create and insert the session - let session = Session::new( - session_token.clone(), - user_data.id.clone(), - user_data.username.clone(), - ); + let session = Session::new(user_data.id.clone(), user_data.username.clone()); session.insert().await?; - let cookie_value = format!( - "session={}; HttpOnly; Path=/; Max-Age={}", - session_token, expiration - ); + let issued_at = chrono::Utc::now(); - let mut headers = HeaderMap::new(); - headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap()); + let claims = BearerTokenClaims { + sub: session.user_id.clone(), + name: session.user_name.clone(), + iat: issued_at.timestamp(), + exp: session.expires_at.timestamp(), + jti: session.session_id.clone(), + }; - Ok((headers, Json(user_data))) + // Create the JWT + let jwt_secret = env::var("JWT_SECRET").expect("Expected a JWT secret in the environment"); + let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes()); + let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // Return the bearer token and user information + let response = BearerTokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: (session.expires_at.timestamp() - issued_at.timestamp()) as u64, + }; + + Ok(Json(response)) } diff --git a/src/api/auth/session.rs b/src/api/auth/session.rs new file mode 100644 index 0000000..b67a5f8 --- /dev/null +++ b/src/api/auth/session.rs @@ -0,0 +1,73 @@ +use std::env; +use std::sync::OnceLock; +use chrono::{DateTime, Utc}; +use redis::{AsyncCommands, RedisResult}; +use serde::{Deserialize, Serialize}; +use crate::api::auth::csprng; +use crate::data; +use crate::error::SirenResult; + +static SESSION_TTL: OnceLock = OnceLock::new(); + +fn get_session_ttl() -> i64 { + // Initialize the SESSION_TTL value lazily + *SESSION_TTL.get_or_init(|| { + env::var("API_SESSION_TTL") + .ok() + .and_then(|val| val.parse::().ok()) + .unwrap_or(3600) // Default to 3600 seconds (1 hour) + }) +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Session { + pub session_id: String, + pub user_id: String, + pub user_name: String, + pub expires_at: DateTime, +} + +impl Session { + pub fn new(user_id: String, user_name: String) -> Session { + let now = Utc::now(); + let session_ttl = get_session_ttl(); + Session { + session_id: csprng(32), + user_id, + user_name, + expires_at: now + chrono::Duration::seconds(session_ttl), + } + } + + pub async fn insert(&self) -> SirenResult<()> { + let mut redis = data::redis_async_connection().await?; + let session_id = self.session_id.clone(); + redis + .set_ex( + session_id, + serde_json::to_string(self)?, + self.expires_at.timestamp() as u64, + ) + .await?; + Ok(()) + } + + pub async fn find(session_id: &str) -> SirenResult> { + let mut redis = data::redis_async_connection().await?; + let result: RedisResult> = redis.get(session_id).await; + match result { + Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)), + Ok(None) => Ok(None), + Err(err) => Err(err.into()), + } + } + + pub async fn delete(session_id: &str) -> SirenResult<()> { + let mut redis = data::redis_async_connection().await?; + let result: RedisResult<()> = redis.del(session_id).await; + match result { + Ok(_) => Ok(()), + Err(err) => Err(err.into()), + } + } +} \ No newline at end of file diff --git a/src/api/mod.rs b/src/api/mod.rs index 7c5bd9b..9b2e9ed 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,13 +1,12 @@ -mod app; -mod oauth; - pub use app::App; use std::sync::Arc; use axum::Router; -use serde::{Deserialize, Serialize}; use crate::AppState; +mod app; +mod auth; + pub fn get_routes() -> Router> { - Router::new().nest("/oauth", oauth::get_routes()) + Router::new().merge(auth::get_routes()) } diff --git a/src/bot/chat/mod.rs b/src/bot/chat/mod.rs index af3bb08..dfde377 100644 --- a/src/bot/chat/mod.rs +++ b/src/bot/chat/mod.rs @@ -10,24 +10,24 @@ pub async fn process_message(ctx: &Context, command: &CommandInteraction, privat pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option { let data = CreateMessage::new().content(content.to_owned()); - return match user_id.dm(ctx, data).await { + match user_id.dm(ctx, data).await { Ok(message) => Some(message), Err(err) => { log::error!("Failed to create direct message for {content}\n{err}"); None } - }; + } } pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option { let data = CreateMessage::new().content(content.to_owned()); - return match user.direct_message(ctx, data).await { + match user.direct_message(ctx, data).await { Ok(message) => Some(message), Err(err) => { log::error!("Failed to create direct message for {content}\n{err}"); None } - }; + } } pub async fn create_message_response( @@ -50,7 +50,7 @@ pub async fn create_message_response( } pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) { - let mut data = CreateInteractionResponseMessage::new(); + let data = CreateInteractionResponseMessage::new(); let builder = CreateInteractionResponse::Message(data); match modal.create_response(&ctx.http, builder).await { Ok(_) => {} diff --git a/src/bot/commands/utility/ping.rs b/src/bot/commands/utility/ping.rs index 6e24f70..ec95f7e 100644 --- a/src/bot/commands/utility/ping.rs +++ b/src/bot/commands/utility/ping.rs @@ -1,4 +1,4 @@ -use serenity::all::{CommandDataOption, CommandInteraction, Context, CreateCommand}; +use serenity::all::{CommandInteraction, Context, CreateCommand}; use crate::bot::chat::create_message_response; pub async fn run(ctx: &Context, command: &CommandInteraction) { diff --git a/src/bot/handler.rs b/src/bot/handler.rs index b7a4c83..5048105 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -1,4 +1,4 @@ -use serenity::all::{CreateInteractionResponse, Interaction, ResumedEvent}; +use serenity::all::{Interaction, ResumedEvent}; use serenity::async_trait; use serenity::model::gateway::Ready; use serenity::model::channel::Message; @@ -7,7 +7,7 @@ use crate::bot::commands::chat::generate_response; use crate::bot::oai::OAI; use crate::data::guilds::GuildCache; use super::{commands}; -use super::chat::{create_message_response, create_modal_response}; +use super::chat::{create_modal_response}; pub struct BotHandler { // Open AI Config diff --git a/src/data/guilds/model.rs b/src/data/guilds/model.rs index b3b03cd..145607c 100644 --- a/src/data/guilds/model.rs +++ b/src/data/guilds/model.rs @@ -26,12 +26,12 @@ impl GuildCache { )", TABLE_NAME )) - .bind(self.id) - .bind(&self.name) - .bind(self.owner_id) - .bind(self.volume) - .execute(pool) - .await?; + .bind(self.id) + .bind(&self.name) + .bind(self.owner_id) + .bind(self.volume) + .execute(pool) + .await?; Ok(()) } @@ -40,10 +40,7 @@ impl GuildCache { let query = QueryBuilder::new(TABLE_NAME) .where_condition(Condition::is_equal("id", "$1")) // Use a placeholder .build(); - let item = sqlx::query_as(&query) - .bind(id) - .fetch_optional(pool) - .await?; + let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?; Ok(item) } @@ -58,12 +55,12 @@ impl GuildCache { WHERE id = $1", TABLE_NAME )) - .bind(self.id) - .bind(&self.name) - .bind(self.owner_id) - .bind(self.volume) - .execute(pool) - .await?; + .bind(self.id) + .bind(&self.name) + .bind(self.owner_id) + .bind(self.volume) + .execute(pool) + .await?; Ok(()) } } diff --git a/src/data/query.rs b/src/data/query.rs index 37bd1a0..ea0f32b 100644 --- a/src/data/query.rs +++ b/src/data/query.rs @@ -160,4 +160,4 @@ impl Condition { Condition::Group(a) => format!("({})", a.to_sql()), } } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index ad8597a..c4d302f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,4 @@ use std::env; -use std::collections::HashSet; use std::sync::Arc; use serenity::http::Http; use serenity::prelude::*; @@ -72,7 +71,10 @@ async fn main() -> Result<(), Box> { cache: Arc::clone(&client.cache), }; - log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner); + log::debug!( + "Starting Siren with ID: {bot_id} (Contact: {:?})", + bot_owner + ); // Spawn shutdown signal handling let shard_manager = Arc::clone(&client.shard_manager);