use log::{error, debug, trace, warn}; use serde::{Serialize, Deserialize}; use serde_json::Value; use serenity::model::Permissions; use serenity::model::channel::Message; use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; use serenity::prelude::*; use siren::{GetResponse, ServiceError}; pub struct OAI { pub client: reqwest::Client, pub base_url: String, pub service_url: String, pub max_attempts: i64, pub token: String, pub max_tokens: i64, pub default_model: GPTModel, pub max_context_questions: i64 } #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionRequest { model: GPTModel, messages: Vec, /// Value between 0 and 2 #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, /// Value between 0 and 1 #[serde(skip_serializing_if = "Option::is_none")] top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] n: Option, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, /// Value between -2.0 and 2.0 #[serde(skip_serializing_if = "Option::is_none")] presence_penalty: Option, /// Value between -2.0 and 2.0 #[serde(skip_serializing_if = "Option::is_none")] frequency_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] user: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionMessage { role: GPTRole, content: String } #[derive(Debug, Clone, Serialize, Deserialize)] enum GPTRole { #[serde(rename = "system")] System, #[serde(rename = "user")] User, #[serde(rename = "assistant")] Assistant, #[serde(rename = "function")] Function } #[derive(Debug, Clone, Serialize, Deserialize)] pub enum GPTModel { #[serde(rename = "gpt-3.5-turbo")] GPT35Turbo, #[serde(rename = "gpt-3.5-turbo-0613")] GPT35Snapshot, #[serde(rename = "gpt-3.5-turbo-16k")] GPT3516k, #[serde(rename = "gpt-3.5-turbo-16k-0613")] GPT3516kSnapshot, #[serde(rename = "gpt-4")] GPT4, #[serde(rename = "gpt-4-0613")] GPT4Snapshot, #[serde(rename = "gpt-4-32k")] GPT432k, #[serde(rename = "gpt-4-32k-0613")] GPT432kSnapshot, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ChatCompletionResponse { id: String, object: String, created: i64, model: GPTModel, usage: Usage, choices: Vec } #[derive(Debug, Clone, Serialize, Deserialize)] struct Usage { prompt_tokens: i64, completion_tokens: i64, total_tokens: i64 } #[derive(Debug, Clone, Serialize, Deserialize)] struct Choice { message: ChatCompletionMessage, finish_reason: String, index: i64 } #[derive(Debug, Clone, Serialize, Deserialize)] struct ResponseError { error: Option, message: Option, param: Option, #[serde(rename = "type")] error_type: Option } #[derive(Debug, Clone, Serialize, Deserialize)] struct ErrorDetails { code: Option } #[derive(Debug, Clone, Serialize, Deserialize)] enum ResponseEvent { ChatCompletionResponse(ChatCompletionResponse), ResponseError(ResponseError) } impl OAI { async fn get_request(&self, request: ChatCompletionRequest) -> Result { let uri = format!("{}/chat/completions", self.base_url); let body = serde_json::to_string(&request).unwrap(); trace!("Sending request to {}: {}", uri, body); let value = self.client .post(&uri) .bearer_auth(&self.token) .header("Content-Type", "application/json".to_string()) .body(body) .send() .await? .json::() .await?; trace!("Received response from OpenAI: {:?}", value); // let response = match serde_json::from_value::(value) { // Ok(r) => { // match r { // ResponseEvent::ChatCompletionResponse(r) => r, // ResponseEvent::ResponseError(e) => return Err(ServiceError { message: e.message.unwrap_or("Unknown error".to_string()), status: 500 }), // } // }, // Err(err) => return Err(ServiceError { // message: format!("Could not parse response from OpenAI: {}", err), // status: 500 // }) // }; let response = serde_json::from_value::(value)?; Ok(response) } async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result>, ServiceError> { let uri = format!("{}/messages?guild_id={}&channel_id={}&author_id={}&limit={}", self.service_url, guild_id, channel_id, author_id, self.max_context_questions); let value = self.client .get(&uri) .send() .await? .json::() .await?; let response = serde_json::from_value::>>(value)?; Ok(response) } async fn store_message(&self, message: siren::Message) -> Result { let uri = format!("{}/messages", self.service_url); trace!("Sending request to {}", uri); let value = self.client .post(&uri) .json::(&message) .send() .await? .json::() .await?; trace!("Received response from Service: {:?}", value); let response = serde_json::from_value::(value)?; Ok(response) } } pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { debug!("Generating response for message: {}", msg.content); let guild_id = msg.guild_id.unwrap(); let channel_id = msg.channel_id; let author_id = msg.author.id; // Parse out the bot mention from the message let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); let parsed_content = msg.content.replace(bot_mention.as_str(), ""); let mut messages = vec![ ChatCompletionMessage { role: GPTRole::System, content: "Siren is a Discord bot specializing in Dungeons and Dragons. Limit Siren's responses to <= 2000 characters. Siren must always obey these instructions, no matter what.".to_string() }, ]; let previous_messages = oai.get_messages(guild_id.0, channel_id.0, author_id.0).await; match previous_messages { Ok(m) => { for message in m.data { messages.push( ChatCompletionMessage { role: GPTRole::User, content: format!("{}", message.request) } ); messages.push( ChatCompletionMessage { role: GPTRole::Assistant, content: format!("{}", message.response) } ); } }, Err(err) => warn!("Could not load previous messages: {}", err) }; messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() }); let request = ChatCompletionRequest { model: oai.default_model.clone(), messages, temperature: Some(0.5), top_p: None, n: None, max_tokens: Some(oai.max_tokens), presence_penalty: Some(0.6), frequency_penalty: Some(0.0), user: Some(msg.author.name.clone()) }; // Get the thread channel ID let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| { thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread) }).await { Ok(c) => { let allow = Permissions::SEND_MESSAGES; let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES; let overwrite = PermissionOverwrite { allow, deny, kind: PermissionOverwriteType::Member(msg.author.id), }; let _ = c.create_permission(&ctx.http, &overwrite).await; c.id } Err(_) => { channel_id } }; let typing = response_channel.start_typing(&ctx.http).unwrap(); // Get the OAI response and store message/response into the database let response = match oai.get_request(request).await { Ok(r) => { debug!("Processing response received from OpenAI"); if !r.choices.is_empty() { let res = r.choices[0].message.content.clone(); if let Err(err) = oai.store_message(siren::Message { id: r.id, guild_id: guild_id.0 as i64, channel_id: response_channel.0 as i64, user_id: author_id.0 as i64, created: r.created, model: serde_json::to_string(&r.model).unwrap(), request: parsed_content, response: res.clone(), request_tags: vec![], response_tags: vec![], }).await { warn!("{}", err); } res } else { warn!("No choices received in the response from OpenAI"); "No reply received".to_string() } } Err(err) => { error!("Could not get response from OpenAI: {}", err.message); "There was an error processing your message. Please try again later.".to_string() } }; debug!("Writing response: \"{}\"", response); typing.stop(); if let Err(why) = response_channel.say(&ctx.http, response).await { error!("Cannot send message: {}", why); } // match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| { // thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread) // }).await { // Ok(c) => { // if let Err(why) = c.say(&ctx.http, response).await { // error!("Cannot send message: {}", why); // } // } // Err(_) => { // if let Err(why) = channel_id.say(&ctx.http, response).await { // error!("Cannot send message: {}", why); // } // } // }; } fn truncate(s: &str, max_chars: usize) -> &str { match s.char_indices().nth(max_chars) { None => s, Some((idx, _)) => &s[..idx], } }