v0.1.15 Refactoring and updating OpenAI usage again

This commit is contained in:
2023-04-15 09:58:30 -04:00
parent 1522dfd99b
commit 414bca2dd6
12 changed files with 312 additions and 68 deletions

View File

@@ -3,6 +3,7 @@ package com.bensherriff.siren;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.PlayerManager;
import com.bensherriff.siren.commands.*;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.openai.OpenAIManager;
import com.bensherriff.siren.settings.GuildSettings;
@@ -68,6 +69,8 @@ public class Listener extends ListenerAdapter {
this.playerManager.initialize();
this.openAIManager = new OpenAIManager(this);
DatabaseManager.createTables();
commands.put("play", new PlayCommand(this));
commands.put("stop", new StopCommand(this));
commands.put("skip", new SkipCommand(this));

View File

@@ -0,0 +1,65 @@
package com.bensherriff.siren.database;
import com.bensherriff.siren.openai.OpenAIManager;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class DatabaseConnection {
private static final Logger LOGGER = LogManager.getLogger(DatabaseConnection.class);
private static final ExecutorService executorService = Executors.newFixedThreadPool(4);
private static final ThreadLocal<Connection> connectionThreadLocal = new ThreadLocal<>();
public static Connection getConnection() throws SQLException {
Connection connection = connectionThreadLocal.get();
if (connection == null) {
Map<String, String> env = System.getenv();
String dbUrl = env.get("DATABASE_URL");
String dbUsername = env.get("DATABASE_USERNAME");
String dbPassword = env.get("DATABASE_PASSWORD");
connection = DriverManager.getConnection(dbUrl, dbUsername, dbPassword);
connectionThreadLocal.set(connection);
}
return connection;
}
public static void closeConnection() {
Connection connection = connectionThreadLocal.get();
connectionThreadLocal.remove();
if (connection != null) {
executorService.submit(() -> {
try {
connection.close();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
});
}
}
public static void shutdown() {
executorService.shutdown();
try {
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
executorService.shutdownNow();
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
LOGGER.error("ExecutorService did not terminate");
}
}
} catch (InterruptedException e) {
executorService.shutdownNow();
Thread.currentThread().interrupt();
}
}
}

View File

@@ -0,0 +1,123 @@
package com.bensherriff.siren.database;
import com.theokanning.openai.completion.chat.ChatMessage;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.sql.*;
import java.util.List;
public class DatabaseManager {
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
"id SERIAL PRIMARY KEY, " +
"guild_id BIGINT NOT NULL, " +
"thread_id BIGINT NOT NULL, " +
"user_id BIGINT NOT NULL, " +
"message_type VARCHAR(20) NOT NULL," +
"message_text TEXT NOT NULL," +
"message_response TEXT, " +
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
")";
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
"id SERIAL PRIMARY KEY, " +
"embeddings FLOAT[] NOT NULL" +
");";
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
"id SERIAL PRIMARY KEY, " +
"message_id INT NOT NULL, " +
"embedding_id INT NOT NULL" +
")";
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
"message_type, " +
"guild_id, " +
"thread_id, " +
"user_id, " +
"message_text, " +
"message_response) " +
"VALUES (?, ?, ?, ?, ?, ?)";
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
"embeddings) " +
"VALUES (?)";
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
"message_id, " +
"embeddings_id, " +
"VALUES (?, ?)";
public static void createTables() {
createMessageTable();
createEmbeddingsTable();
createMessageEmbeddingsTable();
}
private static void createMessageTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'messages' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_MESSAGE_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
private static void createEmbeddingsTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_EMBEDDINGS_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
private static void createMessageEmbeddingsTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
preparedStatement.setString(1, message.getRole());
preparedStatement.setLong(2, guildId);
preparedStatement.setLong(3, threadId);
preparedStatement.setLong(4, userId);
preparedStatement.setString(5, message.getContent());
preparedStatement.setString(6, response);
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
}
public static int storeEmbedding(List<Double> data) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
}
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
preparedStatement.setInt(1, messageId);
preparedStatement.setInt(2, embeddingId);
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return -1;
}
}
}

View File

@@ -1,6 +1,7 @@
package com.bensherriff.siren.openai;
import com.bensherriff.siren.Listener;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.theokanning.openai.completion.CompletionRequest;
@@ -8,6 +9,9 @@ import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.service.OpenAiService;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Message;
@@ -20,6 +24,7 @@ import org.apache.logging.log4j.Logger;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
public class OpenAIManager {
private static final Logger LOGGER = LogManager.getLogger(OpenAIManager.class);
@@ -27,13 +32,11 @@ public class OpenAIManager {
private final Settings settings;
private final JDA jda;
private final ScheduledExecutorService executor;
private final Map<String, List<ChatMessage>> threadMessages;
public OpenAIManager(Listener listener) {
this.settings = listener.getSettings();
this.jda = listener.getJDA();
this.executor = listener.getExecutor();
this.threadMessages = new HashMap<>();
if (settings.getOpenAISettings().getToken().isEmpty()) {
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
@@ -62,62 +65,35 @@ public class OpenAIManager {
String message = parseMessage(event.getMessage().getContentRaw());
long guildId = event.getGuild().getIdLong();
Model model = settings.getGuildSettings().get(guildId).getModel();
GuildSettings guildSettings = settings.getGuildSettings().get(guildId);
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
StringBuilder stringBuilder = new StringBuilder();
List<ChatMessage> chatMessages = new ArrayList<>();
ChatMessage chatMessage = createChatMessage(message, event);
List<Embedding> embeddings = new ArrayList<>();
if (message.isEmpty() || message.isBlank()) {
event.getMessage().reply("Your message is empty. Please try again").queue();
return;
}
StringBuilder stringBuilder = new StringBuilder();
List<ChatMessage> chatMessages = new ArrayList<>();
ChatMessage chatMessage = createChatMessage(message, event);
try {
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
// Send OpenAI Message and get response
switch (model) {
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
CompletionRequest completionRequest = CompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.prompt(message)
.build();
CompletionRequest completionRequest = createCompletionRequest(message, event);
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
}
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
//TODO Handle memories properly
// if (event.isFromThread()) {
// String channelId = event.getChannel().asThreadChannel().getId();
// // Update ThreadMessages with the new message, and add previous messages to be sent out
// if (threadMessages.containsKey(channelId)) {
// threadMessages.get(channelId).add(chatMessage);
// chatMessages.addAll(threadMessages.get(channelId));
// } else {
// threadMessages.put(channelId, new ArrayList<>(Arrays.asList(chatMessage)));
// }
// }
chatMessages.add(chatMessage);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.messages(chatMessages)
.build();
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
//TODO fix embeddings
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
// embeddings.addAll(embeddingResult.getData());
}
default -> {
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
@@ -126,29 +102,50 @@ public class OpenAIManager {
return;
}
}
// Respond to user
if (event.isFromThread()) {
ThreadChannel channel = event.getChannel().asThreadChannel();
channel.sendMessage(stringBuilder.toString()).queue();
} else {
// The max discord title length is 100 characters
String threadTitle = message;
if (message.length() > 100) {
threadTitle = message.substring(0, 100);
}
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
channel.sendMessage(stringBuilder.toString()).queue();
// threadMessages.put(channel.getId(), new ArrayList<>(Arrays.asList(chatMessage)));
});
}
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
} catch (Exception ex) {
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
}
}
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return EmbeddingRequest.builder()
.model(guildSettings.getModel().getName())
.user(event.getAuthor().getId())
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
.build();
}
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.messages(chatMessages)
.build();
}
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return CompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.prompt(message)
.build();
}
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
ChatMessage chatMessage = new ChatMessage();
chatMessage.setContent(message);
@@ -160,6 +157,34 @@ public class OpenAIManager {
return chatMessage;
}
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
if (event.isFromThread()) {
ThreadChannel channel = event.getChannel().asThreadChannel();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
channel.sendMessage(response).queue();
} else {
// The max discord title length is 100 characters
String threadTitle = chatMessage.getContent();
if (chatMessage.getContent().length() > 100) {
threadTitle = chatMessage.getContent().substring(0, 100);
}
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
channel.sendMessage(response).queue();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
});
}
}
private String parseMessage(String input) {
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
}