From 16d8fa5af870dac90cae19f55d67ef5670ec9a05 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Tue, 3 Oct 2023 14:21:53 -0400 Subject: [PATCH] Spells endpoints --- .env.TEMPLATE | 5 + docker-compose.yml | 2 + migrations/000007_create_spells/up.sql | 6 +- src/commands/oai.rs | 2 +- src/db/mod.rs | 43 ++++++--- src/db/spells/mod.rs | 40 ++++++++ src/db/spells/model.rs | 101 ++++++++++++++------ src/db/spells/routes.rs | 79 +++++++++++++++ src/error_handler.rs | 23 ++++- src/main.rs | 127 +++++++++++++++---------- 10 files changed, 330 insertions(+), 98 deletions(-) create mode 100644 src/db/spells/routes.rs diff --git a/.env.TEMPLATE b/.env.TEMPLATE index 039f8a6..c933137 100644 --- a/.env.TEMPLATE +++ b/.env.TEMPLATE @@ -1,8 +1,13 @@ RUST_LOG=warn,siren=info + DATABASE_USER=siren DATABASE_PASSWORD= DATABASE_NAME=siren DATABASE_HOST=localhost DATABASE_PORT=5432 + +SERVICE_HOST=localhost +SERVICE_PORT=5000 + DISCORD_TOKEN= OPENAI_API_KEY= \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 574ea46..efc5f2e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,6 +14,8 @@ services: environment: DATABASE_HOST: db DATABASE_PORT: 5432 + SERVICE_HOST: siren + SERVICE_PORT: 5000 depends_on: - db restart: unless-stopped diff --git a/migrations/000007_create_spells/up.sql b/migrations/000007_create_spells/up.sql index 8312c69..83beab8 100644 --- a/migrations/000007_create_spells/up.sql +++ b/migrations/000007_create_spells/up.sql @@ -2,17 +2,17 @@ CREATE TABLE IF NOT EXISTS spells ( id INTEGER GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL, school TEXT NOT NULL, - level TEXT NOT NULL, + level INTEGER NOT NULL, ritual BOOLEAN DEFAULT FALSE, casting_time TEXT NOT NULL, range TEXT NOT NULL, components_verbal BOOLEAN DEFAULT FALSE, components_somatic BOOLEAN DEFAULT FALSE, components_material BOOLEAN DEFAULT FALSE, - components_materials_needed TEXT + components_materials_needed TEXT, duration TEXT NOT NULL, classes TEXT[] NOT NULL, sources TEXT[] NOT NULL, - tags TEXT[] + tags TEXT[], description TEXT NOT NULL ); \ No newline at end of file diff --git a/src/commands/oai.rs b/src/commands/oai.rs index 13a2e01..56706aa 100644 --- a/src/commands/oai.rs +++ b/src/commands/oai.rs @@ -12,7 +12,7 @@ use serenity::model::channel::Message; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; use serenity::prelude::*; -use crate::db::{connection, NewMessageDB, MessageDB}; +use crate::db::{connection, messages::{MessageDB, NewMessageDB}}; pub struct OAI { pub client: reqwest::Client, diff --git a/src/db/mod.rs b/src/db/mod.rs index 991ff2b..261221d 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,26 +1,25 @@ use crate::error_handler::ServiceError; use diesel::{r2d2::ConnectionManager, PgConnection}; +use serde::{Deserialize, Serialize}; use crate::diesel_migrations::MigrationHarness; use lazy_static::lazy_static; use log::{error, info}; use r2d2; use std::env; -mod backgrounds; -mod bestiary; -mod classes; -mod conditions; -mod feats; -mod items; -mod messages; -mod options; -mod races; -mod spells; -mod users; +pub mod backgrounds; +pub mod bestiary; +pub mod classes; +pub mod conditions; +pub mod feats; +pub mod items; +pub mod messages; +pub mod options; +pub mod races; +pub mod spells; +pub mod users; pub mod schema; -pub use messages::*; - type Pool = r2d2::Pool>; pub type DbConnection = r2d2::PooledConnection>; @@ -52,3 +51,21 @@ pub fn connection() -> Result { POOL.get() .map_err(|e| ServiceError::new(500, format!("Failed getting db connection: {}", e))) } + +pub fn load_data() { + spells::load_data(); +} + +#[derive(Serialize, Deserialize)] +pub struct GetResponse { + pub data: T, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option +} + +#[derive(Serialize, Deserialize)] +pub struct Metadata { + pub total: i32, + pub limit: i32, + pub page: i32 +} diff --git a/src/db/spells/mod.rs b/src/db/spells/mod.rs index 4a7ebf6..46a85d4 100644 --- a/src/db/spells/mod.rs +++ b/src/db/spells/mod.rs @@ -1,3 +1,43 @@ mod model; +mod routes; pub use model::*; +pub use routes::init_routes; + +pub fn load_data() { + let root_path = std::env::current_dir().unwrap(); + let mut data_path = std::path::PathBuf::from(root_path); + data_path.push("data/spells.json"); + match data_path.to_str() { + Some(path) => { + log::debug!("Loading spells from {}", path); + match std::fs::read_to_string(data_path) { + Ok(data) => { + match serde_json::from_str::(&data) { + Ok(json) => { + match serde_json::from_value::>(json) { + Ok(spells) => { + let count = QuerySpell::get_count().unwrap(); + if count >= spells.len() as i64 { + log::warn!("Spell data is already loaded"); + return; + } + for spell in spells { + match InsertSpell::insert(spell.into()) { + Ok(_) => {}, + Err(err) => log::error!("Failed to insert spell: {}", err) + } + } + }, + Err(err) => log::error!("Failed to parse spells data: {}", err) + } + }, + Err(err) => log::error!("Failed to parse spells data to value: {}", err) + }; + }, + Err(err) => log::error!("Failed to read spells data: {}", err) + }; + }, + None => log::error!("Failed to find spells data directory") + } +} diff --git a/src/db/spells/model.rs b/src/db/spells/model.rs index b5081be..e849104 100644 --- a/src/db/spells/model.rs +++ b/src/db/spells/model.rs @@ -1,7 +1,7 @@ use diesel::prelude::*; use serde::{Deserialize, Serialize}; -use crate::db::schema::spells; +use crate::{db::schema::spells, error_handler::ServiceError}; #[derive(Queryable, QueryableByName)] #[diesel(table_name = spells)] @@ -24,6 +24,38 @@ pub struct QuerySpell { pub description: String } +impl QuerySpell { + pub fn get_all(limit: i32, page: i32) -> Result, ServiceError> { + let mut conn = crate::db::connection()?; + let mut query = spells::table + .limit(limit as i64) + .into_boxed(); + query = query.filter(spells::id.gt(std::cmp::max(0, page - 1) * limit)); + let spells = query.load::(&mut conn)?; + Ok(spells) + } + + pub fn get_by_id(id: i32) -> Result { + let mut conn = crate::db::connection()?; + let spell = spells::table + .filter(spells::id.eq(id)) + .first::(&mut conn)?; + Ok(spell) + } + + pub fn get_count() -> Result { + let mut conn = crate::db::connection()?; + let count = spells::table.count().get_result(&mut conn)?; + Ok(count) + } + + pub fn delete(id: i32) -> Result { + let mut conn = crate::db::connection()?; + let spell = diesel::delete(spells::table.filter(spells::id.eq(id))).get_result(&mut conn)?; + Ok(spell) + } +} + #[derive(Insertable, AsChangeset)] #[diesel(table_name = spells)] pub struct InsertSpell { @@ -44,6 +76,20 @@ pub struct InsertSpell { pub description: String } +impl InsertSpell { + pub fn insert(spell: Self) -> Result { + let mut conn = crate::db::connection()?; + let spell = diesel::insert_into(spells::table).values(spell).get_result(&mut conn)?; + Ok(spell) + } + + pub fn update(id: i32, spell: Self) -> Result { + let mut conn = crate::db::connection()?; + let spell = diesel::update(spells::table.filter(spells::id.eq(id))).set(spell).get_result(&mut conn)?; + Ok(spell) + } +} + #[derive(Serialize, Deserialize)] pub struct Spell { pub name: String, @@ -56,6 +102,7 @@ pub struct Spell { pub duration: String, pub classes: Vec, pub sources: Vec, + pub tags: Vec, pub description: String } @@ -67,30 +114,8 @@ pub struct Components { pub materials_needed: Option } -impl Spell { - /// Convert spell to insertable struct - pub fn to_insert(self) -> InsertSpell { - return InsertSpell { - name: self.name, - school: self.school, - level: self.level, - ritual: self.ritual, - casting_time: self.casting_time, - range: self.range, - components_verbal: self.components.verbal, - components_somatic: self.components.somatic, - components_material: self.components.material, - components_materials_needed: self.components.materials_needed, - duration: self.duration, - classes: self.classes, - sources: self.sources, - tags: vec![], - description: self.description - } - } - - /// Convert query struct to spell - pub fn from_query(query: QuerySpell) -> Self { +impl From for Spell { + fn from(query: QuerySpell) -> Self { return Self { name: query.name, school: query.school, @@ -107,14 +132,30 @@ impl Spell { duration: query.duration, classes: query.classes, sources: query.sources, + tags: query.tags, description: query.description } } +} - /// Convert file to spell - pub fn from_file(file: String) -> Self { - let data = std::fs::read_to_string(file).unwrap(); - let spell: Spell = serde_json::from_str(&data).unwrap(); - return spell; +impl Into for Spell { + fn into(self) -> InsertSpell { + return InsertSpell { + name: self.name, + school: self.school, + level: self.level, + ritual: self.ritual, + casting_time: self.casting_time, + range: self.range, + components_verbal: self.components.verbal, + components_somatic: self.components.somatic, + components_material: self.components.material, + components_materials_needed: self.components.materials_needed, + duration: self.duration, + classes: self.classes, + sources: self.sources, + tags: self.tags, + description: self.description + } } } diff --git a/src/db/spells/routes.rs b/src/db/spells/routes.rs new file mode 100644 index 0000000..ea48000 --- /dev/null +++ b/src/db/spells/routes.rs @@ -0,0 +1,79 @@ +use actix_web::{get, post, put, delete, web, HttpResponse, HttpRequest, ResponseError}; +use serde::{Serialize, Deserialize}; + +use crate::db::{spells::QuerySpell, GetResponse, Metadata}; + +use super::{Spell, InsertSpell}; + +#[derive(Serialize, Deserialize)] +struct GetAllParams { + limit: Option, + page: Option, +} + +#[get("/spells")] +async fn get_all(req: HttpRequest) -> HttpResponse { + let params = web::Query::::from_query(req.query_string()).unwrap(); + let limit = params.limit.unwrap_or(20); + let page = params.page.unwrap_or(1); + match web::block(move || QuerySpell::get_all(limit, page)).await.unwrap() { + Ok(spells) => { + let mut response: Vec = Vec::new(); + for spell in spells { + response.push(Spell::from(spell)); + } + let total_count = QuerySpell::get_count().unwrap(); + HttpResponse::Ok().json(GetResponse { + data: response, + metadata: Some(Metadata { + total: total_count as i32, + limit, + page + }) + }) + }, + Err(err) => ResponseError::error_response(&err) + } +} + +#[get("/spells/{id}")] +async fn get_by_id(id: web::Path) -> HttpResponse { + match web::block(move || QuerySpell::get_by_id(id.into_inner())).await.unwrap() { + Ok(spell) => HttpResponse::Ok().json(GetResponse { + data: Spell::from(spell), + metadata: None + }), + Err(err) => ResponseError::error_response(&err) + } +} + +#[post("/spells")] +async fn create(spell: web::Json) -> HttpResponse { + match InsertSpell::insert(spell.into_inner().into()) { + Ok(spell) => HttpResponse::Created().json(Spell::from(spell)), + Err(err) => ResponseError::error_response(&err) + } +} + +#[put("/spells/{id}")] +async fn update(id: web::Path, spell: web::Json) -> HttpResponse { + match web::block(move || InsertSpell::update(id.into_inner(), spell.into_inner().into())).await.unwrap() { + Ok(spell) => HttpResponse::Ok().json(Spell::from(spell)), + Err(err) => ResponseError::error_response(&err) + } +} + +#[delete("/spells/{id}")] +async fn delete(id: web::Path) -> HttpResponse { + match web::block(move || QuerySpell::delete(id.into_inner())).await.unwrap() { + Ok(spell) => HttpResponse::Ok().json(Spell::from(spell)), + Err(err) => ResponseError::error_response(&err) + } +} + +pub fn init_routes(config: &mut web::ServiceConfig) { + config.service(get_all); + config.service(get_by_id); + config.service(create); + config.service(delete); +} \ No newline at end of file diff --git a/src/error_handler.rs b/src/error_handler.rs index 18f4f7c..4ae1f90 100644 --- a/src/error_handler.rs +++ b/src/error_handler.rs @@ -1,4 +1,6 @@ +use actix_web::{ResponseError, HttpResponse}; use diesel::result::Error as DieselError; +use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use std::fmt; @@ -29,8 +31,27 @@ impl From for ServiceError { DieselError::DatabaseError(_, err) => ServiceError::new(409, err.message().to_string()), DieselError::NotFound => { ServiceError::new(404, "The record was not found".to_string()) - } + }, + DieselError::SerializationError(err) => { + ServiceError::new(422, err.to_string()) + }, err => ServiceError::new(500, format!("Unknown Diesel error: {}", err)), } } } + +impl ResponseError for ServiceError { + fn error_response(&self) -> HttpResponse { + let status_code = match StatusCode::from_u16(self.error_status_code) { + Ok(status_code) => status_code, + Err(_) => StatusCode::INTERNAL_SERVER_ERROR, + }; + + let error_message = match status_code.as_u16() < 500 { + true => self.error_message.clone(), + false => "Internal server error".to_string(), + }; + + HttpResponse::build(status_code).json(serde_json::json!({ "message": error_message })) + } +} diff --git a/src/main.rs b/src/main.rs index 03a755c..f3506aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ use std::collections::{HashSet, HashMap}; use std::env; use std::sync::Arc; +use actix_web::{HttpServer, App}; use commands::audio::{create_response, AudioConfig, AudioConfigs}; use dotenv::dotenv; @@ -111,60 +112,86 @@ impl EventHandler for Handler { } } -#[tokio::main] -async fn main() { +#[actix_web::main] +async fn main() -> std::io::Result<()> { dotenv().ok(); env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); - - let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); - let intents: GatewayIntents = GatewayIntents::all(); - - let http: Http = Http::new(&token); - let (owners, _bot_id) = match http.get_current_application_info().await { - Ok(info) => { - let mut owners: HashSet = HashSet::new(); - if let Some(team) = info.team { - owners.insert(team.owner_user_id); - } else { - owners.insert(info.owner.id); - } - match http.get_current_user().await { - Ok(bot) => (owners, bot.id), - Err(why) => panic!("Could not access the bot id: {:?}", why) - } - }, - Err(why) => panic!("Could not access application info: {:?}", why) - }; - db::init(); - - let handler = match env::var("OPENAI_API_KEY") { - Ok(token) => { - info!("Loaded OpenAI token"); - Handler { - oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 15 }) - } - } + db::load_data(); + + // setup_discord_bot(); + + let host = env::var("SERVICE_HOST").unwrap_or("localhost".to_string()); + let port = env::var("SERVICE_PORT").unwrap_or("5000".to_string()); + + match HttpServer::new(|| { + App::new() + .configure(db::spells::init_routes) + }) + .bind(format!("{}:{}", host, port)) { + Ok(b) => { + info!("Binding server to {}:{}", host, port); + b + }, Err(err) => { - warn!("Could not load OpenAI token: {}", err); - Handler { oai: None } + error!("Could not bind server: {}", err); + return Err(err); } - }; - - let mut client = Client::builder(token, intents) - .event_handler(handler) - .framework(StandardFramework::new() - .configure(|c| c.owners(owners))) - .register_songbird() - .await - .expect("Error creating client"); - - { - let mut data = client.data.write().await; - data.insert::(Arc::new(RwLock::new(HashMap::default()))); - } - - if let Err(why) = client.start_autosharded().await { - error!("An error occurred while running the client: {:?}", why); } + .run() + .await +} + +fn setup_discord_bot() { + tokio::spawn(async { + let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); + let intents: GatewayIntents = GatewayIntents::all(); + + let http: Http = Http::new(&token); + let (owners, _bot_id) = match http.get_current_application_info().await { + Ok(info) => { + let mut owners: HashSet = HashSet::new(); + if let Some(team) = info.team { + owners.insert(team.owner_user_id); + } else { + owners.insert(info.owner.id); + } + match http.get_current_user().await { + Ok(bot) => (owners, bot.id), + Err(why) => panic!("Could not access the bot id: {:?}", why) + } + }, + Err(why) => panic!("Could not access application info: {:?}", why) + }; + + let handler = match env::var("OPENAI_API_KEY") { + Ok(token) => { + info!("Loaded OpenAI token"); + Handler { + oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 15 }) + } + } + Err(err) => { + warn!("Could not load OpenAI token: {}", err); + Handler { oai: None } + } + }; + + let mut client = Client::builder(token, intents) + .event_handler(handler) + .framework(StandardFramework::new() + .configure(|c| c.owners(owners))) + .register_songbird() + .await + .expect("Error creating client"); + + { + let mut data = client.data.write().await; + data.insert::(Arc::new(RwLock::new(HashMap::default()))); + } + + if let Err(why) = client.start_autosharded().await { + error!("An error occurred while running the client: {:?}", why); + } + }); }