v0.1.20 Error handling and added local track support

This commit is contained in:
2023-04-17 08:53:04 -04:00
parent c226095b1a
commit 312f87d91a
16 changed files with 367 additions and 156 deletions

1
.gitignore vendored
View File

@@ -4,4 +4,5 @@
**/app/ **/app/
**/settings.json **/settings.json
**/logs/ **/logs/
**/audio/
.env .env

View File

@@ -1 +1 @@
export SIREN_VERSION=0.1.19 export SIREN_VERSION=0.1.20

View File

@@ -6,9 +6,7 @@ import com.bensherriff.siren.commands.*;
import com.bensherriff.siren.database.DatabaseManager; import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException; import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.ai.OpenAIManager; import com.bensherriff.siren.ai.OpenAIManager;
import com.bensherriff.siren.settings.GuildSettings; import com.bensherriff.siren.settings.*;
import com.bensherriff.siren.settings.Settings;
import com.bensherriff.siren.settings.SettingsManager;
import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Guild; import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.entities.Member; import net.dv8tion.jda.api.entities.Member;
@@ -22,6 +20,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@@ -79,6 +78,7 @@ public class Listener extends ListenerAdapter {
this.openAIManager = new OpenAIManager(this); this.openAIManager = new OpenAIManager(this);
DatabaseManager.createTables(); DatabaseManager.createTables();
populateAudioTable();
} }
public void closeAudioConnection(long guildID) { public void closeAudioConnection(long guildID) {
@@ -107,7 +107,7 @@ public class Listener extends ListenerAdapter {
long guildId = Long.parseLong(guild.getId()); long guildId = Long.parseLong(guild.getId());
AudioHandler audioHandler; AudioHandler audioHandler;
if (guild.getAudioManager().getSendingHandler() == null) { if (guild.getAudioManager().getSendingHandler() == null || !settings.getGuildSettings().containsKey(guildId)) {
LOGGER.info("Creating Audio Handler for guild {}", guildId); LOGGER.info("Creating Audio Handler for guild {}", guildId);
if (!settings.getGuildSettings().containsKey(guildId)) { if (!settings.getGuildSettings().containsKey(guildId)) {
settings.getGuildSettings().put(guildId, new GuildSettings()); settings.getGuildSettings().put(guildId, new GuildSettings());
@@ -122,6 +122,32 @@ public class Listener extends ListenerAdapter {
return audioHandler; return audioHandler;
} }
private void populateAudioTable() {
DatabaseManager.clearTable("audio");
File directory = new File(SettingsManager.AUDIO_DIRECTORY);
try {
if (!directory.exists() && !directory.mkdirs()) {
LOGGER.error("Failed to create directory at {}", directory.getPath());
return;
}
Tracks tracks = SettingsManager.load(SettingsManager.TRACKS_PATH, Tracks.class);
int rows = 0;
File[] files = directory.listFiles();
if (files != null) {
for (Track track : tracks.getTracks()) {
for (File file : files) {
if (file.exists() && file.getName().equals(track.getFileName())) {
rows += DatabaseManager.storeAudio(track.getFileName(), track.getTags());
}
}
}
}
LOGGER.debug("Updated with {} local tracks", rows);
} catch (IOException ex) {
LOGGER.error("Failed to load local tracks; {}", ex.getMessage());
}
}
@Override @Override
public void onReady(@NotNull ReadyEvent event) { public void onReady(@NotNull ReadyEvent event) {
super.onReady(event); super.onReady(event);
@@ -131,7 +157,7 @@ public class Listener extends ListenerAdapter {
commands.put("volume", new VolumeCommand(this)); commands.put("volume", new VolumeCommand(this));
commands.put("pause", new PauseCommand(this)); commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this)); commands.put("resume", new ResumeCommand(this));
commands.put("image", new ImageCommand(this)); // commands.put("image", new ImageCommand(this));
jda.getGuilds().forEach(guild -> executor.execute(() -> { jda.getGuilds().forEach(guild -> executor.execute(() -> {
LOGGER.debug("Updating commands for \"{}\" <{}>", guild.getName(), guild.getId()); LOGGER.debug("Updating commands for \"{}\" <{}>", guild.getName(), guild.getId());
guild.updateCommands().addCommands( guild.updateCommands().addCommands(
@@ -151,7 +177,7 @@ public class Listener extends ListenerAdapter {
try { try {
commands.get(command).execute(event); commands.get(command).execute(event);
} catch (Exception ex) { } catch (Exception ex) {
LOGGER.error(ex.getMessage()); LOGGER.error("Failed to execute command; {}", ex.getMessage());
event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue(); event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue();
} }
}); });

View File

@@ -30,7 +30,7 @@ public class Main {
try { try {
start(); start();
} catch (Exception ex) { } catch (Exception ex) {
LOGGER.error(ex.getMessage()); LOGGER.error("Caught unhandled exception; {}", ex.getMessage());
} }
} }

View File

@@ -23,8 +23,10 @@ import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel;
import net.dv8tion.jda.api.events.message.MessageReceivedEvent; import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.postgresql.jdbc.PgArray;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Duration; import java.time.Duration;
import java.util.*; import java.util.*;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@@ -141,15 +143,37 @@ public class OpenAIManager {
// Handle System Messages // Handle System Messages
chatMessages.add(createSystemMessage("You are a discord bot named Siren")); chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName())); chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
if (event.isFromThread()) { if (event.isFromThread()) {
String query = new QueryBuilder().from("messages") String query = new QueryBuilder("messages")
.where("guild_id = ? AND thread_id = ?") .where("guild_id = ? AND thread_id = ?")
.orderBy("timestamp DESC") .orderBy("timestamp DESC")
.limit(10) .limit(10)
.build(); .build();
List<MessageData> previousMessages = DatabaseManager.getMessages(
// Build MessageData objects from query results
List<Map<String, Object>> results = DatabaseManager.query(
query, event.getGuild().getIdLong(), event.getChannel().getIdLong()); query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
List<MessageData> previousMessages = new ArrayList<>();
for (Map<String, Object> result : results) {
Object[] resultTopicObjects = (Object[]) ((PgArray) result.get("topics")).getArray();
Set<String> resultTopics = new HashSet<>();
for (Object object : resultTopicObjects) {
resultTopics.add((String) object);
}
MessageData messageData = new MessageData(
(long) result.get("guild_id"),
(long) result.get("thread_id"),
(long) result.get("user_id"),
(String) result.get("message_type"),
(String) result.get("message_text"),
(String) result.get("message_response"),
resultTopics,
(Timestamp) result.get("timestamp")
);
previousMessages.add(messageData);
}
Set<String> potentialTopics = new HashSet<>(); Set<String> potentialTopics = new HashSet<>();
for (MessageData previousMessage : previousMessages) { for (MessageData previousMessage : previousMessages) {
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " + ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
@@ -157,10 +181,10 @@ public class OpenAIManager {
"\". You replied with \"" + previousMessage.getMessageResponse() + "\"."); "\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText())); potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText()));
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse())); potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse()));
chatMessages.add(previousChatMessage); // chatMessages.add(previousChatMessage);
LOGGER.trace("Potential topics: {}", potentialTopics);
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
} }
LOGGER.trace("Potential topics: {}", potentialTopics);
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
} }
return ChatCompletionRequest.builder() return ChatCompletionRequest.builder()

View File

@@ -68,29 +68,29 @@ public class AudioHandler extends AudioEventAdapter implements AudioSendHandler
} }
public void setVolume(int volume) { public void setVolume(int volume) {
LOGGER.debug("Set volume to {}", volume); LOGGER.trace("Set volume to {}", volume);
player.setVolume(volume); player.setVolume(volume);
} }
@Override @Override
public void onPlayerPause(AudioPlayer player) { public void onPlayerPause(AudioPlayer player) {
LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title); LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
} }
@Override @Override
public void onPlayerResume(AudioPlayer player) { public void onPlayerResume(AudioPlayer player) {
LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title); LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
} }
@Override @Override
public void onTrackStart(AudioPlayer player, AudioTrack track) { public void onTrackStart(AudioPlayer player, AudioTrack track) {
LOGGER.debug("Starting track {}", track.getInfo().title); LOGGER.trace("Starting track {}", track.getInfo().title);
manager.getListener().getJDA().getPresence().setActivity(Activity.playing(track.getInfo().title)); manager.getListener().getJDA().getPresence().setActivity(Activity.playing(track.getInfo().title));
} }
@Override @Override
public void onTrackEnd(AudioPlayer player, AudioTrack track, AudioTrackEndReason endReason) { public void onTrackEnd(AudioPlayer player, AudioTrack track, AudioTrackEndReason endReason) {
LOGGER.debug("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext); LOGGER.trace("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext);
if (queue.isEmpty()) { if (queue.isEmpty()) {
manager.getListener().closeAudioConnection(guildID); manager.getListener().closeAudioConnection(guildID);
manager.getListener().getJDA().getPresence().setActivity(Activity.playing("nothing")); manager.getListener().getJDA().getPresence().setActivity(Activity.playing("nothing"));

View File

@@ -1,8 +1,17 @@
package com.bensherriff.siren.audio; package com.bensherriff.siren.audio;
import com.bensherriff.siren.Listener; import com.bensherriff.siren.Listener;
import com.sedmelluq.discord.lavaplayer.player.AudioConfiguration;
import com.sedmelluq.discord.lavaplayer.player.DefaultAudioPlayerManager; import com.sedmelluq.discord.lavaplayer.player.DefaultAudioPlayerManager;
import com.sedmelluq.discord.lavaplayer.source.AudioSourceManagers; import com.sedmelluq.discord.lavaplayer.source.AudioSourceManagers;
import com.sedmelluq.discord.lavaplayer.source.bandcamp.BandcampAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.beam.BeamAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.http.HttpAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.soundcloud.SoundCloudAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.twitch.TwitchStreamAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.vimeo.VimeoAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.youtube.YoutubeAudioSourceManager;
public class PlayerManager extends DefaultAudioPlayerManager { public class PlayerManager extends DefaultAudioPlayerManager {
@@ -17,6 +26,15 @@ public class PlayerManager extends DefaultAudioPlayerManager {
} }
public void initialize() { public void initialize() {
getConfiguration().setResamplingQuality(AudioConfiguration.ResamplingQuality.MEDIUM);
registerSourceManager(new YoutubeAudioSourceManager());
registerSourceManager(SoundCloudAudioSourceManager.createDefault());
registerSourceManager(new BandcampAudioSourceManager());
registerSourceManager(new VimeoAudioSourceManager());
registerSourceManager(new TwitchStreamAudioSourceManager());
registerSourceManager(new BeamAudioSourceManager());
registerSourceManager(new HttpAudioSourceManager());
registerSourceManager(new LocalAudioSourceManager());
AudioSourceManagers.registerRemoteSources(this); AudioSourceManagers.registerRemoteSources(this);
AudioSourceManagers.registerLocalSource(this); AudioSourceManagers.registerLocalSource(this);
} }

View File

@@ -7,6 +7,7 @@ import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEve
import net.dv8tion.jda.api.interactions.commands.OptionMapping; import net.dv8tion.jda.api.interactions.commands.OptionMapping;
import net.dv8tion.jda.api.interactions.commands.OptionType; import net.dv8tion.jda.api.interactions.commands.OptionType;
import net.dv8tion.jda.api.interactions.commands.build.Commands; import net.dv8tion.jda.api.interactions.commands.build.Commands;
import net.dv8tion.jda.api.interactions.commands.build.OptionData;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
@@ -18,7 +19,11 @@ public class ImageCommand extends Command {
slashCommandData = Commands.slash("image", "Generate an image using DALL-E") slashCommandData = Commands.slash("image", "Generate an image using DALL-E")
.addOption(OptionType.STRING, "prompt", "The prompt for image generation", true) .addOption(OptionType.STRING, "prompt", "The prompt for image generation", true)
.addOption(OptionType.INTEGER, "count", "The number of images to be generated", false) .addOption(OptionType.INTEGER, "count", "The number of images to be generated", false)
.addOption(OptionType.INTEGER, "size", "Size of the picture, either 1 (small), 2 (medium), or 3 (large)", false); .addOptions(new OptionData(OptionType.STRING, "type", "The size of the picture", false)
.addChoice("Small", ImageSize.SMALL.getSize())
.addChoice("Medium", ImageSize.MEDIUM.getSize())
.addChoice("Large", ImageSize.LARGE.getSize())
);
} }
//TODO Store image in database //TODO Store image in database

View File

@@ -1,32 +1,94 @@
package com.bensherriff.siren.commands; package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler; import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.database.QueryBuilder;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException; import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.Listener; import com.bensherriff.siren.Listener;
import com.bensherriff.siren.settings.SettingsManager;
import com.sedmelluq.discord.lavaplayer.container.MediaContainerDescriptor;
import com.sedmelluq.discord.lavaplayer.player.AudioLoadResultHandler; import com.sedmelluq.discord.lavaplayer.player.AudioLoadResultHandler;
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioTrack;
import com.sedmelluq.discord.lavaplayer.tools.FriendlyException; import com.sedmelluq.discord.lavaplayer.tools.FriendlyException;
import com.sedmelluq.discord.lavaplayer.track.AudioPlaylist; import com.sedmelluq.discord.lavaplayer.track.AudioPlaylist;
import com.sedmelluq.discord.lavaplayer.track.AudioTrack; import com.sedmelluq.discord.lavaplayer.track.AudioTrack;
import com.sedmelluq.discord.lavaplayer.track.AudioTrackInfo;
import com.sedmelluq.discord.lavaplayer.track.BasicAudioPlaylist;
import net.dv8tion.jda.api.entities.Guild; import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent; import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.OptionType; import net.dv8tion.jda.api.interactions.commands.OptionType;
import net.dv8tion.jda.api.interactions.commands.build.Commands; import net.dv8tion.jda.api.interactions.commands.build.Commands;
import net.dv8tion.jda.api.interactions.commands.build.SubcommandData;
import java.io.IOException; import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
public class PlayCommand extends Command { public class PlayCommand extends Command {
public PlayCommand(Listener listener) { public PlayCommand(Listener listener) {
super(listener); super(listener);
slashCommandData = Commands.slash("play", "Play a track from a URL") slashCommandData = Commands.slash("play", "Play a track")
.addOption(OptionType.STRING, "url", "Track URL", true); .addSubcommands(new SubcommandData("url", "Play a track from a URL")
.addOption(OptionType.STRING, "url", "Track URL", true))
.addSubcommands(new SubcommandData("local", "Play a track from a local file")
.addOption(OptionType.STRING, "file", "Local file name", true))
.addSubcommands(new SubcommandData("tags", "Play a track based on tags")
.addOption(OptionType.STRING, "tags", "The list of tags separated by a semicolon", true)
.addOption(OptionType.BOOLEAN, "inclusive", "If true, tracks can match any tag"));
} }
@Override @Override
public void execute(SlashCommandInteractionEvent event) throws IOException { public void execute(SlashCommandInteractionEvent event) throws IOException {
String trackURL = Objects.requireNonNull(event.getOption("url")).getAsString(); String audioDirectoryPath = SettingsManager.AUDIO_DIRECTORY + SettingsManager.SEPARATOR;
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event)); if ("url".equals(event.getSubcommandName())) {
String trackURL = Objects.requireNonNull(event.getOption("url")).getAsString();
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event));
} else if ("local".equals(event.getSubcommandName())) {
String fileName = Objects.requireNonNull(event.getOption("file")).getAsString();
if (!fileName.contains(".m4a")) {
fileName = fileName.concat(".m4a");
}
String trackURL = audioDirectoryPath + fileName;
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event));
} else if ("tags".equals(event.getSubcommandName())) {
String query;
String tagsString = Objects.requireNonNull(event.getOption("tags")).getAsString();
String[] tags = tagsString.split(";,");
for (int i = 0; i < tags.length; i++) {
tags[i] = tags[i].trim();
}
if (event.getOption("inclusive") != null && Objects.requireNonNull(event.getOption("inclusive")).getAsBoolean()) {
query = new QueryBuilder("audio").where("tags && ARRAY[?]").build();
} else {
query = new QueryBuilder("audio").where("tags @> ARRAY[?]").build();
}
try {
List<AudioTrack> tracks = new ArrayList<>();
List<Map<String, Object>> results = DatabaseManager.query(query, List.of(tags));
if (results.isEmpty()) {
event.getHook().sendMessage("No tracks found with the those tags.").queue();
return;
}
for (Map<String, Object> result : results) {
// String title = (String) result.get("title");
// String author = (String) result.get("author");
// long length = (Long) result.get("length");
// String identifier = (String) result.get("identifier");
String fileName = audioDirectoryPath.concat((String) result.get("file_name"));
LOGGER.debug("{}", fileName);
listener.getPlayerManager().loadItemOrdered(event.getGuild(), fileName, new ResultHandler(event));
}
// AudioPlaylist playlist = new BasicAudioPlaylist("Playlist based on tags", tracks, null, false);
} catch (SQLException ex) {
LOGGER.error("Failed to retrieve audio tags; {}", ex.getMessage());
}
}
} }
private class ResultHandler implements AudioLoadResultHandler { private class ResultHandler implements AudioLoadResultHandler {
@@ -52,7 +114,11 @@ public class PlayCommand extends Command {
public void trackLoaded(AudioTrack track) { public void trackLoaded(AudioTrack track) {
try { try {
playTrack(guild, userID, audioHandler, track); playTrack(guild, userID, audioHandler, track);
event.getHook().sendMessage("Adding **" + track.getInfo().title + "** to queue...").queue(); String trackTitle = "**" + track.getInfo().title + "**";
if (trackTitle.equalsIgnoreCase("Unknown title")) {
trackTitle = "track";
}
event.getHook().sendMessage("Adding " + trackTitle + " to queue...").queue();
} catch (EmptyVoiceChannelException e) { } catch (EmptyVoiceChannelException e) {
event.getHook().sendMessage("You must be connected to a voice channel in order to play tracks!").queue(); event.getHook().sendMessage("You must be connected to a voice channel in order to play tracks!").queue();
} catch (Exception e) { } catch (Exception e) {
@@ -80,7 +146,7 @@ public class PlayCommand extends Command {
@Override @Override
public void noMatches() { public void noMatches() {
event.getHook().sendMessage("Nothing found at that URL").queue(); event.getHook().sendMessage("No track found").queue();
} }
@Override @Override

View File

@@ -41,7 +41,7 @@ public class DatabaseConnection {
try { try {
connection.close(); connection.close();
} catch (SQLException ex) { } catch (SQLException ex) {
LOGGER.error(ex.getMessage()); LOGGER.error("Failed to close connection; {}", ex.getMessage());
} }
}); });
} }

View File

@@ -4,6 +4,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import java.sql.*; import java.sql.*;
import java.sql.Date;
import java.util.*; import java.util.*;
import static java.util.Map.entry; import static java.util.Map.entry;
@@ -11,49 +12,38 @@ import static java.util.Map.entry;
public class DatabaseManager { public class DatabaseManager {
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class); private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
private static final Map<String, String> createTableQueries = Map.ofEntries( private static final Map<String, List<String>> createTableQueries = Map.ofEntries(
entry("messages", "CREATE TABLE IF NOT EXISTS messages (" + entry("messages", List.of(
"id SERIAL PRIMARY KEY, " + "id SERIAL PRIMARY KEY",
"guild_id BIGINT NOT NULL, " + "guild_id BIGINT NOT NULL",
"thread_id BIGINT NOT NULL, " + "thread_id BIGINT NOT NULL",
"user_id BIGINT NOT NULL, " + "user_id BIGINT NOT NULL",
"message_type VARCHAR(20) NOT NULL," + "message_type VARCHAR(20) NOT NULL",
"message_text TEXT NOT NULL," + "message_text TEXT NOT NULL",
"message_response TEXT, " + "message_response TEXT",
"topics TEXT[], " + "topics TEXT[]",
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" + "timestamp TIMESTAMP NOT NULL DEFAULT NOW()")),
")"), entry("embeddings", List.of(
entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" + "id SERIAL PRIMARY KEY",
"id SERIAL PRIMARY KEY, " + "embeddings FLOAT[] NOT NULL")),
"embeddings FLOAT[] NOT NULL" + entry("message_embeddings", List.of(
")"), "id SERIAL PRIMARY KEY",
entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" + "message_id INT NOT NULL",
"id SERIAL PRIMARY KEY, " + "embedding_id INT NOT NULL")),
"message_id INT NOT NULL, " + entry("audio", List.of(
"embedding_id INT NOT NULL" + "id SERIAL PRIMARY KEY",
")") "title TEXT NOT NULL",
"author TEXT NOT NULL",
"length BIGINT NOT NULL",
"identifier TEXT NOT NULL",
"file_name TEXT NOT NULL",
"tags TEXT[]"
))
); );
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
"message_type, " +
"guild_id, " +
"thread_id, " +
"user_id, " +
"message_text, " +
"message_response, " +
"topics) " +
"VALUES (?, ?, ?, ?, ?, ?, ?)";
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
"embeddings) " +
"VALUES (?)";
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
"message_id, " +
"embeddings_id, " +
"VALUES (?, ?)";
public static void createTables() { public static void createTables() {
for (Map.Entry<String, String> entry : createTableQueries.entrySet()) { for (Map.Entry<String, List<String>> entry : createTableQueries.entrySet()) {
if (!createTable(entry.getKey())) { if (!createTable(entry.getKey(), entry.getValue())) {
LOGGER.warn("Failed to create one or more required database tables"); LOGGER.warn("Failed to create one or more required database tables");
return; return;
} }
@@ -61,79 +51,135 @@ public class DatabaseManager {
LOGGER.debug("Databases initialized"); LOGGER.debug("Databases initialized");
} }
private static boolean createTable(String tableName) { private static boolean createTable(String tableName, List<String> columns) {
if (tableExists(tableName)) { if (tableExists(tableName)) {
return true; return true;
} else { } else {
try { try {
StringBuilder stringBuilder = new StringBuilder("CREATE TABLE IF NOT EXISTS ")
.append(tableName).append(" ( ");
for (int i = 0; i < columns.size(); i++) {
stringBuilder.append(columns.get(i));
if (i != columns.size() - 1) {
stringBuilder.append(", ");
}
}
stringBuilder.append(")");
Connection connection = DatabaseConnection.getConnection(); Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating '{}' database table if it does not exist", tableName); LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
Statement statement = connection.createStatement(); Statement statement = connection.createStatement();
statement.execute(createTableQueries.get(tableName)); statement.execute(stringBuilder.toString());
return true; return true;
} catch (SQLException ex) { } catch (SQLException ex) {
LOGGER.error(ex.getMessage()); LOGGER.error("Failed to create table; {}", ex.getMessage());
return false; return false;
} }
} }
} }
public static int storeMessage(MessageData messageData) { public static int storeMessage(MessageData messageData) {
if (!tableExists("messages")) { String INSERT_MESSAGE = "INSERT INTO messages (" +
LOGGER.warn("Table 'messages' does not exist"); "message_type, " +
return -1; "guild_id, " +
} "thread_id, " +
try { "user_id, " +
Connection connection = DatabaseConnection.getConnection(); "message_text, " +
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE); "message_response, " +
preparedStatement.setString(1, messageData.getMessageType()); "topics) " +
preparedStatement.setLong(2, messageData.getGuildId()); "VALUES (?, ?, ?, ?, ?, ?, ?)";
preparedStatement.setLong(3, messageData.getThreadId()); return storeMessage("messages", INSERT_MESSAGE,
preparedStatement.setLong(4, messageData.getUserId()); messageData.getMessageType(),
preparedStatement.setString(5, messageData.getMessageText()); messageData.getGuildId(),
preparedStatement.setString(6, messageData.getMessageResponse()); messageData.getThreadId(),
preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray())); messageData.getUserId(),
return preparedStatement.executeUpdate(); messageData.getMessageText(),
} catch (SQLException ex) { messageData.getMessageResponse(),
LOGGER.error(ex.getMessage()); messageData.getTopics()
return -1; );
}
} }
public static int storeEmbedding(List<Double> data) { public static int storeEmbedding(List<Double> data) {
if (!tableExists("embeddings")) { String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
LOGGER.warn("Table 'embeddings' does not exist"); "embeddings) " +
return -1; "VALUES (?)";
} return storeMessage("embeddings", INSERT_EMBEDDING, data);
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING);
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
} }
public static int storeMessageEmbeddings(int messageId, int embeddingId) { public static int storeMessageEmbeddings(int messageId, int embeddingId) {
if (!tableExists("message_embeddings")) { String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
LOGGER.warn("Table 'message_embeddings' does not exist"); "message_id, " +
"embeddings_id, " +
"VALUES (?, ?)";
return storeMessage("message_embeddings", INSERT_MESSAGE_EMBEDDINGS, messageId, embeddingId);
}
public static int storeAudio(String fileName, String[] tags) {
String INSERT_AUDIO = "INSERT INTO audio (" +
"file_name, " +
"tags) " +
"VALUES (?, ?)";
return storeMessage("audio", INSERT_AUDIO, fileName, tags);
}
private static int storeMessage(String tableName, String insertString, Object... params) {
if (!tableExists(tableName)) {
LOGGER.warn("Table '{}' does not exist", tableName);
return -1; return -1;
} }
try { try {
Connection connection = DatabaseConnection.getConnection(); Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS); PreparedStatement preparedStatement = connection.prepareStatement(insertString);
preparedStatement.setInt(1, messageId); int i = 1;
preparedStatement.setInt(2, embeddingId); for (Object param : params) {
if (param == null) {
preparedStatement.setNull(i, Types.NULL);
} else if (param instanceof String) {
preparedStatement.setString(i, (String) param);
} else if (param instanceof Integer) {
preparedStatement.setInt(i, (int) param);
} else if (param instanceof Double) {
preparedStatement.setDouble(i, (double) param);
} else if (param instanceof Float) {
preparedStatement.setFloat(i, (float) param);
} else if (param instanceof Long) {
preparedStatement.setLong(i, (long) param);
} else if (param instanceof Boolean) {
preparedStatement.setBoolean(i, (boolean) param);
} else if (param instanceof Date) {
preparedStatement.setDate(i, (Date) param);
} else if (param instanceof Time) {
preparedStatement.setTime(i, (Time) param);
} else if (param instanceof Timestamp) {
preparedStatement.setTimestamp(i, (Timestamp) param);
} else if (param instanceof byte[]) {
preparedStatement.setBytes(i, (byte[]) param);
} else if (param instanceof Blob) {
preparedStatement.setBlob(i, (Blob) param);
} else if (param instanceof Clob) {
preparedStatement.setClob(i, (Clob) param);
} else if (param instanceof List && ((List<?>) param).get(0) instanceof String) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray()));
} else if (param instanceof List && ((List<?>) param).get(0) instanceof Double) {
preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0])));
} else if (param instanceof Set && ((Set<?>) param).toArray()[0] instanceof String) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray()));
} else if (param instanceof String[]) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param)));
} else {
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass());
}
i++;
}
return preparedStatement.executeUpdate(); return preparedStatement.executeUpdate();
} catch (SQLException ex) { } catch (SQLException ex) {
LOGGER.error(ex.getMessage()); LOGGER.error("Failed to store message; {}", ex.getMessage());
return -1; return -1;
} }
} }
public static List<MessageData> getMessages(String query, Object... params) throws SQLException, IllegalArgumentException { public static List<Map<String, Object>> query(String query, Object... params) throws SQLException, IllegalArgumentException {
LOGGER.trace("Query: <{}>", query); LOGGER.trace("Query: <{}>", query);
Connection connection = DatabaseConnection.getConnection(); Connection connection = DatabaseConnection.getConnection();
PreparedStatement stmt = connection.prepareStatement(query); PreparedStatement stmt = connection.prepareStatement(query);
@@ -159,21 +205,19 @@ public class DatabaseManager {
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName()); throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName());
} }
} }
ResultSet resultSet = stmt.executeQuery(); List<Map<String, Object>> resultList = new ArrayList<>();
List<MessageData> resultList = new ArrayList<>(); try (ResultSet resultSet = stmt.executeQuery()) {
while (resultSet.next()) { ResultSetMetaData metaData = resultSet.getMetaData();
Array array = resultSet.getArray(8); int columnCount = metaData.getColumnCount();
MessageData messageData = new MessageData( while (resultSet.next()) {
resultSet.getLong(2), Map<String, Object> rowMap = new HashMap<>();
resultSet.getLong(3), for (int j = 1; j <= columnCount; j++) {
resultSet.getLong(4), rowMap.put(metaData.getColumnName(j), resultSet.getObject(j));
resultSet.getString(5), }
resultSet.getString(6), resultList.add(rowMap);
resultSet.getString(7), }
new HashSet<>(Arrays.asList((String[]) array.getArray())), } catch (SQLException ex) {
resultSet.getTimestamp(9) LOGGER.error("Failed to execute query; {}", ex.getMessage());
);
resultList.add(messageData);
} }
return resultList; return resultList;
} }
@@ -185,8 +229,18 @@ public class DatabaseManager {
ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'"); ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'");
return resultSet.next(); return resultSet.next();
} catch (SQLException ex) { } catch (SQLException ex) {
LOGGER.error("Failed to check if table exists; {}" + ex.getMessage()); LOGGER.error("Failed to check if table exists; {}", ex.getMessage());
return false; return false;
} }
} }
public static void clearTable(String tableName) {
try {
Connection connection = DatabaseConnection.getConnection();
Statement statement = connection.createStatement();
statement.executeUpdate("TRUNCATE TABLE " + tableName);
} catch (SQLException ex) {
LOGGER.error("Failed to clear table; {}", ex.getMessage());
}
}
} }

View File

@@ -3,18 +3,17 @@ package com.bensherriff.siren.database;
public class QueryBuilder { public class QueryBuilder {
private boolean distinct; private boolean distinct;
private String columnList; private String columnList;
private String tableName; private final String tableName;
private String whereClause; private String whereClause;
private String orderByClause; private String orderByClause;
private Integer limit; private Integer limit;
public QueryBuilder select(String columnList) { public QueryBuilder(String tableName) {
this.columnList = columnList; this.tableName = tableName;
return this;
} }
public QueryBuilder from(String tableName) { public QueryBuilder select(String columnList) {
this.tableName = tableName; this.columnList = columnList;
return this; return this;
} }

View File

@@ -1,18 +1,15 @@
package com.bensherriff.siren.settings; package com.bensherriff.siren.settings;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
public class Settings { public class Settings {
private String token = ""; private String token = "";
private String owner = ""; private String owner = "250842261221277697";
private int threadPool = 2; private int threadPool = 2;
private Map<Long, GuildSettings> guildSettings = new HashMap<>(); private Map<Long, GuildSettings> guildSettings = new HashMap<>();
private OpenAISettings openAISettings = new OpenAISettings(); private OpenAISettings openAISettings = new OpenAISettings();
private List<TrackSettings> tracks = new ArrayList<>();
public String getToken() { public String getToken() {
return token; return token;
@@ -54,12 +51,4 @@ public class Settings {
public void setOpenAISettings(OpenAISettings openAISettings) { public void setOpenAISettings(OpenAISettings openAISettings) {
this.openAISettings = openAISettings; this.openAISettings = openAISettings;
} }
public List<TrackSettings> getTracks() {
return tracks;
}
public void setTracks(List<TrackSettings> tracks) {
this.tracks = tracks;
}
} }

View File

@@ -10,11 +10,15 @@ import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
public class SettingsManager { public class SettingsManager {
private static final Logger LOGGER = LogManager.getLogger(SettingsManager.class); private static final Logger LOGGER = LogManager.getLogger(SettingsManager.class);
public static final String SEPARATOR = File.separator; public static final String SEPARATOR = File.separator;
public static final String PATH = String.join(SEPARATOR, System.getProperty("user.dir"), "settings.json"); public static final String USER_DIRECTORY = System.getProperty("user.dir");
public static final String PATH = String.join(SEPARATOR, USER_DIRECTORY, "settings.json");
public static final String AUDIO_DIRECTORY = String.join(SEPARATOR, USER_DIRECTORY, "audio");
public static final String TRACKS_PATH = String.join(SEPARATOR, AUDIO_DIRECTORY, "tracks.json");
private static final ObjectMapper mapper = new ObjectMapper(); private static final ObjectMapper mapper = new ObjectMapper();
private static final ObjectWriter writer = mapper.writer(new DefaultPrettyPrinter()); private static final ObjectWriter writer = mapper.writer(new DefaultPrettyPrinter());
@@ -23,14 +27,22 @@ public class SettingsManager {
} }
public static Settings load(String path) throws IOException { public static Settings load(String path) throws IOException {
return load(path, Settings.class);
}
public static <T> T load(String path, Class<T> type) throws IOException {
File file = new File(path); File file = new File(path);
if (!file.exists()) { if (!file.exists()) {
LOGGER.warn("Settings file does not exist, creating new file at: {}", file.getPath()); LOGGER.warn("{} file does not exist, creating new file at: {}", type.getSimpleName(), file.getPath());
write(new Settings()); try {
write(path, type.getConstructor().newInstance());
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException ex) {
throw new IOException(ex);
}
} }
LOGGER.info("Reading settings from {}", file.getPath()); LOGGER.info("Reading {} file from {}", type.getSimpleName(), file.getPath());
try (InputStream inputStream = new FileInputStream(file)) { try (InputStream inputStream = new FileInputStream(file)) {
return mapper.readValue(inputStream, Settings.class); return mapper.readValue(inputStream, type);
} }
} }
@@ -38,8 +50,8 @@ public class SettingsManager {
write(PATH, settings); write(PATH, settings);
} }
public static void write(String path, Settings settings) throws IOException { public static void write(String path, Object object) throws IOException {
File file = new File(path); File file = new File(path);
writer.writeValue(file, settings); writer.writeValue(file, object);
} }
} }

View File

@@ -1,15 +1,15 @@
package com.bensherriff.siren.settings; package com.bensherriff.siren.settings;
public class TrackSettings { public class Track {
private String file; private String fileName;
private String[] tags; private String[] tags;
public String getFile() { public String getFileName() {
return file; return fileName;
} }
public void setFile(String file) { public void setFileName(String fileName) {
this.file = file; this.fileName = fileName;
} }
public String[] getTags() { public String[] getTags() {

View File

@@ -0,0 +1,17 @@
package com.bensherriff.siren.settings;
import java.util.ArrayList;
import java.util.List;
public class Tracks {
private List<Track> tracks = new ArrayList<>();
public List<Track> getTracks() {
return tracks;
}
public void setTracks(List<Track> tracks) {
this.tracks = tracks;
}
}