v0.1.16 Query builder and working on memory management
This commit is contained in:
19
pom.xml
19
pom.xml
@@ -61,13 +61,14 @@
|
||||
</distributionManagement>
|
||||
|
||||
<properties>
|
||||
<jda.version>5.0.0-beta.6</jda.version>
|
||||
<jda.version>5.0.0-beta.8</jda.version>
|
||||
<lavaplayer.version>1.4.0</lavaplayer.version>
|
||||
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
|
||||
<jackson.version>2.14.2</jackson.version>
|
||||
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
||||
<postgresql.version>42.6.0</postgresql.version>
|
||||
<slf4j.version>2.0.6</slf4j.version>
|
||||
<corenlp.version>4.5.3</corenlp.version>
|
||||
<slf4j.version>2.0.7</slf4j.version>
|
||||
<log4j.version>2.20.0</log4j.version>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>17</maven.compiler.source>
|
||||
@@ -105,6 +106,18 @@
|
||||
<artifactId>postgresql</artifactId>
|
||||
<version>${postgresql.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>edu.stanford.nlp</groupId>
|
||||
<artifactId>stanford-corenlp</artifactId>
|
||||
<version>${corenlp.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>edu.stanford.nlp</groupId>
|
||||
<artifactId>stanford-corenlp</artifactId>
|
||||
<version>${corenlp.version}</version>
|
||||
<classifier>models</classifier>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Logging -->
|
||||
<dependency>
|
||||
@@ -150,7 +163,7 @@
|
||||
</transformer>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<manifestEntries>
|
||||
<Main-Class>com.bensherriff.siren.Bot</Main-Class>
|
||||
<Main-Class>com.bensherriff.siren.Main</Main-Class>
|
||||
<Specification-Title>${project.artifactId}</Specification-Title>
|
||||
<Specification-Version>${project.version}</Specification-Version>
|
||||
<Implementation-Title>${project.artifactId}</Implementation-Title>
|
||||
|
||||
@@ -34,14 +34,15 @@ public class Listener extends ListenerAdapter {
|
||||
|
||||
private final Settings settings;
|
||||
private final Map<String, Command> commands = new HashMap<>();
|
||||
private final String owner;
|
||||
private PlayerManager playerManager;
|
||||
private OpenAIManager openAIManager;
|
||||
private JDA jda;
|
||||
|
||||
public Listener(Settings settings) {
|
||||
this.settings = settings;
|
||||
|
||||
this.executor = Executors.newScheduledThreadPool(this.settings.getThreadPool());
|
||||
this.owner = "@bsherriff";
|
||||
}
|
||||
|
||||
public ScheduledExecutorService getExecutor() {
|
||||
@@ -64,6 +65,10 @@ public class Listener extends ListenerAdapter {
|
||||
this.jda = jda;
|
||||
}
|
||||
|
||||
public String getOwner() {
|
||||
return owner;
|
||||
}
|
||||
|
||||
public void initialize() {
|
||||
this.playerManager = new PlayerManager(this);
|
||||
this.playerManager.initialize();
|
||||
@@ -78,6 +83,7 @@ public class Listener extends ListenerAdapter {
|
||||
commands.put("pause", new PauseCommand(this));
|
||||
commands.put("resume", new ResumeCommand(this));
|
||||
commands.put("help", new PauseCommand(this));
|
||||
|
||||
}
|
||||
|
||||
public void closeAudioConnection(long guildID) {
|
||||
@@ -144,7 +150,7 @@ public class Listener extends ListenerAdapter {
|
||||
commands.get(command).execute(event);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
event.getHook().sendMessage("An error occurred while processing your command. Please contact your administrator.").queue();
|
||||
event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue();
|
||||
}
|
||||
});
|
||||
} else {
|
||||
|
||||
@@ -13,8 +13,8 @@ import org.apache.logging.log4j.Logger;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class Bot {
|
||||
private static final Logger LOGGER = LogManager.getLogger(Bot.class);
|
||||
public class Main {
|
||||
private static final Logger LOGGER = LogManager.getLogger(Main.class);
|
||||
private final static GatewayIntent[] INTENTS = {
|
||||
GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS,
|
||||
GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT
|
||||
@@ -89,7 +89,7 @@ public class PlayCommand extends Command {
|
||||
if (exception.getMessage().contains("Unknown file format")) {
|
||||
event.getHook().sendMessage(errorMsg + ". " + exception.getMessage()).queue();
|
||||
} else {
|
||||
event.getHook().sendMessage(errorMsg + ". Please contact your administrator.").queue();
|
||||
event.getHook().sendMessage(errorMsg + ". Please contact " + listener.getOwner() + ".").queue();
|
||||
}
|
||||
LOGGER.error("{}: {}", errorMsg, exception.getMessage());
|
||||
}
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.sql.*;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
import static java.util.Map.entry;
|
||||
|
||||
public class DatabaseManager {
|
||||
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
|
||||
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
|
||||
|
||||
private static final Map<String, String> createTableQueries = Map.ofEntries(
|
||||
entry("messages", "CREATE TABLE IF NOT EXISTS messages (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"guild_id BIGINT NOT NULL, " +
|
||||
"thread_id BIGINT NOT NULL, " +
|
||||
@@ -17,26 +20,28 @@ public class DatabaseManager {
|
||||
"message_type VARCHAR(20) NOT NULL," +
|
||||
"message_text TEXT NOT NULL," +
|
||||
"message_response TEXT, " +
|
||||
"topics TEXT[], " +
|
||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
|
||||
")";
|
||||
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
|
||||
")"),
|
||||
entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"embeddings FLOAT[] NOT NULL" +
|
||||
");";
|
||||
|
||||
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
||||
")"),
|
||||
entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"message_id INT NOT NULL, " +
|
||||
"embedding_id INT NOT NULL" +
|
||||
")";
|
||||
")")
|
||||
);
|
||||
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||
"message_type, " +
|
||||
"guild_id, " +
|
||||
"thread_id, " +
|
||||
"user_id, " +
|
||||
"message_text, " +
|
||||
"message_response) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?)";
|
||||
"message_response, " +
|
||||
"topics) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)";
|
||||
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
||||
"embeddings) " +
|
||||
"VALUES (?)";
|
||||
@@ -47,52 +52,47 @@ public class DatabaseManager {
|
||||
"VALUES (?, ?)";
|
||||
|
||||
public static void createTables() {
|
||||
createMessageTable();
|
||||
createEmbeddingsTable();
|
||||
createMessageEmbeddingsTable();
|
||||
for (Map.Entry<String, String> entry : createTableQueries.entrySet()) {
|
||||
if (!createTable(entry.getKey())) {
|
||||
LOGGER.warn("Failed to create one or more required database tables");
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOGGER.debug("Successfully created database tables");
|
||||
}
|
||||
|
||||
private static void createMessageTable() {
|
||||
private static boolean createTable(String tableName) {
|
||||
if (tableExists(tableName)) {
|
||||
return true;
|
||||
} else {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'messages' database table if it does not exist");
|
||||
LOGGER.debug("Creating '{}' database table if it does not exist", tableName);
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_MESSAGE_TABLE);
|
||||
statement.execute(createTableQueries.get(tableName));
|
||||
return true;
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void createEmbeddingsTable() {
|
||||
public static int storeMessage(MessageData messageData) {
|
||||
if (!tableExists("messages")) {
|
||||
LOGGER.warn("Table 'messages' does not exist");
|
||||
return -1;
|
||||
}
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_EMBEDDINGS_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static void createMessageEmbeddingsTable() {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
|
||||
preparedStatement.setString(1, message.getRole());
|
||||
preparedStatement.setLong(2, guildId);
|
||||
preparedStatement.setLong(3, threadId);
|
||||
preparedStatement.setLong(4, userId);
|
||||
preparedStatement.setString(5, message.getContent());
|
||||
preparedStatement.setString(6, response);
|
||||
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE);
|
||||
preparedStatement.setString(1, messageData.getMessageType());
|
||||
preparedStatement.setLong(2, messageData.getGuildId());
|
||||
preparedStatement.setLong(3, messageData.getThreadId());
|
||||
preparedStatement.setLong(4, messageData.getUserId());
|
||||
preparedStatement.setString(5, messageData.getMessageText());
|
||||
preparedStatement.setString(6, messageData.getMessageResponse());
|
||||
preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray()));
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
@@ -101,7 +101,13 @@ public class DatabaseManager {
|
||||
}
|
||||
|
||||
public static int storeEmbedding(List<Double> data) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
|
||||
if (!tableExists("embeddings")) {
|
||||
LOGGER.warn("Table 'embeddings' does not exist");
|
||||
return -1;
|
||||
}
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING);
|
||||
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
@@ -111,7 +117,13 @@ public class DatabaseManager {
|
||||
}
|
||||
|
||||
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
|
||||
if (!tableExists("message_embeddings")) {
|
||||
LOGGER.warn("Table 'message_embeddings' does not exist");
|
||||
return -1;
|
||||
}
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS);
|
||||
preparedStatement.setInt(1, messageId);
|
||||
preparedStatement.setInt(2, embeddingId);
|
||||
return preparedStatement.executeUpdate();
|
||||
@@ -120,4 +132,61 @@ public class DatabaseManager {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
public static List<MessageData> getMessages(String query, Object... params) throws SQLException, IllegalArgumentException {
|
||||
LOGGER.trace("Query: <{}>", query);
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
PreparedStatement stmt = connection.prepareStatement(query);
|
||||
int i = 1;
|
||||
for (Object param : params) {
|
||||
if (param instanceof String) {
|
||||
stmt.setString(i++, (String) param);
|
||||
} else if (param instanceof Integer) {
|
||||
stmt.setInt(i++, (Integer) param);
|
||||
} else if (param instanceof Long) {
|
||||
stmt.setLong(i++, (Long) param);
|
||||
} else if (param instanceof Double) {
|
||||
stmt.setDouble(i++, (Double) param);
|
||||
} else if (param instanceof Float) {
|
||||
stmt.setFloat(i++, (Float) param);
|
||||
} else if (param instanceof Timestamp) {
|
||||
stmt.setTimestamp(i++, (Timestamp) param);
|
||||
} else if (param instanceof Boolean) {
|
||||
stmt.setBoolean(i++, (Boolean) param);
|
||||
} else if (param instanceof List) {
|
||||
stmt.setArray(i++, connection.createArrayOf("text", ((List<?>) param).toArray()));
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName());
|
||||
}
|
||||
}
|
||||
ResultSet resultSet = stmt.executeQuery();
|
||||
List<MessageData> resultList = new ArrayList<>();
|
||||
while (resultSet.next()) {
|
||||
Array array = resultSet.getArray(8);
|
||||
MessageData messageData = new MessageData(
|
||||
resultSet.getLong(2),
|
||||
resultSet.getLong(3),
|
||||
resultSet.getLong(4),
|
||||
resultSet.getString(5),
|
||||
resultSet.getString(6),
|
||||
resultSet.getString(7),
|
||||
new HashSet<>(Arrays.asList((String[]) array.getArray())),
|
||||
resultSet.getTimestamp(9)
|
||||
);
|
||||
resultList.add(messageData);
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
private static boolean tableExists(String tableName) {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
Statement statement = connection.createStatement();
|
||||
ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'");
|
||||
return resultSet.next();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error("Failed to check if table exists; {}" + ex.getMessage());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.Set;
|
||||
|
||||
public class MessageData {
|
||||
private final Long guildId;
|
||||
private final Long threadId;
|
||||
private final Long userId;
|
||||
private final String messageType;
|
||||
private final String messageText;
|
||||
private final String messageResponse;
|
||||
private final Set<String> topics;
|
||||
private Timestamp timestamp;
|
||||
|
||||
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
||||
String messageResponse, Set<String> topics) {
|
||||
this.guildId = guildId;
|
||||
this.threadId = threadId;
|
||||
this.userId = userId;
|
||||
this.messageType = messageType;
|
||||
this.messageText = messageText;
|
||||
this.messageResponse = messageResponse;
|
||||
this.topics = topics;
|
||||
}
|
||||
|
||||
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
||||
String messageResponse, Set<String> topics, Timestamp timestamp) {
|
||||
this.guildId = guildId;
|
||||
this.threadId = threadId;
|
||||
this.userId = userId;
|
||||
this.messageType = messageType;
|
||||
this.messageText = messageText;
|
||||
this.messageResponse = messageResponse;
|
||||
this.topics = topics;
|
||||
this.timestamp = timestamp;
|
||||
}
|
||||
|
||||
public Long getGuildId() {
|
||||
return guildId;
|
||||
}
|
||||
|
||||
public Long getThreadId() {
|
||||
return threadId;
|
||||
}
|
||||
|
||||
public Long getUserId() {
|
||||
return userId;
|
||||
}
|
||||
|
||||
public String getMessageType() {
|
||||
return messageType;
|
||||
}
|
||||
|
||||
public String getMessageText() {
|
||||
return messageText;
|
||||
}
|
||||
|
||||
public String getMessageResponse() {
|
||||
return messageResponse;
|
||||
}
|
||||
|
||||
public Set<String> getTopics() {
|
||||
return topics;
|
||||
}
|
||||
|
||||
public Timestamp getTimestamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
public void setTimestamp(Timestamp timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
public class QueryBuilder {
|
||||
private boolean distinct;
|
||||
private String columnList;
|
||||
private String tableName;
|
||||
private String whereClause;
|
||||
private String orderByClause;
|
||||
private Integer limit;
|
||||
|
||||
public QueryBuilder select(String columnList) {
|
||||
this.columnList = columnList;
|
||||
return this;
|
||||
}
|
||||
|
||||
public QueryBuilder from(String tableName) {
|
||||
this.tableName = tableName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public QueryBuilder where(String whereClause) {
|
||||
this.whereClause = whereClause;
|
||||
return this;
|
||||
}
|
||||
|
||||
public QueryBuilder orderBy(String orderByClause) {
|
||||
this.orderByClause = orderByClause;
|
||||
return this;
|
||||
}
|
||||
|
||||
public QueryBuilder limit(int limit) {
|
||||
this.limit = limit;
|
||||
return this;
|
||||
}
|
||||
|
||||
public QueryBuilder distinct(boolean distinct) {
|
||||
this.distinct = distinct;
|
||||
return this;
|
||||
}
|
||||
|
||||
public String build() {
|
||||
StringBuilder queryBuilder = new StringBuilder();
|
||||
queryBuilder.append("SELECT ");
|
||||
if (distinct) {
|
||||
queryBuilder.append("DISTINCT ");
|
||||
}
|
||||
if (columnList != null && !columnList.isEmpty()) {
|
||||
queryBuilder.append(columnList);
|
||||
} else {
|
||||
queryBuilder.append("*");
|
||||
}
|
||||
queryBuilder.append(" FROM ");
|
||||
queryBuilder.append(tableName);
|
||||
if (whereClause != null) {
|
||||
queryBuilder.append(" WHERE ");
|
||||
queryBuilder.append(whereClause);
|
||||
}
|
||||
if (orderByClause != null) {
|
||||
queryBuilder.append(" ORDER BY ");
|
||||
queryBuilder.append(orderByClause);
|
||||
}
|
||||
if (limit != null) {
|
||||
queryBuilder.append(" LIMIT ");
|
||||
queryBuilder.append(limit);
|
||||
}
|
||||
return queryBuilder.toString();
|
||||
}
|
||||
}
|
||||
95
src/main/java/com/bensherriff/siren/openai/NLP.java
Normal file
95
src/main/java/com/bensherriff/siren/openai/NLP.java
Normal file
@@ -0,0 +1,95 @@
|
||||
package com.bensherriff.siren.openai;
|
||||
|
||||
import edu.stanford.nlp.ling.CoreAnnotations;
|
||||
import edu.stanford.nlp.ling.CoreLabel;
|
||||
import edu.stanford.nlp.pipeline.Annotation;
|
||||
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
|
||||
import edu.stanford.nlp.util.CoreMap;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class NLP {
|
||||
private static final Logger LOGGER = LogManager.getLogger(NLP.class);
|
||||
private final StanfordCoreNLP pipeline;
|
||||
private final Map<String, List<String>> keywords;
|
||||
|
||||
public NLP() {
|
||||
Properties props = new Properties();
|
||||
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner");
|
||||
pipeline = new StanfordCoreNLP(props);
|
||||
keywords = new HashMap<>();
|
||||
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
|
||||
}
|
||||
|
||||
public Set<String> getTopicKeywords(String sentence) {
|
||||
Set<String> topics = new LinkedHashSet<>();
|
||||
Annotation document = new Annotation(sentence);
|
||||
pipeline.annotate(document);
|
||||
|
||||
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
||||
CoreMap sentenceMap = sentences.get(0);
|
||||
List<CoreLabel> tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
|
||||
List<CoreLabel> namedEntities = new ArrayList<>();
|
||||
|
||||
for (CoreLabel token : tokens) {
|
||||
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||
if (!ne.equals("0")) {
|
||||
namedEntities.add(token);
|
||||
}
|
||||
}
|
||||
for (CoreLabel namedEntity : namedEntities) {
|
||||
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||
String word = namedEntity.word();
|
||||
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
|
||||
topics.add(word);
|
||||
} else if (ne.equals("LOCATION")) {
|
||||
String[] posTags = word.split("_");
|
||||
for (String posTag : posTags) {
|
||||
if (posTag.startsWith("N")) {
|
||||
topics.add(word);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
|
||||
if (pos.startsWith("NN")) {
|
||||
topics.add(word);
|
||||
for (String keyword : keywords.keySet()) {
|
||||
if (keywords.get(keyword).contains(word.toLowerCase())) {
|
||||
topics.add(keyword);
|
||||
}
|
||||
}
|
||||
if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
|
||||
topics.add("dnd");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return topics;
|
||||
}
|
||||
|
||||
public List<String> lemmatize(String documentText) {
|
||||
List<String> lemmas = new ArrayList<>();
|
||||
Annotation document = new Annotation(documentText);
|
||||
pipeline.annotate(document);
|
||||
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
||||
for (CoreMap sentence : sentences) {
|
||||
for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) {
|
||||
lemmas.add(token.get(CoreAnnotations.LemmaAnnotation.class));
|
||||
}
|
||||
}
|
||||
return lemmas;
|
||||
}
|
||||
|
||||
// TODO finish this method
|
||||
public Set<String> getSynonyms(String targetWord) {
|
||||
return Collections.emptySet();
|
||||
}
|
||||
|
||||
// TODO finish this method
|
||||
public static double calculateSimilarity(String sentence1, String sentence2) {
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package com.bensherriff.siren.openai;
|
||||
|
||||
import com.bensherriff.siren.Listener;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.database.MessageData;
|
||||
import com.bensherriff.siren.database.QueryBuilder;
|
||||
import com.bensherriff.siren.settings.GuildSettings;
|
||||
import com.bensherriff.siren.settings.Settings;
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
@@ -11,7 +13,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import com.theokanning.openai.embedding.Embedding;
|
||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||
import com.theokanning.openai.embedding.EmbeddingResult;
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
import net.dv8tion.jda.api.JDA;
|
||||
import net.dv8tion.jda.api.entities.Message;
|
||||
@@ -21,6 +22,7 @@ import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
@@ -32,11 +34,15 @@ public class OpenAIManager {
|
||||
private final Settings settings;
|
||||
private final JDA jda;
|
||||
private final ScheduledExecutorService executor;
|
||||
private final String owner;
|
||||
private final NLP NLP;
|
||||
|
||||
public OpenAIManager(Listener listener) {
|
||||
this.settings = listener.getSettings();
|
||||
this.jda = listener.getJDA();
|
||||
this.executor = listener.getExecutor();
|
||||
this.owner = listener.getOwner();
|
||||
this.NLP = new NLP();
|
||||
|
||||
if (settings.getOpenAISettings().getToken().isEmpty()) {
|
||||
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
|
||||
@@ -56,19 +62,13 @@ public class OpenAIManager {
|
||||
if (openAiService != null) {
|
||||
executor.execute(() -> sendMessage(event));
|
||||
} else {
|
||||
event.getMessage().reply("OpenAI functionality is disabled. Please contact an administrator").queue();
|
||||
event.getMessage().reply("OpenAI functionality is disabled. Please contact " + owner + ".").queue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void sendMessage(MessageReceivedEvent event) {
|
||||
String message = parseMessage(event.getMessage().getContentRaw());
|
||||
long guildId = event.getGuild().getIdLong();
|
||||
Model model = settings.getGuildSettings().get(guildId).getModel();
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
ChatMessage chatMessage = createChatMessage(message, event);
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
|
||||
if (message.isEmpty() || message.isBlank()) {
|
||||
event.getMessage().reply("Your message is empty. Please try again").queue();
|
||||
@@ -76,7 +76,13 @@ public class OpenAIManager {
|
||||
}
|
||||
|
||||
try {
|
||||
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
long guildId = event.getGuild().getIdLong();
|
||||
Model model = settings.getGuildSettings().get(guildId).getModel();
|
||||
ChatMessage chatMessage = createChatMessage(message, event);
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
|
||||
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, event.getAuthor().getId(), event.getMessageId(), message);
|
||||
// Send OpenAI Message and get response
|
||||
switch (model) {
|
||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||
@@ -85,8 +91,7 @@ public class OpenAIManager {
|
||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||
}
|
||||
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
||||
chatMessages.add(chatMessage);
|
||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
|
||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||
|
||||
@@ -96,7 +101,7 @@ public class OpenAIManager {
|
||||
// embeddings.addAll(embeddingResult.getData());
|
||||
}
|
||||
default -> {
|
||||
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
|
||||
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
|
||||
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
|
||||
model, Arrays.toString(Model.values()));
|
||||
return;
|
||||
@@ -105,7 +110,7 @@ public class OpenAIManager {
|
||||
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
||||
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
|
||||
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,8 +123,30 @@ public class OpenAIManager {
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
chatMessages.add(chatMessage);
|
||||
|
||||
// Handle System Messages
|
||||
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
|
||||
chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName()));
|
||||
if (event.isFromThread()) {
|
||||
String query = new QueryBuilder().from("messages")
|
||||
.where("guild_id = ? AND thread_id = ?")
|
||||
.orderBy("timestamp DESC")
|
||||
.limit(10)
|
||||
.build();
|
||||
List<MessageData> previousMessages = DatabaseManager.getMessages(
|
||||
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
|
||||
for (MessageData previousMessage : previousMessages) {
|
||||
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
|
||||
previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() +
|
||||
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
|
||||
chatMessages.add(previousChatMessage);
|
||||
}
|
||||
}
|
||||
|
||||
return ChatCompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
@@ -146,6 +173,10 @@ public class OpenAIManager {
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatMessage createSystemMessage(String message) {
|
||||
return new ChatMessage(Role.SYSTEM.getName(), message);
|
||||
}
|
||||
|
||||
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
|
||||
ChatMessage chatMessage = new ChatMessage();
|
||||
chatMessage.setContent(message);
|
||||
@@ -158,31 +189,36 @@ public class OpenAIManager {
|
||||
}
|
||||
|
||||
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
||||
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
|
||||
Set<String> topics = new LinkedHashSet<>();
|
||||
topics.addAll(NLP.getTopicKeywords(chatMessage.getContent()));
|
||||
topics.addAll(NLP.getTopicKeywords(response));
|
||||
LOGGER.trace("Topics: {}", topics);
|
||||
|
||||
if (event.isFromThread()) {
|
||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||
channel.getIdLong(), response);
|
||||
embeddings.forEach(embedding -> {
|
||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||
});
|
||||
channel.sendMessage(response).queue();
|
||||
storeMessage(chatMessage, event, response, embeddings, topics, channel);
|
||||
} else {
|
||||
// The max discord title length is 100 characters
|
||||
String threadTitle = chatMessage.getContent();
|
||||
if (chatMessage.getContent().length() > 100) {
|
||||
threadTitle = chatMessage.getContent().substring(0, 100);
|
||||
}
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||
channel.sendMessage(response).queue();
|
||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||
channel.getIdLong(), response);
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel));
|
||||
}
|
||||
}
|
||||
|
||||
private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response,
|
||||
List<Embedding> embeddings, Set<String> topics, ThreadChannel channel) {
|
||||
MessageData messageData = new MessageData(event.getGuild().getIdLong(),
|
||||
channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(),
|
||||
response, topics);
|
||||
int messageRow = DatabaseManager.storeMessage(messageData);
|
||||
embeddings.forEach(embedding -> {
|
||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||
});
|
||||
});
|
||||
}
|
||||
channel.sendMessage(response).queue();
|
||||
}
|
||||
|
||||
private String parseMessage(String input) {
|
||||
|
||||
Reference in New Issue
Block a user