Switched from diesel to sqlx
This commit is contained in:
@@ -106,7 +106,7 @@ pub async fn play_track(
|
||||
let call_handler = handler_lock.lock().await;
|
||||
call_handler.queue().is_empty()
|
||||
};
|
||||
let guild = GuildCache::get(guild_id.get() as i64)?;
|
||||
let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap();
|
||||
let valid = is_valid_url(&track_url);
|
||||
// Check if the URL is valid
|
||||
if !valid.0 {
|
||||
|
||||
@@ -62,7 +62,9 @@ pub async fn set_volume(manager: Arc<Songbird>, guild_id: GuildId, volume: i32)
|
||||
// Format volume to f32 bound between 0.0 and 1.0
|
||||
let volume = std::cmp::min(100, std::cmp::max(0, volume));
|
||||
let bound_volume = volume as f32 / 100.0;
|
||||
let _ = GuildCache::update_audio(guild_id.get() as i64, volume);
|
||||
let mut guild_cache = GuildCache::get_by_id(guild_id.get() as i64).await.unwrap().unwrap();
|
||||
guild_cache.volume = volume;
|
||||
guild_cache.update().await.unwrap();
|
||||
|
||||
if let Some(handler_lock) = manager.get(guild_id) {
|
||||
let handler = handler_lock.lock().await;
|
||||
|
||||
@@ -6,7 +6,7 @@ use serenity::model::channel::Message;
|
||||
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
|
||||
use serenity::prelude::*;
|
||||
|
||||
use crate::bot::messages::{QueryFilters, QueryMessage};
|
||||
use crate::bot::messages::MessageCache;
|
||||
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
|
||||
|
||||
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
@@ -27,30 +27,30 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
},
|
||||
];
|
||||
|
||||
match QueryMessage::get_all(
|
||||
&QueryFilters {
|
||||
by_guild_id: Some(guild_id.get() as i64),
|
||||
by_channel_id: Some(channel_id.get() as i64),
|
||||
by_user_id: Some(author_id.get() as i64),
|
||||
..Default::default()
|
||||
},
|
||||
100,
|
||||
1,
|
||||
) {
|
||||
Ok(m) => {
|
||||
for message in m {
|
||||
messages.push(ChatCompletionMessage {
|
||||
role: GPTRole::User,
|
||||
content: format!("{}", message.request),
|
||||
});
|
||||
messages.push(ChatCompletionMessage {
|
||||
role: GPTRole::Assistant,
|
||||
content: format!("{}", message.response),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(err) => warn!("Could not load previous messages: {}", err),
|
||||
};
|
||||
// match MessageCache::get_all(
|
||||
// &QueryFilters {
|
||||
// by_guild_id: Some(guild_id.get() as i64),
|
||||
// by_channel_id: Some(channel_id.get() as i64),
|
||||
// by_user_id: Some(author_id.get() as i64),
|
||||
// ..Default::default()
|
||||
// },
|
||||
// 100,
|
||||
// 1,
|
||||
// ) {
|
||||
// Ok(m) => {
|
||||
// for message in m {
|
||||
// messages.push(ChatCompletionMessage {
|
||||
// role: GPTRole::User,
|
||||
// content: format!("{}", message.request),
|
||||
// });
|
||||
// messages.push(ChatCompletionMessage {
|
||||
// role: GPTRole::Assistant,
|
||||
// content: format!("{}", message.response),
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
// Err(err) => warn!("Could not load previous messages: {}", err),
|
||||
// };
|
||||
messages.push(ChatCompletionMessage {
|
||||
role: GPTRole::User,
|
||||
content: parsed_content.clone(),
|
||||
@@ -98,7 +98,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
trace!("Processing response received from OpenAI");
|
||||
if !r.choices.is_empty() {
|
||||
let res = r.choices[0].message.content.clone();
|
||||
if let Err(err) = QueryMessage::insert(QueryMessage {
|
||||
let message_cache = MessageCache {
|
||||
id: r.id,
|
||||
guild_id: guild_id.get() as i64,
|
||||
channel_id: response_channel.get() as i64,
|
||||
@@ -109,7 +109,8 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
response: res.clone(),
|
||||
request_tags: vec![],
|
||||
response_tags: vec![],
|
||||
}) {
|
||||
};
|
||||
if let Err(err) = message_cache.insert().await {
|
||||
warn!("{}", err);
|
||||
}
|
||||
res
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use diesel::prelude::*;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::SirenResult;
|
||||
|
||||
use crate::storage::{schema::guilds, connection};
|
||||
const TABLE_NAME: &str = "guilds";
|
||||
|
||||
#[derive(Insertable, AsChangeset, Queryable, QueryableByName, Serialize, Deserialize)]
|
||||
#[diesel(table_name = guilds)]
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct GuildCache {
|
||||
pub id: i64,
|
||||
pub bot_id: i64,
|
||||
@@ -13,25 +11,50 @@ pub struct GuildCache {
|
||||
}
|
||||
|
||||
impl GuildCache {
|
||||
pub fn insert(&self) -> SirenResult<Self> {
|
||||
let mut conn = connection()?;
|
||||
let guild = diesel::insert_into(guilds::table)
|
||||
.values(self)
|
||||
.get_result(&mut conn)?;
|
||||
Ok(guild)
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
let pool = crate::database::pool();
|
||||
sqlx::query(&format!(
|
||||
"INSERT INTO {} (
|
||||
id,
|
||||
bot_id,
|
||||
volume
|
||||
) VALUES (
|
||||
$1, $2, $3
|
||||
)",
|
||||
TABLE_NAME
|
||||
))
|
||||
.bind(self.id)
|
||||
.bind(self.bot_id)
|
||||
.bind(self.volume)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get(id: i64) -> SirenResult<Self> {
|
||||
let mut conn = connection()?;
|
||||
let guild = guilds::table.filter(guilds::id.eq(id)).first(&mut conn)?;
|
||||
Ok(guild)
|
||||
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
|
||||
let pool = crate::database::pool();
|
||||
let item =
|
||||
sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME))
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(item)
|
||||
}
|
||||
|
||||
pub fn update_audio(id: i64, volume: i32) -> SirenResult<Self> {
|
||||
let mut conn = connection()?;
|
||||
let guild = diesel::update(guilds::table.filter(guilds::id.eq(id)))
|
||||
.set(guilds::volume.eq(volume))
|
||||
.get_result(&mut conn)?;
|
||||
Ok(guild)
|
||||
pub async fn update(&self) -> SirenResult<()> {
|
||||
let pool = crate::database::pool();
|
||||
sqlx::query(&format!(
|
||||
"UPDATE {} SET
|
||||
bot_id = $2,
|
||||
volume = $3
|
||||
WHERE id = $1",
|
||||
TABLE_NAME))
|
||||
.bind(self.id)
|
||||
.bind(self.bot_id)
|
||||
.bind(self.volume)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,13 +82,13 @@ impl EventHandler for Handler {
|
||||
for guild in ready.guilds {
|
||||
// Check if guild exists in database
|
||||
let guild_id = guild.id.get() as i64;
|
||||
if let Err(why) = GuildCache::get(guild_id) {
|
||||
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
|
||||
let guild_cache = GuildCache {
|
||||
id: guild_id,
|
||||
bot_id: 1,
|
||||
volume: 100
|
||||
};
|
||||
guild_cache.insert();
|
||||
guild_cache.insert().await.unwrap();
|
||||
}
|
||||
let commands = guild
|
||||
.id
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::error::SirenResult;
|
||||
|
||||
use crate::storage::{
|
||||
schema::messages::{self},
|
||||
connection,
|
||||
};
|
||||
const TABLE_NAME: &str = "messages";
|
||||
|
||||
#[derive(Queryable, Selectable, Insertable, AsChangeset, Serialize, Deserialize)]
|
||||
#[diesel(table_name = messages)]
|
||||
pub struct QueryMessage {
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct MessageCache {
|
||||
pub id: String,
|
||||
pub guild_id: i64,
|
||||
pub channel_id: i64,
|
||||
@@ -22,118 +17,38 @@ pub struct QueryMessage {
|
||||
pub response_tags: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct QueryFilters {
|
||||
pub by_id: Option<String>,
|
||||
pub by_guild_id: Option<i64>,
|
||||
pub by_channel_id: Option<i64>,
|
||||
pub by_user_id: Option<i64>,
|
||||
pub by_model: Option<String>,
|
||||
pub by_request: Option<String>,
|
||||
pub by_response: Option<String>,
|
||||
pub by_request_tags: Option<Vec<String>>,
|
||||
pub by_response_tags: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Default for QueryFilters {
|
||||
fn default() -> Self {
|
||||
QueryFilters {
|
||||
by_id: None,
|
||||
by_guild_id: None,
|
||||
by_channel_id: None,
|
||||
by_user_id: None,
|
||||
by_model: None,
|
||||
by_request: None,
|
||||
by_response: None,
|
||||
by_request_tags: None,
|
||||
by_response_tags: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryMessage {
|
||||
pub fn get_all(filters: &QueryFilters, limit: i32, page: i32) -> SirenResult<Vec<Self>> {
|
||||
let mut conn = connection()?;
|
||||
let mut query = messages::table
|
||||
.limit(limit as i64)
|
||||
.order(messages::created.asc())
|
||||
.into_boxed();
|
||||
// Limit query to page and limit
|
||||
let offset = (page - 1) * limit;
|
||||
query = query.offset(offset as i64);
|
||||
// Apply filters
|
||||
if let Some(id) = &filters.by_id {
|
||||
query = query.filter(messages::id.eq(id));
|
||||
}
|
||||
if let Some(guild_id) = &filters.by_guild_id {
|
||||
query = query.filter(messages::guild_id.eq(guild_id));
|
||||
}
|
||||
if let Some(channel_id) = &filters.by_channel_id {
|
||||
query = query.filter(messages::channel_id.eq(channel_id));
|
||||
}
|
||||
if let Some(user_id) = &filters.by_user_id {
|
||||
query = query.filter(messages::user_id.eq(user_id));
|
||||
}
|
||||
if let Some(model) = &filters.by_model {
|
||||
query = query.filter(messages::model.eq(model));
|
||||
}
|
||||
if let Some(request) = &filters.by_request {
|
||||
query = query.filter(messages::request.eq(request));
|
||||
}
|
||||
if let Some(response) = &filters.by_response {
|
||||
query = query.filter(messages::response.eq(response));
|
||||
}
|
||||
if let Some(request_tags) = &filters.by_request_tags {
|
||||
query = query.filter(messages::request_tags.eq(request_tags));
|
||||
}
|
||||
if let Some(response_tags) = &filters.by_response_tags {
|
||||
query = query.filter(messages::response_tags.eq(response_tags));
|
||||
}
|
||||
// Execute query
|
||||
let messages = query.load::<Self>(&mut conn)?;
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub fn get_count(fitlers: &QueryFilters) -> SirenResult<i64> {
|
||||
let mut conn = connection()?;
|
||||
let mut query = messages::table.into_boxed();
|
||||
// Apply filters
|
||||
if let Some(id) = &fitlers.by_id {
|
||||
query = query.filter(messages::id.eq(id));
|
||||
}
|
||||
if let Some(guild_id) = &fitlers.by_guild_id {
|
||||
query = query.filter(messages::guild_id.eq(guild_id));
|
||||
}
|
||||
if let Some(channel_id) = &fitlers.by_channel_id {
|
||||
query = query.filter(messages::channel_id.eq(channel_id));
|
||||
}
|
||||
if let Some(user_id) = &fitlers.by_user_id {
|
||||
query = query.filter(messages::user_id.eq(user_id));
|
||||
}
|
||||
if let Some(model) = &fitlers.by_model {
|
||||
query = query.filter(messages::model.eq(model));
|
||||
}
|
||||
if let Some(request) = &fitlers.by_request {
|
||||
query = query.filter(messages::request.eq(request));
|
||||
}
|
||||
if let Some(response) = &fitlers.by_response {
|
||||
query = query.filter(messages::response.eq(response));
|
||||
}
|
||||
if let Some(request_tags) = &fitlers.by_request_tags {
|
||||
query = query.filter(messages::request_tags.eq(request_tags));
|
||||
}
|
||||
if let Some(response_tags) = &fitlers.by_response_tags {
|
||||
query = query.filter(messages::response_tags.eq(response_tags));
|
||||
}
|
||||
// Execute query
|
||||
let count = query.count().get_result::<i64>(&mut conn)?;
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
pub fn insert(message: Self) -> SirenResult<QueryMessage> {
|
||||
let mut conn = connection()?;
|
||||
let message = diesel::insert_into(messages::table)
|
||||
.values(message)
|
||||
.get_result(&mut conn)?;
|
||||
Ok(message)
|
||||
impl MessageCache {
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
let pool = crate::database::pool();
|
||||
sqlx::query(&format!(
|
||||
"INSERT INTO {} (
|
||||
id,
|
||||
guild_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
created,
|
||||
model,
|
||||
request,
|
||||
response,
|
||||
request_tags,
|
||||
response_tags
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10
|
||||
)",
|
||||
TABLE_NAME
|
||||
))
|
||||
.bind(&self.id)
|
||||
.bind(self.guild_id)
|
||||
.bind(self.channel_id)
|
||||
.bind(self.user_id)
|
||||
.bind(self.created)
|
||||
.bind(&self.model)
|
||||
.bind(&self.request)
|
||||
.bind(&self.response)
|
||||
.bind(&self.request_tags)
|
||||
.bind(&self.response_tags)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user