v0.1.23 Tweaking OpenAI settings, database tables

This commit is contained in:
2023-04-17 15:23:36 -04:00
parent 74efbc2352
commit debea833d7
6 changed files with 286 additions and 192 deletions

View File

@@ -1 +1 @@
export SIREN_VERSION=0.1.22 export SIREN_VERSION=0.1.23

View File

@@ -67,7 +67,7 @@
<jackson.version>2.14.2</jackson.version> <jackson.version>2.14.2</jackson.version>
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version> <theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
<postgresql.version>42.6.0</postgresql.version> <postgresql.version>42.6.0</postgresql.version>
<corenlp.version>4.5.3</corenlp.version> <corenlp.version>4.2.0</corenlp.version>
<slf4j.version>2.0.7</slf4j.version> <slf4j.version>2.0.7</slf4j.version>
<log4j.version>2.20.0</log4j.version> <log4j.version>2.20.0</log4j.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>

View File

@@ -4,6 +4,7 @@ import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.pipeline.Annotation; import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP; import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.sentiment.SentimentCoreAnnotations;
import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.CoreMap;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
@@ -17,20 +18,24 @@ public class NLP {
public NLP() { public NLP() {
Properties props = new Properties(); Properties props = new Properties();
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner"); props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner, parse, dcoref, sentiment");
pipeline = new StanfordCoreNLP(props); pipeline = new StanfordCoreNLP(props);
keywords = new HashMap<>(); keywords = new HashMap<>();
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid")); keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
} }
public Set<String> getTopicKeywords(String sentence) { private List<CoreMap> getSentences(String text) {
Set<String> topics = new LinkedHashSet<>(); Annotation annotation = new Annotation(text);
Annotation document = new Annotation(sentence); pipeline.annotate(annotation);
pipeline.annotate(document); return annotation.get(CoreAnnotations.SentencesAnnotation.class);
}
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class); public Set<String> getTags(String text) {
CoreMap sentenceMap = sentences.get(0); Set<String> tags = new LinkedHashSet<>();
List<CoreLabel> tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
List<CoreMap> sentences = getSentences(text);
for (CoreMap sentence : sentences) {
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
List<CoreLabel> namedEntities = new ArrayList<>(); List<CoreLabel> namedEntities = new ArrayList<>();
for (CoreLabel token : tokens) { for (CoreLabel token : tokens) {
@@ -38,36 +43,39 @@ public class NLP {
if (!ne.equals("0")) { if (!ne.equals("0")) {
namedEntities.add(token); namedEntities.add(token);
} }
if (token.equals(tokens.get(tokens.size() - 1)) && token.word().equals("?") ||
(List.of("what", "when", "who", "where", "why", "how").contains(token.word()))) {
tags.add("question");
} }
}
for (CoreLabel namedEntity : namedEntities) { for (CoreLabel namedEntity : namedEntities) {
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class); String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
String word = namedEntity.word(); String word = namedEntity.word();
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) { if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
topics.add(word); tags.add(word);
} else if (ne.equals("LOCATION")) { } else if (ne.equals("LOCATION")) {
String[] posTags = word.split("_"); String[] posTags = word.split("_");
for (String posTag : posTags) { for (String posTag : posTags) {
if (posTag.startsWith("N")) { if (posTag.startsWith("N")) {
topics.add(word); tags.add(word);
break; break;
} }
} }
} else { } else {
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class); String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
if (pos.startsWith("NN")) { if (pos.startsWith("NN")) {
topics.add(word); tags.add(word);
for (String keyword : keywords.keySet()) { for (String keyword : keywords.keySet()) {
if (keywords.get(keyword).contains(word.toLowerCase())) { if (keywords.get(keyword).contains(word.toLowerCase())) {
topics.add(keyword); tags.add(keyword);
}
}
if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
topics.add("dnd");
} }
} }
} }
} }
return topics; }
}
return tags;
} }
public List<String> lemmatize(String documentText) { public List<String> lemmatize(String documentText) {
@@ -83,13 +91,31 @@ public class NLP {
return lemmas; return lemmas;
} }
// TODO finish this method /**
public Set<String> getSynonyms(String targetWord) { * Determines the sentiment (tone) of each sentence in the text. For example: positive, negative, or neutral
return Collections.emptySet(); * @param text the input text
* @return the list of sentiments
*/
public List<String> sentimentAnalysis(String text) {
Annotation document = new Annotation(text);
pipeline.annotate(document);
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
List<String> sentiments = new ArrayList<>();
for (CoreMap sentence : sentences) {
sentiments.add(sentence.get(SentimentCoreAnnotations.SentimentClass.class));
}
return sentiments;
} }
// TODO finish this method public static double cosineSimilarity(double[] vector1, double[] vector2) {
public static double calculateSimilarity(String sentence1, String sentence2) { double dotProduct = 0.0;
return 0.0; double norm1 = 0.0;
double norm2 = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
norm1 += Math.pow(vector1[i], 2);
norm2 += Math.pow(vector2[i], 2);
}
return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
} }
} }

View File

@@ -10,7 +10,6 @@ import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest; import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.image.CreateImageRequest; import com.theokanning.openai.image.CreateImageRequest;
import com.theokanning.openai.image.ImageResult; import com.theokanning.openai.image.ImageResult;
@@ -88,9 +87,18 @@ public class OpenAIManager {
userSettings = settings.getGuildSettings().get(guildId).getUserSettings().get(authorId); userSettings = settings.getGuildSettings().get(guildId).getUserSettings().get(authorId);
Model model = userSettings.getModel(); Model model = userSettings.getModel();
ChatMessage chatMessage = createUserMessage(message); ChatMessage chatMessage = createUserMessage(message);
List<Embedding> embeddings = new ArrayList<>();
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, authorId, event.getMessageId(), message); LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, authorId, event.getMessageId(), message);
String query = new QueryBuilder("messages")
.where("guild_id = ? AND channel_id = ? AND prompt = ?")
.orderBy("timestamp DESC")
.build();
List<MessageData> messages = DatabaseManager.parseResponse(DatabaseManager.query(query,
guildId, event.getChannel().getIdLong(), chatMessage.getContent()));
if (!messages.isEmpty()) {
stringBuilder.append(messages.get(0).getResponse());
} else {
// Send OpenAI Message and get response // Send OpenAI Message and get response
switch (model) { switch (model) {
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> { case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
@@ -102,11 +110,6 @@ public class OpenAIManager {
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event); ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest); ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim())); chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
//TODO fix embeddings
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
// embeddings.addAll(embeddingResult.getData());
} }
default -> { default -> {
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue(); event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
@@ -115,7 +118,8 @@ public class OpenAIManager {
return; return;
} }
} }
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings); }
handleResponse(chatMessage, event, stringBuilder.toString());
} catch (Exception ex) { } catch (Exception ex) {
LOGGER.error("Caught exception while processing message; {}", ex.getMessage()); LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue(); event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
@@ -123,7 +127,6 @@ public class OpenAIManager {
} }
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) { private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return EmbeddingRequest.builder() return EmbeddingRequest.builder()
.model(Model.ADA_EMBEDDING_2.getName()) .model(Model.ADA_EMBEDDING_2.getName())
.user(event.getAuthor().getId()) .user(event.getAuthor().getId())
@@ -144,66 +147,48 @@ public class OpenAIManager {
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()) UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
.getUserSettings().get(event.getAuthor().getIdLong()); .getUserSettings().get(event.getAuthor().getIdLong());
List<ChatMessage> chatMessages = new ArrayList<>(); List<ChatMessage> chatMessages = new ArrayList<>();
// List<MessageData> previousMessages = getPreviousMessages(event);
// for (MessageData previousMessage : previousMessages) {
// chatMessages.add(createSystemMessage("In a previous conversation, we discussed the topics " + previousMessage.getTags() +
// " and " + previousMessage.getResponseTags()));
// chatMessages.add(createSystemMessage("I sent you the prompt \"" + previousMessage.getPrompt() +
// "\" and you replied with \"" + previousMessage.getResponse() + "\""));
// }
chatMessages.add(chatMessage); chatMessages.add(chatMessage);
// Handle System Messages // Handle System Messages
chatMessages.add(createSystemMessage("You are a discord bot named Siren")); chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName())); chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
if (event.isFromThread()) {
String query = new QueryBuilder("messages")
.where("guild_id = ? AND thread_id = ?")
.orderBy("timestamp DESC")
.limit(10)
.build();
// Build MessageData objects from query results
List<Map<String, Object>> results = DatabaseManager.query(
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
List<MessageData> previousMessages = new ArrayList<>();
for (Map<String, Object> result : results) {
Object[] resultTopicObjects = (Object[]) ((PgArray) result.get("topics")).getArray();
Set<String> resultTopics = new HashSet<>();
for (Object object : resultTopicObjects) {
resultTopics.add((String) object);
}
MessageData messageData = new MessageData(
(long) result.get("guild_id"),
(long) result.get("thread_id"),
(long) result.get("user_id"),
(String) result.get("message_type"),
(String) result.get("message_text"),
(String) result.get("message_response"),
resultTopics,
(Timestamp) result.get("timestamp")
);
previousMessages.add(messageData);
}
Set<String> potentialTopics = new HashSet<>();
for (MessageData previousMessage : previousMessages) {
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() +
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText()));
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse()));
// chatMessages.add(previousChatMessage);
}
LOGGER.trace("Potential topics: {}", potentialTopics);
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
}
return ChatCompletionRequest.builder() return ChatCompletionRequest.builder()
.model(userSettings.getModel().getName()) .model(userSettings.getModel().getName())
.maxTokens(userSettings.getMaxTokens()) .maxTokens(userSettings.getMaxTokens())
.user(event.getAuthor().getId()) .user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature()) .temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP()) // .topP(settings.getOpenAISettings().getTopP())
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty()) .frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
.presencePenalty(settings.getOpenAISettings().getPresencePenalty()) .presencePenalty(settings.getOpenAISettings().getPresencePenalty())
.messages(chatMessages) .messages(chatMessages)
.build(); .build();
} }
private List<MessageData> getPreviousMessages(MessageReceivedEvent event) throws SQLException {
List<MessageData> previousMessages = new ArrayList<>();
if (event.isFromThread()) {
String query = new QueryBuilder("messages")
.where("guild_id = ? AND channel_id = ?")
.orderBy("timestamp DESC")
.limit(3)
.build();
// Build MessageData objects from query results
List<Map<String, Object>> results = DatabaseManager.query(
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
previousMessages.addAll(DatabaseManager.parseResponse(results));
}
return previousMessages;
}
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) { private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()) UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
.getUserSettings().get(event.getAuthor().getIdLong()); .getUserSettings().get(event.getAuthor().getIdLong());
@@ -231,37 +216,38 @@ public class OpenAIManager {
return new ChatMessage(Role.USER.getName(), message); return new ChatMessage(Role.USER.getName(), message);
} }
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) { private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response) {
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response); LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
Set<String> topics = new LinkedHashSet<>(); Set<String> tags = NLP.getTags(chatMessage.getContent());
topics.addAll(NLP.getTopicKeywords(chatMessage.getContent())); Set<String> responseTags = NLP.getTags(response);
topics.addAll(NLP.getTopicKeywords(response));
LOGGER.trace("Topics: {}", topics); MessageData.MessageDataBuilder builder = new MessageData.MessageDataBuilder()
.guildId(event.getGuild().getIdLong())
.channelId(event.getChannel().getIdLong())
.userId(event.getAuthor().getIdLong())
.type(chatMessage.getRole())
.prompt(chatMessage.getContent())
.response(response)
.tags(tags)
.responseTags(responseTags);
LOGGER.trace("Tags: {}", tags);
if (event.isFromThread()) { if (event.isFromThread()) {
DatabaseManager.storeMessage(builder.build());
ThreadChannel channel = event.getChannel().asThreadChannel(); ThreadChannel channel = event.getChannel().asThreadChannel();
storeMessage(chatMessage, event, response, embeddings, topics, channel); channel.sendMessage(response).queue();
} else { } else {
// The max discord title length is 100 characters // The max discord title length is 100 characters
String threadTitle = chatMessage.getContent(); String threadTitle = chatMessage.getContent();
if (chatMessage.getContent().length() > 100) { if (chatMessage.getContent().length() > 100) {
threadTitle = chatMessage.getContent().substring(0, 100); threadTitle = chatMessage.getContent().substring(0, 100);
} }
event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel)); event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
} DatabaseManager.storeMessage(builder.channelId(channel.getIdLong()).build());
}
private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response,
List<Embedding> embeddings, Set<String> topics, ThreadChannel channel) {
MessageData messageData = new MessageData(event.getGuild().getIdLong(),
channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(),
response, topics);
int messageRow = DatabaseManager.storeMessage(messageData);
embeddings.forEach(embedding -> {
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
});
channel.sendMessage(response).queue(); channel.sendMessage(response).queue();
});
}
} }
private String parseMessage(String input) { private String parseMessage(String input) {

View File

@@ -3,6 +3,7 @@ package com.bensherriff.siren.database;
import com.bensherriff.siren.settings.LocalTrack; import com.bensherriff.siren.settings.LocalTrack;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.postgresql.jdbc.PgArray;
import java.sql.*; import java.sql.*;
import java.sql.Date; import java.sql.Date;
@@ -17,12 +18,13 @@ public class DatabaseManager {
entry("messages", List.of( entry("messages", List.of(
"id SERIAL PRIMARY KEY", "id SERIAL PRIMARY KEY",
"guild_id BIGINT NOT NULL", "guild_id BIGINT NOT NULL",
"thread_id BIGINT NOT NULL", "channel_id BIGINT NOT NULL",
"user_id BIGINT NOT NULL", "user_id BIGINT NOT NULL",
"message_type VARCHAR(20) NOT NULL", "type VARCHAR(20) NOT NULL",
"message_text TEXT NOT NULL", "prompt TEXT NOT NULL",
"message_response TEXT", "response TEXT",
"topics TEXT[]", "tags TEXT[]",
"response_tags TEXT[]",
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()")), "timestamp TIMESTAMP NOT NULL DEFAULT NOW()")),
entry("embeddings", List.of( entry("embeddings", List.of(
"id SERIAL PRIMARY KEY", "id SERIAL PRIMARY KEY",
@@ -81,22 +83,24 @@ public class DatabaseManager {
public static int storeMessage(MessageData messageData) { public static int storeMessage(MessageData messageData) {
String INSERT_MESSAGE = "INSERT INTO messages (" + String INSERT_MESSAGE = "INSERT INTO messages (" +
"message_type, " +
"guild_id, " + "guild_id, " +
"thread_id, " + "channel_id, " +
"user_id, " + "user_id, " +
"message_text, " + "type, " +
"message_response, " + "prompt, " +
"topics) " + "response, " +
"VALUES (?, ?, ?, ?, ?, ?, ?)"; "tags, " +
"response_tags) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
return storeMessage("messages", INSERT_MESSAGE, return storeMessage("messages", INSERT_MESSAGE,
messageData.getMessageType(),
messageData.getGuildId(), messageData.getGuildId(),
messageData.getThreadId(), messageData.getChannelId(),
messageData.getUserId(), messageData.getUserId(),
messageData.getMessageText(), messageData.getType(),
messageData.getMessageResponse(), messageData.getPrompt(),
messageData.getTopics() messageData.getResponse(),
messageData.getTags(),
messageData.getResponseTags()
); );
} }
@@ -164,11 +168,11 @@ public class DatabaseManager {
preparedStatement.setBlob(i, (Blob) param); preparedStatement.setBlob(i, (Blob) param);
} else if (param instanceof Clob) { } else if (param instanceof Clob) {
preparedStatement.setClob(i, (Clob) param); preparedStatement.setClob(i, (Clob) param);
} else if (param instanceof List && ((List<?>) param).get(0) instanceof String) { } else if (param instanceof List && !((List<?>) param).isEmpty() && ((List<?>) param).get(0) instanceof String) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray())); preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray()));
} else if (param instanceof List && ((List<?>) param).get(0) instanceof Double) { } else if (param instanceof List && !((List<?>) param).isEmpty() && ((List<?>) param).get(0) instanceof Double) {
preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0]))); preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0])));
} else if (param instanceof Set && ((Set<?>) param).toArray()[0] instanceof String) { } else if (param instanceof Set && !((Set<?>) param).isEmpty() && ((Set<?>) param).toArray()[0] instanceof String) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray())); preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray()));
} else if (param instanceof String[]) { } else if (param instanceof String[]) {
preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param))); preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param)));
@@ -249,4 +253,30 @@ public class DatabaseManager {
LOGGER.error("Failed to clear table; {}", ex.getMessage()); LOGGER.error("Failed to clear table; {}", ex.getMessage());
} }
} }
public static List<MessageData> parseResponse(List<Map<String, Object>> results) throws SQLException {
List<MessageData> messageData = new ArrayList<>();
for (Map<String, Object> result : results) {
Set<String> promptTags = new HashSet<>();
for (Object object : (Object[]) ((PgArray) result.get("tags")).getArray()) {
promptTags.add((String) object);
}
Set<String> responseTags = new HashSet<>();
for (Object object : (Object[]) ((PgArray) result.get("response_tags")).getArray()) {
promptTags.add((String) object);
}
messageData.add(new MessageData.MessageDataBuilder()
.guildId((long) result.get("guild_id"))
.channelId((long) result.get("channel_id"))
.userId((long) result.get("user_id"))
.type((String) result.get("type"))
.prompt((String) result.get("prompt"))
.response((String) result.get("response"))
.tags(promptTags)
.responseTags(responseTags)
.timestamp((Timestamp) result.get("timestamp"))
.build());
}
return messageData;
}
} }

View File

@@ -1,74 +1,126 @@
package com.bensherriff.siren.database; package com.bensherriff.siren.database;
import java.sql.Timestamp; import java.sql.Timestamp;
import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
public class MessageData { public class MessageData {
private final Long guildId; private final Long guildId;
private final Long threadId; private final Long channelId;
private final Long userId; private final Long userId;
private final String messageType; private final String type;
private final String messageText; private final String prompt;
private final String messageResponse; private final String response;
private final Set<String> topics; private final Set<String> tags;
private Timestamp timestamp; private final Set<String> responseTags;
private final Timestamp timestamp;
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText, private MessageData(MessageDataBuilder builder) {
String messageResponse, Set<String> topics) { this.guildId = builder.guildId;
this.guildId = guildId; this.channelId = builder.channelId;
this.threadId = threadId; this.userId = builder.userId;
this.userId = userId; this.type = builder.type;
this.messageType = messageType; this.prompt = builder.prompt;
this.messageText = messageText; this.response = builder.response;
this.messageResponse = messageResponse; this.tags = builder.tags;
this.topics = topics; this.responseTags = builder.responseTags;
} this.timestamp = builder.timestamp;
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
String messageResponse, Set<String> topics, Timestamp timestamp) {
this.guildId = guildId;
this.threadId = threadId;
this.userId = userId;
this.messageType = messageType;
this.messageText = messageText;
this.messageResponse = messageResponse;
this.topics = topics;
this.timestamp = timestamp;
} }
public Long getGuildId() { public Long getGuildId() {
return guildId; return guildId;
} }
public Long getThreadId() { public Long getChannelId() {
return threadId; return channelId;
} }
public Long getUserId() { public Long getUserId() {
return userId; return userId;
} }
public String getMessageType() { public String getType() {
return messageType; return type;
} }
public String getMessageText() { public String getPrompt() {
return messageText; return prompt;
} }
public String getMessageResponse() { public String getResponse() {
return messageResponse; return response;
} }
public Set<String> getTopics() { public Set<String> getTags() {
return topics; return tags;
}
public Set<String> getResponseTags() {
return responseTags;
} }
public Timestamp getTimestamp() { public Timestamp getTimestamp() {
return timestamp; return timestamp;
} }
public void setTimestamp(Timestamp timestamp) { public static class MessageDataBuilder {
private Long guildId;
private Long channelId;
private Long userId;
private String type = "";
private String prompt = "";
private String response = "";
private Set<String> tags = new LinkedHashSet<>();
private Set<String> responseTags = new LinkedHashSet<>();
private Timestamp timestamp = new Timestamp(0);
public MessageDataBuilder guildId(Long guildId) {
this.guildId = guildId;
return this;
}
public MessageDataBuilder channelId(Long channelId) {
this.channelId = channelId;
return this;
}
public MessageDataBuilder userId(Long userId) {
this.userId = userId;
return this;
}
public MessageDataBuilder type(String type) {
this.type = type;
return this;
}
public MessageDataBuilder prompt(String prompt) {
this.prompt = prompt;
return this;
}
public MessageDataBuilder response(String response) {
this.response = response;
return this;
}
public MessageDataBuilder tags(Set<String> tags) {
this.tags = tags;
return this;
}
public MessageDataBuilder responseTags(Set<String> responseTags) {
this.responseTags = responseTags;
return this;
}
public MessageDataBuilder timestamp(Timestamp timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this;
}
public MessageData build() {
return new MessageData(this);
}
} }
} }