v0.1.20 Error handling and added local track support
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,4 +4,5 @@
|
||||
**/app/
|
||||
**/settings.json
|
||||
**/logs/
|
||||
**/audio/
|
||||
.env
|
||||
@@ -6,9 +6,7 @@ import com.bensherriff.siren.commands.*;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
||||
import com.bensherriff.siren.ai.OpenAIManager;
|
||||
import com.bensherriff.siren.settings.GuildSettings;
|
||||
import com.bensherriff.siren.settings.Settings;
|
||||
import com.bensherriff.siren.settings.SettingsManager;
|
||||
import com.bensherriff.siren.settings.*;
|
||||
import net.dv8tion.jda.api.JDA;
|
||||
import net.dv8tion.jda.api.entities.Guild;
|
||||
import net.dv8tion.jda.api.entities.Member;
|
||||
@@ -22,6 +20,7 @@ import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Executors;
|
||||
@@ -79,6 +78,7 @@ public class Listener extends ListenerAdapter {
|
||||
this.openAIManager = new OpenAIManager(this);
|
||||
|
||||
DatabaseManager.createTables();
|
||||
populateAudioTable();
|
||||
}
|
||||
|
||||
public void closeAudioConnection(long guildID) {
|
||||
@@ -107,7 +107,7 @@ public class Listener extends ListenerAdapter {
|
||||
long guildId = Long.parseLong(guild.getId());
|
||||
AudioHandler audioHandler;
|
||||
|
||||
if (guild.getAudioManager().getSendingHandler() == null) {
|
||||
if (guild.getAudioManager().getSendingHandler() == null || !settings.getGuildSettings().containsKey(guildId)) {
|
||||
LOGGER.info("Creating Audio Handler for guild {}", guildId);
|
||||
if (!settings.getGuildSettings().containsKey(guildId)) {
|
||||
settings.getGuildSettings().put(guildId, new GuildSettings());
|
||||
@@ -122,6 +122,32 @@ public class Listener extends ListenerAdapter {
|
||||
return audioHandler;
|
||||
}
|
||||
|
||||
private void populateAudioTable() {
|
||||
DatabaseManager.clearTable("audio");
|
||||
File directory = new File(SettingsManager.AUDIO_DIRECTORY);
|
||||
try {
|
||||
if (!directory.exists() && !directory.mkdirs()) {
|
||||
LOGGER.error("Failed to create directory at {}", directory.getPath());
|
||||
return;
|
||||
}
|
||||
Tracks tracks = SettingsManager.load(SettingsManager.TRACKS_PATH, Tracks.class);
|
||||
int rows = 0;
|
||||
File[] files = directory.listFiles();
|
||||
if (files != null) {
|
||||
for (Track track : tracks.getTracks()) {
|
||||
for (File file : files) {
|
||||
if (file.exists() && file.getName().equals(track.getFileName())) {
|
||||
rows += DatabaseManager.storeAudio(track.getFileName(), track.getTags());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
LOGGER.debug("Updated with {} local tracks", rows);
|
||||
} catch (IOException ex) {
|
||||
LOGGER.error("Failed to load local tracks; {}", ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onReady(@NotNull ReadyEvent event) {
|
||||
super.onReady(event);
|
||||
@@ -131,7 +157,7 @@ public class Listener extends ListenerAdapter {
|
||||
commands.put("volume", new VolumeCommand(this));
|
||||
commands.put("pause", new PauseCommand(this));
|
||||
commands.put("resume", new ResumeCommand(this));
|
||||
commands.put("image", new ImageCommand(this));
|
||||
// commands.put("image", new ImageCommand(this));
|
||||
jda.getGuilds().forEach(guild -> executor.execute(() -> {
|
||||
LOGGER.debug("Updating commands for \"{}\" <{}>", guild.getName(), guild.getId());
|
||||
guild.updateCommands().addCommands(
|
||||
@@ -151,7 +177,7 @@ public class Listener extends ListenerAdapter {
|
||||
try {
|
||||
commands.get(command).execute(event);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
LOGGER.error("Failed to execute command; {}", ex.getMessage());
|
||||
event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -30,7 +30,7 @@ public class Main {
|
||||
try {
|
||||
start();
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
LOGGER.error("Caught unhandled exception; {}", ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,8 +23,10 @@ import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel;
|
||||
import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.postgresql.jdbc.PgArray;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Timestamp;
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
@@ -141,15 +143,37 @@ public class OpenAIManager {
|
||||
|
||||
// Handle System Messages
|
||||
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
|
||||
chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName()));
|
||||
chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
|
||||
if (event.isFromThread()) {
|
||||
String query = new QueryBuilder().from("messages")
|
||||
String query = new QueryBuilder("messages")
|
||||
.where("guild_id = ? AND thread_id = ?")
|
||||
.orderBy("timestamp DESC")
|
||||
.limit(10)
|
||||
.build();
|
||||
List<MessageData> previousMessages = DatabaseManager.getMessages(
|
||||
|
||||
// Build MessageData objects from query results
|
||||
List<Map<String, Object>> results = DatabaseManager.query(
|
||||
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
|
||||
List<MessageData> previousMessages = new ArrayList<>();
|
||||
for (Map<String, Object> result : results) {
|
||||
Object[] resultTopicObjects = (Object[]) ((PgArray) result.get("topics")).getArray();
|
||||
Set<String> resultTopics = new HashSet<>();
|
||||
for (Object object : resultTopicObjects) {
|
||||
resultTopics.add((String) object);
|
||||
}
|
||||
MessageData messageData = new MessageData(
|
||||
(long) result.get("guild_id"),
|
||||
(long) result.get("thread_id"),
|
||||
(long) result.get("user_id"),
|
||||
(String) result.get("message_type"),
|
||||
(String) result.get("message_text"),
|
||||
(String) result.get("message_response"),
|
||||
resultTopics,
|
||||
(Timestamp) result.get("timestamp")
|
||||
);
|
||||
previousMessages.add(messageData);
|
||||
}
|
||||
|
||||
Set<String> potentialTopics = new HashSet<>();
|
||||
for (MessageData previousMessage : previousMessages) {
|
||||
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
|
||||
@@ -157,10 +181,10 @@ public class OpenAIManager {
|
||||
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
|
||||
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText()));
|
||||
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse()));
|
||||
chatMessages.add(previousChatMessage);
|
||||
LOGGER.trace("Potential topics: {}", potentialTopics);
|
||||
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
|
||||
// chatMessages.add(previousChatMessage);
|
||||
}
|
||||
LOGGER.trace("Potential topics: {}", potentialTopics);
|
||||
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
|
||||
}
|
||||
|
||||
return ChatCompletionRequest.builder()
|
||||
|
||||
@@ -68,29 +68,29 @@ public class AudioHandler extends AudioEventAdapter implements AudioSendHandler
|
||||
}
|
||||
|
||||
public void setVolume(int volume) {
|
||||
LOGGER.debug("Set volume to {}", volume);
|
||||
LOGGER.trace("Set volume to {}", volume);
|
||||
player.setVolume(volume);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onPlayerPause(AudioPlayer player) {
|
||||
LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
|
||||
LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onPlayerResume(AudioPlayer player) {
|
||||
LOGGER.debug("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
|
||||
LOGGER.trace("isPaused: {} for {}", player.isPaused(), player.getPlayingTrack().getInfo().title);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrackStart(AudioPlayer player, AudioTrack track) {
|
||||
LOGGER.debug("Starting track {}", track.getInfo().title);
|
||||
LOGGER.trace("Starting track {}", track.getInfo().title);
|
||||
manager.getListener().getJDA().getPresence().setActivity(Activity.playing(track.getInfo().title));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrackEnd(AudioPlayer player, AudioTrack track, AudioTrackEndReason endReason) {
|
||||
LOGGER.debug("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext);
|
||||
LOGGER.trace("Track ended due to {}; {} ", endReason.name(), endReason.mayStartNext);
|
||||
if (queue.isEmpty()) {
|
||||
manager.getListener().closeAudioConnection(guildID);
|
||||
manager.getListener().getJDA().getPresence().setActivity(Activity.playing("nothing"));
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
package com.bensherriff.siren.audio;
|
||||
|
||||
import com.bensherriff.siren.Listener;
|
||||
import com.sedmelluq.discord.lavaplayer.player.AudioConfiguration;
|
||||
import com.sedmelluq.discord.lavaplayer.player.DefaultAudioPlayerManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.AudioSourceManagers;
|
||||
import com.sedmelluq.discord.lavaplayer.source.bandcamp.BandcampAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.beam.BeamAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.http.HttpAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.soundcloud.SoundCloudAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.twitch.TwitchStreamAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.vimeo.VimeoAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.youtube.YoutubeAudioSourceManager;
|
||||
|
||||
public class PlayerManager extends DefaultAudioPlayerManager {
|
||||
|
||||
@@ -17,6 +26,15 @@ public class PlayerManager extends DefaultAudioPlayerManager {
|
||||
}
|
||||
|
||||
public void initialize() {
|
||||
getConfiguration().setResamplingQuality(AudioConfiguration.ResamplingQuality.MEDIUM);
|
||||
registerSourceManager(new YoutubeAudioSourceManager());
|
||||
registerSourceManager(SoundCloudAudioSourceManager.createDefault());
|
||||
registerSourceManager(new BandcampAudioSourceManager());
|
||||
registerSourceManager(new VimeoAudioSourceManager());
|
||||
registerSourceManager(new TwitchStreamAudioSourceManager());
|
||||
registerSourceManager(new BeamAudioSourceManager());
|
||||
registerSourceManager(new HttpAudioSourceManager());
|
||||
registerSourceManager(new LocalAudioSourceManager());
|
||||
AudioSourceManagers.registerRemoteSources(this);
|
||||
AudioSourceManagers.registerLocalSource(this);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEve
|
||||
import net.dv8tion.jda.api.interactions.commands.OptionMapping;
|
||||
import net.dv8tion.jda.api.interactions.commands.OptionType;
|
||||
import net.dv8tion.jda.api.interactions.commands.build.Commands;
|
||||
import net.dv8tion.jda.api.interactions.commands.build.OptionData;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
@@ -18,7 +19,11 @@ public class ImageCommand extends Command {
|
||||
slashCommandData = Commands.slash("image", "Generate an image using DALL-E")
|
||||
.addOption(OptionType.STRING, "prompt", "The prompt for image generation", true)
|
||||
.addOption(OptionType.INTEGER, "count", "The number of images to be generated", false)
|
||||
.addOption(OptionType.INTEGER, "size", "Size of the picture, either 1 (small), 2 (medium), or 3 (large)", false);
|
||||
.addOptions(new OptionData(OptionType.STRING, "type", "The size of the picture", false)
|
||||
.addChoice("Small", ImageSize.SMALL.getSize())
|
||||
.addChoice("Medium", ImageSize.MEDIUM.getSize())
|
||||
.addChoice("Large", ImageSize.LARGE.getSize())
|
||||
);
|
||||
}
|
||||
|
||||
//TODO Store image in database
|
||||
|
||||
@@ -1,32 +1,94 @@
|
||||
package com.bensherriff.siren.commands;
|
||||
|
||||
import com.bensherriff.siren.audio.AudioHandler;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.database.QueryBuilder;
|
||||
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
||||
import com.bensherriff.siren.Listener;
|
||||
import com.bensherriff.siren.settings.SettingsManager;
|
||||
import com.sedmelluq.discord.lavaplayer.container.MediaContainerDescriptor;
|
||||
import com.sedmelluq.discord.lavaplayer.player.AudioLoadResultHandler;
|
||||
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioSourceManager;
|
||||
import com.sedmelluq.discord.lavaplayer.source.local.LocalAudioTrack;
|
||||
import com.sedmelluq.discord.lavaplayer.tools.FriendlyException;
|
||||
import com.sedmelluq.discord.lavaplayer.track.AudioPlaylist;
|
||||
import com.sedmelluq.discord.lavaplayer.track.AudioTrack;
|
||||
import com.sedmelluq.discord.lavaplayer.track.AudioTrackInfo;
|
||||
import com.sedmelluq.discord.lavaplayer.track.BasicAudioPlaylist;
|
||||
import net.dv8tion.jda.api.entities.Guild;
|
||||
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
|
||||
import net.dv8tion.jda.api.interactions.commands.OptionType;
|
||||
import net.dv8tion.jda.api.interactions.commands.build.Commands;
|
||||
import net.dv8tion.jda.api.interactions.commands.build.SubcommandData;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PlayCommand extends Command {
|
||||
|
||||
public PlayCommand(Listener listener) {
|
||||
super(listener);
|
||||
slashCommandData = Commands.slash("play", "Play a track from a URL")
|
||||
.addOption(OptionType.STRING, "url", "Track URL", true);
|
||||
slashCommandData = Commands.slash("play", "Play a track")
|
||||
.addSubcommands(new SubcommandData("url", "Play a track from a URL")
|
||||
.addOption(OptionType.STRING, "url", "Track URL", true))
|
||||
.addSubcommands(new SubcommandData("local", "Play a track from a local file")
|
||||
.addOption(OptionType.STRING, "file", "Local file name", true))
|
||||
.addSubcommands(new SubcommandData("tags", "Play a track based on tags")
|
||||
.addOption(OptionType.STRING, "tags", "The list of tags separated by a semicolon", true)
|
||||
.addOption(OptionType.BOOLEAN, "inclusive", "If true, tracks can match any tag"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(SlashCommandInteractionEvent event) throws IOException {
|
||||
String trackURL = Objects.requireNonNull(event.getOption("url")).getAsString();
|
||||
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event));
|
||||
String audioDirectoryPath = SettingsManager.AUDIO_DIRECTORY + SettingsManager.SEPARATOR;
|
||||
if ("url".equals(event.getSubcommandName())) {
|
||||
String trackURL = Objects.requireNonNull(event.getOption("url")).getAsString();
|
||||
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event));
|
||||
} else if ("local".equals(event.getSubcommandName())) {
|
||||
String fileName = Objects.requireNonNull(event.getOption("file")).getAsString();
|
||||
if (!fileName.contains(".m4a")) {
|
||||
fileName = fileName.concat(".m4a");
|
||||
}
|
||||
String trackURL = audioDirectoryPath + fileName;
|
||||
listener.getPlayerManager().loadItemOrdered(event.getGuild(), trackURL, new ResultHandler(event));
|
||||
} else if ("tags".equals(event.getSubcommandName())) {
|
||||
String query;
|
||||
String tagsString = Objects.requireNonNull(event.getOption("tags")).getAsString();
|
||||
String[] tags = tagsString.split(";,");
|
||||
for (int i = 0; i < tags.length; i++) {
|
||||
tags[i] = tags[i].trim();
|
||||
}
|
||||
if (event.getOption("inclusive") != null && Objects.requireNonNull(event.getOption("inclusive")).getAsBoolean()) {
|
||||
query = new QueryBuilder("audio").where("tags && ARRAY[?]").build();
|
||||
} else {
|
||||
query = new QueryBuilder("audio").where("tags @> ARRAY[?]").build();
|
||||
}
|
||||
try {
|
||||
List<AudioTrack> tracks = new ArrayList<>();
|
||||
List<Map<String, Object>> results = DatabaseManager.query(query, List.of(tags));
|
||||
if (results.isEmpty()) {
|
||||
event.getHook().sendMessage("No tracks found with the those tags.").queue();
|
||||
return;
|
||||
}
|
||||
for (Map<String, Object> result : results) {
|
||||
// String title = (String) result.get("title");
|
||||
// String author = (String) result.get("author");
|
||||
// long length = (Long) result.get("length");
|
||||
// String identifier = (String) result.get("identifier");
|
||||
String fileName = audioDirectoryPath.concat((String) result.get("file_name"));
|
||||
LOGGER.debug("{}", fileName);
|
||||
listener.getPlayerManager().loadItemOrdered(event.getGuild(), fileName, new ResultHandler(event));
|
||||
}
|
||||
// AudioPlaylist playlist = new BasicAudioPlaylist("Playlist based on tags", tracks, null, false);
|
||||
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error("Failed to retrieve audio tags; {}", ex.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class ResultHandler implements AudioLoadResultHandler {
|
||||
@@ -52,7 +114,11 @@ public class PlayCommand extends Command {
|
||||
public void trackLoaded(AudioTrack track) {
|
||||
try {
|
||||
playTrack(guild, userID, audioHandler, track);
|
||||
event.getHook().sendMessage("Adding **" + track.getInfo().title + "** to queue...").queue();
|
||||
String trackTitle = "**" + track.getInfo().title + "**";
|
||||
if (trackTitle.equalsIgnoreCase("Unknown title")) {
|
||||
trackTitle = "track";
|
||||
}
|
||||
event.getHook().sendMessage("Adding " + trackTitle + " to queue...").queue();
|
||||
} catch (EmptyVoiceChannelException e) {
|
||||
event.getHook().sendMessage("You must be connected to a voice channel in order to play tracks!").queue();
|
||||
} catch (Exception e) {
|
||||
@@ -80,7 +146,7 @@ public class PlayCommand extends Command {
|
||||
|
||||
@Override
|
||||
public void noMatches() {
|
||||
event.getHook().sendMessage("Nothing found at that URL").queue();
|
||||
event.getHook().sendMessage("No track found").queue();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -41,7 +41,7 @@ public class DatabaseConnection {
|
||||
try {
|
||||
connection.close();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
LOGGER.error("Failed to close connection; {}", ex.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.sql.*;
|
||||
import java.sql.Date;
|
||||
import java.util.*;
|
||||
|
||||
import static java.util.Map.entry;
|
||||
@@ -11,49 +12,38 @@ import static java.util.Map.entry;
|
||||
public class DatabaseManager {
|
||||
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
|
||||
|
||||
private static final Map<String, String> createTableQueries = Map.ofEntries(
|
||||
entry("messages", "CREATE TABLE IF NOT EXISTS messages (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"guild_id BIGINT NOT NULL, " +
|
||||
"thread_id BIGINT NOT NULL, " +
|
||||
"user_id BIGINT NOT NULL, " +
|
||||
"message_type VARCHAR(20) NOT NULL," +
|
||||
"message_text TEXT NOT NULL," +
|
||||
"message_response TEXT, " +
|
||||
"topics TEXT[], " +
|
||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
|
||||
")"),
|
||||
entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"embeddings FLOAT[] NOT NULL" +
|
||||
")"),
|
||||
entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"message_id INT NOT NULL, " +
|
||||
"embedding_id INT NOT NULL" +
|
||||
")")
|
||||
private static final Map<String, List<String>> createTableQueries = Map.ofEntries(
|
||||
entry("messages", List.of(
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"guild_id BIGINT NOT NULL",
|
||||
"thread_id BIGINT NOT NULL",
|
||||
"user_id BIGINT NOT NULL",
|
||||
"message_type VARCHAR(20) NOT NULL",
|
||||
"message_text TEXT NOT NULL",
|
||||
"message_response TEXT",
|
||||
"topics TEXT[]",
|
||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()")),
|
||||
entry("embeddings", List.of(
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"embeddings FLOAT[] NOT NULL")),
|
||||
entry("message_embeddings", List.of(
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"message_id INT NOT NULL",
|
||||
"embedding_id INT NOT NULL")),
|
||||
entry("audio", List.of(
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"title TEXT NOT NULL",
|
||||
"author TEXT NOT NULL",
|
||||
"length BIGINT NOT NULL",
|
||||
"identifier TEXT NOT NULL",
|
||||
"file_name TEXT NOT NULL",
|
||||
"tags TEXT[]"
|
||||
))
|
||||
);
|
||||
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||
"message_type, " +
|
||||
"guild_id, " +
|
||||
"thread_id, " +
|
||||
"user_id, " +
|
||||
"message_text, " +
|
||||
"message_response, " +
|
||||
"topics) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)";
|
||||
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
||||
"embeddings) " +
|
||||
"VALUES (?)";
|
||||
|
||||
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
|
||||
"message_id, " +
|
||||
"embeddings_id, " +
|
||||
"VALUES (?, ?)";
|
||||
|
||||
public static void createTables() {
|
||||
for (Map.Entry<String, String> entry : createTableQueries.entrySet()) {
|
||||
if (!createTable(entry.getKey())) {
|
||||
for (Map.Entry<String, List<String>> entry : createTableQueries.entrySet()) {
|
||||
if (!createTable(entry.getKey(), entry.getValue())) {
|
||||
LOGGER.warn("Failed to create one or more required database tables");
|
||||
return;
|
||||
}
|
||||
@@ -61,79 +51,135 @@ public class DatabaseManager {
|
||||
LOGGER.debug("Databases initialized");
|
||||
}
|
||||
|
||||
private static boolean createTable(String tableName) {
|
||||
private static boolean createTable(String tableName, List<String> columns) {
|
||||
if (tableExists(tableName)) {
|
||||
return true;
|
||||
} else {
|
||||
try {
|
||||
StringBuilder stringBuilder = new StringBuilder("CREATE TABLE IF NOT EXISTS ")
|
||||
.append(tableName).append(" ( ");
|
||||
for (int i = 0; i < columns.size(); i++) {
|
||||
stringBuilder.append(columns.get(i));
|
||||
if (i != columns.size() - 1) {
|
||||
stringBuilder.append(", ");
|
||||
}
|
||||
}
|
||||
stringBuilder.append(")");
|
||||
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(createTableQueries.get(tableName));
|
||||
statement.execute(stringBuilder.toString());
|
||||
return true;
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
LOGGER.error("Failed to create table; {}", ex.getMessage());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeMessage(MessageData messageData) {
|
||||
if (!tableExists("messages")) {
|
||||
LOGGER.warn("Table 'messages' does not exist");
|
||||
return -1;
|
||||
}
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE);
|
||||
preparedStatement.setString(1, messageData.getMessageType());
|
||||
preparedStatement.setLong(2, messageData.getGuildId());
|
||||
preparedStatement.setLong(3, messageData.getThreadId());
|
||||
preparedStatement.setLong(4, messageData.getUserId());
|
||||
preparedStatement.setString(5, messageData.getMessageText());
|
||||
preparedStatement.setString(6, messageData.getMessageResponse());
|
||||
preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray()));
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||
"message_type, " +
|
||||
"guild_id, " +
|
||||
"thread_id, " +
|
||||
"user_id, " +
|
||||
"message_text, " +
|
||||
"message_response, " +
|
||||
"topics) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)";
|
||||
return storeMessage("messages", INSERT_MESSAGE,
|
||||
messageData.getMessageType(),
|
||||
messageData.getGuildId(),
|
||||
messageData.getThreadId(),
|
||||
messageData.getUserId(),
|
||||
messageData.getMessageText(),
|
||||
messageData.getMessageResponse(),
|
||||
messageData.getTopics()
|
||||
);
|
||||
}
|
||||
|
||||
public static int storeEmbedding(List<Double> data) {
|
||||
if (!tableExists("embeddings")) {
|
||||
LOGGER.warn("Table 'embeddings' does not exist");
|
||||
return -1;
|
||||
}
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING);
|
||||
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
||||
"embeddings) " +
|
||||
"VALUES (?)";
|
||||
return storeMessage("embeddings", INSERT_EMBEDDING, data);
|
||||
}
|
||||
|
||||
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
|
||||
if (!tableExists("message_embeddings")) {
|
||||
LOGGER.warn("Table 'message_embeddings' does not exist");
|
||||
String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
|
||||
"message_id, " +
|
||||
"embeddings_id, " +
|
||||
"VALUES (?, ?)";
|
||||
return storeMessage("message_embeddings", INSERT_MESSAGE_EMBEDDINGS, messageId, embeddingId);
|
||||
}
|
||||
|
||||
public static int storeAudio(String fileName, String[] tags) {
|
||||
String INSERT_AUDIO = "INSERT INTO audio (" +
|
||||
"file_name, " +
|
||||
"tags) " +
|
||||
"VALUES (?, ?)";
|
||||
return storeMessage("audio", INSERT_AUDIO, fileName, tags);
|
||||
}
|
||||
|
||||
private static int storeMessage(String tableName, String insertString, Object... params) {
|
||||
if (!tableExists(tableName)) {
|
||||
LOGGER.warn("Table '{}' does not exist", tableName);
|
||||
return -1;
|
||||
}
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS);
|
||||
preparedStatement.setInt(1, messageId);
|
||||
preparedStatement.setInt(2, embeddingId);
|
||||
PreparedStatement preparedStatement = connection.prepareStatement(insertString);
|
||||
int i = 1;
|
||||
for (Object param : params) {
|
||||
if (param == null) {
|
||||
preparedStatement.setNull(i, Types.NULL);
|
||||
} else if (param instanceof String) {
|
||||
preparedStatement.setString(i, (String) param);
|
||||
} else if (param instanceof Integer) {
|
||||
preparedStatement.setInt(i, (int) param);
|
||||
} else if (param instanceof Double) {
|
||||
preparedStatement.setDouble(i, (double) param);
|
||||
} else if (param instanceof Float) {
|
||||
preparedStatement.setFloat(i, (float) param);
|
||||
} else if (param instanceof Long) {
|
||||
preparedStatement.setLong(i, (long) param);
|
||||
} else if (param instanceof Boolean) {
|
||||
preparedStatement.setBoolean(i, (boolean) param);
|
||||
} else if (param instanceof Date) {
|
||||
preparedStatement.setDate(i, (Date) param);
|
||||
} else if (param instanceof Time) {
|
||||
preparedStatement.setTime(i, (Time) param);
|
||||
} else if (param instanceof Timestamp) {
|
||||
preparedStatement.setTimestamp(i, (Timestamp) param);
|
||||
} else if (param instanceof byte[]) {
|
||||
preparedStatement.setBytes(i, (byte[]) param);
|
||||
} else if (param instanceof Blob) {
|
||||
preparedStatement.setBlob(i, (Blob) param);
|
||||
} else if (param instanceof Clob) {
|
||||
preparedStatement.setClob(i, (Clob) param);
|
||||
} else if (param instanceof List && ((List<?>) param).get(0) instanceof String) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray()));
|
||||
} else if (param instanceof List && ((List<?>) param).get(0) instanceof Double) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0])));
|
||||
} else if (param instanceof Set && ((Set<?>) param).toArray()[0] instanceof String) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray()));
|
||||
} else if (param instanceof String[]) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param)));
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass());
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
LOGGER.error("Failed to store message; {}", ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
public static List<MessageData> getMessages(String query, Object... params) throws SQLException, IllegalArgumentException {
|
||||
public static List<Map<String, Object>> query(String query, Object... params) throws SQLException, IllegalArgumentException {
|
||||
LOGGER.trace("Query: <{}>", query);
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
PreparedStatement stmt = connection.prepareStatement(query);
|
||||
@@ -159,21 +205,19 @@ public class DatabaseManager {
|
||||
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName());
|
||||
}
|
||||
}
|
||||
ResultSet resultSet = stmt.executeQuery();
|
||||
List<MessageData> resultList = new ArrayList<>();
|
||||
while (resultSet.next()) {
|
||||
Array array = resultSet.getArray(8);
|
||||
MessageData messageData = new MessageData(
|
||||
resultSet.getLong(2),
|
||||
resultSet.getLong(3),
|
||||
resultSet.getLong(4),
|
||||
resultSet.getString(5),
|
||||
resultSet.getString(6),
|
||||
resultSet.getString(7),
|
||||
new HashSet<>(Arrays.asList((String[]) array.getArray())),
|
||||
resultSet.getTimestamp(9)
|
||||
);
|
||||
resultList.add(messageData);
|
||||
List<Map<String, Object>> resultList = new ArrayList<>();
|
||||
try (ResultSet resultSet = stmt.executeQuery()) {
|
||||
ResultSetMetaData metaData = resultSet.getMetaData();
|
||||
int columnCount = metaData.getColumnCount();
|
||||
while (resultSet.next()) {
|
||||
Map<String, Object> rowMap = new HashMap<>();
|
||||
for (int j = 1; j <= columnCount; j++) {
|
||||
rowMap.put(metaData.getColumnName(j), resultSet.getObject(j));
|
||||
}
|
||||
resultList.add(rowMap);
|
||||
}
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error("Failed to execute query; {}", ex.getMessage());
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
@@ -185,8 +229,18 @@ public class DatabaseManager {
|
||||
ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'");
|
||||
return resultSet.next();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error("Failed to check if table exists; {}" + ex.getMessage());
|
||||
LOGGER.error("Failed to check if table exists; {}", ex.getMessage());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public static void clearTable(String tableName) {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
Statement statement = connection.createStatement();
|
||||
statement.executeUpdate("TRUNCATE TABLE " + tableName);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error("Failed to clear table; {}", ex.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,18 +3,17 @@ package com.bensherriff.siren.database;
|
||||
public class QueryBuilder {
|
||||
private boolean distinct;
|
||||
private String columnList;
|
||||
private String tableName;
|
||||
private final String tableName;
|
||||
private String whereClause;
|
||||
private String orderByClause;
|
||||
private Integer limit;
|
||||
|
||||
public QueryBuilder select(String columnList) {
|
||||
this.columnList = columnList;
|
||||
return this;
|
||||
public QueryBuilder(String tableName) {
|
||||
this.tableName = tableName;
|
||||
}
|
||||
|
||||
public QueryBuilder from(String tableName) {
|
||||
this.tableName = tableName;
|
||||
public QueryBuilder select(String columnList) {
|
||||
this.columnList = columnList;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
package com.bensherriff.siren.settings;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class Settings {
|
||||
|
||||
private String token = "";
|
||||
private String owner = "";
|
||||
private String owner = "250842261221277697";
|
||||
private int threadPool = 2;
|
||||
private Map<Long, GuildSettings> guildSettings = new HashMap<>();
|
||||
private OpenAISettings openAISettings = new OpenAISettings();
|
||||
private List<TrackSettings> tracks = new ArrayList<>();
|
||||
|
||||
public String getToken() {
|
||||
return token;
|
||||
@@ -54,12 +51,4 @@ public class Settings {
|
||||
public void setOpenAISettings(OpenAISettings openAISettings) {
|
||||
this.openAISettings = openAISettings;
|
||||
}
|
||||
|
||||
public List<TrackSettings> getTracks() {
|
||||
return tracks;
|
||||
}
|
||||
|
||||
public void setTracks(List<TrackSettings> tracks) {
|
||||
this.tracks = tracks;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,11 +10,15 @@ import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
|
||||
public class SettingsManager {
|
||||
private static final Logger LOGGER = LogManager.getLogger(SettingsManager.class);
|
||||
public static final String SEPARATOR = File.separator;
|
||||
public static final String PATH = String.join(SEPARATOR, System.getProperty("user.dir"), "settings.json");
|
||||
public static final String USER_DIRECTORY = System.getProperty("user.dir");
|
||||
public static final String PATH = String.join(SEPARATOR, USER_DIRECTORY, "settings.json");
|
||||
public static final String AUDIO_DIRECTORY = String.join(SEPARATOR, USER_DIRECTORY, "audio");
|
||||
public static final String TRACKS_PATH = String.join(SEPARATOR, AUDIO_DIRECTORY, "tracks.json");
|
||||
private static final ObjectMapper mapper = new ObjectMapper();
|
||||
private static final ObjectWriter writer = mapper.writer(new DefaultPrettyPrinter());
|
||||
|
||||
@@ -23,14 +27,22 @@ public class SettingsManager {
|
||||
}
|
||||
|
||||
public static Settings load(String path) throws IOException {
|
||||
return load(path, Settings.class);
|
||||
}
|
||||
|
||||
public static <T> T load(String path, Class<T> type) throws IOException {
|
||||
File file = new File(path);
|
||||
if (!file.exists()) {
|
||||
LOGGER.warn("Settings file does not exist, creating new file at: {}", file.getPath());
|
||||
write(new Settings());
|
||||
LOGGER.warn("{} file does not exist, creating new file at: {}", type.getSimpleName(), file.getPath());
|
||||
try {
|
||||
write(path, type.getConstructor().newInstance());
|
||||
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException ex) {
|
||||
throw new IOException(ex);
|
||||
}
|
||||
}
|
||||
LOGGER.info("Reading settings from {}", file.getPath());
|
||||
LOGGER.info("Reading {} file from {}", type.getSimpleName(), file.getPath());
|
||||
try (InputStream inputStream = new FileInputStream(file)) {
|
||||
return mapper.readValue(inputStream, Settings.class);
|
||||
return mapper.readValue(inputStream, type);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,8 +50,8 @@ public class SettingsManager {
|
||||
write(PATH, settings);
|
||||
}
|
||||
|
||||
public static void write(String path, Settings settings) throws IOException {
|
||||
public static void write(String path, Object object) throws IOException {
|
||||
File file = new File(path);
|
||||
writer.writeValue(file, settings);
|
||||
writer.writeValue(file, object);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package com.bensherriff.siren.settings;
|
||||
|
||||
public class TrackSettings {
|
||||
private String file;
|
||||
public class Track {
|
||||
private String fileName;
|
||||
private String[] tags;
|
||||
|
||||
public String getFile() {
|
||||
return file;
|
||||
public String getFileName() {
|
||||
return fileName;
|
||||
}
|
||||
|
||||
public void setFile(String file) {
|
||||
this.file = file;
|
||||
public void setFileName(String fileName) {
|
||||
this.fileName = fileName;
|
||||
}
|
||||
|
||||
public String[] getTags() {
|
||||
17
src/main/java/com/bensherriff/siren/settings/Tracks.java
Normal file
17
src/main/java/com/bensherriff/siren/settings/Tracks.java
Normal file
@@ -0,0 +1,17 @@
|
||||
package com.bensherriff.siren.settings;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class Tracks {
|
||||
|
||||
private List<Track> tracks = new ArrayList<>();
|
||||
|
||||
public List<Track> getTracks() {
|
||||
return tracks;
|
||||
}
|
||||
|
||||
public void setTracks(List<Track> tracks) {
|
||||
this.tracks = tracks;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user