v0.1.15 Refactoring and updating OpenAI usage again
This commit is contained in:
3
.env.TEMPLATE
Normal file
3
.env.TEMPLATE
Normal file
@@ -0,0 +1,3 @@
|
||||
export POSTGRES_USER=
|
||||
export POSTGRES_PASSWORD=
|
||||
export POSTGRES_DB=
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,5 +1,7 @@
|
||||
.idea/
|
||||
**/target/
|
||||
**/data/
|
||||
**/app/
|
||||
**/settings.json
|
||||
**/logs/
|
||||
**/logs/
|
||||
.env
|
||||
3
Makefile
3
Makefile
@@ -1,8 +1,9 @@
|
||||
SHELL := /bin/bash
|
||||
include .version
|
||||
include .env
|
||||
|
||||
build:
|
||||
docker rmi siren && docker-compose build
|
||||
if docker inspect siren > /dev/null 2>&1; then docker rmi siren; fi; docker-compose build
|
||||
|
||||
test:
|
||||
docker run --rm -it siren:latest bash
|
||||
|
||||
@@ -38,13 +38,15 @@ https://discord.com/api/oauth2/authorize?client_id=<Client ID>&permissions=54696
|
||||
- applications.commands
|
||||
```
|
||||
|
||||
`docker build -t siren .`
|
||||
`docker-compose up -d`
|
||||
```
|
||||
make build
|
||||
make up
|
||||
```
|
||||
|
||||
## Development
|
||||
Build container
|
||||
|
||||
`docker build -t siren .`
|
||||
`make build`
|
||||
|
||||
Run container locally
|
||||
|
||||
|
||||
@@ -11,5 +11,22 @@ services:
|
||||
- JAVA_VERSION=17
|
||||
- VERSION=${SIREN_VERSION}
|
||||
volumes:
|
||||
- ./data:/app
|
||||
- ./app:/app
|
||||
environment:
|
||||
DATABASE_URL: jdbc:postgresql://db:5432/${POSTGRES_DB}
|
||||
DATABASE_USERNAME: ${POSTGRES_USER}
|
||||
DATABASE_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
depends_on:
|
||||
- db
|
||||
restart: unless-stopped
|
||||
db:
|
||||
image: postgres:latest
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
volumes:
|
||||
- ./data:/var/lib/postgresql/data
|
||||
ports:
|
||||
- "5432:5432"
|
||||
restart: unless-stopped
|
||||
9
pom.xml
9
pom.xml
@@ -66,6 +66,7 @@
|
||||
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
|
||||
<jackson.version>2.14.2</jackson.version>
|
||||
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
||||
<postgresql.version>42.6.0</postgresql.version>
|
||||
<slf4j.version>2.0.6</slf4j.version>
|
||||
<log4j.version>2.20.0</log4j.version>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
@@ -74,7 +75,6 @@
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<!-- Discord Dependencies -->
|
||||
<dependency>
|
||||
<groupId>net.dv8tion</groupId>
|
||||
<artifactId>JDA</artifactId>
|
||||
@@ -95,13 +95,16 @@
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- OpenAI https://github.com/TheoKanning/openai-java -->
|
||||
<dependency>
|
||||
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||
<artifactId>service</artifactId>
|
||||
<version>${theokanning-openai-gpt3.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.postgresql</groupId>
|
||||
<artifactId>postgresql</artifactId>
|
||||
<version>${postgresql.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Logging -->
|
||||
<dependency>
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.bensherriff.siren;
|
||||
import com.bensherriff.siren.audio.AudioHandler;
|
||||
import com.bensherriff.siren.audio.PlayerManager;
|
||||
import com.bensherriff.siren.commands.*;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
||||
import com.bensherriff.siren.openai.OpenAIManager;
|
||||
import com.bensherriff.siren.settings.GuildSettings;
|
||||
@@ -68,6 +69,8 @@ public class Listener extends ListenerAdapter {
|
||||
this.playerManager.initialize();
|
||||
this.openAIManager = new OpenAIManager(this);
|
||||
|
||||
DatabaseManager.createTables();
|
||||
|
||||
commands.put("play", new PlayCommand(this));
|
||||
commands.put("stop", new StopCommand(this));
|
||||
commands.put("skip", new SkipCommand(this));
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import com.bensherriff.siren.openai.OpenAIManager;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class DatabaseConnection {
|
||||
private static final Logger LOGGER = LogManager.getLogger(DatabaseConnection.class);
|
||||
private static final ExecutorService executorService = Executors.newFixedThreadPool(4);
|
||||
private static final ThreadLocal<Connection> connectionThreadLocal = new ThreadLocal<>();
|
||||
|
||||
public static Connection getConnection() throws SQLException {
|
||||
Connection connection = connectionThreadLocal.get();
|
||||
|
||||
if (connection == null) {
|
||||
Map<String, String> env = System.getenv();
|
||||
String dbUrl = env.get("DATABASE_URL");
|
||||
String dbUsername = env.get("DATABASE_USERNAME");
|
||||
String dbPassword = env.get("DATABASE_PASSWORD");
|
||||
|
||||
connection = DriverManager.getConnection(dbUrl, dbUsername, dbPassword);
|
||||
connectionThreadLocal.set(connection);
|
||||
}
|
||||
|
||||
return connection;
|
||||
}
|
||||
|
||||
public static void closeConnection() {
|
||||
Connection connection = connectionThreadLocal.get();
|
||||
connectionThreadLocal.remove();
|
||||
|
||||
if (connection != null) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
connection.close();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public static void shutdown() {
|
||||
executorService.shutdown();
|
||||
try {
|
||||
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
|
||||
executorService.shutdownNow();
|
||||
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
|
||||
LOGGER.error("ExecutorService did not terminate");
|
||||
}
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
executorService.shutdownNow();
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.sql.*;
|
||||
import java.util.List;
|
||||
|
||||
public class DatabaseManager {
|
||||
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
|
||||
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"guild_id BIGINT NOT NULL, " +
|
||||
"thread_id BIGINT NOT NULL, " +
|
||||
"user_id BIGINT NOT NULL, " +
|
||||
"message_type VARCHAR(20) NOT NULL," +
|
||||
"message_text TEXT NOT NULL," +
|
||||
"message_response TEXT, " +
|
||||
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
|
||||
")";
|
||||
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"embeddings FLOAT[] NOT NULL" +
|
||||
");";
|
||||
|
||||
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
||||
"id SERIAL PRIMARY KEY, " +
|
||||
"message_id INT NOT NULL, " +
|
||||
"embedding_id INT NOT NULL" +
|
||||
")";
|
||||
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||
"message_type, " +
|
||||
"guild_id, " +
|
||||
"thread_id, " +
|
||||
"user_id, " +
|
||||
"message_text, " +
|
||||
"message_response) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?)";
|
||||
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
||||
"embeddings) " +
|
||||
"VALUES (?)";
|
||||
|
||||
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
|
||||
"message_id, " +
|
||||
"embeddings_id, " +
|
||||
"VALUES (?, ?)";
|
||||
|
||||
public static void createTables() {
|
||||
createMessageTable();
|
||||
createEmbeddingsTable();
|
||||
createMessageEmbeddingsTable();
|
||||
}
|
||||
|
||||
private static void createMessageTable() {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'messages' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_MESSAGE_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static void createEmbeddingsTable() {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_EMBEDDINGS_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static void createMessageEmbeddingsTable() {
|
||||
try {
|
||||
Connection connection = DatabaseConnection.getConnection();
|
||||
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
|
||||
Statement statement = connection.createStatement();
|
||||
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
|
||||
preparedStatement.setString(1, message.getRole());
|
||||
preparedStatement.setLong(2, guildId);
|
||||
preparedStatement.setLong(3, threadId);
|
||||
preparedStatement.setLong(4, userId);
|
||||
preparedStatement.setString(5, message.getContent());
|
||||
preparedStatement.setString(6, response);
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeEmbedding(List<Double> data) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
|
||||
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
|
||||
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
|
||||
preparedStatement.setInt(1, messageId);
|
||||
preparedStatement.setInt(2, embeddingId);
|
||||
return preparedStatement.executeUpdate();
|
||||
} catch (SQLException ex) {
|
||||
LOGGER.error(ex.getMessage());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.bensherriff.siren.openai;
|
||||
|
||||
import com.bensherriff.siren.Listener;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.settings.GuildSettings;
|
||||
import com.bensherriff.siren.settings.Settings;
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
@@ -8,6 +9,9 @@ import com.theokanning.openai.completion.CompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import com.theokanning.openai.embedding.Embedding;
|
||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||
import com.theokanning.openai.embedding.EmbeddingResult;
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
import net.dv8tion.jda.api.JDA;
|
||||
import net.dv8tion.jda.api.entities.Message;
|
||||
@@ -20,6 +24,7 @@ import org.apache.logging.log4j.Logger;
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class OpenAIManager {
|
||||
private static final Logger LOGGER = LogManager.getLogger(OpenAIManager.class);
|
||||
@@ -27,13 +32,11 @@ public class OpenAIManager {
|
||||
private final Settings settings;
|
||||
private final JDA jda;
|
||||
private final ScheduledExecutorService executor;
|
||||
private final Map<String, List<ChatMessage>> threadMessages;
|
||||
|
||||
public OpenAIManager(Listener listener) {
|
||||
this.settings = listener.getSettings();
|
||||
this.jda = listener.getJDA();
|
||||
this.executor = listener.getExecutor();
|
||||
this.threadMessages = new HashMap<>();
|
||||
|
||||
if (settings.getOpenAISettings().getToken().isEmpty()) {
|
||||
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
|
||||
@@ -62,62 +65,35 @@ public class OpenAIManager {
|
||||
String message = parseMessage(event.getMessage().getContentRaw());
|
||||
long guildId = event.getGuild().getIdLong();
|
||||
Model model = settings.getGuildSettings().get(guildId).getModel();
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(guildId);
|
||||
|
||||
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
ChatMessage chatMessage = createChatMessage(message, event);
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
|
||||
if (message.isEmpty() || message.isBlank()) {
|
||||
event.getMessage().reply("Your message is empty. Please try again").queue();
|
||||
return;
|
||||
}
|
||||
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
ChatMessage chatMessage = createChatMessage(message, event);
|
||||
|
||||
try {
|
||||
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
||||
// Send OpenAI Message and get response
|
||||
switch (model) {
|
||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||
CompletionRequest completionRequest = CompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.prompt(message)
|
||||
.build();
|
||||
CompletionRequest completionRequest = createCompletionRequest(message, event);
|
||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||
}
|
||||
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
||||
//TODO Handle memories properly
|
||||
// if (event.isFromThread()) {
|
||||
// String channelId = event.getChannel().asThreadChannel().getId();
|
||||
// // Update ThreadMessages with the new message, and add previous messages to be sent out
|
||||
// if (threadMessages.containsKey(channelId)) {
|
||||
// threadMessages.get(channelId).add(chatMessage);
|
||||
// chatMessages.addAll(threadMessages.get(channelId));
|
||||
// } else {
|
||||
// threadMessages.put(channelId, new ArrayList<>(Arrays.asList(chatMessage)));
|
||||
// }
|
||||
// }
|
||||
chatMessages.add(chatMessage);
|
||||
|
||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.messages(chatMessages)
|
||||
.build();
|
||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
|
||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||
|
||||
//TODO fix embeddings
|
||||
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
|
||||
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
|
||||
// embeddings.addAll(embeddingResult.getData());
|
||||
}
|
||||
default -> {
|
||||
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
|
||||
@@ -126,29 +102,50 @@ public class OpenAIManager {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Respond to user
|
||||
if (event.isFromThread()) {
|
||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||
channel.sendMessage(stringBuilder.toString()).queue();
|
||||
} else {
|
||||
// The max discord title length is 100 characters
|
||||
String threadTitle = message;
|
||||
if (message.length() > 100) {
|
||||
threadTitle = message.substring(0, 100);
|
||||
}
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||
channel.sendMessage(stringBuilder.toString()).queue();
|
||||
// threadMessages.put(channel.getId(), new ArrayList<>(Arrays.asList(chatMessage)));
|
||||
});
|
||||
}
|
||||
|
||||
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
||||
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
|
||||
}
|
||||
}
|
||||
|
||||
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return EmbeddingRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.user(event.getAuthor().getId())
|
||||
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return ChatCompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.messages(chatMessages)
|
||||
.build();
|
||||
}
|
||||
|
||||
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return CompletionRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.maxTokens(guildSettings.getMaxTokens())
|
||||
.user(event.getAuthor().getId())
|
||||
.temperature(settings.getOpenAISettings().getTemperature())
|
||||
.topP(settings.getOpenAISettings().getTopP())
|
||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||
.prompt(message)
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
|
||||
ChatMessage chatMessage = new ChatMessage();
|
||||
chatMessage.setContent(message);
|
||||
@@ -160,6 +157,34 @@ public class OpenAIManager {
|
||||
return chatMessage;
|
||||
}
|
||||
|
||||
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
||||
if (event.isFromThread()) {
|
||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||
channel.getIdLong(), response);
|
||||
embeddings.forEach(embedding -> {
|
||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||
});
|
||||
channel.sendMessage(response).queue();
|
||||
} else {
|
||||
// The max discord title length is 100 characters
|
||||
String threadTitle = chatMessage.getContent();
|
||||
if (chatMessage.getContent().length() > 100) {
|
||||
threadTitle = chatMessage.getContent().substring(0, 100);
|
||||
}
|
||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||
channel.sendMessage(response).queue();
|
||||
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||
channel.getIdLong(), response);
|
||||
embeddings.forEach(embedding -> {
|
||||
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private String parseMessage(String input) {
|
||||
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user