Added auth route (temp api key route)

This commit is contained in:
2024-12-19 19:54:05 -05:00
parent 4840c7c001
commit a97505ea5e
15 changed files with 288 additions and 137 deletions

2
.env
View File

@@ -3,6 +3,8 @@ RUST_LOG=warn,siren=info
DISCORD_TOKEN=
DISCORD_SECRET=
JWT_SECRET=CHANGEME
DATABASE_USER=siren
DATABASE_PASSWORD=CHANGEME # Change this to a secure password
DATABASE_NAME=siren

View File

@@ -27,7 +27,6 @@ rand_chacha = "0.3.1"
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
regex = "1.11.0"
axum = "0.7.7"
axum-extra = { version = "0.9.6", features = ["typed-header"] }
lazy_static = "1.5.0"
futures = "0.3.31"
axum-login = "0.16.0"
sqlx-postgres = "0.8.2"
jsonwebtoken = "9.3.0"

36
src/api/auth/api_key.rs Normal file
View File

@@ -0,0 +1,36 @@
use std::sync::Arc;
use axum::{middleware, Extension, Router};
use axum::middleware::from_extractor;
use axum::routing::post;
use crate::api::auth::{authenticate_middleware, csprng};
use crate::api::auth::middleware::AuthorizationMiddleware;
use crate::api::auth::session::Session;
use crate::AppState;
use crate::error::SirenResult;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new().route("/api-key", post(create_api_key))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
struct ApiKey {
pub key: String,
pub user_id: String,
pub access_mask: u32,
}
impl ApiKey {
fn new(user_id: String, access_mask: u32) -> Self {
ApiKey {
key: csprng(64),
user_id,
access_mask
}
}
}
async fn create_api_key(Extension(session): Extension<Session>) -> SirenResult<String> {
log::debug!("Generating API key for {} ({})", &session.user_id, &session.user_name);
let api_key = ApiKey::new(session.user_id, 0);
Ok(api_key.key)
}

View File

@@ -0,0 +1,10 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct BearerTokenClaims {
pub sub: String,
pub name: String,
pub iat: i64,
pub exp: i64,
pub jti: String,
}

View File

@@ -0,0 +1,71 @@
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::{Method, StatusCode};
use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}};
use chrono::Utc;
use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::api::auth::bearer_token::BearerTokenClaims;
use crate::api::auth::session::Session;
use crate::error::SirenResult;
pub struct AuthorizationMiddleware;
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// For options requests browsers will not send the authorization header.
if parts.method == Method::OPTIONS {
return Ok(Self);
}
let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
else {
log::error!("Could not get Authorization header from the request");
return Err(StatusCode::UNAUTHORIZED);
};
match check_auth(bearer).await {
Ok(session) => {
parts.extensions.insert(session);
Ok(Self)
},
Err(err) => {
log::error!("{:?}", err);
Err(StatusCode::UNAUTHORIZED)
}
}
}
}
async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
// Decode and validate the JWT
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
let token_data = decode::<BearerTokenClaims>(
bearer.token(),
&decoding_key,
&Validation::default()
).map_err(|_| StatusCode::UNAUTHORIZED)?;
let claims = token_data.claims;
// Check if the token has expired
let now = Utc::now().timestamp();
if claims.exp < now {
return Err(StatusCode::UNAUTHORIZED.into());
}
// Confirm the session exists in the session store (based on `jti`)
match Session::find(&claims.jti).await {
Ok(Some(session)) => Ok(session),
_ => Err(StatusCode::UNAUTHORIZED)?,
}
}

28
src/api/auth/mod.rs Normal file
View File

@@ -0,0 +1,28 @@
use std::sync::Arc;
use axum::Router;
use rand::Rng;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::SeedableRng;
use crate::AppState;
mod oauth;
mod session;
mod api_key;
mod bearer_token;
mod middleware;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.nest("/oauth", oauth::get_routes())
.merge(api_key::get_routes())
}
pub fn csprng(take: usize) -> String {
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
let rng = ChaCha20Rng::from_entropy();
rng
.sample_iter(rand::distributions::Alphanumeric)
.take(take)
.map(char::from)
.collect()
}

View File

@@ -1,48 +1,24 @@
use std::env;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use axum::extract::{Query, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::{Json, Router};
use axum::http::header::SET_COOKIE;
use axum::response::Redirect;
use axum::routing::get;
use chrono::{DateTime, Utc};
use rand::Rng;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::SeedableRng;
use redis::{AsyncCommands, RedisResult};
use serde::{Deserialize, Serialize};
use crate::{data, AppState};
use crate::api::auth::bearer_token::BearerTokenClaims;
use crate::AppState;
use crate::api::auth::csprng;
use crate::api::auth::session::Session;
use crate::error::SirenResult;
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/authorize", get(discord_authorize))
.route("/callback", get(oauth_callback))
}
fn get_session_ttl() -> i64 {
// Initialize the SESSION_TTL value lazily
*SESSION_TTL.get_or_init(|| {
env::var("SESSION_TTL")
.ok()
.and_then(|val| val.parse::<i64>().ok())
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
})
}
pub fn csprng(take: usize) -> String {
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
let rng = ChaCha20Rng::from_entropy();
rng
.sample_iter(rand::distributions::Alphanumeric)
.take(take)
.map(char::from)
.collect()
}
#[derive(Deserialize)]
struct AuthQuery {
code: String,
@@ -66,59 +42,6 @@ struct DiscordUser {
avatar: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct Session {
session_id: String,
user_id: String,
user_name: String,
pub expires_at: DateTime<Utc>,
}
impl Session {
fn new(id: String, user_id: String, user_name: String) -> Session {
let now = Utc::now();
let session_ttl = get_session_ttl();
Session {
session_id: id,
user_id,
user_name,
expires_at: now + chrono::Duration::seconds(session_ttl),
}
}
async fn insert(&self) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let session_id = self.session_id.clone();
redis
.set_ex(
session_id,
serde_json::to_string(self)?,
self.expires_at.timestamp() as u64,
)
.await?;
Ok(())
}
async fn get(session_id: String) -> SirenResult<Option<Session>> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<Option<String>> = redis.get(session_id).await;
match result {
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
}
}
async fn delete(session_id: String) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<()> = redis.del(session_id).await;
match result {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
}
}
// async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
// // Construct the Discord OAuth URL
// let discord_auth_url = format!(
@@ -137,10 +60,17 @@ async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<St
Ok(discord_auth_url)
}
#[derive(Debug, Serialize)]
pub struct BearerTokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
}
async fn oauth_callback(
State(state): State<Arc<AppState>>,
Query(query): Query<AuthQuery>,
) -> SirenResult<(HeaderMap, Json<DiscordUser>)> {
) -> SirenResult<Json<BearerTokenResponse>> {
// Exchange code for an access token
let token_response = state
.client
@@ -193,28 +123,32 @@ async fn oauth_callback(
log::debug!("User authenticated: {:?}", user_data);
// Generate a session token
let session_token = csprng(16);
let expiration = env::var("API_SESSION_TTL")
.expect("Expected a session ttl in the environment")
.parse::<u64>()
.unwrap();
// Create and insert the session
let session = Session::new(
session_token.clone(),
user_data.id.clone(),
user_data.username.clone(),
);
let session = Session::new(user_data.id.clone(), user_data.username.clone());
session.insert().await?;
let cookie_value = format!(
"session={}; HttpOnly; Path=/; Max-Age={}",
session_token, expiration
);
let issued_at = chrono::Utc::now();
let mut headers = HeaderMap::new();
headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
let claims = BearerTokenClaims {
sub: session.user_id.clone(),
name: session.user_name.clone(),
iat: issued_at.timestamp(),
exp: session.expires_at.timestamp(),
jti: session.session_id.clone(),
};
Ok((headers, Json(user_data)))
// Create the JWT
let jwt_secret = env::var("JWT_SECRET").expect("Expected a JWT secret in the environment");
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Return the bearer token and user information
let response = BearerTokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: (session.expires_at.timestamp() - issued_at.timestamp()) as u64,
};
Ok(Json(response))
}

73
src/api/auth/session.rs Normal file
View File

@@ -0,0 +1,73 @@
use std::env;
use std::sync::OnceLock;
use chrono::{DateTime, Utc};
use redis::{AsyncCommands, RedisResult};
use serde::{Deserialize, Serialize};
use crate::api::auth::csprng;
use crate::data;
use crate::error::SirenResult;
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
fn get_session_ttl() -> i64 {
// Initialize the SESSION_TTL value lazily
*SESSION_TTL.get_or_init(|| {
env::var("API_SESSION_TTL")
.ok()
.and_then(|val| val.parse::<i64>().ok())
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
})
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Session {
pub session_id: String,
pub user_id: String,
pub user_name: String,
pub expires_at: DateTime<Utc>,
}
impl Session {
pub fn new(user_id: String, user_name: String) -> Session {
let now = Utc::now();
let session_ttl = get_session_ttl();
Session {
session_id: csprng(32),
user_id,
user_name,
expires_at: now + chrono::Duration::seconds(session_ttl),
}
}
pub async fn insert(&self) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let session_id = self.session_id.clone();
redis
.set_ex(
session_id,
serde_json::to_string(self)?,
self.expires_at.timestamp() as u64,
)
.await?;
Ok(())
}
pub async fn find(session_id: &str) -> SirenResult<Option<Session>> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<Option<String>> = redis.get(session_id).await;
match result {
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
}
}
pub async fn delete(session_id: &str) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<()> = redis.del(session_id).await;
match result {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
}
}

View File

@@ -1,13 +1,12 @@
mod app;
mod oauth;
pub use app::App;
use std::sync::Arc;
use axum::Router;
use serde::{Deserialize, Serialize};
use crate::AppState;
mod app;
mod auth;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new().nest("/oauth", oauth::get_routes())
Router::new().merge(auth::get_routes())
}

View File

@@ -10,24 +10,24 @@ pub async fn process_message(ctx: &Context, command: &CommandInteraction, privat
pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> {
let data = CreateMessage::new().content(content.to_owned());
return match user_id.dm(ctx, data).await {
match user_id.dm(ctx, data).await {
Ok(message) => Some(message),
Err(err) => {
log::error!("Failed to create direct message for {content}\n{err}");
None
}
};
}
}
pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option<Message> {
let data = CreateMessage::new().content(content.to_owned());
return match user.direct_message(ctx, data).await {
match user.direct_message(ctx, data).await {
Ok(message) => Some(message),
Err(err) => {
log::error!("Failed to create direct message for {content}\n{err}");
None
}
};
}
}
pub async fn create_message_response(
@@ -50,7 +50,7 @@ pub async fn create_message_response(
}
pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) {
let mut data = CreateInteractionResponseMessage::new();
let data = CreateInteractionResponseMessage::new();
let builder = CreateInteractionResponse::Message(data);
match modal.create_response(&ctx.http, builder).await {
Ok(_) => {}

View File

@@ -1,4 +1,4 @@
use serenity::all::{CommandDataOption, CommandInteraction, Context, CreateCommand};
use serenity::all::{CommandInteraction, Context, CreateCommand};
use crate::bot::chat::create_message_response;
pub async fn run(ctx: &Context, command: &CommandInteraction) {

View File

@@ -1,4 +1,4 @@
use serenity::all::{CreateInteractionResponse, Interaction, ResumedEvent};
use serenity::all::{Interaction, ResumedEvent};
use serenity::async_trait;
use serenity::model::gateway::Ready;
use serenity::model::channel::Message;
@@ -7,7 +7,7 @@ use crate::bot::commands::chat::generate_response;
use crate::bot::oai::OAI;
use crate::data::guilds::GuildCache;
use super::{commands};
use super::chat::{create_message_response, create_modal_response};
use super::chat::{create_modal_response};
pub struct BotHandler {
// Open AI Config

View File

@@ -26,12 +26,12 @@ impl GuildCache {
)",
TABLE_NAME
))
.bind(self.id)
.bind(&self.name)
.bind(self.owner_id)
.bind(self.volume)
.execute(pool)
.await?;
.bind(self.id)
.bind(&self.name)
.bind(self.owner_id)
.bind(self.volume)
.execute(pool)
.await?;
Ok(())
}
@@ -40,10 +40,7 @@ impl GuildCache {
let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder
.build();
let item = sqlx::query_as(&query)
.bind(id)
.fetch_optional(pool)
.await?;
let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?;
Ok(item)
}
@@ -58,12 +55,12 @@ impl GuildCache {
WHERE id = $1",
TABLE_NAME
))
.bind(self.id)
.bind(&self.name)
.bind(self.owner_id)
.bind(self.volume)
.execute(pool)
.await?;
.bind(self.id)
.bind(&self.name)
.bind(self.owner_id)
.bind(self.volume)
.execute(pool)
.await?;
Ok(())
}
}

View File

@@ -160,4 +160,4 @@ impl Condition {
Condition::Group(a) => format!("({})", a.to_sql()),
}
}
}
}

View File

@@ -1,5 +1,4 @@
use std::env;
use std::collections::HashSet;
use std::sync::Arc;
use serenity::http::Http;
use serenity::prelude::*;
@@ -72,7 +71,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
cache: Arc::clone(&client.cache),
};
log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner);
log::debug!(
"Starting Siren with ID: {bot_id} (Contact: {:?})",
bot_owner
);
// Spawn shutdown signal handling
let shard_manager = Arc::clone(&client.shard_manager);