v0.1.12 Added OpenAI capabilities

This commit is contained in:
2023-04-14 22:24:59 -04:00
parent bc0b440a01
commit a8e4f0f99d
17 changed files with 343 additions and 152 deletions

View File

@@ -9,7 +9,7 @@ services:
dockerfile: ./Dockerfile
args:
- JAVA_VERSION=17
- VERSION=0.1.11
- VERSION=0.1.12
volumes:
- ./data:/app
restart: unless-stopped

12
pom.xml
View File

@@ -3,7 +3,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.bensherriff</groupId>
<artifactId>siren</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
<packaging>jar</packaging>
<repositories>
@@ -65,6 +65,7 @@
<lavaplayer.version>1.4.0</lavaplayer.version>
<lavaplayer-natives-extra.version>1.3.13</lavaplayer-natives-extra.version>
<jackson.version>2.14.2</jackson.version>
<theokanning-openai-gpt3.version>0.12.0</theokanning-openai-gpt3.version>
<slf4j.version>2.0.6</slf4j.version>
<log4j.version>2.20.0</log4j.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
@@ -95,6 +96,13 @@
<version>${jackson.version}</version>
</dependency>
<!-- OpenAI https://github.com/TheoKanning/openai-java -->
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
<version>${theokanning-openai-gpt3.version}</version>
</dependency>
<!-- Logging -->
<dependency>
<groupId>org.apache.logging.log4j</groupId>
@@ -139,7 +147,7 @@
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<manifestEntries>
<Main-Class>com.bensherriff.siren.MusicBot</Main-Class>
<Main-Class>com.bensherriff.siren.Bot</Main-Class>
<Specification-Title>${project.artifactId}</Specification-Title>
<Specification-Version>${project.version}</Specification-Version>
<Implementation-Title>${project.artifactId}</Implementation-Title>

View File

@@ -1,6 +1,5 @@
package com.bensherriff.siren;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.settings.Settings;
import com.bensherriff.siren.settings.SettingsManager;
import net.dv8tion.jda.api.JDA;
@@ -14,8 +13,8 @@ import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.util.Arrays;
public class MusicBot {
private static final Logger LOGGER = LogManager.getLogger(MusicBot.class);
public class Bot {
private static final Logger LOGGER = LogManager.getLogger(Bot.class);
private final static GatewayIntent[] INTENTS = {
GatewayIntent.DIRECT_MESSAGES, GatewayIntent.GUILD_MESSAGES, GatewayIntent.GUILD_MESSAGE_REACTIONS,
GatewayIntent.GUILD_VOICE_STATES, GatewayIntent.MESSAGE_CONTENT

View File

@@ -0,0 +1,240 @@
package com.bensherriff.siren;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.PlayerManager;
import com.bensherriff.siren.commands.*;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.bensherriff.siren.settings.SettingsManager;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.entities.Member;
import net.dv8tion.jda.api.entities.MessageType;
import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel;
import net.dv8tion.jda.api.entities.channel.concrete.VoiceChannel;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
import net.dv8tion.jda.api.events.session.ReadyEvent;
import net.dv8tion.jda.api.hooks.ListenerAdapter;
import net.dv8tion.jda.api.managers.AudioManager;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import java.io.IOException;
import java.time.Duration;
import java.util.*;
import java.util.stream.Collectors;
public class Listener extends ListenerAdapter {
private static final Logger LOGGER = LogManager.getLogger(Listener.class);
private final PlayerManager playerManager;
private final Settings settings;
private final Map<String, Command> commands = new HashMap<>();
private OpenAiService openAiService;
private JDA jda;
public Listener(Settings settings) {
this.settings = settings;
if (settings.getOpenAISettings().getToken().isEmpty()) {
LOGGER.warn("No OpenAI token; OpenAI functionality is disabled");
} else {
openAiService = new OpenAiService(settings.getOpenAISettings().getToken(), Duration.ofMillis(settings.getOpenAISettings().getTimeout()));
}
this.playerManager = new PlayerManager(this);
this.playerManager.initialize();
commands.put("play", new PlayCommand(this));
commands.put("stop", new StopCommand(this));
commands.put("skip", new SkipCommand(this));
commands.put("volume", new VolumeCommand(this));
commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this));
}
public PlayerManager getPlayerManager() {
return playerManager;
}
public Settings getSettings() {
return settings;
}
public JDA getJDA() {
return jda;
}
public void setJDA(JDA jda) {
this.jda = jda;
}
public void closeAudioConnection(long guildID) {
Guild guild = jda.getGuildById(guildID);
if (guild != null) {
guild.getAudioManager().closeAudioConnection();
}
}
public void connectToVoiceChannel(String userID, AudioManager audioManager) throws EmptyVoiceChannelException {
if (!audioManager.isConnected()) {
Member member = audioManager.getGuild().getMemberById(userID);
if (member != null) {
if (member.getVoiceState() != null && member.getVoiceState().inAudioChannel()) {
VoiceChannel voiceChannel = Objects.requireNonNull(member.getVoiceState().getChannel()).asVoiceChannel();
LOGGER.debug("Connecting to channel {} in guild {}", voiceChannel.getId(), voiceChannel.getGuild().getId());
audioManager.openAudioConnection(voiceChannel);
} else {
throw new EmptyVoiceChannelException("Member {} is not connected to a voice channel");
}
}
}
}
public synchronized AudioHandler getGuildAudioPlayer(Guild guild) throws IOException {
long guildId = Long.parseLong(guild.getId());
AudioHandler audioHandler;
if (guild.getAudioManager().getSendingHandler() == null) {
LOGGER.info("Creating Audio Handler for guild {}", guildId);
if (!settings.getGuildSettings().containsKey(guildId)) {
settings.getGuildSettings().put(guildId, new GuildSettings());
SettingsManager.write(settings);
}
audioHandler = new AudioHandler(playerManager, guildId);
guild.getAudioManager().setSendingHandler(audioHandler);
} else {
audioHandler = (AudioHandler) guild.getAudioManager().getSendingHandler();
}
return audioHandler;
}
@Override
public void onReady(@NotNull ReadyEvent event) {
super.onReady(event);
jda.getGuilds().forEach(guild -> {
LOGGER.debug("Updating commands for {}", guild.getId());
guild.updateCommands().addCommands(commands.values().stream().map(Command::getSlashCommandData).collect(Collectors.toList())).queue();
});
super.onReady(event);
LOGGER.info("Ready!");
}
@Override
public void onSlashCommandInteraction(@NotNull SlashCommandInteractionEvent event) {
String command = event.getName();
event.deferReply().queue();
try {
if (commands.containsKey(command)) {
commands.get(command).execute(event);
} else {
event.getHook().sendMessage("Unexpected command received.").queue();
}
} catch (Exception ex) {
LOGGER.error(ex.getMessage());
event.getHook().sendMessage("An error occurred while processing your command. Please contact your administrator.").queue();
}
super.onSlashCommandInteraction(event);
}
@Override
public void onMessageReceived(@NotNull MessageReceivedEvent event) {
String message = parseMessage(event.getMessage().getContentRaw());
long guildId = event.getGuild().getIdLong();
GuildSettings guildSettings = settings.getGuildSettings().get(guildId);
String model = settings.getGuildSettings().get(guildId).getModel();
if (shouldReply(event)) {
if (openAiService != null) {
LOGGER.trace("{} Sending message: {}", guildId, message);
try {
StringBuilder stringBuilder = new StringBuilder();
if (model.equals("text-ada-001")) {
CompletionRequest request = CompletionRequest.builder()
.model(guildSettings.getModel())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.prompt(message)
.build();
CompletionResult result = openAiService.createCompletion(request);
result.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
} else if (model.equals("gpt-3.5-turbo")){
ChatMessage chatMessage = new ChatMessage();
chatMessage.setContent(message);
chatMessage.setRole(settings.getOpenAISettings().getDefaultRole());
ChatCompletionRequest request = ChatCompletionRequest.builder()
.model(guildSettings.getModel())
.maxTokens(guildSettings.getMaxTokens())
.user(event.getAuthor().getId())
.messages(List.of(chatMessage))
.build();
ChatCompletionResult result = openAiService.createChatCompletion(request);
result.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
}
if (event.isFromThread()) {
ThreadChannel channel = event.getChannel().asThreadChannel();
channel.sendMessage(stringBuilder.toString()).queue();
} else {
String threadTitle = message;
if (message.length() > 20) {
threadTitle = message.substring(0, 20);
}
event.getMessage().createThreadChannel(threadTitle).queue(threadChannel ->
threadChannel.sendMessage(stringBuilder.toString()).queue());
}
} catch (Exception ex) {
LOGGER.error(ex.getMessage());
event.getMessage().reply("An error occurred while processing your message. Please contact your administrator.").queue();
}
} else {
event.getMessage().reply("OpenAI functionality is not enabled. Please contact an administrator").queue();
}
}
}
private String parseMessage(String input) {
return input.replaceAll("<@.*?>", "").replaceAll(" +", " ").trim();
}
/**
* @param event Message event received
* @return true if the message should be replied to by the bot, otherwise false
*/
private boolean shouldReply(MessageReceivedEvent event) {
boolean shouldReply = false;
if (!event.getAuthor().isBot()) {
// Check if message mentions bot
shouldReply = event.getMessage().getMentions().getMembers().stream().anyMatch(m -> m.getId().equals(jda.getSelfUser().getId()));
// Check if message is a reply
if (!shouldReply) {
shouldReply = event.getMessage().getType().equals(MessageType.INLINE_REPLY) &&
event.getMessage().getReferencedMessage() != null &&
event.getMessage().getReferencedMessage().getAuthor().getId().equals(jda.getSelfUser().getId());
}
// Check if message is from a bot thread
if (!shouldReply) {
shouldReply = event.isFromThread() &&
event.getChannel().asThreadChannel().getOwner() != null &&
Objects.requireNonNull(event.getChannel().asThreadChannel().getOwner()).getId().equals(jda.getSelfUser().getId());
}
}
return shouldReply;
}
}

View File

@@ -1,133 +0,0 @@
package com.bensherriff.siren.audio;
import com.bensherriff.siren.commands.*;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.bensherriff.siren.settings.SettingsManager;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.entities.Member;
import net.dv8tion.jda.api.entities.channel.concrete.VoiceChannel;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.events.session.ReadyEvent;
import net.dv8tion.jda.api.hooks.ListenerAdapter;
import net.dv8tion.jda.api.managers.AudioManager;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class Listener extends ListenerAdapter {
private static final Logger LOGGER = LogManager.getLogger(Listener.class);
private final PlayerManager playerManager;
private final Settings settings;
private final Map<String, Command> commands = new HashMap<>();
private JDA jda;
public Listener(Settings settings) {
this.settings = settings;
this.playerManager = new PlayerManager(this);
this.playerManager.initialize();
commands.put("play", new PlayCommand(this));
commands.put("stop", new StopCommand(this));
commands.put("skip", new SkipCommand(this));
commands.put("volume", new VolumeCommand(this));
commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this));
}
public PlayerManager getPlayerManager() {
return playerManager;
}
public Settings getSettings() {
return settings;
}
public JDA getJDA() {
return jda;
}
public void setJDA(JDA jda) {
this.jda = jda;
}
public void closeAudioConnection(long guildID) {
Guild guild = jda.getGuildById(guildID);
if (guild != null) {
guild.getAudioManager().closeAudioConnection();
}
}
public void connectToVoiceChannel(String userID, AudioManager audioManager) throws EmptyVoiceChannelException {
if (!audioManager.isConnected()) {
Member member = audioManager.getGuild().getMemberById(userID);
if (member != null) {
if (member.getVoiceState() != null && member.getVoiceState().inAudioChannel()) {
VoiceChannel voiceChannel = Objects.requireNonNull(member.getVoiceState().getChannel()).asVoiceChannel();
LOGGER.debug("Connecting to channel {} in guild {}", voiceChannel.getId(), voiceChannel.getGuild().getId());
audioManager.openAudioConnection(voiceChannel);
} else {
throw new EmptyVoiceChannelException("Member {} is not connected to a voice channel");
}
}
}
}
public synchronized AudioHandler getGuildAudioPlayer(Guild guild) throws IOException {
long guildId = Long.parseLong(guild.getId());
AudioHandler audioHandler;
if (guild.getAudioManager().getSendingHandler() == null) {
LOGGER.info("Creating Audio Handler for guild {}", guildId);
if (!settings.getGuildSettings().containsKey(guildId)) {
settings.getGuildSettings().put(guildId, new GuildSettings());
SettingsManager.write(settings);
}
audioHandler = new AudioHandler(playerManager, guildId);
guild.getAudioManager().setSendingHandler(audioHandler);
} else {
audioHandler = (AudioHandler) guild.getAudioManager().getSendingHandler();
}
return audioHandler;
}
@Override
public void onReady(@NotNull ReadyEvent event) {
super.onReady(event);
jda.getGuilds().forEach(guild -> {
LOGGER.debug("Updating commands for {}", guild.getId());
guild.updateCommands().addCommands(commands.values().stream().map(Command::getSlashCommandData).collect(Collectors.toList())).queue();
});
super.onReady(event);
LOGGER.info("Ready!");
}
@Override
public void onSlashCommandInteraction(@NotNull SlashCommandInteractionEvent event) {
String command = event.getName();
event.deferReply().queue();
try {
if (commands.containsKey(command)) {
commands.get(command).execute(event);
} else {
event.getHook().sendMessage("Unexpected command received.").queue();
}
} catch (Exception ex) {
LOGGER.error(ex.getMessage());
event.getHook().sendMessage("An error occurred while processing your command. Please contact your administrator.").queue();
}
super.onSlashCommandInteraction(event);
}
}

View File

@@ -1,5 +1,6 @@
package com.bensherriff.siren.audio;
import com.bensherriff.siren.Listener;
import com.sedmelluq.discord.lavaplayer.player.DefaultAudioPlayerManager;
import com.sedmelluq.discord.lavaplayer.source.AudioSourceManagers;

View File

@@ -1,6 +1,6 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.Listener;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.build.SlashCommandData;

View File

@@ -1,7 +1,7 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.Listener;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.build.Commands;

View File

@@ -2,7 +2,7 @@ package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.Listener;
import com.sedmelluq.discord.lavaplayer.player.AudioLoadResultHandler;
import com.sedmelluq.discord.lavaplayer.tools.FriendlyException;
import com.sedmelluq.discord.lavaplayer.track.AudioPlaylist;

View File

@@ -1,7 +1,7 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.Listener;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.build.Commands;

View File

@@ -1,7 +1,7 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.Listener;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.build.Commands;

View File

@@ -1,7 +1,7 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.Listener;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.build.Commands;

View File

@@ -1,7 +1,7 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.audio.AudioHandler;
import com.bensherriff.siren.audio.Listener;
import com.bensherriff.siren.Listener;
import com.bensherriff.siren.settings.SettingsManager;
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;

View File

@@ -7,6 +7,9 @@ public class GuildSettings {
private String prefix = "!";
private int volume = 100;
private String model = "text-ada-001";
private int maxTokens = 100;
public String getPrefix() {
return prefix;
}
@@ -22,4 +25,20 @@ public class GuildSettings {
public void setVolume(int volume) throws IOException {
this.volume = volume;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public int getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(int maxTokens) {
this.maxTokens = maxTokens;
}
}

View File

@@ -0,0 +1,52 @@
package com.bensherriff.siren.settings;
import java.util.Arrays;
import java.util.List;
public class OpenAISettings {
private String token = "";
/**
* One of ['system', 'assistant', 'user']
* System: The system role provides full access to all OpenAI APIs and resources, including access to billing information and account management features. An API key with the system role can perform any action that is allowed by the OpenAI API.
* Assistant: The assistant role is designed for use with virtual assistants or chatbots that interact with users. An API key with the assistant role can perform actions related to natural language processing, such as generating text, answering questions, or completing tasks. The assistant role is more limited than the system role, but it still provides access to powerful language models such as GPT-3.
* User: The user role provides limited access to specific OpenAI APIs and resources. An API key with the user role can only perform actions related to specific use cases, such as accessing a particular language model or dataset. The user role is more restricted than the assistant or system roles, but it is suitable for many common use cases.
*/
private String defaultRole = "user";
/**
* In milliseconds
*/
private long timeout = 10000;
private final List<String> availableModels = Arrays.asList("gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-ada-001", "text-davinci-001", "text-davinci-002", "text-davinci-003");
public String getToken() {
return token;
}
public void setToken(String token) {
this.token = token;
}
public List<String> getAvailableModels() {
return availableModels;
}
public String getDefaultRole() {
return defaultRole;
}
public void setDefaultRole(String defaultRole) {
this.defaultRole = defaultRole;
}
public long getTimeout() {
return timeout;
}
public void setTimeout(long timeout) {
this.timeout = timeout;
}
}

View File

@@ -7,8 +7,8 @@ public class Settings {
private String token = "";
private String owner = "";
private Map<Long, GuildSettings> guildSettings = new HashMap<>();
private OpenAISettings openAISettings = new OpenAISettings();
public String getToken() {
return token;
@@ -33,4 +33,13 @@ private Map<Long, GuildSettings> guildSettings = new HashMap<>();
public void setGuildSettings(Map<Long, GuildSettings> guildSettings) {
this.guildSettings = guildSettings;
}
public OpenAISettings getOpenAISettings() {
return openAISettings;
}
public void setOpenAISettings(OpenAISettings openAISettings) {
this.openAISettings = openAISettings;
}
}

View File

@@ -23,10 +23,6 @@ public class SettingsManager {
}
public static Settings load(String path) throws IOException {
// If settings is not available, create new default settings
// try(InputStream inputStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("settings.json")) {
// return mapper.readValue(inputStream, Settings.class);
// }
File file = new File(path);
if (!file.exists()) {
LOGGER.warn("Settings file does not exist, creating new file at: {}", file.getPath());