v0.1.16 Query builder and working on memory management

This commit is contained in:
2023-04-16 07:58:16 -04:00
parent 414bca2dd6
commit e6fb86c4e0
10 changed files with 464 additions and 103 deletions

View File

@@ -34,14 +34,15 @@ public class Listener extends ListenerAdapter {
private final Settings settings;
private final Map<String, Command> commands = new HashMap<>();
private final String owner;
private PlayerManager playerManager;
private OpenAIManager openAIManager;
private JDA jda;
public Listener(Settings settings) {
this.settings = settings;
this.executor = Executors.newScheduledThreadPool(this.settings.getThreadPool());
this.owner = "@bsherriff";
}
public ScheduledExecutorService getExecutor() {
@@ -64,6 +65,10 @@ public class Listener extends ListenerAdapter {
this.jda = jda;
}
public String getOwner() {
return owner;
}
public void initialize() {
this.playerManager = new PlayerManager(this);
this.playerManager.initialize();
@@ -78,6 +83,7 @@ public class Listener extends ListenerAdapter {
commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this));
commands.put("help", new PauseCommand(this));
}
public void closeAudioConnection(long guildID) {
@@ -144,7 +150,7 @@ public class Listener extends ListenerAdapter {
commands.get(command).execute(event);
} catch (Exception ex) {
LOGGER.error(ex.getMessage());
event.getHook().sendMessage("An error occurred while processing your command. Please contact your administrator.").queue();
event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue();
}
});
} else {

View File

@@ -13,8 +13,8 @@ import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.util.Arrays;
public class Bot {
private static final Logger LOGGER = LogManager.getLogger(Bot.class);
public class Main {
private static final Logger LOGGER = LogManager.getLogger(Main.class);
private final static GatewayIntent[] INTENTS = {
GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS,
GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT

View File

@@ -89,7 +89,7 @@ public class PlayCommand extends Command {
if (exception.getMessage().contains("Unknown file format")) {
event.getHook().sendMessage(errorMsg + ". " + exception.getMessage()).queue();
} else {
event.getHook().sendMessage(errorMsg + ". Please contact your administrator.").queue();
event.getHook().sendMessage(errorMsg + ". Please contact " + listener.getOwner() + ".").queue();
}
LOGGER.error("{}: {}", errorMsg, exception.getMessage());
}

View File

@@ -1,42 +1,47 @@
package com.bensherriff.siren.database;
import com.theokanning.openai.completion.chat.ChatMessage;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.sql.*;
import java.util.List;
import java.util.*;
import static java.util.Map.entry;
public class DatabaseManager {
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
"id SERIAL PRIMARY KEY, " +
"guild_id BIGINT NOT NULL, " +
"thread_id BIGINT NOT NULL, " +
"user_id BIGINT NOT NULL, " +
"message_type VARCHAR(20) NOT NULL," +
"message_text TEXT NOT NULL," +
"message_response TEXT, " +
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
")";
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
"id SERIAL PRIMARY KEY, " +
"embeddings FLOAT[] NOT NULL" +
");";
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
"id SERIAL PRIMARY KEY, " +
"message_id INT NOT NULL, " +
"embedding_id INT NOT NULL" +
")";
private static final Map<String, String> createTableQueries = Map.ofEntries(
entry("messages", "CREATE TABLE IF NOT EXISTS messages (" +
"id SERIAL PRIMARY KEY, " +
"guild_id BIGINT NOT NULL, " +
"thread_id BIGINT NOT NULL, " +
"user_id BIGINT NOT NULL, " +
"message_type VARCHAR(20) NOT NULL," +
"message_text TEXT NOT NULL," +
"message_response TEXT, " +
"topics TEXT[], " +
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
")"),
entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" +
"id SERIAL PRIMARY KEY, " +
"embeddings FLOAT[] NOT NULL" +
")"),
entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" +
"id SERIAL PRIMARY KEY, " +
"message_id INT NOT NULL, " +
"embedding_id INT NOT NULL" +
")")
);
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
"message_type, " +
"guild_id, " +
"thread_id, " +
"user_id, " +
"message_text, " +
"message_response) " +
"VALUES (?, ?, ?, ?, ?, ?)";
"message_response, " +
"topics) " +
"VALUES (?, ?, ?, ?, ?, ?, ?)";
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
"embeddings) " +
"VALUES (?)";
@@ -47,52 +52,47 @@ public class DatabaseManager {
"VALUES (?, ?)";
public static void createTables() {
createMessageTable();
createEmbeddingsTable();
createMessageEmbeddingsTable();
for (Map.Entry<String, String> entry : createTableQueries.entrySet()) {
if (!createTable(entry.getKey())) {
LOGGER.warn("Failed to create one or more required database tables");
return;
}
}
LOGGER.debug("Successfully created database tables");
}
private static void createMessageTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'messages' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_MESSAGE_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
private static boolean createTable(String tableName) {
if (tableExists(tableName)) {
return true;
} else {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
Statement statement = connection.createStatement();
statement.execute(createTableQueries.get(tableName));
return true;
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return false;
}
}
}
private static void createEmbeddingsTable() {
public static int storeMessage(MessageData messageData) {
if (!tableExists("messages")) {
LOGGER.warn("Table 'messages' does not exist");
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_EMBEDDINGS_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
private static void createMessageEmbeddingsTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
preparedStatement.setString(1, message.getRole());
preparedStatement.setLong(2, guildId);
preparedStatement.setLong(3, threadId);
preparedStatement.setLong(4, userId);
preparedStatement.setString(5, message.getContent());
preparedStatement.setString(6, response);
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE);
preparedStatement.setString(1, messageData.getMessageType());
preparedStatement.setLong(2, messageData.getGuildId());
preparedStatement.setLong(3, messageData.getThreadId());
preparedStatement.setLong(4, messageData.getUserId());
preparedStatement.setString(5, messageData.getMessageText());
preparedStatement.setString(6, messageData.getMessageResponse());
preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray()));
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
@@ -101,7 +101,13 @@ public class DatabaseManager {
}
public static int storeEmbedding(List<Double> data) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
if (!tableExists("embeddings")) {
LOGGER.warn("Table 'embeddings' does not exist");
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING);
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
return preparedStatement.executeUpdate();
} catch (SQLException ex) {
@@ -111,7 +117,13 @@ public class DatabaseManager {
}
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
if (!tableExists("message_embeddings")) {
LOGGER.warn("Table 'message_embeddings' does not exist");
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS);
preparedStatement.setInt(1, messageId);
preparedStatement.setInt(2, embeddingId);
return preparedStatement.executeUpdate();
@@ -120,4 +132,61 @@ public class DatabaseManager {
return -1;
}
}
public static List<MessageData> getMessages(String query, Object... params) throws SQLException, IllegalArgumentException {
LOGGER.trace("Query: <{}>", query);
Connection connection = DatabaseConnection.getConnection();
PreparedStatement stmt = connection.prepareStatement(query);
int i = 1;
for (Object param : params) {
if (param instanceof String) {
stmt.setString(i++, (String) param);
} else if (param instanceof Integer) {
stmt.setInt(i++, (Integer) param);
} else if (param instanceof Long) {
stmt.setLong(i++, (Long) param);
} else if (param instanceof Double) {
stmt.setDouble(i++, (Double) param);
} else if (param instanceof Float) {
stmt.setFloat(i++, (Float) param);
} else if (param instanceof Timestamp) {
stmt.setTimestamp(i++, (Timestamp) param);
} else if (param instanceof Boolean) {
stmt.setBoolean(i++, (Boolean) param);
} else if (param instanceof List) {
stmt.setArray(i++, connection.createArrayOf("text", ((List<?>) param).toArray()));
} else {
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName());
}
}
ResultSet resultSet = stmt.executeQuery();
List<MessageData> resultList = new ArrayList<>();
while (resultSet.next()) {
Array array = resultSet.getArray(8);
MessageData messageData = new MessageData(
resultSet.getLong(2),
resultSet.getLong(3),
resultSet.getLong(4),
resultSet.getString(5),
resultSet.getString(6),
resultSet.getString(7),
new HashSet<>(Arrays.asList((String[]) array.getArray())),
resultSet.getTimestamp(9)
);
resultList.add(messageData);
}
return resultList;
}
private static boolean tableExists(String tableName) {
try {
Connection connection = DatabaseConnection.getConnection();
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'");
return resultSet.next();
} catch (SQLException ex) {
LOGGER.error("Failed to check if table exists; {}" + ex.getMessage());
return false;
}
}
}

View File

@@ -0,0 +1,74 @@
package com.bensherriff.siren.database;
import java.sql.Timestamp;
import java.util.Set;
public class MessageData {
private final Long guildId;
private final Long threadId;
private final Long userId;
private final String messageType;
private final String messageText;
private final String messageResponse;
private final Set<String> topics;
private Timestamp timestamp;
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
String messageResponse, Set<String> topics) {
this.guildId = guildId;
this.threadId = threadId;
this.userId = userId;
this.messageType = messageType;
this.messageText = messageText;
this.messageResponse = messageResponse;
this.topics = topics;
}
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
String messageResponse, Set<String> topics, Timestamp timestamp) {
this.guildId = guildId;
this.threadId = threadId;
this.userId = userId;
this.messageType = messageType;
this.messageText = messageText;
this.messageResponse = messageResponse;
this.topics = topics;
this.timestamp = timestamp;
}
public Long getGuildId() {
return guildId;
}
public Long getThreadId() {
return threadId;
}
public Long getUserId() {
return userId;
}
public String getMessageType() {
return messageType;
}
public String getMessageText() {
return messageText;
}
public String getMessageResponse() {
return messageResponse;
}
public Set<String> getTopics() {
return topics;
}
public Timestamp getTimestamp() {
return timestamp;
}
public void setTimestamp(Timestamp timestamp) {
this.timestamp = timestamp;
}
}

View File

@@ -0,0 +1,68 @@
package com.bensherriff.siren.database;
public class QueryBuilder {
private boolean distinct;
private String columnList;
private String tableName;
private String whereClause;
private String orderByClause;
private Integer limit;
public QueryBuilder select(String columnList) {
this.columnList = columnList;
return this;
}
public QueryBuilder from(String tableName) {
this.tableName = tableName;
return this;
}
public QueryBuilder where(String whereClause) {
this.whereClause = whereClause;
return this;
}
public QueryBuilder orderBy(String orderByClause) {
this.orderByClause = orderByClause;
return this;
}
public QueryBuilder limit(int limit) {
this.limit = limit;
return this;
}
public QueryBuilder distinct(boolean distinct) {
this.distinct = distinct;
return this;
}
public String build() {
StringBuilder queryBuilder = new StringBuilder();
queryBuilder.append("SELECT ");
if (distinct) {
queryBuilder.append("DISTINCT ");
}
if (columnList != null && !columnList.isEmpty()) {
queryBuilder.append(columnList);
} else {
queryBuilder.append("*");
}
queryBuilder.append(" FROM ");
queryBuilder.append(tableName);
if (whereClause != null) {
queryBuilder.append(" WHERE ");
queryBuilder.append(whereClause);
}
if (orderByClause != null) {
queryBuilder.append(" ORDER BY ");
queryBuilder.append(orderByClause);
}
if (limit != null) {
queryBuilder.append(" LIMIT ");
queryBuilder.append(limit);
}
return queryBuilder.toString();
}
}

View File

@@ -0,0 +1,95 @@
package com.bensherriff.siren.openai;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.util.CoreMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.*;
public class NLP {
private static final Logger LOGGER = LogManager.getLogger(NLP.class);
private final StanfordCoreNLP pipeline;
private final Map<String, List<String>> keywords;
public NLP() {
Properties props = new Properties();
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner");
pipeline = new StanfordCoreNLP(props);
keywords = new HashMap<>();
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
}
public Set<String> getTopicKeywords(String sentence) {
Set<String> topics = new LinkedHashSet<>();
Annotation document = new Annotation(sentence);
pipeline.annotate(document);
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
CoreMap sentenceMap = sentences.get(0);
List<CoreLabel> tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
List<CoreLabel> namedEntities = new ArrayList<>();
for (CoreLabel token : tokens) {
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
if (!ne.equals("0")) {
namedEntities.add(token);
}
}
for (CoreLabel namedEntity : namedEntities) {
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
String word = namedEntity.word();
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
topics.add(word);
} else if (ne.equals("LOCATION")) {
String[] posTags = word.split("_");
for (String posTag : posTags) {
if (posTag.startsWith("N")) {
topics.add(word);
break;
}
}
} else {
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
if (pos.startsWith("NN")) {
topics.add(word);
for (String keyword : keywords.keySet()) {
if (keywords.get(keyword).contains(word.toLowerCase())) {
topics.add(keyword);
}
}
if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
topics.add("dnd");
}
}
}
}
return topics;
}
public List<String> lemmatize(String documentText) {
List<String> lemmas = new ArrayList<>();
Annotation document = new Annotation(documentText);
pipeline.annotate(document);
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
for (CoreMap sentence : sentences) {
for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) {
lemmas.add(token.get(CoreAnnotations.LemmaAnnotation.class));
}
}
return lemmas;
}
// TODO finish this method
public Set<String> getSynonyms(String targetWord) {
return Collections.emptySet();
}
// TODO finish this method
public static double calculateSimilarity(String sentence1, String sentence2) {
return 0.0;
}
}

View File

@@ -2,6 +2,8 @@ package com.bensherriff.siren.openai;
import com.bensherriff.siren.Listener;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.database.MessageData;
import com.bensherriff.siren.database.QueryBuilder;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.theokanning.openai.completion.CompletionRequest;
@@ -11,7 +13,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.service.OpenAiService;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Message;
@@ -21,6 +22,7 @@ import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.sql.SQLException;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ScheduledExecutorService;
@@ -32,11 +34,15 @@ public class OpenAIManager {
private final Settings settings;
private final JDA jda;
private final ScheduledExecutorService executor;
private final String owner;
private final NLP NLP;
public OpenAIManager(Listener listener) {
this.settings = listener.getSettings();
this.jda = listener.getJDA();
this.executor = listener.getExecutor();
this.owner = listener.getOwner();
this.NLP = new NLP();
if (settings.getOpenAISettings().getToken().isEmpty()) {
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
@@ -56,19 +62,13 @@ public class OpenAIManager {
if (openAiService != null) {
executor.execute(() -> sendMessage(event));
} else {
event.getMessage().reply("OpenAI functionality is disabled. Please contact an administrator").queue();
event.getMessage().reply("OpenAI functionality is disabled. Please contact " + owner + ".").queue();
}
}
}
private void sendMessage(MessageReceivedEvent event) {
String message = parseMessage(event.getMessage().getContentRaw());
long guildId = event.getGuild().getIdLong();
Model model = settings.getGuildSettings().get(guildId).getModel();
StringBuilder stringBuilder = new StringBuilder();
List<ChatMessage> chatMessages = new ArrayList<>();
ChatMessage chatMessage = createChatMessage(message, event);
List<Embedding> embeddings = new ArrayList<>();
if (message.isEmpty() || message.isBlank()) {
event.getMessage().reply("Your message is empty. Please try again").queue();
@@ -76,7 +76,13 @@ public class OpenAIManager {
}
try {
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
StringBuilder stringBuilder = new StringBuilder();
long guildId = event.getGuild().getIdLong();
Model model = settings.getGuildSettings().get(guildId).getModel();
ChatMessage chatMessage = createChatMessage(message, event);
List<Embedding> embeddings = new ArrayList<>();
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, event.getAuthor().getId(), event.getMessageId(), message);
// Send OpenAI Message and get response
switch (model) {
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
@@ -85,8 +91,7 @@ public class OpenAIManager {
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
}
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
chatMessages.add(chatMessage);
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
@@ -96,7 +101,7 @@ public class OpenAIManager {
// embeddings.addAll(embeddingResult.getData());
}
default -> {
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
model, Arrays.toString(Model.values()));
return;
@@ -105,7 +110,7 @@ public class OpenAIManager {
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
} catch (Exception ex) {
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
}
}
@@ -118,8 +123,30 @@ public class OpenAIManager {
.build();
}
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(chatMessage);
// Handle System Messages
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName()));
if (event.isFromThread()) {
String query = new QueryBuilder().from("messages")
.where("guild_id = ? AND thread_id = ?")
.orderBy("timestamp DESC")
.limit(10)
.build();
List<MessageData> previousMessages = DatabaseManager.getMessages(
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
for (MessageData previousMessage : previousMessages) {
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() +
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
chatMessages.add(previousChatMessage);
}
}
return ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
@@ -146,6 +173,10 @@ public class OpenAIManager {
.build();
}
private ChatMessage createSystemMessage(String message) {
return new ChatMessage(Role.SYSTEM.getName(), message);
}
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
ChatMessage chatMessage = new ChatMessage();
chatMessage.setContent(message);
@@ -158,33 +189,38 @@ public class OpenAIManager {
}
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
Set<String> topics = new LinkedHashSet<>();
topics.addAll(NLP.getTopicKeywords(chatMessage.getContent()));
topics.addAll(NLP.getTopicKeywords(response));
LOGGER.trace("Topics: {}", topics);
if (event.isFromThread()) {
ThreadChannel channel = event.getChannel().asThreadChannel();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
channel.sendMessage(response).queue();
storeMessage(chatMessage, event, response, embeddings, topics, channel);
} else {
// The max discord title length is 100 characters
String threadTitle = chatMessage.getContent();
if (chatMessage.getContent().length() > 100) {
threadTitle = chatMessage.getContent().substring(0, 100);
}
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
channel.sendMessage(response).queue();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
});
event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel));
}
}
private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response,
List<Embedding> embeddings, Set<String> topics, ThreadChannel channel) {
MessageData messageData = new MessageData(event.getGuild().getIdLong(),
channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(),
response, topics);
int messageRow = DatabaseManager.storeMessage(messageData);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
channel.sendMessage(response).queue();
}
private String parseMessage(String input) {
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
}