v0.1.16 Query builder and working on memory management

This commit is contained in:
2023-04-16 07:58:16 -04:00
parent 414bca2dd6
commit e6fb86c4e0
10 changed files with 464 additions and 103 deletions

View File

@@ -1 +1 @@
export SIREN_VERSION=0.1.15 export SIREN_VERSION=0.1.16

19
pom.xml
View File

@@ -61,13 +61,14 @@
</distributionManagement> </distributionManagement>
<properties> <properties>
<jda.version>5.0.0-beta.6</jda.version> <jda.version>5.0.0-beta.8</jda.version>
<lavaplayer.version>1.4.0</lavaplayer.version> <lavaplayer.version>1.4.0</lavaplayer.version>
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version> <lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
<jackson.version>2.14.2</jackson.version> <jackson.version>2.14.2</jackson.version>
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version> <theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
<postgresql.version>42.6.0</postgresql.version> <postgresql.version>42.6.0</postgresql.version>
<slf4j.version>2.0.6</slf4j.version> <corenlp.version>4.5.3</corenlp.version>
<slf4j.version>2.0.7</slf4j.version>
<log4j.version>2.20.0</log4j.version> <log4j.version>2.20.0</log4j.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>17</maven.compiler.source> <maven.compiler.source>17</maven.compiler.source>
@@ -105,6 +106,18 @@
<artifactId>postgresql</artifactId> <artifactId>postgresql</artifactId>
<version>${postgresql.version}</version> <version>${postgresql.version}</version>
</dependency> </dependency>
<dependency>
<groupId>edu.stanford.nlp</groupId>
<artifactId>stanford-corenlp</artifactId>
<version>${corenlp.version}</version>
</dependency>
<dependency>
<groupId>edu.stanford.nlp</groupId>
<artifactId>stanford-corenlp</artifactId>
<version>${corenlp.version}</version>
<classifier>models</classifier>
<scope>runtime</scope>
</dependency>
<!-- Logging --> <!-- Logging -->
<dependency> <dependency>
@@ -150,7 +163,7 @@
</transformer> </transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<manifestEntries> <manifestEntries>
<Main-Class>com.bensherriff.siren.Bot</Main-Class> <Main-Class>com.bensherriff.siren.Main</Main-Class>
<Specification-Title>${project.artifactId}</Specification-Title> <Specification-Title>${project.artifactId}</Specification-Title>
<Specification-Version>${project.version}</Specification-Version> <Specification-Version>${project.version}</Specification-Version>
<Implementation-Title>${project.artifactId}</Implementation-Title> <Implementation-Title>${project.artifactId}</Implementation-Title>

View File

@@ -34,14 +34,15 @@ public class Listener extends ListenerAdapter {
private final Settings settings; private final Settings settings;
private final Map<String, Command> commands = new HashMap<>(); private final Map<String, Command> commands = new HashMap<>();
private final String owner;
private PlayerManager playerManager; private PlayerManager playerManager;
private OpenAIManager openAIManager; private OpenAIManager openAIManager;
private JDA jda; private JDA jda;
public Listener(Settings settings) { public Listener(Settings settings) {
this.settings = settings; this.settings = settings;
this.executor = Executors.newScheduledThreadPool(this.settings.getThreadPool()); this.executor = Executors.newScheduledThreadPool(this.settings.getThreadPool());
this.owner = "@bsherriff";
} }
public ScheduledExecutorService getExecutor() { public ScheduledExecutorService getExecutor() {
@@ -64,6 +65,10 @@ public class Listener extends ListenerAdapter {
this.jda = jda; this.jda = jda;
} }
public String getOwner() {
return owner;
}
public void initialize() { public void initialize() {
this.playerManager = new PlayerManager(this); this.playerManager = new PlayerManager(this);
this.playerManager.initialize(); this.playerManager.initialize();
@@ -78,6 +83,7 @@ public class Listener extends ListenerAdapter {
commands.put("pause", new PauseCommand(this)); commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this)); commands.put("resume", new ResumeCommand(this));
commands.put("help", new PauseCommand(this)); commands.put("help", new PauseCommand(this));
} }
public void closeAudioConnection(long guildID) { public void closeAudioConnection(long guildID) {
@@ -144,7 +150,7 @@ public class Listener extends ListenerAdapter {
commands.get(command).execute(event); commands.get(command).execute(event);
} catch (Exception ex) { } catch (Exception ex) {
LOGGER.error(ex.getMessage()); LOGGER.error(ex.getMessage());
event.getHook().sendMessage("An error occurred while processing your command. Please contact your administrator.").queue(); event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue();
} }
}); });
} else { } else {

View File

@@ -13,8 +13,8 @@ import org.apache.logging.log4j.Logger;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
public class Bot { public class Main {
private static final Logger LOGGER = LogManager.getLogger(Bot.class); private static final Logger LOGGER = LogManager.getLogger(Main.class);
private final static GatewayIntent[] INTENTS = { private final static GatewayIntent[] INTENTS = {
GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS, GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS,
GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT

View File

@@ -89,7 +89,7 @@ public class PlayCommand extends Command {
if (exception.getMessage().contains("Unknown file format")) { if (exception.getMessage().contains("Unknown file format")) {
event.getHook().sendMessage(errorMsg + ". " + exception.getMessage()).queue(); event.getHook().sendMessage(errorMsg + ". " + exception.getMessage()).queue();
} else { } else {
event.getHook().sendMessage(errorMsg + ". Please contact your administrator.").queue(); event.getHook().sendMessage(errorMsg + ". Please contact " + listener.getOwner() + ".").queue();
} }
LOGGER.error("{}: {}", errorMsg, exception.getMessage()); LOGGER.error("{}: {}", errorMsg, exception.getMessage());
} }

View File

@@ -1,42 +1,47 @@
package com.bensherriff.siren.database; package com.bensherriff.siren.database;
import com.theokanning.openai.completion.chat.ChatMessage;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import java.sql.*; import java.sql.*;
import java.util.List; import java.util.*;
import static java.util.Map.entry;
public class DatabaseManager { public class DatabaseManager {
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class); private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
"id SERIAL PRIMARY KEY, " +
"guild_id BIGINT NOT NULL, " +
"thread_id BIGINT NOT NULL, " +
"user_id BIGINT NOT NULL, " +
"message_type VARCHAR(20) NOT NULL," +
"message_text TEXT NOT NULL," +
"message_response TEXT, " +
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
")";
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
"id SERIAL PRIMARY KEY, " +
"embeddings FLOAT[] NOT NULL" +
");";
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" + private static final Map<String, String> createTableQueries = Map.ofEntries(
"id SERIAL PRIMARY KEY, " + entry("messages", "CREATE TABLE IF NOT EXISTS messages (" +
"message_id INT NOT NULL, " + "id SERIAL PRIMARY KEY, " +
"embedding_id INT NOT NULL" + "guild_id BIGINT NOT NULL, " +
")"; "thread_id BIGINT NOT NULL, " +
"user_id BIGINT NOT NULL, " +
"message_type VARCHAR(20) NOT NULL," +
"message_text TEXT NOT NULL," +
"message_response TEXT, " +
"topics TEXT[], " +
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
")"),
entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" +
"id SERIAL PRIMARY KEY, " +
"embeddings FLOAT[] NOT NULL" +
")"),
entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" +
"id SERIAL PRIMARY KEY, " +
"message_id INT NOT NULL, " +
"embedding_id INT NOT NULL" +
")")
);
private static final String INSERT_MESSAGE = "INSERT INTO messages (" + private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
"message_type, " + "message_type, " +
"guild_id, " + "guild_id, " +
"thread_id, " + "thread_id, " +
"user_id, " + "user_id, " +
"message_text, " + "message_text, " +
"message_response) " + "message_response, " +
"VALUES (?, ?, ?, ?, ?, ?)"; "topics) " +
"VALUES (?, ?, ?, ?, ?, ?, ?)";
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" + private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
"embeddings) " + "embeddings) " +
"VALUES (?)"; "VALUES (?)";
@@ -47,52 +52,47 @@ public class DatabaseManager {
"VALUES (?, ?)"; "VALUES (?, ?)";
public static void createTables() { public static void createTables() {
createMessageTable(); for (Map.Entry<String, String> entry : createTableQueries.entrySet()) {
createEmbeddingsTable(); if (!createTable(entry.getKey())) {
createMessageEmbeddingsTable(); LOGGER.warn("Failed to create one or more required database tables");
return;
}
}
LOGGER.debug("Successfully created database tables");
} }
private static void createMessageTable() { private static boolean createTable(String tableName) {
try { if (tableExists(tableName)) {
Connection connection = DatabaseConnection.getConnection(); return true;
LOGGER.debug("Creating 'messages' database table if it does not exist"); } else {
Statement statement = connection.createStatement(); try {
statement.execute(CREATE_MESSAGE_TABLE); Connection connection = DatabaseConnection.getConnection();
} catch (SQLException ex) { LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
LOGGER.error(ex.getMessage()); Statement statement = connection.createStatement();
statement.execute(createTableQueries.get(tableName));
return true;
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
return false;
}
} }
} }
private static void createEmbeddingsTable() { public static int storeMessage(MessageData messageData) {
if (!tableExists("messages")) {
LOGGER.warn("Table 'messages' does not exist");
return -1;
}
try { try {
Connection connection = DatabaseConnection.getConnection(); Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'embeddings' database table if it does not exist"); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE);
Statement statement = connection.createStatement(); preparedStatement.setString(1, messageData.getMessageType());
statement.execute(CREATE_EMBEDDINGS_TABLE); preparedStatement.setLong(2, messageData.getGuildId());
} catch (SQLException ex) { preparedStatement.setLong(3, messageData.getThreadId());
LOGGER.error(ex.getMessage()); preparedStatement.setLong(4, messageData.getUserId());
} preparedStatement.setString(5, messageData.getMessageText());
} preparedStatement.setString(6, messageData.getMessageResponse());
preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray()));
private static void createMessageEmbeddingsTable() {
try {
Connection connection = DatabaseConnection.getConnection();
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
Statement statement = connection.createStatement();
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
} catch (SQLException ex) {
LOGGER.error(ex.getMessage());
}
}
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
preparedStatement.setString(1, message.getRole());
preparedStatement.setLong(2, guildId);
preparedStatement.setLong(3, threadId);
preparedStatement.setLong(4, userId);
preparedStatement.setString(5, message.getContent());
preparedStatement.setString(6, response);
return preparedStatement.executeUpdate(); return preparedStatement.executeUpdate();
} catch (SQLException ex) { } catch (SQLException ex) {
LOGGER.error(ex.getMessage()); LOGGER.error(ex.getMessage());
@@ -101,7 +101,13 @@ public class DatabaseManager {
} }
public static int storeEmbedding(List<Double> data) { public static int storeEmbedding(List<Double> data) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) { if (!tableExists("embeddings")) {
LOGGER.warn("Table 'embeddings' does not exist");
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING);
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0]))); preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
return preparedStatement.executeUpdate(); return preparedStatement.executeUpdate();
} catch (SQLException ex) { } catch (SQLException ex) {
@@ -111,7 +117,13 @@ public class DatabaseManager {
} }
public static int storeMessageEmbeddings(int messageId, int embeddingId) { public static int storeMessageEmbeddings(int messageId, int embeddingId) {
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) { if (!tableExists("message_embeddings")) {
LOGGER.warn("Table 'message_embeddings' does not exist");
return -1;
}
try {
Connection connection = DatabaseConnection.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS);
preparedStatement.setInt(1, messageId); preparedStatement.setInt(1, messageId);
preparedStatement.setInt(2, embeddingId); preparedStatement.setInt(2, embeddingId);
return preparedStatement.executeUpdate(); return preparedStatement.executeUpdate();
@@ -120,4 +132,61 @@ public class DatabaseManager {
return -1; return -1;
} }
} }
public static List<MessageData> getMessages(String query, Object... params) throws SQLException, IllegalArgumentException {
LOGGER.trace("Query: <{}>", query);
Connection connection = DatabaseConnection.getConnection();
PreparedStatement stmt = connection.prepareStatement(query);
int i = 1;
for (Object param : params) {
if (param instanceof String) {
stmt.setString(i++, (String) param);
} else if (param instanceof Integer) {
stmt.setInt(i++, (Integer) param);
} else if (param instanceof Long) {
stmt.setLong(i++, (Long) param);
} else if (param instanceof Double) {
stmt.setDouble(i++, (Double) param);
} else if (param instanceof Float) {
stmt.setFloat(i++, (Float) param);
} else if (param instanceof Timestamp) {
stmt.setTimestamp(i++, (Timestamp) param);
} else if (param instanceof Boolean) {
stmt.setBoolean(i++, (Boolean) param);
} else if (param instanceof List) {
stmt.setArray(i++, connection.createArrayOf("text", ((List<?>) param).toArray()));
} else {
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName());
}
}
ResultSet resultSet = stmt.executeQuery();
List<MessageData> resultList = new ArrayList<>();
while (resultSet.next()) {
Array array = resultSet.getArray(8);
MessageData messageData = new MessageData(
resultSet.getLong(2),
resultSet.getLong(3),
resultSet.getLong(4),
resultSet.getString(5),
resultSet.getString(6),
resultSet.getString(7),
new HashSet<>(Arrays.asList((String[]) array.getArray())),
resultSet.getTimestamp(9)
);
resultList.add(messageData);
}
return resultList;
}
private static boolean tableExists(String tableName) {
try {
Connection connection = DatabaseConnection.getConnection();
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'");
return resultSet.next();
} catch (SQLException ex) {
LOGGER.error("Failed to check if table exists; {}" + ex.getMessage());
return false;
}
}
} }

View File

@@ -0,0 +1,74 @@
package com.bensherriff.siren.database;
import java.sql.Timestamp;
import java.util.Set;
public class MessageData {
private final Long guildId;
private final Long threadId;
private final Long userId;
private final String messageType;
private final String messageText;
private final String messageResponse;
private final Set<String> topics;
private Timestamp timestamp;
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
String messageResponse, Set<String> topics) {
this.guildId = guildId;
this.threadId = threadId;
this.userId = userId;
this.messageType = messageType;
this.messageText = messageText;
this.messageResponse = messageResponse;
this.topics = topics;
}
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
String messageResponse, Set<String> topics, Timestamp timestamp) {
this.guildId = guildId;
this.threadId = threadId;
this.userId = userId;
this.messageType = messageType;
this.messageText = messageText;
this.messageResponse = messageResponse;
this.topics = topics;
this.timestamp = timestamp;
}
public Long getGuildId() {
return guildId;
}
public Long getThreadId() {
return threadId;
}
public Long getUserId() {
return userId;
}
public String getMessageType() {
return messageType;
}
public String getMessageText() {
return messageText;
}
public String getMessageResponse() {
return messageResponse;
}
public Set<String> getTopics() {
return topics;
}
public Timestamp getTimestamp() {
return timestamp;
}
public void setTimestamp(Timestamp timestamp) {
this.timestamp = timestamp;
}
}

View File

@@ -0,0 +1,68 @@
package com.bensherriff.siren.database;
public class QueryBuilder {
private boolean distinct;
private String columnList;
private String tableName;
private String whereClause;
private String orderByClause;
private Integer limit;
public QueryBuilder select(String columnList) {
this.columnList = columnList;
return this;
}
public QueryBuilder from(String tableName) {
this.tableName = tableName;
return this;
}
public QueryBuilder where(String whereClause) {
this.whereClause = whereClause;
return this;
}
public QueryBuilder orderBy(String orderByClause) {
this.orderByClause = orderByClause;
return this;
}
public QueryBuilder limit(int limit) {
this.limit = limit;
return this;
}
public QueryBuilder distinct(boolean distinct) {
this.distinct = distinct;
return this;
}
public String build() {
StringBuilder queryBuilder = new StringBuilder();
queryBuilder.append("SELECT ");
if (distinct) {
queryBuilder.append("DISTINCT ");
}
if (columnList != null && !columnList.isEmpty()) {
queryBuilder.append(columnList);
} else {
queryBuilder.append("*");
}
queryBuilder.append(" FROM ");
queryBuilder.append(tableName);
if (whereClause != null) {
queryBuilder.append(" WHERE ");
queryBuilder.append(whereClause);
}
if (orderByClause != null) {
queryBuilder.append(" ORDER BY ");
queryBuilder.append(orderByClause);
}
if (limit != null) {
queryBuilder.append(" LIMIT ");
queryBuilder.append(limit);
}
return queryBuilder.toString();
}
}

View File

@@ -0,0 +1,95 @@
package com.bensherriff.siren.openai;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.util.CoreMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.*;
public class NLP {
private static final Logger LOGGER = LogManager.getLogger(NLP.class);
private final StanfordCoreNLP pipeline;
private final Map<String, List<String>> keywords;
public NLP() {
Properties props = new Properties();
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner");
pipeline = new StanfordCoreNLP(props);
keywords = new HashMap<>();
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
}
public Set<String> getTopicKeywords(String sentence) {
Set<String> topics = new LinkedHashSet<>();
Annotation document = new Annotation(sentence);
pipeline.annotate(document);
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
CoreMap sentenceMap = sentences.get(0);
List<CoreLabel> tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
List<CoreLabel> namedEntities = new ArrayList<>();
for (CoreLabel token : tokens) {
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
if (!ne.equals("0")) {
namedEntities.add(token);
}
}
for (CoreLabel namedEntity : namedEntities) {
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
String word = namedEntity.word();
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
topics.add(word);
} else if (ne.equals("LOCATION")) {
String[] posTags = word.split("_");
for (String posTag : posTags) {
if (posTag.startsWith("N")) {
topics.add(word);
break;
}
}
} else {
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
if (pos.startsWith("NN")) {
topics.add(word);
for (String keyword : keywords.keySet()) {
if (keywords.get(keyword).contains(word.toLowerCase())) {
topics.add(keyword);
}
}
if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
topics.add("dnd");
}
}
}
}
return topics;
}
public List<String> lemmatize(String documentText) {
List<String> lemmas = new ArrayList<>();
Annotation document = new Annotation(documentText);
pipeline.annotate(document);
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
for (CoreMap sentence : sentences) {
for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) {
lemmas.add(token.get(CoreAnnotations.LemmaAnnotation.class));
}
}
return lemmas;
}
// TODO finish this method
public Set<String> getSynonyms(String targetWord) {
return Collections.emptySet();
}
// TODO finish this method
public static double calculateSimilarity(String sentence1, String sentence2) {
return 0.0;
}
}

View File

@@ -2,6 +2,8 @@ package com.bensherriff.siren.openai;
import com.bensherriff.siren.Listener; import com.bensherriff.siren.Listener;
import com.bensherriff.siren.database.DatabaseManager; import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.database.MessageData;
import com.bensherriff.siren.database.QueryBuilder;
import com.bensherriff.siren.settings.GuildSettings; import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings; import com.bensherriff.siren.settings.Settings;
import com.theokanning.openai.completion.CompletionRequest; import com.theokanning.openai.completion.CompletionRequest;
@@ -11,7 +13,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.embedding.Embedding; import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest; import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.service.OpenAiService; import com.theokanning.openai.service.OpenAiService;
import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Message; import net.dv8tion.jda.api.entities.Message;
@@ -21,6 +22,7 @@ import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import java.sql.SQLException;
import java.time.Duration; import java.time.Duration;
import java.util.*; import java.util.*;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@@ -32,11 +34,15 @@ public class OpenAIManager {
private final Settings settings; private final Settings settings;
private final JDA jda; private final JDA jda;
private final ScheduledExecutorService executor; private final ScheduledExecutorService executor;
private final String owner;
private final NLP NLP;
public OpenAIManager(Listener listener) { public OpenAIManager(Listener listener) {
this.settings = listener.getSettings(); this.settings = listener.getSettings();
this.jda = listener.getJDA(); this.jda = listener.getJDA();
this.executor = listener.getExecutor(); this.executor = listener.getExecutor();
this.owner = listener.getOwner();
this.NLP = new NLP();
if (settings.getOpenAISettings().getToken().isEmpty()) { if (settings.getOpenAISettings().getToken().isEmpty()) {
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled"); LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
@@ -56,19 +62,13 @@ public class OpenAIManager {
if (openAiService != null) { if (openAiService != null) {
executor.execute(() -> sendMessage(event)); executor.execute(() -> sendMessage(event));
} else { } else {
event.getMessage().reply("OpenAI functionality is disabled. Please contact an administrator").queue(); event.getMessage().reply("OpenAI functionality is disabled. Please contact " + owner + ".").queue();
} }
} }
} }
private void sendMessage(MessageReceivedEvent event) { private void sendMessage(MessageReceivedEvent event) {
String message = parseMessage(event.getMessage().getContentRaw()); String message = parseMessage(event.getMessage().getContentRaw());
long guildId = event.getGuild().getIdLong();
Model model = settings.getGuildSettings().get(guildId).getModel();
StringBuilder stringBuilder = new StringBuilder();
List<ChatMessage> chatMessages = new ArrayList<>();
ChatMessage chatMessage = createChatMessage(message, event);
List<Embedding> embeddings = new ArrayList<>();
if (message.isEmpty() || message.isBlank()) { if (message.isEmpty() || message.isBlank()) {
event.getMessage().reply("Your message is empty. Please try again").queue(); event.getMessage().reply("Your message is empty. Please try again").queue();
@@ -76,7 +76,13 @@ public class OpenAIManager {
} }
try { try {
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message); StringBuilder stringBuilder = new StringBuilder();
long guildId = event.getGuild().getIdLong();
Model model = settings.getGuildSettings().get(guildId).getModel();
ChatMessage chatMessage = createChatMessage(message, event);
List<Embedding> embeddings = new ArrayList<>();
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, event.getAuthor().getId(), event.getMessageId(), message);
// Send OpenAI Message and get response // Send OpenAI Message and get response
switch (model) { switch (model) {
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> { case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
@@ -85,8 +91,7 @@ public class OpenAIManager {
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim())); completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
} }
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> { case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
chatMessages.add(chatMessage); ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest); ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim())); chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
@@ -96,7 +101,7 @@ public class OpenAIManager {
// embeddings.addAll(embeddingResult.getData()); // embeddings.addAll(embeddingResult.getData());
} }
default -> { default -> {
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue(); event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId, LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
model, Arrays.toString(Model.values())); model, Arrays.toString(Model.values()));
return; return;
@@ -105,7 +110,7 @@ public class OpenAIManager {
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings); handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
} catch (Exception ex) { } catch (Exception ex) {
LOGGER.error("Caught exception while processing message; {}", ex.getMessage()); LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue(); event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
} }
} }
@@ -118,8 +123,30 @@ public class OpenAIManager {
.build(); .build();
} }
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) { private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()); GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(chatMessage);
// Handle System Messages
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName()));
if (event.isFromThread()) {
String query = new QueryBuilder().from("messages")
.where("guild_id = ? AND thread_id = ?")
.orderBy("timestamp DESC")
.limit(10)
.build();
List<MessageData> previousMessages = DatabaseManager.getMessages(
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
for (MessageData previousMessage : previousMessages) {
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() +
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
chatMessages.add(previousChatMessage);
}
}
return ChatCompletionRequest.builder() return ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName()) .model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens()) .maxTokens(guildSettings.getMaxTokens())
@@ -146,6 +173,10 @@ public class OpenAIManager {
.build(); .build();
} }
private ChatMessage createSystemMessage(String message) {
return new ChatMessage(Role.SYSTEM.getName(), message);
}
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) { private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
ChatMessage chatMessage = new ChatMessage(); ChatMessage chatMessage = new ChatMessage();
chatMessage.setContent(message); chatMessage.setContent(message);
@@ -158,33 +189,38 @@ public class OpenAIManager {
} }
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) { private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
Set<String> topics = new LinkedHashSet<>();
topics.addAll(NLP.getTopicKeywords(chatMessage.getContent()));
topics.addAll(NLP.getTopicKeywords(response));
LOGGER.trace("Topics: {}", topics);
if (event.isFromThread()) { if (event.isFromThread()) {
ThreadChannel channel = event.getChannel().asThreadChannel(); ThreadChannel channel = event.getChannel().asThreadChannel();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(), storeMessage(chatMessage, event, response, embeddings, topics, channel);
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
channel.sendMessage(response).queue();
} else { } else {
// The max discord title length is 100 characters // The max discord title length is 100 characters
String threadTitle = chatMessage.getContent(); String threadTitle = chatMessage.getContent();
if (chatMessage.getContent().length() > 100) { if (chatMessage.getContent().length() > 100) {
threadTitle = chatMessage.getContent().substring(0, 100); threadTitle = chatMessage.getContent().substring(0, 100);
} }
event.getMessage().createThreadChannel(threadTitle).queue(channel -> { event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel));
channel.sendMessage(response).queue();
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
channel.getIdLong(), response);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
});
} }
} }
private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response,
List<Embedding> embeddings, Set<String> topics, ThreadChannel channel) {
MessageData messageData = new MessageData(event.getGuild().getIdLong(),
channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(),
response, topics);
int messageRow = DatabaseManager.storeMessage(messageData);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
channel.sendMessage(response).queue();
}
private String parseMessage(String input) { private String parseMessage(String input) {
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim(); return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
} }