v0.1.15 Refactoring and updating OpenAI usage again
This commit is contained in:
@@ -3,6 +3,7 @@ package com.bensherriff.siren;
|
||||
import com.bensherriff.siren.audio.AudioHandler;
|
||||
import com.bensherriff.siren.audio.PlayerManager;
|
||||
import com.bensherriff.siren.commands.*;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
||||
import com.bensherriff.siren.openai.OpenAIManager;
|
||||
import com.bensherriff.siren.settings.GuildSettings;
|
||||
@@ -68,6 +69,8 @@ public class Listener extends ListenerAdapter {
|
||||
this.playerManager.initialize();
|
||||
this.openAIManager = new OpenAIManager(this);
|
||||
|
||||
DatabaseManager.createTables();
|
||||
|
||||
commands.put("play", new PlayCommand(this));
|
||||
commands.put("stop", new StopCommand(this));
|
||||
commands.put("skip", new SkipCommand(this));
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import com.bensherriff.siren.openai.OpenAIManager;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class DatabaseConnection {
|
||||
private static final Logger LOGGER = LogManager.getLogger(DatabaseConnection.class);
|
||||
private static final ExecutorService executorService = Executors.newFixedThreadPool(4);
|
||||
private static final ThreadLocal<Connection> connectionThreadLocal = new ThreadLocal<>();
|
||||
|
||||
public static Connection getConnection() throws SQLException {
|
||||
Connection connection = connectionThreadLocal.get();
|
||||
|
||||
if (connection == null) {
|
||||
Map<String, String> env = System.getenv();
|
||||
String dbUrl = env.get("DATABASE_URL");
|
||||
String dbUsername = env.get("DATABASE_USERNAME");
|
||||
String dbPassword = env.get("DATABASE_PASSWORD");
|
||||
|
||||
connection = DriverManager.getConnection(dbUrl, dbUsername, dbPassword);
|
||||
connectionThreadLocal.set(connection);
|
||||
}
|
||||
|
||||
return connection;
|
||||
}
|
||||
|
||||
public static void closeConnection() {
|
||||
Connection connection = connectionThreadLocal.get();
|
||||
connectionThreadLocal.remove();
|
||||
|
||||
if (connection != null) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
connection.close();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public static void shutdown() {
|
||||
executorService.shutdown();
|
||||
try {
|
||||
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
|
||||
executorService.shutdownNow();
|
||||
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
|
||||
LOGGER.error("ExecutorService did not terminate");
|
||||
}
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
executorService.shutdownNow();
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.sql.*;
|
||||
import java.util.List;
|
||||
|
||||
public class DatabaseManager {
|
||||
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
|
||||
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"guild_id BIGINT NOT NULL, " +
|
||||
"thread_id BIGINT NOT NULL, " +
|
||||
"user_id BIGINT NOT NULL, " +
|
||||
"message_type VARCHAR(20) NOT NULL," +
|
||||
"message_text TEXT NOT NULL," +
|
||||
"message_response TEXT, " +
|
||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
|
||||
")";
|
||||
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"embeddings FLOAT[] NOT NULL" +
|
||||
");";
|
||||
|
||||
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"message_id INT NOT NULL, " +
|
||||
"embedding_id INT NOT NULL" +
|
||||
")";
|
||||
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||
"message_type, " +
|
||||
"guild_id, " +
|
||||
"thread_id, " +
|
||||
"user_id, " +
|
||||
"message_text, " +
|
||||
"message_response) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?)";
|
||||
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
||||
"embeddings) " +
|
||||
"VALUES (?)";
|
||||
|
||||
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
|
||||
"message_id, " +
|
||||
"embeddings_id, " +
|
||||
"VALUES (?, ?)";
|
||||
|
||||
public static void createTables() {
|
||||
createMessageTable();
|
||||
createEmbeddingsTable();
|
||||
createMessageEmbeddingsTable();
|
||||
}
|
||||
|
||||
private static void createMessageTable() {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'messages' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_MESSAGE_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static void createEmbeddingsTable() {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_EMBEDDINGS_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static void createMessageEmbeddingsTable() {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
|
||||
preparedStatement.setString(1, message.getRole());
|
||||
preparedStatement.setLong(2, guildId);
|
||||
preparedStatement.setLong(3, threadId);
|
||||
preparedStatement.setLong(4, userId);
|
||||
preparedStatement.setString(5, message.getContent());
|
||||
preparedStatement.setString(6, response);
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeEmbedding(List<Double> data) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
|
||||
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
|
||||
preparedStatement.setInt(1, messageId);
|
||||
preparedStatement.setInt(2, embeddingId);
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.bensherriff.siren.openai;
|
||||
|
||||
import com.bensherriff.siren.Listener;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.settings.GuildSettings;
|
||||
import com.bensherriff.siren.settings.Settings;
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
@@ -8,6 +9,9 @@ import com.theokanning.openai.completion.CompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import com.theokanning.openai.embedding.Embedding;
|
||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||
import com.theokanning.openai.embedding.EmbeddingResult;
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
import net.dv8tion.jda.api.JDA;
|
||||
import net.dv8tion.jda.api.entities.Message;
|
||||
@@ -20,6 +24,7 @@ import org.apache.logging.log4j.Logger;
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class OpenAIManager {
|
||||
private static final Logger LOGGER = LogManager.getLogger(OpenAIManager.class);
|
||||
@@ -27,13 +32,11 @@ public class OpenAIManager {
|
||||
private final Settings settings;
|
||||
private final JDA jda;
|
||||
private final ScheduledExecutorService executor;
|
||||
private final Map<String, List<ChatMessage>> threadMessages;
|
||||
|
||||
public OpenAIManager(Listener listener) {
|
||||
this.settings = listener.getSettings();
|
||||
this.jda = listener.getJDA();
|
||||
this.executor = listener.getExecutor();
|
||||
this.threadMessages = new HashMap<>();
|
||||
|
||||
if (settings.getOpenAISettings().getToken().isEmpty()) {
|
||||
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
|
||||
@@ -62,62 +65,35 @@ public class OpenAIManager {
|
||||
String message = parseMessage(event.getMessage().getContentRaw());
|
||||
long guildId = event.getGuild().getIdLong();
|
||||
Model model = settings.getGuildSettings().get(guildId).getModel();
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(guildId);
|
||||
|
||||
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
ChatMessage chatMessage = createChatMessage(message, event);
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
|
||||
if (message.isEmpty() || message.isBlank()) {
|
||||
event.getMessage().reply("Your message is empty. Please try again").queue();
|
||||
return;
|
||||
}
|
||||
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
ChatMessage chatMessage = createChatMessage(message, event);
|
||||
|
||||
try {
|
||||
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
||||
// Send OpenAI Message and get response
|
||||
switch (model) {
|
||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||
CompletionRequest completionRequest = CompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.prompt(message)
|
||||
.build();
|
||||
CompletionRequest completionRequest = createCompletionRequest(message, event);
|
||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||
}
|
||||
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
||||
//TODO Handle memories properly
|
||||
// if (event.isFromThread()) {
|
||||
// String channelId = event.getChannel().asThreadChannel().getId();
|
||||
// // Update ThreadMessages with the new message, and add previous messages to be sent out
|
||||
// if (threadMessages.containsKey(channelId)) {
|
||||
// threadMessages.get(channelId).add(chatMessage);
|
||||
// chatMessages.addAll(threadMessages.get(channelId));
|
||||
// } else {
|
||||
// threadMessages.put(channelId, new ArrayList<>(Arrays.asList(chatMessage)));
|
||||
// }
|
||||
// }
|
||||
chatMessages.add(chatMessage);
|
||||
|
||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.messages(chatMessages)
|
||||
.build();
|
||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
|
||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||
|
||||
//TODO fix embeddings
|
||||
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
|
||||
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
|
||||
// embeddings.addAll(embeddingResult.getData());
|
||||
}
|
||||
default -> {
|
||||
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
|
||||
@@ -126,29 +102,50 @@ public class OpenAIManager {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Respond to user
|
||||
if (event.isFromThread()) {
|
||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||
channel.sendMessage(stringBuilder.toString()).queue();
|
||||
} else {
|
||||
// The max discord title length is 100 characters
|
||||
String threadTitle = message;
|
||||
if (message.length() > 100) {
|
||||
threadTitle = message.substring(0, 100);
|
||||
}
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||
channel.sendMessage(stringBuilder.toString()).queue();
|
||||
// threadMessages.put(channel.getId(), new ArrayList<>(Arrays.asList(chatMessage)));
|
||||
});
|
||||
}
|
||||
|
||||
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
||||
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
|
||||
}
|
||||
}
|
||||
|
||||
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return EmbeddingRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.user(event.getAuthor().getId())
|
||||
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return ChatCompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.messages(chatMessages)
|
||||
.build();
|
||||
}
|
||||
|
||||
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return CompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.prompt(message)
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
|
||||
ChatMessage chatMessage = new ChatMessage();
|
||||
chatMessage.setContent(message);
|
||||
@@ -160,6 +157,34 @@ public class OpenAIManager {
|
||||
return chatMessage;
|
||||
}
|
||||
|
||||
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
||||
if (event.isFromThread()) {
|
||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||
channel.getIdLong(), response);
|
||||
embeddings.forEach(embedding -> {
|
||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||
});
|
||||
channel.sendMessage(response).queue();
|
||||
} else {
|
||||
// The max discord title length is 100 characters
|
||||
String threadTitle = chatMessage.getContent();
|
||||
if (chatMessage.getContent().length() > 100) {
|
||||
threadTitle = chatMessage.getContent().substring(0, 100);
|
||||
}
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||
channel.sendMessage(response).queue();
|
||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||
channel.getIdLong(), response);
|
||||
embeddings.forEach(embedding -> {
|
||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private String parseMessage(String input) {
|
||||
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user