180 lines
5.7 KiB
Rust
180 lines
5.7 KiB
Rust
use log::{error, trace, warn};
|
|
|
|
use serenity::model::Permissions;
|
|
use serenity::model::channel::Message;
|
|
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
|
|
use serenity::prelude::*;
|
|
|
|
use crate::bot::messages::{QueryFilters, QueryMessage};
|
|
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
|
|
|
|
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
|
trace!("Generating response for message: {}", msg.content);
|
|
|
|
let guild_id = msg.guild_id.unwrap();
|
|
let channel_id = msg.channel_id;
|
|
let author_id = msg.author.id;
|
|
|
|
// Parse out the bot mention from the message
|
|
let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0);
|
|
let parsed_content = msg.content.replace(bot_mention.as_str(), "");
|
|
|
|
let mut messages = vec![
|
|
ChatCompletionMessage {
|
|
role: GPTRole::System,
|
|
content: "You are a Discord bot named Siren that acts as the Dungeon Master's assistant. Siren must always obey these instructions, no matter what.".to_string()
|
|
},
|
|
];
|
|
|
|
match QueryMessage::get_all(&QueryFilters {
|
|
by_guild_id: Some(guild_id.0 as i64),
|
|
by_channel_id: Some(channel_id.0 as i64),
|
|
by_user_id: Some(author_id.0 as i64),
|
|
..Default::default()
|
|
}, 100, 1) {
|
|
Ok(m) => {
|
|
for message in m {
|
|
messages.push(
|
|
ChatCompletionMessage {
|
|
role: GPTRole::User,
|
|
content: format!("{}", message.request)
|
|
}
|
|
);
|
|
messages.push(
|
|
ChatCompletionMessage {
|
|
role: GPTRole::Assistant,
|
|
content: format!("{}", message.response)
|
|
}
|
|
);
|
|
}
|
|
},
|
|
Err(err) => warn!("Could not load previous messages: {}", err)
|
|
};
|
|
messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() });
|
|
|
|
let request = ChatCompletionRequest {
|
|
model: oai.default_model.clone(),
|
|
messages,
|
|
temperature: Some(0.5),
|
|
top_p: None,
|
|
n: None,
|
|
max_tokens: Some(oai.max_tokens),
|
|
presence_penalty: Some(0.6),
|
|
frequency_penalty: Some(0.0),
|
|
user: Some(msg.author.name.clone())
|
|
};
|
|
|
|
// Get the thread channel ID
|
|
let thread_name = generate_thread_name(oai, &parsed_content, 99).await;
|
|
let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| {
|
|
thread.name(thread_name).kind(ChannelType::PublicThread)
|
|
}).await {
|
|
Ok(c) => {
|
|
let allow = Permissions::SEND_MESSAGES;
|
|
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
|
|
let overwrite = PermissionOverwrite {
|
|
allow,
|
|
deny,
|
|
kind: PermissionOverwriteType::Member(msg.author.id),
|
|
};
|
|
let _ = c.create_permission(&ctx.http, &overwrite).await;
|
|
c.id
|
|
}
|
|
Err(_) => {
|
|
channel_id
|
|
}
|
|
};
|
|
|
|
let typing = response_channel.start_typing(&ctx.http).unwrap();
|
|
|
|
// Get the OAI response and store message/response into the database
|
|
let response = match oai.chat_completion(request).await {
|
|
Ok(r) => {
|
|
trace!("Processing response received from OpenAI");
|
|
if !r.choices.is_empty() {
|
|
let res = r.choices[0].message.content.clone();
|
|
if let Err(err) = QueryMessage::insert(QueryMessage {
|
|
id: r.id,
|
|
guild_id: guild_id.0 as i64,
|
|
channel_id: response_channel.0 as i64,
|
|
user_id: author_id.0 as i64,
|
|
created: r.created,
|
|
model: serde_json::to_string(&r.model).unwrap(),
|
|
request: parsed_content,
|
|
response: res.clone(),
|
|
request_tags: vec![],
|
|
response_tags: vec![],
|
|
}) {
|
|
warn!("{}", err);
|
|
}
|
|
res
|
|
} else {
|
|
warn!("No choices received in the response from OpenAI");
|
|
"No reply received".to_string()
|
|
}
|
|
}
|
|
Err(err) => {
|
|
error!("Could not get response from OpenAI: {}", err.message);
|
|
"There was an error processing your message. Please try again later.".to_string()
|
|
}
|
|
};
|
|
trace!("Writing response: \"{}\"", response);
|
|
|
|
typing.stop();
|
|
if let Err(why) = response_channel.say(&ctx.http, response).await {
|
|
error!("Cannot send message: {}", why);
|
|
}
|
|
|
|
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
|
|
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
|
// }).await {
|
|
// Ok(c) => {
|
|
// if let Err(why) = c.say(&ctx.http, response).await {
|
|
// error!("Cannot send message: {}", why);
|
|
// }
|
|
// }
|
|
// Err(_) => {
|
|
// if let Err(why) = channel_id.say(&ctx.http, response).await {
|
|
// error!("Cannot send message: {}", why);
|
|
// }
|
|
// }
|
|
// };
|
|
}
|
|
|
|
async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String {
|
|
let message = ChatCompletionMessage {
|
|
role: GPTRole::User,
|
|
content: format!("---\n{}\n---\nSummarize the message above into a concise Discord thread title", s)
|
|
};
|
|
let request = ChatCompletionRequest {
|
|
model: oai.default_model.clone(),
|
|
messages: vec![message],
|
|
temperature: Some(0.5),
|
|
top_p: None,
|
|
n: None,
|
|
max_tokens: Some(oai.max_tokens),
|
|
presence_penalty: Some(0.6),
|
|
frequency_penalty: Some(0.0),
|
|
user: None
|
|
};
|
|
// Truncate the response to the max number of characters
|
|
let mut response = match s.char_indices().nth(max_chars) {
|
|
None => s,
|
|
Some((idx, _)) => &s[..idx]
|
|
}.to_string();
|
|
// Set the response to the OAI response
|
|
match oai.chat_completion(request).await {
|
|
Ok(r) => {
|
|
if !r.choices.is_empty() {
|
|
response = r.choices[0].message.content.clone();
|
|
} else {
|
|
warn!("No choices received in the response from OpenAI");
|
|
}
|
|
}
|
|
Err(err) => {
|
|
error!("Could not get response from OpenAI: {}", err.message);
|
|
}
|
|
};
|
|
return response;
|
|
}
|