Placeholder while updating query

This commit is contained in:
2024-12-21 21:36:05 -05:00
parent ceea975836
commit 7718cf19c3
17 changed files with 298 additions and 33 deletions

View File

@@ -5,7 +5,7 @@ meta {
} }
post { post {
url: {{baseUrl}}/audio/1061092965579235398/pause url: {{baseUrl}}/audio/{{server}}/pause
body: json body: json
auth: inherit auth: inherit
} }

View File

@@ -5,7 +5,7 @@ meta {
} }
post { post {
url: {{baseUrl}}/audio/1061092965579235398/play url: {{baseUrl}}/audio/{{server}}/play
body: json body: json
auth: inherit auth: inherit
} }

View File

@@ -5,7 +5,7 @@ meta {
} }
post { post {
url: {{baseUrl}}/audio/1061092965579235398/resume url: {{baseUrl}}/audio/{{server}}/resume
body: json body: json
auth: inherit auth: inherit
} }

View File

@@ -4,9 +4,6 @@ auth {
auth:apikey { auth:apikey {
key: X-API-Key key: X-API-Key
value: rwOS4yMmNpQvL0vLHc1jWQoefJB1bvKvOvBSswiYh0mkhZDc1lsgFZmpXaSUXAa5ZjpRWR117hLQ1l0VPPSGkRXZl7dPRVCc value: {{apiKey}}
placement: header placement: header
} }
vars:pre-request {
baseUrl: http://localhost:3000/api
}

17
bruno/dice/Track.bru Normal file
View File

@@ -0,0 +1,17 @@
meta {
name: Track
type: http
seq: 1
}
post {
url: {{baseUrl}}/dice/{{server}}/track
body: json
auth: inherit
}
body:json {
{
"dice": "1d4"
}
}

View File

@@ -0,0 +1,7 @@
vars {
baseUrl: http://localhost:3000/api
server: 1061092965579235398
}
vars:secret [
apiKey
]

View File

@@ -7,5 +7,9 @@ meta {
post { post {
url: {{baseUrl}}/api-key url: {{baseUrl}}/api-key
body: none body: none
auth: inherit auth: bearer
}
auth:bearer {
token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDgwODA5MiwiZXhwIjoxNzM0ODk0NDkyLCJqdGkiOiJsSWFHaU15Wll5cnFVYmFJTGs2dzAyZTY4YkFPZjFZWSJ9.fCeooH2IdtXiy2s23WykXtOaR8dvnUinmSGFcV-fOwQ
} }

View File

@@ -20,10 +20,13 @@ CREATE TABLE IF NOT EXISTS api_keys (
key TEXT PRIMARY KEY NOT NULL, key TEXT PRIMARY KEY NOT NULL,
user_id BIGINT NOT NULL, user_id BIGINT NOT NULL,
user_name TEXT NOT NULL, user_name TEXT NOT NULL,
access_mask INT access_mask INT,
created_at TIMESTAMPTZ NOT NULL,
last_used_at TIMESTAMPTZ
); );
CREATE TABLE IF NOT EXISTS dice_thresholds ( CREATE TABLE IF NOT EXISTS dice_track (
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(), id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(),
guild_id BIGINT NOT NULL,
owner_id BIGINT NOT NULL, owner_id BIGINT NOT NULL,
dice TEXT NOT NULL, dice TEXT NOT NULL,
user_id BIGINT, user_id BIGINT,

View File

@@ -2,11 +2,9 @@ use std::sync::Arc;
use axum::extract::{Path, State}; use axum::extract::{Path, State};
use axum::middleware::from_extractor; use axum::middleware::from_extractor;
use axum::{Extension, Json, Router}; use axum::{Extension, Json, Router};
use axum::response::IntoResponse;
use axum::routing::post; use axum::routing::post;
use reqwest::StatusCode;
use serde::Deserialize; use serde::Deserialize;
use crate::api::auth::{AuthCredential, AuthorizationMiddleware, Session}; use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
use crate::AppState; use crate::AppState;
use crate::bot::commands::audio::join_voice_channel; use crate::bot::commands::audio::join_voice_channel;
use crate::bot::commands::audio::pause::pause_track; use crate::bot::commands::audio::pause::pause_track;

View File

@@ -2,11 +2,10 @@ use std::sync::Arc;
use axum::{Extension, Router}; use axum::{Extension, Router};
use axum::middleware::from_extractor; use axum::middleware::from_extractor;
use axum::routing::post; use axum::routing::post;
use reqwest::StatusCode; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::auth::{csprng, AuthCredential}; use crate::api::auth::{csprng, AuthCredential};
use crate::api::auth::AuthorizationMiddleware; use crate::api::auth::AuthorizationMiddleware;
use crate::api::auth::session::Session;
use crate::AppState; use crate::AppState;
use crate::data::query::{Condition, QueryBuilder}; use crate::data::query::{Condition, QueryBuilder};
use crate::error::{Error, SirenResult}; use crate::error::{Error, SirenResult};
@@ -25,6 +24,8 @@ pub struct ApiKey {
pub user_id: i64, pub user_id: i64,
pub user_name: String, pub user_name: String,
pub access_mask: i32, pub access_mask: i32,
pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>>
} }
impl ApiKey { impl ApiKey {
@@ -34,6 +35,8 @@ impl ApiKey {
user_id: user_id as i64, user_id: user_id as i64,
user_name, user_name,
access_mask, access_mask,
created_at: Utc::now(),
last_used_at: None,
} }
} }
@@ -44,9 +47,11 @@ impl ApiKey {
key, key,
user_id, user_id,
user_name, user_name,
access_mask access_mask,
created_at,
last_used_at
) VALUES ( ) VALUES (
$1, $2, $3, $4 $1, $2, $3, $4, $5, $6
)", )",
TABLE_NAME TABLE_NAME
)) ))
@@ -54,11 +59,36 @@ impl ApiKey {
.bind(self.user_id) .bind(self.user_id)
.bind(&self.user_name) .bind(&self.user_name)
.bind(self.access_mask) .bind(self.access_mask)
.bind(self.created_at)
.bind(self.last_used_at)
.execute(pool) .execute(pool)
.await?; .await?;
Ok(()) Ok(())
} }
pub async fn update(&self) -> SirenResult<()> {
let pool = crate::data::pool();
sqlx::query(&format!(
"UPDATE {} SET
user_id = $2,
user_name = $3,
access_mask = $4,
created_at = $5,
last_used_at = $6
WHERE key = $1",
TABLE_NAME
))
.bind(&self.key)
.bind(self.user_id)
.bind(&self.user_name)
.bind(self.access_mask)
.bind(self.created_at)
.bind(self.last_used_at)
.execute(pool)
.await?;
Ok(())
}
pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> { pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> {
let pool = crate::data::pool(); let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME) let query = QueryBuilder::new(TABLE_NAME)
@@ -84,8 +114,13 @@ impl ApiKey {
async fn create_api_key(Extension(credential): Extension<AuthCredential>) -> SirenResult<String> { async fn create_api_key(Extension(credential): Extension<AuthCredential>) -> SirenResult<String> {
let session = match credential { let session = match credential {
AuthCredential::ApiKey(_) => return Err(Error::new(400, "API keys cannot be generated with an API key".to_string())), AuthCredential::ApiKey(_) => {
AuthCredential::Session(session) => session return Err(Error::new(
400,
"API keys cannot be generated using an existing API key for authentication.".to_string(),
))
}
AuthCredential::Session(session) => session,
}; };
log::debug!( log::debug!(
"Generating API key for {} ({})", "Generating API key for {} ({})",

View File

@@ -87,11 +87,14 @@ async fn check_bearer_auth(bearer_token: &str) -> SirenResult<Session> {
} }
async fn check_api_key_auth(key: &str) -> SirenResult<ApiKey> { async fn check_api_key_auth(key: &str) -> SirenResult<ApiKey> {
let mut api_key = match ApiKey::find_by_key(key).await? {
let api_key = match ApiKey::find_by_key(key).await? {
Some(api_key) => api_key, Some(api_key) => api_key,
None => return Err(StatusCode::UNAUTHORIZED.into()), None => return Err(StatusCode::UNAUTHORIZED.into()),
}; };
// Update when the API key was last used
api_key.last_used_at = Some(Utc::now());
api_key.update().await?;
Ok(api_key) Ok(api_key)
} }

187
src/api/dice/mod.rs Normal file
View File

@@ -0,0 +1,187 @@
use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
use axum::{Extension, Json, Router};
use axum::extract::{Path, State};
use axum::middleware::from_extractor;
use axum::routing::post;
use axum_extra::handler::HandlerCallWithExtractors;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
use crate::AppState;
use crate::bot::commands::fun::roll::{format_roll, parse_dice};
use crate::data::query::{Condition, QueryBuilder};
use crate::error::{Error, SirenResult};
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/:guild_id/track", post(add_track_dice))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
const TABLE_NAME: &str = "dice_track";
#[derive(Serialize, Deserialize, Clone, Debug)]
enum TrackDiceOperator {
#[serde(rename = "eq")]
Equal,
#[serde(rename = "lt")]
LessThan,
#[serde(rename = "lte")]
LessThanEqual,
#[serde(rename = "gt")]
GreaterThan,
#[serde(rename = "gte")]
GreaterThanEqual,
}
// Implementing the ToString trait for converting the enum to a string
impl Display for TrackDiceOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
TrackDiceOperator::Equal => "eq".to_string(),
TrackDiceOperator::LessThan => "lt".to_string(),
TrackDiceOperator::LessThanEqual => "lte".to_string(),
TrackDiceOperator::GreaterThan => "gt".to_string(),
TrackDiceOperator::GreaterThanEqual => "gte".to_string(),
};
write!(f, "{}", str)
}
}
// Implementing the FromStr trait for parsing a string into the enum
impl FromStr for TrackDiceOperator {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"eq" => Ok(TrackDiceOperator::Equal),
"lt" => Ok(TrackDiceOperator::LessThan),
"lte" => Ok(TrackDiceOperator::LessThanEqual),
"gt" => Ok(TrackDiceOperator::GreaterThan),
"gte" => Ok(TrackDiceOperator::GreaterThanEqual),
_ => Err(format!("Unknown value for TrackDiceOperator: {}", s)),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
struct DiceTrackPayload {
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<TrackDiceOperator>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
struct InsertDiceTrack {
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
struct QueryDiceTrack {
id: Uuid,
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
impl QueryDiceTrack {
pub async fn find() -> SirenResult<Vec<Self>> {
let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME)
// .where_condition(
// Condition::and(
// Condition::is_equal("guild_id", "$1"),
// Condition::and(
// Condition::is_equal("owner_id", "$2"),
//
// )
// )
// )
.build();
let items: Vec<QueryDiceTrack> = sqlx::query_as(&query)
.fetch_all(pool).await?;
Ok(items)
}
}
impl InsertDiceTrack {
pub async fn insert(&self) -> SirenResult<QueryDiceTrack> {
let pool = crate::data::pool();
let query = format!(
"INSERT INTO {} (
guild_id,
owner_id,
dice,
user_id,
value,
operator
) VALUES (
$1, $2, $3, $4, $5, $6
) RETURNING *",
TABLE_NAME
);
let item: QueryDiceTrack = match sqlx::query_as(&query)
.bind(self.guild_id)
.bind(self.owner_id)
.bind(&self.dice)
.bind(self.user_id)
.bind(self.value)
.bind(&self.operator)
.fetch_optional(pool).await? {
Some(result) => result,
None => return Err(Error::new(500, "Error storing".to_string()))
};
Ok(item)
}
}
pub async fn add_track_dice(
Extension(credential): Extension<AuthCredential>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
Json(payload): Json<DiceTrackPayload>,
) -> SirenResult<Json<QueryDiceTrack>> {
// Check if the user exists in the cache
let owner_id = credential.user_id();
let owner_id = match state.cache.user(owner_id) {
Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())),
};
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
let dice = parse_dice(&payload.dice)?;
let dice = InsertDiceTrack {
guild_id: guild_id.get() as i64,
owner_id: owner_id.get() as i64,
dice: format_roll(dice.0, dice.1, dice.2),
user_id: payload.user_id,
value: payload.value,
operator: match payload.operator {
None => None,
Some(s) => Some(s.to_string()),
}
};
let dice_track = dice.insert().await?;
Ok(Json(dice_track))
}

View File

@@ -2,14 +2,17 @@ pub use app::App;
use std::sync::Arc; use std::sync::Arc;
use axum::Router; use axum::Router;
use serde::Deserialize;
use crate::AppState; use crate::AppState;
mod app; mod app;
mod audio; mod audio;
mod auth; mod auth;
mod dice;
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()
.merge(auth::get_routes()) .merge(auth::get_routes())
.nest("/audio/:guild_id", audio::get_routes()) .nest("/audio/:guild_id", audio::get_routes())
.nest("/dice", dice::get_routes())
} }

View File

@@ -1,7 +1,7 @@
use serenity::all::{ use serenity::all::{
CommandInteraction, Context, CreateInteractionResponse, CreateInteractionResponseMessage, CommandInteraction, Context, CreateInteractionResponse, CreateInteractionResponseMessage,
CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, ModalInteraction, CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, ModalInteraction,
User, UserId, UserId,
}; };
pub async fn process_message(ctx: &Context, command: &CommandInteraction, private: bool) { pub async fn process_message(ctx: &Context, command: &CommandInteraction, private: bool) {

View File

@@ -7,6 +7,7 @@ use serenity::all::{
}; };
use crate::bot::chat::{create_message_response, edit_response}; use crate::bot::chat::{create_message_response, edit_response};
use crate::error::{Error, SirenResult};
use crate::utils::{a_or_an, number_to_words}; use crate::utils::{a_or_an, number_to_words};
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {
@@ -64,6 +65,8 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
} }
None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await, None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await,
}; };
// Check for dice tracks
} }
Err(why) => { Err(why) => {
edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await; edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await;
@@ -123,7 +126,7 @@ pub fn roll_dice(count: u32, sides: u32, modifier: i32) -> i32 {
total total
} }
pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { pub fn parse_dice(dice: &str) -> SirenResult<(u32, u32, i32)> {
// If the input is just a number (e.g., "20" or "6"), assume it's the number of sides // If the input is just a number (e.g., "20" or "6"), assume it's the number of sides
if let Ok(n) = dice.parse::<u32>() { if let Ok(n) = dice.parse::<u32>() {
return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers
@@ -144,31 +147,31 @@ pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
Some("") => 1, // Handle cases like "d6", assume 1 dice Some("") => 1, // Handle cases like "d6", assume 1 dice
Some(c) => match c.parse::<u32>() { Some(c) => match c.parse::<u32>() {
Ok(n) => n, Ok(n) => n,
Err(_) => return Err(format!("Invalid dice count: {}", c)), Err(_) => return Err(Error::new(400, format!("Invalid dice count: {}", c))),
}, },
None => return Err(format!("Invalid dice string: {}", dice)), None => return Err(Error::new(400, format!("Invalid dice string: {}", dice))),
}; };
// Parse the number of sides // Parse the number of sides
let sides_part = parts let sides_part = parts
.next() .next()
.ok_or_else(|| format!("Invalid dice string: {}", dice))?; .ok_or_else(|| Error::new(400, format!("Invalid dice string: {}", dice)))?;
let sides = match sides_part.parse::<u32>() { let sides = match sides_part.parse::<u32>() {
Ok(n) => { Ok(n) => {
if [4, 6, 8, 10, 12, 20, 100].contains(&n) { if [4, 6, 8, 10, 12, 20, 100].contains(&n) {
n n
} else { } else {
return Err(format!( return Err(Error::new(400, format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
n n
)); )));
} }
} }
Err(_) => { Err(_) => {
return Err(format!( return Err(Error::new(400, format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
sides_part sides_part
)) )))
} }
}; };
@@ -189,7 +192,7 @@ pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
-n -n
} }
} }
Err(_) => return Err(format!("Invalid dice modifier: {}", m)), Err(_) => return Err(Error::new(400, format!("Invalid dice modifier: {}", m))),
}, },
None => 0, // No modifier found None => 0, // No modifier found
}; };

View File

@@ -23,6 +23,7 @@ pub struct BotHandler {
pub oai: Option<OAI>, pub oai: Option<OAI>,
} }
static REGISTERED: OnceLock<bool> = OnceLock::new();
static SONGBIRD: OnceLock<Arc<Songbird>> = OnceLock::new(); static SONGBIRD: OnceLock<Arc<Songbird>> = OnceLock::new();
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new(); static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
@@ -104,6 +105,13 @@ impl EventHandler for BotHandler {
CLIENT.set(http_client).ok(); CLIENT.set(http_client).ok();
} }
// Update registered to prevent reloading the commands
if REGISTERED.get().is_some() {
return;
} else {
REGISTERED.set(true).ok();
}
log::trace!("Handling {} guilds", ready.guilds.len()); log::trace!("Handling {} guilds", ready.guilds.len());
for guild in ready.guilds { for guild in ready.guilds {
// Check if guild exists in database // Check if guild exists in database

View File

@@ -38,7 +38,7 @@ impl GuildCache {
pub async fn find_by_id(id: i64) -> SirenResult<Option<Self>> { pub async fn find_by_id(id: i64) -> SirenResult<Option<Self>> {
let pool = crate::data::pool(); let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME) let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder .where_condition(Condition::is_equal("id", "$1"))
.build(); .build();
let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?; let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?;