@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "siren"
|
||||
version = "0.2.2"
|
||||
version = "0.2.3"
|
||||
edition = "2021"
|
||||
authors = ["Ben Sherriff <hello@bensherriff.com>"]
|
||||
repository = "https://github.com/bensherriff/siren"
|
||||
|
||||
28
Dockerfile
28
Dockerfile
@@ -1,21 +1,33 @@
|
||||
FROM rust:1.67 as builder
|
||||
FROM rust:1.70.0 as builder
|
||||
WORKDIR /siren
|
||||
RUN apt-get update && apt-get install -y cmake && apt-get auto-remove -y
|
||||
ADD src ./src/
|
||||
ADD Cargo.toml ./
|
||||
RUN cargo build --release --bin siren
|
||||
RUN apt-get update && apt-get install -y cmake && \
|
||||
cargo build --release --bin siren
|
||||
|
||||
FROM debian:bullseye-slim as packages
|
||||
WORKDIR /packages
|
||||
RUN apt-get update && apt-get install -y libopus-dev libpq5 libpq-dev curl tar xz-utils
|
||||
RUN curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux > yt-dlp && \
|
||||
chmod +x yt-dlp
|
||||
RUN curl -L https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz > ffmpeg.tar.xz && \
|
||||
RUN apt-get update && apt-get install -y curl tar xz-utils && \
|
||||
curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux > yt-dlp && \
|
||||
chmod +x yt-dlp && \
|
||||
curl -L https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz > ffmpeg.tar.xz && \
|
||||
tar -xJf ffmpeg.tar.xz --wildcards */bin/ffmpeg --transform='s/^.*\///' && rm ffmpeg.tar.xz
|
||||
|
||||
# FROM debian:bullseye-slim as libraries
|
||||
# WORKDIR /libraries
|
||||
# RUN apt-get update && apt-get install -y unzip && \
|
||||
# curl -L https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu117.zip > libtorch.zip && \
|
||||
# unzip libtorch.zip && rm libtorch.zip
|
||||
|
||||
FROM debian:bullseye-slim as runtime
|
||||
WORKDIR /siren
|
||||
RUN apt-get update && apt-get install -y libopus-dev libpq5 libpq-dev && apt-get auto-remove -y
|
||||
COPY --from=builder /siren/target/release/siren /usr/local/bin/siren
|
||||
COPY --from=packages /packages /usr/bin
|
||||
# ADD migrations ./migrations/
|
||||
# COPY --from=libraries /libraries /usr/lib
|
||||
|
||||
# ARG LIBTORCH=/usr/lib/libtorch
|
||||
# ARG LD_LIBRARY_PATH=${LIBTORCH}/lib:${LD_LIBRARY_PATH}
|
||||
|
||||
# ADD migrations ./
|
||||
CMD ["siren"]
|
||||
|
||||
18
Makefile
18
Makefile
@@ -1,5 +1,6 @@
|
||||
#!make
|
||||
SHELL := /bin/bash
|
||||
|
||||
include .env
|
||||
include .version
|
||||
export $(shell sed 's/=.*//' .env)
|
||||
@@ -9,20 +10,25 @@ SIREN_IMAGES = $(shell docker images 'siren' -a -q)
|
||||
|
||||
.PHONY: help build test up down exec clean
|
||||
|
||||
build:
|
||||
help: ## Help command
|
||||
@echo
|
||||
@cat Makefile | grep -E '^[a-zA-Z\/_-]+:.*?## .*$$' | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
@echo
|
||||
|
||||
build: ## Build the docker image
|
||||
docker build -t siren:${SIREN_VERSION} .
|
||||
|
||||
test:
|
||||
test: ## Run the docker app as a container
|
||||
docker run --env-file .env -it --rm --name siren siren:${SIREN_VERSION}
|
||||
|
||||
up:
|
||||
up: ## Start the app
|
||||
docker compose up -d
|
||||
|
||||
down:
|
||||
down: ## Stop the app
|
||||
docker compose down
|
||||
|
||||
exec:
|
||||
exec: ## Enter running docker container
|
||||
docker exec -it siren bash
|
||||
|
||||
clean:
|
||||
clean: ## Cleanup docker images
|
||||
docker rmi $(SIREN_IMAGES)
|
||||
|
||||
@@ -2,25 +2,33 @@ version: '3'
|
||||
|
||||
services:
|
||||
siren:
|
||||
image: siren:${SIREN_VERSION}
|
||||
image: siren:${SIREN_VERSION:-latest}
|
||||
container_name: siren
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./Dockerfile
|
||||
args:
|
||||
- VERSION=${SIREN_VERSION}
|
||||
- VERSION=${SIREN_VERSION:-latest}
|
||||
volumes:
|
||||
- ./app:/siren
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
DISCORD_TOKEN: ${DISCORD_TOKEN}
|
||||
RUST_LOG: ${RUST_LOG}
|
||||
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB}
|
||||
OPENAI_API_KEY: ${OPENAI_API_KEY}
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
POSTGRES_HOST: db
|
||||
depends_on:
|
||||
- db
|
||||
restart: unless-stopped
|
||||
db:
|
||||
image: postgres:latest
|
||||
container_name: siren_db
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use log::debug;
|
||||
@@ -16,6 +17,18 @@ pub mod skip;
|
||||
pub mod stop;
|
||||
pub mod volume;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AudioConfigs;
|
||||
|
||||
impl TypeMapKey for AudioConfigs {
|
||||
type Value = Arc<RwLock<HashMap<GuildId, AudioConfig>>>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AudioConfig {
|
||||
pub volume: f32
|
||||
}
|
||||
|
||||
/// Joins a Discord voice channel.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -135,7 +148,7 @@ pub async fn edit_response(ctx: &Context, command: &ApplicationCommandInteractio
|
||||
///
|
||||
/// # Returns
|
||||
/// Result<Metadata, SongbirdError> - Ok if the song was added successfully, Err if there was an error.
|
||||
pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool) -> Result<Metadata, SongbirdError> {
|
||||
pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool, audio_config: Option<&AudioConfig>) -> Result<Metadata, SongbirdError> {
|
||||
let source = if is_valid_url(url) {
|
||||
Restartable::ytdl(url.to_owned(), lazy).await?
|
||||
} else {
|
||||
@@ -144,7 +157,10 @@ pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool) -> Result<M
|
||||
let mut handler = call.lock().await;
|
||||
let track: Input = source.into();
|
||||
let metadata = *track.metadata.clone();
|
||||
handler.enqueue_source(track);
|
||||
let track_handle = handler.enqueue_source(track);
|
||||
if let Some(ac) = audio_config {
|
||||
let _ = track_handle.set_volume(ac.volume);
|
||||
}
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use serenity::builder::CreateApplicationCommand;
|
||||
use serenity::model::application::interaction::application_command::ApplicationCommandInteraction;
|
||||
use songbird::EventHandler;
|
||||
|
||||
use crate::commands::audio::{join, leave, add_song, get_songbird};
|
||||
use crate::commands::audio::{join, leave, add_song, get_songbird, AudioConfigs};
|
||||
|
||||
use super::{create_response, edit_response};
|
||||
|
||||
@@ -65,11 +65,16 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
|
||||
let call_handler = handler_lock.lock().await;
|
||||
call_handler.queue().is_empty()
|
||||
};
|
||||
match add_song(handler_lock.clone(), &track_url, is_queue_empty).await {
|
||||
let audio_config = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
|
||||
};
|
||||
let ac = audio_config.read().await;
|
||||
match add_song(handler_lock.clone(), &track_url, is_queue_empty, ac.get(&guild_id)).await {
|
||||
Ok(added_song) => {
|
||||
let track_title = added_song.title.unwrap();
|
||||
debug!("Added song: {}", track_title);
|
||||
if let Err(why) = edit_response(&ctx, &command, format!("Added song to queue: {}", track_title)).await {
|
||||
debug!("Added track: {}", track_title);
|
||||
if let Err(why) = edit_response(&ctx, &command, format!("Added track to queue: {}", track_title)).await {
|
||||
error!("Failed to edit response message: {}", why);
|
||||
}
|
||||
let mut handler = handler_lock.lock().await;
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
use log::{error, warn};
|
||||
|
||||
use serenity::prelude::*;
|
||||
use serenity::builder::CreateApplicationCommand;
|
||||
use serenity::model::application::interaction::application_command::ApplicationCommandInteraction;
|
||||
|
||||
use super::{get_songbird, create_response, edit_response, AudioConfigs, AudioConfig};
|
||||
|
||||
pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
|
||||
// Get the volume
|
||||
let volume = match command.data.options.get(0) {
|
||||
Some(t) => match &t.value {
|
||||
Some(v) => match v.as_i64() {
|
||||
Some(p) => std::cmp::min(100, std::cmp::max(0, p)),
|
||||
None => {
|
||||
warn!("Unable to get volume option as a string");
|
||||
if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await {
|
||||
error!("Failed to create response message: {}", why);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("Missing volume option value");
|
||||
if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await {
|
||||
error!("Failed to create response message: {}", why);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("Missing volume option");
|
||||
if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await {
|
||||
error!("Failed to create response message: {}", why);
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Format volume to f32 bound between 0.0 and 1.0
|
||||
let bound_volume = volume as f32 / 100.0;
|
||||
|
||||
// Create the initial response
|
||||
if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await {
|
||||
error!("Failed to create response message: {}", why);
|
||||
return;
|
||||
}
|
||||
|
||||
let guild_id = match command.guild_id {
|
||||
Some(g) => g,
|
||||
None => {
|
||||
if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await {
|
||||
error!("Failed to edit response message: {}", why);
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
let audio_config_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
|
||||
};
|
||||
{
|
||||
let mut audio_configs = audio_config_lock.write().await;
|
||||
*audio_configs.entry(guild_id).or_insert(AudioConfig { volume: 1.0 }) = AudioConfig { volume: bound_volume };
|
||||
}
|
||||
let manager = get_songbird(ctx).await;
|
||||
if let Some(handler_lock) = manager.get(guild_id) {
|
||||
let handler = handler_lock.lock().await;
|
||||
for (_, track_handle) in handler.queue().current_queue().iter().enumerate() {
|
||||
let _ = track_handle.set_volume(bound_volume);
|
||||
}
|
||||
}
|
||||
if let Err(why) = edit_response(&ctx, &command, format!("Setting the volume to {}", volume)).await {
|
||||
error!("Failed to set the volume: {}", why);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand {
|
||||
command.name("volume").description("Set the audio player volume").create_option(|option| { option
|
||||
.name("volume")
|
||||
.description("Volume between 0 and 100")
|
||||
.kind(serenity::model::prelude::command::CommandOptionType::Integer)
|
||||
.required(true)
|
||||
})
|
||||
}
|
||||
@@ -8,7 +8,9 @@ use log::{error, debug, trace, warn};
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde_json::Value;
|
||||
use serenity::model::Permissions;
|
||||
use serenity::model::channel::Message;
|
||||
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
|
||||
use serenity::prelude::*;
|
||||
|
||||
use crate::database::models::{NewMessageDB, MessageDB};
|
||||
@@ -187,7 +189,6 @@ impl OAI {
|
||||
|
||||
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &Pool<ConnectionManager<PgConnection>>) {
|
||||
debug!("Generating response for message: {}", msg.content);
|
||||
let typing = msg.channel_id.start_typing(&ctx.http).unwrap();
|
||||
|
||||
let guild_id = msg.guild_id.unwrap();
|
||||
let channel_id = msg.channel_id;
|
||||
@@ -205,7 +206,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
|
||||
.and(crate::database::schema::messages::channel_id.eq(channel_id.0 as i64))
|
||||
.and(crate::database::schema::messages::user_id.eq(author_id.0 as i64))
|
||||
)
|
||||
.order(crate::database::schema::messages::created.desc())
|
||||
.order(crate::database::schema::messages::created.asc())
|
||||
.limit(oai.max_context_questions)
|
||||
.load(&mut connection);
|
||||
|
||||
@@ -213,7 +214,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
|
||||
Ok(r) => {
|
||||
let mut previous_message = "".to_string();
|
||||
for message in r {
|
||||
previous_message = format!("{}\nYou: {}\n Siren: {}", previous_message, message.request, message.response);
|
||||
previous_message = format!("{}You: {}\n Siren: {}\n", previous_message, message.request, message.response);
|
||||
}
|
||||
Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message })
|
||||
}
|
||||
@@ -231,7 +232,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
|
||||
];
|
||||
|
||||
if let Some(mut previous) = previous_messages {
|
||||
previous.content = format!("{}\nYou: {}\nSiren: ", previous.content, parsed_content);
|
||||
previous.content = format!("{}You: {}\nSiren: ", previous.content, parsed_content);
|
||||
messages.push(previous);
|
||||
} else {
|
||||
messages.push(ChatCompletionMessage {
|
||||
@@ -253,16 +254,40 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
|
||||
frequency_penalty: Some(0.0),
|
||||
user: Some(msg.author.name.clone())
|
||||
};
|
||||
|
||||
// Get the thread channel ID
|
||||
let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| {
|
||||
thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
||||
}).await {
|
||||
Ok(c) => {
|
||||
let allow = Permissions::SEND_MESSAGES;
|
||||
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
|
||||
let overwrite = PermissionOverwrite {
|
||||
allow,
|
||||
deny,
|
||||
kind: PermissionOverwriteType::Member(msg.author.id),
|
||||
};
|
||||
let _ = c.create_permission(&ctx.http, &overwrite).await;
|
||||
c.id
|
||||
}
|
||||
Err(_) => {
|
||||
channel_id
|
||||
}
|
||||
};
|
||||
|
||||
let typing = response_channel.start_typing(&ctx.http).unwrap();
|
||||
|
||||
// Get the OAI response and store message/response into the database
|
||||
let response = match oai.get_request(request).await {
|
||||
Ok(r) => {
|
||||
debug!("Processing response received from OpenAI");
|
||||
if !r.choices.is_empty() {
|
||||
// Insert the message into the messages database table
|
||||
let res = r.choices[0].message.content.clone();
|
||||
// Insert the message into the messages database table
|
||||
if let Err(err) = insert_into(crate::database::schema::messages::table).values(NewMessageDB {
|
||||
id: &r.id,
|
||||
guild_id: guild_id.0 as i64,
|
||||
channel_id: channel_id.0 as i64,
|
||||
channel_id: response_channel.0 as i64,
|
||||
user_id: author_id.0 as i64,
|
||||
created: r.created,
|
||||
model: &model,
|
||||
@@ -286,9 +311,30 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
|
||||
};
|
||||
debug!("Writing response: \"{}\"", response);
|
||||
|
||||
// Stop the typing indicator and send the response
|
||||
typing.stop();
|
||||
if let Err(why) = msg.channel_id.say(&ctx.http, response).await {
|
||||
if let Err(why) = response_channel.say(&ctx.http, response).await {
|
||||
error!("Cannot send message: {}", why);
|
||||
}
|
||||
|
||||
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
|
||||
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
||||
// }).await {
|
||||
// Ok(c) => {
|
||||
// if let Err(why) = c.say(&ctx.http, response).await {
|
||||
// error!("Cannot send message: {}", why);
|
||||
// }
|
||||
// }
|
||||
// Err(_) => {
|
||||
// if let Err(why) = channel_id.say(&ctx.http, response).await {
|
||||
// error!("Cannot send message: {}", why);
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
}
|
||||
|
||||
fn truncate(s: &str, max_chars: usize) -> &str {
|
||||
match s.char_indices().nth(max_chars) {
|
||||
None => s,
|
||||
Some((idx, _)) => &s[..idx],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
// use std::path::Path;
|
||||
|
||||
use diesel::RunQueryDsl;
|
||||
use diesel::r2d2::{Pool, ConnectionManager};
|
||||
@@ -11,34 +11,51 @@ pub mod schema;
|
||||
|
||||
pub fn run_migrations(pool: &Pool<ConnectionManager<PgConnection>>) {
|
||||
let mut connection = pool.get().unwrap();
|
||||
let migrations_dir = Path::new("./migrations");
|
||||
let migrations = std::fs::read_dir(&migrations_dir).unwrap();
|
||||
|
||||
for migration in migrations {
|
||||
if migration.as_ref().unwrap().file_type().unwrap().is_dir() {
|
||||
let migration_paths = std::fs::read_dir(&migration.unwrap().path()).unwrap();
|
||||
|
||||
for migration_path in migration_paths {
|
||||
if migration_path.as_ref().unwrap().file_name().eq_ignore_ascii_case("up.sql") {
|
||||
let path = &migration_path.unwrap().path();
|
||||
let contents = std::fs::read_to_string(path).expect("Unable to read from file");
|
||||
if let Err(err) = diesel::sql_query(&contents).execute(&mut connection) {
|
||||
error!("Could not run migration: {}", err);
|
||||
} else {
|
||||
info!("Successfully ran migration: {}", path.display());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Err(err) = diesel::sql_query("CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
guild_id BIGINT NOT NULL,
|
||||
channel_id BIGINT NOT NULL,
|
||||
user_id BIGINT NOT NULL,
|
||||
created BIGINT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
request TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
request_tags TEXT[] NOT NULL,
|
||||
response_tags TEXT[] NOT NULL
|
||||
)").execute(&mut connection) {
|
||||
error!("Could not create messages table: {}", err);
|
||||
} else {
|
||||
info!("Successfully created messages table");
|
||||
}
|
||||
// let migrations_dir = Path::new("./migrations");
|
||||
// let migrations = std::fs::read_dir(&migrations_dir).unwrap();
|
||||
|
||||
// for migration in migrations {
|
||||
// if migration.as_ref().unwrap().file_type().unwrap().is_dir() {
|
||||
// let migration_paths = std::fs::read_dir(&migration.unwrap().path()).unwrap();
|
||||
|
||||
// for migration_path in migration_paths {
|
||||
// if migration_path.as_ref().unwrap().file_name().eq_ignore_ascii_case("up.sql") {
|
||||
// let path = &migration_path.unwrap().path();
|
||||
// let contents = std::fs::read_to_string(path).expect("Unable to read from file");
|
||||
// if let Err(err) = diesel::sql_query(&contents).execute(&mut connection) {
|
||||
// error!("Could not run migration: {}", err);
|
||||
// } else {
|
||||
// info!("Successfully ran migration: {}", path.display());
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn establish_connection() -> Pool<ConnectionManager<PgConnection>> {
|
||||
let database_user = env::var("POSTGRES_USER").expect("Expected a user in the environment");
|
||||
let database_password = env::var("POSTGRES_PASSWORD").expect("Expected a password in the environment");
|
||||
let database_name = env::var("POSTGRES_DB").expect("Expected a database name in the environment");
|
||||
let database_host = env::var("POSTGRES_HOST").unwrap_or("localhost".to_string());
|
||||
|
||||
let database_url = format!("postgres://{}:{}@localhost/{}", database_user, database_password, database_name);
|
||||
let database_url = format!("postgres://{}:{}@{}/{}", database_user, database_password, database_host, database_name);
|
||||
let manager = ConnectionManager::<PgConnection>::new(database_url);
|
||||
Pool::builder().build(manager).expect("Failed to create pool.")
|
||||
}
|
||||
35
src/main.rs
35
src/main.rs
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashSet, HashMap};
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
|
||||
use commands::audio::create_response;
|
||||
use commands::audio::{create_response, AudioConfig, AudioConfigs};
|
||||
use diesel::r2d2::{Pool, ConnectionManager};
|
||||
use diesel::pg::PgConnection;
|
||||
|
||||
@@ -35,7 +36,17 @@ impl EventHandler for Handler {
|
||||
Some(oai) => {
|
||||
match msg.mentions_me(&ctx.http).await {
|
||||
Ok(mentioned) => {
|
||||
if mentioned {
|
||||
let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
|
||||
Ok(t) => {
|
||||
match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) {
|
||||
Some(_) => true,
|
||||
None => false
|
||||
}
|
||||
}
|
||||
Err(_) => false
|
||||
};
|
||||
// let has_bot = msg.channel_id.get_thread_members(&ctx.http).await.unwrap().contains(ctx.cache.current_user_id().0);
|
||||
if mentioned || bot_in_thread {
|
||||
commands::oai::generate_response(&ctx, &msg, oai, &self.pool).await;
|
||||
}
|
||||
}
|
||||
@@ -44,7 +55,6 @@ impl EventHandler for Handler {
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
|
||||
@@ -55,6 +65,7 @@ impl EventHandler for Handler {
|
||||
"pause" => commands::audio::pause::run(&ctx, &command).await,
|
||||
"resume" => commands::audio::resume::run(&ctx, &command).await,
|
||||
"skip" => commands::audio::skip::run(&ctx, &command).await,
|
||||
"volume" => commands::audio::volume::run(&ctx, &command).await,
|
||||
_ => {
|
||||
let content: String = match command.data.name.as_str() {
|
||||
"ping" => commands::ping::run(&command.data.options),
|
||||
@@ -74,6 +85,14 @@ impl EventHandler for Handler {
|
||||
warn!("No ready guilds found");
|
||||
}
|
||||
for guild in ready.guilds {
|
||||
let audio_config_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
|
||||
};
|
||||
{
|
||||
let mut audio_configs = audio_config_lock.write().await;
|
||||
let _ = audio_configs.insert(guild.id, AudioConfig { volume: 1.0 });
|
||||
}
|
||||
let commands = guild.id.set_application_commands(&ctx.http, |commands| {
|
||||
commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) })
|
||||
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) })
|
||||
@@ -81,6 +100,7 @@ impl EventHandler for Handler {
|
||||
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) })
|
||||
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) })
|
||||
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) })
|
||||
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::volume::register(command) })
|
||||
}).await;
|
||||
match commands {
|
||||
Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0),
|
||||
@@ -122,7 +142,7 @@ async fn main() {
|
||||
Ok(token) => {
|
||||
info!("Loaded OpenAI token");
|
||||
Handler {
|
||||
oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 10 }),
|
||||
oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 15 }),
|
||||
pool
|
||||
}
|
||||
}
|
||||
@@ -140,6 +160,11 @@ async fn main() {
|
||||
.await
|
||||
.expect("Error creating client");
|
||||
|
||||
{
|
||||
let mut data = client.data.write().await;
|
||||
data.insert::<AudioConfigs>(Arc::new(RwLock::new(HashMap::default())));
|
||||
}
|
||||
|
||||
if let Err(why) = client.start_autosharded().await {
|
||||
error!("An error occurred while running the client: {:?}", why);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user