v0.1.23 Tweaking OpenAI settings, database tables
This commit is contained in:
@@ -4,6 +4,7 @@ import edu.stanford.nlp.ling.CoreAnnotations;
|
||||
import edu.stanford.nlp.ling.CoreLabel;
|
||||
import edu.stanford.nlp.pipeline.Annotation;
|
||||
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
|
||||
import edu.stanford.nlp.sentiment.SentimentCoreAnnotations;
|
||||
import edu.stanford.nlp.util.CoreMap;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
@@ -17,57 +18,64 @@ public class NLP {
|
||||
|
||||
public NLP() {
|
||||
Properties props = new Properties();
|
||||
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner");
|
||||
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner, parse, dcoref, sentiment");
|
||||
pipeline = new StanfordCoreNLP(props);
|
||||
keywords = new HashMap<>();
|
||||
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
|
||||
}
|
||||
|
||||
public Set<String> getTopicKeywords(String sentence) {
|
||||
Set<String> topics = new LinkedHashSet<>();
|
||||
Annotation document = new Annotation(sentence);
|
||||
pipeline.annotate(document);
|
||||
private List<CoreMap> getSentences(String text) {
|
||||
Annotation annotation = new Annotation(text);
|
||||
pipeline.annotate(annotation);
|
||||
return annotation.get(CoreAnnotations.SentencesAnnotation.class);
|
||||
}
|
||||
|
||||
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
||||
CoreMap sentenceMap = sentences.get(0);
|
||||
List<CoreLabel> tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
|
||||
List<CoreLabel> namedEntities = new ArrayList<>();
|
||||
public Set<String> getTags(String text) {
|
||||
Set<String> tags = new LinkedHashSet<>();
|
||||
|
||||
for (CoreLabel token : tokens) {
|
||||
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||
if (!ne.equals("0")) {
|
||||
namedEntities.add(token);
|
||||
}
|
||||
}
|
||||
for (CoreLabel namedEntity : namedEntities) {
|
||||
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||
String word = namedEntity.word();
|
||||
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
|
||||
topics.add(word);
|
||||
} else if (ne.equals("LOCATION")) {
|
||||
String[] posTags = word.split("_");
|
||||
for (String posTag : posTags) {
|
||||
if (posTag.startsWith("N")) {
|
||||
topics.add(word);
|
||||
break;
|
||||
}
|
||||
List<CoreMap> sentences = getSentences(text);
|
||||
for (CoreMap sentence : sentences) {
|
||||
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
|
||||
List<CoreLabel> namedEntities = new ArrayList<>();
|
||||
|
||||
for (CoreLabel token : tokens) {
|
||||
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||
if (!ne.equals("0")) {
|
||||
namedEntities.add(token);
|
||||
}
|
||||
} else {
|
||||
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
|
||||
if (pos.startsWith("NN")) {
|
||||
topics.add(word);
|
||||
for (String keyword : keywords.keySet()) {
|
||||
if (keywords.get(keyword).contains(word.toLowerCase())) {
|
||||
topics.add(keyword);
|
||||
if (token.equals(tokens.get(tokens.size() - 1)) && token.word().equals("?") ||
|
||||
(List.of("what", "when", "who", "where", "why", "how").contains(token.word()))) {
|
||||
tags.add("question");
|
||||
}
|
||||
}
|
||||
|
||||
for (CoreLabel namedEntity : namedEntities) {
|
||||
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||
String word = namedEntity.word();
|
||||
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
|
||||
tags.add(word);
|
||||
} else if (ne.equals("LOCATION")) {
|
||||
String[] posTags = word.split("_");
|
||||
for (String posTag : posTags) {
|
||||
if (posTag.startsWith("N")) {
|
||||
tags.add(word);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
|
||||
topics.add("dnd");
|
||||
} else {
|
||||
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
|
||||
if (pos.startsWith("NN")) {
|
||||
tags.add(word);
|
||||
for (String keyword : keywords.keySet()) {
|
||||
if (keywords.get(keyword).contains(word.toLowerCase())) {
|
||||
tags.add(keyword);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return topics;
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<String> lemmatize(String documentText) {
|
||||
@@ -83,13 +91,31 @@ public class NLP {
|
||||
return lemmas;
|
||||
}
|
||||
|
||||
// TODO finish this method
|
||||
public Set<String> getSynonyms(String targetWord) {
|
||||
return Collections.emptySet();
|
||||
/**
|
||||
* Determines the sentiment (tone) of each sentence in the text. For example: positive, negative, or neutral
|
||||
* @param text the input text
|
||||
* @return the list of sentiments
|
||||
*/
|
||||
public List<String> sentimentAnalysis(String text) {
|
||||
Annotation document = new Annotation(text);
|
||||
pipeline.annotate(document);
|
||||
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
||||
List<String> sentiments = new ArrayList<>();
|
||||
for (CoreMap sentence : sentences) {
|
||||
sentiments.add(sentence.get(SentimentCoreAnnotations.SentimentClass.class));
|
||||
}
|
||||
return sentiments;
|
||||
}
|
||||
|
||||
// TODO finish this method
|
||||
public static double calculateSimilarity(String sentence1, String sentence2) {
|
||||
return 0.0;
|
||||
public static double cosineSimilarity(double[] vector1, double[] vector2) {
|
||||
double dotProduct = 0.0;
|
||||
double norm1 = 0.0;
|
||||
double norm2 = 0.0;
|
||||
for (int i = 0; i < vector1.length; i++) {
|
||||
dotProduct += vector1[i] * vector2[i];
|
||||
norm1 += Math.pow(vector1[i], 2);
|
||||
norm2 += Math.pow(vector2[i], 2);
|
||||
}
|
||||
return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.theokanning.openai.completion.CompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import com.theokanning.openai.embedding.Embedding;
|
||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||
import com.theokanning.openai.image.CreateImageRequest;
|
||||
import com.theokanning.openai.image.ImageResult;
|
||||
@@ -88,34 +87,39 @@ public class OpenAIManager {
|
||||
userSettings = settings.getGuildSettings().get(guildId).getUserSettings().get(authorId);
|
||||
Model model = userSettings.getModel();
|
||||
ChatMessage chatMessage = createUserMessage(message);
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
|
||||
|
||||
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, authorId, event.getMessageId(), message);
|
||||
// Send OpenAI Message and get response
|
||||
switch (model) {
|
||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||
CompletionRequest completionRequest = createCompletionRequest(message, event);
|
||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||
}
|
||||
case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
|
||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||
|
||||
//TODO fix embeddings
|
||||
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
|
||||
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
|
||||
// embeddings.addAll(embeddingResult.getData());
|
||||
}
|
||||
default -> {
|
||||
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
|
||||
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
|
||||
model, Arrays.toString(Model.values()));
|
||||
return;
|
||||
String query = new QueryBuilder("messages")
|
||||
.where("guild_id = ? AND channel_id = ? AND prompt = ?")
|
||||
.orderBy("timestamp DESC")
|
||||
.build();
|
||||
List<MessageData> messages = DatabaseManager.parseResponse(DatabaseManager.query(query,
|
||||
guildId, event.getChannel().getIdLong(), chatMessage.getContent()));
|
||||
if (!messages.isEmpty()) {
|
||||
stringBuilder.append(messages.get(0).getResponse());
|
||||
} else {
|
||||
// Send OpenAI Message and get response
|
||||
switch (model) {
|
||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||
CompletionRequest completionRequest = createCompletionRequest(message, event);
|
||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||
}
|
||||
case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
|
||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||
}
|
||||
default -> {
|
||||
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
|
||||
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
|
||||
model, Arrays.toString(Model.values()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
||||
handleResponse(chatMessage, event, stringBuilder.toString());
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
||||
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
|
||||
@@ -123,7 +127,6 @@ public class OpenAIManager {
|
||||
}
|
||||
|
||||
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return EmbeddingRequest.builder()
|
||||
.model(Model.ADA_EMBEDDING_2.getName())
|
||||
.user(event.getAuthor().getId())
|
||||
@@ -144,66 +147,48 @@ public class OpenAIManager {
|
||||
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
|
||||
.getUserSettings().get(event.getAuthor().getIdLong());
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
// List<MessageData> previousMessages = getPreviousMessages(event);
|
||||
// for (MessageData previousMessage : previousMessages) {
|
||||
// chatMessages.add(createSystemMessage("In a previous conversation, we discussed the topics " + previousMessage.getTags() +
|
||||
// " and " + previousMessage.getResponseTags()));
|
||||
// chatMessages.add(createSystemMessage("I sent you the prompt \"" + previousMessage.getPrompt() +
|
||||
// "\" and you replied with \"" + previousMessage.getResponse() + "\""));
|
||||
// }
|
||||
chatMessages.add(chatMessage);
|
||||
|
||||
// Handle System Messages
|
||||
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
|
||||
chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
|
||||
if (event.isFromThread()) {
|
||||
String query = new QueryBuilder("messages")
|
||||
.where("guild_id = ? AND thread_id = ?")
|
||||
.orderBy("timestamp DESC")
|
||||
.limit(10)
|
||||
.build();
|
||||
|
||||
// Build MessageData objects from query results
|
||||
List<Map<String, Object>> results = DatabaseManager.query(
|
||||
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
|
||||
List<MessageData> previousMessages = new ArrayList<>();
|
||||
for (Map<String, Object> result : results) {
|
||||
Object[] resultTopicObjects = (Object[]) ((PgArray) result.get("topics")).getArray();
|
||||
Set<String> resultTopics = new HashSet<>();
|
||||
for (Object object : resultTopicObjects) {
|
||||
resultTopics.add((String) object);
|
||||
}
|
||||
MessageData messageData = new MessageData(
|
||||
(long) result.get("guild_id"),
|
||||
(long) result.get("thread_id"),
|
||||
(long) result.get("user_id"),
|
||||
(String) result.get("message_type"),
|
||||
(String) result.get("message_text"),
|
||||
(String) result.get("message_response"),
|
||||
resultTopics,
|
||||
(Timestamp) result.get("timestamp")
|
||||
);
|
||||
previousMessages.add(messageData);
|
||||
}
|
||||
|
||||
Set<String> potentialTopics = new HashSet<>();
|
||||
for (MessageData previousMessage : previousMessages) {
|
||||
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
|
||||
previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() +
|
||||
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
|
||||
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText()));
|
||||
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse()));
|
||||
// chatMessages.add(previousChatMessage);
|
||||
}
|
||||
LOGGER.trace("Potential topics: {}", potentialTopics);
|
||||
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
|
||||
}
|
||||
|
||||
return ChatCompletionRequest.builder()
|
||||
.model(userSettings.getModel().getName())
|
||||
.maxTokens(userSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
// .topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.messages(chatMessages)
|
||||
.build();
|
||||
}
|
||||
|
||||
private List<MessageData> getPreviousMessages(MessageReceivedEvent event) throws SQLException {
|
||||
List<MessageData> previousMessages = new ArrayList<>();
|
||||
if (event.isFromThread()) {
|
||||
String query = new QueryBuilder("messages")
|
||||
.where("guild_id = ? AND channel_id = ?")
|
||||
.orderBy("timestamp DESC")
|
||||
.limit(3)
|
||||
.build();
|
||||
|
||||
// Build MessageData objects from query results
|
||||
List<Map<String, Object>> results = DatabaseManager.query(
|
||||
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
|
||||
previousMessages.addAll(DatabaseManager.parseResponse(results));
|
||||
}
|
||||
return previousMessages;
|
||||
}
|
||||
|
||||
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
|
||||
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
|
||||
.getUserSettings().get(event.getAuthor().getIdLong());
|
||||
@@ -231,39 +216,40 @@ public class OpenAIManager {
|
||||
return new ChatMessage(Role.USER.getName(), message);
|
||||
}
|
||||
|
||||
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
||||
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response) {
|
||||
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
|
||||
Set<String> topics = new LinkedHashSet<>();
|
||||
topics.addAll(NLP.getTopicKeywords(chatMessage.getContent()));
|
||||
topics.addAll(NLP.getTopicKeywords(response));
|
||||
LOGGER.trace("Topics: {}", topics);
|
||||
Set<String> tags = NLP.getTags(chatMessage.getContent());
|
||||
Set<String> responseTags = NLP.getTags(response);
|
||||
|
||||
MessageData.MessageDataBuilder builder = new MessageData.MessageDataBuilder()
|
||||
.guildId(event.getGuild().getIdLong())
|
||||
.channelId(event.getChannel().getIdLong())
|
||||
.userId(event.getAuthor().getIdLong())
|
||||
.type(chatMessage.getRole())
|
||||
.prompt(chatMessage.getContent())
|
||||
.response(response)
|
||||
.tags(tags)
|
||||
.responseTags(responseTags);
|
||||
|
||||
LOGGER.trace("Tags: {}", tags);
|
||||
|
||||
if (event.isFromThread()) {
|
||||
DatabaseManager.storeMessage(builder.build());
|
||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||
storeMessage(chatMessage, event, response, embeddings, topics, channel);
|
||||
channel.sendMessage(response).queue();
|
||||
} else {
|
||||
// The max discord title length is 100 characters
|
||||
String threadTitle = chatMessage.getContent();
|
||||
if (chatMessage.getContent().length() > 100) {
|
||||
threadTitle = chatMessage.getContent().substring(0, 100);
|
||||
}
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel));
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||
DatabaseManager.storeMessage(builder.channelId(channel.getIdLong()).build());
|
||||
channel.sendMessage(response).queue();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response,
|
||||
List<Embedding> embeddings, Set<String> topics, ThreadChannel channel) {
|
||||
MessageData messageData = new MessageData(event.getGuild().getIdLong(),
|
||||
channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(),
|
||||
response, topics);
|
||||
int messageRow = DatabaseManager.storeMessage(messageData);
|
||||
embeddings.forEach(embedding -> {
|
||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||
});
|
||||
channel.sendMessage(response).queue();
|
||||
}
|
||||
|
||||
private String parseMessage(String input) {
|
||||
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.bensherriff.siren.database;
|
||||
import com.bensherriff.siren.settings.LocalTrack;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.postgresql.jdbc.PgArray;
|
||||
|
||||
import java.sql.*;
|
||||
import java.sql.Date;
|
||||
@@ -17,12 +18,13 @@ public class DatabaseManager {
|
||||
entry("messages", List.of(
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"guild_id BIGINT NOT NULL",
|
||||
"thread_id BIGINT NOT NULL",
|
||||
"channel_id BIGINT NOT NULL",
|
||||
"user_id BIGINT NOT NULL",
|
||||
"message_type VARCHAR(20) NOT NULL",
|
||||
"message_text TEXT NOT NULL",
|
||||
"message_response TEXT",
|
||||
"topics TEXT[]",
|
||||
"type VARCHAR(20) NOT NULL",
|
||||
"prompt TEXT NOT NULL",
|
||||
"response TEXT",
|
||||
"tags TEXT[]",
|
||||
"response_tags TEXT[]",
|
||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()")),
|
||||
entry("embeddings", List.of(
|
||||
"id SERIAL PRIMARY KEY",
|
||||
@@ -81,22 +83,24 @@ public class DatabaseManager {
|
||||
|
||||
public static int storeMessage(MessageData messageData) {
|
||||
String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||
"message_type, " +
|
||||
"guild_id, " +
|
||||
"thread_id, " +
|
||||
"channel_id, " +
|
||||
"user_id, " +
|
||||
"message_text, " +
|
||||
"message_response, " +
|
||||
"topics) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)";
|
||||
"type, " +
|
||||
"prompt, " +
|
||||
"response, " +
|
||||
"tags, " +
|
||||
"response_tags) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
|
||||
return storeMessage("messages", INSERT_MESSAGE,
|
||||
messageData.getMessageType(),
|
||||
messageData.getGuildId(),
|
||||
messageData.getThreadId(),
|
||||
messageData.getChannelId(),
|
||||
messageData.getUserId(),
|
||||
messageData.getMessageText(),
|
||||
messageData.getMessageResponse(),
|
||||
messageData.getTopics()
|
||||
messageData.getType(),
|
||||
messageData.getPrompt(),
|
||||
messageData.getResponse(),
|
||||
messageData.getTags(),
|
||||
messageData.getResponseTags()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -164,11 +168,11 @@ public class DatabaseManager {
|
||||
preparedStatement.setBlob(i, (Blob) param);
|
||||
} else if (param instanceof Clob) {
|
||||
preparedStatement.setClob(i, (Clob) param);
|
||||
} else if (param instanceof List && ((List<?>) param).get(0) instanceof String) {
|
||||
} else if (param instanceof List && !((List<?>) param).isEmpty() && ((List<?>) param).get(0) instanceof String) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray()));
|
||||
} else if (param instanceof List && ((List<?>) param).get(0) instanceof Double) {
|
||||
} else if (param instanceof List && !((List<?>) param).isEmpty() && ((List<?>) param).get(0) instanceof Double) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0])));
|
||||
} else if (param instanceof Set && ((Set<?>) param).toArray()[0] instanceof String) {
|
||||
} else if (param instanceof Set && !((Set<?>) param).isEmpty() && ((Set<?>) param).toArray()[0] instanceof String) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray()));
|
||||
} else if (param instanceof String[]) {
|
||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param)));
|
||||
@@ -249,4 +253,30 @@ public class DatabaseManager {
|
||||
LOGGER.error("Failed to clear table; {}", ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static List<MessageData> parseResponse(List<Map<String, Object>> results) throws SQLException {
|
||||
List<MessageData> messageData = new ArrayList<>();
|
||||
for (Map<String, Object> result : results) {
|
||||
Set<String> promptTags = new HashSet<>();
|
||||
for (Object object : (Object[]) ((PgArray) result.get("tags")).getArray()) {
|
||||
promptTags.add((String) object);
|
||||
}
|
||||
Set<String> responseTags = new HashSet<>();
|
||||
for (Object object : (Object[]) ((PgArray) result.get("response_tags")).getArray()) {
|
||||
promptTags.add((String) object);
|
||||
}
|
||||
messageData.add(new MessageData.MessageDataBuilder()
|
||||
.guildId((long) result.get("guild_id"))
|
||||
.channelId((long) result.get("channel_id"))
|
||||
.userId((long) result.get("user_id"))
|
||||
.type((String) result.get("type"))
|
||||
.prompt((String) result.get("prompt"))
|
||||
.response((String) result.get("response"))
|
||||
.tags(promptTags)
|
||||
.responseTags(responseTags)
|
||||
.timestamp((Timestamp) result.get("timestamp"))
|
||||
.build());
|
||||
}
|
||||
return messageData;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,74 +1,126 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Set;
|
||||
|
||||
public class MessageData {
|
||||
private final Long guildId;
|
||||
private final Long threadId;
|
||||
private final Long channelId;
|
||||
private final Long userId;
|
||||
private final String messageType;
|
||||
private final String messageText;
|
||||
private final String messageResponse;
|
||||
private final Set<String> topics;
|
||||
private Timestamp timestamp;
|
||||
private final String type;
|
||||
private final String prompt;
|
||||
private final String response;
|
||||
private final Set<String> tags;
|
||||
private final Set<String> responseTags;
|
||||
private final Timestamp timestamp;
|
||||
|
||||
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
||||
String messageResponse, Set<String> topics) {
|
||||
this.guildId = guildId;
|
||||
this.threadId = threadId;
|
||||
this.userId = userId;
|
||||
this.messageType = messageType;
|
||||
this.messageText = messageText;
|
||||
this.messageResponse = messageResponse;
|
||||
this.topics = topics;
|
||||
}
|
||||
|
||||
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
||||
String messageResponse, Set<String> topics, Timestamp timestamp) {
|
||||
this.guildId = guildId;
|
||||
this.threadId = threadId;
|
||||
this.userId = userId;
|
||||
this.messageType = messageType;
|
||||
this.messageText = messageText;
|
||||
this.messageResponse = messageResponse;
|
||||
this.topics = topics;
|
||||
this.timestamp = timestamp;
|
||||
private MessageData(MessageDataBuilder builder) {
|
||||
this.guildId = builder.guildId;
|
||||
this.channelId = builder.channelId;
|
||||
this.userId = builder.userId;
|
||||
this.type = builder.type;
|
||||
this.prompt = builder.prompt;
|
||||
this.response = builder.response;
|
||||
this.tags = builder.tags;
|
||||
this.responseTags = builder.responseTags;
|
||||
this.timestamp = builder.timestamp;
|
||||
}
|
||||
|
||||
public Long getGuildId() {
|
||||
return guildId;
|
||||
}
|
||||
|
||||
public Long getThreadId() {
|
||||
return threadId;
|
||||
public Long getChannelId() {
|
||||
return channelId;
|
||||
}
|
||||
|
||||
public Long getUserId() {
|
||||
return userId;
|
||||
}
|
||||
|
||||
public String getMessageType() {
|
||||
return messageType;
|
||||
public String getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
public String getMessageText() {
|
||||
return messageText;
|
||||
public String getPrompt() {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
public String getMessageResponse() {
|
||||
return messageResponse;
|
||||
public String getResponse() {
|
||||
return response;
|
||||
}
|
||||
|
||||
public Set<String> getTopics() {
|
||||
return topics;
|
||||
public Set<String> getTags() {
|
||||
return tags;
|
||||
}
|
||||
|
||||
public Set<String> getResponseTags() {
|
||||
return responseTags;
|
||||
}
|
||||
|
||||
public Timestamp getTimestamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
public void setTimestamp(Timestamp timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
public static class MessageDataBuilder {
|
||||
private Long guildId;
|
||||
private Long channelId;
|
||||
private Long userId;
|
||||
private String type = "";
|
||||
private String prompt = "";
|
||||
private String response = "";
|
||||
private Set<String> tags = new LinkedHashSet<>();
|
||||
private Set<String> responseTags = new LinkedHashSet<>();
|
||||
private Timestamp timestamp = new Timestamp(0);
|
||||
|
||||
public MessageDataBuilder guildId(Long guildId) {
|
||||
this.guildId = guildId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder channelId(Long channelId) {
|
||||
this.channelId = channelId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder userId(Long userId) {
|
||||
this.userId = userId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder type(String type) {
|
||||
this.type = type;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder prompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder response(String response) {
|
||||
this.response = response;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder tags(Set<String> tags) {
|
||||
this.tags = tags;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder responseTags(Set<String> responseTags) {
|
||||
this.responseTags = responseTags;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageDataBuilder timestamp(Timestamp timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MessageData build() {
|
||||
return new MessageData(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user