From 4c5300fa5929d19b06cb11f70ebc3bb6878cfbde Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Sat, 21 Dec 2024 23:31:41 -0500 Subject: [PATCH] Built insert and update builders --- src/api/auth/api_key.rs | 4 +- src/api/dice/mod.rs | 3 +- src/data/condition.rs | 131 ++++++++++++++++++++++++++++++ src/data/guilds/model.rs | 58 +++++--------- src/data/insert.rs | 87 ++++++++++++++++++++ src/data/mod.rs | 48 ++++++++++- src/data/query.rs | 169 ++++----------------------------------- src/data/update.rs | 84 +++++++++++++++++++ 8 files changed, 390 insertions(+), 194 deletions(-) create mode 100644 src/data/condition.rs create mode 100644 src/data/insert.rs create mode 100644 src/data/update.rs diff --git a/src/api/auth/api_key.rs b/src/api/auth/api_key.rs index 79856c4..0b1c430 100644 --- a/src/api/auth/api_key.rs +++ b/src/api/auth/api_key.rs @@ -7,7 +7,9 @@ use serde::{Deserialize, Serialize}; use crate::api::auth::{csprng, AuthCredential}; use crate::api::auth::AuthorizationMiddleware; use crate::AppState; -use crate::data::query::{Condition, QueryBuilder, Value}; +use crate::data::condition::Condition; +use crate::data::query::QueryBuilder; +use crate::data::Value; use crate::error::{Error, SirenResult}; pub fn get_routes() -> Router> { diff --git a/src/api/dice/mod.rs b/src/api/dice/mod.rs index 545205e..e0dff08 100644 --- a/src/api/dice/mod.rs +++ b/src/api/dice/mod.rs @@ -5,13 +5,12 @@ use axum::{Extension, Json, Router}; use axum::extract::{Path, State}; use axum::middleware::from_extractor; use axum::routing::post; -use axum_extra::handler::HandlerCallWithExtractors; use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::api::auth::{AuthCredential, AuthorizationMiddleware}; use crate::AppState; use crate::bot::commands::fun::roll::{format_roll, parse_dice}; -use crate::data::query::{Condition, QueryBuilder}; +use crate::data::query::QueryBuilder; use crate::error::{Error, SirenResult}; pub fn get_routes() -> Router> { diff --git a/src/data/condition.rs b/src/data/condition.rs new file mode 100644 index 0000000..9b50bfb --- /dev/null +++ b/src/data/condition.rs @@ -0,0 +1,131 @@ +use crate::data::Value; + +pub enum Condition { + Simple(String, Vec), + And(Box, Box), + Or(Box, Box), + Group(Box), +} + +impl Condition { + pub fn new(condition: &str) -> Self { + Condition::Simple(condition.to_string(), vec![]) + } + + pub fn and(self, other: Self) -> Self { + Condition::And(Box::new(self), Box::new(other)) + } + + pub fn or(self, other: Self) -> Self { + Condition::Or(Box::new(self), Box::new(other)) + } + + pub fn group(self) -> Self { + Condition::Group(Box::new(self)) + } + + pub fn is_equal(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} = ?", left), vec![right.into()]) + } + + pub fn not_equal(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} != ?", left), vec![right.into()]) + } + + pub fn is_null(value: &str) -> Self { + Condition::Simple(format!("{} IS NULL", value), vec![]) + } + + pub fn not_null(value: &str) -> Self { + Condition::Simple(format!("{} IS NOT NULL", value), vec![]) + } + + pub fn is_in(left: &str, right: Vec) -> Self { + let right_list = right + .iter() + .map(|v| "'?'".to_string()) + .collect::>() + .join(", "); + Condition::Simple(format!("{} IN ({})", left, right_list), right) + } + + pub fn not_in(left: &str, right: Vec) -> Self { + let right_list = right + .iter() + .map(|v| "'?'".to_string()) + .collect::>() + .join(", "); + Condition::Simple(format!("{} NOT IN ({})", left, right_list), right) + } + + pub fn like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} LIKE '?'", left), vec![right.into()]) + } + + pub fn not_like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} NOT LIKE '?'", left), vec![right.into()]) + } + + pub fn i_like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} ILIKE '?'", left), vec![right.into()]) + } + + pub fn not_i_like(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} NOT ILIKE '?'", left), vec![right.into()]) + } + + pub fn gt(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} > ?", left), vec![right.into()]) + } + + pub fn gte(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} >= ?", left), vec![right.into()]) + } + + pub fn lt(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} < ?", left), vec![right.into()]) + } + + pub fn lte(left: &str, right: impl Into) -> Self { + Condition::Simple(format!("{} <= ?", left), vec![right.into()]) + } + + pub fn to_sql(&self, mut counter: &mut usize) -> (String, Vec) { + let mut sql = String::new(); + let mut binds = Vec::new(); + + match self { + Condition::Simple(condition, values) => { + // Replace all instances of '?' with a numbered bind + let mut bind_index = *counter; + let numbered_condition = condition.replace("?", { + bind_index += 1; + &format!("${}", bind_index) + }); + sql.push_str(&numbered_condition); + binds.extend(values.clone()); + } + Condition::And(left, right) => { + let (left_sql, left_binds) = left.to_sql(counter); + let (right_sql, right_binds) = right.to_sql(counter); + sql.push_str(&format!("{} AND {}", left_sql, right_sql)); + binds.extend(left_binds); + binds.extend(right_binds); + } + Condition::Or(left, right) => { + let (left_sql, left_binds) = left.to_sql(counter); + let (right_sql, right_binds) = right.to_sql(counter); + sql.push_str(&format!("{} OR {}", left_sql, right_sql)); + binds.extend(left_binds); + binds.extend(right_binds); + } + Condition::Group(inner) => { + let (inner_sql, inner_binds) = inner.to_sql(counter); + sql.push_str(&format!("({})", inner_sql)); + binds.extend(inner_binds); + } + }; + + (sql, binds) + } +} diff --git a/src/data/guilds/model.rs b/src/data/guilds/model.rs index 17d6a9e..1752a56 100644 --- a/src/data/guilds/model.rs +++ b/src/data/guilds/model.rs @@ -1,7 +1,10 @@ use serde::{Serialize, Deserialize}; -use sqlx::{Database, Postgres}; -use sqlx::query::QueryAs; -use crate::data::query::{Condition, QueryBuilder, Value}; +use sqlx::Database; +use crate::data::condition::Condition; +use crate::data::insert::InsertBuilder; +use crate::data::query::QueryBuilder; +use crate::data::update::UpdateBuilder; +use crate::data::Value; use crate::error::SirenResult; const TABLE_NAME: &str = "guilds"; @@ -16,50 +19,31 @@ pub struct GuildCache { impl GuildCache { pub async fn insert(&self) -> SirenResult<()> { - let pool = crate::data::pool(); - sqlx::query(&format!( - "INSERT INTO {} ( - id, - name, - owner_id, - volume - ) VALUES ( - $1, $2, $3, $4 - )", - TABLE_NAME - )) - .bind(self.id) - .bind(&self.name) - .bind(self.owner_id) - .bind(self.volume) - .execute(pool) - .await?; + InsertBuilder::new(TABLE_NAME) + .column("id", Value::BigInt(self.id)) + .column("name", Value::OptionalText(self.name.clone())) + .column("owner_id", Value::OptionalBigInt(self.owner_id)) + .column("volume", Value::Int(self.volume)) + .execute() + .await?; Ok(()) } pub async fn find_by_id(id: i64) -> Option { QueryBuilder::new(TABLE_NAME) - .where_condition(Condition::is_equal("id", Value::BIGINT(id))) + .where_condition(Condition::is_equal("id", Value::BigInt(id))) .fetch_optional() .await } pub async fn update(&self) -> SirenResult<()> { - let pool = crate::data::pool(); - sqlx::query(&format!( - "UPDATE {} SET - name = $2, - owner_id = $3, - volume = $4 - WHERE id = $1", - TABLE_NAME - )) - .bind(self.id) - .bind(&self.name) - .bind(self.owner_id) - .bind(self.volume) - .execute(pool) - .await?; + UpdateBuilder::new(TABLE_NAME) + .column("name", Value::OptionalText(self.name.clone())) + .column("owner_id", Value::OptionalBigInt(self.owner_id)) + .column("volume", Value::Int(self.volume)) + .where_condition(Condition::is_equal("id", Value::BigInt(self.id))) + .execute() + .await?; Ok(()) } } diff --git a/src/data/insert.rs b/src/data/insert.rs new file mode 100644 index 0000000..174353d --- /dev/null +++ b/src/data/insert.rs @@ -0,0 +1,87 @@ +use crate::data::Value; + +pub struct InsertBuilder { + table: String, + columns: Vec, + returning: Vec, + values: Vec, +} + +impl InsertBuilder { + pub fn new(table: &str) -> Self { + Self { + table: table.to_string(), + columns: Vec::new(), + returning: Vec::new(), + values: Vec::new(), + } + } + + pub fn column(mut self, column: &str, value: Value) -> Self { + self.columns.push(column.to_string()); + self.values.push(value); + self + } + + pub fn returning(mut self, columns: &[&str]) -> Self { + self.returning = columns.iter().map(|s| s.to_string()).collect(); + self + } + + pub async fn execute(self) -> Result { + // Build the SQL query and its values + let (query_string, values) = self.build(); + + // Start constructing the query + let mut query = sqlx::query(&query_string); + + // Bind each value to its respective placeholder + for value in values { + match value { + Value::Int(n) => query = query.bind(n), + Value::OptionalInt(n) => query = query.bind(n), + Value::BigInt(n) => query = query.bind(n), + Value::OptionalBigInt(n) => query = query.bind(n), + Value::Float(n) => query = query.bind(n), + Value::OptionalFloat(n) => query = query.bind(n), + Value::Double(n) => query = query.bind(n), + Value::OptionalDouble(n) => query = query.bind(n), + Value::Bool(n) => query = query.bind(n), + Value::OptionalBool(n) => query = query.bind(n), + Value::Text(n) => query = query.bind(n), + Value::OptionalText(n) => query = query.bind(n), + } + } + + let pool = crate::data::pool(); + query.execute(pool).await + } + + fn build(self) -> (String, Vec) { + if self.columns.is_empty() || self.values.is_empty() { + panic!("Cannot build insert query without columns and values"); + } + + // Create the list of column names + let columns = self.columns.join(", "); + + // Generate placeholders for values ($1, $2, etc.) + let placeholders = (1..=self.values.len()) + .map(|i| format!("${}", i)) + .collect::>() + .join(", "); + + // Create the basic INSERT statement + let mut query = format!( + "INSERT INTO {} ({}) VALUES ({})", + self.table, columns, placeholders + ); + + // Add RETURNING clause if specified + if !self.returning.is_empty() { + query.push_str(&format!(" RETURNING {}", self.returning.join(", "))); + } + + (query, self.values) + } +} diff --git a/src/data/mod.rs b/src/data/mod.rs index 15ae33a..bdbe59f 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,13 +1,16 @@ -use std::{sync::OnceLock, time::Duration}; - +use std::{fmt, sync::OnceLock, time::Duration}; +use std::fmt::Display; use redis::{aio::MultiplexedConnection as RedisConnection, Client as RedisClient, RedisResult}; use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; use crate::error::SirenResult; +pub mod condition; pub mod events; pub mod guilds; +pub mod insert; pub mod messages; pub mod query; +pub mod update; static POOL: OnceLock> = OnceLock::new(); static REDIS: OnceLock = OnceLock::new(); @@ -84,3 +87,44 @@ async fn run_migrations() -> SirenResult<()> { sqlx::migrate!().run(pool).await?; Ok(()) } + +#[derive(Debug, Clone)] +pub enum Value { + Int(i32), + OptionalInt(Option), + BigInt(i64), + OptionalBigInt(Option), + Float(f32), + OptionalFloat(Option), + Double(f64), + OptionalDouble(Option), + Bool(bool), + OptionalBool(Option), + Text(String), + OptionalText(Option), +} + +impl Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Value::Int(n) => write!(f, "{}", n), + Value::OptionalInt(Some(n)) => write!(f, "{}", n), + Value::OptionalInt(None) => write!(f, "NULL"), + Value::BigInt(n) => write!(f, "{}", n), + Value::OptionalBigInt(Some(n)) => write!(f, "{}", n), + Value::OptionalBigInt(None) => write!(f, "NULL"), + Value::Float(n) => write!(f, "{}", n), + Value::OptionalFloat(Some(n)) => write!(f, "{}", n), + Value::OptionalFloat(None) => write!(f, "NULL"), + Value::Double(n) => write!(f, "{}", n), + Value::OptionalDouble(Some(n)) => write!(f, "{}", n), + Value::OptionalDouble(None) => write!(f, "NULL"), + Value::Bool(n) => write!(f, "{}", n), + Value::OptionalBool(Some(n)) => write!(f, "{}", n), + Value::OptionalBool(None) => write!(f, "NULL"), + Value::Text(s) => write!(f, "'{}'", s.replace("'", "''")), + Value::OptionalText(Some(s)) => write!(f, "'{}'", s.replace("'", "''")), + Value::OptionalText(None) => write!(f, "NULL"), + } + } +} diff --git a/src/data/query.rs b/src/data/query.rs index a08efc6..4cf3607 100644 --- a/src/data/query.rs +++ b/src/data/query.rs @@ -1,6 +1,8 @@ use std::fmt; use std::fmt::Display; use sqlx::{FromRow, Postgres}; +use crate::data::condition::Condition; +use crate::data::Value; pub struct QueryBuilder { table: String, @@ -52,16 +54,28 @@ impl QueryBuilder { let mut query_as = sqlx::query_as(&query_string); for value in values { match value { - Value::INT(n) => query_as = query_as.bind(n), - Value::BIGINT(n) => query_as = query_as.bind(n), + Value::Int(n) => query_as = query_as.bind(n), + Value::OptionalInt(n) => query_as = query_as.bind(n), + Value::BigInt(n) => query_as = query_as.bind(n), + Value::OptionalBigInt(n) => query_as = query_as.bind(n), + Value::Float(n) => query_as = query_as.bind(n), + Value::OptionalFloat(n) => query_as = query_as.bind(n), + Value::Double(n) => query_as = query_as.bind(n), + Value::OptionalDouble(n) => query_as = query_as.bind(n), Value::Bool(n) => query_as = query_as.bind(n), + Value::OptionalBool(n) => query_as = query_as.bind(n), Value::Text(n) => query_as = query_as.bind(n), + Value::OptionalText(n) => query_as = query_as.bind(n), } } let pool = crate::data::pool(); query_as.fetch_optional(pool).await.unwrap_or_else(|err| { - log::error!("{}", err); + log::error!( + "Unable to fetch optional on query '{}': {}", + query_string, + err + ); None }) } @@ -97,152 +111,3 @@ impl QueryBuilder { (query, values) } } - -#[derive(Debug, Clone)] -pub enum Value { - INT(i32), - BIGINT(i64), - Bool(bool), - Text(String), -} - -impl Display for Value { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Value::INT(n) => write!(f, "{}", n), - Value::BIGINT(n) => write!(f, "{}", n), - Value::Bool(n) => write!(f, "{}", n), - Value::Text(s) => write!(f, "'{}'", s), // Wrap strings in quotes for SQL - } - } -} - -pub enum Condition { - Simple(String, Vec), - And(Box, Box), - Or(Box, Box), - Group(Box), -} - -impl Condition { - pub fn new(condition: &str) -> Self { - Condition::Simple(condition.to_string(), vec![]) - } - - pub fn and(self, other: Self) -> Self { - Condition::And(Box::new(self), Box::new(other)) - } - - pub fn or(self, other: Self) -> Self { - Condition::Or(Box::new(self), Box::new(other)) - } - - pub fn group(self) -> Self { - Condition::Group(Box::new(self)) - } - - pub fn is_equal(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} = ?", left), vec![right.into()]) - } - - pub fn not_equal(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} != ?", left), vec![right.into()]) - } - - pub fn is_null(value: &str) -> Self { - Condition::Simple(format!("{} IS NULL", value), vec![]) - } - - pub fn not_null(value: &str) -> Self { - Condition::Simple(format!("{} IS NOT NULL", value), vec![]) - } - - pub fn is_in(left: &str, right: Vec) -> Self { - let right_list = right - .iter() - .map(|v| "'?'".to_string()) - .collect::>() - .join(", "); - Condition::Simple(format!("{} IN ({})", left, right_list), right) - } - - pub fn not_in(left: &str, right: Vec) -> Self { - let right_list = right - .iter() - .map(|v| "'?'".to_string()) - .collect::>() - .join(", "); - Condition::Simple(format!("{} NOT IN ({})", left, right_list), right) - } - - pub fn like(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} LIKE '?'", left), vec![right.into()]) - } - - pub fn not_like(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} NOT LIKE '?'", left), vec![right.into()]) - } - - pub fn i_like(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} ILIKE '?'", left), vec![right.into()]) - } - - pub fn not_i_like(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} NOT ILIKE '?'", left), vec![right.into()]) - } - - pub fn gt(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} > ?", left), vec![right.into()]) - } - - pub fn gte(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} >= ?", left), vec![right.into()]) - } - - pub fn lt(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} < ?", left), vec![right.into()]) - } - - pub fn lte(left: &str, right: impl Into) -> Self { - Condition::Simple(format!("{} <= ?", left), vec![right.into()]) - } - - fn to_sql(&self, mut counter: &mut usize) -> (String, Vec) { - let mut sql = String::new(); - let mut binds = Vec::new(); - - match self { - Condition::Simple(condition, values) => { - // Replace all instances of '?' with an numbered bind - let mut bind_index = *counter; - let numbered_condition = condition.replace("?", { - bind_index += 1; - &format!("${}", bind_index) - }); - sql.push_str(&numbered_condition); - binds.extend(values.clone()); - } - Condition::And(left, right) => { - let (left_sql, left_binds) = left.to_sql(counter); - let (right_sql, right_binds) = right.to_sql(counter); - sql.push_str(&format!("{} AND {}", left_sql, right_sql)); - binds.extend(left_binds); - binds.extend(right_binds); - } - Condition::Or(left, right) => { - let (left_sql, left_binds) = left.to_sql(counter); - let (right_sql, right_binds) = right.to_sql(counter); - sql.push_str(&format!("{} OR {}", left_sql, right_sql)); - binds.extend(left_binds); - binds.extend(right_binds); - } - Condition::Group(inner) => { - let (inner_sql, inner_binds) = inner.to_sql(counter); - sql.push_str(&format!("({})", inner_sql)); - binds.extend(inner_binds); - } - }; - - (sql, binds) - } -} diff --git a/src/data/update.rs b/src/data/update.rs new file mode 100644 index 0000000..b33848e --- /dev/null +++ b/src/data/update.rs @@ -0,0 +1,84 @@ +use crate::data::condition::Condition; +use crate::data::Value; + +pub struct UpdateBuilder { + table: String, + columns: Vec<(String, Value)>, + condition: Option, +} + +impl UpdateBuilder { + pub fn new(table: &str) -> Self { + Self { + table: table.to_string(), + columns: Vec::new(), + condition: None, + } + } + + pub fn column(mut self, column: &str, value: Value) -> Self { + self.columns.push((column.to_string(), value)); + self + } + + pub fn where_condition(mut self, condition: Condition) -> Self { + self.condition = Some(condition); + self + } + + pub async fn execute(self) -> Result { + // Build the SQL query and its values + let (query_string, values) = self.build(); + + // Start constructing the query + let mut query = sqlx::query(&query_string); + + // Bind each value to its respective placeholder + for value in values { + match value { + Value::Int(n) => query = query.bind(n), + Value::OptionalInt(n) => query = query.bind(n), + Value::BigInt(n) => query = query.bind(n), + Value::OptionalBigInt(n) => query = query.bind(n), + Value::Float(n) => query = query.bind(n), + Value::OptionalFloat(n) => query = query.bind(n), + Value::Double(n) => query = query.bind(n), + Value::OptionalDouble(n) => query = query.bind(n), + Value::Bool(n) => query = query.bind(n), + Value::OptionalBool(n) => query = query.bind(n), + Value::Text(n) => query = query.bind(n), + Value::OptionalText(n) => query = query.bind(n), + } + } + + let pool = crate::data::pool(); + query.execute(pool).await + } + + fn build(self) -> (String, Vec) { + if self.columns.is_empty() { + panic!("Cannot build update query without columns to set"); + } + + // Generate the SET clause + let set_clause = self + .columns + .iter() + .enumerate() + .map(|(i, (col, _))| format!("{} = ${}", col, i + 1)) + .collect::>() + .join(", "); + + // Prepare the WHERE clause if conditions are present + let mut query = format!("UPDATE {} SET {}", self.table, set_clause); + + let mut values: Vec = Vec::new(); + if let Some(condition) = self.condition { + let where_condition = condition.to_sql(&mut 0); + query.push_str(&format!(" WHERE {}", where_condition.0)); + values = where_condition.1; + } + + (query, values) + } +}