diff --git a/.gitignore b/.gitignore index c6f6245..ac78a29 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,3 @@ audio/ logs/ settings.json app/ -data/ diff --git a/data/spells/cantrips.json b/data/spells/cantrips.json new file mode 100644 index 0000000..eee31ef --- /dev/null +++ b/data/spells/cantrips.json @@ -0,0 +1,37 @@ +[ + { + "name": "Acid Splash", + "school": "conjuration", + "level": 0, + "ritual": false, + "casting_time": { + "amount": 1, + "type": "action" + }, + "range": { + "type": "point", + "amount": 60, + "unit": "feet" + }, + "components": { + "verbal": true, + "somatic": true, + "material": false + }, + "duration": { + "type": "instantaneous" + }, + "classes": ["artificer", "sorcerer", "wizard"], + "sources": [ + { + "source": "PHB", + "page": 211 + } + ], + "description": { + "entries": [ + "You hurl a bubble of acid. Choose one creature within range, or choose two creatures within range that are within 5 feet of each other. A target must succeed on a Dexterity saving throw or take {@damage 1d6} acid damage.", + "This spell's damage increases by {@damage 1d6} when you reach 5th level ({@damage 2d6}), 11th level ({@damage 3d6}), and 17th level ({@damage 4d6})."] + } + } +] \ No newline at end of file diff --git a/data/spells/level_1.json b/data/spells/level_1.json new file mode 100644 index 0000000..e69de29 diff --git a/migrations/000007_create_spells/up.sql b/migrations/000007_create_spells/up.sql index 83beab8..ee9c5d7 100644 --- a/migrations/000007_create_spells/up.sql +++ b/migrations/000007_create_spells/up.sql @@ -4,15 +4,29 @@ CREATE TABLE IF NOT EXISTS spells ( school TEXT NOT NULL, level INTEGER NOT NULL, ritual BOOLEAN DEFAULT FALSE, - casting_time TEXT NOT NULL, - range TEXT NOT NULL, + casting_time_amount INTEGER NOT NULL, + casting_time_unit TEXT NOT NULL, + saving_throw TEXT[], + attack_type TEXT, + damage_type TEXT, + conditions TEXT[], + range_type TEXT NOT NULL, + range_amount INTEGER, + range_unit TEXT, + area_type TEXT, + area_amount INTEGER, + area_unit TEXT, components_verbal BOOLEAN DEFAULT FALSE, components_somatic BOOLEAN DEFAULT FALSE, components_material BOOLEAN DEFAULT FALSE, components_materials_needed TEXT, - duration TEXT NOT NULL, + components_materials_cost INTEGER, + components_materials_consumed BOOLEAN DEFAULT FALSE, + duration_type TEXT NOT NULL, + duration_amount INTEGER, + duration_unit TEXT, classes TEXT[] NOT NULL, sources TEXT[] NOT NULL, tags TEXT[], - description TEXT NOT NULL + description JSONB NOT NULL ); \ No newline at end of file diff --git a/src/commands/oai.rs b/src/commands/oai.rs index 56706aa..9e29c4b 100644 --- a/src/commands/oai.rs +++ b/src/commands/oai.rs @@ -1,7 +1,3 @@ - -use std::error::Error; -use std::fmt; - use diesel::{prelude::*, insert_into}; use log::{error, debug, trace, warn}; @@ -13,12 +9,15 @@ use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverw use serenity::prelude::*; use crate::db::{connection, messages::{MessageDB, NewMessageDB}}; +use crate::error_handler::ServiceError; pub struct OAI { pub client: reqwest::Client, pub base_url: String, pub max_attempts: i64, pub token: String, + pub max_tokens: i64, + pub default_model: GPTModel, pub max_context_questions: i64 } @@ -65,7 +64,7 @@ enum GPTRole { } #[derive(Debug, Clone, Serialize, Deserialize)] -enum GPTModel { +pub enum GPTModel { #[serde(rename = "gpt-3.5-turbo")] GPT35Turbo, #[serde(rename = "gpt-3.5-turbo-0613")] @@ -110,26 +109,18 @@ struct Choice { #[derive(Debug, Clone, Serialize, Deserialize)] struct ResponseError { - code: Option, + error: Option, message: Option, param: Option, #[serde(rename = "type")] error_type: Option } -#[derive(Debug)] -struct OAIError { - pub message: String +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ErrorDetails { + code: Option } -impl fmt::Display for OAIError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "OAIError: {}", self.message) - } -} - -impl Error for OAIError {} - #[derive(Debug, Clone, Serialize, Deserialize)] enum ResponseEvent { ChatCompletionResponse(ChatCompletionResponse), @@ -137,7 +128,7 @@ enum ResponseEvent { } impl OAI { - async fn get_request(&self, request: ChatCompletionRequest) -> Result { + async fn get_request(&self, request: ChatCompletionRequest) -> Result { let uri = format!("{}/chat/completions", self.base_url); let body = serde_json::to_string(&request).unwrap(); trace!("Sending request to {}: {}", uri, body); @@ -150,35 +141,39 @@ impl OAI { .send() .await { Ok(r) => r, - Err(err) => return Err(OAIError { + Err(err) => return Err(ServiceError { message: format!("Could not send request to OpenAI: {}", err), + status: 500 }) } .json::() .await { Ok(r) => r, - Err(err) => return Err(OAIError { - message: format!("Could not read response from OpenAI: {}", err) + Err(err) => return Err(ServiceError { + message: format!("Could not read response from OpenAI: {}", err), + status: 500 }) }; trace!("Received response from OpenAI: {:?}", value); - // let response = match serde_json::from_value::(value) { + // let response = match serde_json::from_value::(value) { // Ok(r) => { // match r { - // OAIResponseEvent::OAIResponse(r) => r, - // OAIResponseEvent::OAIError(e) => return Err(OAIError { message: e.message.unwrap_or("Unknown error".to_string()) }) + // ResponseEvent::ChatCompletionResponse(r) => r, + // ResponseEvent::ResponseError(e) => return Err(ServiceError { message: e.message.unwrap_or("Unknown error".to_string()), status: 500 }), // } // }, - // Err(err) => return Err(OAIError { - // message: format!("Could not parse response from OpenAI: {}", err) + // Err(err) => return Err(ServiceError { + // message: format!("Could not parse response from OpenAI: {}", err), + // status: 500 // }) // }; let response = match serde_json::from_value::(value) { Ok(r) => r, - Err(err) => return Err(OAIError { - message: format!("Could not parse response from OpenAI: {}", err) + Err(err) => return Err(ServiceError { + message: format!("Could not parse response from OpenAI: {}", err), + status: 500 }) }; @@ -208,20 +203,6 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { .order(crate::db::schema::messages::created.asc()) .limit(oai.max_context_questions) .load(&mut connection); - - let previous_messages = match result { - Ok(r) => { - let mut previous_message = "".to_string(); - for message in r { - previous_message = format!("{}You: {}\n Siren: {}\n", previous_message, message.request, message.response); - } - Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message }) - } - Err(err) => { - error!("Could not load previous messages: {}", err); - None - } - }; let mut messages = vec![ ChatCompletionMessage { @@ -230,25 +211,34 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { }, ]; - if let Some(mut previous) = previous_messages { - previous.content = format!("{}You: {}\nSiren: ", previous.content, parsed_content); - messages.push(previous); - } else { - messages.push(ChatCompletionMessage { - role: GPTRole::User, - content: format!("You: {}, Siren: ", parsed_content) - }); - } - - let model = "gpt-3.5-turbo".to_string(); + match result { + Ok(r) => { + for message in r { + messages.push( + ChatCompletionMessage { + role: GPTRole::User, + content: format!("{}", message.request) + } + ); + messages.push( + ChatCompletionMessage { + role: GPTRole::Assistant, + content: format!("{}", message.response) + } + ); + } + }, + Err(err) => error!("Could not load previous messages: {}", err) + }; + messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() }); let request = ChatCompletionRequest { - model: GPTModel::GPT35Turbo, + model: oai.default_model.clone(), messages, temperature: Some(0.5), top_p: None, n: None, - max_tokens: Some(1000), + max_tokens: Some(oai.max_tokens), presence_penalty: Some(0.6), frequency_penalty: Some(0.0), user: Some(msg.author.name.clone()) @@ -289,7 +279,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { channel_id: response_channel.0 as i64, user_id: author_id.0 as i64, created: r.created, - model: &model, + model: &serde_json::to_string(&r.model).unwrap(), request: &parsed_content, response: &res, request_tags: vec![], @@ -305,7 +295,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { } Err(err) => { error!("Could not get response from OpenAI: {}", err.message); - err.message + "There was an error processing your message. Please try again later.".to_string() } }; debug!("Writing response: \"{}\"", response); diff --git a/src/db/classes/mod.rs b/src/db/classes/mod.rs index e69de29..24e3024 100644 --- a/src/db/classes/mod.rs +++ b/src/db/classes/mod.rs @@ -0,0 +1,3 @@ +mod model; + +pub use model::*; \ No newline at end of file diff --git a/src/db/classes/model.rs b/src/db/classes/model.rs new file mode 100644 index 0000000..c76b89a --- /dev/null +++ b/src/db/classes/model.rs @@ -0,0 +1,41 @@ +use std::str::FromStr; + +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub enum AbilityType { + Strength, + Dexterity, + Constitution, + Intelligence, + Wisdom, + Charisma +} + +impl AbilityType { + pub fn to_string(&self) -> String { + match self { + AbilityType::Strength => "Strength".to_string(), + AbilityType::Dexterity => "Dexterity".to_string(), + AbilityType::Constitution => "Constitution".to_string(), + AbilityType::Intelligence => "Intelligence".to_string(), + AbilityType::Wisdom => "Wisdom".to_string(), + AbilityType::Charisma => "Charisma".to_string() + } + } +} + +impl FromStr for AbilityType { + type Err = (); + fn from_str(s: &str) -> Result { + match s { + "Strength" => Ok(AbilityType::Strength), + "Dexterity" => Ok(AbilityType::Dexterity), + "Constitution" => Ok(AbilityType::Constitution), + "Intelligence" => Ok(AbilityType::Intelligence), + "Wisdom" => Ok(AbilityType::Wisdom), + "Charisma" => Ok(AbilityType::Charisma), + _ => Err(()) + } + } +} \ No newline at end of file diff --git a/src/db/conditions/mod.rs b/src/db/conditions/mod.rs index e69de29..ca4e835 100644 --- a/src/db/conditions/mod.rs +++ b/src/db/conditions/mod.rs @@ -0,0 +1,69 @@ +use std::str::FromStr; + +use serde::{Deserialize, Serialize}; + + +#[derive(Debug, Serialize, Deserialize)] +pub enum ConditionType { + Blinded, + Charmed, + Deafened, + Exhaustion, + Frightened, + Grappled, + Incapacitated, + Invisible, + Paralyzed, + Petrified, + Poisoned, + Prone, + Restrained, + Stunned, + Unconscious +} + +impl ConditionType { + pub fn to_string(&self) -> String { + match self { + ConditionType::Blinded => "Blinded".to_string(), + ConditionType::Charmed => "Charmed".to_string(), + ConditionType::Deafened => "Deafened".to_string(), + ConditionType::Exhaustion => "Exhaustion".to_string(), + ConditionType::Frightened => "Frightened".to_string(), + ConditionType::Grappled => "Grappled".to_string(), + ConditionType::Incapacitated => "Incapacitated".to_string(), + ConditionType::Invisible => "Invisible".to_string(), + ConditionType::Paralyzed => "Paralyzed".to_string(), + ConditionType::Petrified => "Petrified".to_string(), + ConditionType::Poisoned => "Poisoned".to_string(), + ConditionType::Prone => "Prone".to_string(), + ConditionType::Restrained => "Restrained".to_string(), + ConditionType::Stunned => "Stunned".to_string(), + ConditionType::Unconscious => "Unconscious".to_string() + } + } +} + +impl FromStr for ConditionType { + type Err = (); + fn from_str(s: &str) -> Result { + match s { + "Blinded" => Ok(ConditionType::Blinded), + "Charmed" => Ok(ConditionType::Charmed), + "Deafened" => Ok(ConditionType::Deafened), + "Exhaustion" => Ok(ConditionType::Exhaustion), + "Frightened" => Ok(ConditionType::Frightened), + "Grappled" => Ok(ConditionType::Grappled), + "Incapacitated" => Ok(ConditionType::Incapacitated), + "Invisible" => Ok(ConditionType::Invisible), + "Paralyzed" => Ok(ConditionType::Paralyzed), + "Petrified" => Ok(ConditionType::Petrified), + "Poisoned" => Ok(ConditionType::Poisoned), + "Prone" => Ok(ConditionType::Prone), + "Restrained" => Ok(ConditionType::Restrained), + "Stunned" => Ok(ConditionType::Stunned), + "Unconscious" => Ok(ConditionType::Unconscious), + _ => Err(()) + } + } +} diff --git a/src/db/schema.rs b/src/db/schema.rs index 17e2dc8..f44816b 100644 --- a/src/db/schema.rs +++ b/src/db/schema.rs @@ -20,16 +20,30 @@ diesel::table! { school -> Text, level -> Integer, ritual -> Bool, - casting_time -> Text, - range -> Text, + casting_time_amount -> Integer, + casting_time_unit -> Text, + saving_throw -> Nullable>, + attack_type -> Nullable, + damage_type -> Nullable, + conditions -> Nullable>, + range_type -> Text, + range_amount -> Nullable, + range_unit -> Nullable, + area_type -> Nullable, + area_amount -> Nullable, + area_unit -> Nullable, components_verbal -> Bool, components_somatic -> Bool, components_material -> Bool, components_materials_needed -> Nullable, - duration -> Text, + components_materials_cost -> Nullable, + components_materials_consumed -> Nullable, + duration_type -> Text, + duration_amount -> Nullable, + duration_unit -> Nullable, classes -> Array, sources -> Array, tags -> Array, - description -> Text + description -> Jsonb } } \ No newline at end of file diff --git a/src/db/spells/mod.rs b/src/db/spells/mod.rs index 46a85d4..6dfedea 100644 --- a/src/db/spells/mod.rs +++ b/src/db/spells/mod.rs @@ -6,38 +6,40 @@ pub use routes::init_routes; pub fn load_data() { let root_path = std::env::current_dir().unwrap(); - let mut data_path = std::path::PathBuf::from(root_path); - data_path.push("data/spells.json"); - match data_path.to_str() { - Some(path) => { - log::debug!("Loading spells from {}", path); - match std::fs::read_to_string(data_path) { - Ok(data) => { - match serde_json::from_str::(&data) { - Ok(json) => { - match serde_json::from_value::>(json) { - Ok(spells) => { - let count = QuerySpell::get_count().unwrap(); - if count >= spells.len() as i64 { - log::warn!("Spell data is already loaded"); - return; - } - for spell in spells { - match InsertSpell::insert(spell.into()) { - Ok(_) => {}, - Err(err) => log::error!("Failed to insert spell: {}", err) - } - } - }, - Err(err) => log::error!("Failed to parse spells data: {}", err) - } - }, - Err(err) => log::error!("Failed to parse spells data to value: {}", err) - }; - }, - Err(err) => log::error!("Failed to read spells data: {}", err) - }; - }, - None => log::error!("Failed to find spells data directory") + let files = [ + "cantrips.json", "level_1.json", "level_2.json", "level_3.json", "level_4.json", "level_5.json", "level_6.json", "level_7.json", "level_8.json", "level_9.json" + ]; + let mut spells: Vec = vec![]; + for file in files { + let mut data_path = std::path::PathBuf::from(&root_path); + data_path.push(format!("data/spells/{}", file)); + let path = data_path.to_str().unwrap(); + match std::fs::read_to_string(path) { + Ok(data) => { + log::debug!("Loading spells from {}", path); + match serde_json::from_str::(&data) { + Ok(json) => { + match serde_json::from_value::>(json) { + Ok(mut new_spells) => spells.append(&mut new_spells), + Err(err) => log::error!("Failed to parse spells data: {}", err) + } + }, + Err(err) => log::error!("Failed to parse spells data to value: {}", err) + }; + }, + Err(err) => log::error!("Failed to read from {}: {}", file, err) + }; + } + let count = QuerySpell::get_count().unwrap(); + if count >= spells.len() as i64 { + log::warn!("Spell data is already loaded"); + return; + } + for spell in spells { + let spell_name = spell.name.clone(); + match InsertSpell::insert(spell.into()) { + Ok(_) => {}, + Err(err) => log::error!("Failed to insert '{}' spell: {}", spell_name, err) + } } } diff --git a/src/db/spells/model.rs b/src/db/spells/model.rs index e849104..792b2bb 100644 --- a/src/db/spells/model.rs +++ b/src/db/spells/model.rs @@ -1,7 +1,9 @@ -use diesel::prelude::*; -use serde::{Deserialize, Serialize}; +use std::str::FromStr; -use crate::{db::schema::spells, error_handler::ServiceError}; +use diesel::prelude::*; +use serde::{Deserialize, Serialize, ser::SerializeMap}; + +use crate::{db::{schema::spells, classes::AbilityType, conditions::ConditionType}, error_handler::ServiceError}; #[derive(Queryable, QueryableByName)] #[diesel(table_name = spells)] @@ -11,17 +13,31 @@ pub struct QuerySpell { pub school: String, pub level: i32, pub ritual: bool, - pub casting_time: String, - pub range: String, + pub casting_time_amount: i32, + pub casting_time_unit: String, + pub saving_throw: Option>, + pub attack_type: Option, + pub damage_type: Option, + pub conditions: Option>, + pub range_type: String, + pub range_amount: Option, + pub range_unit: Option, + pub area_type: Option, + pub area_amount: Option, + pub area_unit: Option, pub components_verbal: bool, pub components_somatic: bool, pub components_material: bool, pub components_materials_needed: Option, - pub duration: String, + pub components_materials_cost: Option, + pub components_materials_consumed: Option, + pub duration_type: String, + pub duration_amount: Option, + pub duration_unit: Option, pub classes: Vec, pub sources: Vec, pub tags: Vec, - pub description: String + pub description: serde_json::Value } impl QuerySpell { @@ -63,17 +79,31 @@ pub struct InsertSpell { pub school: String, pub level: i32, pub ritual: bool, - pub casting_time: String, - pub range: String, + pub casting_time_amount: i32, + pub casting_time_unit: String, + pub saving_throw: Option>, + pub attack_type: Option, + pub damage_type: Option, + pub conditions: Option>, + pub range_type: String, + pub range_amount: Option, + pub range_unit: Option, + pub area_type: Option, + pub area_amount: Option, + pub area_unit: Option, pub components_verbal: bool, pub components_somatic: bool, pub components_material: bool, pub components_materials_needed: Option, - pub duration: String, + pub components_materials_cost: Option, + pub components_materials_consumed: Option, + pub duration_type: String, + pub duration_amount: Option, + pub duration_unit: Option, pub classes: Vec, pub sources: Vec, pub tags: Vec, - pub description: String + pub description: serde_json::Value } impl InsertSpell { @@ -90,50 +120,541 @@ impl InsertSpell { } } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Spell { pub name: String, - pub school: String, + pub school: SchoolType, pub level: i32, pub ritual: bool, - pub casting_time: String, - pub range: String, + pub casting_time: CastingTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub saving_throw: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub attack_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub damage_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub conditions: Option>, + pub range: Range, + #[serde(skip_serializing_if = "Option::is_none")] + pub area: Option, pub components: Components, - pub duration: String, + pub duration: Duration, pub classes: Vec, - pub sources: Vec, - pub tags: Vec, - pub description: String + pub sources: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tags: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] +pub enum SchoolType { + #[serde(rename = "abjuration")] + Abjuration, + #[serde(rename = "conjuration")] + Conjuration, + #[serde(rename = "divination")] + Divination, + #[serde(rename = "enchantment")] + Enchantment, + #[serde(rename = "evocation")] + Evocation, + #[serde(rename = "illusion")] + Illusion, + #[serde(rename = "necromancy")] + Necromancy, + #[serde(rename = "transmutation")] + Transmutation +} + +impl SchoolType { + pub fn to_string(&self) -> String { + match self { + SchoolType::Abjuration => "abjuration".to_string(), + SchoolType::Conjuration => "conjuration".to_string(), + SchoolType::Divination => "divination".to_string(), + SchoolType::Enchantment => "enchantment".to_string(), + SchoolType::Evocation => "evocation".to_string(), + SchoolType::Illusion => "illusion".to_string(), + SchoolType::Necromancy => "necromancy".to_string(), + SchoolType::Transmutation => "transmutation".to_string() + } + } +} + +impl FromStr for SchoolType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "abjuration" => Ok(SchoolType::Abjuration), + "conjuration" => Ok(SchoolType::Conjuration), + "divination" => Ok(SchoolType::Divination), + "enchantment" => Ok(SchoolType::Enchantment), + "evocation" => Ok(SchoolType::Evocation), + "illusion" => Ok(SchoolType::Illusion), + "necromancy" => Ok(SchoolType::Necromancy), + "transmutation" => Ok(SchoolType::Transmutation), + _ => Err(()) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CastingTime { + pub amount: i32, + #[serde(rename = "type")] + pub casting_type: CastingType +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum CastingType { + #[serde(rename = "action")] + Action, + #[serde(rename = "bonus")] + BonusAction, + #[serde(rename = "reaction")] + Reaction, + #[serde(rename = "minutes")] + Minutes, + #[serde(rename = "hours")] + Hours +} + +impl CastingType { + pub fn to_string(&self) -> String { + match self { + CastingType::Action => "action".to_string(), + CastingType::BonusAction => "bonus".to_string(), + CastingType::Reaction => "reaction".to_string(), + CastingType::Minutes => "minutes".to_string(), + CastingType::Hours => "hours".to_string() + } + } +} + +impl FromStr for CastingType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "action" => Ok(CastingType::Action), + "bonus" => Ok(CastingType::BonusAction), + "reaction" => Ok(CastingType::Reaction), + "minutes" => Ok(CastingType::Minutes), + "hours" => Ok(CastingType::Hours), + _ => Err(()) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum SpellAttackType { + #[serde(rename = "melee")] + Melee, + #[serde(rename = "ranged")] + Ranged, +} + +impl SpellAttackType { + pub fn to_string(&self) -> String { + match self { + SpellAttackType::Melee => "melee".to_string(), + SpellAttackType::Ranged => "ranged".to_string() + } + } +} + +impl FromStr for SpellAttackType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "melee" => Ok(SpellAttackType::Melee), + "ranged" => Ok(SpellAttackType::Ranged), + _ => Err(()) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum SpellDamageType { + #[serde(rename = "acid")] + Acid, + #[serde(rename = "bludgeoning")] + Bludgeoning, + #[serde(rename = "cold")] + Cold, + #[serde(rename = "fire")] + Fire, + #[serde(rename = "force")] + Force, + #[serde(rename = "lightning")] + Lightning, + #[serde(rename = "necrotic")] + Necrotic, + #[serde(rename = "piercing")] + Piercing, + #[serde(rename = "poison")] + Poison, + #[serde(rename = "psychic")] + Psychic, + #[serde(rename = "radiant")] + Radiant, + #[serde(rename = "slashing")] + Slashing, + #[serde(rename = "thunder")] + Thunder +} + +impl SpellDamageType { + pub fn to_string(&self) -> String { + match self { + SpellDamageType::Acid => "acid".to_string(), + SpellDamageType::Bludgeoning => "bludgeoning".to_string(), + SpellDamageType::Cold => "cold".to_string(), + SpellDamageType::Fire => "fire".to_string(), + SpellDamageType::Force => "force".to_string(), + SpellDamageType::Lightning => "lightning".to_string(), + SpellDamageType::Necrotic => "necrotic".to_string(), + SpellDamageType::Piercing => "piercing".to_string(), + SpellDamageType::Poison => "poison".to_string(), + SpellDamageType::Psychic => "psychic".to_string(), + SpellDamageType::Radiant => "radiant".to_string(), + SpellDamageType::Slashing => "slashing".to_string(), + SpellDamageType::Thunder => "thunder".to_string() + } + } +} + +impl FromStr for SpellDamageType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "acid" => Ok(SpellDamageType::Acid), + "bludgeoning" => Ok(SpellDamageType::Bludgeoning), + "cold" => Ok(SpellDamageType::Cold), + "fire" => Ok(SpellDamageType::Fire), + "force" => Ok(SpellDamageType::Force), + "lightning" => Ok(SpellDamageType::Lightning), + "necrotic" => Ok(SpellDamageType::Necrotic), + "piercing" => Ok(SpellDamageType::Piercing), + "poison" => Ok(SpellDamageType::Poison), + "psychic" => Ok(SpellDamageType::Psychic), + "radiant" => Ok(SpellDamageType::Radiant), + "slashing" => Ok(SpellDamageType::Slashing), + "thunder" => Ok(SpellDamageType::Thunder), + _ => Err(()) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Range { + #[serde(rename = "type")] + pub range_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub amount: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub unit: Option +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Area { + #[serde(rename = "type")] + pub area_type: AreaType, + #[serde(skip_serializing_if = "Option::is_none")] + pub amount: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub unit: Option +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum AreaType { + #[serde(rename = "cone")] + Cone, + #[serde(rename = "cube")] + Cube, + #[serde(rename = "cylinder")] + Cylinder, + #[serde(rename = "line")] + Line, + #[serde(rename = "sphere")] + Sphere +} + +impl AreaType { + pub fn to_string(&self) -> String { + match self { + AreaType::Cone => "cone".to_string(), + AreaType::Cube => "cube".to_string(), + AreaType::Cylinder => "cylinder".to_string(), + AreaType::Line => "line".to_string(), + AreaType::Sphere => "sphere".to_string() + } + } +} + +impl FromStr for AreaType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "cone" => Ok(AreaType::Cone), + "cube" => Ok(AreaType::Cube), + "cylinder" => Ok(AreaType::Cylinder), + "line" => Ok(AreaType::Line), + "sphere" => Ok(AreaType::Sphere), + _ => Err(()) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Duration { + #[serde(rename = "type")] + pub duration_type: DurationType, + #[serde(skip_serializing_if = "Option::is_none")] + pub amount: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub unit: Option +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum DurationType { + #[serde(rename = "concentration")] + Concentration, + #[serde(rename = "instantaneous")] + Instantaneous, + #[serde(rename = "timed")] + Timed, + #[serde(rename = "dispelled")] + UntilDispelled, + #[serde(rename = "special")] + Special +} + +impl DurationType { + pub fn to_string(&self) -> String { + match self { + DurationType::Concentration => "concentration".to_string(), + DurationType::Instantaneous => "instantaneous".to_string(), + DurationType::Timed => "timed".to_string(), + DurationType::UntilDispelled => "dispelled".to_string(), + DurationType::Special => "special".to_string() + } + } +} + +impl FromStr for DurationType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "concentration" => Ok(DurationType::Concentration), + "instantaneous" => Ok(DurationType::Instantaneous), + "timed" => Ok(DurationType::Timed), + "dispelled" => Ok(DurationType::UntilDispelled), + "special" => Ok(DurationType::Special), + _ => Err(()) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Source { + pub source: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub page: Option +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Description { + pub entries: Vec +} + +#[derive(Debug)] +pub struct Entry { + pub entry_type: String, + pub items: Vec +} + +impl<'de> Deserialize<'de> for Entry { + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { + let value = serde_json::Value::deserialize(deserializer)?; + match value { + serde_json::Value::String(s) => Ok(Entry { + entry_type: "string".to_string(), + items: vec![s] + }), + serde_json::Value::Object(o) => { + let entry_type = match o.get("type") { + Some(t) => match t.as_str() { + Some(s) => s.to_string(), + None => return Err(serde::de::Error::custom("Invalid entry type")) + }, + None => return Err(serde::de::Error::custom("Missing entry type")) + }; + let items = match o.get("items") { + Some(i) => match i.as_array() { + Some(a) => { + let mut items = Vec::new(); + for item in a { + match item.as_str() { + Some(s) => items.push(s.to_string()), + None => return Err(serde::de::Error::custom("Invalid entry item")) + } + } + items + }, + None => return Err(serde::de::Error::custom("Invalid entry items")) + }, + None => return Err(serde::de::Error::custom("Missing entry items")) + }; + Ok(Entry { + entry_type, + items + }) + }, + _ => Err(serde::de::Error::custom("Invalid entry")) + } + } +} + +impl Serialize for Entry { + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + match self.entry_type.as_str() { + "string" => serializer.serialize_str(&self.items[0]), + _ => { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("type", &self.entry_type)?; + map.serialize_entry("items", &self.items)?; + map.end() + } + } + } +} + +#[derive(Debug, Serialize, Deserialize)] pub struct Components { pub verbal: bool, pub somatic: bool, pub material: bool, - pub materials_needed: Option + #[serde(skip_serializing_if = "Option::is_none")] + pub materials_needed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub materials_cost: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub materials_consumed: Option } impl From for Spell { fn from(query: QuerySpell) -> Self { return Self { name: query.name, - school: query.school, + school: match SchoolType::from_str(&query.school) { + Ok(school_type) => school_type, + Err(_) => { + log::error!("Failed to parse spell school type: {}", query.school); + SchoolType::Abjuration + } + }, level: query.level, ritual: query.ritual, - casting_time: query.casting_time, - range: query.range, + casting_time: CastingTime { + amount: query.casting_time_amount, + casting_type: match CastingType::from_str(&query.casting_time_unit) { + Ok(casting_type) => casting_type, + Err(_) => { + log::error!("Failed to parse spell casting type: {}", query.casting_time_unit); + CastingType::Action + } + } + }, + saving_throw: query.saving_throw.map(|saving_throw| saving_throw.iter().map(|ability_type| match AbilityType::from_str(&ability_type) { + Ok(ability_type) => ability_type, + Err(_) => { + log::error!("Failed to parse spell saving throw: {}", ability_type); + AbilityType::Strength + } + }).collect()), + attack_type: match query.attack_type { + Some(attack_type) => match SpellAttackType::from_str(&attack_type) { + Ok(attack_type) => Some(attack_type), + Err(_) => { + log::error!("Failed to parse spell attack type: {}", attack_type); + None + } + }, + None => None + }, + damage_type: query.damage_type.map(|damage_type| match SpellDamageType::from_str(&damage_type) { + Ok(damage_type) => damage_type, + Err(_) => { + log::error!("Failed to parse spell damage type: {}", damage_type); + SpellDamageType::Acid + } + }), + conditions: query.conditions.map(|conditions| conditions.iter().map(|condition_type| match ConditionType::from_str(&condition_type) { + Ok(condition_type) => condition_type, + Err(_) => { + log::error!("Failed to parse spell condition type: {}", condition_type); + ConditionType::Blinded + } + }).collect()), + range: Range { + range_type: query.range_type, + amount: query.range_amount, + unit: query.range_unit + }, + area: match query.area_type { + Some(area_type) => Some(Area { + area_type: match AreaType::from_str(&area_type) { + Ok(area_type) => area_type, + Err(_) => { + log::error!("Failed to parse spell area type: {}", area_type); + AreaType::Cone + } + }, + amount: query.area_amount, + unit: query.area_unit + }), + None => None + }, components: Components { verbal: query.components_verbal, somatic: query.components_somatic, material: query.components_material, - materials_needed: query.components_materials_needed + materials_needed: query.components_materials_needed, + materials_cost: query.components_materials_cost, + materials_consumed: query.components_materials_consumed + }, + duration: Duration { + duration_type: match DurationType::from_str(&query.duration_type) { + Ok(duration_type) => duration_type, + Err(_) => { + log::error!("Failed to parse spell duration type: {}", query.duration_type); + DurationType::Special + } + }, + amount: query.duration_amount, + unit: query.duration_unit }, - duration: query.duration, classes: query.classes, - sources: query.sources, - tags: query.tags, - description: query.description + sources: query.sources.iter().map(|source| Source { + source: source.to_string(), + page: None + }).collect(), + tags: Some(query.tags), + description: match serde_json::from_value(query.description) { + Ok(description) => description, + Err(err) => { + log::error!("Failed to parse spell description: {}", err); + None + } + } } } } @@ -142,20 +663,67 @@ impl Into for Spell { fn into(self) -> InsertSpell { return InsertSpell { name: self.name, - school: self.school, + school: self.school.to_string(), level: self.level, ritual: self.ritual, - casting_time: self.casting_time, - range: self.range, + casting_time_amount: self.casting_time.amount, + casting_time_unit: self.casting_time.casting_type.to_string(), + saving_throw: match self.saving_throw { + Some(saving_throw) => Some(saving_throw.iter().map(|ability_type| ability_type.to_string()).collect()), + None => None + }, + attack_type: match self.attack_type { + Some(attack_type) => Some(attack_type.to_string()), + None => None + }, + damage_type: match self.damage_type { + Some(damage_type) => Some(damage_type.to_string()), + None => None + }, + conditions: match self.conditions { + Some(conditions) => Some(conditions.iter().map(|condition_type| condition_type.to_string()).collect()), + None => None + }, + range_type: self.range.range_type.to_string(), + range_amount: self.range.amount, + range_unit: self.range.unit, + area_type: match &self.area { + Some(area) => Some(area.area_type.to_string()), + None => None + }, + area_amount: match &self.area { + Some(area) => area.amount, + None => None + }, + area_unit: match &self.area { + Some(area) => match &area.unit { + Some(unit) => Some(unit.to_string()), + None => None + }, + None => None + }, components_verbal: self.components.verbal, components_somatic: self.components.somatic, components_material: self.components.material, components_materials_needed: self.components.materials_needed, - duration: self.duration, + components_materials_cost: self.components.materials_cost, + components_materials_consumed: self.components.materials_consumed, + duration_type: self.duration.duration_type.to_string(), + duration_amount: self.duration.amount, + duration_unit: self.duration.unit, classes: self.classes, - sources: self.sources, - tags: self.tags, - description: self.description + sources: self.sources.iter().map(|source| match source.page { Some(page) => format!("{} {}", source.source, page), None => source.source.to_string() }).collect(), + tags: match self.tags { + Some(tags) => tags, + None => Vec::new() + }, + description: match serde_json::to_value(self.description) { + Ok(description) => description, + Err(err) => { + log::error!("Failed to serialize spell description: {}", err); + serde_json::Value::Null + } + } } } } diff --git a/src/error_handler.rs b/src/error_handler.rs index 4ae1f90..1d20939 100644 --- a/src/error_handler.rs +++ b/src/error_handler.rs @@ -6,22 +6,22 @@ use std::fmt; #[derive(Debug, Deserialize, Serialize)] pub struct ServiceError { - pub error_status_code: u16, - pub error_message: String, + pub status: u16, + pub message: String, } impl ServiceError { pub fn new(error_status_code: u16, error_message: String) -> ServiceError { ServiceError { - error_status_code, - error_message, + status: error_status_code, + message: error_message, } } } impl fmt::Display for ServiceError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(self.error_message.as_str()) + f.write_str(self.message.as_str()) } } @@ -42,13 +42,13 @@ impl From for ServiceError { impl ResponseError for ServiceError { fn error_response(&self) -> HttpResponse { - let status_code = match StatusCode::from_u16(self.error_status_code) { + let status_code = match StatusCode::from_u16(self.status) { Ok(status_code) => status_code, Err(_) => StatusCode::INTERNAL_SERVER_ERROR, }; let error_message = match status_code.as_u16() < 500 { - true => self.error_message.clone(), + true => self.message.clone(), false => "Internal server error".to_string(), }; diff --git a/src/main.rs b/src/main.rs index f3506aa..832c7a1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,8 @@ use serenity::http::Http; use serenity::prelude::*; use songbird::SerenityInit; +use crate::commands::oai::GPTModel; + mod commands; mod error_handler; mod db; @@ -168,7 +170,15 @@ fn setup_discord_bot() { Ok(token) => { info!("Loaded OpenAI token"); Handler { - oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 15 }) + oai: Some(commands::oai::OAI { + client: reqwest::Client::new(), + base_url: "https://api.openai.com/v1".to_string(), + max_attempts: 5, + token, + max_context_questions: 30, + max_tokens: 2048, + default_model: GPTModel::GPT35Turbo, + }) } } Err(err) => {