Switched from diesel to sqlx

This commit is contained in:
2024-09-05 11:52:14 -04:00
parent bce363db7e
commit d08800f9e0
42 changed files with 365 additions and 687 deletions

View File

@@ -17,9 +17,7 @@ serde_json = "1.0.127"
serenity = { version = "0.12.2", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] } serenity = { version = "0.12.2", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] }
songbird = { version = "0.4.3", features = ["builtin-queue"] } songbird = { version = "0.4.3", features = ["builtin-queue"] }
symphonia = { version = "0.5.4", features = ["all"] } symphonia = { version = "0.5.4", features = ["all"] }
diesel = { version = "2.1.5", default-features = false, features = ["postgres", "chrono", "r2d2", "32-column-tables", "serde_json", "with-deprecated"] } sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "chrono", "uuid"] }
diesel_migrations = { version = "2.1.0", features = ["postgres"] }
r2d2 = "0.8.10"
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }
reqwest = { version = "0.11", default-features = false, features = ["json"] } reqwest = { version = "0.11", default-features = false, features = ["json"] }
lazy_static = "1.5.0" lazy_static = "1.5.0"

View File

@@ -50,4 +50,4 @@ docker-clean: ## Stop the docker containers and remove volumes
docker-refresh: docker-clean backend-up ## Refresh the docker containers docker-refresh: docker-clean backend-up ## Refresh the docker containers
psql: ## Connect to the database psql: ## Connect to the database
@docker exec -it siren-db psql -U ${DATABASE_USER} -P pager=off @docker exec -it siren-postgres psql -U ${DATABASE_USER} -P pager=off

View File

@@ -6,51 +6,51 @@ x-env_file: &env
name: siren name: siren
services: services:
# bot: bot:
# image: siren-service:${SIREN_VERSION:-latest} image: siren-service:${SIREN_VERSION:-latest}
# container_name: siren-service container_name: siren-service
# build: build:
# context: . context: .
# dockerfile: ./Dockerfile dockerfile: ./Dockerfile
# args: args:
# - VERSION=${SIREN_VERSION:-latest} - VERSION=${SIREN_VERSION:-latest}
# env_file: *env env_file: *env
# environment: environment:
# DATABASE_HOST: db DATABASE_HOST: db
# DATABASE_PORT: 5432 DATABASE_PORT: 5432
# REDIS_HOST: redis REDIS_HOST: redis
# REDIS_PORT: 6379 REDIS_PORT: 6379
# MINIO_HOST: minio MINIO_HOST: minio
# MINIO_PORT: 9000 MINIO_PORT: 9000
# SERVICE_HOST: service SERVICE_HOST: service
# SERVICE_PORT: 5000 SERVICE_PORT: 5000
# DATA_DIR_PATH: /data DATA_DIR_PATH: /data
# volumes: volumes:
# - ${DATA_DIR_PATH:-/data}:/data - ${DATA_DIR_PATH:-/data}:/data
# ports: ports:
# - ${SERVICE_PORT:-5000}:5000 - ${SERVICE_PORT:-5000}:5000
# depends_on: depends_on:
# - db - db
# - redis - redis
# - minio - minio
# networks: networks:
# - frontend - frontend
# - backend - backend
# restart: unless-stopped restart: unless-stopped
# profiles: profiles:
# - bot - bot
db: postgres:
image: postgres:latest image: postgres:latest
container_name: siren-db container_name: siren-postgres
env_file: *env env_file: *env
environment: environment:
POSTGRES_USER: ${DATABASE_USER} POSTGRES_USER: ${DATABASE_USER}
POSTGRES_PASSWORD: ${DATABASE_PASSWORD} POSTGRES_PASSWORD: ${DATABASE_PASSWORD}
POSTGRES_DB: ${DATABASE_NAME} POSTGRES_DB: ${DATABASE_NAME}
volumes: volumes:
- db:/var/lib/postgresql/data - postgres:/var/lib/postgresql/data
- db_logs:/var/log - postgres_logs:/var/log
ports: ports:
- ${DATABASE_PORT:-5432}:5432 - ${DATABASE_PORT:-5432}:5432
networks: networks:
@@ -73,8 +73,8 @@ services:
restart: unless-stopped restart: unless-stopped
volumes: volumes:
db: postgres:
db_logs: postgres_logs:
redis: redis:
networks: networks:

View File

@@ -1 +0,0 @@
DROP TABLE messages;

View File

@@ -1,12 +0,0 @@
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
channel_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
created BIGINT NOT NULL,
model TEXT NOT NULL,
request TEXT NOT NULL,
response TEXT NOT NULL,
request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL
);

View File

@@ -1 +0,0 @@
DROP TABLE races;

View File

@@ -1,7 +0,0 @@
CREATE TABLE IF NOT EXISTS races (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
size TEXT NOT NULL,
source TEXT NOT NULL,
data JSON NOT NULL
);

View File

@@ -1 +0,0 @@
DROP TABLE classes;

View File

@@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS classes (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1 +0,0 @@
DROP TABLE feats;

View File

@@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS feats (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1 +0,0 @@
DROP TABLE options_features;

View File

@@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS options_features (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1 +0,0 @@
DROP TABLE backgrounds;

View File

@@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS backgrounds (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1 +0,0 @@
DROP TABLE items;

View File

@@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS items (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1 +0,0 @@
DROP TABLE spells;

View File

@@ -1,15 +0,0 @@
CREATE TABLE IF NOT EXISTS spells (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
school TEXT NOT NULL,
level INTEGER NOT NULL,
ritual BOOLEAN DEFAULT FALSE,
concentration BOOLEAN DEFAULT FALSE,
classes TEXT[] NOT NULL,
damage_inflict TEXT[] NOT NULL,
damage_resist TEXT[] NOT NULL,
conditions TEXT[] NOT NULL,
saving_throw TEXT[] NOT NULL,
attack_type TEXT,
data JSONB NOT NULL
);

View File

@@ -1 +0,0 @@
DROP TABLE conditions;

View File

@@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS conditions (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1 +0,0 @@
DROP TABLE bestiary;

View File

@@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS bestiary (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1 +0,0 @@
DROP TABLE guilds;

View File

@@ -1,5 +0,0 @@
CREATE TABLE IF NOT EXISTS guilds (
id BIGINT PRIMARY KEY NOT NULL,
bot_id BIGINT NOT NULL,
volume INTEGER NOT NULL
);

View File

@@ -1 +0,0 @@
DROP TABLE users;

View File

@@ -1,11 +0,0 @@
CREATE TABLE IF NOT EXISTS users (
email TEXT PRIMARY KEY NOT NULL,
hash TEXT NOT NULL,
role TEXT NOT NULL,
first_name TEXT NOT NULL,
last_name TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
profile_picture TEXT,
verified BOOLEAN NOT NULL DEFAULT FALSE
);

28
migrations/000_base.sql Normal file
View File

@@ -0,0 +1,28 @@
CREATE TABLE IF NOT EXISTS guilds (
id BIGINT PRIMARY KEY NOT NULL,
bot_id BIGINT NOT NULL,
volume INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS users (
email TEXT PRIMARY KEY NOT NULL,
hash TEXT NOT NULL,
role TEXT NOT NULL,
first_name TEXT NOT NULL,
last_name TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
profile_picture TEXT,
verified BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
channel_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
created BIGINT NOT NULL,
model TEXT NOT NULL,
request TEXT NOT NULL,
response TEXT NOT NULL,
request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL
);

View File

@@ -0,0 +1,43 @@
CREATE TABLE IF NOT EXISTS races (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
size TEXT NOT NULL,
source TEXT NOT NULL,
data JSON NOT NULL
);
CREATE TABLE IF NOT EXISTS classes (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS feats (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS options_features (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS backgrounds (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS items (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS spells (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
school TEXT NOT NULL,
level INTEGER NOT NULL,
ritual BOOLEAN DEFAULT FALSE,
concentration BOOLEAN DEFAULT FALSE,
classes TEXT[] NOT NULL,
damage_inflict TEXT[] NOT NULL,
damage_resist TEXT[] NOT NULL,
conditions TEXT[] NOT NULL,
saving_throw TEXT[] NOT NULL,
attack_type TEXT,
data JSONB NOT NULL
);
CREATE TABLE IF NOT EXISTS conditions (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS bestiary (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -106,7 +106,7 @@ pub async fn play_track(
let call_handler = handler_lock.lock().await; let call_handler = handler_lock.lock().await;
call_handler.queue().is_empty() call_handler.queue().is_empty()
}; };
let guild = GuildCache::get(guild_id.get() as i64)?; let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap();
let valid = is_valid_url(&track_url); let valid = is_valid_url(&track_url);
// Check if the URL is valid // Check if the URL is valid
if !valid.0 { if !valid.0 {

View File

@@ -62,7 +62,9 @@ pub async fn set_volume(manager: Arc<Songbird>, guild_id: GuildId, volume: i32)
// Format volume to f32 bound between 0.0 and 1.0 // Format volume to f32 bound between 0.0 and 1.0
let volume = std::cmp::min(100, std::cmp::max(0, volume)); let volume = std::cmp::min(100, std::cmp::max(0, volume));
let bound_volume = volume as f32 / 100.0; let bound_volume = volume as f32 / 100.0;
let _ = GuildCache::update_audio(guild_id.get() as i64, volume); let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64).await.unwrap().unwrap();
guild_cache.volume = volume;
guild_cache.update().await.unwrap();
if let Some(handler_lock) = manager.get(guild_id) { if let Some(handler_lock) = manager.get(guild_id) {
let handler = handler_lock.lock().await; let handler = handler_lock.lock().await;

View File

@@ -6,7 +6,7 @@ use serenity::model::channel::Message;
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
use serenity::prelude::*; use serenity::prelude::*;
use crate::bot::messages::{QueryFilters, QueryMessage}; use crate::bot::messages::MessageCache;
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI}; use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
@@ -27,30 +27,30 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
}, },
]; ];
match QueryMessage::get_all( // match MessageCache::get_all(
&QueryFilters { // &QueryFilters {
by_guild_id: Some(guild_id.get() as i64), // by_guild_id: Some(guild_id.get() as i64),
by_channel_id: Some(channel_id.get() as i64), // by_channel_id: Some(channel_id.get() as i64),
by_user_id: Some(author_id.get() as i64), // by_user_id: Some(author_id.get() as i64),
..Default::default() // ..Default::default()
}, // },
100, // 100,
1, // 1,
) { // ) {
Ok(m) => { // Ok(m) => {
for message in m { // for message in m {
messages.push(ChatCompletionMessage { // messages.push(ChatCompletionMessage {
role: GPTRole::User, // role: GPTRole::User,
content: format!("{}", message.request), // content: format!("{}", message.request),
}); // });
messages.push(ChatCompletionMessage { // messages.push(ChatCompletionMessage {
role: GPTRole::Assistant, // role: GPTRole::Assistant,
content: format!("{}", message.response), // content: format!("{}", message.response),
}); // });
} // }
} // }
Err(err) => warn!("Could not load previous messages: {}", err), // Err(err) => warn!("Could not load previous messages: {}", err),
}; // };
messages.push(ChatCompletionMessage { messages.push(ChatCompletionMessage {
role: GPTRole::User, role: GPTRole::User,
content: parsed_content.clone(), content: parsed_content.clone(),
@@ -98,7 +98,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
trace!("Processing response received from OpenAI"); trace!("Processing response received from OpenAI");
if !r.choices.is_empty() { if !r.choices.is_empty() {
let res = r.choices[0].message.content.clone(); let res = r.choices[0].message.content.clone();
if let Err(err) = QueryMessage::insert(QueryMessage { let message_cache = MessageCache {
id: r.id, id: r.id,
guild_id: guild_id.get() as i64, guild_id: guild_id.get() as i64,
channel_id: response_channel.get() as i64, channel_id: response_channel.get() as i64,
@@ -109,7 +109,8 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
response: res.clone(), response: res.clone(),
request_tags: vec![], request_tags: vec![],
response_tags: vec![], response_tags: vec![],
}) { };
if let Err(err) = message_cache.insert().await {
warn!("{}", err); warn!("{}", err);
} }
res res

View File

@@ -1,11 +1,9 @@
use diesel::prelude::*;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::error::SirenResult; use crate::error::SirenResult;
use crate::storage::{schema::guilds, connection}; const TABLE_NAME: &str = "guilds";
#[derive(Insertable, AsChangeset, Queryable, QueryableByName, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
#[diesel(table_name = guilds)]
pub struct GuildCache { pub struct GuildCache {
pub id: i64, pub id: i64,
pub bot_id: i64, pub bot_id: i64,
@@ -13,25 +11,50 @@ pub struct GuildCache {
} }
impl GuildCache { impl GuildCache {
pub fn insert(&self) -> SirenResult<Self> { pub async fn insert(&self) -> SirenResult<()> {
let mut conn = connection()?; let pool = crate::database::pool();
let guild = diesel::insert_into(guilds::table) sqlx::query(&format!(
.values(self) "INSERT INTO {} (
.get_result(&mut conn)?; id,
Ok(guild) bot_id,
volume
) VALUES (
$1, $2, $3
)",
TABLE_NAME
))
.bind(self.id)
.bind(self.bot_id)
.bind(self.volume)
.execute(pool)
.await?;
Ok(())
} }
pub fn get(id: i64) -> SirenResult<Self> { pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
let mut conn = connection()?; let pool = crate::database::pool();
let guild = guilds::table.filter(guilds::id.eq(id)).first(&mut conn)?; let item =
Ok(guild) sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME))
.bind(id)
.fetch_optional(pool)
.await?;
Ok(item)
} }
pub fn update_audio(id: i64, volume: i32) -> SirenResult<Self> { pub async fn update(&self) -> SirenResult<()> {
let mut conn = connection()?; let pool = crate::database::pool();
let guild = diesel::update(guilds::table.filter(guilds::id.eq(id))) sqlx::query(&format!(
.set(guilds::volume.eq(volume)) "UPDATE {} SET
.get_result(&mut conn)?; bot_id = $2,
Ok(guild) volume = $3
WHERE id = $1",
TABLE_NAME))
.bind(self.id)
.bind(self.bot_id)
.bind(self.volume)
.execute(pool)
.await?;
Ok(())
} }
} }

View File

@@ -82,13 +82,13 @@ impl EventHandler for Handler {
for guild in ready.guilds { for guild in ready.guilds {
// Check if guild exists in database // Check if guild exists in database
let guild_id = guild.id.get() as i64; let guild_id = guild.id.get() as i64;
if let Err(why) = GuildCache::get(guild_id) { if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
let guild_cache = GuildCache { let guild_cache = GuildCache {
id: guild_id, id: guild_id,
bot_id: 1, bot_id: 1,
volume: 100 volume: 100
}; };
guild_cache.insert(); guild_cache.insert().await.unwrap();
} }
let commands = guild let commands = guild
.id .id

View File

@@ -1,15 +1,10 @@
use diesel::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::SirenResult; use crate::error::SirenResult;
use crate::storage::{ const TABLE_NAME: &str = "messages";
schema::messages::{self},
connection,
};
#[derive(Queryable, Selectable, Insertable, AsChangeset, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
#[diesel(table_name = messages)] pub struct MessageCache {
pub struct QueryMessage {
pub id: String, pub id: String,
pub guild_id: i64, pub guild_id: i64,
pub channel_id: i64, pub channel_id: i64,
@@ -22,118 +17,38 @@ pub struct QueryMessage {
pub response_tags: Vec<String>, pub response_tags: Vec<String>,
} }
pub struct QueryFilters { impl MessageCache {
pub by_id: Option<String>, pub async fn insert(&self) -> SirenResult<()> {
pub by_guild_id: Option<i64>, let pool = crate::database::pool();
pub by_channel_id: Option<i64>, sqlx::query(&format!(
pub by_user_id: Option<i64>, "INSERT INTO {} (
pub by_model: Option<String>, id,
pub by_request: Option<String>, guild_id,
pub by_response: Option<String>, channel_id,
pub by_request_tags: Option<Vec<String>>, user_id,
pub by_response_tags: Option<Vec<String>>, created,
} model,
request,
impl Default for QueryFilters { response,
fn default() -> Self { request_tags,
QueryFilters { response_tags
by_id: None, ) VALUES (
by_guild_id: None, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10
by_channel_id: None, )",
by_user_id: None, TABLE_NAME
by_model: None, ))
by_request: None, .bind(&self.id)
by_response: None, .bind(self.guild_id)
by_request_tags: None, .bind(self.channel_id)
by_response_tags: None, .bind(self.user_id)
} .bind(self.created)
} .bind(&self.model)
} .bind(&self.request)
.bind(&self.response)
impl QueryMessage { .bind(&self.request_tags)
pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> SirenResult<Vec<Self>> { .bind(&self.response_tags)
let mut conn = connection()?; .execute(pool)
let mut query = messages::table .await?;
.limit(limit as i64) Ok(())
.order(messages::created.asc())
.into_boxed();
// Limit query to page and limit
let offset = (page - 1) * limit;
query = query.offset(offset as i64);
// Apply filters
if let Some(id) = &filters.by_id {
query = query.filter(messages::id.eq(id));
}
if let Some(guild_id) = &filters.by_guild_id {
query = query.filter(messages::guild_id.eq(guild_id));
}
if let Some(channel_id) = &filters.by_channel_id {
query = query.filter(messages::channel_id.eq(channel_id));
}
if let Some(user_id) = &filters.by_user_id {
query = query.filter(messages::user_id.eq(user_id));
}
if let Some(model) = &filters.by_model {
query = query.filter(messages::model.eq(model));
}
if let Some(request) = &filters.by_request {
query = query.filter(messages::request.eq(request));
}
if let Some(response) = &filters.by_response {
query = query.filter(messages::response.eq(response));
}
if let Some(request_tags) = &filters.by_request_tags {
query = query.filter(messages::request_tags.eq(request_tags));
}
if let Some(response_tags) = &filters.by_response_tags {
query = query.filter(messages::response_tags.eq(response_tags));
}
// Execute query
let messages = query.load::<Self>(&mut conn)?;
Ok(messages)
}
pub fn get_count(fitlers: &QueryFilters) -> SirenResult<i64> {
let mut conn = connection()?;
let mut query = messages::table.into_boxed();
// Apply filters
if let Some(id) = &fitlers.by_id {
query = query.filter(messages::id.eq(id));
}
if let Some(guild_id) = &fitlers.by_guild_id {
query = query.filter(messages::guild_id.eq(guild_id));
}
if let Some(channel_id) = &fitlers.by_channel_id {
query = query.filter(messages::channel_id.eq(channel_id));
}
if let Some(user_id) = &fitlers.by_user_id {
query = query.filter(messages::user_id.eq(user_id));
}
if let Some(model) = &fitlers.by_model {
query = query.filter(messages::model.eq(model));
}
if let Some(request) = &fitlers.by_request {
query = query.filter(messages::request.eq(request));
}
if let Some(response) = &fitlers.by_response {
query = query.filter(messages::response.eq(response));
}
if let Some(request_tags) = &fitlers.by_request_tags {
query = query.filter(messages::request_tags.eq(request_tags));
}
if let Some(response_tags) = &fitlers.by_response_tags {
query = query.filter(messages::response_tags.eq(response_tags));
}
// Execute query
let count = query.count().get_result::<i64>(&mut conn)?;
Ok(count)
}
pub fn insert(message: Self) -> SirenResult<QueryMessage> {
let mut conn = connection()?;
let message = diesel::insert_into(messages::table)
.values(message)
.get_result(&mut conn)?;
Ok(message)
} }
} }

81
src/database/mod.rs Normal file
View File

@@ -0,0 +1,81 @@
use std::{sync::OnceLock, time::Duration};
use redis::{aio::MultiplexedConnection as RedisConnection, Client as RedisClient, RedisResult};
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use crate::error::SirenResult;
static POOL: OnceLock<Pool<Postgres>> = OnceLock::new();
static REDIS: OnceLock<RedisClient> = OnceLock::new();
pub async fn initialize() -> SirenResult<()> {
log::info!("Initializing database...");
let db_user = std::env::var("DATABASE_USER").unwrap_or("siren".to_string());
let db_password = std::env::var("DATABASE_PASSWORD").expect("DATABASE_PASSWORD must be set");
let db_host: String = std::env::var("DATABASE_HOST").expect("DATABASE_HOST must be set");
let db_port = std::env::var("DATABASE_PORT").unwrap_or("5432".to_string());
let db_name = std::env::var("DATABASE_NAME").unwrap_or("siren".to_string());
// Setup Postgres pool connection
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(30))
.connect(&format!(
"postgres://{}:{}@{}:{}/{}",
db_user, db_password, db_host, db_port, db_name
))
.await?;
match POOL.set(pool) {
Ok(_) => {}
Err(_) => {
log::warn!("Database pool already initialized");
}
}
// Setup Redis connection
let redis = {
let host = std::env::var("REDIS_HOST").unwrap_or("localhost".to_string());
let port = std::env::var("REDIS_PORT").unwrap_or("6379".to_string());
let url = format!("redis://{}:{}", host, port);
RedisClient::open(url).expect("Failed to create redis client")
};
match REDIS.set(redis) {
Ok(_) => {}
Err(_) => {
log::warn!("Redis client already initialized");
}
}
// Run migrations
match run_migrations().await {
Ok(_) => log::debug!("Successfully ran migrations"),
Err(e) => log::error!("Failed to run migrations: {}", e),
}
log::info!("Database initialized");
Ok(())
}
pub fn pool() -> &'static Pool<Postgres> {
POOL.get().unwrap()
}
fn redis() -> &'static RedisClient {
REDIS.get().unwrap()
}
pub fn redis_connection() -> RedisResult<redis::Connection> {
let conn = redis().get_connection()?;
Ok(conn)
}
pub async fn redis_async_connection() -> RedisResult<RedisConnection> {
let conn = redis().get_multiplexed_async_connection().await?;
Ok(conn)
}
async fn run_migrations() -> SirenResult<()> {
log::debug!("Running migrations");
let pool = pool();
sqlx::migrate!().run(pool).await?;
Ok(())
}

View File

@@ -7,7 +7,6 @@ use std::{
io::BufReader, io::BufReader,
}; };
use log::{warn, trace};
pub use model::*; pub use model::*;
pub use types::*; pub use types::*;
@@ -29,25 +28,25 @@ pub fn load_data(data_dir_path: &str) {
match result { match result {
Ok(spells) => { Ok(spells) => {
for spell in spells { for spell in spells {
let mut filters = QueryFilters::default(); // let mut filters = QueryFilters::default();
filters.by_name = Some(spell.name.clone()); // filters.by_name = Some(spell.name.clone());
match QuerySpell::get_all(&filters, 100, 1) { // match QuerySpell::get_all(&filters, 100, 1) {
Ok(spells) => { // Ok(spells) => {
if spells.len() > 0 { // if spells.len() > 0 {
trace!("Spell '{}' already exists", spell.name); // trace!("Spell '{}' already exists", spell.name);
continue; // continue;
// }
// }
// Err(err) => {
// warn!("Error checking if spell '{}' exists: {}", spell.name, err);
// continue;
// }
// };
// let spell = InsertSpell::insert(spell.into()).unwrap();
// trace!("Inserted spell: {}", spell.name);
} }
} }
Err(err) => { Err(err) => log::warn!("Error reading spells from file: {}", err),
warn!("Error checking if spell '{}' exists: {}", spell.name, err);
continue;
}
};
let spell = InsertSpell::insert(spell.into()).unwrap();
trace!("Inserted spell: {}", spell.name);
}
}
Err(err) => warn!("Error reading spells from file: {}", err),
}; };
} }
} }
@@ -55,7 +54,7 @@ pub fn load_data(data_dir_path: &str) {
} }
} }
} else { } else {
warn!( log::warn!(
"Data path '{}' does not exist, no data imported", "Data path '{}' does not exist, no data imported",
data_dir_path data_dir_path
); );

View File

@@ -1,9 +1,5 @@
use diesel::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::SirenResult;
use crate::storage::connection;
use crate::storage::schema::spells::{self};
use crate::dnd::{classes::AbilityType, conditions::ConditionType}; use crate::dnd::{classes::AbilityType, conditions::ConditionType};
use super::{ use super::{
@@ -11,8 +7,7 @@ use super::{
Source, Description, DurationType, Effect, Source, Description, DurationType, Effect,
}; };
#[derive(Debug, Queryable, QueryableByName, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[diesel(table_name = spells)]
pub struct QuerySpell { pub struct QuerySpell {
pub id: i32, pub id: i32,
pub name: String, pub name: String,
@@ -30,216 +25,6 @@ pub struct QuerySpell {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct QueryFilters {
pub by_name: Option<String>,
pub like_name: Option<String>,
pub by_schools: Option<Vec<String>>,
pub by_levels: Option<Vec<i32>>,
pub by_ritual: Option<bool>,
pub by_concentration: Option<bool>,
pub by_classes: Option<Vec<String>>,
pub by_damage_inflict: Option<Vec<String>>,
pub by_damage_resist: Option<Vec<String>>,
pub by_conditions: Option<Vec<String>>,
pub by_saving_throw: Option<Vec<String>>,
pub by_attack_type: Option<String>,
}
impl Default for QueryFilters {
fn default() -> Self {
Self {
by_name: None,
like_name: None,
by_schools: None,
by_levels: None,
by_ritual: None,
by_concentration: None,
by_classes: None,
by_damage_inflict: None,
by_damage_resist: None,
by_conditions: None,
by_saving_throw: None,
by_attack_type: None,
}
}
}
impl QuerySpell {
pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> SirenResult<Vec<Self>> {
let mut conn = connection()?;
let mut query = spells::table.limit(limit as i64).into_boxed();
// Limit query to page and limit
let offset = (page - 1) * limit;
query = query.offset(offset as i64);
if let Some(name) = &filters.by_name {
query = query.filter(spells::name.eq(name));
}
if let Some(name) = &filters.like_name {
query = query.filter(spells::name.ilike(format!("%{}%", name)));
}
if let Some(schools) = &filters.by_schools {
query = query.filter(
spells::school.eq_any(
schools
.iter()
.map(|school| school.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(levels) = &filters.by_levels {
query = query.filter(spells::level.eq_any(levels));
}
if let Some(ritual) = filters.by_ritual {
query = query.filter(spells::ritual.eq(ritual));
}
if let Some(concentration) = filters.by_concentration {
query = query.filter(spells::concentration.eq(concentration));
}
if let Some(classes) = &filters.by_classes {
query = query.filter(spells::classes.overlaps_with(classes));
}
if let Some(damage_inflict) = &filters.by_damage_inflict {
query = query.filter(
spells::damage_inflict.overlaps_with(
damage_inflict
.iter()
.map(|damage_inflict| damage_inflict.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(damage_resist) = &filters.by_damage_resist {
query = query.filter(
spells::damage_resist.overlaps_with(
damage_resist
.iter()
.map(|damage_resist| damage_resist.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(conditions) = &filters.by_conditions {
query = query.filter(
spells::conditions.overlaps_with(
conditions
.iter()
.map(|condition| condition.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(saving_throw) = &filters.by_saving_throw {
query = query.filter(
spells::saving_throw.overlaps_with(
saving_throw
.iter()
.map(|saving_throw| saving_throw.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(attack_type) = &filters.by_attack_type {
query = query.filter(spells::attack_type.eq(attack_type.to_string()));
}
let spells = query.load::<QuerySpell>(&mut conn)?;
Ok(spells)
}
pub fn get_count(filters: &QueryFilters) -> SirenResult<i64> {
let mut conn = connection()?;
let mut query = spells::table.count().into_boxed();
if let Some(name) = &filters.by_name {
query = query.filter(spells::name.ilike(format!("%{}%", name)));
}
if let Some(schools) = &filters.by_schools {
query = query.filter(
spells::school.eq_any(
schools
.iter()
.map(|school| school.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(levels) = &filters.by_levels {
query = query.filter(spells::level.eq_any(levels));
}
if let Some(ritual) = filters.by_ritual {
query = query.filter(spells::ritual.eq(ritual));
}
if let Some(concentration) = filters.by_concentration {
query = query.filter(spells::concentration.eq(concentration));
}
if let Some(classes) = &filters.by_classes {
query = query.filter(spells::classes.overlaps_with(classes));
}
if let Some(damage_inflict) = &filters.by_damage_inflict {
query = query.filter(
spells::damage_inflict.overlaps_with(
damage_inflict
.iter()
.map(|damage_inflict| damage_inflict.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(damage_resist) = &filters.by_damage_resist {
query = query.filter(
spells::damage_resist.overlaps_with(
damage_resist
.iter()
.map(|damage_resist| damage_resist.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(conditions) = &filters.by_conditions {
query = query.filter(
spells::conditions.overlaps_with(
conditions
.iter()
.map(|condition| condition.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(saving_throw) = &filters.by_saving_throw {
query = query.filter(
spells::saving_throw.overlaps_with(
saving_throw
.iter()
.map(|saving_throw| saving_throw.to_string())
.collect::<Vec<String>>(),
),
);
}
if let Some(attack_type) = &filters.by_attack_type {
query = query.filter(spells::attack_type.eq(attack_type.to_string()));
}
let count = query.get_result(&mut conn)?;
Ok(count)
}
pub fn get_by_id(id: i32) -> SirenResult<Self> {
let mut conn = connection()?;
let spell = spells::table
.filter(spells::id.eq(id))
.first::<QuerySpell>(&mut conn)?;
Ok(spell)
}
pub fn delete(id: i32) -> SirenResult<Self> {
let mut conn = connection()?;
let spell = diesel::delete(spells::table.filter(spells::id.eq(id))).get_result(&mut conn)?;
Ok(spell)
}
}
#[derive(Debug, Insertable, AsChangeset)]
#[diesel(table_name = spells)]
pub struct InsertSpell { pub struct InsertSpell {
pub name: String, pub name: String,
pub school: String, pub school: String,
@@ -255,24 +40,6 @@ pub struct InsertSpell {
pub data: serde_json::Value, pub data: serde_json::Value,
} }
impl InsertSpell {
pub fn insert(spell: Self) -> SirenResult<QuerySpell> {
let mut conn = connection()?;
let spell = diesel::insert_into(spells::table)
.values(spell)
.get_result(&mut conn)?;
Ok(spell)
}
pub fn update(id: i32, spell: Self) -> SirenResult<QuerySpell> {
let mut conn = connection()?;
let spell = diesel::update(spells::table.filter(spells::id.eq(id)))
.set(spell)
.get_result(&mut conn)?;
Ok(spell)
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Spell { pub struct Spell {
pub id: Option<i32>, pub id: Option<i32>,

View File

@@ -1,5 +1,4 @@
use std::fmt; use std::fmt;
use diesel::result::Error as DieselError;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub type SirenResult<T> = Result<T, Error>; pub type SirenResult<T> = Result<T, Error>;
@@ -37,19 +36,44 @@ impl From<std::string::FromUtf8Error> for Error {
} }
} }
impl From<DieselError> for Error { impl From<sqlx::Error> for Error {
fn from(error: DieselError) -> Self { fn from(error: sqlx::Error) -> Self {
match error { match error {
DieselError::DatabaseError(kind, err) => match kind { sqlx::Error::RowNotFound => Error::new(404, "Not found".to_string()),
diesel::result::DatabaseErrorKind::UniqueViolation => { sqlx::Error::ColumnIndexOutOfBounds { .. } => Error::new(422, error.to_string()),
Self::new(409, err.message().to_string()) sqlx::Error::ColumnNotFound { .. } => Error::new(422, error.to_string()),
sqlx::Error::ColumnDecode { .. } => Error::new(422, error.to_string()),
sqlx::Error::Decode(_) => Error::new(422, error.to_string()),
sqlx::Error::PoolTimedOut => Error::new(503, error.to_string()),
sqlx::Error::PoolClosed => Error::new(503, error.to_string()),
sqlx::Error::Tls(_) => Error::new(500, error.to_string()),
sqlx::Error::Io(_) => Error::new(500, error.to_string()),
sqlx::Error::Protocol(_) => Error::new(500, error.to_string()),
sqlx::Error::Configuration(_) => Error::new(500, error.to_string()),
sqlx::Error::AnyDriverError(_) => Error::new(500, error.to_string()),
sqlx::Error::Database(err) => {
if let Some(code) = err.code() {
match code.trim() {
// Unique violation
"23505" => return Error::new(409, err.to_string()),
_ => (),
} }
_ => Self::new(500, err.message().to_string()),
},
DieselError::NotFound => Self::new(404, "The record was not found".to_string()),
DieselError::SerializationError(err) => Self::new(422, err.to_string()),
err => Self::new(500, format!("Unknown database error: {}", err)),
} }
Error::new(500, err.to_string())
}
sqlx::Error::Migrate(_) => Error::new(500, error.to_string()),
sqlx::Error::TypeNotFound { type_name } => {
Error::new(500, format!("Type not found: {}", type_name))
}
sqlx::Error::WorkerCrashed => Error::new(500, error.to_string()),
_ => Error::new(500, error.to_string()),
}
}
}
impl From<sqlx::migrate::MigrateError> for Error {
fn from(error: sqlx::migrate::MigrateError) -> Self {
Error::new(500, error.to_string())
} }
} }

View File

@@ -1,7 +1,3 @@
extern crate diesel;
#[macro_use]
extern crate diesel_migrations;
use std::env; use std::env;
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
@@ -13,9 +9,9 @@ use reqwest::Client as HttpClient;
use crate::bot::handler::Handler; use crate::bot::handler::Handler;
mod bot; mod bot;
mod database;
mod dnd; mod dnd;
mod error; mod error;
mod storage;
pub struct HttpKey; pub struct HttpKey;
@@ -27,7 +23,10 @@ impl TypeMapKey for HttpKey {
async fn main() { async fn main() {
dotenv::dotenv().ok(); dotenv::dotenv().ok();
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
storage::init().await; if let Err(err) = database::initialize().await {
log::error!("Failed to initialize database: {err}");
return;
};
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
let intents: GatewayIntents = GatewayIntents::all(); let intents: GatewayIntents = GatewayIntents::all();

View File

@@ -1,65 +0,0 @@
use diesel::{r2d2::ConnectionManager as DieselConnectionManager, PgConnection};
use redis::{aio::MultiplexedConnection, Client as RedisClient};
use crate::{diesel_migrations::MigrationHarness, error::{Error as SirenError, SirenResult}};
use lazy_static::lazy_static;
use log::{error, info};
use r2d2;
use std::env;
pub mod schema;
type DbPool = r2d2::Pool<DieselConnectionManager<PgConnection>>;
pub type DbConnection = r2d2::PooledConnection<DieselConnectionManager<PgConnection>>;
pub const MIGRATIONS: diesel_migrations::EmbeddedMigrations = embed_migrations!();
lazy_static! {
static ref POOL: DbPool = {
let username = env::var("DATABASE_USER").expect("DATABASE_USERNAME is not set");
let password = env::var("DATABASE_PASSWORD").expect("DATABASE_PASSWORD is not set");
let host = env::var("DATABASE_HOST").unwrap_or("localhost".to_string());
let name = env::var("DATABASE_NAME").expect("DATABASE_NAME is not set");
let port = env::var("DATABASE_PORT").unwrap_or("5432".to_string());
let url = format!(
"postgres://{}:{}@{}:{}/{}",
username, password, host, port, name
);
let manager = DieselConnectionManager::<PgConnection>::new(url);
DbPool::builder()
.test_on_check_out(true)
.build(manager)
.expect("Failed to create db pool")
};
static ref REDIS: RedisClient = {
let host = env::var("REDIS_HOST").unwrap_or("localhost".to_string());
let port = env::var("REDIS_PORT").unwrap_or("6379".to_string());
let url = format!("redis://{}:{}", host, port);
RedisClient::open(url).expect("Failed to create redis client")
};
}
pub async fn init() {
lazy_static::initialize(&POOL);
lazy_static::initialize(&REDIS);
let mut pool: DbConnection = connection().expect("Failed to get db connection");
match pool.run_pending_migrations(MIGRATIONS) {
Ok(_) => info!("Database initialized"),
Err(err) => error!("Failed to initialize database; {}", err),
};
}
pub fn connection() -> SirenResult<DbConnection> {
POOL
.get()
.map_err(|e| SirenError::new(500, format!("Failed getting db connection: {}", e)))
}
pub fn redis_connection() -> SirenResult<redis::Connection> {
let conn = REDIS.get_connection()?;
Ok(conn)
}
pub async fn redis_async_connection() -> SirenResult<MultiplexedConnection> {
let conn = REDIS.get_multiplexed_async_connection().await?;
Ok(conn)
}

View File

@@ -1,54 +0,0 @@
diesel::table! {
messages (id) {
id -> Text,
guild_id -> BigInt,
channel_id -> BigInt,
user_id -> BigInt,
created -> BigInt,
model -> Text,
request -> Text,
response -> Text,
request_tags -> Array<Text>,
response_tags -> Array<Text>,
}
}
diesel::table! {
spells (id) {
id -> Integer,
name -> Text,
school -> Text,
level -> Integer,
ritual -> Bool,
concentration -> Bool,
classes -> Array<Text>,
damage_inflict -> Array<Text>,
damage_resist -> Array<Text>,
conditions -> Array<Text>,
saving_throw -> Array<Text>,
attack_type -> Nullable<Text>,
data -> Jsonb
}
}
diesel::table! {
guilds (id) {
id -> BigInt,
bot_id -> BigInt,
volume -> Integer,
}
}
diesel::table! {
users (email) {
email -> Text,
hash -> Text,
role -> Text,
first_name -> Text,
last_name -> Text,
updated_at -> Timestamp,
created_at -> Timestamp,
profile_picture -> Nullable<Text>,
verified -> Bool,
}
}