diff --git a/.version b/.version index 4a0b6e2..d649612 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -export SIREN_VERSION=0.1.15 \ No newline at end of file +export SIREN_VERSION=0.1.16 \ No newline at end of file diff --git a/pom.xml b/pom.xml index 6fdb07b..0a366ab 100644 --- a/pom.xml +++ b/pom.xml @@ -61,13 +61,14 @@ - 5.0.0-beta.6 + 5.0.0-beta.8 1.4.0 1.3.13 2.14.2 0.12.0 42.6.0 - 2.0.6 + 4.5.3 + 2.0.7 2.20.0 UTF-8 17 @@ -105,6 +106,18 @@ postgresql ${postgresql.version} + + edu.stanford.nlp + stanford-corenlp + ${corenlp.version} + + + edu.stanford.nlp + stanford-corenlp + ${corenlp.version} + models + runtime + @@ -150,7 +163,7 @@ - com.bensherriff.siren.Bot + com.bensherriff.siren.Main ${project.artifactId} ${project.version} ${project.artifactId} diff --git a/src/main/java/com/bensherriff/siren/Listener.java b/src/main/java/com/bensherriff/siren/Listener.java index 7c3feb7..e781384 100644 --- a/src/main/java/com/bensherriff/siren/Listener.java +++ b/src/main/java/com/bensherriff/siren/Listener.java @@ -34,14 +34,15 @@ public class Listener extends ListenerAdapter { private final Settings settings; private final Map commands = new HashMap<>(); + private final String owner; private PlayerManager playerManager; private OpenAIManager openAIManager; private JDA jda; public Listener(Settings settings) { this.settings = settings; - this.executor = Executors.newScheduledThreadPool(this.settings.getThreadPool()); + this.owner = "@bsherriff"; } public ScheduledExecutorService getExecutor() { @@ -64,6 +65,10 @@ public class Listener extends ListenerAdapter { this.jda = jda; } + public String getOwner() { + return owner; + } + public void initialize() { this.playerManager = new PlayerManager(this); this.playerManager.initialize(); @@ -78,6 +83,7 @@ public class Listener extends ListenerAdapter { commands.put("pause", new PauseCommand(this)); commands.put("resume", new ResumeCommand(this)); commands.put("help", new PauseCommand(this)); + } public void closeAudioConnection(long guildID) { @@ -144,7 +150,7 @@ public class Listener extends ListenerAdapter { commands.get(command).execute(event); } catch (Exception ex) { LOGGER.error(ex.getMessage()); - event.getHook().sendMessage("An error occurred while processing your command. Please contact your administrator.").queue(); + event.getHook().sendMessage("An error occurred while processing your command. Please contact " + owner + ".").queue(); } }); } else { diff --git a/src/main/java/com/bensherriff/siren/Bot.java b/src/main/java/com/bensherriff/siren/Main.java similarity index 95% rename from src/main/java/com/bensherriff/siren/Bot.java rename to src/main/java/com/bensherriff/siren/Main.java index 59ef4b5..7e9b0d7 100644 --- a/src/main/java/com/bensherriff/siren/Bot.java +++ b/src/main/java/com/bensherriff/siren/Main.java @@ -13,8 +13,8 @@ import org.apache.logging.log4j.Logger; import java.io.IOException; import java.util.Arrays; -public class Bot { - private static final Logger LOGGER = LogManager.getLogger(Bot.class); +public class Main { + private static final Logger LOGGER = LogManager.getLogger(Main.class); private final static GatewayIntent[] INTENTS = { GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS, GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT diff --git a/src/main/java/com/bensherriff/siren/commands/PlayCommand.java b/src/main/java/com/bensherriff/siren/commands/PlayCommand.java index 20a74b9..09f9a72 100644 --- a/src/main/java/com/bensherriff/siren/commands/PlayCommand.java +++ b/src/main/java/com/bensherriff/siren/commands/PlayCommand.java @@ -89,7 +89,7 @@ public class PlayCommand extends Command { if (exception.getMessage().contains("Unknown file format")) { event.getHook().sendMessage(errorMsg + ". " + exception.getMessage()).queue(); } else { - event.getHook().sendMessage(errorMsg + ". Please contact your administrator.").queue(); + event.getHook().sendMessage(errorMsg + ". Please contact " + listener.getOwner() + ".").queue(); } LOGGER.error("{}: {}", errorMsg, exception.getMessage()); } diff --git a/src/main/java/com/bensherriff/siren/database/DatabaseManager.java b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java index b4f449d..2e4f075 100644 --- a/src/main/java/com/bensherriff/siren/database/DatabaseManager.java +++ b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java @@ -1,42 +1,47 @@ package com.bensherriff.siren.database; -import com.theokanning.openai.completion.chat.ChatMessage; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.sql.*; -import java.util.List; +import java.util.*; + +import static java.util.Map.entry; public class DatabaseManager { private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class); - private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" + - "id SERIAL PRIMARY KEY, " + - "guild_id BIGINT NOT NULL, " + - "thread_id BIGINT NOT NULL, " + - "user_id BIGINT NOT NULL, " + - "message_type VARCHAR(20) NOT NULL," + - "message_text TEXT NOT NULL," + - "message_response TEXT, " + - "timestamp TIMESTAMP NOT NULL DEFAULT NOW()" + - ")"; - private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" + - "id SERIAL PRIMARY KEY, " + - "embeddings FLOAT[] NOT NULL" + - ");"; - private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" + - "id SERIAL PRIMARY KEY, " + - "message_id INT NOT NULL, " + - "embedding_id INT NOT NULL" + - ")"; + private static final Map createTableQueries = Map.ofEntries( + entry("messages", "CREATE TABLE IF NOT EXISTS messages (" + + "id SERIAL PRIMARY KEY, " + + "guild_id BIGINT NOT NULL, " + + "thread_id BIGINT NOT NULL, " + + "user_id BIGINT NOT NULL, " + + "message_type VARCHAR(20) NOT NULL," + + "message_text TEXT NOT NULL," + + "message_response TEXT, " + + "topics TEXT[], " + + "timestamp TIMESTAMP NOT NULL DEFAULT NOW()" + + ")"), + entry("embeddings", "CREATE TABLE IF NOT EXISTS embeddings (" + + "id SERIAL PRIMARY KEY, " + + "embeddings FLOAT[] NOT NULL" + + ")"), + entry("message_embeddings", "CREATE TABLE IF NOT EXISTS message_embeddings (" + + "id SERIAL PRIMARY KEY, " + + "message_id INT NOT NULL, " + + "embedding_id INT NOT NULL" + + ")") + ); private static final String INSERT_MESSAGE = "INSERT INTO messages (" + "message_type, " + "guild_id, " + "thread_id, " + "user_id, " + "message_text, " + - "message_response) " + - "VALUES (?, ?, ?, ?, ?, ?)"; + "message_response, " + + "topics) " + + "VALUES (?, ?, ?, ?, ?, ?, ?)"; private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" + "embeddings) " + "VALUES (?)"; @@ -47,52 +52,47 @@ public class DatabaseManager { "VALUES (?, ?)"; public static void createTables() { - createMessageTable(); - createEmbeddingsTable(); - createMessageEmbeddingsTable(); + for (Map.Entry entry : createTableQueries.entrySet()) { + if (!createTable(entry.getKey())) { + LOGGER.warn("Failed to create one or more required database tables"); + return; + } + } + LOGGER.debug("Successfully created database tables"); } - private static void createMessageTable() { - try { - Connection connection = DatabaseConnection.getConnection(); - LOGGER.debug("Creating 'messages' database table if it does not exist"); - Statement statement = connection.createStatement(); - statement.execute(CREATE_MESSAGE_TABLE); - } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); + private static boolean createTable(String tableName) { + if (tableExists(tableName)) { + return true; + } else { + try { + Connection connection = DatabaseConnection.getConnection(); + LOGGER.debug("Creating '{}' database table if it does not exist", tableName); + Statement statement = connection.createStatement(); + statement.execute(createTableQueries.get(tableName)); + return true; + } catch (SQLException ex) { + LOGGER.error(ex.getMessage()); + return false; + } } } - private static void createEmbeddingsTable() { + public static int storeMessage(MessageData messageData) { + if (!tableExists("messages")) { + LOGGER.warn("Table 'messages' does not exist"); + return -1; + } try { Connection connection = DatabaseConnection.getConnection(); - LOGGER.debug("Creating 'embeddings' database table if it does not exist"); - Statement statement = connection.createStatement(); - statement.execute(CREATE_EMBEDDINGS_TABLE); - } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); - } - } - - private static void createMessageEmbeddingsTable() { - try { - Connection connection = DatabaseConnection.getConnection(); - LOGGER.debug("Creating 'message_embeddings' database table if it does not exist"); - Statement statement = connection.createStatement(); - statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE); - } catch (SQLException ex) { - LOGGER.error(ex.getMessage()); - } - } - - public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) { - try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) { - preparedStatement.setString(1, message.getRole()); - preparedStatement.setLong(2, guildId); - preparedStatement.setLong(3, threadId); - preparedStatement.setLong(4, userId); - preparedStatement.setString(5, message.getContent()); - preparedStatement.setString(6, response); + PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE); + preparedStatement.setString(1, messageData.getMessageType()); + preparedStatement.setLong(2, messageData.getGuildId()); + preparedStatement.setLong(3, messageData.getThreadId()); + preparedStatement.setLong(4, messageData.getUserId()); + preparedStatement.setString(5, messageData.getMessageText()); + preparedStatement.setString(6, messageData.getMessageResponse()); + preparedStatement.setArray(7, connection.createArrayOf("text", messageData.getTopics().toArray())); return preparedStatement.executeUpdate(); } catch (SQLException ex) { LOGGER.error(ex.getMessage()); @@ -101,7 +101,13 @@ public class DatabaseManager { } public static int storeEmbedding(List data) { - try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) { + if (!tableExists("embeddings")) { + LOGGER.warn("Table 'embeddings' does not exist"); + return -1; + } + try { + Connection connection = DatabaseConnection.getConnection(); + PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING); preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0]))); return preparedStatement.executeUpdate(); } catch (SQLException ex) { @@ -111,7 +117,13 @@ public class DatabaseManager { } public static int storeMessageEmbeddings(int messageId, int embeddingId) { - try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) { + if (!tableExists("message_embeddings")) { + LOGGER.warn("Table 'message_embeddings' does not exist"); + return -1; + } + try { + Connection connection = DatabaseConnection.getConnection(); + PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS); preparedStatement.setInt(1, messageId); preparedStatement.setInt(2, embeddingId); return preparedStatement.executeUpdate(); @@ -120,4 +132,61 @@ public class DatabaseManager { return -1; } } + + public static List getMessages(String query, Object... params) throws SQLException, IllegalArgumentException { + LOGGER.trace("Query: <{}>", query); + Connection connection = DatabaseConnection.getConnection(); + PreparedStatement stmt = connection.prepareStatement(query); + int i = 1; + for (Object param : params) { + if (param instanceof String) { + stmt.setString(i++, (String) param); + } else if (param instanceof Integer) { + stmt.setInt(i++, (Integer) param); + } else if (param instanceof Long) { + stmt.setLong(i++, (Long) param); + } else if (param instanceof Double) { + stmt.setDouble(i++, (Double) param); + } else if (param instanceof Float) { + stmt.setFloat(i++, (Float) param); + } else if (param instanceof Timestamp) { + stmt.setTimestamp(i++, (Timestamp) param); + } else if (param instanceof Boolean) { + stmt.setBoolean(i++, (Boolean) param); + } else if (param instanceof List) { + stmt.setArray(i++, connection.createArrayOf("text", ((List) param).toArray())); + } else { + throw new IllegalArgumentException("Unsupported parameter type: " + param.getClass().getName()); + } + } + ResultSet resultSet = stmt.executeQuery(); + List resultList = new ArrayList<>(); + while (resultSet.next()) { + Array array = resultSet.getArray(8); + MessageData messageData = new MessageData( + resultSet.getLong(2), + resultSet.getLong(3), + resultSet.getLong(4), + resultSet.getString(5), + resultSet.getString(6), + resultSet.getString(7), + new HashSet<>(Arrays.asList((String[]) array.getArray())), + resultSet.getTimestamp(9) + ); + resultList.add(messageData); + } + return resultList; + } + + private static boolean tableExists(String tableName) { + try { + Connection connection = DatabaseConnection.getConnection(); + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery("SELECT tablename FROM pg_tables WHERE tablename = '" + tableName + "'"); + return resultSet.next(); + } catch (SQLException ex) { + LOGGER.error("Failed to check if table exists; {}" + ex.getMessage()); + return false; + } + } } diff --git a/src/main/java/com/bensherriff/siren/database/MessageData.java b/src/main/java/com/bensherriff/siren/database/MessageData.java new file mode 100644 index 0000000..8a42728 --- /dev/null +++ b/src/main/java/com/bensherriff/siren/database/MessageData.java @@ -0,0 +1,74 @@ +package com.bensherriff.siren.database; + +import java.sql.Timestamp; +import java.util.Set; + +public class MessageData { + private final Long guildId; + private final Long threadId; + private final Long userId; + private final String messageType; + private final String messageText; + private final String messageResponse; + private final Set topics; + private Timestamp timestamp; + + public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText, + String messageResponse, Set topics) { + this.guildId = guildId; + this.threadId = threadId; + this.userId = userId; + this.messageType = messageType; + this.messageText = messageText; + this.messageResponse = messageResponse; + this.topics = topics; + } + + public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText, + String messageResponse, Set topics, Timestamp timestamp) { + this.guildId = guildId; + this.threadId = threadId; + this.userId = userId; + this.messageType = messageType; + this.messageText = messageText; + this.messageResponse = messageResponse; + this.topics = topics; + this.timestamp = timestamp; + } + + public Long getGuildId() { + return guildId; + } + + public Long getThreadId() { + return threadId; + } + + public Long getUserId() { + return userId; + } + + public String getMessageType() { + return messageType; + } + + public String getMessageText() { + return messageText; + } + + public String getMessageResponse() { + return messageResponse; + } + + public Set getTopics() { + return topics; + } + + public Timestamp getTimestamp() { + return timestamp; + } + + public void setTimestamp(Timestamp timestamp) { + this.timestamp = timestamp; + } +} diff --git a/src/main/java/com/bensherriff/siren/database/QueryBuilder.java b/src/main/java/com/bensherriff/siren/database/QueryBuilder.java new file mode 100644 index 0000000..996f810 --- /dev/null +++ b/src/main/java/com/bensherriff/siren/database/QueryBuilder.java @@ -0,0 +1,68 @@ +package com.bensherriff.siren.database; + +public class QueryBuilder { + private boolean distinct; + private String columnList; + private String tableName; + private String whereClause; + private String orderByClause; + private Integer limit; + + public QueryBuilder select(String columnList) { + this.columnList = columnList; + return this; + } + + public QueryBuilder from(String tableName) { + this.tableName = tableName; + return this; + } + + public QueryBuilder where(String whereClause) { + this.whereClause = whereClause; + return this; + } + + public QueryBuilder orderBy(String orderByClause) { + this.orderByClause = orderByClause; + return this; + } + + public QueryBuilder limit(int limit) { + this.limit = limit; + return this; + } + + public QueryBuilder distinct(boolean distinct) { + this.distinct = distinct; + return this; + } + + public String build() { + StringBuilder queryBuilder = new StringBuilder(); + queryBuilder.append("SELECT "); + if (distinct) { + queryBuilder.append("DISTINCT "); + } + if (columnList != null && !columnList.isEmpty()) { + queryBuilder.append(columnList); + } else { + queryBuilder.append("*"); + } + queryBuilder.append(" FROM "); + queryBuilder.append(tableName); + if (whereClause != null) { + queryBuilder.append(" WHERE "); + queryBuilder.append(whereClause); + } + if (orderByClause != null) { + queryBuilder.append(" ORDER BY "); + queryBuilder.append(orderByClause); + } + if (limit != null) { + queryBuilder.append(" LIMIT "); + queryBuilder.append(limit); + } + return queryBuilder.toString(); + } +} diff --git a/src/main/java/com/bensherriff/siren/openai/NLP.java b/src/main/java/com/bensherriff/siren/openai/NLP.java new file mode 100644 index 0000000..b792201 --- /dev/null +++ b/src/main/java/com/bensherriff/siren/openai/NLP.java @@ -0,0 +1,95 @@ +package com.bensherriff.siren.openai; + +import edu.stanford.nlp.ling.CoreAnnotations; +import edu.stanford.nlp.ling.CoreLabel; +import edu.stanford.nlp.pipeline.Annotation; +import edu.stanford.nlp.pipeline.StanfordCoreNLP; +import edu.stanford.nlp.util.CoreMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.*; + +public class NLP { + private static final Logger LOGGER = LogManager.getLogger(NLP.class); + private final StanfordCoreNLP pipeline; + private final Map> keywords; + + public NLP() { + Properties props = new Properties(); + props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner"); + pipeline = new StanfordCoreNLP(props); + keywords = new HashMap<>(); + keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid")); + } + + public Set getTopicKeywords(String sentence) { + Set topics = new LinkedHashSet<>(); + Annotation document = new Annotation(sentence); + pipeline.annotate(document); + + List sentences = document.get(CoreAnnotations.SentencesAnnotation.class); + CoreMap sentenceMap = sentences.get(0); + List tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class); + List namedEntities = new ArrayList<>(); + + for (CoreLabel token : tokens) { + String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class); + if (!ne.equals("0")) { + namedEntities.add(token); + } + } + for (CoreLabel namedEntity : namedEntities) { + String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class); + String word = namedEntity.word(); + if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) { + topics.add(word); + } else if (ne.equals("LOCATION")) { + String[] posTags = word.split("_"); + for (String posTag : posTags) { + if (posTag.startsWith("N")) { + topics.add(word); + break; + } + } + } else { + String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class); + if (pos.startsWith("NN")) { + topics.add(word); + for (String keyword : keywords.keySet()) { + if (keywords.get(keyword).contains(word.toLowerCase())) { + topics.add(keyword); + } + } + if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) { + topics.add("dnd"); + } + } + } + } + return topics; + } + + public List lemmatize(String documentText) { + List lemmas = new ArrayList<>(); + Annotation document = new Annotation(documentText); + pipeline.annotate(document); + List sentences = document.get(CoreAnnotations.SentencesAnnotation.class); + for (CoreMap sentence : sentences) { + for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) { + lemmas.add(token.get(CoreAnnotations.LemmaAnnotation.class)); + } + } + return lemmas; + } + + // TODO finish this method + public Set getSynonyms(String targetWord) { + return Collections.emptySet(); + } + + // TODO finish this method + public static double calculateSimilarity(String sentence1, String sentence2) { + return 0.0; + } +} diff --git a/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java b/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java index ff10672..311c2c9 100644 --- a/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java +++ b/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java @@ -2,6 +2,8 @@ package com.bensherriff.siren.openai; import com.bensherriff.siren.Listener; import com.bensherriff.siren.database.DatabaseManager; +import com.bensherriff.siren.database.MessageData; +import com.bensherriff.siren.database.QueryBuilder; import com.bensherriff.siren.settings.GuildSettings; import com.bensherriff.siren.settings.Settings; import com.theokanning.openai.completion.CompletionRequest; @@ -11,7 +13,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.embedding.Embedding; import com.theokanning.openai.embedding.EmbeddingRequest; -import com.theokanning.openai.embedding.EmbeddingResult; import com.theokanning.openai.service.OpenAiService; import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.entities.Message; @@ -21,6 +22,7 @@ import net.dv8tion.jda.api.events.message.MessageReceivedEvent; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import java.sql.SQLException; import java.time.Duration; import java.util.*; import java.util.concurrent.ScheduledExecutorService; @@ -32,11 +34,15 @@ public class OpenAIManager { private final Settings settings; private final JDA jda; private final ScheduledExecutorService executor; + private final String owner; + private final NLP NLP; public OpenAIManager(Listener listener) { this.settings = listener.getSettings(); this.jda = listener.getJDA(); this.executor = listener.getExecutor(); + this.owner = listener.getOwner(); + this.NLP = new NLP(); if (settings.getOpenAISettings().getToken().isEmpty()) { LOGGER.warn("No OpenAI token; OpenAI functionality is disabled"); @@ -56,19 +62,13 @@ public class OpenAIManager { if (openAiService != null) { executor.execute(() -> sendMessage(event)); } else { - event.getMessage().reply("OpenAI functionality is disabled. Please contact an administrator").queue(); + event.getMessage().reply("OpenAI functionality is disabled. Please contact " + owner + ".").queue(); } } } private void sendMessage(MessageReceivedEvent event) { String message = parseMessage(event.getMessage().getContentRaw()); - long guildId = event.getGuild().getIdLong(); - Model model = settings.getGuildSettings().get(guildId).getModel(); - StringBuilder stringBuilder = new StringBuilder(); - List chatMessages = new ArrayList<>(); - ChatMessage chatMessage = createChatMessage(message, event); - List embeddings = new ArrayList<>(); if (message.isEmpty() || message.isBlank()) { event.getMessage().reply("Your message is empty. Please try again").queue(); @@ -76,7 +76,13 @@ public class OpenAIManager { } try { - LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message); + StringBuilder stringBuilder = new StringBuilder(); + long guildId = event.getGuild().getIdLong(); + Model model = settings.getGuildSettings().get(guildId).getModel(); + ChatMessage chatMessage = createChatMessage(message, event); + List embeddings = new ArrayList<>(); + + LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, event.getAuthor().getId(), event.getMessageId(), message); // Send OpenAI Message and get response switch (model) { case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> { @@ -85,8 +91,7 @@ public class OpenAIManager { completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim())); } case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> { - chatMessages.add(chatMessage); - ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event); + ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event); ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest); chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim())); @@ -96,7 +101,7 @@ public class OpenAIManager { // embeddings.addAll(embeddingResult.getData()); } default -> { - event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue(); + event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue(); LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId, model, Arrays.toString(Model.values())); return; @@ -105,7 +110,7 @@ public class OpenAIManager { handleResponse(chatMessage, event, stringBuilder.toString(), embeddings); } catch (Exception ex) { LOGGER.error("Caught exception while processing message; {}", ex.getMessage()); - event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue(); + event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue(); } } @@ -118,8 +123,30 @@ public class OpenAIManager { .build(); } - private ChatCompletionRequest createCompletionRequest(List chatMessages, MessageReceivedEvent event) { + private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException { GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()); + List chatMessages = new ArrayList<>(); + chatMessages.add(chatMessage); + + // Handle System Messages + chatMessages.add(createSystemMessage("You are a discord bot named Siren")); + chatMessages.add(createSystemMessage("I am a user named " + event.getAuthor().getName())); + if (event.isFromThread()) { + String query = new QueryBuilder().from("messages") + .where("guild_id = ? AND thread_id = ?") + .orderBy("timestamp DESC") + .limit(10) + .build(); + List previousMessages = DatabaseManager.getMessages( + query, event.getGuild().getIdLong(), event.getChannel().getIdLong()); + for (MessageData previousMessage : previousMessages) { + ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " + + previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() + + "\". You replied with \"" + previousMessage.getMessageResponse() + "\"."); + chatMessages.add(previousChatMessage); + } + } + return ChatCompletionRequest.builder() .model(guildSettings.getModel().getName()) .maxTokens(guildSettings.getMaxTokens()) @@ -146,6 +173,10 @@ public class OpenAIManager { .build(); } + private ChatMessage createSystemMessage(String message) { + return new ChatMessage(Role.SYSTEM.getName(), message); + } + private ChatMessage createChatMessage(String message, MessageReceivedEvent event) { ChatMessage chatMessage = new ChatMessage(); chatMessage.setContent(message); @@ -158,33 +189,38 @@ public class OpenAIManager { } private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List embeddings) { + LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response); + Set topics = new LinkedHashSet<>(); + topics.addAll(NLP.getTopicKeywords(chatMessage.getContent())); + topics.addAll(NLP.getTopicKeywords(response)); + LOGGER.trace("Topics: {}", topics); + if (event.isFromThread()) { ThreadChannel channel = event.getChannel().asThreadChannel(); - int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(), - channel.getIdLong(), response); - embeddings.forEach(embedding -> { - int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding()); - DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow); - }); - channel.sendMessage(response).queue(); + storeMessage(chatMessage, event, response, embeddings, topics, channel); } else { // The max discord title length is 100 characters String threadTitle = chatMessage.getContent(); if (chatMessage.getContent().length() > 100) { threadTitle = chatMessage.getContent().substring(0, 100); } - event.getMessage().createThreadChannel(threadTitle).queue(channel -> { - channel.sendMessage(response).queue(); - int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(), - channel.getIdLong(), response); - embeddings.forEach(embedding -> { - int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding()); - DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow); - }); - }); + event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel)); } } + private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response, + List embeddings, Set topics, ThreadChannel channel) { + MessageData messageData = new MessageData(event.getGuild().getIdLong(), + channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(), + response, topics); + int messageRow = DatabaseManager.storeMessage(messageData); + embeddings.forEach(embedding -> { + int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding()); + DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow); + }); + channel.sendMessage(response).queue(); + } + private String parseMessage(String input) { return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim(); }