Added lib to service
This commit is contained in:
@@ -6,8 +6,7 @@ use serenity::model::Permissions;
|
||||
use serenity::model::channel::Message;
|
||||
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
|
||||
use serenity::prelude::*;
|
||||
|
||||
use crate::error_handler::BotError;
|
||||
use siren::{GetResponse, ServiceError};
|
||||
|
||||
pub struct OAI {
|
||||
pub client: reqwest::Client,
|
||||
@@ -126,51 +125,8 @@ enum ResponseEvent {
|
||||
ResponseError(ResponseError)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct GetResponse<T> {
|
||||
pub data: T,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<Metadata>
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Metadata {
|
||||
pub total: i32,
|
||||
pub limit: i32,
|
||||
pub page: i32,
|
||||
pub pages: i32
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct QueryMessage {
|
||||
pub id: String,
|
||||
pub guild_id: i64,
|
||||
pub channel_id: i64,
|
||||
pub user_id: i64,
|
||||
pub created: i64,
|
||||
pub model: String,
|
||||
pub request: String,
|
||||
pub response: String,
|
||||
pub request_tags: Vec<String>,
|
||||
pub response_tags: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InsertMessage {
|
||||
pub id: String,
|
||||
pub guild_id: i64,
|
||||
pub channel_id: i64,
|
||||
pub user_id: i64,
|
||||
pub created: i64,
|
||||
pub model: String,
|
||||
pub request: String,
|
||||
pub response: String,
|
||||
pub request_tags: Vec<String>,
|
||||
pub response_tags: Vec<String>,
|
||||
}
|
||||
|
||||
impl OAI {
|
||||
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, BotError> {
|
||||
async fn get_request(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, ServiceError> {
|
||||
let uri = format!("{}/chat/completions", self.base_url);
|
||||
let body = serde_json::to_string(&request).unwrap();
|
||||
trace!("Sending request to {}: {}", uri, body);
|
||||
@@ -204,7 +160,7 @@ impl OAI {
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result<GetResponse<Vec<QueryMessage>>, BotError> {
|
||||
async fn get_messages(&self, guild_id: u64, channel_id: u64, author_id: u64) -> Result<GetResponse<Vec<siren::Message>>, ServiceError> {
|
||||
let uri = format!("{}/messages?guild_id={}&channel_id={}&author_id={}&limit={}", self.service_url, guild_id, channel_id, author_id, self.max_context_questions);
|
||||
let value = self.client
|
||||
.get(&uri)
|
||||
@@ -213,23 +169,23 @@ impl OAI {
|
||||
.json::<Value>()
|
||||
.await?;
|
||||
|
||||
let response = serde_json::from_value::<GetResponse<Vec<QueryMessage>>>(value)?;
|
||||
let response = serde_json::from_value::<GetResponse<Vec<siren::Message>>>(value)?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn store_message(&self, message: InsertMessage) -> Result<QueryMessage, BotError> {
|
||||
async fn store_message(&self, message: siren::Message) -> Result<siren::Message, ServiceError> {
|
||||
let uri = format!("{}/messages", self.service_url);
|
||||
trace!("Sending request to {}", uri);
|
||||
let value = self.client
|
||||
.post(&uri)
|
||||
.json::<InsertMessage>(&message)
|
||||
.json::<siren::Message>(&message)
|
||||
.send()
|
||||
.await?
|
||||
.json::<Value>()
|
||||
.await?;
|
||||
trace!("Received response from Service: {:?}", value);
|
||||
let response = serde_json::from_value::<QueryMessage>(value)?;
|
||||
let response = serde_json::from_value::<siren::Message>(value)?;
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
@@ -314,7 +270,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
debug!("Processing response received from OpenAI");
|
||||
if !r.choices.is_empty() {
|
||||
let res = r.choices[0].message.content.clone();
|
||||
if let Err(err) = oai.store_message(InsertMessage {
|
||||
if let Err(err) = oai.store_message(siren::Message {
|
||||
id: r.id,
|
||||
guild_id: guild_id.0 as i64,
|
||||
channel_id: response_channel.0 as i64,
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct BotError {
|
||||
pub status: u16,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl BotError {
|
||||
pub fn new(error_status_code: u16, error_message: String) -> BotError {
|
||||
BotError {
|
||||
status: error_status_code,
|
||||
message: error_message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BotError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.message.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for BotError {
|
||||
fn from(error: reqwest::Error) -> BotError {
|
||||
BotError::new(500, format!("Unknown reqwest error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for BotError {
|
||||
fn from(error: serde_json::Error) -> BotError {
|
||||
BotError::new(500, format!("Unknown serde_json error: {}", error))
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,6 @@ use songbird::SerenityInit;
|
||||
use crate::commands::oai::GPTModel;
|
||||
|
||||
mod commands;
|
||||
mod error_handler;
|
||||
|
||||
struct Handler {
|
||||
// Open AI Config
|
||||
|
||||
Reference in New Issue
Block a user