Files
siren/src/api/dice/mod.rs

200 lines
5.5 KiB
Rust

use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
use axum::{Extension, Json, Router};
use axum::extract::{Path, State};
use axum::middleware::from_extractor;
use axum::routing::post;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
use crate::AppState;
use crate::bot::commands::fun::roll::{format_roll, parse_dice};
use crate::data::condition::Condition;
use crate::data::{ExecutableQuery, Value};
use crate::data::query::QueryBuilder;
use crate::error::{Error, SirenResult};
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/:guild_id/track", post(add_track_dice))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
const TABLE_NAME: &str = "dice_track";
#[derive(Serialize, Deserialize, Clone, Debug)]
enum TrackDiceOperator {
#[serde(rename = "eq")]
Equal,
#[serde(rename = "lt")]
LessThan,
#[serde(rename = "lte")]
LessThanEqual,
#[serde(rename = "gt")]
GreaterThan,
#[serde(rename = "gte")]
GreaterThanEqual,
}
// Implementing the ToString trait for converting the enum to a string
impl Display for TrackDiceOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
TrackDiceOperator::Equal => "eq".to_string(),
TrackDiceOperator::LessThan => "lt".to_string(),
TrackDiceOperator::LessThanEqual => "lte".to_string(),
TrackDiceOperator::GreaterThan => "gt".to_string(),
TrackDiceOperator::GreaterThanEqual => "gte".to_string(),
};
write!(f, "{}", str)
}
}
// Implementing the FromStr trait for parsing a string into the enum
impl FromStr for TrackDiceOperator {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"eq" => Ok(TrackDiceOperator::Equal),
"lt" => Ok(TrackDiceOperator::LessThan),
"lte" => Ok(TrackDiceOperator::LessThanEqual),
"gt" => Ok(TrackDiceOperator::GreaterThan),
"gte" => Ok(TrackDiceOperator::GreaterThanEqual),
_ => Err(format!("Unknown value for TrackDiceOperator: {}", s)),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
struct DiceTrackPayload {
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<TrackDiceOperator>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
struct InsertDiceTrack {
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
struct QueryDiceTrack {
id: Uuid,
guild_id: i64,
owner_id: i64,
dice: String,
user_id: Option<i64>,
value: Option<i32>,
operator: Option<String>,
}
impl QueryDiceTrack {
pub async fn find(dice: &InsertDiceTrack) -> Option<Self> {
QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::and(
Condition::is_equal("guild_id", Value::BigInt(dice.guild_id)),
Condition::and(
Condition::is_equal("owner_id", Value::BigInt(dice.owner_id)),
Condition::and(
Condition::is_equal("dice", Value::Text(dice.dice.clone())),
Condition::and(
Condition::is_equal("user_id", Value::OptionalBigInt(dice.user_id)),
Condition::and(
Condition::is_equal("value", Value::OptionalInt(dice.value)),
Condition::is_equal("operator", Value::OptionalText(dice.operator.clone())),
),
),
),
),
))
.fetch_optional()
.await
}
}
impl InsertDiceTrack {
pub async fn insert(&self) -> SirenResult<QueryDiceTrack> {
let pool = crate::data::pool();
let query = format!(
"INSERT INTO {} (
guild_id,
owner_id,
dice,
user_id,
value,
operator
) VALUES (
$1, $2, $3, $4, $5, $6
) RETURNING *",
TABLE_NAME
);
let item: QueryDiceTrack = match sqlx::query_as(&query)
.bind(self.guild_id)
.bind(self.owner_id)
.bind(&self.dice)
.bind(self.user_id)
.bind(self.value)
.bind(&self.operator)
.fetch_optional(pool)
.await?
{
Some(result) => result,
None => return Err(Error::new(500, "Error storing".to_string())),
};
Ok(item)
}
}
pub async fn add_track_dice(
Extension(credential): Extension<AuthCredential>,
State(state): State<Arc<AppState>>,
Path(guild_id): Path<u64>,
Json(payload): Json<DiceTrackPayload>,
) -> SirenResult<Json<QueryDiceTrack>> {
// Check if the user exists in the cache
let owner_id = credential.user_id();
let owner_id = match state.cache.user(owner_id) {
Some(user) => user.id,
None => return Err(Error::not_found("User not found".to_string())),
};
// Validate if the guild exists in the cache
let guild_id = match state.cache.guild(guild_id) {
Some(guild) => guild.id,
None => return Err(Error::not_found("Guild not found".to_string())),
};
let dice = parse_dice(&payload.dice)?;
let insert_dice = InsertDiceTrack {
guild_id: guild_id.get() as i64,
owner_id: owner_id.get() as i64,
dice: format_roll(dice.0, dice.1, dice.2),
user_id: payload.user_id,
value: payload.value,
operator: match payload.operator {
None => None,
Some(s) => Some(s.to_string()),
},
};
// Check for existing dice tracks
let results = QueryDiceTrack::find(&insert_dice).await;
match results {
Some(dice_track) => Ok(Json(dice_track)),
None => {
let dice_track = insert_dice.insert().await?;
Ok(Json(dice_track))
}
}
}