v0.1.17 Added image generation support

This commit is contained in:
2023-04-16 09:25:35 -04:00
parent e6fb86c4e0
commit 7e0a2a4e64
12 changed files with 97 additions and 29 deletions

View File

@@ -1 +1 @@
export SIREN_VERSION=0.1.16
export SIREN_VERSION=0.1.17

View File

@@ -5,7 +5,7 @@ import com.bensherriff.siren.audio.PlayerManager;
import com.bensherriff.siren.commands.*;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
import com.bensherriff.siren.openai.OpenAIManager;
import com.bensherriff.siren.ai.OpenAIManager;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.bensherriff.siren.settings.SettingsManager;
@@ -57,6 +57,10 @@ public class Listener extends ListenerAdapter {
return settings;
}
public String getOwner() {
return owner;
}
public JDA getJDA() {
return jda;
}
@@ -65,8 +69,8 @@ public class Listener extends ListenerAdapter {
this.jda = jda;
}
public String getOwner() {
return owner;
public OpenAIManager getOpenAIManager() {
return openAIManager;
}
public void initialize() {
@@ -75,15 +79,6 @@ public class Listener extends ListenerAdapter {
this.openAIManager = new OpenAIManager(this);
DatabaseManager.createTables();
commands.put("play", new PlayCommand(this));
commands.put("stop", new StopCommand(this));
commands.put("skip", new SkipCommand(this));
commands.put("volume", new VolumeCommand(this));
commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this));
commands.put("help", new PauseCommand(this));
}
public void closeAudioConnection(long guildID) {
@@ -130,6 +125,13 @@ public class Listener extends ListenerAdapter {
@Override
public void onReady(@NotNull ReadyEvent event) {
super.onReady(event);
commands.put("play", new PlayCommand(this));
commands.put("stop", new StopCommand(this));
commands.put("skip", new SkipCommand(this));
commands.put("volume", new VolumeCommand(this));
commands.put("pause", new PauseCommand(this));
commands.put("resume", new ResumeCommand(this));
commands.put("image", new ImageCommand(this));
jda.getGuilds().forEach(guild -> executor.execute(() -> {
LOGGER.debug("Updating commands for {}", guild.getId());
guild.updateCommands().addCommands(

View File

@@ -0,0 +1,17 @@
package com.bensherriff.siren.ai;
public enum ImageSize {
SMALL("256x256"),
MEDIUM("512x512"),
LARGE("1024x1024");
private final String size;
ImageSize(String size) {
this.size = size;
}
public String getSize() {
return size;
}
}

View File

@@ -1,4 +1,4 @@
package com.bensherriff.siren.openai;
package com.bensherriff.siren.ai;
public enum Model {
DAVINCI_3("text-davinci-003"),
@@ -7,15 +7,13 @@ public enum Model {
BABBAGE_1("text-babbage-001"),
ADA_1("text-ada-001"),
GPT_4("gpt-4"),
GPT_4_0314("gpt-4-0314"),
GPT_4_32K("gpt-4-32k"),
GPT_4_32K_0314("gpt-4-32k-0314"),
GPT_35_TURBO("gpt-3.5-turbo"),
GPT_35_TURBO_0301("gpt-3.5-turbo-0301"),
ADA_EMBEDDING_2("text-embedding-ada-002")
;
private final String name;
private Model(String name) {
Model(String name) {
this.name = name;
}

View File

@@ -1,4 +1,4 @@
package com.bensherriff.siren.openai;
package com.bensherriff.siren.ai;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;

View File

@@ -1,4 +1,4 @@
package com.bensherriff.siren.openai;
package com.bensherriff.siren.ai;
import com.bensherriff.siren.Listener;
import com.bensherriff.siren.database.DatabaseManager;
@@ -13,6 +13,8 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.image.CreateImageRequest;
import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.service.OpenAiService;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.entities.Message;
@@ -90,7 +92,7 @@ public class OpenAIManager {
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
}
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
@@ -117,12 +119,21 @@ public class OpenAIManager {
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
return EmbeddingRequest.builder()
.model(guildSettings.getModel().getName())
.model(Model.ADA_EMBEDDING_2.getName())
.user(event.getAuthor().getId())
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
.build();
}
public ImageResult createImage(String prompt, int count, ImageSize imageSize) {
LOGGER.trace("Generating {} image(s) of size {} with prompt <{}>", count, imageSize.getSize(), prompt);
return openAiService.createImage(CreateImageRequest.builder()
.prompt(prompt)
.size(imageSize.getSize())
.n(count <= 0 ? 1 : Math.min(count, 3))
.build());
}
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
List<ChatMessage> chatMessages = new ArrayList<>();

View File

@@ -1,4 +1,4 @@
package com.bensherriff.siren.openai;
package com.bensherriff.siren.ai;
public enum Role {
SYSTEM("system"),
@@ -7,7 +7,7 @@ public enum Role {
;
private final String name;
private Role(String name) {
Role(String name) {
this.name = name;
}

View File

@@ -0,0 +1,41 @@
package com.bensherriff.siren.commands;
import com.bensherriff.siren.Listener;
import com.bensherriff.siren.ai.ImageSize;
import com.theokanning.openai.image.ImageResult;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.interactions.commands.OptionMapping;
import net.dv8tion.jda.api.interactions.commands.OptionType;
import net.dv8tion.jda.api.interactions.commands.build.Commands;
import java.io.IOException;
import java.util.Objects;
public class ImageCommand extends Command {
public ImageCommand(Listener listener) {
super(listener);
slashCommandData = Commands.slash("image", "Generate an image using DALL-E")
.addOption(OptionType.STRING, "prompt", "The prompt for image generation", true)
.addOption(OptionType.INTEGER, "count", "The number of images to be generated", false);
}
@Override
public void execute(SlashCommandInteractionEvent event) throws IOException {
if (event.getUser().getId().equals(listener.getSettings().getOwner())) {
String prompt = Objects.requireNonNull(event.getOption("prompt")).getAsString();
int count = 1;
OptionMapping countOption = event.getOption("count");
if (countOption != null) {
count = countOption.getAsInt();
}
ImageResult result = listener.getOpenAIManager().createImage(prompt, count, ImageSize.SMALL);
StringBuilder responseURLS = new StringBuilder();
result.getData().forEach(image -> responseURLS.append(image.getUrl()).append("\n"));
event.getHook().sendMessage(responseURLS.toString()).queue();
} else {
event.getHook().sendMessage("This command is currently only available to " + listener.getOwner() + ".").queue();
}
}
}

View File

@@ -1,6 +1,5 @@
package com.bensherriff.siren.database;
import com.bensherriff.siren.openai.OpenAIManager;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

View File

@@ -58,7 +58,7 @@ public class DatabaseManager {
return;
}
}
LOGGER.debug("Successfully created database tables");
LOGGER.debug("Databases initialized");
}
private static boolean createTable(String tableName) {

View File

@@ -1,6 +1,6 @@
package com.bensherriff.siren.settings;
import com.bensherriff.siren.openai.Model;
import com.bensherriff.siren.ai.Model;
import java.io.IOException;
@@ -30,7 +30,7 @@ public class GuildSettings {
return volume;
}
public void setVolume(int volume) throws IOException {
public void setVolume(int volume) {
this.volume = volume;
}

View File

@@ -1,6 +1,6 @@
package com.bensherriff.siren.settings;
import com.bensherriff.siren.openai.Role;
import com.bensherriff.siren.ai.Role;
public class OpenAISettings {