200 lines
5.5 KiB
Rust
200 lines
5.5 KiB
Rust
use std::fmt::Display;
|
|
use std::str::FromStr;
|
|
use std::sync::Arc;
|
|
use axum::{Extension, Json, Router};
|
|
use axum::extract::{Path, State};
|
|
use axum::middleware::from_extractor;
|
|
use axum::routing::post;
|
|
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
|
|
use crate::AppState;
|
|
use crate::bot::commands::fun::roll::{format_roll, parse_dice};
|
|
use crate::data::condition::Condition;
|
|
use crate::data::{ExecutableQuery, Value};
|
|
use crate::data::query::QueryBuilder;
|
|
use crate::error::{Error, SirenResult};
|
|
|
|
pub fn get_routes() -> Router<Arc<AppState>> {
|
|
Router::new()
|
|
.route("/:guild_id/track", post(add_track_dice))
|
|
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
|
}
|
|
|
|
const TABLE_NAME: &str = "dice_track";
|
|
|
|
#[derive(Serialize, Deserialize, Clone, Debug)]
|
|
enum TrackDiceOperator {
|
|
#[serde(rename = "eq")]
|
|
Equal,
|
|
#[serde(rename = "lt")]
|
|
LessThan,
|
|
#[serde(rename = "lte")]
|
|
LessThanEqual,
|
|
#[serde(rename = "gt")]
|
|
GreaterThan,
|
|
#[serde(rename = "gte")]
|
|
GreaterThanEqual,
|
|
}
|
|
|
|
// Implementing the ToString trait for converting the enum to a string
|
|
impl Display for TrackDiceOperator {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
let str = match self {
|
|
TrackDiceOperator::Equal => "eq".to_string(),
|
|
TrackDiceOperator::LessThan => "lt".to_string(),
|
|
TrackDiceOperator::LessThanEqual => "lte".to_string(),
|
|
TrackDiceOperator::GreaterThan => "gt".to_string(),
|
|
TrackDiceOperator::GreaterThanEqual => "gte".to_string(),
|
|
};
|
|
write!(f, "{}", str)
|
|
}
|
|
}
|
|
|
|
// Implementing the FromStr trait for parsing a string into the enum
|
|
impl FromStr for TrackDiceOperator {
|
|
type Err = String;
|
|
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
match s {
|
|
"eq" => Ok(TrackDiceOperator::Equal),
|
|
"lt" => Ok(TrackDiceOperator::LessThan),
|
|
"lte" => Ok(TrackDiceOperator::LessThanEqual),
|
|
"gt" => Ok(TrackDiceOperator::GreaterThan),
|
|
"gte" => Ok(TrackDiceOperator::GreaterThanEqual),
|
|
_ => Err(format!("Unknown value for TrackDiceOperator: {}", s)),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, Debug)]
|
|
struct DiceTrackPayload {
|
|
dice: String,
|
|
user_id: Option<i64>,
|
|
value: Option<i32>,
|
|
operator: Option<TrackDiceOperator>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
|
struct InsertDiceTrack {
|
|
guild_id: i64,
|
|
owner_id: i64,
|
|
dice: String,
|
|
user_id: Option<i64>,
|
|
value: Option<i32>,
|
|
operator: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
|
struct QueryDiceTrack {
|
|
id: Uuid,
|
|
guild_id: i64,
|
|
owner_id: i64,
|
|
dice: String,
|
|
user_id: Option<i64>,
|
|
value: Option<i32>,
|
|
operator: Option<String>,
|
|
}
|
|
|
|
impl QueryDiceTrack {
|
|
pub async fn find(dice: &InsertDiceTrack) -> Option<Self> {
|
|
QueryBuilder::new(TABLE_NAME)
|
|
.where_condition(Condition::and(
|
|
Condition::is_equal("guild_id", Value::BigInt(dice.guild_id)),
|
|
Condition::and(
|
|
Condition::is_equal("owner_id", Value::BigInt(dice.owner_id)),
|
|
Condition::and(
|
|
Condition::is_equal("dice", Value::Text(dice.dice.clone())),
|
|
Condition::and(
|
|
Condition::is_equal("user_id", Value::OptionalBigInt(dice.user_id)),
|
|
Condition::and(
|
|
Condition::is_equal("value", Value::OptionalInt(dice.value)),
|
|
Condition::is_equal("operator", Value::OptionalText(dice.operator.clone())),
|
|
),
|
|
),
|
|
),
|
|
),
|
|
))
|
|
.fetch_optional()
|
|
.await
|
|
}
|
|
}
|
|
|
|
impl InsertDiceTrack {
|
|
pub async fn insert(&self) -> SirenResult<QueryDiceTrack> {
|
|
let pool = crate::data::pool();
|
|
let query = format!(
|
|
"INSERT INTO {} (
|
|
guild_id,
|
|
owner_id,
|
|
dice,
|
|
user_id,
|
|
value,
|
|
operator
|
|
) VALUES (
|
|
$1, $2, $3, $4, $5, $6
|
|
) RETURNING *",
|
|
TABLE_NAME
|
|
);
|
|
let item: QueryDiceTrack = match sqlx::query_as(&query)
|
|
.bind(self.guild_id)
|
|
.bind(self.owner_id)
|
|
.bind(&self.dice)
|
|
.bind(self.user_id)
|
|
.bind(self.value)
|
|
.bind(&self.operator)
|
|
.fetch_optional(pool)
|
|
.await?
|
|
{
|
|
Some(result) => result,
|
|
None => return Err(Error::new(500, "Error storing".to_string())),
|
|
};
|
|
Ok(item)
|
|
}
|
|
}
|
|
|
|
pub async fn add_track_dice(
|
|
Extension(credential): Extension<AuthCredential>,
|
|
State(state): State<Arc<AppState>>,
|
|
Path(guild_id): Path<u64>,
|
|
Json(payload): Json<DiceTrackPayload>,
|
|
) -> SirenResult<Json<QueryDiceTrack>> {
|
|
// Check if the user exists in the cache
|
|
let owner_id = credential.user_id();
|
|
let owner_id = match state.cache.user(owner_id) {
|
|
Some(user) => user.id,
|
|
None => return Err(Error::not_found("User not found".to_string())),
|
|
};
|
|
|
|
// Validate if the guild exists in the cache
|
|
let guild_id = match state.cache.guild(guild_id) {
|
|
Some(guild) => guild.id,
|
|
None => return Err(Error::not_found("Guild not found".to_string())),
|
|
};
|
|
|
|
let dice = parse_dice(&payload.dice)?;
|
|
|
|
let insert_dice = InsertDiceTrack {
|
|
guild_id: guild_id.get() as i64,
|
|
owner_id: owner_id.get() as i64,
|
|
dice: format_roll(dice.0, dice.1, dice.2),
|
|
user_id: payload.user_id,
|
|
value: payload.value,
|
|
operator: match payload.operator {
|
|
None => None,
|
|
Some(s) => Some(s.to_string()),
|
|
},
|
|
};
|
|
|
|
// Check for existing dice tracks
|
|
let results = QueryDiceTrack::find(&insert_dice).await;
|
|
|
|
match results {
|
|
Some(dice_track) => Ok(Json(dice_track)),
|
|
None => {
|
|
let dice_track = insert_dice.insert().await?;
|
|
Ok(Json(dice_track))
|
|
}
|
|
}
|
|
}
|