diff --git a/.env.TEMPLATE b/.env.TEMPLATE index 3e1ad7b..039f8a6 100644 --- a/.env.TEMPLATE +++ b/.env.TEMPLATE @@ -1,6 +1,8 @@ -DISCORD_TOKEN= RUST_LOG=warn,siren=info -POSTGRES_USER=siren -POSTGRES_PASSWORD= -POSTGRES_DB=siren +DATABASE_USER=siren +DATABASE_PASSWORD= +DATABASE_NAME=siren +DATABASE_HOST=localhost +DATABASE_PORT=5432 +DISCORD_TOKEN= OPENAI_API_KEY= \ No newline at end of file diff --git a/.version b/.version index 4f1b91d..e9f3799 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -SIREN_VERSION=0.2.3 \ No newline at end of file +SIREN_VERSION=0.2.4 \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index dd6eed2..4317c94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "siren" -version = "0.2.3" +version = "0.2.4" edition = "2021" authors = ["Ben Sherriff "] repository = "https://github.com/bensherriff/siren" @@ -9,9 +9,12 @@ license = "GPL-3.0-or-later" [dependencies] dotenv = "0.15.0" -serde_json = "1.0" -log = "0.4.19" +serde_json = "1.0.107" +log = "0.4.20" env_logger = "0.10.0" +diesel_migrations = { version = "2.1.0", features = ["postgres"] } +r2d2 = "0.8.10" +lazy_static = "1.4.0" [dependencies.serenity] version = "0.11.6" @@ -23,23 +26,23 @@ version = "0.3.2" features = ["builtin-queue", "yt-dlp"] [dependencies.tokio] -version = "1.29.1" +version = "1.32.0" features = ["macros", "rt-multi-thread"] [dependencies.serde] -version = "1.0" +version = "1.0.188" features = ["derive"] [dependencies.reqwest] -version = "0.11.18" +version = "0.11.22" default-features = false features = ["json", "rustls-tls"] [dependencies.diesel] -version = "2.1.0" +version = "2.1.2" default-features = false features = ["postgres", "32-column-tables", "serde_json", "r2d2", "with-deprecated"] [dependencies.pyo3] -version = "0.19.1" +version = "0.19.2" features = ["auto-initialize"] diff --git a/Dockerfile b/Dockerfile index 0af2576..405bf0d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,12 @@ -FROM rust:1.70.0 as builder +# Builder +FROM rust:1.72.1-bookworm as builder WORKDIR /siren ADD src ./src/ ADD Cargo.toml ./ RUN apt-get update && apt-get install -y cmake && \ cargo build --release --bin siren +# Packages FROM debian:bullseye-slim as packages WORKDIR /packages RUN apt-get update && apt-get install -y curl tar xz-utils && \ @@ -19,6 +21,7 @@ RUN apt-get update && apt-get install -y curl tar xz-utils && \ # curl -L https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu117.zip > libtorch.zip && \ # unzip libtorch.zip && rm libtorch.zip +# Runner FROM debian:bullseye-slim as runtime WORKDIR /siren RUN apt-get update && apt-get install -y libopus-dev libpq5 libpq-dev && apt-get auto-remove -y diff --git a/docker-compose.yml b/docker-compose.yml index 40d94b0..574ea46 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,38 +3,36 @@ version: '3' services: siren: image: siren:${SIREN_VERSION:-latest} - container_name: siren + container_name: siren-service build: context: . dockerfile: ./Dockerfile args: - VERSION=${SIREN_VERSION:-latest} - volumes: - - ./app:/siren env_file: - .env environment: - DISCORD_TOKEN: ${DISCORD_TOKEN} - RUST_LOG: ${RUST_LOG} - OPENAI_API_KEY: ${OPENAI_API_KEY} - POSTGRES_USER: ${POSTGRES_USER} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} - POSTGRES_DB: ${POSTGRES_DB} - POSTGRES_HOST: db + DATABASE_HOST: db + DATABASE_PORT: 5432 depends_on: - db restart: unless-stopped db: image: postgres:latest - container_name: siren_db + container_name: siren-db env_file: - .env environment: - POSTGRES_USER: ${POSTGRES_USER} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} - POSTGRES_DB: ${POSTGRES_DB} + POSTGRES_USER: ${DATABASE_USER} + POSTGRES_PASSWORD: ${DATABASE_PASSWORD} + POSTGRES_DB: ${DATABASE_NAME} volumes: - - ./data:/var/lib/postgresql/data + - db:/var/lib/postgresql/data + - db_logs:/var/log ports: - - "5432:5432" + - ${DATABASE_PORT:-5432}:5432 restart: unless-stopped + +volumes: + db: + db_logs: diff --git a/migrations/000000_create_messages/down.sql b/migrations/000000_create_messages/down.sql new file mode 100644 index 0000000..be13677 --- /dev/null +++ b/migrations/000000_create_messages/down.sql @@ -0,0 +1 @@ +DROP TABLE messages; \ No newline at end of file diff --git a/migrations/create_messages/up.sql b/migrations/000000_create_messages/up.sql similarity index 99% rename from migrations/create_messages/up.sql rename to migrations/000000_create_messages/up.sql index f5965cf..f876ba8 100644 --- a/migrations/create_messages/up.sql +++ b/migrations/000000_create_messages/up.sql @@ -9,4 +9,4 @@ CREATE TABLE IF NOT EXISTS messages ( response TEXT NOT NULL, request_tags TEXT[] NOT NULL, response_tags TEXT[] NOT NULL -) \ No newline at end of file +); \ No newline at end of file diff --git a/migrations/create_messages/down.sql b/migrations/create_messages/down.sql deleted file mode 100644 index 90b6925..0000000 --- a/migrations/create_messages/down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE messages \ No newline at end of file diff --git a/src/commands/oai.rs b/src/commands/oai.rs index 1576be2..7bb702e 100644 --- a/src/commands/oai.rs +++ b/src/commands/oai.rs @@ -2,8 +2,7 @@ use std::error::Error; use std::fmt; -use diesel::{prelude::*, PgConnection, insert_into}; -use diesel::r2d2::{Pool, ConnectionManager}; +use diesel::{prelude::*, insert_into}; use log::{error, debug, trace, warn}; use serde::{Serialize, Deserialize}; @@ -13,7 +12,8 @@ use serenity::model::channel::Message; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; use serenity::prelude::*; -use crate::database::models::{NewMessageDB, MessageDB}; +use crate::db; +use crate::messages::{NewMessageDB, MessageDB}; pub struct OAI { pub client: reqwest::Client, @@ -187,26 +187,26 @@ impl OAI { } } -pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &Pool>) { +pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { debug!("Generating response for message: {}", msg.content); + let mut connection = db::connection().unwrap(); let guild_id = msg.guild_id.unwrap(); let channel_id = msg.channel_id; let author_id = msg.author.id; - let mut connection = pool.get().unwrap(); // Parse out the bot mention from the message let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); let parsed_content = msg.content.replace(bot_mention.as_str(), ""); // Setup the request messages - let result: Result, diesel::result::Error> = crate::database::schema::messages::table + let result: Result, diesel::result::Error> = crate::schema::messages::table .select(MessageDB::as_select()) - .filter((crate::database::schema::messages::guild_id.eq(guild_id.0 as i64)) - .and(crate::database::schema::messages::channel_id.eq(channel_id.0 as i64)) - .and(crate::database::schema::messages::user_id.eq(author_id.0 as i64)) + .filter((crate::schema::messages::guild_id.eq(guild_id.0 as i64)) + .and(crate::schema::messages::channel_id.eq(channel_id.0 as i64)) + .and(crate::schema::messages::user_id.eq(author_id.0 as i64)) ) - .order(crate::database::schema::messages::created.asc()) + .order(crate::schema::messages::created.asc()) .limit(oai.max_context_questions) .load(&mut connection); @@ -284,7 +284,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P if !r.choices.is_empty() { let res = r.choices[0].message.content.clone(); // Insert the message into the messages database table - if let Err(err) = insert_into(crate::database::schema::messages::table).values(NewMessageDB { + if let Err(err) = insert_into(crate::schema::messages::table).values(NewMessageDB { id: &r.id, guild_id: guild_id.0 as i64, channel_id: response_channel.0 as i64, diff --git a/src/database/mod.rs b/src/database/mod.rs deleted file mode 100644 index 1673f85..0000000 --- a/src/database/mod.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::env; -// use std::path::Path; - -use diesel::RunQueryDsl; -use diesel::r2d2::{Pool, ConnectionManager}; -use diesel::pg::PgConnection; -use log::{error, info}; - -pub mod models; -pub mod schema; - -pub fn run_migrations(pool: &Pool>) { - let mut connection = pool.get().unwrap(); - if let Err(err) = diesel::sql_query("CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY NOT NULL, - guild_id BIGINT NOT NULL, - channel_id BIGINT NOT NULL, - user_id BIGINT NOT NULL, - created BIGINT NOT NULL, - model TEXT NOT NULL, - request TEXT NOT NULL, - response TEXT NOT NULL, - request_tags TEXT[] NOT NULL, - response_tags TEXT[] NOT NULL - )").execute(&mut connection) { - error!("Could not create messages table: {}", err); - } else { - info!("Successfully created messages table"); - } - // let migrations_dir = Path::new("./migrations"); - // let migrations = std::fs::read_dir(&migrations_dir).unwrap(); - - // for migration in migrations { - // if migration.as_ref().unwrap().file_type().unwrap().is_dir() { - // let migration_paths = std::fs::read_dir(&migration.unwrap().path()).unwrap(); - - // for migration_path in migration_paths { - // if migration_path.as_ref().unwrap().file_name().eq_ignore_ascii_case("up.sql") { - // let path = &migration_path.unwrap().path(); - // let contents = std::fs::read_to_string(path).expect("Unable to read from file"); - // if let Err(err) = diesel::sql_query(&contents).execute(&mut connection) { - // error!("Could not run migration: {}", err); - // } else { - // info!("Successfully ran migration: {}", path.display()); - // } - // } - // } - // } - // } -} - -pub fn establish_connection() -> Pool> { - let database_user = env::var("POSTGRES_USER").expect("Expected a user in the environment"); - let database_password = env::var("POSTGRES_PASSWORD").expect("Expected a password in the environment"); - let database_name = env::var("POSTGRES_DB").expect("Expected a database name in the environment"); - let database_host = env::var("POSTGRES_HOST").unwrap_or("localhost".to_string()); - - let database_url = format!("postgres://{}:{}@{}/{}", database_user, database_password, database_host, database_name); - let manager = ConnectionManager::::new(database_url); - Pool::builder().build(manager).expect("Failed to create pool.") -} \ No newline at end of file diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..bb22870 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,39 @@ +use crate::error_handler::ServiceError; +use diesel::{r2d2::ConnectionManager, PgConnection}; +use crate::diesel_migrations::MigrationHarness; +use lazy_static::lazy_static; +use log::{error, info}; +use r2d2; +use std::env; + +type Pool = r2d2::Pool>; +pub type DbConnection = r2d2::PooledConnection>; + +pub const MIGRATIONS: diesel_migrations::EmbeddedMigrations = embed_migrations!(); + +lazy_static! { + static ref POOL: Pool = { + let username = env::var("DATABASE_USER").expect("DATABASE_USERNAME is not set"); + let password = env::var("DATABASE_PASSWORD").expect("DATABASE_PASSWORD is not set"); + let host = env::var("DATABASE_HOST").unwrap_or("localhost".to_string()); + let name = env::var("DATABASE_NAME").expect("DATABASE_NAME is not set"); + let port = env::var("DATABASE_PORT").unwrap_or("5432".to_string()); + let url = format!("postgres://{}:{}@{}:{}/{}", username, password, host, port, name); + let manager = ConnectionManager::::new(url); + Pool::builder().test_on_check_out(true).build(manager).expect("Failed to create db pool") + }; +} + +pub fn init() { + lazy_static::initialize(&POOL); + let mut pool: DbConnection = connection().expect("Failed to get db connection"); + match pool.run_pending_migrations(MIGRATIONS) { + Ok(_) => info!("Database initialized"), + Err(err) => error!("Failed to initialize database; {}", err) + }; +} + +pub fn connection() -> Result { + POOL.get() + .map_err(|e| ServiceError::new(500, format!("Failed getting db connection: {}", e))) +} diff --git a/src/error_handler.rs b/src/error_handler.rs new file mode 100644 index 0000000..18f4f7c --- /dev/null +++ b/src/error_handler.rs @@ -0,0 +1,36 @@ +use diesel::result::Error as DieselError; +use serde::{Deserialize, Serialize}; +use std::fmt; + +#[derive(Debug, Deserialize, Serialize)] +pub struct ServiceError { + pub error_status_code: u16, + pub error_message: String, +} + +impl ServiceError { + pub fn new(error_status_code: u16, error_message: String) -> ServiceError { + ServiceError { + error_status_code, + error_message, + } + } +} + +impl fmt::Display for ServiceError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.error_message.as_str()) + } +} + +impl From for ServiceError { + fn from(error: DieselError) -> ServiceError { + match error { + DieselError::DatabaseError(_, err) => ServiceError::new(409, err.message().to_string()), + DieselError::NotFound => { + ServiceError::new(404, "The record was not found".to_string()) + } + err => ServiceError::new(500, format!("Unknown Diesel error: {}", err)), + } + } +} diff --git a/src/main.rs b/src/main.rs index 5ec5ef0..2e1a248 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,12 @@ +extern crate diesel; +#[macro_use] +extern crate diesel_migrations; + use std::collections::{HashSet, HashMap}; use std::env; use std::sync::Arc; use commands::audio::{create_response, AudioConfig, AudioConfigs}; -use diesel::r2d2::{Pool, ConnectionManager}; -use diesel::pg::PgConnection; use dotenv::dotenv; use log::{error, warn, info}; @@ -18,11 +20,13 @@ use serenity::prelude::*; use songbird::SerenityInit; mod commands; -mod database; +mod error_handler; +mod db; +mod messages; +mod schema; struct Handler { // Open AI Config - oai: Option, - pool: Pool> + oai: Option } #[async_trait] @@ -46,7 +50,7 @@ impl EventHandler for Handler { Err(_) => false }; if mentioned || bot_in_thread { - commands::oai::generate_response(&ctx, &msg, oai, &self.pool).await; + commands::oai::generate_response(&ctx, &msg, oai).await; } } Err(why) => warn!("Could not check mentions: {:?}", why) @@ -134,20 +138,18 @@ async fn main() { Err(why) => panic!("Could not access application info: {:?}", why) }; - let pool = database::establish_connection(); - database::run_migrations(&pool); + db::init(); let handler = match env::var("OPENAI_API_KEY") { Ok(token) => { info!("Loaded OpenAI token"); Handler { - oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 15 }), - pool + oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 15 }) } } Err(err) => { warn!("Could not load OpenAI token: {}", err); - Handler { oai: None, pool } + Handler { oai: None } } }; diff --git a/src/messages/mod.rs b/src/messages/mod.rs new file mode 100644 index 0000000..4a7ebf6 --- /dev/null +++ b/src/messages/mod.rs @@ -0,0 +1,3 @@ +mod model; + +pub use model::*; diff --git a/src/database/models.rs b/src/messages/model.rs similarity index 95% rename from src/database/models.rs rename to src/messages/model.rs index 93d9bab..f1120ef 100644 --- a/src/database/models.rs +++ b/src/messages/model.rs @@ -1,6 +1,6 @@ use diesel::prelude::*; -use super::schema::messages; +use crate::schema::messages; #[derive(Queryable, Selectable)] #[diesel(table_name = messages)] diff --git a/src/database/schema.rs b/src/schema.rs similarity index 100% rename from src/database/schema.rs rename to src/schema.rs