From 312f87d91a56d755e1b6ddd37a3584f02c6d8ca3 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Mon, 17 Apr 2023 08:53:04 -0400 Subject: [PATCH] v0.1.20 Error handling and added local track support --- .gitignore | 1 + .version | 2 +- .../java/com/bensherriff/siren/Listener.java | 38 ++- src/main/java/com/bensherriff/siren/Main.java | 2 +- .../bensherriff/siren/ai/OpenAIManager.java | 36 ++- .../bensherriff/siren/audio/AudioHandler.java | 10 +- .../siren/audio/PlayerManager.java | 18 ++ .../siren/commands/ImageCommand.java | 7 +- .../siren/commands/PlayCommand.java | 78 +++++- .../siren/database/DatabaseConnection.java | 2 +- .../siren/database/DatabaseManager.java | 250 +++++++++++------- .../siren/database/QueryBuilder.java | 11 +- .../bensherriff/siren/settings/Settings.java | 13 +- .../siren/settings/SettingsManager.java | 26 +- .../{TrackSettings.java => Track.java} | 12 +- .../bensherriff/siren/settings/Tracks.java | 17 ++ 16 files changed, 367 insertions(+), 156 deletions(-) rename src/main/java/com/bensherriff/siren/settings/{TrackSettings.java => Track.java} (52%) create mode 100644 src/main/java/com/bensherriff/siren/settings/Tracks.java diff --git a/.gitignore b/.gitignore index 7fa04e9..a6cd868 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ **/app/ **/settings.json **/logs/ +**/audio/ .env \ No newline at end of file diff --git a/.version b/.version index d3a16cc..02db736 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -export SIREN_VERSION=0.1.19 \ No newline at end of file +export SIREN_VERSION=0.1.20 \ No newline at end of file diff --git a/src/main/java/com/bensherriff/siren/Listener.java b/src/main/java/com/bensherriff/siren/Listener.java index fe7449d..4e6af48 100644 --- a/src/main/java/com/bensherriff/siren/Listener.java +++ b/src/main/java/com/bensherriff/siren/Listener.java @@ -6,9 +6,7 @@ import com.bensherriff.siren.commands.*; import com.bensherriff.siren.database.DatabaseManager; import com.bensherriff.siren.exceptions.EmptyVoiceChannelException; import com.bensherriff.siren.ai.OpenAIManager; -import com.bensherriff.siren.settings.GuildSettings; -import com.bensherriff.siren.settings.Settings; -import com.bensherriff.siren.settings.SettingsManager; +import com.bensherriff.siren.settings.*; import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.entities.Guild; import net.dv8tion.jda.api.entities.Member; @@ -22,6 +20,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.NotNull; +import java.io.File; import java.io.IOException; import java.util.*; import java.util.concurrent.Executors; @@ -79,6 +78,7 @@ public class Listener extends ListenerAdapter { this.openAIManager = new OpenAIManager(this); DatabaseManager.createTables(); + populateAudioTable(); } public void closeAudioConnection(long guildID) { @@ -107,7 +107,7 @@ public class Listener extends ListenerAdapter { long guildId = Long.parseLong(guild.getId()); AudioHandler audioHandler; - if (guild.getAudioManager().getSendingHandler() == null) { + if (guild.getAudioManager().getSendingHandler() == null || !settings.getGuildSettings().containsKey(guildId)) { LOGGER.info("Creating Audio Handler for guild {}", guildId); if (!settings.getGuildSettings().containsKey(guildId)) { settings.getGuildSettings().put(guildId, new GuildSettings()); @@ -122,6 +122,32 @@ public class Listener extends ListenerAdapter { return audioHandler; } + private void populateAudioTable() { + DatabaseManager.clearTable("audio"); + File directory = new File(SettingsManager.AUDIO_DIRECTORY); + try { + if (!directory.exists() && !directory.mkdirs()) { + LOGGER.error("Failed to create directory at {}", directory.getPath()); + return; + } + Tracks tracks = SettingsManager.load(SettingsManager.TRACKS_PATH, Tracks.class); + int rows = 0; + File[] files = directory.listFiles(); + if (files != null) { + for (Track track : tracks.getTracks()) { + for (File file : files) { + if (file.exists() && file.getName().equals(track.getFileName())) { + rows += DatabaseManager.storeAudio(track.getFileName(), track.getTags()); + } + } + } + } + LOGGER.debug("Updated with {} local tracks", rows); + } catch (IOException ex) { + LOGGER.error("Failed to load local tracks; {}", ex.getMessage()); + } + } + @Override public void onReady(@NotNull ReadyEvent event) { super.onReady(event); @@ -131,7 +157,7 @@ public class Listener extends ListenerAdapter { commands.put("volume", new VolumeCommand(this)); commands.put("pause", new PauseCommand(this)); commands.put("resume", new ResumeCommand(this)); - commands.put("image", new ImageCommand(this)); +// commands.put("image", new ImageCommand(this)); jda.getGuilds().forEach(guild -> executor.execute(() -> { LOGGER.debug("Updating commands for \"{}\" <{}>", guild.getName(), guild.getId()); guild.updateCommands().addCommands( @@ -151,7 +177,7 @@ public class Listener extends ListenerAdapter { try { commands.get(command).execute(event); } catch (Exception ex) { - LOGGER.error(ex.getMessage()); + LOGGER.error("Failed to execute command; {}", ex.getMessage()); event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue(); } }); diff --git a/src/main/java/com/bensherriff/siren/Main.java b/src/main/java/com/bensherriff/siren/Main.java index 7e9b0d7..521f1a5 100644 --- a/src/main/java/com/bensherriff/siren/Main.java +++ b/src/main/java/com/bensherriff/siren/Main.java @@ -30,7 +30,7 @@ public class Main { try { start(); } catch (Exception ex) { - LOGGER.error(ex.getMessage()); + LOGGER.error("Caught unhandled exception; {}", ex.getMessage()); } } diff --git a/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java b/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java index 1efb75b..e296f4d 100644 --- a/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java +++ b/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java @@ -23,8 +23,10 @@ import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel; import net.dv8tion.jda.api.events.message.MessageReceivedEvent; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.postgresql.jdbc.PgArray; import java.sql.SQLException; +import java.sql.Timestamp; import java.time.Duration; import java.util.*; import java.util.concurrent.ScheduledExecutorService; @@ -141,15 +143,37 @@ public class OpenAIManager { // Handle System Messages chatMessages.add(createSystemMessage("You are a discord bot named Siren")); - chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName())); + chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName())); if (event.isFromThread()) { - String query = new QueryBuilder().from("messages") + String query = new QueryBuilder("messages") .where("guild_id = ? AND thread_id = ?") .orderBy("timestamp DESC") .limit(10) .build(); - List previousMessages = DatabaseManager.getMessages( + + // Build MessageData objects from query results + List> results = DatabaseManager.query( query, event.getGuild().getIdLong(), event.getChannel().getIdLong()); + List previousMessages = new ArrayList<>(); + for (Map result : results) { + Object[] resultTopicObjects = (Object[]) ((PgArray) result.get("topics")).getArray(); + Set resultTopics = new HashSet<>(); + for (Object object : resultTopicObjects) { + resultTopics.add((String) object); + } + MessageData messageData = new MessageData( + (long) result.get("guild_id"), + (long) result.get("thread_id"), + (long) result.get("user_id"), + (String) result.get("message_type"), + (String) result.get("message_text"), + (String) result.get("message_response"), + resultTopics, + (Timestamp) result.get("timestamp") + ); + previousMessages.add(messageData); + } + Set potentialTopics = new HashSet<>(); for (MessageData previousMessage : previousMessages) { ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " + @@ -157,10 +181,10 @@ public class OpenAIManager { "\". You replied with \"" + previousMessage.getMessageResponse() + "\"."); potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText())); potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse())); - chatMessages.add(previousChatMessage); - LOGGER.trace("Potential topics: {}", potentialTopics); -// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics)); +// chatMessages.add(previousChatMessage); } + LOGGER.trace("Potential topics: {}", potentialTopics); +// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics)); } return ChatCompletionRequest.builder() diff --git a/src/main/java/com/bensherriff/siren/audio/AudioHandler.java b/src/main/java/com/bensherriff/siren/audio/AudioHandler.java index 4a633bc..c98385c 100644 --- a/src/main/java/com/bensherriff/siren/audio/AudioHandler.java +++ b/src/main/java/com/bensherriff/siren/audio/AudioHandler.java @@ -68,29 +68,29 @@ public class AudioHandler extends AudioEventAdapter implements AudioSendHandler } public void setVolume(int volume) { - LOGGER.debug("Set volume to {}", volume); + LOGGER.trace("Set volume to {}", volume); player.setVolume(volume); } @Override public void onPlayerPause(AudioPlayer player) { - LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title); + LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title); } @Override public void onPlayerResume(AudioPlayer player) { - LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title); + LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title); } @Override public void onTrackStart(AudioPlayer player, AudioTrack track) { - LOGGER.debug("Starting track {}", track.getInfo().title); + LOGGER.trace("Starting track {}", track.getInfo().title); manager.getListener().getJDA().getPresence().setActivity(Activity.playing(track.getInfo().title)); } @Override public void onTrackEnd(AudioPlayer player, AudioTrack track, AudioTrackEndReason endReason) { - LOGGER.debug("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext); + LOGGER.trace("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext); if (queue.isEmpty()) { manager.getListener().closeAudioConnection(guildID); manager.getListener().getJDA().getPresence().setActivity(Activity.playing("nothing")); diff --git a/src/main/java/com/bensherriff/siren/audio/PlayerManager.java b/src/main/java/com/bensherriff/siren/audio/PlayerManager.java index 30ec233..18bb082 100644 --- a/src/main/java/com/bensherriff/siren/audio/PlayerManager.java +++ b/src/main/java/com/bensherriff/siren/audio/PlayerManager.java @@ -1,8 +1,17 @@ package com.bensherriff.siren.audio; import com.bensherriff.siren.Listener; +import com.sedmelluq.discord.lavaplayer.player.AudioConfiguration; import com.sedmelluq.discord.lavaplayer.player.DefaultAudioPlayerManager; import com.sedmelluq.discord.lavaplayer.source.AudioSourceManagers; +import com.sedmelluq.discord.lavaplayer.source.bandcamp.BandcampAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.beam.BeamAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.http.HttpAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.soundcloud.SoundCloudAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.twitch.TwitchStreamAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.vimeo.VimeoAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.youtube.YoutubeAudioSourceManager; public class PlayerManager extends DefaultAudioPlayerManager { @@ -17,6 +26,15 @@ public class PlayerManager extends DefaultAudioPlayerManager { } public void initialize() { + getConfiguration().setResamplingQuality(AudioConfiguration.ResamplingQuality.MEDIUM); + registerSourceManager(new YoutubeAudioSourceManager()); + registerSourceManager(SoundCloudAudioSourceManager.createDefault()); + registerSourceManager(new BandcampAudioSourceManager()); + registerSourceManager(new VimeoAudioSourceManager()); + registerSourceManager(new TwitchStreamAudioSourceManager()); + registerSourceManager(new BeamAudioSourceManager()); + registerSourceManager(new HttpAudioSourceManager()); + registerSourceManager(new LocalAudioSourceManager()); AudioSourceManagers.registerRemoteSources(this); AudioSourceManagers.registerLocalSource(this); } diff --git a/src/main/java/com/bensherriff/siren/commands/ImageCommand.java b/src/main/java/com/bensherriff/siren/commands/ImageCommand.java index f77523a..d4e74d1 100644 --- a/src/main/java/com/bensherriff/siren/commands/ImageCommand.java +++ b/src/main/java/com/bensherriff/siren/commands/ImageCommand.java @@ -7,6 +7,7 @@ import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEve import net.dv8tion.jda.api.interactions.commands.OptionMapping; import net.dv8tion.jda.api.interactions.commands.OptionType; import net.dv8tion.jda.api.interactions.commands.build.Commands; +import net.dv8tion.jda.api.interactions.commands.build.OptionData; import java.io.IOException; import java.util.Objects; @@ -18,7 +19,11 @@ public class ImageCommand extends Command { slashCommandData = Commands.slash("image", "Generate an image using DALL-E") .addOption(OptionType.STRING, "prompt", "The prompt for image generation", true) .addOption(OptionType.INTEGER, "count", "The number of images to be generated", false) - .addOption(OptionType.INTEGER, "size", "Size of the picture, either 1 (small), 2 (medium), or 3 (large)", false); + .addOptions(new OptionData(OptionType.STRING, "type", "The size of the picture", false) + .addChoice("Small", ImageSize.SMALL.getSize()) + .addChoice("Medium", ImageSize.MEDIUM.getSize()) + .addChoice("Large", ImageSize.LARGE.getSize()) + ); } //TODO Store image in database diff --git a/src/main/java/com/bensherriff/siren/commands/PlayCommand.java b/src/main/java/com/bensherriff/siren/commands/PlayCommand.java index 09f9a72..cf531b4 100644 --- a/src/main/java/com/bensherriff/siren/commands/PlayCommand.java +++ b/src/main/java/com/bensherriff/siren/commands/PlayCommand.java @@ -1,32 +1,94 @@ package com.bensherriff.siren.commands; import com.bensherriff.siren.audio.AudioHandler; +import com.bensherriff.siren.database.DatabaseManager; +import com.bensherriff.siren.database.QueryBuilder; import com.bensherriff.siren.exceptions.EmptyVoiceChannelException; import com.bensherriff.siren.Listener; +import com.bensherriff.siren.settings.SettingsManager; +import com.sedmelluq.discord.lavaplayer.container.MediaContainerDescriptor; import com.sedmelluq.discord.lavaplayer.player.AudioLoadResultHandler; +import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager; +import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioTrack; import com.sedmelluq.discord.lavaplayer.tools.FriendlyException; import com.sedmelluq.discord.lavaplayer.track.AudioPlaylist; import com.sedmelluq.discord.lavaplayer.track.AudioTrack; +import com.sedmelluq.discord.lavaplayer.track.AudioTrackInfo; +import com.sedmelluq.discord.lavaplayer.track.BasicAudioPlaylist; import net.dv8tion.jda.api.entities.Guild; import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent; import net.dv8tion.jda.api.interactions.commands.OptionType; import net.dv8tion.jda.api.interactions.commands.build.Commands; +import net.dv8tion.jda.api.interactions.commands.build.SubcommandData; import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.Objects; public class PlayCommand extends Command { public PlayCommand(Listener listener) { super(listener); - slashCommandData = Commands.slash("play", "Play a track from a URL") - .addOption(OptionType.STRING, "url", "Track URL", true); + slashCommandData = Commands.slash("play", "Play a track") + .addSubcommands(new SubcommandData("url", "Play a track from a URL") + .addOption(OptionType.STRING, "url", "Track URL", true)) + .addSubcommands(new SubcommandData("local", "Play a track from a local file") + .addOption(OptionType.STRING, "file", "Local file name", true)) + .addSubcommands(new SubcommandData("tags", "Play a track based on tags") + .addOption(OptionType.STRING, "tags", "The list of tags separated by a semicolon", true) + .addOption(OptionType.BOOLEAN, "inclusive", "If true, tracks can match any tag")); } @Override public void execute(SlashCommandInteractionEvent event) throws IOException { - String trackURL = Objects.requireNonNull(event.getOption("url")).getAsString(); - listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event)); + String audioDirectoryPath = SettingsManager.AUDIO_DIRECTORY + SettingsManager.SEPARATOR; + if ("url".equals(event.getSubcommandName())) { + String trackURL = Objects.requireNonNull(event.getOption("url")).getAsString(); + listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event)); + } else if ("local".equals(event.getSubcommandName())) { + String fileName = Objects.requireNonNull(event.getOption("file")).getAsString(); + if (!fileName.contains(".m4a")) { + fileName = fileName.concat(".m4a"); + } + String trackURL = audioDirectoryPath + fileName; + listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event)); + } else if ("tags".equals(event.getSubcommandName())) { + String query; + String tagsString = Objects.requireNonNull(event.getOption("tags")).getAsString(); + String[] tags = tagsString.split(";,"); + for (int i = 0; i < tags.length; i++) { + tags[i] = tags[i].trim(); + } + if (event.getOption("inclusive") != null && Objects.requireNonNull(event.getOption("inclusive")).getAsBoolean()) { + query = new QueryBuilder("audio").where("tags && ARRAY[?]").build(); + } else { + query = new QueryBuilder("audio").where("tags @> ARRAY[?]").build(); + } + try { + List tracks = new ArrayList<>(); + List> results = DatabaseManager.query(query, List.of(tags)); + if (results.isEmpty()) { + event.getHook().sendMessage("No tracks found with the those tags.").queue(); + return; + } + for (Map result : results) { +// String title = (String) result.get("title"); +// String author = (String) result.get("author"); +// long length = (Long) result.get("length"); +// String identifier = (String) result.get("identifier"); + String fileName = audioDirectoryPath.concat((String) result.get("file_name")); + LOGGER.debug("{}", fileName); + listener.getPlayerManager().loadItemOrdered(event.getGuild(), fileName, new ResultHandler(event)); + } +// AudioPlaylist playlist = new BasicAudioPlaylist("Playlist based on tags", tracks, null, false); + + } catch (SQLException ex) { + LOGGER.error("Failed to retrieve audio tags; {}", ex.getMessage()); + } + } } private class ResultHandler implements AudioLoadResultHandler { @@ -52,7 +114,11 @@ public class PlayCommand extends Command { public void trackLoaded(AudioTrack track) { try { playTrack(guild, userID, audioHandler, track); - event.getHook().sendMessage("Adding **" + track.getInfo().title + "** to queue...").queue(); + String trackTitle = "**" + track.getInfo().title + "**"; + if (trackTitle.equalsIgnoreCase("Unknown title")) { + trackTitle = "track"; + } + event.getHook().sendMessage("Adding " + trackTitle + " to queue...").queue(); } catch (EmptyVoiceChannelException e) { event.getHook().sendMessage("You must be connected to a voice channel in order to play tracks!").queue(); } catch (Exception e) { @@ -80,7 +146,7 @@ public class PlayCommand extends Command { @Override public void noMatches() { - event.getHook().sendMessage("Nothing found at that URL").queue(); + event.getHook().sendMessage("No track found").queue(); } @Override diff --git a/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java b/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java index e773237..6813881 100644 --- a/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java +++ b/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java @@ -41,7 +41,7 @@ public class DatabaseConnection { try { connection.close(); } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); + LOGGER.error("Failed to close connection; {}", ex.getMessage()); } }); } diff --git a/src/main/java/com/bensherriff/siren/database/DatabaseManager.java b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java index a06e690..1fb126b 100644 --- a/src/main/java/com/bensherriff/siren/database/DatabaseManager.java +++ b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java @@ -4,6 +4,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.sql.*; +import java.sql.Date; import java.util.*; import static java.util.Map.entry; @@ -11,49 +12,38 @@ import static java.util.Map.entry; public class DatabaseManager { private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class); - private static final Map createTableQueries = Map.ofEntries( - entry("messages", "CREATE TABLE IF NOT EXISTS messages (" + - "id SERIAL PRIMARY KEY, " + - "guild_id BIGINT NOT NULL, " + - "thread_id BIGINT NOT NULL, " + - "user_id BIGINT NOT NULL, " + - "message_type VARCHAR(20) NOT NULL," + - "message_text TEXT NOT NULL," + - "message_response TEXT, " + - "topics TEXT[], " + - "timestamp TIMESTAMP NOT NULL DEFAULT NOW()" + - ")"), - entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" + - "id SERIAL PRIMARY KEY, " + - "embeddings FLOAT[] NOT NULL" + - ")"), - entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" + - "id SERIAL PRIMARY KEY, " + - "message_id INT NOT NULL, " + - "embedding_id INT NOT NULL" + - ")") + private static final Map> createTableQueries = Map.ofEntries( + entry("messages", List.of( + "id SERIAL PRIMARY KEY", + "guild_id BIGINT NOT NULL", + "thread_id BIGINT NOT NULL", + "user_id BIGINT NOT NULL", + "message_type VARCHAR(20) NOT NULL", + "message_text TEXT NOT NULL", + "message_response TEXT", + "topics TEXT[]", + "timestamp TIMESTAMP NOT NULL DEFAULT NOW()")), + entry("embeddings", List.of( + "id SERIAL PRIMARY KEY", + "embeddings FLOAT[] NOT NULL")), + entry("message_embeddings", List.of( + "id SERIAL PRIMARY KEY", + "message_id INT NOT NULL", + "embedding_id INT NOT NULL")), + entry("audio", List.of( + "id SERIAL PRIMARY KEY", + "title TEXT NOT NULL", + "author TEXT NOT NULL", + "length BIGINT NOT NULL", + "identifier TEXT NOT NULL", + "file_name TEXT NOT NULL", + "tags TEXT[]" + )) ); - private static final String INSERT_MESSAGE = "INSERT INTO messages (" + - "message_type, " + - "guild_id, " + - "thread_id, " + - "user_id, " + - "message_text, " + - "message_response, " + - "topics) " + - "VALUES (?, ?, ?, ?, ?, ?, ?)"; - private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" + - "embeddings) " + - "VALUES (?)"; - - private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" + - "message_id, " + - "embeddings_id, " + - "VALUES (?, ?)"; public static void createTables() { - for (Map.Entry entry : createTableQueries.entrySet()) { - if (!createTable(entry.getKey())) { + for (Map.Entry> entry : createTableQueries.entrySet()) { + if (!createTable(entry.getKey(), entry.getValue())) { LOGGER.warn("Failed to create one or more required database tables"); return; } @@ -61,79 +51,135 @@ public class DatabaseManager { LOGGER.debug("Databases initialized"); } - private static boolean createTable(String tableName) { + private static boolean createTable(String tableName, List columns) { if (tableExists(tableName)) { return true; } else { try { + StringBuilder stringBuilder = new StringBuilder("CREATE TABLE IF NOT EXISTS ") + .append(tableName).append(" ( "); + for (int i = 0; i < columns.size(); i++) { + stringBuilder.append(columns.get(i)); + if (i != columns.size() - 1) { + stringBuilder.append(", "); + } + } + stringBuilder.append(")"); + Connection connection = DatabaseConnection.getConnection(); LOGGER.debug("Creating '{}' database table if it does not exist", tableName); Statement statement = connection.createStatement(); - statement.execute(createTableQueries.get(tableName)); + statement.execute(stringBuilder.toString()); return true; } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); + LOGGER.error("Failed to create table; {}", ex.getMessage()); return false; } } } public static int storeMessage(MessageData messageData) { - if (!tableExists("messages")) { - LOGGER.warn("Table 'messages' does not exist"); - return -1; - } - try { - Connection connection = DatabaseConnection.getConnection(); - PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE); - preparedStatement.setString(1, messageData.getMessageType()); - preparedStatement.setLong(2, messageData.getGuildId()); - preparedStatement.setLong(3, messageData.getThreadId()); - preparedStatement.setLong(4, messageData.getUserId()); - preparedStatement.setString(5, messageData.getMessageText()); - preparedStatement.setString(6, messageData.getMessageResponse()); - preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray())); - return preparedStatement.executeUpdate(); - } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); - return -1; - } + String INSERT_MESSAGE = "INSERT INTO messages (" + + "message_type, " + + "guild_id, " + + "thread_id, " + + "user_id, " + + "message_text, " + + "message_response, " + + "topics) " + + "VALUES (?, ?, ?, ?, ?, ?, ?)"; + return storeMessage("messages", INSERT_MESSAGE, + messageData.getMessageType(), + messageData.getGuildId(), + messageData.getThreadId(), + messageData.getUserId(), + messageData.getMessageText(), + messageData.getMessageResponse(), + messageData.getTopics() + ); } public static int storeEmbedding(List data) { - if (!tableExists("embeddings")) { - LOGGER.warn("Table 'embeddings' does not exist"); - return -1; - } - try { - Connection connection = DatabaseConnection.getConnection(); - PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING); - preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0]))); - return preparedStatement.executeUpdate(); - } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); - return -1; - } + String INSERT_EMBEDDING = "INSERT INTO embeddings (" + + "embeddings) " + + "VALUES (?)"; + return storeMessage("embeddings", INSERT_EMBEDDING, data); } public static int storeMessageEmbeddings(int messageId, int embeddingId) { - if (!tableExists("message_embeddings")) { - LOGGER.warn("Table 'message_embeddings' does not exist"); + String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" + + "message_id, " + + "embeddings_id, " + + "VALUES (?, ?)"; + return storeMessage("message_embeddings", INSERT_MESSAGE_EMBEDDINGS, messageId, embeddingId); + } + + public static int storeAudio(String fileName, String[] tags) { + String INSERT_AUDIO = "INSERT INTO audio (" + + "file_name, " + + "tags) " + + "VALUES (?, ?)"; + return storeMessage("audio", INSERT_AUDIO, fileName, tags); + } + + private static int storeMessage(String tableName, String insertString, Object... params) { + if (!tableExists(tableName)) { + LOGGER.warn("Table '{}' does not exist", tableName); return -1; } try { Connection connection = DatabaseConnection.getConnection(); - PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS); - preparedStatement.setInt(1, messageId); - preparedStatement.setInt(2, embeddingId); + PreparedStatement preparedStatement = connection.prepareStatement(insertString); + int i = 1; + for (Object param : params) { + if (param == null) { + preparedStatement.setNull(i, Types.NULL); + } else if (param instanceof String) { + preparedStatement.setString(i, (String) param); + } else if (param instanceof Integer) { + preparedStatement.setInt(i, (int) param); + } else if (param instanceof Double) { + preparedStatement.setDouble(i, (double) param); + } else if (param instanceof Float) { + preparedStatement.setFloat(i, (float) param); + } else if (param instanceof Long) { + preparedStatement.setLong(i, (long) param); + } else if (param instanceof Boolean) { + preparedStatement.setBoolean(i, (boolean) param); + } else if (param instanceof Date) { + preparedStatement.setDate(i, (Date) param); + } else if (param instanceof Time) { + preparedStatement.setTime(i, (Time) param); + } else if (param instanceof Timestamp) { + preparedStatement.setTimestamp(i, (Timestamp) param); + } else if (param instanceof byte[]) { + preparedStatement.setBytes(i, (byte[]) param); + } else if (param instanceof Blob) { + preparedStatement.setBlob(i, (Blob) param); + } else if (param instanceof Clob) { + preparedStatement.setClob(i, (Clob) param); + } else if (param instanceof List && ((List) param).get(0) instanceof String) { + preparedStatement.setArray(i, connection.createArrayOf("text", ((List) param).toArray())); + } else if (param instanceof List && ((List) param).get(0) instanceof Double) { + preparedStatement.setArray(i, connection.createArrayOf("float8", ((List) param).toArray(new Double[0]))); + } else if (param instanceof Set && ((Set) param).toArray()[0] instanceof String) { + preparedStatement.setArray(i, connection.createArrayOf("text", ((Set) param).toArray())); + } else if (param instanceof String[]) { + preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param))); + } else { + throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass()); + } + i++; + } + return preparedStatement.executeUpdate(); } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); + LOGGER.error("Failed to store message; {}", ex.getMessage()); return -1; } } - public static List getMessages(String query, Object... params) throws SQLException, IllegalArgumentException { + public static List> query(String query, Object... params) throws SQLException, IllegalArgumentException { LOGGER.trace("Query: <{}>", query); Connection connection = DatabaseConnection.getConnection(); PreparedStatement stmt = connection.prepareStatement(query); @@ -159,21 +205,19 @@ public class DatabaseManager { throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName()); } } - ResultSet resultSet = stmt.executeQuery(); - List resultList = new ArrayList<>(); - while (resultSet.next()) { - Array array = resultSet.getArray(8); - MessageData messageData = new MessageData( - resultSet.getLong(2), - resultSet.getLong(3), - resultSet.getLong(4), - resultSet.getString(5), - resultSet.getString(6), - resultSet.getString(7), - new HashSet<>(Arrays.asList((String[]) array.getArray())), - resultSet.getTimestamp(9) - ); - resultList.add(messageData); + List> resultList = new ArrayList<>(); + try (ResultSet resultSet = stmt.executeQuery()) { + ResultSetMetaData metaData = resultSet.getMetaData(); + int columnCount = metaData.getColumnCount(); + while (resultSet.next()) { + Map rowMap = new HashMap<>(); + for (int j = 1; j <= columnCount; j++) { + rowMap.put(metaData.getColumnName(j), resultSet.getObject(j)); + } + resultList.add(rowMap); + } + } catch (SQLException ex) { + LOGGER.error("Failed to execute query; {}", ex.getMessage()); } return resultList; } @@ -185,8 +229,18 @@ public class DatabaseManager { ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'"); return resultSet.next(); } catch (SQLException ex) { - LOGGER.error("Failed to check if table exists; {}" + ex.getMessage()); + LOGGER.error("Failed to check if table exists; {}", ex.getMessage()); return false; } } + + public static void clearTable(String tableName) { + try { + Connection connection = DatabaseConnection.getConnection(); + Statement statement = connection.createStatement(); + statement.executeUpdate("TRUNCATE TABLE " + tableName); + } catch (SQLException ex) { + LOGGER.error("Failed to clear table; {}", ex.getMessage()); + } + } } diff --git a/src/main/java/com/bensherriff/siren/database/QueryBuilder.java b/src/main/java/com/bensherriff/siren/database/QueryBuilder.java index 996f810..edaebd8 100644 --- a/src/main/java/com/bensherriff/siren/database/QueryBuilder.java +++ b/src/main/java/com/bensherriff/siren/database/QueryBuilder.java @@ -3,18 +3,17 @@ package com.bensherriff.siren.database; public class QueryBuilder { private boolean distinct; private String columnList; - private String tableName; + private final String tableName; private String whereClause; private String orderByClause; private Integer limit; - public QueryBuilder select(String columnList) { - this.columnList = columnList; - return this; + public QueryBuilder(String tableName) { + this.tableName = tableName; } - public QueryBuilder from(String tableName) { - this.tableName = tableName; + public QueryBuilder select(String columnList) { + this.columnList = columnList; return this; } diff --git a/src/main/java/com/bensherriff/siren/settings/Settings.java b/src/main/java/com/bensherriff/siren/settings/Settings.java index da32a43..2332caa 100644 --- a/src/main/java/com/bensherriff/siren/settings/Settings.java +++ b/src/main/java/com/bensherriff/siren/settings/Settings.java @@ -1,18 +1,15 @@ package com.bensherriff.siren.settings; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; public class Settings { private String token = ""; - private String owner = ""; + private String owner = "250842261221277697"; private int threadPool = 2; private Map guildSettings = new HashMap<>(); private OpenAISettings openAISettings = new OpenAISettings(); - private List tracks = new ArrayList<>(); public String getToken() { return token; @@ -54,12 +51,4 @@ public class Settings { public void setOpenAISettings(OpenAISettings openAISettings) { this.openAISettings = openAISettings; } - - public List getTracks() { - return tracks; - } - - public void setTracks(List tracks) { - this.tracks = tracks; - } } diff --git a/src/main/java/com/bensherriff/siren/settings/SettingsManager.java b/src/main/java/com/bensherriff/siren/settings/SettingsManager.java index 943fab3..fe4ceaa 100644 --- a/src/main/java/com/bensherriff/siren/settings/SettingsManager.java +++ b/src/main/java/com/bensherriff/siren/settings/SettingsManager.java @@ -10,11 +10,15 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; public class SettingsManager { private static final Logger LOGGER = LogManager.getLogger(SettingsManager.class); public static final String SEPARATOR = File.separator; - public static final String PATH = String.join(SEPARATOR, System.getProperty("user.dir"), "settings.json"); + public static final String USER_DIRECTORY = System.getProperty("user.dir"); + public static final String PATH = String.join(SEPARATOR, USER_DIRECTORY, "settings.json"); + public static final String AUDIO_DIRECTORY = String.join(SEPARATOR, USER_DIRECTORY, "audio"); + public static final String TRACKS_PATH = String.join(SEPARATOR, AUDIO_DIRECTORY, "tracks.json"); private static final ObjectMapper mapper = new ObjectMapper(); private static final ObjectWriter writer = mapper.writer(new DefaultPrettyPrinter()); @@ -23,14 +27,22 @@ public class SettingsManager { } public static Settings load(String path) throws IOException { + return load(path, Settings.class); + } + + public static T load(String path, Class type) throws IOException { File file = new File(path); if (!file.exists()) { - LOGGER.warn("Settings file does not exist, creating new file at: {}", file.getPath()); - write(new Settings()); + LOGGER.warn("{} file does not exist, creating new file at: {}", type.getSimpleName(), file.getPath()); + try { + write(path, type.getConstructor().newInstance()); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException ex) { + throw new IOException(ex); + } } - LOGGER.info("Reading settings from {}", file.getPath()); + LOGGER.info("Reading {} file from {}", type.getSimpleName(), file.getPath()); try (InputStream inputStream = new FileInputStream(file)) { - return mapper.readValue(inputStream, Settings.class); + return mapper.readValue(inputStream, type); } } @@ -38,8 +50,8 @@ public class SettingsManager { write(PATH, settings); } - public static void write(String path, Settings settings) throws IOException { + public static void write(String path, Object object) throws IOException { File file = new File(path); - writer.writeValue(file, settings); + writer.writeValue(file, object); } } diff --git a/src/main/java/com/bensherriff/siren/settings/TrackSettings.java b/src/main/java/com/bensherriff/siren/settings/Track.java similarity index 52% rename from src/main/java/com/bensherriff/siren/settings/TrackSettings.java rename to src/main/java/com/bensherriff/siren/settings/Track.java index c944937..5769dce 100644 --- a/src/main/java/com/bensherriff/siren/settings/TrackSettings.java +++ b/src/main/java/com/bensherriff/siren/settings/Track.java @@ -1,15 +1,15 @@ package com.bensherriff.siren.settings; -public class TrackSettings { - private String file; +public class Track { + private String fileName; private String[] tags; - public String getFile() { - return file; + public String getFileName() { + return fileName; } - public void setFile(String file) { - this.file = file; + public void setFileName(String fileName) { + this.fileName = fileName; } public String[] getTags() { diff --git a/src/main/java/com/bensherriff/siren/settings/Tracks.java b/src/main/java/com/bensherriff/siren/settings/Tracks.java new file mode 100644 index 0000000..74dba7e --- /dev/null +++ b/src/main/java/com/bensherriff/siren/settings/Tracks.java @@ -0,0 +1,17 @@ +package com.bensherriff.siren.settings; + +import java.util.ArrayList; +import java.util.List; + +public class Tracks { + + private List tracks = new ArrayList<>(); + + public List getTracks() { + return tracks; + } + + public void setTracks(List tracks) { + this.tracks = tracks; + } +}