From 7e0a2a4e6433ef5d16f691cab88671f1576388c1 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Sun, 16 Apr 2023 09:25:35 -0400 Subject: [PATCH] v0.1.17 Added image generation support --- .version | 2 +- .../java/com/bensherriff/siren/Listener.java | 26 ++++++------ .../com/bensherriff/siren/ai/ImageSize.java | 17 ++++++++ .../siren/{openai => ai}/Model.java | 8 ++-- .../bensherriff/siren/{openai => ai}/NLP.java | 2 +- .../siren/{openai => ai}/OpenAIManager.java | 17 ++++++-- .../siren/{openai => ai}/Role.java | 4 +- .../siren/commands/ImageCommand.java | 41 +++++++++++++++++++ .../siren/database/DatabaseConnection.java | 1 - .../siren/database/DatabaseManager.java | 2 +- .../siren/settings/GuildSettings.java | 4 +- .../siren/settings/OpenAISettings.java | 2 +- 12 files changed, 97 insertions(+), 29 deletions(-) create mode 100644 src/main/java/com/bensherriff/siren/ai/ImageSize.java rename src/main/java/com/bensherriff/siren/{openai => ai}/Model.java (67%) rename src/main/java/com/bensherriff/siren/{openai => ai}/NLP.java (98%) rename src/main/java/com/bensherriff/siren/{openai => ai}/OpenAIManager.java (94%) rename src/main/java/com/bensherriff/siren/{openai => ai}/Role.java (74%) create mode 100644 src/main/java/com/bensherriff/siren/commands/ImageCommand.java diff --git a/.version b/.version index d649612..0ad30f2 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -export SIREN_VERSION=0.1.16 \ No newline at end of file +export SIREN_VERSION=0.1.17 \ No newline at end of file diff --git a/src/main/java/com/bensherriff/siren/Listener.java b/src/main/java/com/bensherriff/siren/Listener.java index e781384..045143c 100644 --- a/src/main/java/com/bensherriff/siren/Listener.java +++ b/src/main/java/com/bensherriff/siren/Listener.java @@ -5,7 +5,7 @@ import com.bensherriff.siren.audio.PlayerManager; import com.bensherriff.siren.commands.*; import com.bensherriff.siren.database.DatabaseManager; import com.bensherriff.siren.exceptions.EmptyVoiceChannelException; -import com.bensherriff.siren.openai.OpenAIManager; +import com.bensherriff.siren.ai.OpenAIManager; import com.bensherriff.siren.settings.GuildSettings; import com.bensherriff.siren.settings.Settings; import com.bensherriff.siren.settings.SettingsManager; @@ -57,6 +57,10 @@ public class Listener extends ListenerAdapter { return settings; } + public String getOwner() { + return owner; + } + public JDA getJDA() { return jda; } @@ -65,8 +69,8 @@ public class Listener extends ListenerAdapter { this.jda = jda; } - public String getOwner() { - return owner; + public OpenAIManager getOpenAIManager() { + return openAIManager; } public void initialize() { @@ -75,15 +79,6 @@ public class Listener extends ListenerAdapter { this.openAIManager = new OpenAIManager(this); DatabaseManager.createTables(); - - commands.put("play", new PlayCommand(this)); - commands.put("stop", new StopCommand(this)); - commands.put("skip", new SkipCommand(this)); - commands.put("volume", new VolumeCommand(this)); - commands.put("pause", new PauseCommand(this)); - commands.put("resume", new ResumeCommand(this)); - commands.put("help", new PauseCommand(this)); - } public void closeAudioConnection(long guildID) { @@ -130,6 +125,13 @@ public class Listener extends ListenerAdapter { @Override public void onReady(@NotNull ReadyEvent event) { super.onReady(event); + commands.put("play", new PlayCommand(this)); + commands.put("stop", new StopCommand(this)); + commands.put("skip", new SkipCommand(this)); + commands.put("volume", new VolumeCommand(this)); + commands.put("pause", new PauseCommand(this)); + commands.put("resume", new ResumeCommand(this)); + commands.put("image", new ImageCommand(this)); jda.getGuilds().forEach(guild -> executor.execute(() -> { LOGGER.debug("Updating commands for {}", guild.getId()); guild.updateCommands().addCommands( diff --git a/src/main/java/com/bensherriff/siren/ai/ImageSize.java b/src/main/java/com/bensherriff/siren/ai/ImageSize.java new file mode 100644 index 0000000..0540580 --- /dev/null +++ b/src/main/java/com/bensherriff/siren/ai/ImageSize.java @@ -0,0 +1,17 @@ +package com.bensherriff.siren.ai; + +public enum ImageSize { + SMALL("256x256"), + MEDIUM("512x512"), + LARGE("1024x1024"); + + private final String size; + + ImageSize(String size) { + this.size = size; + } + + public String getSize() { + return size; + } +} diff --git a/src/main/java/com/bensherriff/siren/openai/Model.java b/src/main/java/com/bensherriff/siren/ai/Model.java similarity index 67% rename from src/main/java/com/bensherriff/siren/openai/Model.java rename to src/main/java/com/bensherriff/siren/ai/Model.java index d7ca1b1..a939521 100644 --- a/src/main/java/com/bensherriff/siren/openai/Model.java +++ b/src/main/java/com/bensherriff/siren/ai/Model.java @@ -1,4 +1,4 @@ -package com.bensherriff.siren.openai; +package com.bensherriff.siren.ai; public enum Model { DAVINCI_3("text-davinci-003"), @@ -7,15 +7,13 @@ public enum Model { BABBAGE_1("text-babbage-001"), ADA_1("text-ada-001"), GPT_4("gpt-4"), - GPT_4_0314("gpt-4-0314"), GPT_4_32K("gpt-4-32k"), - GPT_4_32K_0314("gpt-4-32k-0314"), GPT_35_TURBO("gpt-3.5-turbo"), - GPT_35_TURBO_0301("gpt-3.5-turbo-0301"), + ADA_EMBEDDING_2("text-embedding-ada-002") ; private final String name; - private Model(String name) { + Model(String name) { this.name = name; } diff --git a/src/main/java/com/bensherriff/siren/openai/NLP.java b/src/main/java/com/bensherriff/siren/ai/NLP.java similarity index 98% rename from src/main/java/com/bensherriff/siren/openai/NLP.java rename to src/main/java/com/bensherriff/siren/ai/NLP.java index b792201..0572bcb 100644 --- a/src/main/java/com/bensherriff/siren/openai/NLP.java +++ b/src/main/java/com/bensherriff/siren/ai/NLP.java @@ -1,4 +1,4 @@ -package com.bensherriff.siren.openai; +package com.bensherriff.siren.ai; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreLabel; diff --git a/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java b/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java similarity index 94% rename from src/main/java/com/bensherriff/siren/openai/OpenAIManager.java rename to src/main/java/com/bensherriff/siren/ai/OpenAIManager.java index 311c2c9..a9be659 100644 --- a/src/main/java/com/bensherriff/siren/openai/OpenAIManager.java +++ b/src/main/java/com/bensherriff/siren/ai/OpenAIManager.java @@ -1,4 +1,4 @@ -package com.bensherriff.siren.openai; +package com.bensherriff.siren.ai; import com.bensherriff.siren.Listener; import com.bensherriff.siren.database.DatabaseManager; @@ -13,6 +13,8 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.embedding.Embedding; import com.theokanning.openai.embedding.EmbeddingRequest; +import com.theokanning.openai.image.CreateImageRequest; +import com.theokanning.openai.image.ImageResult; import com.theokanning.openai.service.OpenAiService; import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.entities.Message; @@ -90,7 +92,7 @@ public class OpenAIManager { CompletionResult completionResult = openAiService.createCompletion(completionRequest); completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim())); } - case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> { + case GPT_4, GPT_4_32K, GPT_35_TURBO -> { ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event); ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest); chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim())); @@ -117,12 +119,21 @@ public class OpenAIManager { private EmbeddingRequest createEmbeddingRequest(List chatMessages, MessageReceivedEvent event) { GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()); return EmbeddingRequest.builder() - .model(guildSettings.getModel().getName()) + .model(Model.ADA_EMBEDDING_2.getName()) .user(event.getAuthor().getId()) .input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList())) .build(); } + public ImageResult createImage(String prompt, int count, ImageSize imageSize) { + LOGGER.trace("Generating {} image(s) of size {} with prompt <{}>", count, imageSize.getSize(), prompt); + return openAiService.createImage(CreateImageRequest.builder() + .prompt(prompt) + .size(imageSize.getSize()) + .n(count <= 0 ? 1 : Math.min(count, 3)) + .build()); + } + private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException { GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong()); List chatMessages = new ArrayList<>(); diff --git a/src/main/java/com/bensherriff/siren/openai/Role.java b/src/main/java/com/bensherriff/siren/ai/Role.java similarity index 74% rename from src/main/java/com/bensherriff/siren/openai/Role.java rename to src/main/java/com/bensherriff/siren/ai/Role.java index 0121c72..e777072 100644 --- a/src/main/java/com/bensherriff/siren/openai/Role.java +++ b/src/main/java/com/bensherriff/siren/ai/Role.java @@ -1,4 +1,4 @@ -package com.bensherriff.siren.openai; +package com.bensherriff.siren.ai; public enum Role { SYSTEM("system"), @@ -7,7 +7,7 @@ public enum Role { ; private final String name; - private Role(String name) { + Role(String name) { this.name = name; } diff --git a/src/main/java/com/bensherriff/siren/commands/ImageCommand.java b/src/main/java/com/bensherriff/siren/commands/ImageCommand.java new file mode 100644 index 0000000..25cff0d --- /dev/null +++ b/src/main/java/com/bensherriff/siren/commands/ImageCommand.java @@ -0,0 +1,41 @@ +package com.bensherriff.siren.commands; + +import com.bensherriff.siren.Listener; +import com.bensherriff.siren.ai.ImageSize; +import com.theokanning.openai.image.ImageResult; +import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent; +import net.dv8tion.jda.api.interactions.commands.OptionMapping; +import net.dv8tion.jda.api.interactions.commands.OptionType; +import net.dv8tion.jda.api.interactions.commands.build.Commands; + +import java.io.IOException; +import java.util.Objects; + +public class ImageCommand extends Command { + + public ImageCommand(Listener listener) { + super(listener); + slashCommandData = Commands.slash("image", "Generate an image using DALL-E") + .addOption(OptionType.STRING, "prompt", "The prompt for image generation", true) + .addOption(OptionType.INTEGER, "count", "The number of images to be generated", false); + } + + @Override + public void execute(SlashCommandInteractionEvent event) throws IOException { + if (event.getUser().getId().equals(listener.getSettings().getOwner())) { + String prompt = Objects.requireNonNull(event.getOption("prompt")).getAsString(); + int count = 1; + OptionMapping countOption = event.getOption("count"); + if (countOption != null) { + count = countOption.getAsInt(); + } + + ImageResult result = listener.getOpenAIManager().createImage(prompt, count, ImageSize.SMALL); + StringBuilder responseURLS = new StringBuilder(); + result.getData().forEach(image -> responseURLS.append(image.getUrl()).append("\n")); + event.getHook().sendMessage(responseURLS.toString()).queue(); + } else { + event.getHook().sendMessage("This command is currently only available to " + listener.getOwner() + ".").queue(); + } + } +} diff --git a/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java b/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java index 46bf275..e773237 100644 --- a/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java +++ b/src/main/java/com/bensherriff/siren/database/DatabaseConnection.java @@ -1,6 +1,5 @@ package com.bensherriff.siren.database; -import com.bensherriff.siren.openai.OpenAIManager; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; diff --git a/src/main/java/com/bensherriff/siren/database/DatabaseManager.java b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java index 2e4f075..a06e690 100644 --- a/src/main/java/com/bensherriff/siren/database/DatabaseManager.java +++ b/src/main/java/com/bensherriff/siren/database/DatabaseManager.java @@ -58,7 +58,7 @@ public class DatabaseManager { return; } } - LOGGER.debug("Successfully created database tables"); + LOGGER.debug("Databases initialized"); } private static boolean createTable(String tableName) { diff --git a/src/main/java/com/bensherriff/siren/settings/GuildSettings.java b/src/main/java/com/bensherriff/siren/settings/GuildSettings.java index ed3474a..febae94 100644 --- a/src/main/java/com/bensherriff/siren/settings/GuildSettings.java +++ b/src/main/java/com/bensherriff/siren/settings/GuildSettings.java @@ -1,6 +1,6 @@ package com.bensherriff.siren.settings; -import com.bensherriff.siren.openai.Model; +import com.bensherriff.siren.ai.Model; import java.io.IOException; @@ -30,7 +30,7 @@ public class GuildSettings { return volume; } - public void setVolume(int volume) throws IOException { + public void setVolume(int volume) { this.volume = volume; } diff --git a/src/main/java/com/bensherriff/siren/settings/OpenAISettings.java b/src/main/java/com/bensherriff/siren/settings/OpenAISettings.java index b5cce34..98265f3 100644 --- a/src/main/java/com/bensherriff/siren/settings/OpenAISettings.java +++ b/src/main/java/com/bensherriff/siren/settings/OpenAISettings.java @@ -1,6 +1,6 @@ package com.bensherriff.siren.settings; -import com.bensherriff.siren.openai.Role; +import com.bensherriff.siren.ai.Role; public class OpenAISettings {