v0.1.20 Error handling and added local track support

This commit is contained in:
2023-04-17 08:53:04 -04:00
parent c226095b1a
commit 312f87d91a
16 changed files with 367 additions and 156 deletions

1
.gitignore vendored
View File

@@ -4,4 +4,5 @@
**/app/
**/settings.json
**/logs/
**/audio/
.env

View File

@@ -1 +1 @@
export SIREN_VERSION=0.1.19
export SIREN_VERSION=0.1.20

View File

@@ -6,9 +6,7 @@ import com.bensherriff.siren.commands.*;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.ai.OpenAIManager;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.bensherriff.siren.settings.SettingsManager;
import com.bensherriff.siren.settings.*;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.entities.Member;
@@ -22,6 +20,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Executors;
@@ -79,6 +78,7 @@ public class Listener extends ListenerAdapter {
this.openAIManager = new OpenAIManager(this);
DatabaseManager.createTables();
populateAudioTable();
}
public void closeAudioConnection(long guildID) {
@@ -107,7 +107,7 @@ public class Listener extends ListenerAdapter {
long guildId = Long.parseLong(guild.getId());
AudioHandler audioHandler;
if (guild.getAudioManager().getSendingHandler() == null) {
if (guild.getAudioManager().getSendingHandler() == null || !settings.getGuildSettings().containsKey(guildId)) {
LOGGER.info("Creating Audio Handler for guild {}", guildId);
if (!settings.getGuildSettings().containsKey(guildId)) {
settings.getGuildSettings().put(guildId, new GuildSettings());
@@ -122,6 +122,32 @@ public class Listener extends ListenerAdapter {
return audioHandler;
}
private void populateAudioTable() {
DatabaseManager.clearTable("audio");
File directory = new File(SettingsManager.AUDIO_DIRECTORY);
try {
if (!directory.exists() && !directory.mkdirs()) {
LOGGER.error("Failed to create directory at {}", directory.getPath());
return;
}
Tracks tracks = SettingsManager.load(SettingsManager.TRACKS_PATH, Tracks.class);
int rows = 0;
File[] files = directory.listFiles();
if (files != null) {
for (Track track : tracks.getTracks()) {
for (File file : files) {
if (file.exists() && file.getName().equals(track.getFileName())) {
rows += DatabaseManager.storeAudio(track.getFileName(), track.getTags());
}
}
}
}
LOGGER.debug("Updated with {} local tracks", rows);
} catch (IOException ex) {
LOGGER.error("Failed to load local tracks; {}", ex.getMessage());
}
}
@Override
public void onReady(@NotNull ReadyEvent event) {
super.onReady(event);
@@ -131,7 +157,7 @@ public class Listener extends ListenerAdapter {
commands.put("volume", new VolumeCommand(this));
commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this));
commands.put("image", new ImageCommand(this));
// commands.put("image", new ImageCommand(this));
jda.getGuilds().forEach(guild -> executor.execute(() -> {
LOGGER.debug("Updating commands for \"{}\" <{}>", guild.getName(), guild.getId());
guild.updateCommands().addCommands(
@@ -151,7 +177,7 @@ public class Listener extends ListenerAdapter {
try {
commands.get(command).execute(event);
} catch (Exception ex) {
LOGGER.error(ex.getMessage());
LOGGER.error("Failed to execute command; {}", ex.getMessage());
event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue();
}
});

View File

@@ -30,7 +30,7 @@ public class Main {
try {
start();
} catch (Exception ex) {
LOGGER.error(ex.getMessage());
LOGGER.error("Caught unhandled exception; {}", ex.getMessage());
}
}

View File

@@ -23,8 +23,10 @@ import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel;
import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.postgresql.jdbc.PgArray;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ScheduledExecutorService;
@@ -141,15 +143,37 @@ public class OpenAIManager {
// Handle System Messages
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName()));
chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
if (event.isFromThread()) {
String query = new QueryBuilder().from("messages")
String query = new QueryBuilder("messages")
.where("guild_id = ? AND thread_id = ?")
.orderBy("timestamp DESC")
.limit(10)
.build();
List<MessageData> previousMessages = DatabaseManager.getMessages(
// Build MessageData objects from query results
List<Map<String, Object>> results = DatabaseManager.query(
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
List<MessageData> previousMessages = new ArrayList<>();
for (Map<String, Object> result : results) {
Object[] resultTopicObjects = (Object[]) ((PgArray) result.get("topics")).getArray();
Set<String> resultTopics = new HashSet<>();
for (Object object : resultTopicObjects) {
resultTopics.add((String) object);
}
MessageData messageData = new MessageData(
(long) result.get("guild_id"),
(long) result.get("thread_id"),
(long) result.get("user_id"),
(String) result.get("message_type"),
(String) result.get("message_text"),
(String) result.get("message_response"),
resultTopics,
(Timestamp) result.get("timestamp")
);
previousMessages.add(messageData);
}
Set<String> potentialTopics = new HashSet<>();
for (MessageData previousMessage : previousMessages) {
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
@@ -157,11 +181,11 @@ public class OpenAIManager {
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText()));
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse()));
chatMessages.add(previousChatMessage);
// chatMessages.add(previousChatMessage);
}
LOGGER.trace("Potential topics: {}", potentialTopics);
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
}
}
return ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName())

View File

@@ -68,29 +68,29 @@ public class AudioHandler extends AudioEventAdapter implements AudioSendHandler
}
public void setVolume(int volume) {
LOGGER.debug("Set volume to {}", volume);
LOGGER.trace("Set volume to {}", volume);
player.setVolume(volume);
}
@Override
public void onPlayerPause(AudioPlayer player) {
LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
}
@Override
public void onPlayerResume(AudioPlayer player) {
LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
}
@Override
public void onTrackStart(AudioPlayer player, AudioTrack track) {
LOGGER.debug("Starting track {}", track.getInfo().title);
LOGGER.trace("Starting track {}", track.getInfo().title);
manager.getListener().getJDA().getPresence().setActivity(Activity.playing(track.getInfo().title));
}
@Override
public void onTrackEnd(AudioPlayer player, AudioTrack track, AudioTrackEndReason endReason) {
LOGGER.debug("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext);
LOGGER.trace("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext);
if (queue.isEmpty()) {
manager.getListener().closeAudioConnection(guildID);
manager.getListener().getJDA().getPresence().setActivity(Activity.playing("nothing"));

View File

@@ -1,8 +1,17 @@
package com.bensherriff.siren.audio;
import com.bensherriff.siren.Listener;
import com.sedmelluq.discord.lavaplayer.player.AudioConfiguration;
import com.sedmelluq.discord.lavaplayer.player.DefaultAudioPlayerManager;
import com.sedmelluq.discord.lavaplayer.source.AudioSourceManagers;
import com.sedmelluq.discord.lavaplayer.source.bandcamp.BandcampAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.beam.BeamAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.http.HttpAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.soundcloud.SoundCloudAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.twitch.TwitchStreamAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.vimeo.VimeoAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.youtube.YoutubeAudioSourceManager;
public class PlayerManager extends DefaultAudioPlayerManager {
@@ -17,6 +26,15 @@ public class PlayerManager extends DefaultAudioPlayerManager {
}
public void initialize() {
getConfiguration().setResamplingQuality(AudioConfiguration.ResamplingQuality.MEDIUM);
registerSourceManager(new YoutubeAudioSourceManager());
registerSourceManager(SoundCloudAudioSourceManager.createDefault());
registerSourceManager(new BandcampAudioSourceManager());
registerSourceManager(new VimeoAudioSourceManager());
registerSourceManager(new TwitchStreamAudioSourceManager());
registerSourceManager(new BeamAudioSourceManager());
registerSourceManager(new HttpAudioSourceManager());
registerSourceManager(new LocalAudioSourceManager());
AudioSourceManagers.registerRemoteSources(this);
AudioSourceManagers.registerLocalSource(this);
}

View File

@@ -7,6 +7,7 @@ import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEve
import net.dv8tion.jda.api.interactions.commands.OptionMapping;
import net.dv8tion.jda.api.interactions.commands.OptionType;
import net.dv8tion.jda.api.interactions.commands.build.Commands;
import net.dv8tion.jda.api.interactions.commands.build.OptionData;
import java.io.IOException;
import java.util.Objects;
@@ -18,7 +19,11 @@ public class ImageCommand extends Command {
slashCommandData = Commands.slash("image", "Generate an image using DALL-E")
.addOption(OptionType.STRING, "prompt", "The prompt for image generation", true)
.addOption(OptionType.INTEGER, "count", "The number of images to be generated", false)
.addOption(OptionType.INTEGER, "size", "Size of the picture, either 1 (small), 2 (medium), or 3 (large)", false);
.addOptions(new OptionData(OptionType.STRING, "type", "The size of the picture", false)
.addChoice("Small", ImageSize.SMALL.getSize())
.addChoice("Medium", ImageSize.MEDIUM.getSize())
.addChoice("Large", ImageSize.LARGE.getSize())
);
}
//TODO Store image in database

View File

@@ -1,32 +1,94 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.database.QueryBuilder;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.Listener;
import com.bensherriff.siren.settings.SettingsManager;
import com.sedmelluq.discord.lavaplayer.container.MediaContainerDescriptor;
import com.sedmelluq.discord.lavaplayer.player.AudioLoadResultHandler;
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager;
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioTrack;
import com.sedmelluq.discord.lavaplayer.tools.FriendlyException;
import com.sedmelluq.discord.lavaplayer.track.AudioPlaylist;
import com.sedmelluq.discord.lavaplayer.track.AudioTrack;
import com.sedmelluq.discord.lavaplayer.track.AudioTrackInfo;
import com.sedmelluq.discord.lavaplayer.track.BasicAudioPlaylist;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.OptionType;
import net.dv8tion.jda.api.interactions.commands.build.Commands;
import net.dv8tion.jda.api.interactions.commands.build.SubcommandData;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class PlayCommand extends Command {
public PlayCommand(Listener listener) {
super(listener);
slashCommandData = Commands.slash("play", "Play a track from a URL")
.addOption(OptionType.STRING, "url", "Track URL", true);
slashCommandData = Commands.slash("play", "Play a track")
.addSubcommands(new SubcommandData("url", "Play a track from a URL")
.addOption(OptionType.STRING, "url", "Track URL", true))
.addSubcommands(new SubcommandData("local", "Play a track from a local file")
.addOption(OptionType.STRING, "file", "Local file name", true))
.addSubcommands(new SubcommandData("tags", "Play a track based on tags")
.addOption(OptionType.STRING, "tags", "The list of tags separated by a semicolon", true)
.addOption(OptionType.BOOLEAN, "inclusive", "If true, tracks can match any tag"));
}
@Override
public void execute(SlashCommandInteractionEvent event) throws IOException {
String audioDirectoryPath = SettingsManager.AUDIO_DIRECTORY + SettingsManager.SEPARATOR;
if ("url".equals(event.getSubcommandName())) {
String trackURL = Objects.requireNonNull(event.getOption("url")).getAsString();
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event));
} else if ("local".equals(event.getSubcommandName())) {
String fileName = Objects.requireNonNull(event.getOption("file")).getAsString();
if (!fileName.contains(".m4a")) {
fileName = fileName.concat(".m4a");
}
String trackURL = audioDirectoryPath + fileName;
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event));
} else if ("tags".equals(event.getSubcommandName())) {
String query;
String tagsString = Objects.requireNonNull(event.getOption("tags")).getAsString();
String[] tags = tagsString.split(";,");
for (int i = 0; i < tags.length; i++) {
tags[i] = tags[i].trim();
}
if (event.getOption("inclusive") != null && Objects.requireNonNull(event.getOption("inclusive")).getAsBoolean()) {
query = new QueryBuilder("audio").where("tags && ARRAY[?]").build();
} else {
query = new QueryBuilder("audio").where("tags @> ARRAY[?]").build();
}
try {
List<AudioTrack> tracks = new ArrayList<>();
List<Map<String, Object>> results = DatabaseManager.query(query, List.of(tags));
if (results.isEmpty()) {
event.getHook().sendMessage("No tracks found with the those tags.").queue();
return;
}
for (Map<String, Object> result : results) {
// String title = (String) result.get("title");
// String author = (String) result.get("author");
// long length = (Long) result.get("length");
// String identifier = (String) result.get("identifier");
String fileName = audioDirectoryPath.concat((String) result.get("file_name"));
LOGGER.debug("{}", fileName);
listener.getPlayerManager().loadItemOrdered(event.getGuild(), fileName, new ResultHandler(event));
}
// AudioPlaylist playlist = new BasicAudioPlaylist("Playlist based on tags", tracks, null, false);
} catch (SQLException ex) {
LOGGER.error("Failed to retrieve audio tags; {}", ex.getMessage());
}
}
}
private class ResultHandler implements AudioLoadResultHandler {
@@ -52,7 +114,11 @@ public class PlayCommand extends Command {
public void trackLoaded(AudioTrack track) {
try {
playTrack(guild, userID, audioHandler, track);
event.getHook().sendMessage("Adding **" + track.getInfo().title + "** to queue...").queue();
String trackTitle = "**" + track.getInfo().title + "**";
if (trackTitle.equalsIgnoreCase("Unknown title")) {
trackTitle = "track";
}
event.getHook().sendMessage("Adding " + trackTitle + " to queue...").queue();
} catch (EmptyVoiceChannelException e) {
event.getHook().sendMessage("You must be connected to a voice channel in order to play tracks!").queue();
} catch (Exception e) {
@@ -80,7 +146,7 @@ public class PlayCommand extends Command {
@Override
public void noMatches() {
event.getHook().sendMessage("Nothing found at that URL").queue();
event.getHook().sendMessage("No track found").queue();
}
@Override

View File

@@ -41,7 +41,7 @@ public class DatabaseConnection {
try {
connection.close();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
LOGGER.error("Failed to close connection; {}", ex.getMessage());
}
});
}

View File

@@ -4,6 +4,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.sql.*;
import java.sql.Date;
import java.util.*;
import static java.util.Map.entry;
@@ -11,29 +12,74 @@ import static java.util.Map.entry;
public class DatabaseManager {
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
private static final Map<String, String> createTableQueries = Map.ofEntries(
entry("messages", "CREATE TABLE IF NOT EXISTS messages (" +
"id SERIAL PRIMARY KEY, " +
"guild_id BIGINT NOT NULL, " +
"thread_id BIGINT NOT NULL, " +
"user_id BIGINT NOT NULL, " +
"message_type VARCHAR(20) NOT NULL," +
"message_text TEXT NOT NULL," +
"message_response TEXT, " +
"topics TEXT[], " +
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
")"),
entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" +
"id SERIAL PRIMARY KEY, " +
"embeddings FLOAT[] NOT NULL" +
")"),
entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" +
"id SERIAL PRIMARY KEY, " +
"message_id INT NOT NULL, " +
"embedding_id INT NOT NULL" +
")")
private static final Map<String, List<String>> createTableQueries = Map.ofEntries(
entry("messages", List.of(
"id SERIAL PRIMARY KEY",
"guild_id BIGINT NOT NULL",
"thread_id BIGINT NOT NULL",
"user_id BIGINT NOT NULL",
"message_type VARCHAR(20) NOT NULL",
"message_text TEXT NOT NULL",
"message_response TEXT",
"topics TEXT[]",
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()")),
entry("embeddings", List.of(
"id SERIAL PRIMARY KEY",
"embeddings FLOAT[] NOT NULL")),
entry("message_embeddings", List.of(
"id SERIAL PRIMARY KEY",
"message_id INT NOT NULL",
"embedding_id INT NOT NULL")),
entry("audio", List.of(
"id SERIAL PRIMARY KEY",
"title TEXT NOT NULL",
"author TEXT NOT NULL",
"length BIGINT NOT NULL",
"identifier TEXT NOT NULL",
"file_name TEXT NOT NULL",
"tags TEXT[]"
))
);
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
public static void createTables() {
for (Map.Entry<String, List<String>> entry : createTableQueries.entrySet()) {
if (!createTable(entry.getKey(), entry.getValue())) {
LOGGER.warn("Failed to create one or more required database tables");
return;
}
}
LOGGER.debug("Databases initialized");
}
private static boolean createTable(String tableName, List<String> columns) {
if (tableExists(tableName)) {
return true;
} else {
try {
StringBuilder stringBuilder = new StringBuilder("CREATE TABLE IF NOT EXISTS ")
.append(tableName).append(" ( ");
for (int i = 0; i < columns.size(); i++) {
stringBuilder.append(columns.get(i));
if (i != columns.size() - 1) {
stringBuilder.append(", ");
}
}
stringBuilder.append(")");
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
Statement statement = connection.createStatement();
statement.execute(stringBuilder.toString());
return true;
} catch (SQLException ex) {
LOGGER.error("Failed to create table; {}", ex.getMessage());
return false;
}
}
}
public static int storeMessage(MessageData messageData) {
String INSERT_MESSAGE = "INSERT INTO messages (" +
"message_type, " +
"guild_id, " +
"thread_id, " +
@@ -42,98 +88,98 @@ public class DatabaseManager {
"message_response, " +
"topics) " +
"VALUES (?, ?, ?, ?, ?, ?, ?)";
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
"embeddings) " +
"VALUES (?)";
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
"message_id, " +
"embeddings_id, " +
"VALUES (?, ?)";
public static void createTables() {
for (Map.Entry<String, String> entry : createTableQueries.entrySet()) {
if (!createTable(entry.getKey())) {
LOGGER.warn("Failed to create one or more required database tables");
return;
}
}
LOGGER.debug("Databases initialized");
}
private static boolean createTable(String tableName) {
if (tableExists(tableName)) {
return true;
} else {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
Statement statement = connection.createStatement();
statement.execute(createTableQueries.get(tableName));
return true;
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return false;
}
}
}
public static int storeMessage(MessageData messageData) {
if (!tableExists("messages")) {
LOGGER.warn("Table 'messages' does not exist");
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE);
preparedStatement.setString(1, messageData.getMessageType());
preparedStatement.setLong(2, messageData.getGuildId());
preparedStatement.setLong(3, messageData.getThreadId());
preparedStatement.setLong(4, messageData.getUserId());
preparedStatement.setString(5, messageData.getMessageText());
preparedStatement.setString(6, messageData.getMessageResponse());
preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray()));
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
return storeMessage("messages", INSERT_MESSAGE,
messageData.getMessageType(),
messageData.getGuildId(),
messageData.getThreadId(),
messageData.getUserId(),
messageData.getMessageText(),
messageData.getMessageResponse(),
messageData.getTopics()
);
}
public static int storeEmbedding(List<Double> data) {
if (!tableExists("embeddings")) {
LOGGER.warn("Table 'embeddings' does not exist");
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING);
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
"embeddings) " +
"VALUES (?)";
return storeMessage("embeddings", INSERT_EMBEDDING, data);
}
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
if (!tableExists("message_embeddings")) {
LOGGER.warn("Table 'message_embeddings' does not exist");
String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
"message_id, " +
"embeddings_id, " +
"VALUES (?, ?)";
return storeMessage("message_embeddings", INSERT_MESSAGE_EMBEDDINGS, messageId, embeddingId);
}
public static int storeAudio(String fileName, String[] tags) {
String INSERT_AUDIO = "INSERT INTO audio (" +
"file_name, " +
"tags) " +
"VALUES (?, ?)";
return storeMessage("audio", INSERT_AUDIO, fileName, tags);
}
private static int storeMessage(String tableName, String insertString, Object... params) {
if (!tableExists(tableName)) {
LOGGER.warn("Table '{}' does not exist", tableName);
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS);
preparedStatement.setInt(1, messageId);
preparedStatement.setInt(2, embeddingId);
PreparedStatement preparedStatement = connection.prepareStatement(insertString);
int i = 1;
for (Object param : params) {
if (param == null) {
preparedStatement.setNull(i, Types.NULL);
} else if (param instanceof String) {
preparedStatement.setString(i, (String) param);
} else if (param instanceof Integer) {
preparedStatement.setInt(i, (int) param);
} else if (param instanceof Double) {
preparedStatement.setDouble(i, (double) param);
} else if (param instanceof Float) {
preparedStatement.setFloat(i, (float) param);
} else if (param instanceof Long) {
preparedStatement.setLong(i, (long) param);
} else if (param instanceof Boolean) {
preparedStatement.setBoolean(i, (boolean) param);
} else if (param instanceof Date) {
preparedStatement.setDate(i, (Date) param);
} else if (param instanceof Time) {
preparedStatement.setTime(i, (Time) param);
} else if (param instanceof Timestamp) {
preparedStatement.setTimestamp(i, (Timestamp) param);
} else if (param instanceof byte[]) {
preparedStatement.setBytes(i, (byte[]) param);
} else if (param instanceof Blob) {
preparedStatement.setBlob(i, (Blob) param);
} else if (param instanceof Clob) {
preparedStatement.setClob(i, (Clob) param);
} else if (param instanceof List && ((List<?>) param).get(0) instanceof String) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray()));
} else if (param instanceof List && ((List<?>) param).get(0) instanceof Double) {
preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0])));
} else if (param instanceof Set && ((Set<?>) param).toArray()[0] instanceof String) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray()));
} else if (param instanceof String[]) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param)));
} else {
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass());
}
i++;
}
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
LOGGER.error("Failed to store message; {}", ex.getMessage());
return -1;
}
}
public static List<MessageData> getMessages(String query, Object... params) throws SQLException, IllegalArgumentException {
public static List<Map<String, Object>> query(String query, Object... params) throws SQLException, IllegalArgumentException {
LOGGER.trace("Query: <{}>", query);
Connection connection = DatabaseConnection.getConnection();
PreparedStatement stmt = connection.prepareStatement(query);
@@ -159,21 +205,19 @@ public class DatabaseManager {
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName());
}
}
ResultSet resultSet = stmt.executeQuery();
List<MessageData> resultList = new ArrayList<>();
List<Map<String, Object>> resultList = new ArrayList<>();
try (ResultSet resultSet = stmt.executeQuery()) {
ResultSetMetaData metaData = resultSet.getMetaData();
int columnCount = metaData.getColumnCount();
while (resultSet.next()) {
Array array = resultSet.getArray(8);
MessageData messageData = new MessageData(
resultSet.getLong(2),
resultSet.getLong(3),
resultSet.getLong(4),
resultSet.getString(5),
resultSet.getString(6),
resultSet.getString(7),
new HashSet<>(Arrays.asList((String[]) array.getArray())),
resultSet.getTimestamp(9)
);
resultList.add(messageData);
Map<String, Object> rowMap = new HashMap<>();
for (int j = 1; j <= columnCount; j++) {
rowMap.put(metaData.getColumnName(j), resultSet.getObject(j));
}
resultList.add(rowMap);
}
} catch (SQLException ex) {
LOGGER.error("Failed to execute query; {}", ex.getMessage());
}
return resultList;
}
@@ -185,8 +229,18 @@ public class DatabaseManager {
ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'");
return resultSet.next();
} catch (SQLException ex) {
LOGGER.error("Failed to check if table exists; {}" + ex.getMessage());
LOGGER.error("Failed to check if table exists; {}", ex.getMessage());
return false;
}
}
public static void clearTable(String tableName) {
try {
Connection connection = DatabaseConnection.getConnection();
Statement statement = connection.createStatement();
statement.executeUpdate("TRUNCATE TABLE " + tableName);
} catch (SQLException ex) {
LOGGER.error("Failed to clear table; {}", ex.getMessage());
}
}
}

View File

@@ -3,18 +3,17 @@ package com.bensherriff.siren.database;
public class QueryBuilder {
private boolean distinct;
private String columnList;
private String tableName;
private final String tableName;
private String whereClause;
private String orderByClause;
private Integer limit;
public QueryBuilder select(String columnList) {
this.columnList = columnList;
return this;
public QueryBuilder(String tableName) {
this.tableName = tableName;
}
public QueryBuilder from(String tableName) {
this.tableName = tableName;
public QueryBuilder select(String columnList) {
this.columnList = columnList;
return this;
}

View File

@@ -1,18 +1,15 @@
package com.bensherriff.siren.settings;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Settings {
private String token = "";
private String owner = "";
private String owner = "250842261221277697";
private int threadPool = 2;
private Map<Long, GuildSettings> guildSettings = new HashMap<>();
private OpenAISettings openAISettings = new OpenAISettings();
private List<TrackSettings> tracks = new ArrayList<>();
public String getToken() {
return token;
@@ -54,12 +51,4 @@ public class Settings {
public void setOpenAISettings(OpenAISettings openAISettings) {
this.openAISettings = openAISettings;
}
public List<TrackSettings> getTracks() {
return tracks;
}
public void setTracks(List<TrackSettings> tracks) {
this.tracks = tracks;
}
}

View File

@@ -10,11 +10,15 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
public class SettingsManager {
private static final Logger LOGGER = LogManager.getLogger(SettingsManager.class);
public static final String SEPARATOR = File.separator;
public static final String PATH = String.join(SEPARATOR, System.getProperty("user.dir"), "settings.json");
public static final String USER_DIRECTORY = System.getProperty("user.dir");
public static final String PATH = String.join(SEPARATOR, USER_DIRECTORY, "settings.json");
public static final String AUDIO_DIRECTORY = String.join(SEPARATOR, USER_DIRECTORY, "audio");
public static final String TRACKS_PATH = String.join(SEPARATOR, AUDIO_DIRECTORY, "tracks.json");
private static final ObjectMapper mapper = new ObjectMapper();
private static final ObjectWriter writer = mapper.writer(new DefaultPrettyPrinter());
@@ -23,14 +27,22 @@ public class SettingsManager {
}
public static Settings load(String path) throws IOException {
return load(path, Settings.class);
}
public static <T> T load(String path, Class<T> type) throws IOException {
File file = new File(path);
if (!file.exists()) {
LOGGER.warn("Settings file does not exist, creating new file at: {}", file.getPath());
write(new Settings());
LOGGER.warn("{} file does not exist, creating new file at: {}", type.getSimpleName(), file.getPath());
try {
write(path, type.getConstructor().newInstance());
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException ex) {
throw new IOException(ex);
}
LOGGER.info("Reading settings from {}", file.getPath());
}
LOGGER.info("Reading {} file from {}", type.getSimpleName(), file.getPath());
try (InputStream inputStream = new FileInputStream(file)) {
return mapper.readValue(inputStream, Settings.class);
return mapper.readValue(inputStream, type);
}
}
@@ -38,8 +50,8 @@ public class SettingsManager {
write(PATH, settings);
}
public static void write(String path, Settings settings) throws IOException {
public static void write(String path, Object object) throws IOException {
File file = new File(path);
writer.writeValue(file, settings);
writer.writeValue(file, object);
}
}

View File

@@ -1,15 +1,15 @@
package com.bensherriff.siren.settings;
public class TrackSettings {
private String file;
public class Track {
private String fileName;
private String[] tags;
public String getFile() {
return file;
public String getFileName() {
return fileName;
}
public void setFile(String file) {
this.file = file;
public void setFileName(String fileName) {
this.fileName = fileName;
}
public String[] getTags() {

View File

@@ -0,0 +1,17 @@
package com.bensherriff.siren.settings;
import java.util.ArrayList;
import java.util.List;
public class Tracks {
private List<Track> tracks = new ArrayList<>();
public List<Track> getTracks() {
return tracks;
}
public void setTracks(List<Track> tracks) {
this.tracks = tracks;
}
}