v0.1.15 Refactoring and updating OpenAI usage again
This commit is contained in:
3
.env.TEMPLATE
Normal file
3
.env.TEMPLATE
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
export POSTGRES_USER=
|
||||||
|
export POSTGRES_PASSWORD=
|
||||||
|
export POSTGRES_DB=
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,5 +1,7 @@
|
|||||||
.idea/
|
.idea/
|
||||||
**/target/
|
**/target/
|
||||||
**/data/
|
**/data/
|
||||||
|
**/app/
|
||||||
**/settings.json
|
**/settings.json
|
||||||
**/logs/
|
**/logs/
|
||||||
|
.env
|
||||||
3
Makefile
3
Makefile
@@ -1,8 +1,9 @@
|
|||||||
SHELL := /bin/bash
|
SHELL := /bin/bash
|
||||||
|
include .version
|
||||||
include .env
|
include .env
|
||||||
|
|
||||||
build:
|
build:
|
||||||
docker rmi siren && docker-compose build
|
if docker inspect siren > /dev/null 2>&1; then docker rmi siren; fi; docker-compose build
|
||||||
|
|
||||||
test:
|
test:
|
||||||
docker run --rm -it siren:latest bash
|
docker run --rm -it siren:latest bash
|
||||||
|
|||||||
@@ -38,13 +38,15 @@ https://discord.com/api/oauth2/authorize?client_id=<Client ID>&permissions=54696
|
|||||||
- applications.commands
|
- applications.commands
|
||||||
```
|
```
|
||||||
|
|
||||||
`docker build -t siren .`
|
```
|
||||||
`docker-compose up -d`
|
make build
|
||||||
|
make up
|
||||||
|
```
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
Build container
|
Build container
|
||||||
|
|
||||||
`docker build -t siren .`
|
`make build`
|
||||||
|
|
||||||
Run container locally
|
Run container locally
|
||||||
|
|
||||||
|
|||||||
@@ -11,5 +11,22 @@ services:
|
|||||||
- JAVA_VERSION=17
|
- JAVA_VERSION=17
|
||||||
- VERSION=${SIREN_VERSION}
|
- VERSION=${SIREN_VERSION}
|
||||||
volumes:
|
volumes:
|
||||||
- ./data:/app
|
- ./app:/app
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: jdbc:postgresql://db:5432/${POSTGRES_DB}
|
||||||
|
DATABASE_USERNAME: ${POSTGRES_USER}
|
||||||
|
DATABASE_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
depends_on:
|
||||||
|
- db
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
db:
|
||||||
|
image: postgres:latest
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: ${POSTGRES_USER}
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
POSTGRES_DB: ${POSTGRES_DB}
|
||||||
|
volumes:
|
||||||
|
- ./data:/var/lib/postgresql/data
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
restart: unless-stopped
|
||||||
9
pom.xml
9
pom.xml
@@ -66,6 +66,7 @@
|
|||||||
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
|
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
|
||||||
<jackson.version>2.14.2</jackson.version>
|
<jackson.version>2.14.2</jackson.version>
|
||||||
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
|
||||||
|
<postgresql.version>42.6.0</postgresql.version>
|
||||||
<slf4j.version>2.0.6</slf4j.version>
|
<slf4j.version>2.0.6</slf4j.version>
|
||||||
<log4j.version>2.20.0</log4j.version>
|
<log4j.version>2.20.0</log4j.version>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
@@ -74,7 +75,6 @@
|
|||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<!-- Discord Dependencies -->
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>net.dv8tion</groupId>
|
<groupId>net.dv8tion</groupId>
|
||||||
<artifactId>JDA</artifactId>
|
<artifactId>JDA</artifactId>
|
||||||
@@ -95,13 +95,16 @@
|
|||||||
<artifactId>jackson-databind</artifactId>
|
<artifactId>jackson-databind</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>${jackson.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<!-- OpenAI https://github.com/TheoKanning/openai-java -->
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||||
<artifactId>service</artifactId>
|
<artifactId>service</artifactId>
|
||||||
<version>${theokanning-openai-gpt3.version}</version>
|
<version>${theokanning-openai-gpt3.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.postgresql</groupId>
|
||||||
|
<artifactId>postgresql</artifactId>
|
||||||
|
<version>${postgresql.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<!-- Logging -->
|
<!-- Logging -->
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.bensherriff.siren;
|
|||||||
import com.bensherriff.siren.audio.AudioHandler;
|
import com.bensherriff.siren.audio.AudioHandler;
|
||||||
import com.bensherriff.siren.audio.PlayerManager;
|
import com.bensherriff.siren.audio.PlayerManager;
|
||||||
import com.bensherriff.siren.commands.*;
|
import com.bensherriff.siren.commands.*;
|
||||||
|
import com.bensherriff.siren.database.DatabaseManager;
|
||||||
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
||||||
import com.bensherriff.siren.openai.OpenAIManager;
|
import com.bensherriff.siren.openai.OpenAIManager;
|
||||||
import com.bensherriff.siren.settings.GuildSettings;
|
import com.bensherriff.siren.settings.GuildSettings;
|
||||||
@@ -68,6 +69,8 @@ public class Listener extends ListenerAdapter {
|
|||||||
this.playerManager.initialize();
|
this.playerManager.initialize();
|
||||||
this.openAIManager = new OpenAIManager(this);
|
this.openAIManager = new OpenAIManager(this);
|
||||||
|
|
||||||
|
DatabaseManager.createTables();
|
||||||
|
|
||||||
commands.put("play", new PlayCommand(this));
|
commands.put("play", new PlayCommand(this));
|
||||||
commands.put("stop", new StopCommand(this));
|
commands.put("stop", new StopCommand(this));
|
||||||
commands.put("skip", new SkipCommand(this));
|
commands.put("skip", new SkipCommand(this));
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package com.bensherriff.siren.database;
|
||||||
|
|
||||||
|
import com.bensherriff.siren.openai.OpenAIManager;
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
|
|
||||||
|
import java.sql.Connection;
|
||||||
|
import java.sql.DriverManager;
|
||||||
|
import java.sql.SQLException;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
public class DatabaseConnection {
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(DatabaseConnection.class);
|
||||||
|
private static final ExecutorService executorService = Executors.newFixedThreadPool(4);
|
||||||
|
private static final ThreadLocal<Connection> connectionThreadLocal = new ThreadLocal<>();
|
||||||
|
|
||||||
|
public static Connection getConnection() throws SQLException {
|
||||||
|
Connection connection = connectionThreadLocal.get();
|
||||||
|
|
||||||
|
if (connection == null) {
|
||||||
|
Map<String, String> env = System.getenv();
|
||||||
|
String dbUrl = env.get("DATABASE_URL");
|
||||||
|
String dbUsername = env.get("DATABASE_USERNAME");
|
||||||
|
String dbPassword = env.get("DATABASE_PASSWORD");
|
||||||
|
|
||||||
|
connection = DriverManager.getConnection(dbUrl, dbUsername, dbPassword);
|
||||||
|
connectionThreadLocal.set(connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
return connection;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void closeConnection() {
|
||||||
|
Connection connection = connectionThreadLocal.get();
|
||||||
|
connectionThreadLocal.remove();
|
||||||
|
|
||||||
|
if (connection != null) {
|
||||||
|
executorService.submit(() -> {
|
||||||
|
try {
|
||||||
|
connection.close();
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void shutdown() {
|
||||||
|
executorService.shutdown();
|
||||||
|
try {
|
||||||
|
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
|
||||||
|
executorService.shutdownNow();
|
||||||
|
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
|
||||||
|
LOGGER.error("ExecutorService did not terminate");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
executorService.shutdownNow();
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package com.bensherriff.siren.database;
|
||||||
|
|
||||||
|
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
|
|
||||||
|
import java.sql.*;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class DatabaseManager {
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(DatabaseManager.class);
|
||||||
|
private static final String CREATE_MESSAGE_TABLE = "CREATE TABLE IF NOT EXISTS messages (" +
|
||||||
|
"id SERIAL PRIMARY KEY, " +
|
||||||
|
"guild_id BIGINT NOT NULL, " +
|
||||||
|
"thread_id BIGINT NOT NULL, " +
|
||||||
|
"user_id BIGINT NOT NULL, " +
|
||||||
|
"message_type VARCHAR(20) NOT NULL," +
|
||||||
|
"message_text TEXT NOT NULL," +
|
||||||
|
"message_response TEXT, " +
|
||||||
|
"timestamp TIMESTAMP NOT NULL DEFAULT NOW()" +
|
||||||
|
")";
|
||||||
|
private static final String CREATE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS embeddings (" +
|
||||||
|
"id SERIAL PRIMARY KEY, " +
|
||||||
|
"embeddings FLOAT[] NOT NULL" +
|
||||||
|
");";
|
||||||
|
|
||||||
|
private static final String CREATE_MESSAGE_EMBEDDINGS_TABLE = "CREATE TABLE IF NOT EXISTS message_embeddings (" +
|
||||||
|
"id SERIAL PRIMARY KEY, " +
|
||||||
|
"message_id INT NOT NULL, " +
|
||||||
|
"embedding_id INT NOT NULL" +
|
||||||
|
")";
|
||||||
|
private static final String INSERT_MESSAGE = "INSERT INTO messages (" +
|
||||||
|
"message_type, " +
|
||||||
|
"guild_id, " +
|
||||||
|
"thread_id, " +
|
||||||
|
"user_id, " +
|
||||||
|
"message_text, " +
|
||||||
|
"message_response) " +
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?)";
|
||||||
|
private static final String INSERT_EMBEDDING = "INSERT INTO embeddings (" +
|
||||||
|
"embeddings) " +
|
||||||
|
"VALUES (?)";
|
||||||
|
|
||||||
|
private static final String INSERT_MESSAGE_EMBEDDINGS = "INSERT INTO message_embeddings" +
|
||||||
|
"message_id, " +
|
||||||
|
"embeddings_id, " +
|
||||||
|
"VALUES (?, ?)";
|
||||||
|
|
||||||
|
public static void createTables() {
|
||||||
|
createMessageTable();
|
||||||
|
createEmbeddingsTable();
|
||||||
|
createMessageEmbeddingsTable();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void createMessageTable() {
|
||||||
|
try {
|
||||||
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
|
LOGGER.debug("Creating 'messages' database table if it does not exist");
|
||||||
|
Statement statement = connection.createStatement();
|
||||||
|
statement.execute(CREATE_MESSAGE_TABLE);
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void createEmbeddingsTable() {
|
||||||
|
try {
|
||||||
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
|
LOGGER.debug("Creating 'embeddings' database table if it does not exist");
|
||||||
|
Statement statement = connection.createStatement();
|
||||||
|
statement.execute(CREATE_EMBEDDINGS_TABLE);
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void createMessageEmbeddingsTable() {
|
||||||
|
try {
|
||||||
|
Connection connection = DatabaseConnection.getConnection();
|
||||||
|
LOGGER.debug("Creating 'message_embeddings' database table if it does not exist");
|
||||||
|
Statement statement = connection.createStatement();
|
||||||
|
statement.execute(CREATE_MESSAGE_EMBEDDINGS_TABLE);
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int storeMessage(ChatMessage message, long guildId, long userId, long threadId, String response) {
|
||||||
|
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE)) {
|
||||||
|
preparedStatement.setString(1, message.getRole());
|
||||||
|
preparedStatement.setLong(2, guildId);
|
||||||
|
preparedStatement.setLong(3, threadId);
|
||||||
|
preparedStatement.setLong(4, userId);
|
||||||
|
preparedStatement.setString(5, message.getContent());
|
||||||
|
preparedStatement.setString(6, response);
|
||||||
|
return preparedStatement.executeUpdate();
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int storeEmbedding(List<Double> data) {
|
||||||
|
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_EMBEDDING)) {
|
||||||
|
preparedStatement.setArray(1, connection.createArrayOf("float8", data.toArray(new Double[0])));
|
||||||
|
return preparedStatement.executeUpdate();
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int storeMessageEmbeddings(int messageId, int embeddingId) {
|
||||||
|
try (Connection connection = DatabaseConnection.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(INSERT_MESSAGE_EMBEDDINGS)) {
|
||||||
|
preparedStatement.setInt(1, messageId);
|
||||||
|
preparedStatement.setInt(2, embeddingId);
|
||||||
|
return preparedStatement.executeUpdate();
|
||||||
|
} catch (SQLException ex) {
|
||||||
|
LOGGER.error(ex.getMessage());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.bensherriff.siren.openai;
|
package com.bensherriff.siren.openai;
|
||||||
|
|
||||||
import com.bensherriff.siren.Listener;
|
import com.bensherriff.siren.Listener;
|
||||||
|
import com.bensherriff.siren.database.DatabaseManager;
|
||||||
import com.bensherriff.siren.settings.GuildSettings;
|
import com.bensherriff.siren.settings.GuildSettings;
|
||||||
import com.bensherriff.siren.settings.Settings;
|
import com.bensherriff.siren.settings.Settings;
|
||||||
import com.theokanning.openai.completion.CompletionRequest;
|
import com.theokanning.openai.completion.CompletionRequest;
|
||||||
@@ -8,6 +9,9 @@ import com.theokanning.openai.completion.CompletionResult;
|
|||||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||||
|
import com.theokanning.openai.embedding.Embedding;
|
||||||
|
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||||
|
import com.theokanning.openai.embedding.EmbeddingResult;
|
||||||
import com.theokanning.openai.service.OpenAiService;
|
import com.theokanning.openai.service.OpenAiService;
|
||||||
import net.dv8tion.jda.api.JDA;
|
import net.dv8tion.jda.api.JDA;
|
||||||
import net.dv8tion.jda.api.entities.Message;
|
import net.dv8tion.jda.api.entities.Message;
|
||||||
@@ -20,6 +24,7 @@ import org.apache.logging.log4j.Logger;
|
|||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class OpenAIManager {
|
public class OpenAIManager {
|
||||||
private static final Logger LOGGER = LogManager.getLogger(OpenAIManager.class);
|
private static final Logger LOGGER = LogManager.getLogger(OpenAIManager.class);
|
||||||
@@ -27,13 +32,11 @@ public class OpenAIManager {
|
|||||||
private final Settings settings;
|
private final Settings settings;
|
||||||
private final JDA jda;
|
private final JDA jda;
|
||||||
private final ScheduledExecutorService executor;
|
private final ScheduledExecutorService executor;
|
||||||
private final Map<String, List<ChatMessage>> threadMessages;
|
|
||||||
|
|
||||||
public OpenAIManager(Listener listener) {
|
public OpenAIManager(Listener listener) {
|
||||||
this.settings = listener.getSettings();
|
this.settings = listener.getSettings();
|
||||||
this.jda = listener.getJDA();
|
this.jda = listener.getJDA();
|
||||||
this.executor = listener.getExecutor();
|
this.executor = listener.getExecutor();
|
||||||
this.threadMessages = new HashMap<>();
|
|
||||||
|
|
||||||
if (settings.getOpenAISettings().getToken().isEmpty()) {
|
if (settings.getOpenAISettings().getToken().isEmpty()) {
|
||||||
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
|
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
|
||||||
@@ -62,62 +65,35 @@ public class OpenAIManager {
|
|||||||
String message = parseMessage(event.getMessage().getContentRaw());
|
String message = parseMessage(event.getMessage().getContentRaw());
|
||||||
long guildId = event.getGuild().getIdLong();
|
long guildId = event.getGuild().getIdLong();
|
||||||
Model model = settings.getGuildSettings().get(guildId).getModel();
|
Model model = settings.getGuildSettings().get(guildId).getModel();
|
||||||
GuildSettings guildSettings = settings.getGuildSettings().get(guildId);
|
StringBuilder stringBuilder = new StringBuilder();
|
||||||
|
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||||
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
ChatMessage chatMessage = createChatMessage(message, event);
|
||||||
|
List<Embedding> embeddings = new ArrayList<>();
|
||||||
|
|
||||||
if (message.isEmpty() || message.isBlank()) {
|
if (message.isEmpty() || message.isBlank()) {
|
||||||
event.getMessage().reply("Your message is empty. Please try again").queue();
|
event.getMessage().reply("Your message is empty. Please try again").queue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
StringBuilder stringBuilder = new StringBuilder();
|
|
||||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
|
||||||
ChatMessage chatMessage = createChatMessage(message, event);
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
LOGGER.trace("Guild: <{}> User: <{}> Message: <{}>", guildId, event.getAuthor().getId(), message);
|
||||||
// Send OpenAI Message and get response
|
// Send OpenAI Message and get response
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
|
||||||
CompletionRequest completionRequest = CompletionRequest.builder()
|
CompletionRequest completionRequest = createCompletionRequest(message, event);
|
||||||
.model(guildSettings.getModel().getName())
|
|
||||||
.maxTokens(guildSettings.getMaxTokens())
|
|
||||||
.user(event.getAuthor().getId())
|
|
||||||
.temperature(settings.getOpenAISettings().getTemperature())
|
|
||||||
.topP(settings.getOpenAISettings().getTopP())
|
|
||||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
|
||||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
|
||||||
.prompt(message)
|
|
||||||
.build();
|
|
||||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||||
}
|
}
|
||||||
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
||||||
//TODO Handle memories properly
|
|
||||||
// if (event.isFromThread()) {
|
|
||||||
// String channelId = event.getChannel().asThreadChannel().getId();
|
|
||||||
// // Update ThreadMessages with the new message, and add previous messages to be sent out
|
|
||||||
// if (threadMessages.containsKey(channelId)) {
|
|
||||||
// threadMessages.get(channelId).add(chatMessage);
|
|
||||||
// chatMessages.addAll(threadMessages.get(channelId));
|
|
||||||
// } else {
|
|
||||||
// threadMessages.put(channelId, new ArrayList<>(Arrays.asList(chatMessage)));
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
chatMessages.add(chatMessage);
|
chatMessages.add(chatMessage);
|
||||||
|
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessages, event);
|
||||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
|
||||||
.model(guildSettings.getModel().getName())
|
|
||||||
.maxTokens(guildSettings.getMaxTokens())
|
|
||||||
.user(event.getAuthor().getId())
|
|
||||||
.temperature(settings.getOpenAISettings().getTemperature())
|
|
||||||
.topP(settings.getOpenAISettings().getTopP())
|
|
||||||
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
|
||||||
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
|
||||||
.messages(chatMessages)
|
|
||||||
.build();
|
|
||||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||||
|
|
||||||
|
//TODO fix embeddings
|
||||||
|
// EmbeddingRequest embeddingRequest = createEmbeddingRequest(chatMessages, event);
|
||||||
|
// EmbeddingResult embeddingResult = openAiService.createEmbeddings(embeddingRequest);
|
||||||
|
// embeddings.addAll(embeddingResult.getData());
|
||||||
}
|
}
|
||||||
default -> {
|
default -> {
|
||||||
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
|
event.getMessage().reply("Unexpected model in settings. Please contact an administrator.").queue();
|
||||||
@@ -126,29 +102,50 @@ public class OpenAIManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
handleResponse(chatMessage, event, stringBuilder.toString(), embeddings);
|
||||||
// Respond to user
|
|
||||||
if (event.isFromThread()) {
|
|
||||||
ThreadChannel channel = event.getChannel().asThreadChannel();
|
|
||||||
channel.sendMessage(stringBuilder.toString()).queue();
|
|
||||||
} else {
|
|
||||||
// The max discord title length is 100 characters
|
|
||||||
String threadTitle = message;
|
|
||||||
if (message.length() > 100) {
|
|
||||||
threadTitle = message.substring(0, 100);
|
|
||||||
}
|
|
||||||
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
|
||||||
channel.sendMessage(stringBuilder.toString()).queue();
|
|
||||||
// threadMessages.put(channel.getId(), new ArrayList<>(Arrays.asList(chatMessage)));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
LOGGER.error("Caught exception while processing message; {}", ex.getMessage());
|
||||||
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
|
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||||
|
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||||
|
return EmbeddingRequest.builder()
|
||||||
|
.model(guildSettings.getModel().getName())
|
||||||
|
.user(event.getAuthor().getId())
|
||||||
|
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private ChatCompletionRequest createCompletionRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||||
|
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||||
|
return ChatCompletionRequest.builder()
|
||||||
|
.model(guildSettings.getModel().getName())
|
||||||
|
.maxTokens(guildSettings.getMaxTokens())
|
||||||
|
.user(event.getAuthor().getId())
|
||||||
|
.temperature(settings.getOpenAISettings().getTemperature())
|
||||||
|
.topP(settings.getOpenAISettings().getTopP())
|
||||||
|
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||||
|
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||||
|
.messages(chatMessages)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
|
||||||
|
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||||
|
return CompletionRequest.builder()
|
||||||
|
.model(guildSettings.getModel().getName())
|
||||||
|
.maxTokens(guildSettings.getMaxTokens())
|
||||||
|
.user(event.getAuthor().getId())
|
||||||
|
.temperature(settings.getOpenAISettings().getTemperature())
|
||||||
|
.topP(settings.getOpenAISettings().getTopP())
|
||||||
|
.frequencyPenalty(settings.getOpenAISettings().getFrequencyPenalty())
|
||||||
|
.presencePenalty(settings.getOpenAISettings().getPresencePenalty())
|
||||||
|
.prompt(message)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
|
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
|
||||||
ChatMessage chatMessage = new ChatMessage();
|
ChatMessage chatMessage = new ChatMessage();
|
||||||
chatMessage.setContent(message);
|
chatMessage.setContent(message);
|
||||||
@@ -160,6 +157,34 @@ public class OpenAIManager {
|
|||||||
return chatMessage;
|
return chatMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {
|
||||||
|
if (event.isFromThread()) {
|
||||||
|
ThreadChannel channel = event.getChannel().asThreadChannel();
|
||||||
|
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||||
|
channel.getIdLong(), response);
|
||||||
|
embeddings.forEach(embedding -> {
|
||||||
|
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||||
|
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||||
|
});
|
||||||
|
channel.sendMessage(response).queue();
|
||||||
|
} else {
|
||||||
|
// The max discord title length is 100 characters
|
||||||
|
String threadTitle = chatMessage.getContent();
|
||||||
|
if (chatMessage.getContent().length() > 100) {
|
||||||
|
threadTitle = chatMessage.getContent().substring(0, 100);
|
||||||
|
}
|
||||||
|
event.getMessage().createThreadChannel(threadTitle).queue(channel -> {
|
||||||
|
channel.sendMessage(response).queue();
|
||||||
|
int messageRow = DatabaseManager.storeMessage(chatMessage, event.getGuild().getIdLong(), event.getAuthor().getIdLong(),
|
||||||
|
channel.getIdLong(), response);
|
||||||
|
embeddings.forEach(embedding -> {
|
||||||
|
int embeddingRow = DatabaseManager.storeEmbedding(embedding.getEmbedding());
|
||||||
|
DatabaseManager.storeMessageEmbeddings(messageRow, embeddingRow);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private String parseMessage(String input) {
|
private String parseMessage(String input) {
|
||||||
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user