v0.1.16 Query builder and working on memory management
This commit is contained in:
19
pom.xml
19
pom.xml
@@ -61,13 +61,14 @@
|
|||||||
</distributionManagement>
|
</distributionManagement>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<jda.version>5.0.0-beta.6</jda.version>
|
<jda.version>5.0.0-beta.8</jda.version>
|
||||||
<lavaplayer.version>1.4.0</lavaplayer.version>
|
<lavaplayer.version>1.4.0</lavaplayer.version>
|
||||||
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
|
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
|
||||||
<jackson.version>2.14.2</jackson.version>
|
<jackson.version>2.14.2</jackson.version>
|
||||||
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
||||||
<postgresql.version>42.6.0</postgresql.version>
|
<postgresql.version>42.6.0</postgresql.version>
|
||||||
<slf4j.version>2.0.6</slf4j.version>
|
<corenlp.version>4.5.3</corenlp.version>
|
||||||
|
<slf4j.version>2.0.7</slf4j.version>
|
||||||
<log4j.version>2.20.0</log4j.version>
|
<log4j.version>2.20.0</log4j.version>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<maven.compiler.source>17</maven.compiler.source>
|
<maven.compiler.source>17</maven.compiler.source>
|
||||||
@@ -105,6 +106,18 @@
|
|||||||
<artifactId>postgresql</artifactId>
|
<artifactId>postgresql</artifactId>
|
||||||
<version>${postgresql.version}</version>
|
<version>${postgresql.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>edu.stanford.nlp</groupId>
|
||||||
|
<artifactId>stanford-corenlp</artifactId>
|
||||||
|
<version>${corenlp.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>edu.stanford.nlp</groupId>
|
||||||
|
<artifactId>stanford-corenlp</artifactId>
|
||||||
|
<version>${corenlp.version}</version>
|
||||||
|
<classifier>models</classifier>
|
||||||
|
<scope>runtime</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<!-- Logging -->
|
<!-- Logging -->
|
||||||
<dependency>
|
<dependency>
|
||||||
@@ -150,7 +163,7 @@
|
|||||||
</transformer>
|
</transformer>
|
||||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||||
<manifestEntries>
|
<manifestEntries>
|
||||||
<Main-Class>com.bensherriff.siren.Bot</Main-Class>
|
<Main-Class>com.bensherriff.siren.Main</Main-Class>
|
||||||
<Specification-Title>${project.artifactId}</Specification-Title>
|
<Specification-Title>${project.artifactId}</Specification-Title>
|
||||||
<Specification-Version>${project.version}</Specification-Version>
|
<Specification-Version>${project.version}</Specification-Version>
|
||||||
<Implementation-Title>${project.artifactId}</Implementation-Title>
|
<Implementation-Title>${project.artifactId}</Implementation-Title>
|
||||||
|
|||||||
@@ -34,14 +34,15 @@ public class Listener extends ListenerAdapter {
|
|||||||
|
|
||||||
private final Settings settings;
|
private final Settings settings;
|
||||||
private final Map<String, Command> commands = new HashMap<>();
|
private final Map<String, Command> commands = new HashMap<>();
|
||||||
|
private final String owner;
|
||||||
private PlayerManager playerManager;
|
private PlayerManager playerManager;
|
||||||
private OpenAIManager openAIManager;
|
private OpenAIManager openAIManager;
|
||||||
private JDA jda;
|
private JDA jda;
|
||||||
|
|
||||||
public Listener(Settings settings) {
|
public Listener(Settings settings) {
|
||||||
this.settings = settings;
|
this.settings = settings;
|
||||||
|
|
||||||
this.executor = Executors.newScheduledThreadPool(this.settings.getThreadPool());
|
this.executor = Executors.newScheduledThreadPool(this.settings.getThreadPool());
|
||||||
|
this.owner = "@bsherriff";
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScheduledExecutorService getExecutor() {
|
public ScheduledExecutorService getExecutor() {
|
||||||
@@ -64,6 +65,10 @@ public class Listener extends ListenerAdapter {
|
|||||||
this.jda = jda;
|
this.jda = jda;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getOwner() {
|
||||||
|
return owner;
|
||||||
|
}
|
||||||
|
|
||||||
public void initialize() {
|
public void initialize() {
|
||||||
this.playerManager = new PlayerManager(this);
|
this.playerManager = new PlayerManager(this);
|
||||||
this.playerManager.initialize();
|
this.playerManager.initialize();
|
||||||
@@ -78,6 +83,7 @@ public class Listener extends ListenerAdapter {
|
|||||||
commands.put("pause", new PauseCommand(this));
|
commands.put("pause", new PauseCommand(this));
|
||||||
commands.put("resume", new ResumeCommand(this));
|
commands.put("resume", new ResumeCommand(this));
|
||||||
commands.put("help", new PauseCommand(this));
|
commands.put("help", new PauseCommand(this));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void closeAudioConnection(long guildID) {
|
public void closeAudioConnection(long guildID) {
|
||||||
@@ -144,7 +150,7 @@ public class Listener extends ListenerAdapter {
|
|||||||
commands.get(command).execute(event);
|
commands.get(command).execute(event);
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
LOGGER.error(ex.getMessage());
|
LOGGER.error(ex.getMessage());
|
||||||
event.getHook().sendMessage("An error occurred while processing your command. Please contact your administrator.").queue();
|
event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import org.apache.logging.log4j.Logger;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
public class Bot {
|
public class Main {
|
||||||
private static final Logger LOGGER = LogManager.getLogger(Bot.class);
|
private static final Logger LOGGER = LogManager.getLogger(Main.class);
|
||||||
private final static GatewayIntent[] INTENTS = {
|
private final static GatewayIntent[] INTENTS = {
|
||||||
GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS,
|
GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS,
|
||||||
GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT
|
GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT
|
||||||
@@ -89,7 +89,7 @@ public class PlayCommand extends Command {
|
|||||||
if (exception.getMessage().contains("Unknown file format")) {
|
if (exception.getMessage().contains("Unknown file format")) {
|
||||||
event.getHook().sendMessage(errorMsg + ". " + exception.getMessage()).queue();
|
event.getHook().sendMessage(errorMsg + ". " + exception.getMessage()).queue();
|
||||||
} else {
|
} else {
|
||||||
event.getHook().sendMessage(errorMsg + ". Please contact your administrator.").queue();
|
event.getHook().sendMessage(errorMsg + ". Please contact " + listener.getOwner() + ".").queue();
|
||||||
}
|
}
|
||||||
LOGGER.error("{}: {}", errorMsg, exception.getMessage());
|
LOGGER.error("{}: {}", errorMsg, exception.getMessage());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,42 +1,47 @@
|
|||||||
package com.bensherriff.siren.database;
|
package com.bensherriff.siren.database;
|
||||||
|
|
||||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
|
|
||||||
import java.sql.*;
|
import java.sql.*;
|
||||||
import java.util.List;
|
import java.util.*;
|
||||||
|
|
||||||
|
import static java.util.Map.entry;
|
||||||
|
|
||||||
public class DatabaseManager {
|
public class DatabaseManager {
|
||||||
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
|
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
|
||||||
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
|
|
||||||
"id SERIAL PRIMARY KEY, " +
|
|
||||||
"guild_id BIGINT NOT NULL, " +
|
|
||||||
"thread_id BIGINT NOT NULL, " +
|
|
||||||
"user_id BIGINT NOT NULL, " +
|
|
||||||
"message_type VARCHAR(20) NOT NULL," +
|
|
||||||
"message_text TEXT NOT NULL," +
|
|
||||||
"message_response TEXT, " +
|
|
||||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
|
|
||||||
")";
|
|
||||||
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
|
|
||||||
"id SERIAL PRIMARY KEY, " +
|
|
||||||
"embeddings FLOAT[] NOT NULL" +
|
|
||||||
");";
|
|
||||||
|
|
||||||
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
private static final Map<String, String> createTableQueries = Map.ofEntries(
|
||||||
"id SERIAL PRIMARY KEY, " +
|
entry("messages", "CREATE TABLE IF NOT EXISTS messages (" +
|
||||||
"message_id INT NOT NULL, " +
|
"id SERIAL PRIMARY KEY, " +
|
||||||
"embedding_id INT NOT NULL" +
|
"guild_id BIGINT NOT NULL, " +
|
||||||
")";
|
"thread_id BIGINT NOT NULL, " +
|
||||||
|
"user_id BIGINT NOT NULL, " +
|
||||||
|
"message_type VARCHAR(20) NOT NULL," +
|
||||||
|
"message_text TEXT NOT NULL," +
|
||||||
|
"message_response TEXT, " +
|
||||||
|
"topics TEXT[], " +
|
||||||
|
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
|
||||||
|
")"),
|
||||||
|
entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" +
|
||||||
|
"id SERIAL PRIMARY KEY, " +
|
||||||
|
"embeddings FLOAT[] NOT NULL" +
|
||||||
|
")"),
|
||||||
|
entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
||||||
|
"id SERIAL PRIMARY KEY, " +
|
||||||
|
"message_id INT NOT NULL, " +
|
||||||
|
"embedding_id INT NOT NULL" +
|
||||||
|
")")
|
||||||
|
);
|
||||||
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
|
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||||
"message_type, " +
|
"message_type, " +
|
||||||
"guild_id, " +
|
"guild_id, " +
|
||||||
"thread_id, " +
|
"thread_id, " +
|
||||||
"user_id, " +
|
"user_id, " +
|
||||||
"message_text, " +
|
"message_text, " +
|
||||||
"message_response) " +
|
"message_response, " +
|
||||||
"VALUES (?, ?, ?, ?, ?, ?)";
|
"topics) " +
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?)";
|
||||||
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
||||||
"embeddings) " +
|
"embeddings) " +
|
||||||
"VALUES (?)";
|
"VALUES (?)";
|
||||||
@@ -47,52 +52,47 @@ public class DatabaseManager {
|
|||||||
"VALUES (?, ?)";
|
"VALUES (?, ?)";
|
||||||
|
|
||||||
public static void createTables() {
|
public static void createTables() {
|
||||||
createMessageTable();
|
for (Map.Entry<String, String> entry : createTableQueries.entrySet()) {
|
||||||
createEmbeddingsTable();
|
if (!createTable(entry.getKey())) {
|
||||||
createMessageEmbeddingsTable();
|
LOGGER.warn("Failed to create one or more required database tables");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOGGER.debug("Successfully created database tables");
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void createMessageTable() {
|
private static boolean createTable(String tableName) {
|
||||||
try {
|
if (tableExists(tableName)) {
|
||||||
Connection connection = DatabaseConnection.getConnection();
|
return true;
|
||||||
LOGGER.debug("Creating 'messages' database table if it does not exist");
|
} else {
|
||||||
Statement statement = connection.createStatement();
|
try {
|
||||||
statement.execute(CREATE_MESSAGE_TABLE);
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
} catch (SQLException ex) {
|
LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
|
||||||
LOGGER.error(ex.getMessage());
|
Statement statement = connection.createStatement();
|
||||||
|
statement.execute(createTableQueries.get(tableName));
|
||||||
|
return true;
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void createEmbeddingsTable() {
|
public static int storeMessage(MessageData messageData) {
|
||||||
|
if (!tableExists("messages")) {
|
||||||
|
LOGGER.warn("Table 'messages' does not exist");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
Connection connection = DatabaseConnection.getConnection();
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
|
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE);
|
||||||
Statement statement = connection.createStatement();
|
preparedStatement.setString(1, messageData.getMessageType());
|
||||||
statement.execute(CREATE_EMBEDDINGS_TABLE);
|
preparedStatement.setLong(2, messageData.getGuildId());
|
||||||
} catch (SQLException ex) {
|
preparedStatement.setLong(3, messageData.getThreadId());
|
||||||
LOGGER.error(ex.getMessage());
|
preparedStatement.setLong(4, messageData.getUserId());
|
||||||
}
|
preparedStatement.setString(5, messageData.getMessageText());
|
||||||
}
|
preparedStatement.setString(6, messageData.getMessageResponse());
|
||||||
|
preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray()));
|
||||||
private static void createMessageEmbeddingsTable() {
|
|
||||||
try {
|
|
||||||
Connection connection = DatabaseConnection.getConnection();
|
|
||||||
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
|
|
||||||
Statement statement = connection.createStatement();
|
|
||||||
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
|
|
||||||
} catch (SQLException ex) {
|
|
||||||
LOGGER.error(ex.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
|
|
||||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
|
|
||||||
preparedStatement.setString(1, message.getRole());
|
|
||||||
preparedStatement.setLong(2, guildId);
|
|
||||||
preparedStatement.setLong(3, threadId);
|
|
||||||
preparedStatement.setLong(4, userId);
|
|
||||||
preparedStatement.setString(5, message.getContent());
|
|
||||||
preparedStatement.setString(6, response);
|
|
||||||
return preparedStatement.executeUpdate();
|
return preparedStatement.executeUpdate();
|
||||||
} catch (SQLException ex) {
|
} catch (SQLException ex) {
|
||||||
LOGGER.error(ex.getMessage());
|
LOGGER.error(ex.getMessage());
|
||||||
@@ -101,7 +101,13 @@ public class DatabaseManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static int storeEmbedding(List<Double> data) {
|
public static int storeEmbedding(List<Double> data) {
|
||||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
|
if (!tableExists("embeddings")) {
|
||||||
|
LOGGER.warn("Table 'embeddings' does not exist");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
|
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING);
|
||||||
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
|
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
|
||||||
return preparedStatement.executeUpdate();
|
return preparedStatement.executeUpdate();
|
||||||
} catch (SQLException ex) {
|
} catch (SQLException ex) {
|
||||||
@@ -111,7 +117,13 @@ public class DatabaseManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
|
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
|
||||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
|
if (!tableExists("message_embeddings")) {
|
||||||
|
LOGGER.warn("Table 'message_embeddings' does not exist");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
|
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS);
|
||||||
preparedStatement.setInt(1, messageId);
|
preparedStatement.setInt(1, messageId);
|
||||||
preparedStatement.setInt(2, embeddingId);
|
preparedStatement.setInt(2, embeddingId);
|
||||||
return preparedStatement.executeUpdate();
|
return preparedStatement.executeUpdate();
|
||||||
@@ -120,4 +132,61 @@ public class DatabaseManager {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static List<MessageData> getMessages(String query, Object... params) throws SQLException, IllegalArgumentException {
|
||||||
|
LOGGER.trace("Query: <{}>", query);
|
||||||
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
|
PreparedStatement stmt = connection.prepareStatement(query);
|
||||||
|
int i = 1;
|
||||||
|
for (Object param : params) {
|
||||||
|
if (param instanceof String) {
|
||||||
|
stmt.setString(i++, (String) param);
|
||||||
|
} else if (param instanceof Integer) {
|
||||||
|
stmt.setInt(i++, (Integer) param);
|
||||||
|
} else if (param instanceof Long) {
|
||||||
|
stmt.setLong(i++, (Long) param);
|
||||||
|
} else if (param instanceof Double) {
|
||||||
|
stmt.setDouble(i++, (Double) param);
|
||||||
|
} else if (param instanceof Float) {
|
||||||
|
stmt.setFloat(i++, (Float) param);
|
||||||
|
} else if (param instanceof Timestamp) {
|
||||||
|
stmt.setTimestamp(i++, (Timestamp) param);
|
||||||
|
} else if (param instanceof Boolean) {
|
||||||
|
stmt.setBoolean(i++, (Boolean) param);
|
||||||
|
} else if (param instanceof List) {
|
||||||
|
stmt.setArray(i++, connection.createArrayOf("text", ((List<?>) param).toArray()));
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ResultSet resultSet = stmt.executeQuery();
|
||||||
|
List<MessageData> resultList = new ArrayList<>();
|
||||||
|
while (resultSet.next()) {
|
||||||
|
Array array = resultSet.getArray(8);
|
||||||
|
MessageData messageData = new MessageData(
|
||||||
|
resultSet.getLong(2),
|
||||||
|
resultSet.getLong(3),
|
||||||
|
resultSet.getLong(4),
|
||||||
|
resultSet.getString(5),
|
||||||
|
resultSet.getString(6),
|
||||||
|
resultSet.getString(7),
|
||||||
|
new HashSet<>(Arrays.asList((String[]) array.getArray())),
|
||||||
|
resultSet.getTimestamp(9)
|
||||||
|
);
|
||||||
|
resultList.add(messageData);
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean tableExists(String tableName) {
|
||||||
|
try {
|
||||||
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
|
Statement statement = connection.createStatement();
|
||||||
|
ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'");
|
||||||
|
return resultSet.next();
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error("Failed to check if table exists; {}" + ex.getMessage());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package com.bensherriff.siren.database;
|
||||||
|
|
||||||
|
import java.sql.Timestamp;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public class MessageData {
|
||||||
|
private final Long guildId;
|
||||||
|
private final Long threadId;
|
||||||
|
private final Long userId;
|
||||||
|
private final String messageType;
|
||||||
|
private final String messageText;
|
||||||
|
private final String messageResponse;
|
||||||
|
private final Set<String> topics;
|
||||||
|
private Timestamp timestamp;
|
||||||
|
|
||||||
|
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
||||||
|
String messageResponse, Set<String> topics) {
|
||||||
|
this.guildId = guildId;
|
||||||
|
this.threadId = threadId;
|
||||||
|
this.userId = userId;
|
||||||
|
this.messageType = messageType;
|
||||||
|
this.messageText = messageText;
|
||||||
|
this.messageResponse = messageResponse;
|
||||||
|
this.topics = topics;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
||||||
|
String messageResponse, Set<String> topics, Timestamp timestamp) {
|
||||||
|
this.guildId = guildId;
|
||||||
|
this.threadId = threadId;
|
||||||
|
this.userId = userId;
|
||||||
|
this.messageType = messageType;
|
||||||
|
this.messageText = messageText;
|
||||||
|
this.messageResponse = messageResponse;
|
||||||
|
this.topics = topics;
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Long getGuildId() {
|
||||||
|
return guildId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Long getThreadId() {
|
||||||
|
return threadId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Long getUserId() {
|
||||||
|
return userId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getMessageType() {
|
||||||
|
return messageType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getMessageText() {
|
||||||
|
return messageText;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getMessageResponse() {
|
||||||
|
return messageResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<String> getTopics() {
|
||||||
|
return topics;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Timestamp getTimestamp() {
|
||||||
|
return timestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTimestamp(Timestamp timestamp) {
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package com.bensherriff.siren.database;
|
||||||
|
|
||||||
|
public class QueryBuilder {
|
||||||
|
private boolean distinct;
|
||||||
|
private String columnList;
|
||||||
|
private String tableName;
|
||||||
|
private String whereClause;
|
||||||
|
private String orderByClause;
|
||||||
|
private Integer limit;
|
||||||
|
|
||||||
|
public QueryBuilder select(String columnList) {
|
||||||
|
this.columnList = columnList;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public QueryBuilder from(String tableName) {
|
||||||
|
this.tableName = tableName;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public QueryBuilder where(String whereClause) {
|
||||||
|
this.whereClause = whereClause;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public QueryBuilder orderBy(String orderByClause) {
|
||||||
|
this.orderByClause = orderByClause;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public QueryBuilder limit(int limit) {
|
||||||
|
this.limit = limit;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public QueryBuilder distinct(boolean distinct) {
|
||||||
|
this.distinct = distinct;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String build() {
|
||||||
|
StringBuilder queryBuilder = new StringBuilder();
|
||||||
|
queryBuilder.append("SELECT ");
|
||||||
|
if (distinct) {
|
||||||
|
queryBuilder.append("DISTINCT ");
|
||||||
|
}
|
||||||
|
if (columnList != null && !columnList.isEmpty()) {
|
||||||
|
queryBuilder.append(columnList);
|
||||||
|
} else {
|
||||||
|
queryBuilder.append("*");
|
||||||
|
}
|
||||||
|
queryBuilder.append(" FROM ");
|
||||||
|
queryBuilder.append(tableName);
|
||||||
|
if (whereClause != null) {
|
||||||
|
queryBuilder.append(" WHERE ");
|
||||||
|
queryBuilder.append(whereClause);
|
||||||
|
}
|
||||||
|
if (orderByClause != null) {
|
||||||
|
queryBuilder.append(" ORDER BY ");
|
||||||
|
queryBuilder.append(orderByClause);
|
||||||
|
}
|
||||||
|
if (limit != null) {
|
||||||
|
queryBuilder.append(" LIMIT ");
|
||||||
|
queryBuilder.append(limit);
|
||||||
|
}
|
||||||
|
return queryBuilder.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
95
src/main/java/com/bensherriff/siren/openai/NLP.java
Normal file
95
src/main/java/com/bensherriff/siren/openai/NLP.java
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package com.bensherriff.siren.openai;
|
||||||
|
|
||||||
|
import edu.stanford.nlp.ling.CoreAnnotations;
|
||||||
|
import edu.stanford.nlp.ling.CoreLabel;
|
||||||
|
import edu.stanford.nlp.pipeline.Annotation;
|
||||||
|
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
|
||||||
|
import edu.stanford.nlp.util.CoreMap;
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class NLP {
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(NLP.class);
|
||||||
|
private final StanfordCoreNLP pipeline;
|
||||||
|
private final Map<String, List<String>> keywords;
|
||||||
|
|
||||||
|
public NLP() {
|
||||||
|
Properties props = new Properties();
|
||||||
|
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner");
|
||||||
|
pipeline = new StanfordCoreNLP(props);
|
||||||
|
keywords = new HashMap<>();
|
||||||
|
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<String> getTopicKeywords(String sentence) {
|
||||||
|
Set<String> topics = new LinkedHashSet<>();
|
||||||
|
Annotation document = new Annotation(sentence);
|
||||||
|
pipeline.annotate(document);
|
||||||
|
|
||||||
|
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
||||||
|
CoreMap sentenceMap = sentences.get(0);
|
||||||
|
List<CoreLabel> tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
|
||||||
|
List<CoreLabel> namedEntities = new ArrayList<>();
|
||||||
|
|
||||||
|
for (CoreLabel token : tokens) {
|
||||||
|
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||||
|
if (!ne.equals("0")) {
|
||||||
|
namedEntities.add(token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (CoreLabel namedEntity : namedEntities) {
|
||||||
|
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||||
|
String word = namedEntity.word();
|
||||||
|
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
|
||||||
|
topics.add(word);
|
||||||
|
} else if (ne.equals("LOCATION")) {
|
||||||
|
String[] posTags = word.split("_");
|
||||||
|
for (String posTag : posTags) {
|
||||||
|
if (posTag.startsWith("N")) {
|
||||||
|
topics.add(word);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
|
||||||
|
if (pos.startsWith("NN")) {
|
||||||
|
topics.add(word);
|
||||||
|
for (String keyword : keywords.keySet()) {
|
||||||
|
if (keywords.get(keyword).contains(word.toLowerCase())) {
|
||||||
|
topics.add(keyword);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
|
||||||
|
topics.add("dnd");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return topics;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> lemmatize(String documentText) {
|
||||||
|
List<String> lemmas = new ArrayList<>();
|
||||||
|
Annotation document = new Annotation(documentText);
|
||||||
|
pipeline.annotate(document);
|
||||||
|
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
||||||
|
for (CoreMap sentence : sentences) {
|
||||||
|
for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) {
|
||||||
|
lemmas.add(token.get(CoreAnnotations.LemmaAnnotation.class));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lemmas;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO finish this method
|
||||||
|
public Set<String> getSynonyms(String targetWord) {
|
||||||
|
return Collections.emptySet();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO finish this method
|
||||||
|
public static double calculateSimilarity(String sentence1, String sentence2) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,8 @@ package com.bensherriff.siren.openai;
|
|||||||
|
|
||||||
import com.bensherriff.siren.Listener;
|
import com.bensherriff.siren.Listener;
|
||||||
import com.bensherriff.siren.database.DatabaseManager;
|
import com.bensherriff.siren.database.DatabaseManager;
|
||||||
|
import com.bensherriff.siren.database.MessageData;
|
||||||
|
import com.bensherriff.siren.database.QueryBuilder;
|
||||||
import com.bensherriff.siren.settings.GuildSettings;
|
import com.bensherriff.siren.settings.GuildSettings;
|
||||||
import com.bensherriff.siren.settings.Settings;
|
import com.bensherriff.siren.settings.Settings;
|
||||||
import com.theokanning.openai.completion.CompletionRequest;
|
import com.theokanning.openai.completion.CompletionRequest;
|
||||||
@@ -11,7 +13,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
|||||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||||
import com.theokanning.openai.embedding.Embedding;
|
import com.theokanning.openai.embedding.Embedding;
|
||||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||||
import com.theokanning.openai.embedding.EmbeddingResult;
|
|
||||||
import com.theokanning.openai.service.OpenAiService;
|
import com.theokanning.openai.service.OpenAiService;
|
||||||
import net.dv8tion.jda.api.JDA;
|
import net.dv8tion.jda.api.JDA;
|
||||||
import net.dv8tion.jda.api.entities.Message;
|
import net.dv8tion.jda.api.entities.Message;
|
||||||
@@ -21,6 +22,7 @@ import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
|
|||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
|
|
||||||
|
import java.sql.SQLException;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
@@ -32,11 +34,15 @@ public class OpenAIManager {
|
|||||||
private final Settings settings;
|
private final Settings settings;
|
||||||
private final JDA jda;
|
private final JDA jda;
|
||||||
private final ScheduledExecutorService executor;
|
private final ScheduledExecutorService executor;
|
||||||
|
private final String owner;
|
||||||
|
private final NLP NLP;
|
||||||
|
|
||||||
public OpenAIManager(Listener listener) {
|
public OpenAIManager(Listener listener) {
|
||||||
this.settings = listener.getSettings();
|
this.settings = listener.getSettings();
|
||||||
this.jda = listener.getJDA();
|
this.jda = listener.getJDA();
|
||||||
this.executor = listener.getExecutor();
|
this.executor = listener.getExecutor();
|
||||||
|
this.owner = listener.getOwner();
|
||||||
|
this.NLP = new NLP();
|
||||||
|
|
||||||
if (settings.getOpenAISettings().getToken().isEmpty()) {
|
if (settings.getOpenAISettings().getToken().isEmpty()) {
|
||||||
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
|
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
|
||||||
@@ -56,19 +62,13 @@ public class OpenAIManager {
|
|||||||
if (openAiService != null) {
|
if (openAiService != null) {
|
||||||
executor.execute(() -> sendMessage(event));
|
executor.execute(() -> sendMessage(event));
|
||||||
} else {
|
} else {
|
||||||
event.getMessage().reply("OpenAI functionality is disabled. Please contact an administrator").queue();
|
event.getMessage().reply("OpenAI functionality is disabled. Please contact " + owner + ".").queue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void sendMessage(MessageReceivedEvent event) {
|
private void sendMessage(MessageReceivedEvent event) {
|
||||||
String message = parseMessage(event.getMessage().getContentRaw());
|
String message = parseMessage(event.getMessage().getContentRaw());
|
||||||
long guildId = event.getGuild().getIdLong();
|
|
||||||
Model model = settings.getGuildSettings().get(guildId).getModel();
|
|
||||||
StringBuilder stringBuilder = new StringBuilder();
|
|
||||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
|
||||||
ChatMessage chatMessage = createChatMessage(message, event);
|
|
||||||
List<Embedding> embeddings = new ArrayList<>();
|
|
||||||
|
|
||||||
if (message.isEmpty() || message.isBlank()) {
|
if (message.isEmpty() || message.isBlank()) {
|
||||||
event.getMessage().reply("Your message is empty. Please try again").queue();
|
event.getMessage().reply("Your message is empty. Please try again").queue();
|
||||||
@@ -76,7 +76,13 @@ public class OpenAIManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
StringBuilder stringBuilder = new StringBuilder();
|
||||||
|
long guildId = event.getGuild().getIdLong();
|
||||||
|
Model model = settings.getGuildSettings().get(guildId).getModel();
|
||||||
|
ChatMessage chatMessage = createChatMessage(message, event);
|
||||||
|
List<Embedding> embeddings = new ArrayList<>();
|
||||||
|
|
||||||
|
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, event.getAuthor().getId(), event.getMessageId(), message);
|
||||||
// Send OpenAI Message and get response
|
// Send OpenAI Message and get response
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||||
@@ -85,8 +91,7 @@ public class OpenAIManager {
|
|||||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||||
}
|
}
|
||||||
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
||||||
chatMessages.add(chatMessage);
|
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
||||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
|
|
||||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||||
|
|
||||||
@@ -96,7 +101,7 @@ public class OpenAIManager {
|
|||||||
// embeddings.addAll(embeddingResult.getData());
|
// embeddings.addAll(embeddingResult.getData());
|
||||||
}
|
}
|
||||||
default -> {
|
default -> {
|
||||||
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
|
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
|
||||||
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
|
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
|
||||||
model, Arrays.toString(Model.values()));
|
model, Arrays.toString(Model.values()));
|
||||||
return;
|
return;
|
||||||
@@ -105,7 +110,7 @@ public class OpenAIManager {
|
|||||||
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
||||||
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
|
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,8 +123,30 @@ public class OpenAIManager {
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
|
||||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||||
|
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||||
|
chatMessages.add(chatMessage);
|
||||||
|
|
||||||
|
// Handle System Messages
|
||||||
|
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
|
||||||
|
chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName()));
|
||||||
|
if (event.isFromThread()) {
|
||||||
|
String query = new QueryBuilder().from("messages")
|
||||||
|
.where("guild_id = ? AND thread_id = ?")
|
||||||
|
.orderBy("timestamp DESC")
|
||||||
|
.limit(10)
|
||||||
|
.build();
|
||||||
|
List<MessageData> previousMessages = DatabaseManager.getMessages(
|
||||||
|
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
|
||||||
|
for (MessageData previousMessage : previousMessages) {
|
||||||
|
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
|
||||||
|
previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() +
|
||||||
|
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
|
||||||
|
chatMessages.add(previousChatMessage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ChatCompletionRequest.builder()
|
return ChatCompletionRequest.builder()
|
||||||
.model(guildSettings.getModel().getName())
|
.model(guildSettings.getModel().getName())
|
||||||
.maxTokens(guildSettings.getMaxTokens())
|
.maxTokens(guildSettings.getMaxTokens())
|
||||||
@@ -146,6 +173,10 @@ public class OpenAIManager {
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ChatMessage createSystemMessage(String message) {
|
||||||
|
return new ChatMessage(Role.SYSTEM.getName(), message);
|
||||||
|
}
|
||||||
|
|
||||||
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
|
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
|
||||||
ChatMessage chatMessage = new ChatMessage();
|
ChatMessage chatMessage = new ChatMessage();
|
||||||
chatMessage.setContent(message);
|
chatMessage.setContent(message);
|
||||||
@@ -158,33 +189,38 @@ public class OpenAIManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
||||||
|
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
|
||||||
|
Set<String> topics = new LinkedHashSet<>();
|
||||||
|
topics.addAll(NLP.getTopicKeywords(chatMessage.getContent()));
|
||||||
|
topics.addAll(NLP.getTopicKeywords(response));
|
||||||
|
LOGGER.trace("Topics: {}", topics);
|
||||||
|
|
||||||
if (event.isFromThread()) {
|
if (event.isFromThread()) {
|
||||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
storeMessage(chatMessage, event, response, embeddings, topics, channel);
|
||||||
channel.getIdLong(), response);
|
|
||||||
embeddings.forEach(embedding -> {
|
|
||||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
|
||||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
|
||||||
});
|
|
||||||
channel.sendMessage(response).queue();
|
|
||||||
} else {
|
} else {
|
||||||
// The max discord title length is 100 characters
|
// The max discord title length is 100 characters
|
||||||
String threadTitle = chatMessage.getContent();
|
String threadTitle = chatMessage.getContent();
|
||||||
if (chatMessage.getContent().length() > 100) {
|
if (chatMessage.getContent().length() > 100) {
|
||||||
threadTitle = chatMessage.getContent().substring(0, 100);
|
threadTitle = chatMessage.getContent().substring(0, 100);
|
||||||
}
|
}
|
||||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel));
|
||||||
channel.sendMessage(response).queue();
|
|
||||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
|
||||||
channel.getIdLong(), response);
|
|
||||||
embeddings.forEach(embedding -> {
|
|
||||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
|
||||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response,
|
||||||
|
List<Embedding> embeddings, Set<String> topics, ThreadChannel channel) {
|
||||||
|
MessageData messageData = new MessageData(event.getGuild().getIdLong(),
|
||||||
|
channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(),
|
||||||
|
response, topics);
|
||||||
|
int messageRow = DatabaseManager.storeMessage(messageData);
|
||||||
|
embeddings.forEach(embedding -> {
|
||||||
|
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||||
|
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||||
|
});
|
||||||
|
channel.sendMessage(response).queue();
|
||||||
|
}
|
||||||
|
|
||||||
private String parseMessage(String input) {
|
private String parseMessage(String input) {
|
||||||
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user