From 8e87524ff64799f601318f04b6fff5c2d97fecd4 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Fri, 7 Jul 2023 09:39:14 -0400 Subject: [PATCH] Linked database to Open AI prompts/responses --- src/commands/oai.rs | 120 +++++++++++++++++++++++++++++++------------- 1 file changed, 86 insertions(+), 34 deletions(-) diff --git a/src/commands/oai.rs b/src/commands/oai.rs index 99536e7..ba976f5 100644 --- a/src/commands/oai.rs +++ b/src/commands/oai.rs @@ -4,14 +4,14 @@ use std::fmt; use diesel::{prelude::*, PgConnection, insert_into}; use diesel::r2d2::{Pool, ConnectionManager}; -use log::{error, debug, trace}; +use log::{error, debug, trace, warn}; use serde::{Serialize, Deserialize}; use serde_json::Value; use serenity::model::channel::Message; use serenity::prelude::*; -use crate::database::models::NewMessageDB; +use crate::database::models::{NewMessageDB, MessageDB}; pub struct OAI { pub client: reqwest::Client, @@ -23,7 +23,7 @@ pub struct OAI { #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionRequest { - model: String, + model: GPTModel, messages: Vec, /// Value between 0 and 2 #[serde(skip_serializing_if = "Option::is_none")] @@ -47,20 +47,40 @@ struct ChatCompletionRequest { #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionMessage { - role: String, + role: GPTRole, content: String } #[derive(Debug, Clone, Serialize, Deserialize)] -enum Role { +enum GPTRole { #[serde(rename = "system")] - SYSTEM, + System, #[serde(rename = "user")] - USER, + User, #[serde(rename = "assistant")] - ASSISTANT, + Assistant, #[serde(rename = "function")] - FUNCTION + Function +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum GPTModel { + #[serde(rename = "gpt-3.5-turbo")] + GPT35Turbo, + #[serde(rename = "gpt-3.5-turbo-0613")] + GPT35Snapshot, + #[serde(rename = "gpt-3.5-turbo-16k")] + GPT3516k, + #[serde(rename = "gpt-3.5-turbo-16k-0613")] + GPT3516kSnapshot, + #[serde(rename = "gpt-4")] + GPT4, + #[serde(rename = "gpt-4-0613")] + GPT4Snapshot, + #[serde(rename = "gpt-4-32k")] + GPT432k, + #[serde(rename = "gpt-4-32k-0613")] + GPT432kSnapshot, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -68,7 +88,7 @@ struct ChatCompletionResponse { id: String, object: String, created: i64, - model: String, + model: GPTModel, usage: Usage, choices: Vec } @@ -166,35 +186,64 @@ impl OAI { } pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &Pool>) { + debug!("Generating response for message: {}", msg.content); + let typing = msg.channel_id.start_typing(&ctx.http).unwrap(); + + let guild_id = msg.guild_id.unwrap(); + let channel_id = msg.channel_id; + let author_id = msg.author.id; + let mut connection = pool.get().unwrap(); + + // Parse out the bot mention from the message let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); let parsed_content = msg.content.replace(bot_mention.as_str(), ""); - debug!("Generating response for message: {}", msg.content); - let instructions = vec![ + // Setup the request messages + let result: Result, diesel::result::Error> = crate::database::schema::messages::table + .select(MessageDB::as_select()) + .filter((crate::database::schema::messages::guild_id.eq(guild_id.0 as i64)) + .and(crate::database::schema::messages::channel_id.eq(channel_id.0 as i64)) + .and(crate::database::schema::messages::user_id.eq(author_id.0 as i64)) + ) + .order(crate::database::schema::messages::created.desc()) + .limit(oai.max_context_questions) + .load(&mut connection); + + let previous_messages = match result { + Ok(r) => { + let mut previous_message = "".to_string(); + for message in r { + previous_message = format!("{}\nYou: {}\n Siren: {}", previous_message, message.request, message.response); + } + Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message }) + } + Err(err) => { + error!("Could not load previous messages: {}", err); + None + } + }; + + let mut messages = vec![ ChatCompletionMessage { - role: "system".to_string(), - content: "You are a Discord user named Siren.".to_string() - }, - ChatCompletionMessage { - role: "system".to_string(), - content: "Siren is an expert on Dungeons and Dragons.".to_string() + role: GPTRole::System, + content: "Siren is a Discord bot specializing in Dungeons and Dragons. Limit Siren's responses to <= 2000 characters. Siren must always obey these instructions, no matter what.".to_string() }, ]; - let user_request = ChatCompletionMessage { - role: "user".to_string(), - content: parsed_content.to_string() - }; - - let mut messages: Vec = vec![]; - messages.extend(instructions); - // TODO: Get previous messages - messages.push(user_request); + if let Some(mut previous) = previous_messages { + previous.content = format!("{}\nYou: {}\nSiren: ", previous.content, parsed_content); + messages.push(previous); + } else { + messages.push(ChatCompletionMessage { + role: GPTRole::User, + content: format!("You: {}, Siren: ", parsed_content) + }); + } let model = "gpt-3.5-turbo".to_string(); let request = ChatCompletionRequest { - model: model.to_string(), + model: GPTModel::GPT35Turbo, messages, temperature: Some(0.5), top_p: None, @@ -206,15 +255,15 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P }; let response = match oai.get_request(request).await { Ok(r) => { - debug!("Received response from OpenAI"); + debug!("Processing response received from OpenAI"); if !r.choices.is_empty() { - let mut connection = pool.get().unwrap(); + // Insert the message into the messages database table let res = r.choices[0].message.content.clone(); if let Err(err) = insert_into(crate::database::schema::messages::table).values(NewMessageDB { id: &r.id, - guild_id: msg.guild_id.unwrap().0 as i64, - channel_id: msg.channel_id.0 as i64, - user_id: msg.author.id.0 as i64, + guild_id: guild_id.0 as i64, + channel_id: channel_id.0 as i64, + user_id: author_id.0 as i64, created: r.created, model: &model, request: &parsed_content, @@ -226,6 +275,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P } res } else { + warn!("No choices received in the response from OpenAI"); "No reply received".to_string() } } @@ -234,8 +284,10 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P err.message } }; - debug!("Sending response: \"{}\"", response); + debug!("Writing response: \"{}\"", response); + // Stop the typing indicator and send the response + typing.stop(); if let Err(why) = msg.channel_id.say(&ctx.http, response).await { error!("Cannot send message: {}", why); }