diff --git a/.version b/.version
index 06172af..00bd0d1 100644
--- a/.version
+++ b/.version
@@ -1 +1 @@
-export SIREN_VERSION=0.1.22
\ No newline at end of file
+export SIREN_VERSION=0.1.23
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 0a366ab..f0f29ed 100644
--- a/pom.xml
+++ b/pom.xml
@@ -67,7 +67,7 @@
2.14.2
0.12.0
42.6.0
- 4.5.3
+ 4.2.0
2.0.7
2.20.0
UTF-8
diff --git a/src/main/java/com/bensherriff/siren/ai/NLP.java b/src/main/java/com/bensherriff/siren/ai/NLP.java
index 0572bcb..b2e3a26 100644
--- a/src/main/java/com/bensherriff/siren/ai/NLP.java
+++ b/src/main/java/com/bensherriff/siren/ai/NLP.java
@@ -4,6 +4,7 @@ import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
+import edu.stanford.nlp.sentiment.SentimentCoreAnnotations;
import edu.stanford.nlp.util.CoreMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -17,57 +18,64 @@ public class NLP {
public NLP() {
Properties props = new Properties();
- props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner");
+ props.setProperty("annotators", "tokenize, ssplit, pos, lemma, ner, parse, dcoref, sentiment");
pipeline = new StanfordCoreNLP(props);
keywords = new HashMap<>();
keywords.put("dnd", Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid"));
}
- public Set getTopicKeywords(String sentence) {
- Set topics = new LinkedHashSet<>();
- Annotation document = new Annotation(sentence);
- pipeline.annotate(document);
+ private List getSentences(String text) {
+ Annotation annotation = new Annotation(text);
+ pipeline.annotate(annotation);
+ return annotation.get(CoreAnnotations.SentencesAnnotation.class);
+ }
- List sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
- CoreMap sentenceMap = sentences.get(0);
- List tokens = sentenceMap.get(CoreAnnotations.TokensAnnotation.class);
- List namedEntities = new ArrayList<>();
+ public Set getTags(String text) {
+ Set tags = new LinkedHashSet<>();
- for (CoreLabel token : tokens) {
- String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
- if (!ne.equals("0")) {
- namedEntities.add(token);
- }
- }
- for (CoreLabel namedEntity : namedEntities) {
- String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
- String word = namedEntity.word();
- if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
- topics.add(word);
- } else if (ne.equals("LOCATION")) {
- String[] posTags = word.split("_");
- for (String posTag : posTags) {
- if (posTag.startsWith("N")) {
- topics.add(word);
- break;
- }
+ List sentences = getSentences(text);
+ for (CoreMap sentence : sentences) {
+ List tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
+ List namedEntities = new ArrayList<>();
+
+ for (CoreLabel token : tokens) {
+ String ne = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
+ if (!ne.equals("0")) {
+ namedEntities.add(token);
}
- } else {
- String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
- if (pos.startsWith("NN")) {
- topics.add(word);
- for (String keyword : keywords.keySet()) {
- if (keywords.get(keyword).contains(word.toLowerCase())) {
- topics.add(keyword);
+ if (token.equals(tokens.get(tokens.size() - 1)) && token.word().equals("?") ||
+ (List.of("what", "when", "who", "where", "why", "how").contains(token.word()))) {
+ tags.add("question");
+ }
+ }
+
+ for (CoreLabel namedEntity : namedEntities) {
+ String ne = namedEntity.get(CoreAnnotations.NamedEntityTagAnnotation.class);
+ String word = namedEntity.word();
+ if (ne.equals("PERSON") || ne.equals("ORGANIZATION")) {
+ tags.add(word);
+ } else if (ne.equals("LOCATION")) {
+ String[] posTags = word.split("_");
+ for (String posTag : posTags) {
+ if (posTag.startsWith("N")) {
+ tags.add(word);
+ break;
}
}
- if (Arrays.asList("dnd", "dungeons", "dragons", "sorcerer", "warlock", "cleric", "fighter", "rogue", "bard", "wizard", "paladin", "ranger", "druid").contains(word.toLowerCase())) {
- topics.add("dnd");
+ } else {
+ String pos = namedEntity.get(CoreAnnotations.PartOfSpeechAnnotation.class);
+ if (pos.startsWith("NN")) {
+ tags.add(word);
+ for (String keyword : keywords.keySet()) {
+ if (keywords.get(keyword).contains(word.toLowerCase())) {
+ tags.add(keyword);
+ }
+ }
}
}
}
}
- return topics;
+ return tags;
}
public List lemmatize(String documentText) {
@@ -83,13 +91,31 @@ public class NLP {
return lemmas;
}
- // TODO finish this method
- public Set getSynonyms(String targetWord) {
- return Collections.emptySet();
+ /**
+ * Determines the sentiment (tone) of each sentence in the text. For example: positive, negative, or neutral
+ * @param text the input text
+ * @return the list of sentiments
+ */
+ public List sentimentAnalysis(String text) {
+ Annotation document = new Annotation(text);
+ pipeline.annotate(document);
+ List sentences = document.get(CoreAnnotations.SentencesAnnotation.class);
+ List sentiments = new ArrayList<>();
+ for (CoreMap sentence : sentences) {
+ sentiments.add(sentence.get(SentimentCoreAnnotations.SentimentClass.class));
+ }
+ return sentiments;
}
- // TODO finish this method
- public static double calculateSimilarity(String sentence1, String sentence2) {
- return 0.0;
+ public static double cosineSimilarity(double[] vector1, double[] vector2) {
+ double dotProduct = 0.0;
+ double norm1 = 0.0;
+ double norm2 = 0.0;
+ for (int i = 0; i < vector1.length; i++) {
+ dotProduct += vector1[i] * vector2[i];
+ norm1 += Math.pow(vector1[i], 2);
+ norm2 += Math.pow(vector2[i], 2);
+ }
+ return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
}
}
diff --git a/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java b/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java
index f3165a3..724cc10 100644
--- a/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java
+++ b/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java
@@ -10,7 +10,6 @@ import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
-import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.image.CreateImageRequest;
import com.theokanning.openai.image.ImageResult;
@@ -88,34 +87,39 @@ public class OpenAIManager {
userSettings = settings.getGuildSettings().get(guildId).getUserSettings().get(authorId);
Model model = userSettings.getModel();
ChatMessage chatMessage = createUserMessage(message);
- List embeddings = new ArrayList<>();
+
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, authorId, event.getMessageId(), message);
- // Send OpenAI Message and get response
- switch (model) {
- case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
- CompletionRequest completionRequest = createCompletionRequest(message, event);
- CompletionResult completionResult = openAiService.createCompletion(completionRequest);
- completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
- }
- case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
- ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
- ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
- chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
-
- //TODO fix embeddings
-// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
-// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
-// embeddings.addAll(embeddingResult.getData());
- }
- default -> {
- event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
- LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
- model, Arrays.toString(Model.values()));
- return;
+ String query = new QueryBuilder("messages")
+ .where("guild_id = ? AND channel_id = ? AND prompt = ?")
+ .orderBy("timestamp DESC")
+ .build();
+ List messages = DatabaseManager.parseResponse(DatabaseManager.query(query,
+ guildId, event.getChannel().getIdLong(), chatMessage.getContent()));
+ if (!messages.isEmpty()) {
+ stringBuilder.append(messages.get(0).getResponse());
+ } else {
+ // Send OpenAI Message and get response
+ switch (model) {
+ case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
+ CompletionRequest completionRequest = createCompletionRequest(message, event);
+ CompletionResult completionResult = openAiService.createCompletion(completionRequest);
+ completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
+ }
+ case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
+ ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
+ ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
+ chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
+ }
+ default -> {
+ event.getMessage().reply("Unexpected model in settings. Please contact " + owner + ".").queue();
+ LOGGER.warn("Unexpected model in settings for guild {}: {}. Expected one of {}", guildId,
+ model, Arrays.toString(Model.values()));
+ return;
+ }
}
}
- handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
+ handleResponse(chatMessage, event, stringBuilder.toString());
} catch (Exception ex) {
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
event.getMessage().reply("An error occurred while processing your message. Please contact " + owner + ".").queue();
@@ -123,7 +127,6 @@ public class OpenAIManager {
}
private EmbeddingRequest createEmbeddingRequest(List chatMessages, MessageReceivedEvent event) {
- GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return EmbeddingRequest.builder()
.model(Model.ADA_EMBEDDING_2.getName())
.user(event.getAuthor().getId())
@@ -144,66 +147,48 @@ public class OpenAIManager {
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
.getUserSettings().get(event.getAuthor().getIdLong());
List chatMessages = new ArrayList<>();
+// List previousMessages = getPreviousMessages(event);
+// for (MessageData previousMessage : previousMessages) {
+// chatMessages.add(createSystemMessage("In a previous conversation, we discussed the topics " + previousMessage.getTags() +
+// " and " + previousMessage.getResponseTags()));
+// chatMessages.add(createSystemMessage("I sent you the prompt \"" + previousMessage.getPrompt() +
+// "\" and you replied with \"" + previousMessage.getResponse() + "\""));
+// }
chatMessages.add(chatMessage);
// Handle System Messages
chatMessages.add(createSystemMessage("You are a discord bot named Siren"));
chatMessages.add(createSystemMessage("My name is " + event.getAuthor().getName()));
- if (event.isFromThread()) {
- String query = new QueryBuilder("messages")
- .where("guild_id = ? AND thread_id = ?")
- .orderBy("timestamp DESC")
- .limit(10)
- .build();
-
- // Build MessageData objects from query results
- List