Merge pull request #9 from bensherriff/develop

v0.2.3
This commit is contained in:
Ben Sherriff
2023-07-26 18:37:06 -04:00
committed by GitHub
11 changed files with 279 additions and 59 deletions

View File

@@ -1 +1 @@
SIREN_VERSION=0.2.2 SIREN_VERSION=0.2.3

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "siren" name = "siren"
version = "0.2.2" version = "0.2.3"
edition = "2021" edition = "2021"
authors = ["Ben Sherriff <hello@bensherriff.com>"] authors = ["Ben Sherriff <hello@bensherriff.com>"]
repository = "https://github.com/bensherriff/siren" repository = "https://github.com/bensherriff/siren"

View File

@@ -1,21 +1,33 @@
FROM rust:1.67 as builder FROM rust:1.70.0 as builder
WORKDIR /siren WORKDIR /siren
RUN apt-get update && apt-get install -y cmake && apt-get auto-remove -y
ADD src ./src/ ADD src ./src/
ADD Cargo.toml ./ ADD Cargo.toml ./
RUN cargo build --release --bin siren RUN apt-get update && apt-get install -y cmake && \
cargo build --release --bin siren
FROM debian:bullseye-slim as packages FROM debian:bullseye-slim as packages
WORKDIR /packages WORKDIR /packages
RUN apt-get update && apt-get install -y libopus-dev libpq5 libpq-dev curl tar xz-utils RUN apt-get update && apt-get install -y curl tar xz-utils && \
RUN curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux > yt-dlp && \ curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux > yt-dlp && \
chmod +x yt-dlp chmod +x yt-dlp && \
RUN curl -L https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz > ffmpeg.tar.xz && \ curl -L https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz > ffmpeg.tar.xz && \
tar -xJf ffmpeg.tar.xz --wildcards */bin/ffmpeg --transform='s/^.*\///' && rm ffmpeg.tar.xz tar -xJf ffmpeg.tar.xz --wildcards */bin/ffmpeg --transform='s/^.*\///' && rm ffmpeg.tar.xz
# FROM debian:bullseye-slim as libraries
# WORKDIR /libraries
# RUN apt-get update && apt-get install -y unzip && \
# curl -L https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu117.zip > libtorch.zip && \
# unzip libtorch.zip && rm libtorch.zip
FROM debian:bullseye-slim as runtime FROM debian:bullseye-slim as runtime
WORKDIR /siren WORKDIR /siren
RUN apt-get update && apt-get install -y libopus-dev libpq5 libpq-dev && apt-get auto-remove -y
COPY --from=builder /siren/target/release/siren /usr/local/bin/siren COPY --from=builder /siren/target/release/siren /usr/local/bin/siren
COPY --from=packages /packages /usr/bin COPY --from=packages /packages /usr/bin
# ADD migrations ./migrations/ # COPY --from=libraries /libraries /usr/lib
# ARG LIBTORCH=/usr/lib/libtorch
# ARG LD_LIBRARY_PATH=${LIBTORCH}/lib:${LD_LIBRARY_PATH}
# ADD migrations ./
CMD ["siren"] CMD ["siren"]

View File

@@ -1,5 +1,6 @@
#!make #!make
SHELL := /bin/bash SHELL := /bin/bash
include .env include .env
include .version include .version
export $(shell sed 's/=.*//' .env) export $(shell sed 's/=.*//' .env)
@@ -9,20 +10,25 @@ SIREN_IMAGES = $(shell docker images 'siren' -a -q)
.PHONY: help build test up down exec clean .PHONY: help build test up down exec clean
build: help: ## Help command
@echo
@cat Makefile | grep -E '^[a-zA-Z\/_-]+:.*?## .*$$' | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
@echo
build: ## Build the docker image
docker build -t siren:${SIREN_VERSION} . docker build -t siren:${SIREN_VERSION} .
test: test: ## Run the docker app as a container
docker run --env-file .env -it --rm --name siren siren:${SIREN_VERSION} docker run --env-file .env -it --rm --name siren siren:${SIREN_VERSION}
up: up: ## Start the app
docker compose up -d docker compose up -d
down: down: ## Stop the app
docker compose down docker compose down
exec: exec: ## Enter running docker container
docker exec -it siren bash docker exec -it siren bash
clean: clean: ## Cleanup docker images
docker rmi $(SIREN_IMAGES) docker rmi $(SIREN_IMAGES)

View File

@@ -2,25 +2,33 @@ version: '3'
services: services:
siren: siren:
image: siren:${SIREN_VERSION} image: siren:${SIREN_VERSION:-latest}
container_name: siren container_name: siren
build: build:
context: . context: .
dockerfile: ./Dockerfile dockerfile: ./Dockerfile
args: args:
- VERSION=${SIREN_VERSION} - VERSION=${SIREN_VERSION:-latest}
volumes: volumes:
- ./app:/siren - ./app:/siren
env_file:
- .env
environment: environment:
DISCORD_TOKEN: ${DISCORD_TOKEN} DISCORD_TOKEN: ${DISCORD_TOKEN}
RUST_LOG: ${RUST_LOG} RUST_LOG: ${RUST_LOG}
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} OPENAI_API_KEY: ${OPENAI_API_KEY}
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
POSTGRES_HOST: db
depends_on: depends_on:
- db - db
restart: unless-stopped restart: unless-stopped
db: db:
image: postgres:latest image: postgres:latest
container_name: siren_db container_name: siren_db
env_file:
- .env
environment: environment:
POSTGRES_USER: ${POSTGRES_USER} POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use log::debug; use log::debug;
@@ -16,6 +17,18 @@ pub mod skip;
pub mod stop; pub mod stop;
pub mod volume; pub mod volume;
#[derive(Clone, Debug)]
pub struct AudioConfigs;
impl TypeMapKey for AudioConfigs {
type Value = Arc<RwLock<HashMap<GuildId, AudioConfig>>>;
}
#[derive(Clone, Debug)]
pub struct AudioConfig {
pub volume: f32
}
/// Joins a Discord voice channel. /// Joins a Discord voice channel.
/// ///
/// # Arguments /// # Arguments
@@ -135,7 +148,7 @@ pub async fn edit_response(ctx: &Context, command: &ApplicationCommandInteractio
/// ///
/// # Returns /// # Returns
/// Result<Metadata, SongbirdError> - Ok if the song was added successfully, Err if there was an error. /// Result<Metadata, SongbirdError> - Ok if the song was added successfully, Err if there was an error.
pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool) -> Result<Metadata, SongbirdError> { pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool, audio_config: Option<&AudioConfig>) -> Result<Metadata, SongbirdError> {
let source = if is_valid_url(url) { let source = if is_valid_url(url) {
Restartable::ytdl(url.to_owned(), lazy).await? Restartable::ytdl(url.to_owned(), lazy).await?
} else { } else {
@@ -144,7 +157,10 @@ pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool) -> Result<M
let mut handler = call.lock().await; let mut handler = call.lock().await;
let track: Input = source.into(); let track: Input = source.into();
let metadata = *track.metadata.clone(); let metadata = *track.metadata.clone();
handler.enqueue_source(track); let track_handle = handler.enqueue_source(track);
if let Some(ac) = audio_config {
let _ = track_handle.set_volume(ac.volume);
}
Ok(metadata) Ok(metadata)
} }

View File

@@ -5,7 +5,7 @@ use serenity::builder::CreateApplicationCommand;
use serenity::model::application::interaction::application_command::ApplicationCommandInteraction; use serenity::model::application::interaction::application_command::ApplicationCommandInteraction;
use songbird::EventHandler; use songbird::EventHandler;
use crate::commands::audio::{join, leave, add_song, get_songbird}; use crate::commands::audio::{join, leave, add_song, get_songbird, AudioConfigs};
use super::{create_response, edit_response}; use super::{create_response, edit_response};
@@ -65,11 +65,16 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
let call_handler = handler_lock.lock().await; let call_handler = handler_lock.lock().await;
call_handler.queue().is_empty() call_handler.queue().is_empty()
}; };
match add_song(handler_lock.clone(), &track_url, is_queue_empty).await { let audio_config = {
let data_read = ctx.data.read().await;
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
};
let ac = audio_config.read().await;
match add_song(handler_lock.clone(), &track_url, is_queue_empty, ac.get(&guild_id)).await {
Ok(added_song) => { Ok(added_song) => {
let track_title = added_song.title.unwrap(); let track_title = added_song.title.unwrap();
debug!("Added song: {}", track_title); debug!("Added track: {}", track_title);
if let Err(why) = edit_response(&ctx, &command, format!("Added song to queue: {}", track_title)).await { if let Err(why) = edit_response(&ctx, &command, format!("Added track to queue: {}", track_title)).await {
error!("Failed to edit response message: {}", why); error!("Failed to edit response message: {}", why);
} }
let mut handler = handler_lock.lock().await; let mut handler = handler_lock.lock().await;

View File

@@ -0,0 +1,85 @@
use log::{error, warn};
use serenity::prelude::*;
use serenity::builder::CreateApplicationCommand;
use serenity::model::application::interaction::application_command::ApplicationCommandInteraction;
use super::{get_songbird, create_response, edit_response, AudioConfigs, AudioConfig};
pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
// Get the volume
let volume = match command.data.options.get(0) {
Some(t) => match &t.value {
Some(v) => match v.as_i64() {
Some(p) => std::cmp::min(100, std::cmp::max(0, p)),
None => {
warn!("Unable to get volume option as a string");
if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await {
error!("Failed to create response message: {}", why);
}
return;
}
}
None => {
warn!("Missing volume option value");
if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await {
error!("Failed to create response message: {}", why);
}
return;
}
}
None => {
warn!("Missing volume option");
if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await {
error!("Failed to create response message: {}", why);
}
return;
}
};
// Format volume to f32 bound between 0.0 and 1.0
let bound_volume = volume as f32 / 100.0;
// Create the initial response
if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await {
error!("Failed to create response message: {}", why);
return;
}
let guild_id = match command.guild_id {
Some(g) => g,
None => {
if let Err(why) = edit_response(&ctx, &command, "Unable to join voice channel".to_string()).await {
error!("Failed to edit response message: {}", why);
}
return;
}
};
let audio_config_lock = {
let data_read = ctx.data.read().await;
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
};
{
let mut audio_configs = audio_config_lock.write().await;
*audio_configs.entry(guild_id).or_insert(AudioConfig { volume: 1.0 }) = AudioConfig { volume: bound_volume };
}
let manager = get_songbird(ctx).await;
if let Some(handler_lock) = manager.get(guild_id) {
let handler = handler_lock.lock().await;
for (_, track_handle) in handler.queue().current_queue().iter().enumerate() {
let _ = track_handle.set_volume(bound_volume);
}
}
if let Err(why) = edit_response(&ctx, &command, format!("Setting the volume to {}", volume)).await {
error!("Failed to set the volume: {}", why);
}
}
pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand {
command.name("volume").description("Set the audio player volume").create_option(|option| { option
.name("volume")
.description("Volume between 0 and 100")
.kind(serenity::model::prelude::command::CommandOptionType::Integer)
.required(true)
})
}

View File

@@ -8,7 +8,9 @@ use log::{error, debug, trace, warn};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use serde_json::Value; use serde_json::Value;
use serenity::model::Permissions;
use serenity::model::channel::Message; use serenity::model::channel::Message;
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
use serenity::prelude::*; use serenity::prelude::*;
use crate::database::models::{NewMessageDB, MessageDB}; use crate::database::models::{NewMessageDB, MessageDB};
@@ -187,7 +189,6 @@ impl OAI {
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &Pool<ConnectionManager<PgConnection>>) { pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &Pool<ConnectionManager<PgConnection>>) {
debug!("Generating response for message: {}", msg.content); debug!("Generating response for message: {}", msg.content);
let typing = msg.channel_id.start_typing(&ctx.http).unwrap();
let guild_id = msg.guild_id.unwrap(); let guild_id = msg.guild_id.unwrap();
let channel_id = msg.channel_id; let channel_id = msg.channel_id;
@@ -205,7 +206,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
.and(crate::database::schema::messages::channel_id.eq(channel_id.0 as i64)) .and(crate::database::schema::messages::channel_id.eq(channel_id.0 as i64))
.and(crate::database::schema::messages::user_id.eq(author_id.0 as i64)) .and(crate::database::schema::messages::user_id.eq(author_id.0 as i64))
) )
.order(crate::database::schema::messages::created.desc()) .order(crate::database::schema::messages::created.asc())
.limit(oai.max_context_questions) .limit(oai.max_context_questions)
.load(&mut connection); .load(&mut connection);
@@ -213,7 +214,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
Ok(r) => { Ok(r) => {
let mut previous_message = "".to_string(); let mut previous_message = "".to_string();
for message in r { for message in r {
previous_message = format!("{}\nYou: {}\n Siren: {}", previous_message, message.request, message.response); previous_message = format!("{}You: {}\n Siren: {}\n", previous_message, message.request, message.response);
} }
Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message }) Some(ChatCompletionMessage { role: GPTRole::User, content: previous_message })
} }
@@ -231,7 +232,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
]; ];
if let Some(mut previous) = previous_messages { if let Some(mut previous) = previous_messages {
previous.content = format!("{}\nYou: {}\nSiren: ", previous.content, parsed_content); previous.content = format!("{}You: {}\nSiren: ", previous.content, parsed_content);
messages.push(previous); messages.push(previous);
} else { } else {
messages.push(ChatCompletionMessage { messages.push(ChatCompletionMessage {
@@ -253,16 +254,40 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
user: Some(msg.author.name.clone()) user: Some(msg.author.name.clone())
}; };
// Get the thread channel ID
let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| {
thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
}).await {
Ok(c) => {
let allow = Permissions::SEND_MESSAGES;
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
let overwrite = PermissionOverwrite {
allow,
deny,
kind: PermissionOverwriteType::Member(msg.author.id),
};
let _ = c.create_permission(&ctx.http, &overwrite).await;
c.id
}
Err(_) => {
channel_id
}
};
let typing = response_channel.start_typing(&ctx.http).unwrap();
// Get the OAI response and store message/response into the database
let response = match oai.get_request(request).await { let response = match oai.get_request(request).await {
Ok(r) => { Ok(r) => {
debug!("Processing response received from OpenAI"); debug!("Processing response received from OpenAI");
if !r.choices.is_empty() { if !r.choices.is_empty() {
// Insert the message into the messages database table
let res = r.choices[0].message.content.clone(); let res = r.choices[0].message.content.clone();
// Insert the message into the messages database table
if let Err(err) = insert_into(crate::database::schema::messages::table).values(NewMessageDB { if let Err(err) = insert_into(crate::database::schema::messages::table).values(NewMessageDB {
id: &r.id, id: &r.id,
guild_id: guild_id.0 as i64, guild_id: guild_id.0 as i64,
channel_id: channel_id.0 as i64, channel_id: response_channel.0 as i64,
user_id: author_id.0 as i64, user_id: author_id.0 as i64,
created: r.created, created: r.created,
model: &model, model: &model,
@@ -286,9 +311,30 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P
}; };
debug!("Writing response: \"{}\"", response); debug!("Writing response: \"{}\"", response);
// Stop the typing indicator and send the response
typing.stop(); typing.stop();
if let Err(why) = msg.channel_id.say(&ctx.http, response).await { if let Err(why) = response_channel.say(&ctx.http, response).await {
error!("Cannot send message: {}", why); error!("Cannot send message: {}", why);
} }
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
// }).await {
// Ok(c) => {
// if let Err(why) = c.say(&ctx.http, response).await {
// error!("Cannot send message: {}", why);
// }
// }
// Err(_) => {
// if let Err(why) = channel_id.say(&ctx.http, response).await {
// error!("Cannot send message: {}", why);
// }
// }
// };
}
fn truncate(s: &str, max_chars: usize) -> &str {
match s.char_indices().nth(max_chars) {
None => s,
Some((idx, _)) => &s[..idx],
}
} }

View File

@@ -1,5 +1,5 @@
use std::env; use std::env;
use std::path::Path; // use std::path::Path;
use diesel::RunQueryDsl; use diesel::RunQueryDsl;
use diesel::r2d2::{Pool, ConnectionManager}; use diesel::r2d2::{Pool, ConnectionManager};
@@ -11,34 +11,51 @@ pub mod schema;
pub fn run_migrations(pool: &Pool<ConnectionManager<PgConnection>>) { pub fn run_migrations(pool: &Pool<ConnectionManager<PgConnection>>) {
let mut connection = pool.get().unwrap(); let mut connection = pool.get().unwrap();
let migrations_dir = Path::new("./migrations"); if let Err(err) = diesel::sql_query("CREATE TABLE IF NOT EXISTS messages (
let migrations = std::fs::read_dir(&migrations_dir).unwrap(); id TEXT PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
for migration in migrations { channel_id BIGINT NOT NULL,
if migration.as_ref().unwrap().file_type().unwrap().is_dir() { user_id BIGINT NOT NULL,
let migration_paths = std::fs::read_dir(&migration.unwrap().path()).unwrap(); created BIGINT NOT NULL,
model TEXT NOT NULL,
for migration_path in migration_paths { request TEXT NOT NULL,
if migration_path.as_ref().unwrap().file_name().eq_ignore_ascii_case("up.sql") { response TEXT NOT NULL,
let path = &migration_path.unwrap().path(); request_tags TEXT[] NOT NULL,
let contents = std::fs::read_to_string(path).expect("Unable to read from file"); response_tags TEXT[] NOT NULL
if let Err(err) = diesel::sql_query(&contents).execute(&mut connection) { )").execute(&mut connection) {
error!("Could not run migration: {}", err); error!("Could not create messages table: {}", err);
} else { } else {
info!("Successfully ran migration: {}", path.display()); info!("Successfully created messages table");
}
}
}
}
} }
// let migrations_dir = Path::new("./migrations");
// let migrations = std::fs::read_dir(&migrations_dir).unwrap();
// for migration in migrations {
// if migration.as_ref().unwrap().file_type().unwrap().is_dir() {
// let migration_paths = std::fs::read_dir(&migration.unwrap().path()).unwrap();
// for migration_path in migration_paths {
// if migration_path.as_ref().unwrap().file_name().eq_ignore_ascii_case("up.sql") {
// let path = &migration_path.unwrap().path();
// let contents = std::fs::read_to_string(path).expect("Unable to read from file");
// if let Err(err) = diesel::sql_query(&contents).execute(&mut connection) {
// error!("Could not run migration: {}", err);
// } else {
// info!("Successfully ran migration: {}", path.display());
// }
// }
// }
// }
// }
} }
pub fn establish_connection() -> Pool<ConnectionManager<PgConnection>> { pub fn establish_connection() -> Pool<ConnectionManager<PgConnection>> {
let database_user = env::var("POSTGRES_USER").expect("Expected a user in the environment"); let database_user = env::var("POSTGRES_USER").expect("Expected a user in the environment");
let database_password = env::var("POSTGRES_PASSWORD").expect("Expected a password in the environment"); let database_password = env::var("POSTGRES_PASSWORD").expect("Expected a password in the environment");
let database_name = env::var("POSTGRES_DB").expect("Expected a database name in the environment"); let database_name = env::var("POSTGRES_DB").expect("Expected a database name in the environment");
let database_host = env::var("POSTGRES_HOST").unwrap_or("localhost".to_string());
let database_url = format!("postgres://{}:{}@localhost/{}", database_user, database_password, database_name); let database_url = format!("postgres://{}:{}@{}/{}", database_user, database_password, database_host, database_name);
let manager = ConnectionManager::<PgConnection>::new(database_url); let manager = ConnectionManager::<PgConnection>::new(database_url);
Pool::builder().build(manager).expect("Failed to create pool.") Pool::builder().build(manager).expect("Failed to create pool.")
} }

View File

@@ -1,7 +1,8 @@
use std::collections::HashSet; use std::collections::{HashSet, HashMap};
use std::env; use std::env;
use std::sync::Arc;
use commands::audio::create_response; use commands::audio::{create_response, AudioConfig, AudioConfigs};
use diesel::r2d2::{Pool, ConnectionManager}; use diesel::r2d2::{Pool, ConnectionManager};
use diesel::pg::PgConnection; use diesel::pg::PgConnection;
@@ -35,7 +36,17 @@ impl EventHandler for Handler {
Some(oai) => { Some(oai) => {
match msg.mentions_me(&ctx.http).await { match msg.mentions_me(&ctx.http).await {
Ok(mentioned) => { Ok(mentioned) => {
if mentioned { let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
Ok(t) => {
match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) {
Some(_) => true,
None => false
}
}
Err(_) => false
};
// let has_bot = msg.channel_id.get_thread_members(&ctx.http).await.unwrap().contains(ctx.cache.current_user_id().0);
if mentioned || bot_in_thread {
commands::oai::generate_response(&ctx, &msg, oai, &self.pool).await; commands::oai::generate_response(&ctx, &msg, oai, &self.pool).await;
} }
} }
@@ -44,7 +55,6 @@ impl EventHandler for Handler {
} }
None => {} None => {}
} }
} }
async fn interaction_create(&self, ctx: Context, interaction: Interaction) { async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
@@ -55,6 +65,7 @@ impl EventHandler for Handler {
"pause" => commands::audio::pause::run(&ctx, &command).await, "pause" => commands::audio::pause::run(&ctx, &command).await,
"resume" => commands::audio::resume::run(&ctx, &command).await, "resume" => commands::audio::resume::run(&ctx, &command).await,
"skip" => commands::audio::skip::run(&ctx, &command).await, "skip" => commands::audio::skip::run(&ctx, &command).await,
"volume" => commands::audio::volume::run(&ctx, &command).await,
_ => { _ => {
let content: String = match command.data.name.as_str() { let content: String = match command.data.name.as_str() {
"ping" => commands::ping::run(&command.data.options), "ping" => commands::ping::run(&command.data.options),
@@ -74,6 +85,14 @@ impl EventHandler for Handler {
warn!("No ready guilds found"); warn!("No ready guilds found");
} }
for guild in ready.guilds { for guild in ready.guilds {
let audio_config_lock = {
let data_read = ctx.data.read().await;
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
};
{
let mut audio_configs = audio_config_lock.write().await;
let _ = audio_configs.insert(guild.id, AudioConfig { volume: 1.0 });
}
let commands = guild.id.set_application_commands(&ctx.http, |commands| { let commands = guild.id.set_application_commands(&ctx.http, |commands| {
commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) }) commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) }) .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) })
@@ -81,6 +100,7 @@ impl EventHandler for Handler {
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) }) .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) }) .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) }) .create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::volume::register(command) })
}).await; }).await;
match commands { match commands {
Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0), Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0),
@@ -122,7 +142,7 @@ async fn main() {
Ok(token) => { Ok(token) => {
info!("Loaded OpenAI token"); info!("Loaded OpenAI token");
Handler { Handler {
oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 10 }), oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 15 }),
pool pool
} }
} }
@@ -140,6 +160,11 @@ async fn main() {
.await .await
.expect("Error creating client"); .expect("Error creating client");
{
let mut data = client.data.write().await;
data.insert::<AudioConfigs>(Arc::new(RwLock::new(HashMap::default())));
}
if let Err(why) = client.start_autosharded().await { if let Err(why) = client.start_autosharded().await {
error!("An error occurred while running the client: {:?}", why); error!("An error occurred while running the client: {:?}", why);
} }