From d04c34d555e779c9f773fdb4e75a8838ce4de70b Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Sun, 28 Jan 2024 11:07:32 -0500 Subject: [PATCH] Updated chat/oai layout --- service/Cargo.toml | 1 + service/docker-compose.yml | 5 + service/src/bot/commands/audio/mod.rs | 7 + service/src/bot/commands/audio/play.rs | 41 ++-- service/src/bot/commands/chat.rs | 148 +++++++++++ service/src/bot/commands/mod.rs | 2 +- service/src/bot/commands/oai.rs | 326 ------------------------- service/src/bot/handler.rs | 6 +- service/src/bot/messages/model.rs | 21 +- service/src/bot/messages/routes.rs | 6 +- service/src/bot/mod.rs | 1 + service/src/bot/oai/mod.rs | 3 + service/src/bot/oai/model.rs | 128 ++++++++++ service/src/main.rs | 7 +- 14 files changed, 332 insertions(+), 370 deletions(-) create mode 100644 service/src/bot/commands/chat.rs delete mode 100644 service/src/bot/commands/oai.rs create mode 100644 service/src/bot/oai/mod.rs create mode 100644 service/src/bot/oai/model.rs diff --git a/service/Cargo.toml b/service/Cargo.toml index 615fb9e..238273a 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -30,6 +30,7 @@ redis = { version = "0.23.3", features = ["tokio-comp", "connection-manager", "r base64 = "0.21.4" rust-s3 = "0.33.0" actix-multipart = "0.6.1" +openssl = "0.10.60" # Resolve `openssl` `X509StoreRef::objects` is unsound #10 [dependencies.tokio] version = "1.32.0" diff --git a/service/docker-compose.yml b/service/docker-compose.yml index 6550479..c6880a0 100644 --- a/service/docker-compose.yml +++ b/service/docker-compose.yml @@ -30,6 +30,8 @@ services: - ${SERVICE_PORT:-5000}:5000 depends_on: - db + - redis + - minio networks: - frontend - backend @@ -53,6 +55,8 @@ services: redis: image: redis:latest container_name: siren-redis + volumes: + - redis:/data ports: - ${REDIS_PORT:-6379}:6379 networks: @@ -77,6 +81,7 @@ services: volumes: db: db_logs: + redis: minio: networks: diff --git a/service/src/bot/commands/audio/mod.rs b/service/src/bot/commands/audio/mod.rs index 7d0a793..40b3e24 100644 --- a/service/src/bot/commands/audio/mod.rs +++ b/service/src/bot/commands/audio/mod.rs @@ -103,6 +103,13 @@ pub async fn add_song(call: Arc>, url: &str, lazy: bool, volume: Opt Ok(metadata) } +pub fn get_playlist_urls(url: &str) -> Result, String> { + let mut urls: Vec = Vec::new(); + // TODO fix this later + urls.push(url.to_string()); + Ok(urls) +} + fn is_valid_url(url: &str) -> bool { match url.parse::() { Ok(_) => return true, diff --git a/service/src/bot/commands/audio/play.rs b/service/src/bot/commands/audio/play.rs index b4272be..f7f950e 100644 --- a/service/src/bot/commands/audio/play.rs +++ b/service/src/bot/commands/audio/play.rs @@ -9,7 +9,7 @@ use serenity::model::application::interaction::application_command::ApplicationC use siren::ServiceError; use songbird::{EventHandler, Songbird}; -use crate::bot::{guilds::QueryGuild, commands::audio::{leave, add_song, get_songbird}}; +use crate::bot::{guilds::QueryGuild, commands::audio::{leave, get_playlist_urls, add_song, get_songbird}}; use super::{create_response, edit_response, join_by_user}; @@ -87,31 +87,42 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) { } } -pub async fn play_track(manager: Arc, guild_id: GuildId, track_url: String) -> Result<(), ServiceError> { +pub async fn play_track(manager: Arc, guild_id: GuildId, track_url: String) -> Result { + let mut track_count = 0; if let Some(handler_lock) = manager.get(guild_id) { let is_queue_empty = { let call_handler = handler_lock.lock().await; call_handler.queue().is_empty() }; let guild = QueryGuild::get(guild_id.0 as i64)?; - match add_song(handler_lock.clone(), &track_url, is_queue_empty, Some(guild.volume as f32 / 100.0)).await { - Ok(added_song) => { - let track_title = added_song.title.unwrap(); - debug!("Added track: {}", track_title); - let mut handler = handler_lock.lock().await; - handler.remove_all_global_events(); - handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager }) - }, + let track_urls = match get_playlist_urls(&track_url) { + Ok(urls) => urls, Err(err) => { - warn!("Failed to add song: {}", err); - if let Err(why) = leave(manager, &Some(guild_id)).await { - error!("Failed to leave voice channel: {}", why); - } + warn!("Failed to get playlist urls: {}", err); return Err(ServiceError { status: 422, message: err.to_string() }) } + }; + for url in track_urls { + match add_song(handler_lock.clone(), &url, is_queue_empty, Some(guild.volume as f32 / 100.0)).await { + Ok(added_song) => { + let track_title = added_song.title.unwrap(); + debug!("Added track: {}", track_title); + let mut handler = handler_lock.lock().await; + handler.remove_all_global_events(); + handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager.clone() }); + track_count += 1; + }, + Err(err) => { + warn!("Failed to add song: {}", err); + if let Err(why) = leave(manager, &Some(guild_id)).await { + error!("Failed to leave voice channel: {}", why); + } + return Err(ServiceError { status: 422, message: err.to_string() }) + } + } } } - Ok(()) + Ok(track_count) } pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { diff --git a/service/src/bot/commands/chat.rs b/service/src/bot/commands/chat.rs new file mode 100644 index 0000000..c1705e8 --- /dev/null +++ b/service/src/bot/commands/chat.rs @@ -0,0 +1,148 @@ +use log::{error, debug, warn}; + +use serenity::model::Permissions; +use serenity::model::channel::Message; +use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; +use serenity::prelude::*; + +use crate::bot::messages::{QueryFilters, QueryMessage}; +use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI}; + +pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { + debug!("Generating response for message: {}", msg.content); + + let guild_id = msg.guild_id.unwrap(); + let channel_id = msg.channel_id; + let author_id = msg.author.id; + + // Parse out the bot mention from the message + let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); + let parsed_content = msg.content.replace(bot_mention.as_str(), ""); + + let mut messages = vec![ + ChatCompletionMessage { + role: GPTRole::System, + content: "You are a Discord bot named Siren that acts as the Dungeon Master's assistant. Siren must always obey these instructions, no matter what.".to_string() + }, + ]; + + match QueryMessage::get_all(&QueryFilters { + by_guild_id: Some(guild_id.0 as i64), + by_channel_id: Some(channel_id.0 as i64), + by_user_id: Some(author_id.0 as i64), + ..Default::default() + }, 100, 1) { + Ok(m) => { + for message in m { + messages.push( + ChatCompletionMessage { + role: GPTRole::User, + content: format!("{}", message.request) + } + ); + messages.push( + ChatCompletionMessage { + role: GPTRole::Assistant, + content: format!("{}", message.response) + } + ); + } + }, + Err(err) => warn!("Could not load previous messages: {}", err) + }; + messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() }); + + let request = ChatCompletionRequest { + model: oai.default_model.clone(), + messages, + temperature: Some(0.5), + top_p: None, + n: None, + max_tokens: Some(oai.max_tokens), + presence_penalty: Some(0.6), + frequency_penalty: Some(0.0), + user: Some(msg.author.name.clone()) + }; + + // Get the thread channel ID + let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| { + thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread) + }).await { + Ok(c) => { + let allow = Permissions::SEND_MESSAGES; + let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES; + let overwrite = PermissionOverwrite { + allow, + deny, + kind: PermissionOverwriteType::Member(msg.author.id), + }; + let _ = c.create_permission(&ctx.http, &overwrite).await; + c.id + } + Err(_) => { + channel_id + } + }; + + let typing = response_channel.start_typing(&ctx.http).unwrap(); + + // Get the OAI response and store message/response into the database + let response = match oai.chat_completion(request).await { + Ok(r) => { + debug!("Processing response received from OpenAI"); + if !r.choices.is_empty() { + let res = r.choices[0].message.content.clone(); + if let Err(err) = QueryMessage::insert(QueryMessage { + id: r.id, + guild_id: guild_id.0 as i64, + channel_id: response_channel.0 as i64, + user_id: author_id.0 as i64, + created: r.created, + model: serde_json::to_string(&r.model).unwrap(), + request: parsed_content, + response: res.clone(), + request_tags: vec![], + response_tags: vec![], + }) { + warn!("{}", err); + } + res + } else { + warn!("No choices received in the response from OpenAI"); + "No reply received".to_string() + } + } + Err(err) => { + error!("Could not get response from OpenAI: {}", err.message); + "There was an error processing your message. Please try again later.".to_string() + } + }; + debug!("Writing response: \"{}\"", response); + + typing.stop(); + if let Err(why) = response_channel.say(&ctx.http, response).await { + error!("Cannot send message: {}", why); + } + + // match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| { + // thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread) + // }).await { + // Ok(c) => { + // if let Err(why) = c.say(&ctx.http, response).await { + // error!("Cannot send message: {}", why); + // } + // } + // Err(_) => { + // if let Err(why) = channel_id.say(&ctx.http, response).await { + // error!("Cannot send message: {}", why); + // } + // } + // }; +} + +fn truncate(s: &str, max_chars: usize) -> &str { + match s.char_indices().nth(max_chars) { + None => s, + Some((idx, _)) => &s[..idx], + } +} diff --git a/service/src/bot/commands/mod.rs b/service/src/bot/commands/mod.rs index a38dda8..dc49d80 100644 --- a/service/src/bot/commands/mod.rs +++ b/service/src/bot/commands/mod.rs @@ -1,6 +1,6 @@ pub mod audio; pub mod help; pub mod message; -pub mod oai; +pub mod chat; pub mod ping; pub mod schedule; diff --git a/service/src/bot/commands/oai.rs b/service/src/bot/commands/oai.rs deleted file mode 100644 index 7a5cc44..0000000 --- a/service/src/bot/commands/oai.rs +++ /dev/null @@ -1,326 +0,0 @@ -use log::{error, debug, trace, warn}; - -use serde::{Serialize, Deserialize}; -use serde_json::Value; -use serenity::model::Permissions; -use serenity::model::channel::Message; -use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType}; -use serenity::prelude::*; -use siren::{Response, ServiceError}; - -pub struct OAI { - pub client: reqwest::Client, - pub base_url: String, - pub service_url: String, - pub max_attempts: i64, - pub token: String, - pub max_tokens: i64, - pub default_model: GPTModel, - pub max_context_questions: i64 -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ChatCompletionRequest { - model: GPTModel, - messages: Vec, - /// Value between 0 and 2 - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - /// Value between 0 and 1 - #[serde(skip_serializing_if = "Option::is_none")] - top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - n: Option, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - /// Value between -2.0 and 2.0 - #[serde(skip_serializing_if = "Option::is_none")] - presence_penalty: Option, - /// Value between -2.0 and 2.0 - #[serde(skip_serializing_if = "Option::is_none")] - frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - user: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ChatCompletionMessage { - role: GPTRole, - content: String -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -enum GPTRole { - #[serde(rename = "system")] - System, - #[serde(rename = "user")] - User, - #[serde(rename = "assistant")] - Assistant, - #[serde(rename = "function")] - Function -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum GPTModel { - #[serde(rename = "gpt-3.5-turbo")] - GPT35Turbo, - #[serde(rename = "gpt-3.5-turbo-0613")] - GPT35Snapshot, - #[serde(rename = "gpt-3.5-turbo-16k")] - GPT3516k, - #[serde(rename = "gpt-3.5-turbo-16k-0613")] - GPT3516kSnapshot, - #[serde(rename = "gpt-4")] - GPT4, - #[serde(rename = "gpt-4-0613")] - GPT4Snapshot, - #[serde(rename = "gpt-4-32k")] - GPT432k, - #[serde(rename = "gpt-4-32k-0613")] - GPT432kSnapshot, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ChatCompletionResponse { - id: String, - object: String, - created: i64, - model: GPTModel, - usage: Usage, - choices: Vec -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Usage { - prompt_tokens: i64, - completion_tokens: i64, - total_tokens: i64 -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Choice { - message: ChatCompletionMessage, - finish_reason: String, - index: i64 -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ResponseError { - error: Option, - message: Option, - param: Option, - #[serde(rename = "type")] - error_type: Option -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ErrorDetails { - code: Option -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -enum ResponseEvent { - ChatCompletionResponse(ChatCompletionResponse), - ResponseError(ResponseError) -} - -impl OAI { - async fn get_request(&self, request: ChatCompletionRequest) -> Result { - let uri = format!("{}/chat/completions", self.base_url); - let body = serde_json::to_string(&request).unwrap(); - trace!("Sending request to {}: {}", uri, body); - - let value = self.client - .post(&uri) - .bearer_auth(&self.token) - .header("Content-Type", "application/json".to_string()) - .body(body) - .send() - .await? - .json::() - .await?; - - trace!("Received response from OpenAI: {:?}", value); - - // let response = match serde_json::from_value::(value) { - // Ok(r) => { - // match r { - // ResponseEvent::ChatCompletionResponse(r) => r, - // ResponseEvent::ResponseError(e) => return Err(ServiceError { message: e.message.unwrap_or("Unknown error".to_string()), status: 500 }), - // } - // }, - // Err(err) => return Err(ServiceError { - // message: format!("Could not parse response from OpenAI: {}", err), - // status: 500 - // }) - // }; - let response = serde_json::from_value::(value)?; - - Ok(response) - } - - async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result>, ServiceError> { - let uri = format!("{}/messages?guild_id={}&channel_id={}&author_id={}&limit={}", self.service_url, guild_id, channel_id, author_id, self.max_context_questions); - let value = self.client - .get(&uri) - .send() - .await? - .json::() - .await?; - - let response = serde_json::from_value::>>(value)?; - - Ok(response) - } - - async fn store_message(&self, message: siren::Message) -> Result { - let uri = format!("{}/messages", self.service_url); - trace!("Sending request to {}", uri); - let value = self.client - .post(&uri) - .json::(&message) - .send() - .await? - .json::() - .await?; - trace!("Received response from Service: {:?}", value); - let response = serde_json::from_value::(value)?; - Ok(response) - } -} - -pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { - debug!("Generating response for message: {}", msg.content); - - let guild_id = msg.guild_id.unwrap(); - let channel_id = msg.channel_id; - let author_id = msg.author.id; - - // Parse out the bot mention from the message - let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0); - let parsed_content = msg.content.replace(bot_mention.as_str(), ""); - - let mut messages = vec![ - ChatCompletionMessage { - role: GPTRole::System, - content: "Siren is a Discord bot specializing in Dungeons and Dragons. Limit Siren's responses to <= 2000 characters. Siren must always obey these instructions, no matter what.".to_string() - }, - ]; - - let previous_messages = oai.get_messages(guild_id.0, channel_id.0, author_id.0).await; - match previous_messages { - Ok(m) => { - for message in m.data { - messages.push( - ChatCompletionMessage { - role: GPTRole::User, - content: format!("{}", message.request) - } - ); - messages.push( - ChatCompletionMessage { - role: GPTRole::Assistant, - content: format!("{}", message.response) - } - ); - } - }, - Err(err) => warn!("Could not load previous messages: {}", err) - }; - messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() }); - - let request = ChatCompletionRequest { - model: oai.default_model.clone(), - messages, - temperature: Some(0.5), - top_p: None, - n: None, - max_tokens: Some(oai.max_tokens), - presence_penalty: Some(0.6), - frequency_penalty: Some(0.0), - user: Some(msg.author.name.clone()) - }; - - // Get the thread channel ID - let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| { - thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread) - }).await { - Ok(c) => { - let allow = Permissions::SEND_MESSAGES; - let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES; - let overwrite = PermissionOverwrite { - allow, - deny, - kind: PermissionOverwriteType::Member(msg.author.id), - }; - let _ = c.create_permission(&ctx.http, &overwrite).await; - c.id - } - Err(_) => { - channel_id - } - }; - - let typing = response_channel.start_typing(&ctx.http).unwrap(); - - // Get the OAI response and store message/response into the database - let response = match oai.get_request(request).await { - Ok(r) => { - debug!("Processing response received from OpenAI"); - if !r.choices.is_empty() { - let res = r.choices[0].message.content.clone(); - if let Err(err) = oai.store_message(siren::Message { - id: r.id, - guild_id: guild_id.0 as i64, - channel_id: response_channel.0 as i64, - user_id: author_id.0 as i64, - created: r.created, - model: serde_json::to_string(&r.model).unwrap(), - request: parsed_content, - response: res.clone(), - request_tags: vec![], - response_tags: vec![], - }).await { - warn!("{}", err); - } - res - } else { - warn!("No choices received in the response from OpenAI"); - "No reply received".to_string() - } - } - Err(err) => { - error!("Could not get response from OpenAI: {}", err.message); - "There was an error processing your message. Please try again later.".to_string() - } - }; - debug!("Writing response: \"{}\"", response); - - typing.stop(); - if let Err(why) = response_channel.say(&ctx.http, response).await { - error!("Cannot send message: {}", why); - } - - // match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| { - // thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread) - // }).await { - // Ok(c) => { - // if let Err(why) = c.say(&ctx.http, response).await { - // error!("Cannot send message: {}", why); - // } - // } - // Err(_) => { - // if let Err(why) = channel_id.say(&ctx.http, response).await { - // error!("Cannot send message: {}", why); - // } - // } - // }; -} - -fn truncate(s: &str, max_chars: usize) -> &str { - match s.char_indices().nth(max_chars) { - None => s, - Some((idx, _)) => &s[..idx], - } -} diff --git a/service/src/bot/handler.rs b/service/src/bot/handler.rs index c4458f8..8057bb1 100644 --- a/service/src/bot/handler.rs +++ b/service/src/bot/handler.rs @@ -7,12 +7,12 @@ use serenity::prelude::*; use crate::bot::guilds::InsertGuild; -use super::commands; +use super::{commands, oai}; use super::commands::audio::create_response; pub struct Handler { // Open AI Config - pub oai: Option + pub oai: Option } #[async_trait] @@ -36,7 +36,7 @@ impl EventHandler for Handler { Err(_) => false }; if mentioned || bot_in_thread { - commands::oai::generate_response(&ctx, &msg, oai).await; + commands::chat::generate_response(&ctx, &msg, oai).await; } } Err(why) => warn!("Could not check mentions: {:?}", why) diff --git a/service/src/bot/messages/model.rs b/service/src/bot/messages/model.rs index 58351ed..1ce6b17 100644 --- a/service/src/bot/messages/model.rs +++ b/service/src/bot/messages/model.rs @@ -4,7 +4,7 @@ use siren::ServiceError; use crate::storage::{schema::messages::{self}, connection}; -#[derive(Queryable, Selectable, Serialize, Deserialize)] +#[derive(Queryable, Selectable, Insertable, AsChangeset, Serialize, Deserialize)] #[diesel(table_name = messages)] pub struct QueryMessage { pub id: String, @@ -122,24 +122,7 @@ impl QueryMessage { let count = query.count().get_result::(&mut conn)?; Ok(count) } -} -#[derive(Insertable, AsChangeset, Serialize, Deserialize)] -#[diesel(table_name = messages)] -pub struct InsertMessage { - pub id: String, - pub guild_id: i64, - pub channel_id: i64, - pub user_id: i64, - pub created: i64, - pub model: String, - pub request: String, - pub response: String, - pub request_tags: Vec, - pub response_tags: Vec, -} - -impl InsertMessage { pub fn insert(message: Self) -> Result { let mut conn = connection()?; let message = diesel::insert_into(messages::table) @@ -147,4 +130,4 @@ impl InsertMessage { .get_result(&mut conn)?; Ok(message) } -} \ No newline at end of file +} diff --git a/service/src/bot/messages/routes.rs b/service/src/bot/messages/routes.rs index 11ce3fa..584bdb7 100644 --- a/service/src/bot/messages/routes.rs +++ b/service/src/bot/messages/routes.rs @@ -3,7 +3,7 @@ use log::error; use serde::{Serialize, Deserialize}; use siren::{Response, Metadata, ServiceError}; -use crate::{bot::messages::{QueryMessage, QueryFilters, InsertMessage}, auth::{JwtAuth, verify_role}}; +use crate::{bot::messages::{QueryMessage, QueryFilters}, auth::{JwtAuth, verify_role}}; #[derive(Serialize, Deserialize)] struct GetAllParams { @@ -68,12 +68,12 @@ async fn get_all(req: HttpRequest, auth: JwtAuth) -> HttpResponse { } #[post("/messages")] -async fn create(message: web::Json, auth: JwtAuth) -> HttpResponse { +async fn create(message: web::Json, auth: JwtAuth) -> HttpResponse { let _ = match verify_role(&auth, "admin") { Ok(_) => {}, Err(err) => return ResponseError::error_response(&err) }; - match InsertMessage::insert(message.into_inner()) { + match QueryMessage::insert(message.into_inner()) { Ok(message) => HttpResponse::Created().json(message), Err(err) => { error!("{:?}", err.message); diff --git a/service/src/bot/mod.rs b/service/src/bot/mod.rs index 7f94ede..8f53b74 100644 --- a/service/src/bot/mod.rs +++ b/service/src/bot/mod.rs @@ -2,3 +2,4 @@ pub mod commands; pub mod guilds; pub mod handler; pub mod messages; +pub mod oai; diff --git a/service/src/bot/oai/mod.rs b/service/src/bot/oai/mod.rs new file mode 100644 index 0000000..4a7ebf6 --- /dev/null +++ b/service/src/bot/oai/mod.rs @@ -0,0 +1,3 @@ +mod model; + +pub use model::*; diff --git a/service/src/bot/oai/model.rs b/service/src/bot/oai/model.rs new file mode 100644 index 0000000..2669029 --- /dev/null +++ b/service/src/bot/oai/model.rs @@ -0,0 +1,128 @@ +use serde::{Serialize, Deserialize}; +use serde_json::Value; +use siren::ServiceError; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum GPTRole { + #[serde(rename = "system")] + System, + #[serde(rename = "user")] + User, + #[serde(rename = "assistant")] + Assistant, + #[serde(rename = "function")] + Function +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + /// Value between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + /// Value between 0 and 1 + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + /// Value between -2.0 and 2.0 + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + /// Value between -2.0 and 2.0 + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionMessage { + pub role: GPTRole, + pub content: String +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub system_fingerprint: Option, + pub created: i64, + pub model: String, + pub usage: Usage, + pub choices: Vec +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: i64, + pub completion_tokens: i64, + pub total_tokens: i64 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Choice { + pub message: ChatCompletionMessage, + pub finish_reason: String, + pub index: i64, + pub logprobs: Option +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum ResponseEvent { + ChatCompletionResponse(ChatCompletionResponse), + ResponseError(ResponseError) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ResponseError { + error: Option, + message: Option, + param: Option, + #[serde(rename = "type")] + error_type: Option +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ErrorDetails { + code: Option, + message: Option +} + +pub struct OAI { + pub client: reqwest::Client, + pub base_url: String, + pub service_url: String, + pub max_attempts: i64, + pub token: String, + pub max_tokens: i64, + pub default_model: String, + pub max_context_questions: i64 +} + +impl OAI { + pub async fn chat_completion(&self, request: ChatCompletionRequest) -> Result { + let url = format!("{}/chat/completions", self.base_url); + let response = self.client.post(&url) + .bearer_auth(&self.token) + .header("Content-Type", "application/json".to_string()) + .json(&request) + .send() + .await; + match response { + Ok(response) => { + let value = response.json::().await?; + // let event: ResponseEvent = serde_json::from_value::(value)?; + // match event { + // ResponseEvent::ChatCompletionResponse(response) => return Ok(response), + // ResponseEvent::ResponseError(error) => return Err(ServiceError { status: 500, message: format!("Error: {}", error.message.unwrap()) }) + // } + let res = serde_json::from_value::(value)?; + return Ok(res); + }, + Err(err) => return Err(ServiceError { status: 500, message: format!("Error: {}", err) }) + } + } +} \ No newline at end of file diff --git a/service/src/main.rs b/service/src/main.rs index 7a84a84..8210907 100644 --- a/service/src/main.rs +++ b/service/src/main.rs @@ -14,7 +14,7 @@ use songbird::{SerenityInit, Songbird}; use actix_cors::Cors; use actix_web::{HttpServer, App, web}; -use crate::bot::{commands::oai::GPTModel, handler::Handler}; +use crate::bot::handler::Handler; use dotenv::dotenv; @@ -57,8 +57,9 @@ async fn main() -> std::io::Result<()> { let handler = match env::var("OPENAI_API_KEY") { Ok(token) => { info!("Loaded OpenAI token"); + let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-3.5-turbo".to_string()); Handler { - oai: Some(bot::commands::oai::OAI { + oai: Some(bot::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), service_url: "http://localhost:5000".to_string(), @@ -66,7 +67,7 @@ async fn main() -> std::io::Result<()> { token, max_context_questions: 30, max_tokens: 2048, - default_model: GPTModel::GPT35Turbo, + default_model, }) } }