Updated spells schema

This commit is contained in:
Benjamin Sherriff
2023-10-03 21:00:14 -04:00
parent 16d8fa5af8
commit 75a71410a5
13 changed files with 890 additions and 143 deletions

View File

@@ -1,7 +1,3 @@
use std::error::Error;
use std::fmt;
use diesel::{prelude::*, insert_into};
use log::{error, debug, trace, warn};
@@ -13,12 +9,15 @@ use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverw
use serenity::prelude::*;
use crate::db::{connection, messages::{MessageDB, NewMessageDB}};
use crate::error_handler::ServiceError;
pub struct OAI {
pub client: reqwest::Client,
pub base_url: String,
pub max_attempts: i64,
pub token: String,
pub max_tokens: i64,
pub default_model: GPTModel,
pub max_context_questions: i64
}
@@ -65,7 +64,7 @@ enum GPTRole {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum GPTModel {
pub enum GPTModel {
#[serde(rename = "gpt-3.5-turbo")]
GPT35Turbo,
#[serde(rename = "gpt-3.5-turbo-0613")]
@@ -110,26 +109,18 @@ struct Choice {
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ResponseError {
code: Option<i64>,
error: Option<ErrorDetails>,
message: Option<String>,
param: Option<String>,
#[serde(rename = "type")]
error_type: Option<String>
}
#[derive(Debug)]
struct OAIError {
pub message: String
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ErrorDetails {
code: Option<String>
}
impl fmt::Display for OAIError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "OAIError: {}", self.message)
}
}
impl Error for OAIError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum ResponseEvent {
ChatCompletionResponse(ChatCompletionResponse),
@@ -137,7 +128,7 @@ enum ResponseEvent {
}
impl OAI {
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, OAIError> {
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
let uri = format!("{}/chat/completions", self.base_url);
let body = serde_json::to_string(&request).unwrap();
trace!("Sending request to {}: {}", uri, body);
@@ -150,35 +141,39 @@ impl OAI {
.send()
.await {
Ok(r) => r,
Err(err) => return Err(OAIError {
Err(err) => return Err(ServiceError {
message: format!("Could not send request to OpenAI: {}", err),
status: 500
})
}
.json::<Value>()
.await {
Ok(r) => r,
Err(err) => return Err(OAIError {
message: format!("Could not read response from OpenAI: {}", err)
Err(err) => return Err(ServiceError {
message: format!("Could not read response from OpenAI: {}", err),
status: 500
})
};
trace!("Received response from OpenAI: {:?}", value);
// let response = match serde_json::from_value::<OAIResponseEvent>(value) {
// let response = match serde_json::from_value::<ResponseEvent>(value) {
// Ok(r) => {
// match r {
// OAIResponseEvent::OAIResponse(r) => r,
// OAIResponseEvent::OAIError(e) => return Err(OAIError { message: e.message.unwrap_or("Unknown error".to_string()) })
// ResponseEvent::ChatCompletionResponse(r) => r,
// ResponseEvent::ResponseError(e) => return Err(ServiceError { message: e.message.unwrap_or("Unknown error".to_string()), status: 500 }),
// }
// },
// Err(err) => return Err(OAIError {
// message: format!("Could not parse response from OpenAI: {}", err)
// Err(err) => return Err(ServiceError {
// message: format!("Could not parse response from OpenAI: {}", err),
// status: 500
// })
// };
let response = match serde_json::from_value::<ChatCompletionResponse>(value) {
Ok(r) => r,
Err(err) => return Err(OAIError {
message: format!("Could not parse response from OpenAI: {}", err)
Err(err) => return Err(ServiceError {
message: format!("Could not parse response from OpenAI: {}", err),
status: 500
})
};
@@ -208,20 +203,6 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
.order(crate::db::schema::messages::created.asc())
.limit(oai.max_context_questions)
.load(&mut connection);
let previous_messages = match result {
Ok(r) => {
let mut previous_message = "".to_string();
for message in r {
previous_message = format!("{}You: {}\n Siren: {}\n", previous_message, message.request, message.response);
}
Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message })
}
Err(err) => {
error!("Could not load previous messages: {}", err);
None
}
};
let mut messages = vec![
ChatCompletionMessage {
@@ -230,25 +211,34 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
},
];
if let Some(mut previous) = previous_messages {
previous.content = format!("{}You: {}\nSiren: ", previous.content, parsed_content);
messages.push(previous);
} else {
messages.push(ChatCompletionMessage {
role: GPTRole::User,
content: format!("You: {}, Siren: ", parsed_content)
});
}
let model = "gpt-3.5-turbo".to_string();
match result {
Ok(r) => {
for message in r {
messages.push(
ChatCompletionMessage {
role: GPTRole::User,
content: format!("{}", message.request)
}
);
messages.push(
ChatCompletionMessage {
role: GPTRole::Assistant,
content: format!("{}", message.response)
}
);
}
},
Err(err) => error!("Could not load previous messages: {}", err)
};
messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() });
let request = ChatCompletionRequest {
model: GPTModel::GPT35Turbo,
model: oai.default_model.clone(),
messages,
temperature: Some(0.5),
top_p: None,
n: None,
max_tokens: Some(1000),
max_tokens: Some(oai.max_tokens),
presence_penalty: Some(0.6),
frequency_penalty: Some(0.0),
user: Some(msg.author.name.clone())
@@ -289,7 +279,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
channel_id: response_channel.0 as i64,
user_id: author_id.0 as i64,
created: r.created,
model: &model,
model: &serde_json::to_string(&r.model).unwrap(),
request: &parsed_content,
response: &res,
request_tags: vec![],
@@ -305,7 +295,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
}
Err(err) => {
error!("Could not get response from OpenAI: {}", err.message);
err.message
"There was an error processing your message. Please try again later.".to_string()
}
};
debug!("Writing response: \"{}\"", response);