Added lib to service

This commit is contained in:
Benjamin Sherriff
2023-10-04 19:58:54 -04:00
parent cee9dbdc81
commit ecc65222b6
20 changed files with 169 additions and 166 deletions

View File

@@ -6,8 +6,7 @@ use serenity::model::Permissions;
use serenity::model::channel::Message;
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
use serenity::prelude::*;
use crate::error_handler::BotError;
use siren::{GetResponse, ServiceError};
pub struct OAI {
pub client: reqwest::Client,
@@ -126,51 +125,8 @@ enum ResponseEvent {
ResponseError(ResponseError)
}
#[derive(Serialize, Deserialize)]
pub struct GetResponse<T> {
pub data: T,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>
}
#[derive(Serialize, Deserialize)]
pub struct Metadata {
pub total: i32,
pub limit: i32,
pub page: i32,
pub pages: i32
}
#[derive(Serialize, Deserialize)]
pub struct QueryMessage {
pub id: String,
pub guild_id: i64,
pub channel_id: i64,
pub user_id: i64,
pub created: i64,
pub model: String,
pub request: String,
pub response: String,
pub request_tags: Vec<String>,
pub response_tags: Vec<String>,
}
#[derive(Serialize, Deserialize)]
pub struct InsertMessage {
pub id: String,
pub guild_id: i64,
pub channel_id: i64,
pub user_id: i64,
pub created: i64,
pub model: String,
pub request: String,
pub response: String,
pub request_tags: Vec<String>,
pub response_tags: Vec<String>,
}
impl OAI {
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, BotError> {
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
let uri = format!("{}/chat/completions", self.base_url);
let body = serde_json::to_string(&request).unwrap();
trace!("Sending request to {}: {}", uri, body);
@@ -204,7 +160,7 @@ impl OAI {
Ok(response)
}
async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result<GetResponse<Vec<QueryMessage>>, BotError> {
async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result<GetResponse<Vec<siren::Message>>, ServiceError> {
let uri = format!("{}/messages?guild_id={}&channel_id={}&author_id={}&limit={}", self.service_url, guild_id, channel_id, author_id, self.max_context_questions);
let value = self.client
.get(&uri)
@@ -213,23 +169,23 @@ impl OAI {
.json::<Value>()
.await?;
let response = serde_json::from_value::<GetResponse<Vec<QueryMessage>>>(value)?;
let response = serde_json::from_value::<GetResponse<Vec<siren::Message>>>(value)?;
Ok(response)
}
async fn store_message(&self, message: InsertMessage) -> Result<QueryMessage, BotError> {
async fn store_message(&self, message: siren::Message) -> Result<siren::Message, ServiceError> {
let uri = format!("{}/messages", self.service_url);
trace!("Sending request to {}", uri);
let value = self.client
.post(&uri)
.json::<InsertMessage>(&message)
.json::<siren::Message>(&message)
.send()
.await?
.json::<Value>()
.await?;
trace!("Received response from Service: {:?}", value);
let response = serde_json::from_value::<QueryMessage>(value)?;
let response = serde_json::from_value::<siren::Message>(value)?;
Ok(response)
}
}
@@ -314,7 +270,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
debug!("Processing response received from OpenAI");
if !r.choices.is_empty() {
let res = r.choices[0].message.content.clone();
if let Err(err) = oai.store_message(InsertMessage {
if let Err(err) = oai.store_message(siren::Message {
id: r.id,
guild_id: guild_id.0 as i64,
channel_id: response_channel.0 as i64,

View File

@@ -1,35 +0,0 @@
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Deserialize, Serialize)]
pub struct BotError {
pub status: u16,
pub message: String,
}
impl BotError {
pub fn new(error_status_code: u16, error_message: String) -> BotError {
BotError {
status: error_status_code,
message: error_message,
}
}
}
impl fmt::Display for BotError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.message.as_str())
}
}
impl From<reqwest::Error> for BotError {
fn from(error: reqwest::Error) -> BotError {
BotError::new(500, format!("Unknown reqwest error: {}", error))
}
}
impl From<serde_json::Error> for BotError {
fn from(error: serde_json::Error) -> BotError {
BotError::new(500, format!("Unknown serde_json error: {}", error))
}
}

View File

@@ -18,7 +18,6 @@ use songbird::SerenityInit;
use crate::commands::oai::GPTModel;
mod commands;
mod error_handler;
struct Handler {
// Open AI Config