Switched from diesel to sqlx

This commit is contained in:
2024-09-05 11:52:14 -04:00
parent bce363db7e
commit d08800f9e0
42 changed files with 365 additions and 687 deletions

View File

@@ -106,7 +106,7 @@ pub async fn play_track(
let call_handler = handler_lock.lock().await;
call_handler.queue().is_empty()
};
let guild = GuildCache::get(guild_id.get() as i64)?;
let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap();
let valid = is_valid_url(&track_url);
// Check if the URL is valid
if !valid.0 {

View File

@@ -62,7 +62,9 @@ pub async fn set_volume(manager: Arc<Songbird>, guild_id: GuildId, volume: i32)
// Format volume to f32 bound between 0.0 and 1.0
let volume = std::cmp::min(100, std::cmp::max(0, volume));
let bound_volume = volume as f32 / 100.0;
let _ = GuildCache::update_audio(guild_id.get() as i64, volume);
let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64).await.unwrap().unwrap();
guild_cache.volume = volume;
guild_cache.update().await.unwrap();
if let Some(handler_lock) = manager.get(guild_id) {
let handler = handler_lock.lock().await;

View File

@@ -6,7 +6,7 @@ use serenity::model::channel::Message;
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
use serenity::prelude::*;
use crate::bot::messages::{QueryFilters, QueryMessage};
use crate::bot::messages::MessageCache;
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
@@ -27,30 +27,30 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
},
];
match QueryMessage::get_all(
&QueryFilters {
by_guild_id: Some(guild_id.get() as i64),
by_channel_id: Some(channel_id.get() as i64),
by_user_id: Some(author_id.get() as i64),
..Default::default()
},
100,
1,
) {
Ok(m) => {
for message in m {
messages.push(ChatCompletionMessage {
role: GPTRole::User,
content: format!("{}", message.request),
});
messages.push(ChatCompletionMessage {
role: GPTRole::Assistant,
content: format!("{}", message.response),
});
}
}
Err(err) => warn!("Could not load previous messages: {}", err),
};
// match MessageCache::get_all(
// &QueryFilters {
// by_guild_id: Some(guild_id.get() as i64),
// by_channel_id: Some(channel_id.get() as i64),
// by_user_id: Some(author_id.get() as i64),
// ..Default::default()
// },
// 100,
// 1,
// ) {
// Ok(m) => {
// for message in m {
// messages.push(ChatCompletionMessage {
// role: GPTRole::User,
// content: format!("{}", message.request),
// });
// messages.push(ChatCompletionMessage {
// role: GPTRole::Assistant,
// content: format!("{}", message.response),
// });
// }
// }
// Err(err) => warn!("Could not load previous messages: {}", err),
// };
messages.push(ChatCompletionMessage {
role: GPTRole::User,
content: parsed_content.clone(),
@@ -98,7 +98,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
trace!("Processing response received from OpenAI");
if !r.choices.is_empty() {
let res = r.choices[0].message.content.clone();
if let Err(err) = QueryMessage::insert(QueryMessage {
let message_cache = MessageCache {
id: r.id,
guild_id: guild_id.get() as i64,
channel_id: response_channel.get() as i64,
@@ -109,7 +109,8 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
response: res.clone(),
request_tags: vec![],
response_tags: vec![],
}) {
};
if let Err(err) = message_cache.insert().await {
warn!("{}", err);
}
res

View File

@@ -1,11 +1,9 @@
use diesel::prelude::*;
use serde::{Serialize, Deserialize};
use crate::error::SirenResult;
use crate::storage::{schema::guilds, connection};
const TABLE_NAME: &str = "guilds";
#[derive(Insertable, AsChangeset, Queryable, QueryableByName, Serialize, Deserialize)]
#[diesel(table_name = guilds)]
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct GuildCache {
pub id: i64,
pub bot_id: i64,
@@ -13,25 +11,50 @@ pub struct GuildCache {
}
impl GuildCache {
pub fn insert(&self) -> SirenResult<Self> {
let mut conn = connection()?;
let guild = diesel::insert_into(guilds::table)
.values(self)
.get_result(&mut conn)?;
Ok(guild)
pub async fn insert(&self) -> SirenResult<()> {
let pool = crate::database::pool();
sqlx::query(&format!(
"INSERT INTO {} (
id,
bot_id,
volume
) VALUES (
$1, $2, $3
)",
TABLE_NAME
))
.bind(self.id)
.bind(self.bot_id)
.bind(self.volume)
.execute(pool)
.await?;
Ok(())
}
pub fn get(id: i64) -> SirenResult<Self> {
let mut conn = connection()?;
let guild = guilds::table.filter(guilds::id.eq(id)).first(&mut conn)?;
Ok(guild)
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
let pool = crate::database::pool();
let item =
sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME))
.bind(id)
.fetch_optional(pool)
.await?;
Ok(item)
}
pub fn update_audio(id: i64, volume: i32) -> SirenResult<Self> {
let mut conn = connection()?;
let guild = diesel::update(guilds::table.filter(guilds::id.eq(id)))
.set(guilds::volume.eq(volume))
.get_result(&mut conn)?;
Ok(guild)
pub async fn update(&self) -> SirenResult<()> {
let pool = crate::database::pool();
sqlx::query(&format!(
"UPDATE {} SET
bot_id = $2,
volume = $3
WHERE id = $1",
TABLE_NAME))
.bind(self.id)
.bind(self.bot_id)
.bind(self.volume)
.execute(pool)
.await?;
Ok(())
}
}

View File

@@ -82,13 +82,13 @@ impl EventHandler for Handler {
for guild in ready.guilds {
// Check if guild exists in database
let guild_id = guild.id.get() as i64;
if let Err(why) = GuildCache::get(guild_id) {
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
let guild_cache = GuildCache {
id: guild_id,
bot_id: 1,
volume: 100
};
guild_cache.insert();
guild_cache.insert().await.unwrap();
}
let commands = guild
.id

View File

@@ -1,15 +1,10 @@
use diesel::prelude::*;
use serde::{Deserialize, Serialize};
use crate::error::SirenResult;
use crate::storage::{
schema::messages::{self},
connection,
};
const TABLE_NAME: &str = "messages";
#[derive(Queryable, Selectable, Insertable, AsChangeset, Serialize, Deserialize)]
#[diesel(table_name = messages)]
pub struct QueryMessage {
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageCache {
pub id: String,
pub guild_id: i64,
pub channel_id: i64,
@@ -22,118 +17,38 @@ pub struct QueryMessage {
pub response_tags: Vec<String>,
}
pub struct QueryFilters {
pub by_id: Option<String>,
pub by_guild_id: Option<i64>,
pub by_channel_id: Option<i64>,
pub by_user_id: Option<i64>,
pub by_model: Option<String>,
pub by_request: Option<String>,
pub by_response: Option<String>,
pub by_request_tags: Option<Vec<String>>,
pub by_response_tags: Option<Vec<String>>,
}
impl Default for QueryFilters {
fn default() -> Self {
QueryFilters {
by_id: None,
by_guild_id: None,
by_channel_id: None,
by_user_id: None,
by_model: None,
by_request: None,
by_response: None,
by_request_tags: None,
by_response_tags: None,
}
}
}
impl QueryMessage {
pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> SirenResult<Vec<Self>> {
let mut conn = connection()?;
let mut query = messages::table
.limit(limit as i64)
.order(messages::created.asc())
.into_boxed();
// Limit query to page and limit
let offset = (page - 1) * limit;
query = query.offset(offset as i64);
// Apply filters
if let Some(id) = &filters.by_id {
query = query.filter(messages::id.eq(id));
}
if let Some(guild_id) = &filters.by_guild_id {
query = query.filter(messages::guild_id.eq(guild_id));
}
if let Some(channel_id) = &filters.by_channel_id {
query = query.filter(messages::channel_id.eq(channel_id));
}
if let Some(user_id) = &filters.by_user_id {
query = query.filter(messages::user_id.eq(user_id));
}
if let Some(model) = &filters.by_model {
query = query.filter(messages::model.eq(model));
}
if let Some(request) = &filters.by_request {
query = query.filter(messages::request.eq(request));
}
if let Some(response) = &filters.by_response {
query = query.filter(messages::response.eq(response));
}
if let Some(request_tags) = &filters.by_request_tags {
query = query.filter(messages::request_tags.eq(request_tags));
}
if let Some(response_tags) = &filters.by_response_tags {
query = query.filter(messages::response_tags.eq(response_tags));
}
// Execute query
let messages = query.load::<Self>(&mut conn)?;
Ok(messages)
}
pub fn get_count(fitlers: &QueryFilters) -> SirenResult<i64> {
let mut conn = connection()?;
let mut query = messages::table.into_boxed();
// Apply filters
if let Some(id) = &fitlers.by_id {
query = query.filter(messages::id.eq(id));
}
if let Some(guild_id) = &fitlers.by_guild_id {
query = query.filter(messages::guild_id.eq(guild_id));
}
if let Some(channel_id) = &fitlers.by_channel_id {
query = query.filter(messages::channel_id.eq(channel_id));
}
if let Some(user_id) = &fitlers.by_user_id {
query = query.filter(messages::user_id.eq(user_id));
}
if let Some(model) = &fitlers.by_model {
query = query.filter(messages::model.eq(model));
}
if let Some(request) = &fitlers.by_request {
query = query.filter(messages::request.eq(request));
}
if let Some(response) = &fitlers.by_response {
query = query.filter(messages::response.eq(response));
}
if let Some(request_tags) = &fitlers.by_request_tags {
query = query.filter(messages::request_tags.eq(request_tags));
}
if let Some(response_tags) = &fitlers.by_response_tags {
query = query.filter(messages::response_tags.eq(response_tags));
}
// Execute query
let count = query.count().get_result::<i64>(&mut conn)?;
Ok(count)
}
pub fn insert(message: Self) -> SirenResult<QueryMessage> {
let mut conn = connection()?;
let message = diesel::insert_into(messages::table)
.values(message)
.get_result(&mut conn)?;
Ok(message)
impl MessageCache {
pub async fn insert(&self) -> SirenResult<()> {
let pool = crate::database::pool();
sqlx::query(&format!(
"INSERT INTO {} (
id,
guild_id,
channel_id,
user_id,
created,
model,
request,
response,
request_tags,
response_tags
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10
)",
TABLE_NAME
))
.bind(&self.id)
.bind(self.guild_id)
.bind(self.channel_id)
.bind(self.user_id)
.bind(self.created)
.bind(&self.model)
.bind(&self.request)
.bind(&self.response)
.bind(&self.request_tags)
.bind(&self.response_tags)
.execute(pool)
.await?;
Ok(())
}
}