diff --git a/.env b/.env deleted file mode 100644 index d249bd5..0000000 --- a/.env +++ /dev/null @@ -1 +0,0 @@ -export SIREN_VERSION=0.1.14 \ No newline at end of file diff --git a/.env.TEMPLATE b/.env.TEMPLATE new file mode 100644 index 0000000..79fec62 --- /dev/null +++ b/.env.TEMPLATE @@ -0,0 +1,3 @@ +export POSTGRES_USER= +export POSTGRES_PASSWORD= +export POSTGRES_DB= \ No newline at end of file diff --git a/.gitignore b/.gitignore index f94094e..7fa04e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ .idea/ **/target/ **/data/ +**/app/ **/settings.json -**/logs/ \ No newline at end of file +**/logs/ +.env \ No newline at end of file diff --git a/.version b/.version new file mode 100644 index 0000000..4a0b6e2 --- /dev/null +++ b/.version @@ -0,0 +1 @@ +export SIREN_VERSION=0.1.15 \ No newline at end of file diff --git a/Makefile b/Makefile index ce6fe43..ab6dbf5 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,9 @@ SHELL := /bin/bash +include .version include .env build: - docker rmi siren && docker-compose build + if docker inspect siren > /dev/null 2>&1; then docker rmi siren; fi; docker-compose build test: docker run --rm -it siren:latest bash diff --git a/README.md b/README.md index f57524d..88c4167 100644 --- a/README.md +++ b/README.md @@ -38,13 +38,15 @@ https://discord.com/api/oauth2/authorize?client_id=&permissions=54696 - applications.commands ``` -`docker build -t siren .` -`docker-compose up -d` +``` +make build +make up +``` ## Development Build container -`docker build -t siren .` +`make build` Run container locally diff --git a/docker-compose.yml b/docker-compose.yml index 138ff0e..9310f70 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,5 +11,22 @@ services: - JAVA_VERSION=17 - VERSION=${SIREN_VERSION} volumes: - - ./data:/app + - ./app:/app + environment: + DATABASE_URL: jdbc:postgresql://db:5432/${POSTGRES_DB} + DATABASE_USERNAME: ${POSTGRES_USER} + DATABASE_PASSWORD: ${POSTGRES_PASSWORD} + depends_on: + - db restart: unless-stopped + db: + image: postgres:latest + environment: + POSTGRES_USER: ${POSTGRES_USER} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + POSTGRES_DB: ${POSTGRES_DB} + volumes: + - ./data:/var/lib/postgresql/data + ports: + - "5432:5432" + restart: unless-stopped \ No newline at end of file diff --git a/pom.xml b/pom.xml index d86395b..6fdb07b 100644 --- a/pom.xml +++ b/pom.xml @@ -66,6 +66,7 @@ 1.3.13 2.14.2 0.12.0 + 42.6.0 2.0.6 2.20.0 UTF-8 @@ -74,7 +75,6 @@ - net.dv8tion JDA @@ -95,13 +95,16 @@ jackson-databind ${jackson.version} - - com.theokanning.openai-gpt3-java service ${theokanning-openai-gpt3.version} + + org.postgresql + postgresql + ${postgresql.version} + diff --git a/src/main/java/com/bensherriff/siren/Listener.java b/src/main/java/com/bensherriff/siren/Listener.java index 0e9543c..7c3feb7 100644 --- a/src/main/java/com/bensherriff/siren/Listener.java +++ b/src/main/java/com/bensherriff/siren/Listener.java @@ -3,6 +3,7 @@ package com.bensherriff.siren; import com.bensherriff.siren.audio.AudioHandler; import com.bensherriff.siren.audio.PlayerManager; import com.bensherriff.siren.commands.*; +import com.bensherriff.siren.database.DatabaseManager; import com.bensherriff.siren.exceptions.EmptyVoiceChannelException; import com.bensherriff.siren.openai.OpenAIManager; import com.bensherriff.siren.settings.GuildSettings; @@ -68,6 +69,8 @@ public class Listener extends ListenerAdapter { this.playerManager.initialize(); this.openAIManager = new OpenAIManager(this); + DatabaseManager.createTables(); + commands.put("play", new PlayCommand(this)); commands.put("stop", new StopCommand(this)); commands.put("skip", new SkipCommand(this)); diff --git a/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java b/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java new file mode 100644 index 0000000..46bf275 --- /dev/null +++ b/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java @@ -0,0 +1,65 @@ +package com.bensherriff.siren.database; + +import com.bensherriff.siren.openai.OpenAIManager; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +public class DatabaseConnection { + private static final Logger LOGGER = LogManager.getLogger(DatabaseConnection.class); + private static final ExecutorService executorService = Executors.newFixedThreadPool(4); + private static final ThreadLocal connectionThreadLocal = new ThreadLocal<>(); + + public static Connection getConnection() throws SQLException { + Connection connection = connectionThreadLocal.get(); + + if (connection == null) { + Map env = System.getenv(); + String dbUrl = env.get("DATABASE_URL"); + String dbUsername = env.get("DATABASE_USERNAME"); + String dbPassword = env.get("DATABASE_PASSWORD"); + + connection = DriverManager.getConnection(dbUrl, dbUsername, dbPassword); + connectionThreadLocal.set(connection); + } + + return connection; + } + + public static void closeConnection() { + Connection connection = connectionThreadLocal.get(); + connectionThreadLocal.remove(); + + if (connection != null) { + executorService.submit(() -> { + try { + connection.close(); + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + } + }); + } + } + + public static void shutdown() { + executorService.shutdown(); + try { + if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) { + LOGGER.error("ExecutorService did not terminate"); + } + } + } catch (InterruptedException e) { + executorService.shutdownNow(); + Thread.currentThread().interrupt(); + } + } +} diff --git a/src/main/java/com/bensherriff/siren/database/DatabaseManager.java b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java new file mode 100644 index 0000000..b4f449d --- /dev/null +++ b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java @@ -0,0 +1,123 @@ +package com.bensherriff.siren.database; + +import com.theokanning.openai.completion.chat.ChatMessage; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.*; +import java.util.List; + +public class DatabaseManager { + private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class); + private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" + + "id SERIAL PRIMARY KEY, " + + "guild_id BIGINT NOT NULL, " + + "thread_id BIGINT NOT NULL, " + + "user_id BIGINT NOT NULL, " + + "message_type VARCHAR(20) NOT NULL," + + "message_text TEXT NOT NULL," + + "message_response TEXT, " + + "timestamp TIMESTAMP NOT NULL DEFAULT NOW()" + + ")"; + private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" + + "id SERIAL PRIMARY KEY, " + + "embeddings FLOAT[] NOT NULL" + + ");"; + + private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" + + "id SERIAL PRIMARY KEY, " + + "message_id INT NOT NULL, " + + "embedding_id INT NOT NULL" + + ")"; + private static final String INSERT_MESSAGE = "INSERT INTO messages (" + + "message_type, " + + "guild_id, " + + "thread_id, " + + "user_id, " + + "message_text, " + + "message_response) " + + "VALUES (?, ?, ?, ?, ?, ?)"; + private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" + + "embeddings) " + + "VALUES (?)"; + + private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" + + "message_id, " + + "embeddings_id, " + + "VALUES (?, ?)"; + + public static void createTables() { + createMessageTable(); + createEmbeddingsTable(); + createMessageEmbeddingsTable(); + } + + private static void createMessageTable() { + try { + Connection connection = DatabaseConnection.getConnection(); + LOGGER.debug("Creating 'messages' database table if it does not exist"); + Statement statement = connection.createStatement(); + statement.execute(CREATE_MESSAGE_TABLE); + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + } + } + + private static void createEmbeddingsTable() { + try { + Connection connection = DatabaseConnection.getConnection(); + LOGGER.debug("Creating 'embeddings' database table if it does not exist"); + Statement statement = connection.createStatement(); + statement.execute(CREATE_EMBEDDINGS_TABLE); + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + } + } + + private static void createMessageEmbeddingsTable() { + try { + Connection connection = DatabaseConnection.getConnection(); + LOGGER.debug("Creating 'message_embeddings' database table if it does not exist"); + Statement statement = connection.createStatement(); + statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE); + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + } + } + + public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) { + try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) { + preparedStatement.setString(1, message.getRole()); + preparedStatement.setLong(2, guildId); + preparedStatement.setLong(3, threadId); + preparedStatement.setLong(4, userId); + preparedStatement.setString(5, message.getContent()); + preparedStatement.setString(6, response); + return preparedStatement.executeUpdate(); + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + return -1; + } + } + + public static int storeEmbedding(List data) { + try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) { + preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0]))); + return preparedStatement.executeUpdate(); + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + return -1; + } + } + + public static int storeMessageEmbeddings(int messageId, int embeddingId) { + try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) { + preparedStatement.setInt(1, messageId); + preparedStatement.setInt(2, embeddingId); + return preparedStatement.executeUpdate(); + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + return -1; + } + } +} diff --git a/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java b/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java index 28a2b78..ff10672 100644 --- a/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java +++ b/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java @@ -1,6 +1,7 @@ package com.bensherriff.siren.openai; import com.bensherriff.siren.Listener; +import com.bensherriff.siren.database.DatabaseManager; import com.bensherriff.siren.settings.GuildSettings; import com.bensherriff.siren.settings.Settings; import com.theokanning.openai.completion.CompletionRequest; @@ -8,6 +9,9 @@ import com.theokanning.openai.completion.CompletionResult; import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.embedding.Embedding; +import com.theokanning.openai.embedding.EmbeddingRequest; +import com.theokanning.openai.embedding.EmbeddingResult; import com.theokanning.openai.service.OpenAiService; import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.entities.Message; @@ -20,6 +24,7 @@ import org.apache.logging.log4j.Logger; import java.time.Duration; import java.util.*; import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Collectors; public class OpenAIManager { private static final Logger LOGGER = LogManager.getLogger(OpenAIManager.class); @@ -27,13 +32,11 @@ public class OpenAIManager { private final Settings settings; private final JDA jda; private final ScheduledExecutorService executor; - private final Map> threadMessages; public OpenAIManager(Listener listener) { this.settings = listener.getSettings(); this.jda = listener.getJDA(); this.executor = listener.getExecutor(); - this.threadMessages = new HashMap<>(); if (settings.getOpenAISettings().getToken().isEmpty()) { LOGGER.warn("No OpenAI token; OpenAI functionality is disabled"); @@ -62,62 +65,35 @@ public class OpenAIManager { String message = parseMessage(event.getMessage().getContentRaw()); long guildId = event.getGuild().getIdLong(); Model model = settings.getGuildSettings().get(guildId).getModel(); - GuildSettings guildSettings = settings.getGuildSettings().get(guildId); - - LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message); + StringBuilder stringBuilder = new StringBuilder(); + List chatMessages = new ArrayList<>(); + ChatMessage chatMessage = createChatMessage(message, event); + List embeddings = new ArrayList<>(); if (message.isEmpty() || message.isBlank()) { event.getMessage().reply("Your message is empty. Please try again").queue(); return; } - StringBuilder stringBuilder = new StringBuilder(); - List chatMessages = new ArrayList<>(); - ChatMessage chatMessage = createChatMessage(message, event); - try { + LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message); // Send OpenAI Message and get response switch (model) { case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> { - CompletionRequest completionRequest = CompletionRequest.builder() - .model(guildSettings.getModel().getName()) - .maxTokens(guildSettings.getMaxTokens()) - .user(event.getAuthor().getId()) - .temperature(settings.getOpenAISettings().getTemperature()) - .topP(settings.getOpenAISettings().getTopP()) - .frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty()) - .presencePenalty(settings.getOpenAISettings().getPresencePenalty()) - .prompt(message) - .build(); + CompletionRequest completionRequest = createCompletionRequest(message, event); CompletionResult completionResult = openAiService.createCompletion(completionRequest); completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim())); } case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> { - //TODO Handle memories properly -// if (event.isFromThread()) { -// String channelId = event.getChannel().asThreadChannel().getId(); -// // Update ThreadMessages with the new message, and add previous messages to be sent out -// if (threadMessages.containsKey(channelId)) { -// threadMessages.get(channelId).add(chatMessage); -// chatMessages.addAll(threadMessages.get(channelId)); -// } else { -// threadMessages.put(channelId, new ArrayList<>(Arrays.asList(chatMessage))); -// } -// } chatMessages.add(chatMessage); - - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(guildSettings.getModel().getName()) - .maxTokens(guildSettings.getMaxTokens()) - .user(event.getAuthor().getId()) - .temperature(settings.getOpenAISettings().getTemperature()) - .topP(settings.getOpenAISettings().getTopP()) - .frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty()) - .presencePenalty(settings.getOpenAISettings().getPresencePenalty()) - .messages(chatMessages) - .build(); + ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event); ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest); chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim())); + + //TODO fix embeddings +// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event); +// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest); +// embeddings.addAll(embeddingResult.getData()); } default -> { event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue(); @@ -126,29 +102,50 @@ public class OpenAIManager { return; } } - - // Respond to user - if (event.isFromThread()) { - ThreadChannel channel = event.getChannel().asThreadChannel(); - channel.sendMessage(stringBuilder.toString()).queue(); - } else { - // The max discord title length is 100 characters - String threadTitle = message; - if (message.length() > 100) { - threadTitle = message.substring(0, 100); - } - event.getMessage().createThreadChannel(threadTitle).queue(channel -> { - channel.sendMessage(stringBuilder.toString()).queue(); -// threadMessages.put(channel.getId(), new ArrayList<>(Arrays.asList(chatMessage))); - }); - } - + handleResponse(chatMessage, event, stringBuilder.toString(), embeddings); } catch (Exception ex) { LOGGER.error("Caught exception while processing message; {}", ex.getMessage()); event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue(); } } + private EmbeddingRequest createEmbeddingRequest(List chatMessages, MessageReceivedEvent event) { + GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()); + return EmbeddingRequest.builder() + .model(guildSettings.getModel().getName()) + .user(event.getAuthor().getId()) + .input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList())) + .build(); + } + + private ChatCompletionRequest createCompletionRequest(List chatMessages, MessageReceivedEvent event) { + GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()); + return ChatCompletionRequest.builder() + .model(guildSettings.getModel().getName()) + .maxTokens(guildSettings.getMaxTokens()) + .user(event.getAuthor().getId()) + .temperature(settings.getOpenAISettings().getTemperature()) + .topP(settings.getOpenAISettings().getTopP()) + .frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty()) + .presencePenalty(settings.getOpenAISettings().getPresencePenalty()) + .messages(chatMessages) + .build(); + } + + private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) { + GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()); + return CompletionRequest.builder() + .model(guildSettings.getModel().getName()) + .maxTokens(guildSettings.getMaxTokens()) + .user(event.getAuthor().getId()) + .temperature(settings.getOpenAISettings().getTemperature()) + .topP(settings.getOpenAISettings().getTopP()) + .frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty()) + .presencePenalty(settings.getOpenAISettings().getPresencePenalty()) + .prompt(message) + .build(); + } + private ChatMessage createChatMessage(String message, MessageReceivedEvent event) { ChatMessage chatMessage = new ChatMessage(); chatMessage.setContent(message); @@ -160,6 +157,34 @@ public class OpenAIManager { return chatMessage; } + private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List embeddings) { + if (event.isFromThread()) { + ThreadChannel channel = event.getChannel().asThreadChannel(); + int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(), + channel.getIdLong(), response); + embeddings.forEach(embedding -> { + int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding()); + DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow); + }); + channel.sendMessage(response).queue(); + } else { + // The max discord title length is 100 characters + String threadTitle = chatMessage.getContent(); + if (chatMessage.getContent().length() > 100) { + threadTitle = chatMessage.getContent().substring(0, 100); + } + event.getMessage().createThreadChannel(threadTitle).queue(channel -> { + channel.sendMessage(response).queue(); + int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(), + channel.getIdLong(), response); + embeddings.forEach(embedding -> { + int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding()); + DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow); + }); + }); + } + } + private String parseMessage(String input) { return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim(); }