Placeholder while updating query

This commit is contained in:
2024-12-21 21:36:05 -05:00
parent ceea975836
commit 7718cf19c3
17 changed files with 298 additions and 33 deletions

View File

@@ -5,7 +5,7 @@ meta {
}
post {
url: {{baseUrl}}/audio/1061092965579235398/pause
url: {{baseUrl}}/audio/{{server}}/pause
body: json
auth: inherit
}

View File

@@ -5,7 +5,7 @@ meta {
}
post {
url: {{baseUrl}}/audio/1061092965579235398/play
url: {{baseUrl}}/audio/{{server}}/play
body: json
auth: inherit
}

View File

@@ -5,7 +5,7 @@ meta {
}
post {
url: {{baseUrl}}/audio/1061092965579235398/resume
url: {{baseUrl}}/audio/{{server}}/resume
body: json
auth: inherit
}

View File

@@ -4,9 +4,6 @@ auth {
auth:apikey {
key: X-API-Key
value: rwOS4yMmNpQvL0vLHc1jWQoefJB1bvKvOvBSswiYh0mkhZDc1lsgFZmpXaSUXAa5ZjpRWR117hLQ1l0VPPSGkRXZl7dPRVCc
value: {{apiKey}}
placement: header
}
vars:pre-request {
baseUrl: http://localhost:3000/api
}
}

17
bruno/dice/Track.bru Normal file
View File

@@ -0,0 +1,17 @@
meta {
name: Track
type: http
seq: 1
}
post {
url: {{baseUrl}}/dice/{{server}}/track
body: json
auth: inherit
}
body:json {
{
"dice": "1d4"
}
}

View File

@@ -0,0 +1,7 @@
vars {
baseUrl: http://localhost:3000/api
server: 1061092965579235398
}
vars:secret [
apiKey
]

View File

@@ -7,5 +7,9 @@ meta {
post {
url: {{baseUrl}}/api-key
body: none
auth: inherit
auth: bearer
}
auth:bearer {
token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDgwODA5MiwiZXhwIjoxNzM0ODk0NDkyLCJqdGkiOiJsSWFHaU15Wll5cnFVYmFJTGs2dzAyZTY4YkFPZjFZWSJ9.fCeooH2IdtXiy2s23WykXtOaR8dvnUinmSGFcV-fOwQ
}

View File

@@ -20,10 +20,13 @@ CREATE TABLE IF NOT EXISTS api_keys (
key TEXT PRIMARY KEY NOT NULL,
user_id BIGINT NOT NULL,
user_name TEXT NOT NULL,
access_mask INT
access_mask INT,
created_at TIMESTAMPTZ NOT NULL,
last_used_at TIMESTAMPTZ
);
CREATE TABLE IF NOT EXISTS dice_thresholds (
CREATE TABLE IF NOT EXISTS dice_track (
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(),
guild_id BIGINT NOT NULL,
owner_id BIGINT NOT NULL,
dice TEXT NOT NULL,
user_id BIGINT,

View File

@@ -2,11 +2,9 @@ use std::sync::Arc;
use axum::extract::{Path, State};
use axum::middleware::from_extractor;
use axum::{Extension, Json, Router};
use axum::response::IntoResponse;
use axum::routing::post;
use reqwest::StatusCode;
use serde::Deserialize;
use crate::api::auth::{AuthCredential, AuthorizationMiddleware, Session};
use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
use crate::AppState;
use crate::bot::commands::audio::join_voice_channel;
use crate::bot::commands::audio::pause::pause_track;

View File

@@ -2,11 +2,10 @@ use std::sync::Arc;
use axum::{Extension, Router};
use axum::middleware::from_extractor;
use axum::routing::post;
use reqwest::StatusCode;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::api::auth::{csprng, AuthCredential};
use crate::api::auth::AuthorizationMiddleware;
use crate::api::auth::session::Session;
use crate::AppState;
use crate::data::query::{Condition, QueryBuilder};
use crate::error::{Error, SirenResult};
@@ -25,6 +24,8 @@ pub struct ApiKey {
pub user_id: i64,
pub user_name: String,
pub access_mask: i32,
pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>>
}
impl ApiKey {
@@ -34,6 +35,8 @@ impl ApiKey {
user_id: user_id as i64,
user_name,
access_mask,
created_at: Utc::now(),
last_used_at: None,
}
}
@@ -44,9 +47,11 @@ impl ApiKey {
key,
user_id,
user_name,
access_mask
access_mask,
created_at,
last_used_at
) VALUES (
$1, $2, $3, $4
$1, $2, $3, $4, $5, $6
)",
TABLE_NAME
))
@@ -54,11 +59,36 @@ impl ApiKey {
.bind(self.user_id)
.bind(&self.user_name)
.bind(self.access_mask)
.bind(self.created_at)
.bind(self.last_used_at)
.execute(pool)
.await?;
Ok(())
}
pub async fn update(&self) -> SirenResult<()> {
let pool = crate::data::pool();
sqlx::query(&format!(
"UPDATE {} SET
user_id = $2,
user_name = $3,
access_mask = $4,
created_at = $5,
last_used_at = $6
WHERE key = $1",
TABLE_NAME
))
.bind(&self.key)
.bind(self.user_id)
.bind(&self.user_name)
.bind(self.access_mask)
.bind(self.created_at)
.bind(self.last_used_at)
.execute(pool)
.await?;
Ok(())
}
pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> {
let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME)
@@ -84,8 +114,13 @@ impl ApiKey {
async fn create_api_key(Extension(credential): Extension<AuthCredential>) -> SirenResult<String> {
let session = match credential {
AuthCredential::ApiKey(_) => return Err(Error::new(400, "API keys cannot be generated with an API key".to_string())),
AuthCredential::Session(session) => session
AuthCredential::ApiKey(_) => {
return Err(Error::new(
400,
"API keys cannot be generated using an existing API key for authentication.".to_string(),
))
}
AuthCredential::Session(session) => session,
};
log::debug!(
"Generating API key for {} ({})",

View File

@@ -87,11 +87,14 @@ async fn check_bearer_auth(bearer_token: &str) -> SirenResult<Session> {
}
async fn check_api_key_auth(key: &str) -> SirenResult<ApiKey> {
let api_key = match ApiKey::find_by_key(key).await? {
let mut api_key = match ApiKey::find_by_key(key).await? {
Some(api_key) => api_key,
None => return Err(StatusCode::UNAUTHORIZED.into()),
};
// Update when the API key was last used
api_key.last_used_at = Some(Utc::now());
api_key.update().await?;
Ok(api_key)
}

187
src/api/dice/mod.rs Normal file
View File

@@ -0,0 +1,187 @@
use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
use axum::{Extension, Json, Router};
use axum::extract::{Path, State};
use axum::middleware::from_extractor;
use axum::routing::post;
use axum_extra::handler::HandlerCallWithExtractors;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
use crate::AppState;
use crate::bot::commands::fun::roll::{format_roll, parse_dice};
use crate::data::query::{Condition, QueryBuilder};
use crate::error::{Error, SirenResult};
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/:guild_id/track", post(add_track_dice))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
const TABLE_NAME: &str = "dice_track";
#[derive(Serialize, Deserialize, Clone, Debug)]
enum TrackDiceOperator {
#[serde(rename = "eq")]
Equal,
#[serde(rename = "lt")]
LessThan,
#[serde(rename = "lte")]
LessThanEqual,
#[serde(rename = "gt")]
GreaterThan,
#[serde(rename = "gte")]
GreaterThanEqual,
}
// Implementing the ToString trait for converting the enum to a string
impl Display for TrackDiceOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
TrackDiceOperator::Equal => "eq".to_string(),
TrackDiceOperator::LessThan => "lt".to_string(),
TrackDiceOperator::LessThanEqual => "lte".to_string(),
TrackDiceOperator::GreaterThan => "gt".to_string(),
TrackDiceOperator::GreaterThanEqual => "gte".to_string(),
};
write!(f, "{}", str)
}
}
// Implementing the FromStr trait for parsing a string into the enum
impl FromStr for TrackDiceOperator {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"eq" => Ok(TrackDiceOperator::Equal),
"lt" => Ok(TrackDiceOperator::LessThan),
"lte" => Ok(TrackDiceOperator::LessThanEqual),
"gt" => Ok(TrackDiceOperator::GreaterThan),
"gte" => Ok(TrackDiceOperator::GreaterThanEqual),
_ => Err(format!("Unknown value for TrackDiceOperator: {}", s)),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
struct DiceTrackPayload {
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<TrackDiceOperator>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
struct InsertDiceTrack {
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
struct QueryDiceTrack {
id: Uuid,
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
impl QueryDiceTrack {
pub async fn find() -> SirenResult<Vec<Self>> {
let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME)
// .where_condition(
// Condition::and(
// Condition::is_equal("guild_id", "$1"),
// Condition::and(
// Condition::is_equal("owner_id", "$2"),
//
// )
// )
// )
.build();
let items: Vec<QueryDiceTrack> = sqlx::query_as(&query)
.fetch_all(pool).await?;
Ok(items)
}
}
impl InsertDiceTrack {
pub async fn insert(&self) -> SirenResult<QueryDiceTrack> {
let pool = crate::data::pool();
let query = format!(
"INSERT INTO {} (
guild_id,
owner_id,
dice,
user_id,
value,
operator
) VALUES (
$1, $2, $3, $4, $5, $6
) RETURNING *",
TABLE_NAME
);
let item: QueryDiceTrack = match sqlx::query_as(&query)
.bind(self.guild_id)
.bind(self.owner_id)
.bind(&self.dice)
.bind(self.user_id)
.bind(self.value)
.bind(&self.operator)
.fetch_optional(pool).await? {
Some(result) => result,
None => return Err(Error::new(500, "Error storing".to_string()))
};
Ok(item)
}
}
pub async fn add_track_dice(
Extension(credential): Extension<AuthCredential>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
Json(payload): Json<DiceTrackPayload>,
) -> SirenResult<Json<QueryDiceTrack>> {
// Check if the user exists in the cache
let owner_id = credential.user_id();
let owner_id = match state.cache.user(owner_id) {
Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())),
};
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
let dice = parse_dice(&payload.dice)?;
let dice = InsertDiceTrack {
guild_id: guild_id.get() as i64,
owner_id: owner_id.get() as i64,
dice: format_roll(dice.0, dice.1, dice.2),
user_id: payload.user_id,
value: payload.value,
operator: match payload.operator {
None => None,
Some(s) => Some(s.to_string()),
}
};
let dice_track = dice.insert().await?;
Ok(Json(dice_track))
}

View File

@@ -2,14 +2,17 @@ pub use app::App;
use std::sync::Arc;
use axum::Router;
use serde::Deserialize;
use crate::AppState;
mod app;
mod audio;
mod auth;
mod dice;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.merge(auth::get_routes())
.nest("/audio/:guild_id", audio::get_routes())
.nest("/dice", dice::get_routes())
}

View File

@@ -1,7 +1,7 @@
use serenity::all::{
CommandInteraction, Context, CreateInteractionResponse, CreateInteractionResponseMessage,
CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, ModalInteraction,
User, UserId,
UserId,
};
pub async fn process_message(ctx: &Context, command: &CommandInteraction, private: bool) {

View File

@@ -7,6 +7,7 @@ use serenity::all::{
};
use crate::bot::chat::{create_message_response, edit_response};
use crate::error::{Error, SirenResult};
use crate::utils::{a_or_an, number_to_words};
pub async fn run(ctx: &Context, command: &CommandInteraction) {
@@ -64,6 +65,8 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
}
None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await,
};
// Check for dice tracks
}
Err(why) => {
edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await;
@@ -123,7 +126,7 @@ pub fn roll_dice(count: u32, sides: u32, modifier: i32) -> i32 {
total
}
pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
pub fn parse_dice(dice: &str) -> SirenResult<(u32, u32, i32)> {
// If the input is just a number (e.g., "20" or "6"), assume it's the number of sides
if let Ok(n) = dice.parse::<u32>() {
return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers
@@ -144,31 +147,31 @@ pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
Some("") => 1, // Handle cases like "d6", assume 1 dice
Some(c) => match c.parse::<u32>() {
Ok(n) => n,
Err(_) => return Err(format!("Invalid dice count: {}", c)),
Err(_) => return Err(Error::new(400, format!("Invalid dice count: {}", c))),
},
None => return Err(format!("Invalid dice string: {}", dice)),
None => return Err(Error::new(400, format!("Invalid dice string: {}", dice))),
};
// Parse the number of sides
let sides_part = parts
.next()
.ok_or_else(|| format!("Invalid dice string: {}", dice))?;
.ok_or_else(|| Error::new(400, format!("Invalid dice string: {}", dice)))?;
let sides = match sides_part.parse::<u32>() {
Ok(n) => {
if [4, 6, 8, 10, 12, 20, 100].contains(&n) {
n
} else {
return Err(format!(
return Err(Error::new(400, format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
n
));
)));
}
}
Err(_) => {
return Err(format!(
return Err(Error::new(400, format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
sides_part
))
)))
}
};
@@ -189,7 +192,7 @@ pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> {
-n
}
}
Err(_) => return Err(format!("Invalid dice modifier: {}", m)),
Err(_) => return Err(Error::new(400, format!("Invalid dice modifier: {}", m))),
},
None => 0, // No modifier found
};

View File

@@ -23,6 +23,7 @@ pub struct BotHandler {
pub oai: Option<OAI>,
}
static REGISTERED: OnceLock<bool> = OnceLock::new();
static SONGBIRD: OnceLock<Arc<Songbird>> = OnceLock::new();
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
@@ -104,6 +105,13 @@ impl EventHandler for BotHandler {
CLIENT.set(http_client).ok();
}
// Update registered to prevent reloading the commands
if REGISTERED.get().is_some() {
return;
} else {
REGISTERED.set(true).ok();
}
log::trace!("Handling {} guilds", ready.guilds.len());
for guild in ready.guilds {
// Check if guild exists in database

View File

@@ -38,7 +38,7 @@ impl GuildCache {
pub async fn find_by_id(id: i64) -> SirenResult<Option<Self>> {
let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder
.where_condition(Condition::is_equal("id", "$1"))
.build();
let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?;