Updated chat/oai layout

This commit is contained in:
Benjamin Sherriff
2024-01-28 11:07:32 -05:00
parent b474866e7e
commit d04c34d555
14 changed files with 332 additions and 370 deletions

View File

@@ -30,6 +30,7 @@ redis = { version = "0.23.3", features = ["tokio-comp", "connection-manager", "r
base64 = "0.21.4" base64 = "0.21.4"
rust-s3 = "0.33.0" rust-s3 = "0.33.0"
actix-multipart = "0.6.1" actix-multipart = "0.6.1"
openssl = "0.10.60" # Resolve `openssl` `X509StoreRef::objects` is unsound #10
[dependencies.tokio] [dependencies.tokio]
version = "1.32.0" version = "1.32.0"

View File

@@ -30,6 +30,8 @@ services:
- ${SERVICE_PORT:-5000}:5000 - ${SERVICE_PORT:-5000}:5000
depends_on: depends_on:
- db - db
- redis
- minio
networks: networks:
- frontend - frontend
- backend - backend
@@ -53,6 +55,8 @@ services:
redis: redis:
image: redis:latest image: redis:latest
container_name: siren-redis container_name: siren-redis
volumes:
- redis:/data
ports: ports:
- ${REDIS_PORT:-6379}:6379 - ${REDIS_PORT:-6379}:6379
networks: networks:
@@ -77,6 +81,7 @@ services:
volumes: volumes:
db: db:
db_logs: db_logs:
redis:
minio: minio:
networks: networks:

View File

@@ -103,6 +103,13 @@ pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool, volume: Opt
Ok(metadata) Ok(metadata)
} }
pub fn get_playlist_urls(url: &str) -> Result<Vec<String>, String> {
let mut urls: Vec<String> = Vec::new();
// TODO fix this later
urls.push(url.to_string());
Ok(urls)
}
fn is_valid_url(url: &str) -> bool { fn is_valid_url(url: &str) -> bool {
match url.parse::<reqwest::Url>() { match url.parse::<reqwest::Url>() {
Ok(_) => return true, Ok(_) => return true,

View File

@@ -9,7 +9,7 @@ use serenity::model::application::interaction::application_command::ApplicationC
use siren::ServiceError; use siren::ServiceError;
use songbird::{EventHandler, Songbird}; use songbird::{EventHandler, Songbird};
use crate::bot::{guilds::QueryGuild, commands::audio::{leave, add_song, get_songbird}}; use crate::bot::{guilds::QueryGuild, commands::audio::{leave, get_playlist_urls, add_song, get_songbird}};
use super::{create_response, edit_response, join_by_user}; use super::{create_response, edit_response, join_by_user};
@@ -87,31 +87,42 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
} }
} }
pub async fn play_track(manager: Arc<Songbird>, guild_id: GuildId, track_url: String) -> Result<(), ServiceError> { pub async fn play_track(manager: Arc<Songbird>, guild_id: GuildId, track_url: String) -> Result<i32, ServiceError> {
let mut track_count = 0;
if let Some(handler_lock) = manager.get(guild_id) { if let Some(handler_lock) = manager.get(guild_id) {
let is_queue_empty = { let is_queue_empty = {
let call_handler = handler_lock.lock().await; let call_handler = handler_lock.lock().await;
call_handler.queue().is_empty() call_handler.queue().is_empty()
}; };
let guild = QueryGuild::get(guild_id.0 as i64)?; let guild = QueryGuild::get(guild_id.0 as i64)?;
match add_song(handler_lock.clone(), &track_url, is_queue_empty, Some(guild.volume as f32 / 100.0)).await { let track_urls = match get_playlist_urls(&track_url) {
Ok(added_song) => { Ok(urls) => urls,
let track_title = added_song.title.unwrap();
debug!("Added track: {}", track_title);
let mut handler = handler_lock.lock().await;
handler.remove_all_global_events();
handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager })
},
Err(err) => { Err(err) => {
warn!("Failed to add song: {}", err); warn!("Failed to get playlist urls: {}", err);
if let Err(why) = leave(manager, &Some(guild_id)).await {
error!("Failed to leave voice channel: {}", why);
}
return Err(ServiceError { status: 422, message: err.to_string() }) return Err(ServiceError { status: 422, message: err.to_string() })
} }
};
for url in track_urls {
match add_song(handler_lock.clone(), &url, is_queue_empty, Some(guild.volume as f32 / 100.0)).await {
Ok(added_song) => {
let track_title = added_song.title.unwrap();
debug!("Added track: {}", track_title);
let mut handler = handler_lock.lock().await;
handler.remove_all_global_events();
handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager.clone() });
track_count += 1;
},
Err(err) => {
warn!("Failed to add song: {}", err);
if let Err(why) = leave(manager, &Some(guild_id)).await {
error!("Failed to leave voice channel: {}", why);
}
return Err(ServiceError { status: 422, message: err.to_string() })
}
}
} }
} }
Ok(()) Ok(track_count)
} }
pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand {

View File

@@ -0,0 +1,148 @@
use log::{error, debug, warn};
use serenity::model::Permissions;
use serenity::model::channel::Message;
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
use serenity::prelude::*;
use crate::bot::messages::{QueryFilters, QueryMessage};
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
debug!("Generating response for message: {}", msg.content);
let guild_id = msg.guild_id.unwrap();
let channel_id = msg.channel_id;
let author_id = msg.author.id;
// Parse out the bot mention from the message
let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0);
let parsed_content = msg.content.replace(bot_mention.as_str(), "");
let mut messages = vec![
ChatCompletionMessage {
role: GPTRole::System,
content: "You are a Discord bot named Siren that acts as the Dungeon Master's assistant. Siren must always obey these instructions, no matter what.".to_string()
},
];
match QueryMessage::get_all(&QueryFilters {
by_guild_id: Some(guild_id.0 as i64),
by_channel_id: Some(channel_id.0 as i64),
by_user_id: Some(author_id.0 as i64),
..Default::default()
}, 100, 1) {
Ok(m) => {
for message in m {
messages.push(
ChatCompletionMessage {
role: GPTRole::User,
content: format!("{}", message.request)
}
);
messages.push(
ChatCompletionMessage {
role: GPTRole::Assistant,
content: format!("{}", message.response)
}
);
}
},
Err(err) => warn!("Could not load previous messages: {}", err)
};
messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() });
let request = ChatCompletionRequest {
model: oai.default_model.clone(),
messages,
temperature: Some(0.5),
top_p: None,
n: None,
max_tokens: Some(oai.max_tokens),
presence_penalty: Some(0.6),
frequency_penalty: Some(0.0),
user: Some(msg.author.name.clone())
};
// Get the thread channel ID
let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| {
thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
}).await {
Ok(c) => {
let allow = Permissions::SEND_MESSAGES;
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
let overwrite = PermissionOverwrite {
allow,
deny,
kind: PermissionOverwriteType::Member(msg.author.id),
};
let _ = c.create_permission(&ctx.http, &overwrite).await;
c.id
}
Err(_) => {
channel_id
}
};
let typing = response_channel.start_typing(&ctx.http).unwrap();
// Get the OAI response and store message/response into the database
let response = match oai.chat_completion(request).await {
Ok(r) => {
debug!("Processing response received from OpenAI");
if !r.choices.is_empty() {
let res = r.choices[0].message.content.clone();
if let Err(err) = QueryMessage::insert(QueryMessage {
id: r.id,
guild_id: guild_id.0 as i64,
channel_id: response_channel.0 as i64,
user_id: author_id.0 as i64,
created: r.created,
model: serde_json::to_string(&r.model).unwrap(),
request: parsed_content,
response: res.clone(),
request_tags: vec![],
response_tags: vec![],
}) {
warn!("{}", err);
}
res
} else {
warn!("No choices received in the response from OpenAI");
"No reply received".to_string()
}
}
Err(err) => {
error!("Could not get response from OpenAI: {}", err.message);
"There was an error processing your message. Please try again later.".to_string()
}
};
debug!("Writing response: \"{}\"", response);
typing.stop();
if let Err(why) = response_channel.say(&ctx.http, response).await {
error!("Cannot send message: {}", why);
}
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
// }).await {
// Ok(c) => {
// if let Err(why) = c.say(&ctx.http, response).await {
// error!("Cannot send message: {}", why);
// }
// }
// Err(_) => {
// if let Err(why) = channel_id.say(&ctx.http, response).await {
// error!("Cannot send message: {}", why);
// }
// }
// };
}
fn truncate(s: &str, max_chars: usize) -> &str {
match s.char_indices().nth(max_chars) {
None => s,
Some((idx, _)) => &s[..idx],
}
}

View File

@@ -1,6 +1,6 @@
pub mod audio; pub mod audio;
pub mod help; pub mod help;
pub mod message; pub mod message;
pub mod oai; pub mod chat;
pub mod ping; pub mod ping;
pub mod schedule; pub mod schedule;

View File

@@ -1,326 +0,0 @@
use log::{error, debug, trace, warn};
use serde::{Serialize, Deserialize};
use serde_json::Value;
use serenity::model::Permissions;
use serenity::model::channel::Message;
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
use serenity::prelude::*;
use siren::{Response, ServiceError};
pub struct OAI {
pub client: reqwest::Client,
pub base_url: String,
pub service_url: String,
pub max_attempts: i64,
pub token: String,
pub max_tokens: i64,
pub default_model: GPTModel,
pub max_context_questions: i64
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ChatCompletionRequest {
model: GPTModel,
messages: Vec<ChatCompletionMessage>,
/// Value between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
/// Value between 0 and 1
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
n: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<i64>,
/// Value between -2.0 and 2.0
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f64>,
/// Value between -2.0 and 2.0
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ChatCompletionMessage {
role: GPTRole,
content: String
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum GPTRole {
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "function")]
Function
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GPTModel {
#[serde(rename = "gpt-3.5-turbo")]
GPT35Turbo,
#[serde(rename = "gpt-3.5-turbo-0613")]
GPT35Snapshot,
#[serde(rename = "gpt-3.5-turbo-16k")]
GPT3516k,
#[serde(rename = "gpt-3.5-turbo-16k-0613")]
GPT3516kSnapshot,
#[serde(rename = "gpt-4")]
GPT4,
#[serde(rename = "gpt-4-0613")]
GPT4Snapshot,
#[serde(rename = "gpt-4-32k")]
GPT432k,
#[serde(rename = "gpt-4-32k-0613")]
GPT432kSnapshot,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: i64,
model: GPTModel,
usage: Usage,
choices: Vec<Choice>
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Usage {
prompt_tokens: i64,
completion_tokens: i64,
total_tokens: i64
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Choice {
message: ChatCompletionMessage,
finish_reason: String,
index: i64
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ResponseError {
error: Option<ErrorDetails>,
message: Option<String>,
param: Option<String>,
#[serde(rename = "type")]
error_type: Option<String>
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ErrorDetails {
code: Option<String>
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum ResponseEvent {
ChatCompletionResponse(ChatCompletionResponse),
ResponseError(ResponseError)
}
impl OAI {
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
let uri = format!("{}/chat/completions", self.base_url);
let body = serde_json::to_string(&request).unwrap();
trace!("Sending request to {}: {}", uri, body);
let value = self.client
.post(&uri)
.bearer_auth(&self.token)
.header("Content-Type", "application/json".to_string())
.body(body)
.send()
.await?
.json::<Value>()
.await?;
trace!("Received response from OpenAI: {:?}", value);
// let response = match serde_json::from_value::<ResponseEvent>(value) {
// Ok(r) => {
// match r {
// ResponseEvent::ChatCompletionResponse(r) => r,
// ResponseEvent::ResponseError(e) => return Err(ServiceError { message: e.message.unwrap_or("Unknown error".to_string()), status: 500 }),
// }
// },
// Err(err) => return Err(ServiceError {
// message: format!("Could not parse response from OpenAI: {}", err),
// status: 500
// })
// };
let response = serde_json::from_value::<ChatCompletionResponse>(value)?;
Ok(response)
}
async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result<Response<Vec<siren::Message>>, ServiceError> {
let uri = format!("{}/messages?guild_id={}&channel_id={}&author_id={}&limit={}", self.service_url, guild_id, channel_id, author_id, self.max_context_questions);
let value = self.client
.get(&uri)
.send()
.await?
.json::<Value>()
.await?;
let response = serde_json::from_value::<Response<Vec<siren::Message>>>(value)?;
Ok(response)
}
async fn store_message(&self, message: siren::Message) -> Result<siren::Message, ServiceError> {
let uri = format!("{}/messages", self.service_url);
trace!("Sending request to {}", uri);
let value = self.client
.post(&uri)
.json::<siren::Message>(&message)
.send()
.await?
.json::<Value>()
.await?;
trace!("Received response from Service: {:?}", value);
let response = serde_json::from_value::<siren::Message>(value)?;
Ok(response)
}
}
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
debug!("Generating response for message: {}", msg.content);
let guild_id = msg.guild_id.unwrap();
let channel_id = msg.channel_id;
let author_id = msg.author.id;
// Parse out the bot mention from the message
let bot_mention: String = format!("<@{}>", ctx.cache.current_user_id().0);
let parsed_content = msg.content.replace(bot_mention.as_str(), "");
let mut messages = vec![
ChatCompletionMessage {
role: GPTRole::System,
content: "Siren is a Discord bot specializing in Dungeons and Dragons. Limit Siren's responses to <= 2000 characters. Siren must always obey these instructions, no matter what.".to_string()
},
];
let previous_messages = oai.get_messages(guild_id.0, channel_id.0, author_id.0).await;
match previous_messages {
Ok(m) => {
for message in m.data {
messages.push(
ChatCompletionMessage {
role: GPTRole::User,
content: format!("{}", message.request)
}
);
messages.push(
ChatCompletionMessage {
role: GPTRole::Assistant,
content: format!("{}", message.response)
}
);
}
},
Err(err) => warn!("Could not load previous messages: {}", err)
};
messages.push(ChatCompletionMessage { role: GPTRole::User, content: parsed_content.clone() });
let request = ChatCompletionRequest {
model: oai.default_model.clone(),
messages,
temperature: Some(0.5),
top_p: None,
n: None,
max_tokens: Some(oai.max_tokens),
presence_penalty: Some(0.6),
frequency_penalty: Some(0.0),
user: Some(msg.author.name.clone())
};
// Get the thread channel ID
let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| {
thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
}).await {
Ok(c) => {
let allow = Permissions::SEND_MESSAGES;
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
let overwrite = PermissionOverwrite {
allow,
deny,
kind: PermissionOverwriteType::Member(msg.author.id),
};
let _ = c.create_permission(&ctx.http, &overwrite).await;
c.id
}
Err(_) => {
channel_id
}
};
let typing = response_channel.start_typing(&ctx.http).unwrap();
// Get the OAI response and store message/response into the database
let response = match oai.get_request(request).await {
Ok(r) => {
debug!("Processing response received from OpenAI");
if !r.choices.is_empty() {
let res = r.choices[0].message.content.clone();
if let Err(err) = oai.store_message(siren::Message {
id: r.id,
guild_id: guild_id.0 as i64,
channel_id: response_channel.0 as i64,
user_id: author_id.0 as i64,
created: r.created,
model: serde_json::to_string(&r.model).unwrap(),
request: parsed_content,
response: res.clone(),
request_tags: vec![],
response_tags: vec![],
}).await {
warn!("{}", err);
}
res
} else {
warn!("No choices received in the response from OpenAI");
"No reply received".to_string()
}
}
Err(err) => {
error!("Could not get response from OpenAI: {}", err.message);
"There was an error processing your message. Please try again later.".to_string()
}
};
debug!("Writing response: \"{}\"", response);
typing.stop();
if let Err(why) = response_channel.say(&ctx.http, response).await {
error!("Cannot send message: {}", why);
}
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
// }).await {
// Ok(c) => {
// if let Err(why) = c.say(&ctx.http, response).await {
// error!("Cannot send message: {}", why);
// }
// }
// Err(_) => {
// if let Err(why) = channel_id.say(&ctx.http, response).await {
// error!("Cannot send message: {}", why);
// }
// }
// };
}
fn truncate(s: &str, max_chars: usize) -> &str {
match s.char_indices().nth(max_chars) {
None => s,
Some((idx, _)) => &s[..idx],
}
}

View File

@@ -7,12 +7,12 @@ use serenity::prelude::*;
use crate::bot::guilds::InsertGuild; use crate::bot::guilds::InsertGuild;
use super::commands; use super::{commands, oai};
use super::commands::audio::create_response; use super::commands::audio::create_response;
pub struct Handler { pub struct Handler {
// Open AI Config // Open AI Config
pub oai: Option<commands::oai::OAI> pub oai: Option<oai::OAI>
} }
#[async_trait] #[async_trait]
@@ -36,7 +36,7 @@ impl EventHandler for Handler {
Err(_) => false Err(_) => false
}; };
if mentioned || bot_in_thread { if mentioned || bot_in_thread {
commands::oai::generate_response(&ctx, &msg, oai).await; commands::chat::generate_response(&ctx, &msg, oai).await;
} }
} }
Err(why) => warn!("Could not check mentions: {:?}", why) Err(why) => warn!("Could not check mentions: {:?}", why)

View File

@@ -4,7 +4,7 @@ use siren::ServiceError;
use crate::storage::{schema::messages::{self}, connection}; use crate::storage::{schema::messages::{self}, connection};
#[derive(Queryable, Selectable, Serialize, Deserialize)] #[derive(Queryable, Selectable, Insertable, AsChangeset, Serialize, Deserialize)]
#[diesel(table_name = messages)] #[diesel(table_name = messages)]
pub struct QueryMessage { pub struct QueryMessage {
pub id: String, pub id: String,
@@ -122,24 +122,7 @@ impl QueryMessage {
let count = query.count().get_result::<i64>(&mut conn)?; let count = query.count().get_result::<i64>(&mut conn)?;
Ok(count) Ok(count)
} }
}
#[derive(Insertable, AsChangeset, Serialize, Deserialize)]
#[diesel(table_name = messages)]
pub struct InsertMessage {
pub id: String,
pub guild_id: i64,
pub channel_id: i64,
pub user_id: i64,
pub created: i64,
pub model: String,
pub request: String,
pub response: String,
pub request_tags: Vec<String>,
pub response_tags: Vec<String>,
}
impl InsertMessage {
pub fn insert(message: Self) -> Result<QueryMessage, ServiceError> { pub fn insert(message: Self) -> Result<QueryMessage, ServiceError> {
let mut conn = connection()?; let mut conn = connection()?;
let message = diesel::insert_into(messages::table) let message = diesel::insert_into(messages::table)

View File

@@ -3,7 +3,7 @@ use log::error;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use siren::{Response, Metadata, ServiceError}; use siren::{Response, Metadata, ServiceError};
use crate::{bot::messages::{QueryMessage, QueryFilters, InsertMessage}, auth::{JwtAuth, verify_role}}; use crate::{bot::messages::{QueryMessage, QueryFilters}, auth::{JwtAuth, verify_role}};
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct GetAllParams { struct GetAllParams {
@@ -68,12 +68,12 @@ async fn get_all(req: HttpRequest, auth: JwtAuth) -> HttpResponse {
} }
#[post("/messages")] #[post("/messages")]
async fn create(message: web::Json<InsertMessage>, auth: JwtAuth) -> HttpResponse { async fn create(message: web::Json<QueryMessage>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") { let _ = match verify_role(&auth, "admin") {
Ok(_) => {}, Ok(_) => {},
Err(err) => return ResponseError::error_response(&err) Err(err) => return ResponseError::error_response(&err)
}; };
match InsertMessage::insert(message.into_inner()) { match QueryMessage::insert(message.into_inner()) {
Ok(message) => HttpResponse::Created().json(message), Ok(message) => HttpResponse::Created().json(message),
Err(err) => { Err(err) => {
error!("{:?}", err.message); error!("{:?}", err.message);

View File

@@ -2,3 +2,4 @@ pub mod commands;
pub mod guilds; pub mod guilds;
pub mod handler; pub mod handler;
pub mod messages; pub mod messages;
pub mod oai;

View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

View File

@@ -0,0 +1,128 @@
use serde::{Serialize, Deserialize};
use serde_json::Value;
use siren::ServiceError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GPTRole {
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "function")]
Function
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatCompletionMessage>,
/// Value between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
/// Value between 0 and 1
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i64>,
/// Value between -2.0 and 2.0
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
/// Value between -2.0 and 2.0
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionMessage {
pub role: GPTRole,
pub content: String
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub system_fingerprint: Option<String>,
pub created: i64,
pub model: String,
pub usage: Usage,
pub choices: Vec<Choice>
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: i64,
pub completion_tokens: i64,
pub total_tokens: i64
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub message: ChatCompletionMessage,
pub finish_reason: String,
pub index: i64,
pub logprobs: Option<String>
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum ResponseEvent {
ChatCompletionResponse(ChatCompletionResponse),
ResponseError(ResponseError)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ResponseError {
error: Option<ErrorDetails>,
message: Option<String>,
param: Option<String>,
#[serde(rename = "type")]
error_type: Option<String>
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ErrorDetails {
code: Option<String>,
message: Option<String>
}
pub struct OAI {
pub client: reqwest::Client,
pub base_url: String,
pub service_url: String,
pub max_attempts: i64,
pub token: String,
pub max_tokens: i64,
pub default_model: String,
pub max_context_questions: i64
}
impl OAI {
pub async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
let url = format!("{}/chat/completions", self.base_url);
let response = self.client.post(&url)
.bearer_auth(&self.token)
.header("Content-Type", "application/json".to_string())
.json(&request)
.send()
.await;
match response {
Ok(response) => {
let value = response.json::<Value>().await?;
// let event: ResponseEvent = serde_json::from_value::<ResponseEvent>(value)?;
// match event {
// ResponseEvent::ChatCompletionResponse(response) => return Ok(response),
// ResponseEvent::ResponseError(error) => return Err(ServiceError { status: 500, message: format!("Error: {}", error.message.unwrap()) })
// }
let res = serde_json::from_value::<ChatCompletionResponse>(value)?;
return Ok(res);
},
Err(err) => return Err(ServiceError { status: 500, message: format!("Error: {}", err) })
}
}
}

View File

@@ -14,7 +14,7 @@ use songbird::{SerenityInit, Songbird};
use actix_cors::Cors; use actix_cors::Cors;
use actix_web::{HttpServer, App, web}; use actix_web::{HttpServer, App, web};
use crate::bot::{commands::oai::GPTModel, handler::Handler}; use crate::bot::handler::Handler;
use dotenv::dotenv; use dotenv::dotenv;
@@ -57,8 +57,9 @@ async fn main() -> std::io::Result<()> {
let handler = match env::var("OPENAI_API_KEY") { let handler = match env::var("OPENAI_API_KEY") {
Ok(token) => { Ok(token) => {
info!("Loaded OpenAI token"); info!("Loaded OpenAI token");
let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-3.5-turbo".to_string());
Handler { Handler {
oai: Some(bot::commands::oai::OAI { oai: Some(bot::oai::OAI {
client: reqwest::Client::new(), client: reqwest::Client::new(),
base_url: "https://api.openai.com/v1".to_string(), base_url: "https://api.openai.com/v1".to_string(),
service_url: "http://localhost:5000".to_string(), service_url: "http://localhost:5000".to_string(),
@@ -66,7 +67,7 @@ async fn main() -> std::io::Result<()> {
token, token,
max_context_questions: 30, max_context_questions: 30,
max_tokens: 2048, max_tokens: 2048,
default_model: GPTModel::GPT35Turbo, default_model,
}) })
} }
} }