From 2cd2715d0d4b51716c0241150fdd554a465c4804 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Sat, 21 Dec 2024 22:31:11 -0500 Subject: [PATCH] Updated query builder with bindings --- src/api/auth/api_key.rs | 24 ++--- src/api/dice/mod.rs | 17 ++-- src/bot/commands/audio/play.rs | 4 +- src/bot/commands/audio/volume.rs | 5 +- src/bot/commands/fun/roll.rs | 23 +++-- src/bot/handler.rs | 2 +- src/data/guilds/model.rs | 17 ++-- src/data/query.rs | 163 +++++++++++++++++++++++-------- 8 files changed, 169 insertions(+), 86 deletions(-) diff --git a/src/api/auth/api_key.rs b/src/api/auth/api_key.rs index e063176..79856c4 100644 --- a/src/api/auth/api_key.rs +++ b/src/api/auth/api_key.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::api::auth::{csprng, AuthCredential}; use crate::api::auth::AuthorizationMiddleware; use crate::AppState; -use crate::data::query::{Condition, QueryBuilder}; +use crate::data::query::{Condition, QueryBuilder, Value}; use crate::error::{Error, SirenResult}; pub fn get_routes() -> Router> { @@ -25,7 +25,7 @@ pub struct ApiKey { pub user_name: String, pub access_mask: i32, pub created_at: DateTime, - pub last_used_at: Option> + pub last_used_at: Option>, } impl ApiKey { @@ -78,23 +78,23 @@ impl ApiKey { WHERE key = $1", TABLE_NAME )) - .bind(&self.key) - .bind(self.user_id) - .bind(&self.user_name) - .bind(self.access_mask) - .bind(self.created_at) - .bind(self.last_used_at) - .execute(pool) - .await?; + .bind(&self.key) + .bind(self.user_id) + .bind(&self.user_name) + .bind(self.access_mask) + .bind(self.created_at) + .bind(self.last_used_at) + .execute(pool) + .await?; Ok(()) } pub async fn find_by_key(key: &str) -> SirenResult> { let pool = crate::data::pool(); let query = QueryBuilder::new(TABLE_NAME) - .where_condition(Condition::is_equal("key", "$1")) + .where_condition(Condition::is_equal("key", Value::Text(key.to_string()))) .build(); - let item = sqlx::query_as(&query) + let item = sqlx::query_as(&query.0) .bind(key) .fetch_optional(pool) .await?; diff --git a/src/api/dice/mod.rs b/src/api/dice/mod.rs index 496dff8..545205e 100644 --- a/src/api/dice/mod.rs +++ b/src/api/dice/mod.rs @@ -36,7 +36,6 @@ enum TrackDiceOperator { GreaterThanEqual, } - // Implementing the ToString trait for converting the enum to a string impl Display for TrackDiceOperator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -110,8 +109,7 @@ impl QueryDiceTrack { // ) // ) .build(); - let items: Vec = sqlx::query_as(&query) - .fetch_all(pool).await?; + let items: Vec = sqlx::query_as(&query.0).fetch_all(pool).await?; Ok(items) } @@ -131,7 +129,7 @@ impl InsertDiceTrack { ) VALUES ( $1, $2, $3, $4, $5, $6 ) RETURNING *", - TABLE_NAME + TABLE_NAME ); let item: QueryDiceTrack = match sqlx::query_as(&query) .bind(self.guild_id) @@ -140,9 +138,11 @@ impl InsertDiceTrack { .bind(self.user_id) .bind(self.value) .bind(&self.operator) - .fetch_optional(pool).await? { + .fetch_optional(pool) + .await? + { Some(result) => result, - None => return Err(Error::new(500, "Error storing".to_string())) + None => return Err(Error::new(500, "Error storing".to_string())), }; Ok(item) } @@ -154,7 +154,6 @@ pub async fn add_track_dice( Path(guild_id): Path, Json(payload): Json, ) -> SirenResult> { - // Check if the user exists in the cache let owner_id = credential.user_id(); let owner_id = match state.cache.user(owner_id) { @@ -179,9 +178,9 @@ pub async fn add_track_dice( operator: match payload.operator { None => None, Some(s) => Some(s.to_string()), - } + }, }; let dice_track = dice.insert().await?; Ok(Json(dice_track)) -} \ No newline at end of file +} diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index ea38582..df36057 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -89,9 +89,7 @@ pub async fn enqueue_track( let mut playlist_items: Vec = Vec::new(); if let Some(handler_lock) = manager.get(guild_id) { let mut handler = handler_lock.lock().await; - let guild = GuildCache::find_by_id(guild_id.get() as i64) - .await? - .unwrap(); + let guild = GuildCache::find_by_id(guild_id.get() as i64).await.unwrap(); let valid = is_valid_url(&track_url); // Check if the URL is valid diff --git a/src/bot/commands/audio/volume.rs b/src/bot/commands/audio/volume.rs index 98869e8..da0beea 100644 --- a/src/bot/commands/audio/volume.rs +++ b/src/bot/commands/audio/volume.rs @@ -64,10 +64,7 @@ pub async fn set_volume(manager: &Arc, guild_id: &GuildId, volume: i32 let bound_volume = volume as f32 / 100.0; // Update the guild cache - let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64) - .await - .unwrap() - .unwrap(); + let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64).await.unwrap(); guild_cache.volume = volume; guild_cache.update().await.unwrap(); diff --git a/src/bot/commands/fun/roll.rs b/src/bot/commands/fun/roll.rs index d4ae64d..752d96d 100644 --- a/src/bot/commands/fun/roll.rs +++ b/src/bot/commands/fun/roll.rs @@ -66,7 +66,6 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await, }; // Check for dice tracks - } Err(why) => { edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await; @@ -161,17 +160,23 @@ pub fn parse_dice(dice: &str) -> SirenResult<(u32, u32, i32)> { if [4, 6, 8, 10, 12, 20, 100].contains(&n) { n } else { - return Err(Error::new(400, format!( - "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", - n - ))); + return Err(Error::new( + 400, + format!( + "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", + n + ), + )); } } Err(_) => { - return Err(Error::new(400, format!( - "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", - sides_part - ))) + return Err(Error::new( + 400, + format!( + "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", + sides_part + ), + )) } }; diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 2a2930f..ded87ac 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -116,7 +116,7 @@ impl EventHandler for BotHandler { for guild in ready.guilds { // Check if guild exists in database let guild_id = guild.id.get() as i64; - if let None = GuildCache::find_by_id(guild_id).await.unwrap() { + if let None = GuildCache::find_by_id(guild_id).await { let guild_cache = GuildCache { id: guild_id, name: guild.id.name(&ctx.cache), diff --git a/src/data/guilds/model.rs b/src/data/guilds/model.rs index 524b1b5..17d6a9e 100644 --- a/src/data/guilds/model.rs +++ b/src/data/guilds/model.rs @@ -1,5 +1,7 @@ use serde::{Serialize, Deserialize}; -use crate::data::query::{Condition, QueryBuilder}; +use sqlx::{Database, Postgres}; +use sqlx::query::QueryAs; +use crate::data::query::{Condition, QueryBuilder, Value}; use crate::error::SirenResult; const TABLE_NAME: &str = "guilds"; @@ -35,14 +37,11 @@ impl GuildCache { Ok(()) } - pub async fn find_by_id(id: i64) -> SirenResult> { - let pool = crate::data::pool(); - let query = QueryBuilder::new(TABLE_NAME) - .where_condition(Condition::is_equal("id", "$1")) - .build(); - let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?; - - Ok(item) + pub async fn find_by_id(id: i64) -> Option { + QueryBuilder::new(TABLE_NAME) + .where_condition(Condition::is_equal("id", Value::BIGINT(id))) + .fetch_optional() + .await } pub async fn update(&self) -> SirenResult<()> { diff --git a/src/data/query.rs b/src/data/query.rs index ea0f32b..a08efc6 100644 --- a/src/data/query.rs +++ b/src/data/query.rs @@ -1,9 +1,14 @@ +use std::fmt; +use std::fmt::Display; +use sqlx::{FromRow, Postgres}; + pub struct QueryBuilder { table: String, columns: Vec, condition: Option, order_by: Vec, limit: Option, + offset: Option, } impl QueryBuilder { @@ -14,6 +19,7 @@ impl QueryBuilder { condition: None, order_by: Vec::new(), limit: None, + offset: None, } } @@ -37,7 +43,30 @@ impl QueryBuilder { self } - pub fn build(self) -> String { + pub async fn fetch_optional< + T: Send + Unpin + for<'r> FromRow<'r, ::Row>, + >( + self, + ) -> Option { + let (query_string, values) = self.build(); + let mut query_as = sqlx::query_as(&query_string); + for value in values { + match value { + Value::INT(n) => query_as = query_as.bind(n), + Value::BIGINT(n) => query_as = query_as.bind(n), + Value::Bool(n) => query_as = query_as.bind(n), + Value::Text(n) => query_as = query_as.bind(n), + } + } + + let pool = crate::data::pool(); + query_as.fetch_optional(pool).await.unwrap_or_else(|err| { + log::error!("{}", err); + None + }) + } + + pub fn build(self) -> (String, Vec) { let columns = if self.columns.is_empty() { "*".to_string() } else { @@ -46,8 +75,11 @@ impl QueryBuilder { let mut query = format!("SELECT {} FROM {}", columns, self.table); + let mut values: Vec = Vec::new(); if let Some(condition) = self.condition { - query.push_str(&format!(" WHERE {}", condition.to_sql())); + let where_condition = condition.to_sql(&mut 0); + query.push_str(&format!(" WHERE {}", where_condition.0)); + values = where_condition.1; } if !self.order_by.is_empty() { @@ -58,12 +90,35 @@ impl QueryBuilder { query.push_str(&format!(" LIMIT {}", limit)); } - query + if let Some(offset) = self.offset { + query.push_str(&format!(" OFFSET {}", offset)); + } + + (query, values) + } +} + +#[derive(Debug, Clone)] +pub enum Value { + INT(i32), + BIGINT(i64), + Bool(bool), + Text(String), +} + +impl Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Value::INT(n) => write!(f, "{}", n), + Value::BIGINT(n) => write!(f, "{}", n), + Value::Bool(n) => write!(f, "{}", n), + Value::Text(s) => write!(f, "'{}'", s), // Wrap strings in quotes for SQL + } } } pub enum Condition { - Simple(String), + Simple(String, Vec), And(Box, Box), Or(Box, Box), Group(Box), @@ -71,7 +126,7 @@ pub enum Condition { impl Condition { pub fn new(condition: &str) -> Self { - Condition::Simple(condition.to_string()) + Condition::Simple(condition.to_string(), vec![]) } pub fn and(self, other: Self) -> Self { @@ -86,78 +141,108 @@ impl Condition { Condition::Group(Box::new(self)) } - pub fn is_equal(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} = {}", left, right)) + pub fn is_equal(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} = ?", left), vec![right.into()]) } - pub fn not_equal(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} != {}", left, right)) + pub fn not_equal(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} != ?", left), vec![right.into()]) } pub fn is_null(value: &str) -> Self { - Condition::Simple(format!("{} IS NULL", value)) + Condition::Simple(format!("{} IS NULL", value), vec![]) } pub fn not_null(value: &str) -> Self { - Condition::Simple(format!("{} IS NOT NULL", value)) + Condition::Simple(format!("{} IS NOT NULL", value), vec![]) } - pub fn is_in(left: &str, right: &[&str]) -> Self { + pub fn is_in(left: &str, right: Vec) -> Self { let right_list = right .iter() - .map(|v| format!("'{}'", v)) + .map(|v| "'?'".to_string()) .collect::>() .join(", "); - Condition::Simple(format!("{} IN ({})", left, right_list)) + Condition::Simple(format!("{} IN ({})", left, right_list), right) } - pub fn not_in(left: &str, right: &[&str]) -> Self { + pub fn not_in(left: &str, right: Vec) -> Self { let right_list = right .iter() - .map(|v| format!("'{}'", v)) + .map(|v| "'?'".to_string()) .collect::>() .join(", "); - Condition::Simple(format!("{} NOT IN ({})", left, right_list)) + Condition::Simple(format!("{} NOT IN ({})", left, right_list), right) } - pub fn like(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} LIKE '{}'", left, right)) + pub fn like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} LIKE '?'", left), vec![right.into()]) } - pub fn not_like(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} NOT LIKE '{}'", left, right)) + pub fn not_like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} NOT LIKE '?'", left), vec![right.into()]) } - pub fn i_like(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} ILIKE '{}'", left, right)) + pub fn i_like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} ILIKE '?'", left), vec![right.into()]) } - pub fn not_i_like(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} NOT ILIKE '{}'", left, right)) + pub fn not_i_like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} NOT ILIKE '?'", left), vec![right.into()]) } - pub fn gt(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} > {}", left, right)) + pub fn gt(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} > ?", left), vec![right.into()]) } - pub fn gte(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} >= {}", left, right)) + pub fn gte(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} >= ?", left), vec![right.into()]) } - pub fn lt(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} < {}", left, right)) + pub fn lt(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} < ?", left), vec![right.into()]) } - pub fn lte(left: &str, right: &str) -> Self { - Condition::Simple(format!("{} <= {}", left, right)) + pub fn lte(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} <= ?", left), vec![right.into()]) } - fn to_sql(&self) -> String { + fn to_sql(&self, mut counter: &mut usize) -> (String, Vec) { + let mut sql = String::new(); + let mut binds = Vec::new(); + match self { - Condition::Simple(s) => s.to_string(), - Condition::And(a, b) => format!("{} AND {}", a.to_sql(), b.to_sql()), - Condition::Or(a, b) => format!("{} OR {}", a.to_sql(), b.to_sql()), - Condition::Group(a) => format!("({})", a.to_sql()), - } + Condition::Simple(condition, values) => { + // Replace all instances of '?' with an numbered bind + let mut bind_index = *counter; + let numbered_condition = condition.replace("?", { + bind_index += 1; + &format!("${}", bind_index) + }); + sql.push_str(&numbered_condition); + binds.extend(values.clone()); + } + Condition::And(left, right) => { + let (left_sql, left_binds) = left.to_sql(counter); + let (right_sql, right_binds) = right.to_sql(counter); + sql.push_str(&format!("{} AND {}", left_sql, right_sql)); + binds.extend(left_binds); + binds.extend(right_binds); + } + Condition::Or(left, right) => { + let (left_sql, left_binds) = left.to_sql(counter); + let (right_sql, right_binds) = right.to_sql(counter); + sql.push_str(&format!("{} OR {}", left_sql, right_sql)); + binds.extend(left_binds); + binds.extend(right_binds); + } + Condition::Group(inner) => { + let (inner_sql, inner_binds) = inner.to_sql(counter); + sql.push_str(&format!("({})", inner_sql)); + binds.extend(inner_binds); + } + }; + + (sql, binds) } }