v0.1.23 Tweaking OpenAI settings, database tables
This commit is contained in:
2
pom.xml
2
pom.xml
@@ -67,7 +67,7 @@
|
|||||||
<jackson.version>2.14.2</jackson.version>
|
<jackson.version>2.14.2</jackson.version>
|
||||||
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
||||||
<postgresql.version>42.6.0</postgresql.version>
|
<postgresql.version>42.6.0</postgresql.version>
|
||||||
<corenlp.version>4.5.3</corenlp.version>
|
<corenlp.version>4.2.0</corenlp.version>
|
||||||
<slf4j.version>2.0.7</slf4j.version>
|
<slf4j.version>2.0.7</slf4j.version>
|
||||||
<log4j.version>2.20.0</log4j.version>
|
<log4j.version>2.20.0</log4j.version>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import edu.stanford.nlp.ling.CoreAnnotations;
|
|||||||
import edu.stanford.nlp.ling.CoreLabel;
|
import edu.stanford.nlp.ling.CoreLabel;
|
||||||
import edu.stanford.nlp.pipeline.Annotation;
|
import edu.stanford.nlp.pipeline.Annotation;
|
||||||
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
|
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
|
||||||
|
import edu.stanford.nlp.sentiment.SentimentCoreAnnotations;
|
||||||
import edu.stanford.nlp.util.CoreMap;
|
import edu.stanford.nlp.util.CoreMap;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
@@ -17,57 +18,64 @@ public class NLP {
|
|||||||
|
|
||||||
public NLP() {
|
public NLP() {
|
||||||
Properties props = new Properties();
|
Properties props = new Properties();
|
||||||
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner");
|
props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner, parse, dcoref, sentiment");
|
||||||
pipeline = new StanfordCoreNLP(props);
|
pipeline = new StanfordCoreNLP(props);
|
||||||
keywords = new HashMap<>();
|
keywords = new HashMap<>();
|
||||||
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
|
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<String> getTopicKeywords(String sentence) {
|
private List<CoreMap> getSentences(String text) {
|
||||||
Set<String> topics = new LinkedHashSet<>();
|
Annotation annotation = new Annotation(text);
|
||||||
Annotation document = new Annotation(sentence);
|
pipeline.annotate(annotation);
|
||||||
pipeline.annotate(document);
|
return annotation.get(CoreAnnotations.SentencesAnnotation.class);
|
||||||
|
}
|
||||||
|
|
||||||
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
public Set<String> getTags(String text) {
|
||||||
CoreMap sentenceMap = sentences.get(0);
|
Set<String> tags = new LinkedHashSet<>();
|
||||||
List<CoreLabel> tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
|
|
||||||
List<CoreLabel> namedEntities = new ArrayList<>();
|
|
||||||
|
|
||||||
for (CoreLabel token : tokens) {
|
List<CoreMap> sentences = getSentences(text);
|
||||||
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
for (CoreMap sentence : sentences) {
|
||||||
if (!ne.equals("0")) {
|
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
|
||||||
namedEntities.add(token);
|
List<CoreLabel> namedEntities = new ArrayList<>();
|
||||||
}
|
|
||||||
}
|
for (CoreLabel token : tokens) {
|
||||||
for (CoreLabel namedEntity : namedEntities) {
|
String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||||
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
if (!ne.equals("0")) {
|
||||||
String word = namedEntity.word();
|
namedEntities.add(token);
|
||||||
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
|
|
||||||
topics.add(word);
|
|
||||||
} else if (ne.equals("LOCATION")) {
|
|
||||||
String[] posTags = word.split("_");
|
|
||||||
for (String posTag : posTags) {
|
|
||||||
if (posTag.startsWith("N")) {
|
|
||||||
topics.add(word);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
if (token.equals(tokens.get(tokens.size() - 1)) && token.word().equals("?") ||
|
||||||
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
|
(List.of("what", "when", "who", "where", "why", "how").contains(token.word()))) {
|
||||||
if (pos.startsWith("NN")) {
|
tags.add("question");
|
||||||
topics.add(word);
|
}
|
||||||
for (String keyword : keywords.keySet()) {
|
}
|
||||||
if (keywords.get(keyword).contains(word.toLowerCase())) {
|
|
||||||
topics.add(keyword);
|
for (CoreLabel namedEntity : namedEntities) {
|
||||||
|
String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
|
||||||
|
String word = namedEntity.word();
|
||||||
|
if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
|
||||||
|
tags.add(word);
|
||||||
|
} else if (ne.equals("LOCATION")) {
|
||||||
|
String[] posTags = word.split("_");
|
||||||
|
for (String posTag : posTags) {
|
||||||
|
if (posTag.startsWith("N")) {
|
||||||
|
tags.add(word);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
|
} else {
|
||||||
topics.add("dnd");
|
String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
|
||||||
|
if (pos.startsWith("NN")) {
|
||||||
|
tags.add(word);
|
||||||
|
for (String keyword : keywords.keySet()) {
|
||||||
|
if (keywords.get(keyword).contains(word.toLowerCase())) {
|
||||||
|
tags.add(keyword);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return topics;
|
return tags;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<String> lemmatize(String documentText) {
|
public List<String> lemmatize(String documentText) {
|
||||||
@@ -83,13 +91,31 @@ public class NLP {
|
|||||||
return lemmas;
|
return lemmas;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO finish this method
|
/**
|
||||||
public Set<String> getSynonyms(String targetWord) {
|
* Determines the sentiment (tone) of each sentence in the text. For example: positive, negative, or neutral
|
||||||
return Collections.emptySet();
|
* @param text the input text
|
||||||
|
* @return the list of sentiments
|
||||||
|
*/
|
||||||
|
public List<String> sentimentAnalysis(String text) {
|
||||||
|
Annotation document = new Annotation(text);
|
||||||
|
pipeline.annotate(document);
|
||||||
|
List<CoreMap> sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
|
||||||
|
List<String> sentiments = new ArrayList<>();
|
||||||
|
for (CoreMap sentence : sentences) {
|
||||||
|
sentiments.add(sentence.get(SentimentCoreAnnotations.SentimentClass.class));
|
||||||
|
}
|
||||||
|
return sentiments;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO finish this method
|
public static double cosineSimilarity(double[] vector1, double[] vector2) {
|
||||||
public static double calculateSimilarity(String sentence1, String sentence2) {
|
double dotProduct = 0.0;
|
||||||
return 0.0;
|
double norm1 = 0.0;
|
||||||
|
double norm2 = 0.0;
|
||||||
|
for (int i = 0; i < vector1.length; i++) {
|
||||||
|
dotProduct += vector1[i] * vector2[i];
|
||||||
|
norm1 += Math.pow(vector1[i], 2);
|
||||||
|
norm2 += Math.pow(vector2[i], 2);
|
||||||
|
}
|
||||||
|
return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import com.theokanning.openai.completion.CompletionResult;
|
|||||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||||
import com.theokanning.openai.embedding.Embedding;
|
|
||||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||||
import com.theokanning.openai.image.CreateImageRequest;
|
import com.theokanning.openai.image.CreateImageRequest;
|
||||||
import com.theokanning.openai.image.ImageResult;
|
import com.theokanning.openai.image.ImageResult;
|
||||||
@@ -88,34 +87,39 @@ public class OpenAIManager {
|
|||||||
userSettings = settings.getGuildSettings().get(guildId).getUserSettings().get(authorId);
|
userSettings = settings.getGuildSettings().get(guildId).getUserSettings().get(authorId);
|
||||||
Model model = userSettings.getModel();
|
Model model = userSettings.getModel();
|
||||||
ChatMessage chatMessage = createUserMessage(message);
|
ChatMessage chatMessage = createUserMessage(message);
|
||||||
List<Embedding> embeddings = new ArrayList<>();
|
|
||||||
|
|
||||||
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, authorId, event.getMessageId(), message);
|
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, authorId, event.getMessageId(), message);
|
||||||
// Send OpenAI Message and get response
|
String query = new QueryBuilder("messages")
|
||||||
switch (model) {
|
.where("guild_id = ? AND channel_id = ? AND prompt = ?")
|
||||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
.orderBy("timestamp DESC")
|
||||||
CompletionRequest completionRequest = createCompletionRequest(message, event);
|
.build();
|
||||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
List<MessageData> messages = DatabaseManager.parseResponse(DatabaseManager.query(query,
|
||||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
guildId, event.getChannel().getIdLong(), chatMessage.getContent()));
|
||||||
}
|
if (!messages.isEmpty()) {
|
||||||
case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
|
stringBuilder.append(messages.get(0).getResponse());
|
||||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
} else {
|
||||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
// Send OpenAI Message and get response
|
||||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
switch (model) {
|
||||||
|
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||||
//TODO fix embeddings
|
CompletionRequest completionRequest = createCompletionRequest(message, event);
|
||||||
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
|
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||||
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
|
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||||
// embeddings.addAll(embeddingResult.getData());
|
}
|
||||||
}
|
case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
|
||||||
default -> {
|
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
||||||
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
|
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||||
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
|
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||||
model, Arrays.toString(Model.values()));
|
}
|
||||||
return;
|
default -> {
|
||||||
|
event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
|
||||||
|
LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
|
||||||
|
model, Arrays.toString(Model.values()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
handleResponse(chatMessage, event, stringBuilder.toString());
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
||||||
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
|
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
|
||||||
@@ -123,7 +127,6 @@ public class OpenAIManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
|
||||||
return EmbeddingRequest.builder()
|
return EmbeddingRequest.builder()
|
||||||
.model(Model.ADA_EMBEDDING_2.getName())
|
.model(Model.ADA_EMBEDDING_2.getName())
|
||||||
.user(event.getAuthor().getId())
|
.user(event.getAuthor().getId())
|
||||||
@@ -144,66 +147,48 @@ public class OpenAIManager {
|
|||||||
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
|
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
|
||||||
.getUserSettings().get(event.getAuthor().getIdLong());
|
.getUserSettings().get(event.getAuthor().getIdLong());
|
||||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||||
|
// List<MessageData> previousMessages = getPreviousMessages(event);
|
||||||
|
// for (MessageData previousMessage : previousMessages) {
|
||||||
|
// chatMessages.add(createSystemMessage("In a previous conversation, we discussed the topics " + previousMessage.getTags() +
|
||||||
|
// " and " + previousMessage.getResponseTags()));
|
||||||
|
// chatMessages.add(createSystemMessage("I sent you the prompt \"" + previousMessage.getPrompt() +
|
||||||
|
// "\" and you replied with \"" + previousMessage.getResponse() + "\""));
|
||||||
|
// }
|
||||||
chatMessages.add(chatMessage);
|
chatMessages.add(chatMessage);
|
||||||
|
|
||||||
// Handle System Messages
|
// Handle System Messages
|
||||||
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
|
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
|
||||||
chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
|
chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
|
||||||
if (event.isFromThread()) {
|
|
||||||
String query = new QueryBuilder("messages")
|
|
||||||
.where("guild_id = ? AND thread_id = ?")
|
|
||||||
.orderBy("timestamp DESC")
|
|
||||||
.limit(10)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
// Build MessageData objects from query results
|
|
||||||
List<Map<String, Object>> results = DatabaseManager.query(
|
|
||||||
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
|
|
||||||
List<MessageData> previousMessages = new ArrayList<>();
|
|
||||||
for (Map<String, Object> result : results) {
|
|
||||||
Object[] resultTopicObjects = (Object[]) ((PgArray) result.get("topics")).getArray();
|
|
||||||
Set<String> resultTopics = new HashSet<>();
|
|
||||||
for (Object object : resultTopicObjects) {
|
|
||||||
resultTopics.add((String) object);
|
|
||||||
}
|
|
||||||
MessageData messageData = new MessageData(
|
|
||||||
(long) result.get("guild_id"),
|
|
||||||
(long) result.get("thread_id"),
|
|
||||||
(long) result.get("user_id"),
|
|
||||||
(String) result.get("message_type"),
|
|
||||||
(String) result.get("message_text"),
|
|
||||||
(String) result.get("message_response"),
|
|
||||||
resultTopics,
|
|
||||||
(Timestamp) result.get("timestamp")
|
|
||||||
);
|
|
||||||
previousMessages.add(messageData);
|
|
||||||
}
|
|
||||||
|
|
||||||
Set<String> potentialTopics = new HashSet<>();
|
|
||||||
for (MessageData previousMessage : previousMessages) {
|
|
||||||
ChatMessage previousChatMessage = createSystemMessage("For context, I previously sent you a message at " +
|
|
||||||
previousMessage.getTimestamp() + " which said \"" + previousMessage.getMessageText() +
|
|
||||||
"\". You replied with \"" + previousMessage.getMessageResponse() + "\".");
|
|
||||||
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageText()));
|
|
||||||
potentialTopics.addAll(NLP.getTopicKeywords(previousMessage.getMessageResponse()));
|
|
||||||
// chatMessages.add(previousChatMessage);
|
|
||||||
}
|
|
||||||
LOGGER.trace("Potential topics: {}", potentialTopics);
|
|
||||||
// chatMessages.add(createSystemMessage("As an AI language model, only give replies that relate to " + topics));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ChatCompletionRequest.builder()
|
return ChatCompletionRequest.builder()
|
||||||
.model(userSettings.getModel().getName())
|
.model(userSettings.getModel().getName())
|
||||||
.maxTokens(userSettings.getMaxTokens())
|
.maxTokens(userSettings.getMaxTokens())
|
||||||
.user(event.getAuthor().getId())
|
.user(event.getAuthor().getId())
|
||||||
.temperature(settings.getOpenAISettings().getTemperature())
|
.temperature(settings.getOpenAISettings().getTemperature())
|
||||||
.topP(settings.getOpenAISettings().getTopP())
|
// .topP(settings.getOpenAISettings().getTopP())
|
||||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||||
.messages(chatMessages)
|
.messages(chatMessages)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<MessageData> getPreviousMessages(MessageReceivedEvent event) throws SQLException {
|
||||||
|
List<MessageData> previousMessages = new ArrayList<>();
|
||||||
|
if (event.isFromThread()) {
|
||||||
|
String query = new QueryBuilder("messages")
|
||||||
|
.where("guild_id = ? AND channel_id = ?")
|
||||||
|
.orderBy("timestamp DESC")
|
||||||
|
.limit(3)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Build MessageData objects from query results
|
||||||
|
List<Map<String, Object>> results = DatabaseManager.query(
|
||||||
|
query, event.getGuild().getIdLong(), event.getChannel().getIdLong());
|
||||||
|
previousMessages.addAll(DatabaseManager.parseResponse(results));
|
||||||
|
}
|
||||||
|
return previousMessages;
|
||||||
|
}
|
||||||
|
|
||||||
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
|
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
|
||||||
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
|
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
|
||||||
.getUserSettings().get(event.getAuthor().getIdLong());
|
.getUserSettings().get(event.getAuthor().getIdLong());
|
||||||
@@ -231,39 +216,40 @@ public class OpenAIManager {
|
|||||||
return new ChatMessage(Role.USER.getName(), message);
|
return new ChatMessage(Role.USER.getName(), message);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response) {
|
||||||
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
|
LOGGER.trace("Message Response <{}>: {}", event.getMessageId(), response);
|
||||||
Set<String> topics = new LinkedHashSet<>();
|
Set<String> tags = NLP.getTags(chatMessage.getContent());
|
||||||
topics.addAll(NLP.getTopicKeywords(chatMessage.getContent()));
|
Set<String> responseTags = NLP.getTags(response);
|
||||||
topics.addAll(NLP.getTopicKeywords(response));
|
|
||||||
LOGGER.trace("Topics: {}", topics);
|
MessageData.MessageDataBuilder builder = new MessageData.MessageDataBuilder()
|
||||||
|
.guildId(event.getGuild().getIdLong())
|
||||||
|
.channelId(event.getChannel().getIdLong())
|
||||||
|
.userId(event.getAuthor().getIdLong())
|
||||||
|
.type(chatMessage.getRole())
|
||||||
|
.prompt(chatMessage.getContent())
|
||||||
|
.response(response)
|
||||||
|
.tags(tags)
|
||||||
|
.responseTags(responseTags);
|
||||||
|
|
||||||
|
LOGGER.trace("Tags: {}", tags);
|
||||||
|
|
||||||
if (event.isFromThread()) {
|
if (event.isFromThread()) {
|
||||||
|
DatabaseManager.storeMessage(builder.build());
|
||||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||||
storeMessage(chatMessage, event, response, embeddings, topics, channel);
|
channel.sendMessage(response).queue();
|
||||||
} else {
|
} else {
|
||||||
// The max discord title length is 100 characters
|
// The max discord title length is 100 characters
|
||||||
String threadTitle = chatMessage.getContent();
|
String threadTitle = chatMessage.getContent();
|
||||||
if (chatMessage.getContent().length() > 100) {
|
if (chatMessage.getContent().length() > 100) {
|
||||||
threadTitle = chatMessage.getContent().substring(0, 100);
|
threadTitle = chatMessage.getContent().substring(0, 100);
|
||||||
}
|
}
|
||||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> storeMessage(chatMessage, event, response, embeddings, topics, channel));
|
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||||
|
DatabaseManager.storeMessage(builder.channelId(channel.getIdLong()).build());
|
||||||
|
channel.sendMessage(response).queue();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void storeMessage(ChatMessage chatMessage, MessageReceivedEvent event, String response,
|
|
||||||
List<Embedding> embeddings, Set<String> topics, ThreadChannel channel) {
|
|
||||||
MessageData messageData = new MessageData(event.getGuild().getIdLong(),
|
|
||||||
channel.getIdLong(), event.getAuthor().getIdLong(), chatMessage.getRole(), chatMessage.getContent(),
|
|
||||||
response, topics);
|
|
||||||
int messageRow = DatabaseManager.storeMessage(messageData);
|
|
||||||
embeddings.forEach(embedding -> {
|
|
||||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
|
||||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
|
||||||
});
|
|
||||||
channel.sendMessage(response).queue();
|
|
||||||
}
|
|
||||||
|
|
||||||
private String parseMessage(String input) {
|
private String parseMessage(String input) {
|
||||||
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.bensherriff.siren.database;
|
|||||||
import com.bensherriff.siren.settings.LocalTrack;
|
import com.bensherriff.siren.settings.LocalTrack;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
|
import org.postgresql.jdbc.PgArray;
|
||||||
|
|
||||||
import java.sql.*;
|
import java.sql.*;
|
||||||
import java.sql.Date;
|
import java.sql.Date;
|
||||||
@@ -17,12 +18,13 @@ public class DatabaseManager {
|
|||||||
entry("messages", List.of(
|
entry("messages", List.of(
|
||||||
"id SERIAL PRIMARY KEY",
|
"id SERIAL PRIMARY KEY",
|
||||||
"guild_id BIGINT NOT NULL",
|
"guild_id BIGINT NOT NULL",
|
||||||
"thread_id BIGINT NOT NULL",
|
"channel_id BIGINT NOT NULL",
|
||||||
"user_id BIGINT NOT NULL",
|
"user_id BIGINT NOT NULL",
|
||||||
"message_type VARCHAR(20) NOT NULL",
|
"type VARCHAR(20) NOT NULL",
|
||||||
"message_text TEXT NOT NULL",
|
"prompt TEXT NOT NULL",
|
||||||
"message_response TEXT",
|
"response TEXT",
|
||||||
"topics TEXT[]",
|
"tags TEXT[]",
|
||||||
|
"response_tags TEXT[]",
|
||||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()")),
|
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()")),
|
||||||
entry("embeddings", List.of(
|
entry("embeddings", List.of(
|
||||||
"id SERIAL PRIMARY KEY",
|
"id SERIAL PRIMARY KEY",
|
||||||
@@ -81,22 +83,24 @@ public class DatabaseManager {
|
|||||||
|
|
||||||
public static int storeMessage(MessageData messageData) {
|
public static int storeMessage(MessageData messageData) {
|
||||||
String INSERT_MESSAGE = "INSERT INTO messages (" +
|
String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||||
"message_type, " +
|
|
||||||
"guild_id, " +
|
"guild_id, " +
|
||||||
"thread_id, " +
|
"channel_id, " +
|
||||||
"user_id, " +
|
"user_id, " +
|
||||||
"message_text, " +
|
"type, " +
|
||||||
"message_response, " +
|
"prompt, " +
|
||||||
"topics) " +
|
"response, " +
|
||||||
"VALUES (?, ?, ?, ?, ?, ?, ?)";
|
"tags, " +
|
||||||
|
"response_tags) " +
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
|
||||||
return storeMessage("messages", INSERT_MESSAGE,
|
return storeMessage("messages", INSERT_MESSAGE,
|
||||||
messageData.getMessageType(),
|
|
||||||
messageData.getGuildId(),
|
messageData.getGuildId(),
|
||||||
messageData.getThreadId(),
|
messageData.getChannelId(),
|
||||||
messageData.getUserId(),
|
messageData.getUserId(),
|
||||||
messageData.getMessageText(),
|
messageData.getType(),
|
||||||
messageData.getMessageResponse(),
|
messageData.getPrompt(),
|
||||||
messageData.getTopics()
|
messageData.getResponse(),
|
||||||
|
messageData.getTags(),
|
||||||
|
messageData.getResponseTags()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,11 +168,11 @@ public class DatabaseManager {
|
|||||||
preparedStatement.setBlob(i, (Blob) param);
|
preparedStatement.setBlob(i, (Blob) param);
|
||||||
} else if (param instanceof Clob) {
|
} else if (param instanceof Clob) {
|
||||||
preparedStatement.setClob(i, (Clob) param);
|
preparedStatement.setClob(i, (Clob) param);
|
||||||
} else if (param instanceof List && ((List<?>) param).get(0) instanceof String) {
|
} else if (param instanceof List && !((List<?>) param).isEmpty() && ((List<?>) param).get(0) instanceof String) {
|
||||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray()));
|
preparedStatement.setArray(i, connection.createArrayOf("text", ((List<String>) param).toArray()));
|
||||||
} else if (param instanceof List && ((List<?>) param).get(0) instanceof Double) {
|
} else if (param instanceof List && !((List<?>) param).isEmpty() && ((List<?>) param).get(0) instanceof Double) {
|
||||||
preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0])));
|
preparedStatement.setArray(i, connection.createArrayOf("float8", ((List<Double>) param).toArray(new Double[0])));
|
||||||
} else if (param instanceof Set && ((Set<?>) param).toArray()[0] instanceof String) {
|
} else if (param instanceof Set && !((Set<?>) param).isEmpty() && ((Set<?>) param).toArray()[0] instanceof String) {
|
||||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray()));
|
preparedStatement.setArray(i, connection.createArrayOf("text", ((Set<String>) param).toArray()));
|
||||||
} else if (param instanceof String[]) {
|
} else if (param instanceof String[]) {
|
||||||
preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param)));
|
preparedStatement.setArray(i, connection.createArrayOf("text", ((String[]) param)));
|
||||||
@@ -249,4 +253,30 @@ public class DatabaseManager {
|
|||||||
LOGGER.error("Failed to clear table; {}", ex.getMessage());
|
LOGGER.error("Failed to clear table; {}", ex.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static List<MessageData> parseResponse(List<Map<String, Object>> results) throws SQLException {
|
||||||
|
List<MessageData> messageData = new ArrayList<>();
|
||||||
|
for (Map<String, Object> result : results) {
|
||||||
|
Set<String> promptTags = new HashSet<>();
|
||||||
|
for (Object object : (Object[]) ((PgArray) result.get("tags")).getArray()) {
|
||||||
|
promptTags.add((String) object);
|
||||||
|
}
|
||||||
|
Set<String> responseTags = new HashSet<>();
|
||||||
|
for (Object object : (Object[]) ((PgArray) result.get("response_tags")).getArray()) {
|
||||||
|
promptTags.add((String) object);
|
||||||
|
}
|
||||||
|
messageData.add(new MessageData.MessageDataBuilder()
|
||||||
|
.guildId((long) result.get("guild_id"))
|
||||||
|
.channelId((long) result.get("channel_id"))
|
||||||
|
.userId((long) result.get("user_id"))
|
||||||
|
.type((String) result.get("type"))
|
||||||
|
.prompt((String) result.get("prompt"))
|
||||||
|
.response((String) result.get("response"))
|
||||||
|
.tags(promptTags)
|
||||||
|
.responseTags(responseTags)
|
||||||
|
.timestamp((Timestamp) result.get("timestamp"))
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
return messageData;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,74 +1,126 @@
|
|||||||
package com.bensherriff.siren.database;
|
package com.bensherriff.siren.database;
|
||||||
|
|
||||||
import java.sql.Timestamp;
|
import java.sql.Timestamp;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
public class MessageData {
|
public class MessageData {
|
||||||
private final Long guildId;
|
private final Long guildId;
|
||||||
private final Long threadId;
|
private final Long channelId;
|
||||||
private final Long userId;
|
private final Long userId;
|
||||||
private final String messageType;
|
private final String type;
|
||||||
private final String messageText;
|
private final String prompt;
|
||||||
private final String messageResponse;
|
private final String response;
|
||||||
private final Set<String> topics;
|
private final Set<String> tags;
|
||||||
private Timestamp timestamp;
|
private final Set<String> responseTags;
|
||||||
|
private final Timestamp timestamp;
|
||||||
|
|
||||||
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
private MessageData(MessageDataBuilder builder) {
|
||||||
String messageResponse, Set<String> topics) {
|
this.guildId = builder.guildId;
|
||||||
this.guildId = guildId;
|
this.channelId = builder.channelId;
|
||||||
this.threadId = threadId;
|
this.userId = builder.userId;
|
||||||
this.userId = userId;
|
this.type = builder.type;
|
||||||
this.messageType = messageType;
|
this.prompt = builder.prompt;
|
||||||
this.messageText = messageText;
|
this.response = builder.response;
|
||||||
this.messageResponse = messageResponse;
|
this.tags = builder.tags;
|
||||||
this.topics = topics;
|
this.responseTags = builder.responseTags;
|
||||||
}
|
this.timestamp = builder.timestamp;
|
||||||
|
|
||||||
public MessageData(Long guildId, Long threadId, Long userId, String messageType, String messageText,
|
|
||||||
String messageResponse, Set<String> topics, Timestamp timestamp) {
|
|
||||||
this.guildId = guildId;
|
|
||||||
this.threadId = threadId;
|
|
||||||
this.userId = userId;
|
|
||||||
this.messageType = messageType;
|
|
||||||
this.messageText = messageText;
|
|
||||||
this.messageResponse = messageResponse;
|
|
||||||
this.topics = topics;
|
|
||||||
this.timestamp = timestamp;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getGuildId() {
|
public Long getGuildId() {
|
||||||
return guildId;
|
return guildId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getThreadId() {
|
public Long getChannelId() {
|
||||||
return threadId;
|
return channelId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getUserId() {
|
public Long getUserId() {
|
||||||
return userId;
|
return userId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getMessageType() {
|
public String getType() {
|
||||||
return messageType;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getMessageText() {
|
public String getPrompt() {
|
||||||
return messageText;
|
return prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getMessageResponse() {
|
public String getResponse() {
|
||||||
return messageResponse;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<String> getTopics() {
|
public Set<String> getTags() {
|
||||||
return topics;
|
return tags;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<String> getResponseTags() {
|
||||||
|
return responseTags;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Timestamp getTimestamp() {
|
public Timestamp getTimestamp() {
|
||||||
return timestamp;
|
return timestamp;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setTimestamp(Timestamp timestamp) {
|
public static class MessageDataBuilder {
|
||||||
this.timestamp = timestamp;
|
private Long guildId;
|
||||||
|
private Long channelId;
|
||||||
|
private Long userId;
|
||||||
|
private String type = "";
|
||||||
|
private String prompt = "";
|
||||||
|
private String response = "";
|
||||||
|
private Set<String> tags = new LinkedHashSet<>();
|
||||||
|
private Set<String> responseTags = new LinkedHashSet<>();
|
||||||
|
private Timestamp timestamp = new Timestamp(0);
|
||||||
|
|
||||||
|
public MessageDataBuilder guildId(Long guildId) {
|
||||||
|
this.guildId = guildId;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder channelId(Long channelId) {
|
||||||
|
this.channelId = channelId;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder userId(Long userId) {
|
||||||
|
this.userId = userId;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder type(String type) {
|
||||||
|
this.type = type;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder prompt(String prompt) {
|
||||||
|
this.prompt = prompt;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder response(String response) {
|
||||||
|
this.response = response;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder tags(Set<String> tags) {
|
||||||
|
this.tags = tags;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder responseTags(Set<String> responseTags) {
|
||||||
|
this.responseTags = responseTags;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageDataBuilder timestamp(Timestamp timestamp) {
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageData build() {
|
||||||
|
return new MessageData(this);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user