diff --git a/bruno/audio/Pause Track.bru b/bruno/audio/Pause Track.bru index 36a8cf6..5aeff7d 100644 --- a/bruno/audio/Pause Track.bru +++ b/bruno/audio/Pause Track.bru @@ -5,7 +5,7 @@ meta { } post { - url: {{baseUrl}}/audio/1061092965579235398/pause + url: {{baseUrl}}/audio/{{server}}/pause body: json auth: inherit } diff --git a/bruno/audio/Play Track.bru b/bruno/audio/Play Track.bru index f2c9821..62b5b60 100644 --- a/bruno/audio/Play Track.bru +++ b/bruno/audio/Play Track.bru @@ -5,7 +5,7 @@ meta { } post { - url: {{baseUrl}}/audio/1061092965579235398/play + url: {{baseUrl}}/audio/{{server}}/play body: json auth: inherit } diff --git a/bruno/audio/Resume Track.bru b/bruno/audio/Resume Track.bru index 33e0186..6248a28 100644 --- a/bruno/audio/Resume Track.bru +++ b/bruno/audio/Resume Track.bru @@ -5,7 +5,7 @@ meta { } post { - url: {{baseUrl}}/audio/1061092965579235398/resume + url: {{baseUrl}}/audio/{{server}}/resume body: json auth: inherit } diff --git a/bruno/collection.bru b/bruno/collection.bru index 0157b98..07977f9 100644 --- a/bruno/collection.bru +++ b/bruno/collection.bru @@ -4,9 +4,6 @@ auth { auth:apikey { key: X-API-Key - value: rwOS4yMmNpQvL0vLHc1jWQoefJB1bvKvOvBSswiYh0mkhZDc1lsgFZmpXaSUXAa5ZjpRWR117hLQ1l0VPPSGkRXZl7dPRVCc + value: {{apiKey}} placement: header -} -vars:pre-request { - baseUrl: http://localhost:3000/api -} +} \ No newline at end of file diff --git a/bruno/dice/Track.bru b/bruno/dice/Track.bru new file mode 100644 index 0000000..9fdcde1 --- /dev/null +++ b/bruno/dice/Track.bru @@ -0,0 +1,17 @@ +meta { + name: Track + type: http + seq: 1 +} + +post { + url: {{baseUrl}}/dice/{{server}}/track + body: json + auth: inherit +} + +body:json { + { + "dice": "1d4" + } +} diff --git a/bruno/environments/Localhost.bru b/bruno/environments/Localhost.bru new file mode 100644 index 0000000..27db5cb --- /dev/null +++ b/bruno/environments/Localhost.bru @@ -0,0 +1,7 @@ +vars { + baseUrl: http://localhost:3000/api + server: 1061092965579235398 +} +vars:secret [ + apiKey +] diff --git a/bruno/oauth/Create API Key.bru b/bruno/oauth/Create API Key.bru index 848acb1..08a5b80 100644 --- a/bruno/oauth/Create API Key.bru +++ b/bruno/oauth/Create API Key.bru @@ -7,5 +7,9 @@ meta { post { url: {{baseUrl}}/api-key body: none - auth: inherit + auth: bearer +} + +auth:bearer { + token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDgwODA5MiwiZXhwIjoxNzM0ODk0NDkyLCJqdGkiOiJsSWFHaU15Wll5cnFVYmFJTGs2dzAyZTY4YkFPZjFZWSJ9.fCeooH2IdtXiy2s23WykXtOaR8dvnUinmSGFcV-fOwQ } diff --git a/migrations/000_initial.sql b/migrations/000_initial.sql index 8532b5e..b7f6465 100644 --- a/migrations/000_initial.sql +++ b/migrations/000_initial.sql @@ -20,10 +20,13 @@ CREATE TABLE IF NOT EXISTS api_keys ( key TEXT PRIMARY KEY NOT NULL, user_id BIGINT NOT NULL, user_name TEXT NOT NULL, - access_mask INT + access_mask INT, + created_at TIMESTAMPTZ NOT NULL, + last_used_at TIMESTAMPTZ ); -CREATE TABLE IF NOT EXISTS dice_thresholds ( +CREATE TABLE IF NOT EXISTS dice_track ( id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(), + guild_id BIGINT NOT NULL, owner_id BIGINT NOT NULL, dice TEXT NOT NULL, user_id BIGINT, diff --git a/src/api/audio/mod.rs b/src/api/audio/mod.rs index 02fbb63..b265c1a 100644 --- a/src/api/audio/mod.rs +++ b/src/api/audio/mod.rs @@ -2,11 +2,9 @@ use std::sync::Arc; use axum::extract::{Path, State}; use axum::middleware::from_extractor; use axum::{Extension, Json, Router}; -use axum::response::IntoResponse; use axum::routing::post; -use reqwest::StatusCode; use serde::Deserialize; -use crate::api::auth::{AuthCredential, AuthorizationMiddleware, Session}; +use crate::api::auth::{AuthCredential, AuthorizationMiddleware}; use crate::AppState; use crate::bot::commands::audio::join_voice_channel; use crate::bot::commands::audio::pause::pause_track; diff --git a/src/api/auth/api_key.rs b/src/api/auth/api_key.rs index ea8b31a..e063176 100644 --- a/src/api/auth/api_key.rs +++ b/src/api/auth/api_key.rs @@ -2,11 +2,10 @@ use std::sync::Arc; use axum::{Extension, Router}; use axum::middleware::from_extractor; use axum::routing::post; -use reqwest::StatusCode; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use crate::api::auth::{csprng, AuthCredential}; use crate::api::auth::AuthorizationMiddleware; -use crate::api::auth::session::Session; use crate::AppState; use crate::data::query::{Condition, QueryBuilder}; use crate::error::{Error, SirenResult}; @@ -25,6 +24,8 @@ pub struct ApiKey { pub user_id: i64, pub user_name: String, pub access_mask: i32, + pub created_at: DateTime, + pub last_used_at: Option> } impl ApiKey { @@ -34,6 +35,8 @@ impl ApiKey { user_id: user_id as i64, user_name, access_mask, + created_at: Utc::now(), + last_used_at: None, } } @@ -44,9 +47,11 @@ impl ApiKey { key, user_id, user_name, - access_mask + access_mask, + created_at, + last_used_at ) VALUES ( - $1, $2, $3, $4 + $1, $2, $3, $4, $5, $6 )", TABLE_NAME )) @@ -54,11 +59,36 @@ impl ApiKey { .bind(self.user_id) .bind(&self.user_name) .bind(self.access_mask) + .bind(self.created_at) + .bind(self.last_used_at) .execute(pool) .await?; Ok(()) } + pub async fn update(&self) -> SirenResult<()> { + let pool = crate::data::pool(); + sqlx::query(&format!( + "UPDATE {} SET + user_id = $2, + user_name = $3, + access_mask = $4, + created_at = $5, + last_used_at = $6 + WHERE key = $1", + TABLE_NAME + )) + .bind(&self.key) + .bind(self.user_id) + .bind(&self.user_name) + .bind(self.access_mask) + .bind(self.created_at) + .bind(self.last_used_at) + .execute(pool) + .await?; + Ok(()) + } + pub async fn find_by_key(key: &str) -> SirenResult> { let pool = crate::data::pool(); let query = QueryBuilder::new(TABLE_NAME) @@ -84,8 +114,13 @@ impl ApiKey { async fn create_api_key(Extension(credential): Extension) -> SirenResult { let session = match credential { - AuthCredential::ApiKey(_) => return Err(Error::new(400, "API keys cannot be generated with an API key".to_string())), - AuthCredential::Session(session) => session + AuthCredential::ApiKey(_) => { + return Err(Error::new( + 400, + "API keys cannot be generated using an existing API key for authentication.".to_string(), + )) + } + AuthCredential::Session(session) => session, }; log::debug!( "Generating API key for {} ({})", diff --git a/src/api/auth/middleware.rs b/src/api/auth/middleware.rs index fcc7477..f00342e 100644 --- a/src/api/auth/middleware.rs +++ b/src/api/auth/middleware.rs @@ -87,11 +87,14 @@ async fn check_bearer_auth(bearer_token: &str) -> SirenResult { } async fn check_api_key_auth(key: &str) -> SirenResult { - - let api_key = match ApiKey::find_by_key(key).await? { + let mut api_key = match ApiKey::find_by_key(key).await? { Some(api_key) => api_key, None => return Err(StatusCode::UNAUTHORIZED.into()), }; + // Update when the API key was last used + api_key.last_used_at = Some(Utc::now()); + api_key.update().await?; + Ok(api_key) } diff --git a/src/api/dice/mod.rs b/src/api/dice/mod.rs new file mode 100644 index 0000000..496dff8 --- /dev/null +++ b/src/api/dice/mod.rs @@ -0,0 +1,187 @@ +use std::fmt::Display; +use std::str::FromStr; +use std::sync::Arc; +use axum::{Extension, Json, Router}; +use axum::extract::{Path, State}; +use axum::middleware::from_extractor; +use axum::routing::post; +use axum_extra::handler::HandlerCallWithExtractors; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; +use crate::api::auth::{AuthCredential, AuthorizationMiddleware}; +use crate::AppState; +use crate::bot::commands::fun::roll::{format_roll, parse_dice}; +use crate::data::query::{Condition, QueryBuilder}; +use crate::error::{Error, SirenResult}; + +pub fn get_routes() -> Router> { + Router::new() + .route("/:guild_id/track", post(add_track_dice)) + .route_layer(from_extractor::()) +} + +const TABLE_NAME: &str = "dice_track"; + +#[derive(Serialize, Deserialize, Clone, Debug)] +enum TrackDiceOperator { + #[serde(rename = "eq")] + Equal, + #[serde(rename = "lt")] + LessThan, + #[serde(rename = "lte")] + LessThanEqual, + #[serde(rename = "gt")] + GreaterThan, + #[serde(rename = "gte")] + GreaterThanEqual, +} + + +// Implementing the ToString trait for converting the enum to a string +impl Display for TrackDiceOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let str = match self { + TrackDiceOperator::Equal => "eq".to_string(), + TrackDiceOperator::LessThan => "lt".to_string(), + TrackDiceOperator::LessThanEqual => "lte".to_string(), + TrackDiceOperator::GreaterThan => "gt".to_string(), + TrackDiceOperator::GreaterThanEqual => "gte".to_string(), + }; + write!(f, "{}", str) + } +} + +// Implementing the FromStr trait for parsing a string into the enum +impl FromStr for TrackDiceOperator { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "eq" => Ok(TrackDiceOperator::Equal), + "lt" => Ok(TrackDiceOperator::LessThan), + "lte" => Ok(TrackDiceOperator::LessThanEqual), + "gt" => Ok(TrackDiceOperator::GreaterThan), + "gte" => Ok(TrackDiceOperator::GreaterThanEqual), + _ => Err(format!("Unknown value for TrackDiceOperator: {}", s)), + } + } +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +struct DiceTrackPayload { + dice: String, + user_id: Option, + value: Option, + operator: Option, +} + +#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)] +struct InsertDiceTrack { + guild_id: i64, + owner_id: i64, + dice: String, + user_id: Option, + value: Option, + operator: Option, +} + +#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)] +struct QueryDiceTrack { + id: Uuid, + guild_id: i64, + owner_id: i64, + dice: String, + user_id: Option, + value: Option, + operator: Option, +} + +impl QueryDiceTrack { + pub async fn find() -> SirenResult> { + let pool = crate::data::pool(); + let query = QueryBuilder::new(TABLE_NAME) + // .where_condition( + // Condition::and( + // Condition::is_equal("guild_id", "$1"), + // Condition::and( + // Condition::is_equal("owner_id", "$2"), + // + // ) + // ) + // ) + .build(); + let items: Vec = sqlx::query_as(&query) + .fetch_all(pool).await?; + + Ok(items) + } +} + +impl InsertDiceTrack { + pub async fn insert(&self) -> SirenResult { + let pool = crate::data::pool(); + let query = format!( + "INSERT INTO {} ( + guild_id, + owner_id, + dice, + user_id, + value, + operator + ) VALUES ( + $1, $2, $3, $4, $5, $6 + ) RETURNING *", + TABLE_NAME + ); + let item: QueryDiceTrack = match sqlx::query_as(&query) + .bind(self.guild_id) + .bind(self.owner_id) + .bind(&self.dice) + .bind(self.user_id) + .bind(self.value) + .bind(&self.operator) + .fetch_optional(pool).await? { + Some(result) => result, + None => return Err(Error::new(500, "Error storing".to_string())) + }; + Ok(item) + } +} + +pub async fn add_track_dice( + Extension(credential): Extension, + State(state): State>, + Path(guild_id): Path, + Json(payload): Json, +) -> SirenResult> { + + // Check if the user exists in the cache + let owner_id = credential.user_id(); + let owner_id = match state.cache.user(owner_id) { + Some(user) => user.id, + None => return Err(Error::not_found("User not found".to_string())), + }; + + // Validate if the guild exists in the cache + let guild_id = match state.cache.guild(guild_id) { + Some(guild) => guild.id, + None => return Err(Error::not_found("Guild not found".to_string())), + }; + + let dice = parse_dice(&payload.dice)?; + + let dice = InsertDiceTrack { + guild_id: guild_id.get() as i64, + owner_id: owner_id.get() as i64, + dice: format_roll(dice.0, dice.1, dice.2), + user_id: payload.user_id, + value: payload.value, + operator: match payload.operator { + None => None, + Some(s) => Some(s.to_string()), + } + }; + + let dice_track = dice.insert().await?; + Ok(Json(dice_track)) +} \ No newline at end of file diff --git a/src/api/mod.rs b/src/api/mod.rs index 2992c01..e4ea122 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,14 +2,17 @@ pub use app::App; use std::sync::Arc; use axum::Router; +use serde::Deserialize; use crate::AppState; mod app; mod audio; mod auth; +mod dice; pub fn get_routes() -> Router> { Router::new() .merge(auth::get_routes()) .nest("/audio/:guild_id", audio::get_routes()) + .nest("/dice", dice::get_routes()) } diff --git a/src/bot/chat/mod.rs b/src/bot/chat/mod.rs index 295f3de..e27b803 100644 --- a/src/bot/chat/mod.rs +++ b/src/bot/chat/mod.rs @@ -1,7 +1,7 @@ use serenity::all::{ CommandInteraction, Context, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, ModalInteraction, - User, UserId, + UserId, }; pub async fn process_message(ctx: &Context, command: &CommandInteraction, private: bool) { diff --git a/src/bot/commands/fun/roll.rs b/src/bot/commands/fun/roll.rs index 674e48c..d4ae64d 100644 --- a/src/bot/commands/fun/roll.rs +++ b/src/bot/commands/fun/roll.rs @@ -7,6 +7,7 @@ use serenity::all::{ }; use crate::bot::chat::{create_message_response, edit_response}; +use crate::error::{Error, SirenResult}; use crate::utils::{a_or_an, number_to_words}; pub async fn run(ctx: &Context, command: &CommandInteraction) { @@ -64,6 +65,8 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { } None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await, }; + // Check for dice tracks + } Err(why) => { edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await; @@ -123,7 +126,7 @@ pub fn roll_dice(count: u32, sides: u32, modifier: i32) -> i32 { total } -pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { +pub fn parse_dice(dice: &str) -> SirenResult<(u32, u32, i32)> { // If the input is just a number (e.g., "20" or "6"), assume it's the number of sides if let Ok(n) = dice.parse::() { return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers @@ -144,31 +147,31 @@ pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { Some("") => 1, // Handle cases like "d6", assume 1 dice Some(c) => match c.parse::() { Ok(n) => n, - Err(_) => return Err(format!("Invalid dice count: {}", c)), + Err(_) => return Err(Error::new(400, format!("Invalid dice count: {}", c))), }, - None => return Err(format!("Invalid dice string: {}", dice)), + None => return Err(Error::new(400, format!("Invalid dice string: {}", dice))), }; // Parse the number of sides let sides_part = parts .next() - .ok_or_else(|| format!("Invalid dice string: {}", dice))?; + .ok_or_else(|| Error::new(400, format!("Invalid dice string: {}", dice)))?; let sides = match sides_part.parse::() { Ok(n) => { if [4, 6, 8, 10, 12, 20, 100].contains(&n) { n } else { - return Err(format!( + return Err(Error::new(400, format!( "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", n - )); + ))); } } Err(_) => { - return Err(format!( + return Err(Error::new(400, format!( "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", sides_part - )) + ))) } }; @@ -189,7 +192,7 @@ pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { -n } } - Err(_) => return Err(format!("Invalid dice modifier: {}", m)), + Err(_) => return Err(Error::new(400, format!("Invalid dice modifier: {}", m))), }, None => 0, // No modifier found }; diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 6832739..2a2930f 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -23,6 +23,7 @@ pub struct BotHandler { pub oai: Option, } +static REGISTERED: OnceLock = OnceLock::new(); static SONGBIRD: OnceLock> = OnceLock::new(); static CLIENT: OnceLock = OnceLock::new(); @@ -104,6 +105,13 @@ impl EventHandler for BotHandler { CLIENT.set(http_client).ok(); } + // Update registered to prevent reloading the commands + if REGISTERED.get().is_some() { + return; + } else { + REGISTERED.set(true).ok(); + } + log::trace!("Handling {} guilds", ready.guilds.len()); for guild in ready.guilds { // Check if guild exists in database diff --git a/src/data/guilds/model.rs b/src/data/guilds/model.rs index 84f94cb..524b1b5 100644 --- a/src/data/guilds/model.rs +++ b/src/data/guilds/model.rs @@ -38,7 +38,7 @@ impl GuildCache { pub async fn find_by_id(id: i64) -> SirenResult> { let pool = crate::data::pool(); let query = QueryBuilder::new(TABLE_NAME) - .where_condition(Condition::is_equal("id", "$1")) // Use a placeholder + .where_condition(Condition::is_equal("id", "$1")) .build(); let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?;