v0.1.15 Refactoring and updating OpenAI usage again

This commit is contained in:
2023-04-15 09:58:30 -04:00
parent 1522dfd99b
commit 414bca2dd6
12 changed files with 312 additions and 68 deletions

1
.env
View File

@@ -1 +0,0 @@
export SIREN_VERSION=0.1.14

3
.env.TEMPLATE Normal file
View File

@@ -0,0 +1,3 @@
export POSTGRES_USER=
export POSTGRES_PASSWORD=
export POSTGRES_DB=

2
.gitignore vendored
View File

@@ -1,5 +1,7 @@
.idea/
**/target/
**/data/
**/app/
**/settings.json
**/logs/
.env

1
.version Normal file
View File

@@ -0,0 +1 @@
export SIREN_VERSION=0.1.15

View File

@@ -1,8 +1,9 @@
SHELL := /bin/bash
include .version
include .env
build:
docker rmi siren && docker-compose build
if docker inspect siren > /dev/null 2>&1; then docker rmi siren; fi; docker-compose build
test:
docker run --rm -it siren:latest bash

View File

@@ -38,13 +38,15 @@ https://discord.com/api/oauth2/authorize?client_id=<Client ID>&permissions=54696
- applications.commands
```
`docker build -t siren .`
`docker-compose up -d`
```
make build
make up
```
## Development
Build container
`docker build -t siren .`
`make build`
Run container locally

View File

@@ -11,5 +11,22 @@ services:
- JAVA_VERSION=17
- VERSION=${SIREN_VERSION}
volumes:
- ./data:/app
- ./app:/app
environment:
DATABASE_URL: jdbc:postgresql://db:5432/${POSTGRES_DB}
DATABASE_USERNAME: ${POSTGRES_USER}
DATABASE_PASSWORD: ${POSTGRES_PASSWORD}
depends_on:
- db
restart: unless-stopped
db:
image: postgres:latest
environment:
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
volumes:
- ./data:/var/lib/postgresql/data
ports:
- "5432:5432"
restart: unless-stopped

View File

@@ -66,6 +66,7 @@
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
<jackson.version>2.14.2</jackson.version>
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
<postgresql.version>42.6.0</postgresql.version>
<slf4j.version>2.0.6</slf4j.version>
<log4j.version>2.20.0</log4j.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
@@ -74,7 +75,6 @@
</properties>
<dependencies>
<!-- Discord Dependencies -->
<dependency>
<groupId>net.dv8tion</groupId>
<artifactId>JDA</artifactId>
@@ -95,13 +95,16 @@
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<!-- OpenAI https://github.com/TheoKanning/openai-java -->
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
<version>${theokanning-openai-gpt3.version}</version>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>${postgresql.version}</version>
</dependency>
<!-- Logging -->
<dependency>

View File

@@ -3,6 +3,7 @@ package com.bensherriff.siren;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.PlayerManager;
import com.bensherriff.siren.commands.*;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.openai.OpenAIManager;
import com.bensherriff.siren.settings.GuildSettings;
@@ -68,6 +69,8 @@ public class Listener extends ListenerAdapter {
this.playerManager.initialize();
this.openAIManager = new OpenAIManager(this);
DatabaseManager.createTables();
commands.put("play", new PlayCommand(this));
commands.put("stop", new StopCommand(this));
commands.put("skip", new SkipCommand(this));

View File

@@ -0,0 +1,65 @@
package com.bensherriff.siren.database;
import com.bensherriff.siren.openai.OpenAIManager;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class DatabaseConnection {
private static final Logger LOGGER = LogManager.getLogger(DatabaseConnection.class);
private static final ExecutorService executorService = Executors.newFixedThreadPool(4);
private static final ThreadLocal<Connection> connectionThreadLocal = new ThreadLocal<>();
public static Connection getConnection() throws SQLException {
Connection connection = connectionThreadLocal.get();
if (connection == null) {
Map<String, String> env = System.getenv();
String dbUrl = env.get("DATABASE_URL");
String dbUsername = env.get("DATABASE_USERNAME");
String dbPassword = env.get("DATABASE_PASSWORD");
connection = DriverManager.getConnection(dbUrl, dbUsername, dbPassword);
connectionThreadLocal.set(connection);
}
return connection;
}
public static void closeConnection() {
Connection connection = connectionThreadLocal.get();
connectionThreadLocal.remove();
if (connection != null) {
executorService.submit(() -> {
try {
connection.close();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
});
}
}
public static void shutdown() {
executorService.shutdown();
try {
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
executorService.shutdownNow();
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
LOGGER.error("ExecutorService did not terminate");
}
}
} catch (InterruptedException e) {
executorService.shutdownNow();
Thread.currentThread().interrupt();
}
}
}

View File

@@ -0,0 +1,123 @@
package com.bensherriff.siren.database;
import com.theokanning.openai.completion.chat.ChatMessage;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.sql.*;
import java.util.List;
public class DatabaseManager {
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
"id SERIAL PRIMARY KEY, " +
"guild_id BIGINT NOT NULL, " +
"thread_id BIGINT NOT NULL, " +
"user_id BIGINT NOT NULL, " +
"message_type VARCHAR(20) NOT NULL," +
"message_text TEXT NOT NULL," +
"message_response TEXT, " +
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
")";
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
"id SERIAL PRIMARY KEY, " +
"embeddings FLOAT[] NOT NULL" +
");";
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
"id SERIAL PRIMARY KEY, " +
"message_id INT NOT NULL, " +
"embedding_id INT NOT NULL" +
")";
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
"message_type, " +
"guild_id, " +
"thread_id, " +
"user_id, " +
"message_text, " +
"message_response) " +
"VALUES (?, ?, ?, ?, ?, ?)";
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
"embeddings) " +
"VALUES (?)";
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
"message_id, " +
"embeddings_id, " +
"VALUES (?, ?)";
public static void createTables() {
createMessageTable();
createEmbeddingsTable();
createMessageEmbeddingsTable();
}
private static void createMessageTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'messages' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_MESSAGE_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
private static void createEmbeddingsTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_EMBEDDINGS_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
private static void createMessageEmbeddingsTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
preparedStatement.setString(1, message.getRole());
preparedStatement.setLong(2, guildId);
preparedStatement.setLong(3, threadId);
preparedStatement.setLong(4, userId);
preparedStatement.setString(5, message.getContent());
preparedStatement.setString(6, response);
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
}
public static int storeEmbedding(List<Double> data) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
}
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
preparedStatement.setInt(1, messageId);
preparedStatement.setInt(2, embeddingId);
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
}
}

View File

@@ -1,6 +1,7 @@
package com.bensherriff.siren.openai;
import com.bensherriff.siren.Listener;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.theokanning.openai.completion.CompletionRequest;
@@ -8,6 +9,9 @@ import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.service.OpenAiService;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Message;
@@ -20,6 +24,7 @@ import org.apache.logging.log4j.Logger;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
public class OpenAIManager {
private static final Logger LOGGER = LogManager.getLogger(OpenAIManager.class);
@@ -27,13 +32,11 @@ public class OpenAIManager {
private final Settings settings;
private final JDA jda;
private final ScheduledExecutorService executor;
private final Map<String, List<ChatMessage>> threadMessages;
public OpenAIManager(Listener listener) {
this.settings = listener.getSettings();
this.jda = listener.getJDA();
this.executor = listener.getExecutor();
this.threadMessages = new HashMap<>();
if (settings.getOpenAISettings().getToken().isEmpty()) {
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
@@ -62,62 +65,35 @@ public class OpenAIManager {
String message = parseMessage(event.getMessage().getContentRaw());
long guildId = event.getGuild().getIdLong();
Model model = settings.getGuildSettings().get(guildId).getModel();
GuildSettings guildSettings = settings.getGuildSettings().get(guildId);
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
StringBuilder stringBuilder = new StringBuilder();
List<ChatMessage> chatMessages = new ArrayList<>();
ChatMessage chatMessage = createChatMessage(message, event);
List<Embedding> embeddings = new ArrayList<>();
if (message.isEmpty() || message.isBlank()) {
event.getMessage().reply("Your message is empty. Please try again").queue();
return;
}
StringBuilder stringBuilder = new StringBuilder();
List<ChatMessage> chatMessages = new ArrayList<>();
ChatMessage chatMessage = createChatMessage(message, event);
try {
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
// Send OpenAI Message and get response
switch (model) {
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
CompletionRequest completionRequest = CompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.prompt(message)
.build();
CompletionRequest completionRequest = createCompletionRequest(message, event);
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
}
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
//TODO Handle memories properly
// if (event.isFromThread()) {
// String channelId = event.getChannel().asThreadChannel().getId();
// // Update ThreadMessages with the new message, and add previous messages to be sent out
// if (threadMessages.containsKey(channelId)) {
// threadMessages.get(channelId).add(chatMessage);
// chatMessages.addAll(threadMessages.get(channelId));
// } else {
// threadMessages.put(channelId, new ArrayList<>(Arrays.asList(chatMessage)));
// }
// }
chatMessages.add(chatMessage);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.messages(chatMessages)
.build();
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
//TODO fix embeddings
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
// embeddings.addAll(embeddingResult.getData());
}
default -> {
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
@@ -126,29 +102,50 @@ public class OpenAIManager {
return;
}
}
// Respond to user
if (event.isFromThread()) {
ThreadChannel channel = event.getChannel().asThreadChannel();
channel.sendMessage(stringBuilder.toString()).queue();
} else {
// The max discord title length is 100 characters
String threadTitle = message;
if (message.length() > 100) {
threadTitle = message.substring(0, 100);
}
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
channel.sendMessage(stringBuilder.toString()).queue();
// threadMessages.put(channel.getId(), new ArrayList<>(Arrays.asList(chatMessage)));
});
}
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
} catch (Exception ex) {
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
}
}
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return EmbeddingRequest.builder()
.model(guildSettings.getModel().getName())
.user(event.getAuthor().getId())
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
.build();
}
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.messages(chatMessages)
.build();
}
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return CompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.prompt(message)
.build();
}
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
ChatMessage chatMessage = new ChatMessage();
chatMessage.setContent(message);
@@ -160,6 +157,34 @@ public class OpenAIManager {
return chatMessage;
}
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
if (event.isFromThread()) {
ThreadChannel channel = event.getChannel().asThreadChannel();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
channel.sendMessage(response).queue();
} else {
// The max discord title length is 100 characters
String threadTitle = chatMessage.getContent();
if (chatMessage.getContent().length() > 100) {
threadTitle = chatMessage.getContent().substring(0, 100);
}
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
channel.sendMessage(response).queue();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
});
}
}
private String parseMessage(String input) {
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
}