Updated query builder with bindings

This commit is contained in:
2024-12-21 22:31:11 -05:00
parent 7718cf19c3
commit 2cd2715d0d
8 changed files with 169 additions and 86 deletions

View File

@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use crate::api::auth::{csprng, AuthCredential}; use crate::api::auth::{csprng, AuthCredential};
use crate::api::auth::AuthorizationMiddleware; use crate::api::auth::AuthorizationMiddleware;
use crate::AppState; use crate::AppState;
use crate::data::query::{Condition, QueryBuilder}; use crate::data::query::{Condition, QueryBuilder, Value};
use crate::error::{Error, SirenResult}; use crate::error::{Error, SirenResult};
pub fn get_routes() -> Router<Arc<AppState>> { pub fn get_routes() -> Router<Arc<AppState>> {
@@ -25,7 +25,7 @@ pub struct ApiKey {
pub user_name: String, pub user_name: String,
pub access_mask: i32, pub access_mask: i32,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>> pub last_used_at: Option<DateTime<Utc>>,
} }
impl ApiKey { impl ApiKey {
@@ -92,9 +92,9 @@ impl ApiKey {
pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> { pub async fn find_by_key(key: &str) -> SirenResult<Option<Self>> {
let pool = crate::data::pool(); let pool = crate::data::pool();
let query = QueryBuilder::new(TABLE_NAME) let query = QueryBuilder::new(TABLE_NAME)
.where_condition(Condition::is_equal("key", "$1")) .where_condition(Condition::is_equal("key", Value::Text(key.to_string())))
.build(); .build();
let item = sqlx::query_as(&query) let item = sqlx::query_as(&query.0)
.bind(key) .bind(key)
.fetch_optional(pool) .fetch_optional(pool)
.await?; .await?;

View File

@@ -36,7 +36,6 @@ enum TrackDiceOperator {
GreaterThanEqual, GreaterThanEqual,
} }
// Implementing the ToString trait for converting the enum to a string // Implementing the ToString trait for converting the enum to a string
impl Display for TrackDiceOperator { impl Display for TrackDiceOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -110,8 +109,7 @@ impl QueryDiceTrack {
// ) // )
// ) // )
.build(); .build();
let items: Vec<QueryDiceTrack> = sqlx::query_as(&query) let items: Vec<QueryDiceTrack> = sqlx::query_as(&query.0).fetch_all(pool).await?;
.fetch_all(pool).await?;
Ok(items) Ok(items)
} }
@@ -140,9 +138,11 @@ impl InsertDiceTrack {
.bind(self.user_id) .bind(self.user_id)
.bind(self.value) .bind(self.value)
.bind(&self.operator) .bind(&self.operator)
.fetch_optional(pool).await? { .fetch_optional(pool)
.await?
{
Some(result) => result, Some(result) => result,
None => return Err(Error::new(500, "Error storing".to_string())) None => return Err(Error::new(500, "Error storing".to_string())),
}; };
Ok(item) Ok(item)
} }
@@ -154,7 +154,6 @@ pub async fn add_track_dice(
Path(guild_id): Path<u64>, Path(guild_id): Path<u64>,
Json(payload): Json<DiceTrackPayload>, Json(payload): Json<DiceTrackPayload>,
) -> SirenResult<Json<QueryDiceTrack>> { ) -> SirenResult<Json<QueryDiceTrack>> {
// Check if the user exists in the cache // Check if the user exists in the cache
let owner_id = credential.user_id(); let owner_id = credential.user_id();
let owner_id = match state.cache.user(owner_id) { let owner_id = match state.cache.user(owner_id) {
@@ -179,7 +178,7 @@ pub async fn add_track_dice(
operator: match payload.operator { operator: match payload.operator {
None => None, None => None,
Some(s) => Some(s.to_string()), Some(s) => Some(s.to_string()),
} },
}; };
let dice_track = dice.insert().await?; let dice_track = dice.insert().await?;

View File

@@ -89,9 +89,7 @@ pub async fn enqueue_track(
let mut playlist_items: Vec<YtDlpItem> = Vec::new(); let mut playlist_items: Vec<YtDlpItem> = Vec::new();
if let Some(handler_lock) = manager.get(guild_id) { if let Some(handler_lock) = manager.get(guild_id) {
let mut handler = handler_lock.lock().await; let mut handler = handler_lock.lock().await;
let guild = GuildCache::find_by_id(guild_id.get() as i64) let guild = GuildCache::find_by_id(guild_id.get() as i64).await.unwrap();
.await?
.unwrap();
let valid = is_valid_url(&track_url); let valid = is_valid_url(&track_url);
// Check if the URL is valid // Check if the URL is valid

View File

@@ -64,10 +64,7 @@ pub async fn set_volume(manager: &Arc<Songbird>, guild_id: &GuildId, volume: i32
let bound_volume = volume as f32 / 100.0; let bound_volume = volume as f32 / 100.0;
// Update the guild cache // Update the guild cache
let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64) let mut guild_cache = GuildCache::find_by_id(guild_id.get() as i64).await.unwrap();
.await
.unwrap()
.unwrap();
guild_cache.volume = volume; guild_cache.volume = volume;
guild_cache.update().await.unwrap(); guild_cache.update().await.unwrap();

View File

@@ -66,7 +66,6 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await, None => edit_response(&ctx, &command, format!("🎲 {}\n-# {}", total, response)).await,
}; };
// Check for dice tracks // Check for dice tracks
} }
Err(why) => { Err(why) => {
edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await; edit_response(&ctx, &command, format!("Invalid dice string: {}", why)).await;
@@ -161,17 +160,23 @@ pub fn parse_dice(dice: &str) -> SirenResult<(u32, u32, i32)> {
if [4, 6, 8, 10, 12, 20, 100].contains(&n) { if [4, 6, 8, 10, 12, 20, 100].contains(&n) {
n n
} else { } else {
return Err(Error::new(400, format!( return Err(Error::new(
400,
format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
n n
))); ),
));
} }
} }
Err(_) => { Err(_) => {
return Err(Error::new(400, format!( return Err(Error::new(
400,
format!(
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
sides_part sides_part
))) ),
))
} }
}; };

View File

@@ -116,7 +116,7 @@ impl EventHandler for BotHandler {
for guild in ready.guilds { for guild in ready.guilds {
// Check if guild exists in database // Check if guild exists in database
let guild_id = guild.id.get() as i64; let guild_id = guild.id.get() as i64;
if let None = GuildCache::find_by_id(guild_id).await.unwrap() { if let None = GuildCache::find_by_id(guild_id).await {
let guild_cache = GuildCache { let guild_cache = GuildCache {
id: guild_id, id: guild_id,
name: guild.id.name(&ctx.cache), name: guild.id.name(&ctx.cache),

View File

@@ -1,5 +1,7 @@
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::data::query::{Condition, QueryBuilder}; use sqlx::{Database, Postgres};
use sqlx::query::QueryAs;
use crate::data::query::{Condition, QueryBuilder, Value};
use crate::error::SirenResult; use crate::error::SirenResult;
const TABLE_NAME: &str = "guilds"; const TABLE_NAME: &str = "guilds";
@@ -35,14 +37,11 @@ impl GuildCache {
Ok(()) Ok(())
} }
pub async fn find_by_id(id: i64) -> SirenResult<Option<Self>> { pub async fn find_by_id(id: i64) -> Option<Self> {
let pool = crate::data::pool(); QueryBuilder::new(TABLE_NAME)
let query = QueryBuilder::new(TABLE_NAME) .where_condition(Condition::is_equal("id", Value::BIGINT(id)))
.where_condition(Condition::is_equal("id", "$1")) .fetch_optional()
.build(); .await
let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?;
Ok(item)
} }
pub async fn update(&self) -> SirenResult<()> { pub async fn update(&self) -> SirenResult<()> {

View File

@@ -1,9 +1,14 @@
use std::fmt;
use std::fmt::Display;
use sqlx::{FromRow, Postgres};
pub struct QueryBuilder { pub struct QueryBuilder {
table: String, table: String,
columns: Vec<String>, columns: Vec<String>,
condition: Option<Condition>, condition: Option<Condition>,
order_by: Vec<String>, order_by: Vec<String>,
limit: Option<usize>, limit: Option<usize>,
offset: Option<usize>,
} }
impl QueryBuilder { impl QueryBuilder {
@@ -14,6 +19,7 @@ impl QueryBuilder {
condition: None, condition: None,
order_by: Vec::new(), order_by: Vec::new(),
limit: None, limit: None,
offset: None,
} }
} }
@@ -37,7 +43,30 @@ impl QueryBuilder {
self self
} }
pub fn build(self) -> String { pub async fn fetch_optional<
T: Send + Unpin + for<'r> FromRow<'r, <Postgres as sqlx::Database>::Row>,
>(
self,
) -> Option<T> {
let (query_string, values) = self.build();
let mut query_as = sqlx::query_as(&query_string);
for value in values {
match value {
Value::INT(n) => query_as = query_as.bind(n),
Value::BIGINT(n) => query_as = query_as.bind(n),
Value::Bool(n) => query_as = query_as.bind(n),
Value::Text(n) => query_as = query_as.bind(n),
}
}
let pool = crate::data::pool();
query_as.fetch_optional(pool).await.unwrap_or_else(|err| {
log::error!("{}", err);
None
})
}
pub fn build(self) -> (String, Vec<Value>) {
let columns = if self.columns.is_empty() { let columns = if self.columns.is_empty() {
"*".to_string() "*".to_string()
} else { } else {
@@ -46,8 +75,11 @@ impl QueryBuilder {
let mut query = format!("SELECT {} FROM {}", columns, self.table); let mut query = format!("SELECT {} FROM {}", columns, self.table);
let mut values: Vec<Value> = Vec::new();
if let Some(condition) = self.condition { if let Some(condition) = self.condition {
query.push_str(&format!(" WHERE {}", condition.to_sql())); let where_condition = condition.to_sql(&mut 0);
query.push_str(&format!(" WHERE {}", where_condition.0));
values = where_condition.1;
} }
if !self.order_by.is_empty() { if !self.order_by.is_empty() {
@@ -58,12 +90,35 @@ impl QueryBuilder {
query.push_str(&format!(" LIMIT {}", limit)); query.push_str(&format!(" LIMIT {}", limit));
} }
query if let Some(offset) = self.offset {
query.push_str(&format!(" OFFSET {}", offset));
}
(query, values)
}
}
#[derive(Debug, Clone)]
pub enum Value {
INT(i32),
BIGINT(i64),
Bool(bool),
Text(String),
}
impl Display for Value {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Value::INT(n) => write!(f, "{}", n),
Value::BIGINT(n) => write!(f, "{}", n),
Value::Bool(n) => write!(f, "{}", n),
Value::Text(s) => write!(f, "'{}'", s), // Wrap strings in quotes for SQL
}
} }
} }
pub enum Condition { pub enum Condition {
Simple(String), Simple(String, Vec<Value>),
And(Box<Condition>, Box<Condition>), And(Box<Condition>, Box<Condition>),
Or(Box<Condition>, Box<Condition>), Or(Box<Condition>, Box<Condition>),
Group(Box<Condition>), Group(Box<Condition>),
@@ -71,7 +126,7 @@ pub enum Condition {
impl Condition { impl Condition {
pub fn new(condition: &str) -> Self { pub fn new(condition: &str) -> Self {
Condition::Simple(condition.to_string()) Condition::Simple(condition.to_string(), vec![])
} }
pub fn and(self, other: Self) -> Self { pub fn and(self, other: Self) -> Self {
@@ -86,78 +141,108 @@ impl Condition {
Condition::Group(Box::new(self)) Condition::Group(Box::new(self))
} }
pub fn is_equal(left: &str, right: &str) -> Self { pub fn is_equal(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} = {}", left, right)) Condition::Simple(format!("{} = ?", left), vec![right.into()])
} }
pub fn not_equal(left: &str, right: &str) -> Self { pub fn not_equal(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} != {}", left, right)) Condition::Simple(format!("{} != ?", left), vec![right.into()])
} }
pub fn is_null(value: &str) -> Self { pub fn is_null(value: &str) -> Self {
Condition::Simple(format!("{} IS NULL", value)) Condition::Simple(format!("{} IS NULL", value), vec![])
} }
pub fn not_null(value: &str) -> Self { pub fn not_null(value: &str) -> Self {
Condition::Simple(format!("{} IS NOT NULL", value)) Condition::Simple(format!("{} IS NOT NULL", value), vec![])
} }
pub fn is_in(left: &str, right: &[&str]) -> Self { pub fn is_in(left: &str, right: Vec<Value>) -> Self {
let right_list = right let right_list = right
.iter() .iter()
.map(|v| format!("'{}'", v)) .map(|v| "'?'".to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "); .join(", ");
Condition::Simple(format!("{} IN ({})", left, right_list)) Condition::Simple(format!("{} IN ({})", left, right_list), right)
} }
pub fn not_in(left: &str, right: &[&str]) -> Self { pub fn not_in(left: &str, right: Vec<Value>) -> Self {
let right_list = right let right_list = right
.iter() .iter()
.map(|v| format!("'{}'", v)) .map(|v| "'?'".to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "); .join(", ");
Condition::Simple(format!("{} NOT IN ({})", left, right_list)) Condition::Simple(format!("{} NOT IN ({})", left, right_list), right)
} }
pub fn like(left: &str, right: &str) -> Self { pub fn like(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} LIKE '{}'", left, right)) Condition::Simple(format!("{} LIKE '?'", left), vec![right.into()])
} }
pub fn not_like(left: &str, right: &str) -> Self { pub fn not_like(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} NOT LIKE '{}'", left, right)) Condition::Simple(format!("{} NOT LIKE '?'", left), vec![right.into()])
} }
pub fn i_like(left: &str, right: &str) -> Self { pub fn i_like(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} ILIKE '{}'", left, right)) Condition::Simple(format!("{} ILIKE '?'", left), vec![right.into()])
} }
pub fn not_i_like(left: &str, right: &str) -> Self { pub fn not_i_like(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} NOT ILIKE '{}'", left, right)) Condition::Simple(format!("{} NOT ILIKE '?'", left), vec![right.into()])
} }
pub fn gt(left: &str, right: &str) -> Self { pub fn gt(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} > {}", left, right)) Condition::Simple(format!("{} > ?", left), vec![right.into()])
} }
pub fn gte(left: &str, right: &str) -> Self { pub fn gte(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} >= {}", left, right)) Condition::Simple(format!("{} >= ?", left), vec![right.into()])
} }
pub fn lt(left: &str, right: &str) -> Self { pub fn lt(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} < {}", left, right)) Condition::Simple(format!("{} < ?", left), vec![right.into()])
} }
pub fn lte(left: &str, right: &str) -> Self { pub fn lte(left: &str, right: impl Into<Value>) -> Self {
Condition::Simple(format!("{} <= {}", left, right)) Condition::Simple(format!("{} <= ?", left), vec![right.into()])
} }
fn to_sql(&self) -> String { fn to_sql(&self, mut counter: &mut usize) -> (String, Vec<Value>) {
let mut sql = String::new();
let mut binds = Vec::new();
match self { match self {
Condition::Simple(s) => s.to_string(), Condition::Simple(condition, values) => {
Condition::And(a, b) => format!("{} AND {}", a.to_sql(), b.to_sql()), // Replace all instances of '?' with an numbered bind
Condition::Or(a, b) => format!("{} OR {}", a.to_sql(), b.to_sql()), let mut bind_index = *counter;
Condition::Group(a) => format!("({})", a.to_sql()), let numbered_condition = condition.replace("?", {
bind_index += 1;
&format!("${}", bind_index)
});
sql.push_str(&numbered_condition);
binds.extend(values.clone());
} }
Condition::And(left, right) => {
let (left_sql, left_binds) = left.to_sql(counter);
let (right_sql, right_binds) = right.to_sql(counter);
sql.push_str(&format!("{} AND {}", left_sql, right_sql));
binds.extend(left_binds);
binds.extend(right_binds);
}
Condition::Or(left, right) => {
let (left_sql, left_binds) = left.to_sql(counter);
let (right_sql, right_binds) = right.to_sql(counter);
sql.push_str(&format!("{} OR {}", left_sql, right_sql));
binds.extend(left_binds);
binds.extend(right_binds);
}
Condition::Group(inner) => {
let (inner_sql, inner_binds) = inner.to_sql(counter);
sql.push_str(&format!("({})", inner_sql));
binds.extend(inner_binds);
}
};
(sql, binds)
} }
} }