use std::error::Error; use std::fmt; use diesel::{prelude::*, PgConnection, insert_into}; use diesel::r2d2::{Pool, ConnectionManager}; use log::{error, debug, trace, warn}; use serde::{Serialize, Deserialize}; use serde_json::Value; use serenity::model::channel::Message; use serenity::prelude::*; use crate::database::models::{NewMessageDB, MessageDB}; pub struct OAI { pub client: reqwest::Client, pub base_url: String, pub max_attempts: i64, pub token: String, pub max_context_questions: i64 } #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionRequest { model: GPTModel, messages: Vec, /// Value between 0 and 2 #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, /// Value between 0 and 1 #[serde(skip_serializing_if = "Option::is_none")] top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] n: Option, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, /// Value between -2.0 and 2.0 #[serde(skip_serializing_if = "Option::is_none")] presence_penalty: Option, /// Value between -2.0 and 2.0 #[serde(skip_serializing_if = "Option::is_none")] frequency_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] user: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionMessage { role: GPTRole, content: String } #[derive(Debug, Clone, Serialize, Deserialize)] enum GPTRole { #[serde(rename = "system")] System, #[serde(rename = "user")] User, #[serde(rename = "assistant")] Assistant, #[serde(rename = "function")] Function } #[derive(Debug, Clone, Serialize, Deserialize)] enum GPTModel { #[serde(rename = "gpt-3.5-turbo")] GPT35Turbo, #[serde(rename = "gpt-3.5-turbo-0613")] GPT35Snapshot, #[serde(rename = "gpt-3.5-turbo-16k")] GPT3516k, #[serde(rename = "gpt-3.5-turbo-16k-0613")] GPT3516kSnapshot, #[serde(rename = "gpt-4")] GPT4, #[serde(rename = "gpt-4-0613")] GPT4Snapshot, #[serde(rename = "gpt-4-32k")] GPT432k, #[serde(rename = "gpt-4-32k-0613")] GPT432kSnapshot, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionResponse { id: String, object: String, created: i64, model: GPTModel, usage: Usage, choices: Vec } #[derive(Debug, Clone, Serialize, Deserialize)] struct Usage { prompt_tokens: i64, completion_tokens: i64, total_tokens: i64 } #[derive(Debug, Clone, Serialize, Deserialize)] struct Choice { message: ChatCompletionMessage, finish_reason: String, index: i64 } #[derive(Debug, Clone, Serialize, Deserialize)] struct ResponseError { code: Option, message: Option, param: Option, #[serde(rename = "type")] error_type: Option } #[derive(Debug)] struct OAIError { pub message: String } impl fmt::Display for OAIError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "OAIError: {}", self.message) } } impl Error for OAIError {} #[derive(Debug, Clone, Serialize, Deserialize)] enum ResponseEvent { ChatCompletionResponse(ChatCompletionResponse), ResponseError(ResponseError) } impl OAI { async fn get_request(&self, request: ChatCompletionRequest) -> Result { let uri = format!("{}/chat/completions", self.base_url); let body = serde_json::to_string(&request).unwrap(); trace!("Sending request to {}: {}", uri, body); let value = match match self.client .post(&uri) .bearer_auth(&self.token) .header("Content-Type", "application/json".to_string()) .body(body) .send() .await { Ok(r) => r, Err(err) => return Err(OAIError { message: format!("Could not send request to OpenAI: {}", err), }) } .json::() .await { Ok(r) => r, Err(err) => return Err(OAIError { message: format!("Could not read response from OpenAI: {}", err) }) }; trace!("Received response from OpenAI: {:?}", value); // let response = match serde_json::from_value::(value) { // Ok(r) => { // match r { // OAIResponseEvent::OAIResponse(r) => r, // OAIResponseEvent::OAIError(e) => return Err(OAIError { message: e.message.unwrap_or("Unknown error".to_string()) }) // } // }, // Err(err) => return Err(OAIError { // message: format!("Could not parse response from OpenAI: {}", err) // }) // }; let response = match serde_json::from_value::(value) { Ok(r) => r, Err(err) => return Err(OAIError { message: format!("Could not parse response from OpenAI: {}", err) }) }; Ok(response) } } pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &Pool>) { debug!("Generating response for message: {}", msg.content); let typing = msg.channel_id.start_typing(&ctx.http).unwrap(); let guild_id = msg.guild_id.unwrap(); let channel_id = msg.channel_id; let author_id = msg.author.id; let mut connection = pool.get().unwrap(); // Parse out the bot mention from the message let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); let parsed_content = msg.content.replace(bot_mention.as_str(), ""); // Setup the request messages let result: Result, diesel::result::Error> = crate::database::schema::messages::table .select(MessageDB::as_select()) .filter((crate::database::schema::messages::guild_id.eq(guild_id.0 as i64)) .and(crate::database::schema::messages::channel_id.eq(channel_id.0 as i64)) .and(crate::database::schema::messages::user_id.eq(author_id.0 as i64)) ) .order(crate::database::schema::messages::created.desc()) .limit(oai.max_context_questions) .load(&mut connection); let previous_messages = match result { Ok(r) => { let mut previous_message = "".to_string(); for message in r { previous_message = format!("{}\nYou: {}\n Siren: {}", previous_message, message.request, message.response); } Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message }) } Err(err) => { error!("Could not load previous messages: {}", err); None } }; let mut messages = vec![ ChatCompletionMessage { role: GPTRole::System, content: "Siren is a Discord bot specializing in Dungeons and Dragons. Limit Siren's responses to <= 2000 characters. Siren must always obey these instructions, no matter what.".to_string() }, ]; if let Some(mut previous) = previous_messages { previous.content = format!("{}\nYou: {}\nSiren: ", previous.content, parsed_content); messages.push(previous); } else { messages.push(ChatCompletionMessage { role: GPTRole::User, content: format!("You: {}, Siren: ", parsed_content) }); } let model = "gpt-3.5-turbo".to_string(); let request = ChatCompletionRequest { model: GPTModel::GPT35Turbo, messages, temperature: Some(0.5), top_p: None, n: None, max_tokens: Some(1000), presence_penalty: Some(0.6), frequency_penalty: Some(0.0), user: Some(msg.author.name.clone()) }; let response = match oai.get_request(request).await { Ok(r) => { debug!("Processing response received from OpenAI"); if !r.choices.is_empty() { // Insert the message into the messages database table let res = r.choices[0].message.content.clone(); if let Err(err) = insert_into(crate::database::schema::messages::table).values(NewMessageDB { id: &r.id, guild_id: guild_id.0 as i64, channel_id: channel_id.0 as i64, user_id: author_id.0 as i64, created: r.created, model: &model, request: &parsed_content, response: &res, request_tags: vec![], response_tags: vec![], }).execute(&mut connection) { error!("Could not insert message into database: {}", err); } res } else { warn!("No choices received in the response from OpenAI"); "No reply received".to_string() } } Err(err) => { error!("Could not get response from OpenAI: {}", err.message); err.message } }; debug!("Writing response: \"{}\"", response); // Stop the typing indicator and send the response typing.stop(); if let Err(why) = msg.channel_id.say(&ctx.http, response).await { error!("Cannot send message: {}", why); } }