Updated chat/oai layout
This commit is contained in:
@@ -30,6 +30,7 @@ redis = { version = "0.23.3", features = ["tokio-comp", "connection-manager", "r
|
|||||||
base64 = "0.21.4"
|
base64 = "0.21.4"
|
||||||
rust-s3 = "0.33.0"
|
rust-s3 = "0.33.0"
|
||||||
actix-multipart = "0.6.1"
|
actix-multipart = "0.6.1"
|
||||||
|
openssl = "0.10.60" # Resolve `openssl` `X509StoreRef::objects` is unsound #10
|
||||||
|
|
||||||
[dependencies.tokio]
|
[dependencies.tokio]
|
||||||
version = "1.32.0"
|
version = "1.32.0"
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ services:
|
|||||||
- ${SERVICE_PORT:-5000}:5000
|
- ${SERVICE_PORT:-5000}:5000
|
||||||
depends_on:
|
depends_on:
|
||||||
- db
|
- db
|
||||||
|
- redis
|
||||||
|
- minio
|
||||||
networks:
|
networks:
|
||||||
- frontend
|
- frontend
|
||||||
- backend
|
- backend
|
||||||
@@ -53,6 +55,8 @@ services:
|
|||||||
redis:
|
redis:
|
||||||
image: redis:latest
|
image: redis:latest
|
||||||
container_name: siren-redis
|
container_name: siren-redis
|
||||||
|
volumes:
|
||||||
|
- redis:/data
|
||||||
ports:
|
ports:
|
||||||
- ${REDIS_PORT:-6379}:6379
|
- ${REDIS_PORT:-6379}:6379
|
||||||
networks:
|
networks:
|
||||||
@@ -77,6 +81,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
db:
|
db:
|
||||||
db_logs:
|
db_logs:
|
||||||
|
redis:
|
||||||
minio:
|
minio:
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
|
|||||||
@@ -103,6 +103,13 @@ pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool, volume: Opt
|
|||||||
Ok(metadata)
|
Ok(metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_playlist_urls(url: &str) -> Result<Vec<String>, String> {
|
||||||
|
let mut urls: Vec<String> = Vec::new();
|
||||||
|
// TODO fix this later
|
||||||
|
urls.push(url.to_string());
|
||||||
|
Ok(urls)
|
||||||
|
}
|
||||||
|
|
||||||
fn is_valid_url(url: &str) -> bool {
|
fn is_valid_url(url: &str) -> bool {
|
||||||
match url.parse::<reqwest::Url>() {
|
match url.parse::<reqwest::Url>() {
|
||||||
Ok(_) => return true,
|
Ok(_) => return true,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use serenity::model::application::interaction::application_command::ApplicationC
|
|||||||
use siren::ServiceError;
|
use siren::ServiceError;
|
||||||
use songbird::{EventHandler, Songbird};
|
use songbird::{EventHandler, Songbird};
|
||||||
|
|
||||||
use crate::bot::{guilds::QueryGuild, commands::audio::{leave, add_song, get_songbird}};
|
use crate::bot::{guilds::QueryGuild, commands::audio::{leave, get_playlist_urls, add_song, get_songbird}};
|
||||||
|
|
||||||
use super::{create_response, edit_response, join_by_user};
|
use super::{create_response, edit_response, join_by_user};
|
||||||
|
|
||||||
@@ -87,31 +87,42 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn play_track(manager: Arc<Songbird>, guild_id: GuildId, track_url: String) -> Result<(), ServiceError> {
|
pub async fn play_track(manager: Arc<Songbird>, guild_id: GuildId, track_url: String) -> Result<i32, ServiceError> {
|
||||||
|
let mut track_count = 0;
|
||||||
if let Some(handler_lock) = manager.get(guild_id) {
|
if let Some(handler_lock) = manager.get(guild_id) {
|
||||||
let is_queue_empty = {
|
let is_queue_empty = {
|
||||||
let call_handler = handler_lock.lock().await;
|
let call_handler = handler_lock.lock().await;
|
||||||
call_handler.queue().is_empty()
|
call_handler.queue().is_empty()
|
||||||
};
|
};
|
||||||
let guild = QueryGuild::get(guild_id.0 as i64)?;
|
let guild = QueryGuild::get(guild_id.0 as i64)?;
|
||||||
match add_song(handler_lock.clone(), &track_url, is_queue_empty, Some(guild.volume as f32 / 100.0)).await {
|
let track_urls = match get_playlist_urls(&track_url) {
|
||||||
Ok(added_song) => {
|
Ok(urls) => urls,
|
||||||
let track_title = added_song.title.unwrap();
|
|
||||||
debug!("Added track: {}", track_title);
|
|
||||||
let mut handler = handler_lock.lock().await;
|
|
||||||
handler.remove_all_global_events();
|
|
||||||
handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager })
|
|
||||||
},
|
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Failed to add song: {}", err);
|
warn!("Failed to get playlist urls: {}", err);
|
||||||
if let Err(why) = leave(manager, &Some(guild_id)).await {
|
|
||||||
error!("Failed to leave voice channel: {}", why);
|
|
||||||
}
|
|
||||||
return Err(ServiceError { status: 422, message: err.to_string() })
|
return Err(ServiceError { status: 422, message: err.to_string() })
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
for url in track_urls {
|
||||||
|
match add_song(handler_lock.clone(), &url, is_queue_empty, Some(guild.volume as f32 / 100.0)).await {
|
||||||
|
Ok(added_song) => {
|
||||||
|
let track_title = added_song.title.unwrap();
|
||||||
|
debug!("Added track: {}", track_title);
|
||||||
|
let mut handler = handler_lock.lock().await;
|
||||||
|
handler.remove_all_global_events();
|
||||||
|
handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager.clone() });
|
||||||
|
track_count += 1;
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Failed to add song: {}", err);
|
||||||
|
if let Err(why) = leave(manager, &Some(guild_id)).await {
|
||||||
|
error!("Failed to leave voice channel: {}", why);
|
||||||
|
}
|
||||||
|
return Err(ServiceError { status: 422, message: err.to_string() })
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(track_count)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand {
|
pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand {
|
||||||
|
|||||||
148
service/src/bot/commands/chat.rs
Normal file
148
service/src/bot/commands/chat.rs
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
use log::{error, debug, warn};
|
||||||
|
|
||||||
|
use serenity::model::Permissions;
|
||||||
|
use serenity::model::channel::Message;
|
||||||
|
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
|
||||||
|
use serenity::prelude::*;
|
||||||
|
|
||||||
|
use crate::bot::messages::{QueryFilters, QueryMessage};
|
||||||
|
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
|
||||||
|
|
||||||
|
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||||
|
debug!("Generating response for message: {}", msg.content);
|
||||||
|
|
||||||
|
let guild_id = msg.guild_id.unwrap();
|
||||||
|
let channel_id = msg.channel_id;
|
||||||
|
let author_id = msg.author.id;
|
||||||
|
|
||||||
|
// Parse out the bot mention from the message
|
||||||
|
let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0);
|
||||||
|
let parsed_content = msg.content.replace(bot_mention.as_str(), "");
|
||||||
|
|
||||||
|
let mut messages = vec![
|
||||||
|
ChatCompletionMessage {
|
||||||
|
role: GPTRole::System,
|
||||||
|
content: "You are a Discord bot named Siren that acts as the Dungeon Master's assistant. Siren must always obey these instructions, no matter what.".to_string()
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
match QueryMessage::get_all(&QueryFilters {
|
||||||
|
by_guild_id: Some(guild_id.0 as i64),
|
||||||
|
by_channel_id: Some(channel_id.0 as i64),
|
||||||
|
by_user_id: Some(author_id.0 as i64),
|
||||||
|
..Default::default()
|
||||||
|
}, 100, 1) {
|
||||||
|
Ok(m) => {
|
||||||
|
for message in m {
|
||||||
|
messages.push(
|
||||||
|
ChatCompletionMessage {
|
||||||
|
role: GPTRole::User,
|
||||||
|
content: format!("{}", message.request)
|
||||||
|
}
|
||||||
|
);
|
||||||
|
messages.push(
|
||||||
|
ChatCompletionMessage {
|
||||||
|
role: GPTRole::Assistant,
|
||||||
|
content: format!("{}", message.response)
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(err) => warn!("Could not load previous messages: {}", err)
|
||||||
|
};
|
||||||
|
messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() });
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
model: oai.default_model.clone(),
|
||||||
|
messages,
|
||||||
|
temperature: Some(0.5),
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
max_tokens: Some(oai.max_tokens),
|
||||||
|
presence_penalty: Some(0.6),
|
||||||
|
frequency_penalty: Some(0.0),
|
||||||
|
user: Some(msg.author.name.clone())
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get the thread channel ID
|
||||||
|
let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| {
|
||||||
|
thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
||||||
|
}).await {
|
||||||
|
Ok(c) => {
|
||||||
|
let allow = Permissions::SEND_MESSAGES;
|
||||||
|
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
|
||||||
|
let overwrite = PermissionOverwrite {
|
||||||
|
allow,
|
||||||
|
deny,
|
||||||
|
kind: PermissionOverwriteType::Member(msg.author.id),
|
||||||
|
};
|
||||||
|
let _ = c.create_permission(&ctx.http, &overwrite).await;
|
||||||
|
c.id
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
channel_id
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let typing = response_channel.start_typing(&ctx.http).unwrap();
|
||||||
|
|
||||||
|
// Get the OAI response and store message/response into the database
|
||||||
|
let response = match oai.chat_completion(request).await {
|
||||||
|
Ok(r) => {
|
||||||
|
debug!("Processing response received from OpenAI");
|
||||||
|
if !r.choices.is_empty() {
|
||||||
|
let res = r.choices[0].message.content.clone();
|
||||||
|
if let Err(err) = QueryMessage::insert(QueryMessage {
|
||||||
|
id: r.id,
|
||||||
|
guild_id: guild_id.0 as i64,
|
||||||
|
channel_id: response_channel.0 as i64,
|
||||||
|
user_id: author_id.0 as i64,
|
||||||
|
created: r.created,
|
||||||
|
model: serde_json::to_string(&r.model).unwrap(),
|
||||||
|
request: parsed_content,
|
||||||
|
response: res.clone(),
|
||||||
|
request_tags: vec![],
|
||||||
|
response_tags: vec![],
|
||||||
|
}) {
|
||||||
|
warn!("{}", err);
|
||||||
|
}
|
||||||
|
res
|
||||||
|
} else {
|
||||||
|
warn!("No choices received in the response from OpenAI");
|
||||||
|
"No reply received".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!("Could not get response from OpenAI: {}", err.message);
|
||||||
|
"There was an error processing your message. Please try again later.".to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
debug!("Writing response: \"{}\"", response);
|
||||||
|
|
||||||
|
typing.stop();
|
||||||
|
if let Err(why) = response_channel.say(&ctx.http, response).await {
|
||||||
|
error!("Cannot send message: {}", why);
|
||||||
|
}
|
||||||
|
|
||||||
|
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
|
||||||
|
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
||||||
|
// }).await {
|
||||||
|
// Ok(c) => {
|
||||||
|
// if let Err(why) = c.say(&ctx.http, response).await {
|
||||||
|
// error!("Cannot send message: {}", why);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// Err(_) => {
|
||||||
|
// if let Err(why) = channel_id.say(&ctx.http, response).await {
|
||||||
|
// error!("Cannot send message: {}", why);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate(s: &str, max_chars: usize) -> &str {
|
||||||
|
match s.char_indices().nth(max_chars) {
|
||||||
|
None => s,
|
||||||
|
Some((idx, _)) => &s[..idx],
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
pub mod audio;
|
pub mod audio;
|
||||||
pub mod help;
|
pub mod help;
|
||||||
pub mod message;
|
pub mod message;
|
||||||
pub mod oai;
|
pub mod chat;
|
||||||
pub mod ping;
|
pub mod ping;
|
||||||
pub mod schedule;
|
pub mod schedule;
|
||||||
|
|||||||
@@ -1,326 +0,0 @@
|
|||||||
use log::{error, debug, trace, warn};
|
|
||||||
|
|
||||||
use serde::{Serialize, Deserialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
use serenity::model::Permissions;
|
|
||||||
use serenity::model::channel::Message;
|
|
||||||
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
|
|
||||||
use serenity::prelude::*;
|
|
||||||
use siren::{Response, ServiceError};
|
|
||||||
|
|
||||||
pub struct OAI {
|
|
||||||
pub client: reqwest::Client,
|
|
||||||
pub base_url: String,
|
|
||||||
pub service_url: String,
|
|
||||||
pub max_attempts: i64,
|
|
||||||
pub token: String,
|
|
||||||
pub max_tokens: i64,
|
|
||||||
pub default_model: GPTModel,
|
|
||||||
pub max_context_questions: i64
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ChatCompletionRequest {
|
|
||||||
model: GPTModel,
|
|
||||||
messages: Vec<ChatCompletionMessage>,
|
|
||||||
/// Value between 0 and 2
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
/// Value between 0 and 1
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
n: Option<f64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
max_tokens: Option<i64>,
|
|
||||||
/// Value between -2.0 and 2.0
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
presence_penalty: Option<f64>,
|
|
||||||
/// Value between -2.0 and 2.0
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
frequency_penalty: Option<f64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
user: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ChatCompletionMessage {
|
|
||||||
role: GPTRole,
|
|
||||||
content: String
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
enum GPTRole {
|
|
||||||
#[serde(rename = "system")]
|
|
||||||
System,
|
|
||||||
#[serde(rename = "user")]
|
|
||||||
User,
|
|
||||||
#[serde(rename = "assistant")]
|
|
||||||
Assistant,
|
|
||||||
#[serde(rename = "function")]
|
|
||||||
Function
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub enum GPTModel {
|
|
||||||
#[serde(rename = "gpt-3.5-turbo")]
|
|
||||||
GPT35Turbo,
|
|
||||||
#[serde(rename = "gpt-3.5-turbo-0613")]
|
|
||||||
GPT35Snapshot,
|
|
||||||
#[serde(rename = "gpt-3.5-turbo-16k")]
|
|
||||||
GPT3516k,
|
|
||||||
#[serde(rename = "gpt-3.5-turbo-16k-0613")]
|
|
||||||
GPT3516kSnapshot,
|
|
||||||
#[serde(rename = "gpt-4")]
|
|
||||||
GPT4,
|
|
||||||
#[serde(rename = "gpt-4-0613")]
|
|
||||||
GPT4Snapshot,
|
|
||||||
#[serde(rename = "gpt-4-32k")]
|
|
||||||
GPT432k,
|
|
||||||
#[serde(rename = "gpt-4-32k-0613")]
|
|
||||||
GPT432kSnapshot,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ChatCompletionResponse {
|
|
||||||
id: String,
|
|
||||||
object: String,
|
|
||||||
created: i64,
|
|
||||||
model: GPTModel,
|
|
||||||
usage: Usage,
|
|
||||||
choices: Vec<Choice>
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct Usage {
|
|
||||||
prompt_tokens: i64,
|
|
||||||
completion_tokens: i64,
|
|
||||||
total_tokens: i64
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct Choice {
|
|
||||||
message: ChatCompletionMessage,
|
|
||||||
finish_reason: String,
|
|
||||||
index: i64
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ResponseError {
|
|
||||||
error: Option<ErrorDetails>,
|
|
||||||
message: Option<String>,
|
|
||||||
param: Option<String>,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
error_type: Option<String>
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ErrorDetails {
|
|
||||||
code: Option<String>
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
enum ResponseEvent {
|
|
||||||
ChatCompletionResponse(ChatCompletionResponse),
|
|
||||||
ResponseError(ResponseError)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OAI {
|
|
||||||
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
|
|
||||||
let uri = format!("{}/chat/completions", self.base_url);
|
|
||||||
let body = serde_json::to_string(&request).unwrap();
|
|
||||||
trace!("Sending request to {}: {}", uri, body);
|
|
||||||
|
|
||||||
let value = self.client
|
|
||||||
.post(&uri)
|
|
||||||
.bearer_auth(&self.token)
|
|
||||||
.header("Content-Type", "application/json".to_string())
|
|
||||||
.body(body)
|
|
||||||
.send()
|
|
||||||
.await?
|
|
||||||
.json::<Value>()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
trace!("Received response from OpenAI: {:?}", value);
|
|
||||||
|
|
||||||
// let response = match serde_json::from_value::<ResponseEvent>(value) {
|
|
||||||
// Ok(r) => {
|
|
||||||
// match r {
|
|
||||||
// ResponseEvent::ChatCompletionResponse(r) => r,
|
|
||||||
// ResponseEvent::ResponseError(e) => return Err(ServiceError { message: e.message.unwrap_or("Unknown error".to_string()), status: 500 }),
|
|
||||||
// }
|
|
||||||
// },
|
|
||||||
// Err(err) => return Err(ServiceError {
|
|
||||||
// message: format!("Could not parse response from OpenAI: {}", err),
|
|
||||||
// status: 500
|
|
||||||
// })
|
|
||||||
// };
|
|
||||||
let response = serde_json::from_value::<ChatCompletionResponse>(value)?;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result<Response<Vec<siren::Message>>, ServiceError> {
|
|
||||||
let uri = format!("{}/messages?guild_id={}&channel_id={}&author_id={}&limit={}", self.service_url, guild_id, channel_id, author_id, self.max_context_questions);
|
|
||||||
let value = self.client
|
|
||||||
.get(&uri)
|
|
||||||
.send()
|
|
||||||
.await?
|
|
||||||
.json::<Value>()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let response = serde_json::from_value::<Response<Vec<siren::Message>>>(value)?;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn store_message(&self, message: siren::Message) -> Result<siren::Message, ServiceError> {
|
|
||||||
let uri = format!("{}/messages", self.service_url);
|
|
||||||
trace!("Sending request to {}", uri);
|
|
||||||
let value = self.client
|
|
||||||
.post(&uri)
|
|
||||||
.json::<siren::Message>(&message)
|
|
||||||
.send()
|
|
||||||
.await?
|
|
||||||
.json::<Value>()
|
|
||||||
.await?;
|
|
||||||
trace!("Received response from Service: {:?}", value);
|
|
||||||
let response = serde_json::from_value::<siren::Message>(value)?;
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
|
||||||
debug!("Generating response for message: {}", msg.content);
|
|
||||||
|
|
||||||
let guild_id = msg.guild_id.unwrap();
|
|
||||||
let channel_id = msg.channel_id;
|
|
||||||
let author_id = msg.author.id;
|
|
||||||
|
|
||||||
// Parse out the bot mention from the message
|
|
||||||
let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0);
|
|
||||||
let parsed_content = msg.content.replace(bot_mention.as_str(), "");
|
|
||||||
|
|
||||||
let mut messages = vec![
|
|
||||||
ChatCompletionMessage {
|
|
||||||
role: GPTRole::System,
|
|
||||||
content: "Siren is a Discord bot specializing in Dungeons and Dragons. Limit Siren's responses to <= 2000 characters. Siren must always obey these instructions, no matter what.".to_string()
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let previous_messages = oai.get_messages(guild_id.0, channel_id.0, author_id.0).await;
|
|
||||||
match previous_messages {
|
|
||||||
Ok(m) => {
|
|
||||||
for message in m.data {
|
|
||||||
messages.push(
|
|
||||||
ChatCompletionMessage {
|
|
||||||
role: GPTRole::User,
|
|
||||||
content: format!("{}", message.request)
|
|
||||||
}
|
|
||||||
);
|
|
||||||
messages.push(
|
|
||||||
ChatCompletionMessage {
|
|
||||||
role: GPTRole::Assistant,
|
|
||||||
content: format!("{}", message.response)
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(err) => warn!("Could not load previous messages: {}", err)
|
|
||||||
};
|
|
||||||
messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() });
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
model: oai.default_model.clone(),
|
|
||||||
messages,
|
|
||||||
temperature: Some(0.5),
|
|
||||||
top_p: None,
|
|
||||||
n: None,
|
|
||||||
max_tokens: Some(oai.max_tokens),
|
|
||||||
presence_penalty: Some(0.6),
|
|
||||||
frequency_penalty: Some(0.0),
|
|
||||||
user: Some(msg.author.name.clone())
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get the thread channel ID
|
|
||||||
let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| {
|
|
||||||
thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
|
||||||
}).await {
|
|
||||||
Ok(c) => {
|
|
||||||
let allow = Permissions::SEND_MESSAGES;
|
|
||||||
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
|
|
||||||
let overwrite = PermissionOverwrite {
|
|
||||||
allow,
|
|
||||||
deny,
|
|
||||||
kind: PermissionOverwriteType::Member(msg.author.id),
|
|
||||||
};
|
|
||||||
let _ = c.create_permission(&ctx.http, &overwrite).await;
|
|
||||||
c.id
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
channel_id
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let typing = response_channel.start_typing(&ctx.http).unwrap();
|
|
||||||
|
|
||||||
// Get the OAI response and store message/response into the database
|
|
||||||
let response = match oai.get_request(request).await {
|
|
||||||
Ok(r) => {
|
|
||||||
debug!("Processing response received from OpenAI");
|
|
||||||
if !r.choices.is_empty() {
|
|
||||||
let res = r.choices[0].message.content.clone();
|
|
||||||
if let Err(err) = oai.store_message(siren::Message {
|
|
||||||
id: r.id,
|
|
||||||
guild_id: guild_id.0 as i64,
|
|
||||||
channel_id: response_channel.0 as i64,
|
|
||||||
user_id: author_id.0 as i64,
|
|
||||||
created: r.created,
|
|
||||||
model: serde_json::to_string(&r.model).unwrap(),
|
|
||||||
request: parsed_content,
|
|
||||||
response: res.clone(),
|
|
||||||
request_tags: vec![],
|
|
||||||
response_tags: vec![],
|
|
||||||
}).await {
|
|
||||||
warn!("{}", err);
|
|
||||||
}
|
|
||||||
res
|
|
||||||
} else {
|
|
||||||
warn!("No choices received in the response from OpenAI");
|
|
||||||
"No reply received".to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
error!("Could not get response from OpenAI: {}", err.message);
|
|
||||||
"There was an error processing your message. Please try again later.".to_string()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
debug!("Writing response: \"{}\"", response);
|
|
||||||
|
|
||||||
typing.stop();
|
|
||||||
if let Err(why) = response_channel.say(&ctx.http, response).await {
|
|
||||||
error!("Cannot send message: {}", why);
|
|
||||||
}
|
|
||||||
|
|
||||||
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
|
|
||||||
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
|
||||||
// }).await {
|
|
||||||
// Ok(c) => {
|
|
||||||
// if let Err(why) = c.say(&ctx.http, response).await {
|
|
||||||
// error!("Cannot send message: {}", why);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// Err(_) => {
|
|
||||||
// if let Err(why) = channel_id.say(&ctx.http, response).await {
|
|
||||||
// error!("Cannot send message: {}", why);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate(s: &str, max_chars: usize) -> &str {
|
|
||||||
match s.char_indices().nth(max_chars) {
|
|
||||||
None => s,
|
|
||||||
Some((idx, _)) => &s[..idx],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,12 +7,12 @@ use serenity::prelude::*;
|
|||||||
|
|
||||||
use crate::bot::guilds::InsertGuild;
|
use crate::bot::guilds::InsertGuild;
|
||||||
|
|
||||||
use super::commands;
|
use super::{commands, oai};
|
||||||
use super::commands::audio::create_response;
|
use super::commands::audio::create_response;
|
||||||
|
|
||||||
pub struct Handler {
|
pub struct Handler {
|
||||||
// Open AI Config
|
// Open AI Config
|
||||||
pub oai: Option<commands::oai::OAI>
|
pub oai: Option<oai::OAI>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -36,7 +36,7 @@ impl EventHandler for Handler {
|
|||||||
Err(_) => false
|
Err(_) => false
|
||||||
};
|
};
|
||||||
if mentioned || bot_in_thread {
|
if mentioned || bot_in_thread {
|
||||||
commands::oai::generate_response(&ctx, &msg, oai).await;
|
commands::chat::generate_response(&ctx, &msg, oai).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(why) => warn!("Could not check mentions: {:?}", why)
|
Err(why) => warn!("Could not check mentions: {:?}", why)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use siren::ServiceError;
|
|||||||
|
|
||||||
use crate::storage::{schema::messages::{self}, connection};
|
use crate::storage::{schema::messages::{self}, connection};
|
||||||
|
|
||||||
#[derive(Queryable, Selectable, Serialize, Deserialize)]
|
#[derive(Queryable, Selectable, Insertable, AsChangeset, Serialize, Deserialize)]
|
||||||
#[diesel(table_name = messages)]
|
#[diesel(table_name = messages)]
|
||||||
pub struct QueryMessage {
|
pub struct QueryMessage {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -122,24 +122,7 @@ impl QueryMessage {
|
|||||||
let count = query.count().get_result::<i64>(&mut conn)?;
|
let count = query.count().get_result::<i64>(&mut conn)?;
|
||||||
Ok(count)
|
Ok(count)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Insertable, AsChangeset, Serialize, Deserialize)]
|
|
||||||
#[diesel(table_name = messages)]
|
|
||||||
pub struct InsertMessage {
|
|
||||||
pub id: String,
|
|
||||||
pub guild_id: i64,
|
|
||||||
pub channel_id: i64,
|
|
||||||
pub user_id: i64,
|
|
||||||
pub created: i64,
|
|
||||||
pub model: String,
|
|
||||||
pub request: String,
|
|
||||||
pub response: String,
|
|
||||||
pub request_tags: Vec<String>,
|
|
||||||
pub response_tags: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InsertMessage {
|
|
||||||
pub fn insert(message: Self) -> Result<QueryMessage, ServiceError> {
|
pub fn insert(message: Self) -> Result<QueryMessage, ServiceError> {
|
||||||
let mut conn = connection()?;
|
let mut conn = connection()?;
|
||||||
let message = diesel::insert_into(messages::table)
|
let message = diesel::insert_into(messages::table)
|
||||||
@@ -147,4 +130,4 @@ impl InsertMessage {
|
|||||||
.get_result(&mut conn)?;
|
.get_result(&mut conn)?;
|
||||||
Ok(message)
|
Ok(message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use log::error;
|
|||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use siren::{Response, Metadata, ServiceError};
|
use siren::{Response, Metadata, ServiceError};
|
||||||
|
|
||||||
use crate::{bot::messages::{QueryMessage, QueryFilters, InsertMessage}, auth::{JwtAuth, verify_role}};
|
use crate::{bot::messages::{QueryMessage, QueryFilters}, auth::{JwtAuth, verify_role}};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct GetAllParams {
|
struct GetAllParams {
|
||||||
@@ -68,12 +68,12 @@ async fn get_all(req: HttpRequest, auth: JwtAuth) -> HttpResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[post("/messages")]
|
#[post("/messages")]
|
||||||
async fn create(message: web::Json<InsertMessage>, auth: JwtAuth) -> HttpResponse {
|
async fn create(message: web::Json<QueryMessage>, auth: JwtAuth) -> HttpResponse {
|
||||||
let _ = match verify_role(&auth, "admin") {
|
let _ = match verify_role(&auth, "admin") {
|
||||||
Ok(_) => {},
|
Ok(_) => {},
|
||||||
Err(err) => return ResponseError::error_response(&err)
|
Err(err) => return ResponseError::error_response(&err)
|
||||||
};
|
};
|
||||||
match InsertMessage::insert(message.into_inner()) {
|
match QueryMessage::insert(message.into_inner()) {
|
||||||
Ok(message) => HttpResponse::Created().json(message),
|
Ok(message) => HttpResponse::Created().json(message),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("{:?}", err.message);
|
error!("{:?}", err.message);
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ pub mod commands;
|
|||||||
pub mod guilds;
|
pub mod guilds;
|
||||||
pub mod handler;
|
pub mod handler;
|
||||||
pub mod messages;
|
pub mod messages;
|
||||||
|
pub mod oai;
|
||||||
|
|||||||
3
service/src/bot/oai/mod.rs
Normal file
3
service/src/bot/oai/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
mod model;
|
||||||
|
|
||||||
|
pub use model::*;
|
||||||
128
service/src/bot/oai/model.rs
Normal file
128
service/src/bot/oai/model.rs
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use siren::ServiceError;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub enum GPTRole {
|
||||||
|
#[serde(rename = "system")]
|
||||||
|
System,
|
||||||
|
#[serde(rename = "user")]
|
||||||
|
User,
|
||||||
|
#[serde(rename = "assistant")]
|
||||||
|
Assistant,
|
||||||
|
#[serde(rename = "function")]
|
||||||
|
Function
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<ChatCompletionMessage>,
|
||||||
|
/// Value between 0 and 2
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
/// Value between 0 and 1
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<i64>,
|
||||||
|
/// Value between -2.0 and 2.0
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f64>,
|
||||||
|
/// Value between -2.0 and 2.0
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionMessage {
|
||||||
|
pub role: GPTRole,
|
||||||
|
pub content: String
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
pub created: i64,
|
||||||
|
pub model: String,
|
||||||
|
pub usage: Usage,
|
||||||
|
pub choices: Vec<Choice>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: i64,
|
||||||
|
pub completion_tokens: i64,
|
||||||
|
pub total_tokens: i64
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Choice {
|
||||||
|
pub message: ChatCompletionMessage,
|
||||||
|
pub finish_reason: String,
|
||||||
|
pub index: i64,
|
||||||
|
pub logprobs: Option<String>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
enum ResponseEvent {
|
||||||
|
ChatCompletionResponse(ChatCompletionResponse),
|
||||||
|
ResponseError(ResponseError)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct ResponseError {
|
||||||
|
error: Option<ErrorDetails>,
|
||||||
|
message: Option<String>,
|
||||||
|
param: Option<String>,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
error_type: Option<String>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct ErrorDetails {
|
||||||
|
code: Option<String>,
|
||||||
|
message: Option<String>
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OAI {
|
||||||
|
pub client: reqwest::Client,
|
||||||
|
pub base_url: String,
|
||||||
|
pub service_url: String,
|
||||||
|
pub max_attempts: i64,
|
||||||
|
pub token: String,
|
||||||
|
pub max_tokens: i64,
|
||||||
|
pub default_model: String,
|
||||||
|
pub max_context_questions: i64
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAI {
|
||||||
|
pub async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
|
||||||
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
let response = self.client.post(&url)
|
||||||
|
.bearer_auth(&self.token)
|
||||||
|
.header("Content-Type", "application/json".to_string())
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
match response {
|
||||||
|
Ok(response) => {
|
||||||
|
let value = response.json::<Value>().await?;
|
||||||
|
// let event: ResponseEvent = serde_json::from_value::<ResponseEvent>(value)?;
|
||||||
|
// match event {
|
||||||
|
// ResponseEvent::ChatCompletionResponse(response) => return Ok(response),
|
||||||
|
// ResponseEvent::ResponseError(error) => return Err(ServiceError { status: 500, message: format!("Error: {}", error.message.unwrap()) })
|
||||||
|
// }
|
||||||
|
let res = serde_json::from_value::<ChatCompletionResponse>(value)?;
|
||||||
|
return Ok(res);
|
||||||
|
},
|
||||||
|
Err(err) => return Err(ServiceError { status: 500, message: format!("Error: {}", err) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,7 @@ use songbird::{SerenityInit, Songbird};
|
|||||||
|
|
||||||
use actix_cors::Cors;
|
use actix_cors::Cors;
|
||||||
use actix_web::{HttpServer, App, web};
|
use actix_web::{HttpServer, App, web};
|
||||||
use crate::bot::{commands::oai::GPTModel, handler::Handler};
|
use crate::bot::handler::Handler;
|
||||||
|
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
|
|
||||||
@@ -57,8 +57,9 @@ async fn main() -> std::io::Result<()> {
|
|||||||
let handler = match env::var("OPENAI_API_KEY") {
|
let handler = match env::var("OPENAI_API_KEY") {
|
||||||
Ok(token) => {
|
Ok(token) => {
|
||||||
info!("Loaded OpenAI token");
|
info!("Loaded OpenAI token");
|
||||||
|
let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-3.5-turbo".to_string());
|
||||||
Handler {
|
Handler {
|
||||||
oai: Some(bot::commands::oai::OAI {
|
oai: Some(bot::oai::OAI {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
base_url: "https://api.openai.com/v1".to_string(),
|
base_url: "https://api.openai.com/v1".to_string(),
|
||||||
service_url: "http://localhost:5000".to_string(),
|
service_url: "http://localhost:5000".to_string(),
|
||||||
@@ -66,7 +67,7 @@ async fn main() -> std::io::Result<()> {
|
|||||||
token,
|
token,
|
||||||
max_context_questions: 30,
|
max_context_questions: 30,
|
||||||
max_tokens: 2048,
|
max_tokens: 2048,
|
||||||
default_model: GPTModel::GPT35Turbo,
|
default_model,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user