Updated spells schema
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
use diesel::{prelude::*, insert_into};
|
||||
use log::{error, debug, trace, warn};
|
||||
|
||||
@@ -13,12 +9,15 @@ use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverw
|
||||
use serenity::prelude::*;
|
||||
|
||||
use crate::db::{connection, messages::{MessageDB, NewMessageDB}};
|
||||
use crate::error_handler::ServiceError;
|
||||
|
||||
pub struct OAI {
|
||||
pub client: reqwest::Client,
|
||||
pub base_url: String,
|
||||
pub max_attempts: i64,
|
||||
pub token: String,
|
||||
pub max_tokens: i64,
|
||||
pub default_model: GPTModel,
|
||||
pub max_context_questions: i64
|
||||
}
|
||||
|
||||
@@ -65,7 +64,7 @@ enum GPTRole {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
enum GPTModel {
|
||||
pub enum GPTModel {
|
||||
#[serde(rename = "gpt-3.5-turbo")]
|
||||
GPT35Turbo,
|
||||
#[serde(rename = "gpt-3.5-turbo-0613")]
|
||||
@@ -110,26 +109,18 @@ struct Choice {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ResponseError {
|
||||
code: Option<i64>,
|
||||
error: Option<ErrorDetails>,
|
||||
message: Option<String>,
|
||||
param: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
error_type: Option<String>
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct OAIError {
|
||||
pub message: String
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ErrorDetails {
|
||||
code: Option<String>
|
||||
}
|
||||
|
||||
impl fmt::Display for OAIError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "OAIError: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for OAIError {}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
enum ResponseEvent {
|
||||
ChatCompletionResponse(ChatCompletionResponse),
|
||||
@@ -137,7 +128,7 @@ enum ResponseEvent {
|
||||
}
|
||||
|
||||
impl OAI {
|
||||
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, OAIError> {
|
||||
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
|
||||
let uri = format!("{}/chat/completions", self.base_url);
|
||||
let body = serde_json::to_string(&request).unwrap();
|
||||
trace!("Sending request to {}: {}", uri, body);
|
||||
@@ -150,35 +141,39 @@ impl OAI {
|
||||
.send()
|
||||
.await {
|
||||
Ok(r) => r,
|
||||
Err(err) => return Err(OAIError {
|
||||
Err(err) => return Err(ServiceError {
|
||||
message: format!("Could not send request to OpenAI: {}", err),
|
||||
status: 500
|
||||
})
|
||||
}
|
||||
.json::<Value>()
|
||||
.await {
|
||||
Ok(r) => r,
|
||||
Err(err) => return Err(OAIError {
|
||||
message: format!("Could not read response from OpenAI: {}", err)
|
||||
Err(err) => return Err(ServiceError {
|
||||
message: format!("Could not read response from OpenAI: {}", err),
|
||||
status: 500
|
||||
})
|
||||
};
|
||||
|
||||
trace!("Received response from OpenAI: {:?}", value);
|
||||
|
||||
// let response = match serde_json::from_value::<OAIResponseEvent>(value) {
|
||||
// let response = match serde_json::from_value::<ResponseEvent>(value) {
|
||||
// Ok(r) => {
|
||||
// match r {
|
||||
// OAIResponseEvent::OAIResponse(r) => r,
|
||||
// OAIResponseEvent::OAIError(e) => return Err(OAIError { message: e.message.unwrap_or("Unknown error".to_string()) })
|
||||
// ResponseEvent::ChatCompletionResponse(r) => r,
|
||||
// ResponseEvent::ResponseError(e) => return Err(ServiceError { message: e.message.unwrap_or("Unknown error".to_string()), status: 500 }),
|
||||
// }
|
||||
// },
|
||||
// Err(err) => return Err(OAIError {
|
||||
// message: format!("Could not parse response from OpenAI: {}", err)
|
||||
// Err(err) => return Err(ServiceError {
|
||||
// message: format!("Could not parse response from OpenAI: {}", err),
|
||||
// status: 500
|
||||
// })
|
||||
// };
|
||||
let response = match serde_json::from_value::<ChatCompletionResponse>(value) {
|
||||
Ok(r) => r,
|
||||
Err(err) => return Err(OAIError {
|
||||
message: format!("Could not parse response from OpenAI: {}", err)
|
||||
Err(err) => return Err(ServiceError {
|
||||
message: format!("Could not parse response from OpenAI: {}", err),
|
||||
status: 500
|
||||
})
|
||||
};
|
||||
|
||||
@@ -208,20 +203,6 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
.order(crate::db::schema::messages::created.asc())
|
||||
.limit(oai.max_context_questions)
|
||||
.load(&mut connection);
|
||||
|
||||
let previous_messages = match result {
|
||||
Ok(r) => {
|
||||
let mut previous_message = "".to_string();
|
||||
for message in r {
|
||||
previous_message = format!("{}You: {}\n Siren: {}\n", previous_message, message.request, message.response);
|
||||
}
|
||||
Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message })
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Could not load previous messages: {}", err);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let mut messages = vec![
|
||||
ChatCompletionMessage {
|
||||
@@ -230,25 +211,34 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
},
|
||||
];
|
||||
|
||||
if let Some(mut previous) = previous_messages {
|
||||
previous.content = format!("{}You: {}\nSiren: ", previous.content, parsed_content);
|
||||
messages.push(previous);
|
||||
} else {
|
||||
messages.push(ChatCompletionMessage {
|
||||
role: GPTRole::User,
|
||||
content: format!("You: {}, Siren: ", parsed_content)
|
||||
});
|
||||
}
|
||||
|
||||
let model = "gpt-3.5-turbo".to_string();
|
||||
match result {
|
||||
Ok(r) => {
|
||||
for message in r {
|
||||
messages.push(
|
||||
ChatCompletionMessage {
|
||||
role: GPTRole::User,
|
||||
content: format!("{}", message.request)
|
||||
}
|
||||
);
|
||||
messages.push(
|
||||
ChatCompletionMessage {
|
||||
role: GPTRole::Assistant,
|
||||
content: format!("{}", message.response)
|
||||
}
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(err) => error!("Could not load previous messages: {}", err)
|
||||
};
|
||||
messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() });
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: GPTModel::GPT35Turbo,
|
||||
model: oai.default_model.clone(),
|
||||
messages,
|
||||
temperature: Some(0.5),
|
||||
top_p: None,
|
||||
n: None,
|
||||
max_tokens: Some(1000),
|
||||
max_tokens: Some(oai.max_tokens),
|
||||
presence_penalty: Some(0.6),
|
||||
frequency_penalty: Some(0.0),
|
||||
user: Some(msg.author.name.clone())
|
||||
@@ -289,7 +279,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
channel_id: response_channel.0 as i64,
|
||||
user_id: author_id.0 as i64,
|
||||
created: r.created,
|
||||
model: &model,
|
||||
model: &serde_json::to_string(&r.model).unwrap(),
|
||||
request: &parsed_content,
|
||||
response: &res,
|
||||
request_tags: vec![],
|
||||
@@ -305,7 +295,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Could not get response from OpenAI: {}", err.message);
|
||||
err.message
|
||||
"There was an error processing your message. Please try again later.".to_string()
|
||||
}
|
||||
};
|
||||
debug!("Writing response: \"{}\"", response);
|
||||
|
||||
Reference in New Issue
Block a user