v0.1.17 Added image generation support
This commit is contained in:
@@ -5,7 +5,7 @@ import com.bensherriff.siren.audio.PlayerManager;
|
||||
import com.bensherriff.siren.commands.*;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
||||
import com.bensherriff.siren.openai.OpenAIManager;
|
||||
import com.bensherriff.siren.ai.OpenAIManager;
|
||||
import com.bensherriff.siren.settings.GuildSettings;
|
||||
import com.bensherriff.siren.settings.Settings;
|
||||
import com.bensherriff.siren.settings.SettingsManager;
|
||||
@@ -57,6 +57,10 @@ public class Listener extends ListenerAdapter {
|
||||
return settings;
|
||||
}
|
||||
|
||||
public String getOwner() {
|
||||
return owner;
|
||||
}
|
||||
|
||||
public JDA getJDA() {
|
||||
return jda;
|
||||
}
|
||||
@@ -65,8 +69,8 @@ public class Listener extends ListenerAdapter {
|
||||
this.jda = jda;
|
||||
}
|
||||
|
||||
public String getOwner() {
|
||||
return owner;
|
||||
public OpenAIManager getOpenAIManager() {
|
||||
return openAIManager;
|
||||
}
|
||||
|
||||
public void initialize() {
|
||||
@@ -75,15 +79,6 @@ public class Listener extends ListenerAdapter {
|
||||
this.openAIManager = new OpenAIManager(this);
|
||||
|
||||
DatabaseManager.createTables();
|
||||
|
||||
commands.put("play", new PlayCommand(this));
|
||||
commands.put("stop", new StopCommand(this));
|
||||
commands.put("skip", new SkipCommand(this));
|
||||
commands.put("volume", new VolumeCommand(this));
|
||||
commands.put("pause", new PauseCommand(this));
|
||||
commands.put("resume", new ResumeCommand(this));
|
||||
commands.put("help", new PauseCommand(this));
|
||||
|
||||
}
|
||||
|
||||
public void closeAudioConnection(long guildID) {
|
||||
@@ -130,6 +125,13 @@ public class Listener extends ListenerAdapter {
|
||||
@Override
|
||||
public void onReady(@NotNull ReadyEvent event) {
|
||||
super.onReady(event);
|
||||
commands.put("play", new PlayCommand(this));
|
||||
commands.put("stop", new StopCommand(this));
|
||||
commands.put("skip", new SkipCommand(this));
|
||||
commands.put("volume", new VolumeCommand(this));
|
||||
commands.put("pause", new PauseCommand(this));
|
||||
commands.put("resume", new ResumeCommand(this));
|
||||
commands.put("image", new ImageCommand(this));
|
||||
jda.getGuilds().forEach(guild -> executor.execute(() -> {
|
||||
LOGGER.debug("Updating commands for {}", guild.getId());
|
||||
guild.updateCommands().addCommands(
|
||||
|
||||
17
src/main/java/com/bensherriff/siren/ai/ImageSize.java
Normal file
17
src/main/java/com/bensherriff/siren/ai/ImageSize.java
Normal file
@@ -0,0 +1,17 @@
|
||||
package com.bensherriff.siren.ai;
|
||||
|
||||
public enum ImageSize {
|
||||
SMALL("256x256"),
|
||||
MEDIUM("512x512"),
|
||||
LARGE("1024x1024");
|
||||
|
||||
private final String size;
|
||||
|
||||
ImageSize(String size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
public String getSize() {
|
||||
return size;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.bensherriff.siren.openai;
|
||||
package com.bensherriff.siren.ai;
|
||||
|
||||
public enum Model {
|
||||
DAVINCI_3("text-davinci-003"),
|
||||
@@ -7,15 +7,13 @@ public enum Model {
|
||||
BABBAGE_1("text-babbage-001"),
|
||||
ADA_1("text-ada-001"),
|
||||
GPT_4("gpt-4"),
|
||||
GPT_4_0314("gpt-4-0314"),
|
||||
GPT_4_32K("gpt-4-32k"),
|
||||
GPT_4_32K_0314("gpt-4-32k-0314"),
|
||||
GPT_35_TURBO("gpt-3.5-turbo"),
|
||||
GPT_35_TURBO_0301("gpt-3.5-turbo-0301"),
|
||||
ADA_EMBEDDING_2("text-embedding-ada-002")
|
||||
;
|
||||
|
||||
private final String name;
|
||||
private Model(String name) {
|
||||
Model(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.bensherriff.siren.openai;
|
||||
package com.bensherriff.siren.ai;
|
||||
|
||||
import edu.stanford.nlp.ling.CoreAnnotations;
|
||||
import edu.stanford.nlp.ling.CoreLabel;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.bensherriff.siren.openai;
|
||||
package com.bensherriff.siren.ai;
|
||||
|
||||
import com.bensherriff.siren.Listener;
|
||||
import com.bensherriff.siren.database.DatabaseManager;
|
||||
@@ -13,6 +13,8 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import com.theokanning.openai.embedding.Embedding;
|
||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||
import com.theokanning.openai.image.CreateImageRequest;
|
||||
import com.theokanning.openai.image.ImageResult;
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
import net.dv8tion.jda.api.JDA;
|
||||
import net.dv8tion.jda.api.entities.Message;
|
||||
@@ -90,7 +92,7 @@ public class OpenAIManager {
|
||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||
}
|
||||
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
||||
case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
|
||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||
@@ -117,12 +119,21 @@ public class OpenAIManager {
|
||||
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
return EmbeddingRequest.builder()
|
||||
.model(guildSettings.getModel().getName())
|
||||
.model(Model.ADA_EMBEDDING_2.getName())
|
||||
.user(event.getAuthor().getId())
|
||||
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
|
||||
.build();
|
||||
}
|
||||
|
||||
public ImageResult createImage(String prompt, int count, ImageSize imageSize) {
|
||||
LOGGER.trace("Generating {} image(s) of size {} with prompt <{}>", count, imageSize.getSize(), prompt);
|
||||
return openAiService.createImage(CreateImageRequest.builder()
|
||||
.prompt(prompt)
|
||||
.size(imageSize.getSize())
|
||||
.n(count <= 0 ? 1 : Math.min(count, 3))
|
||||
.build());
|
||||
}
|
||||
|
||||
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
|
||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.bensherriff.siren.openai;
|
||||
package com.bensherriff.siren.ai;
|
||||
|
||||
public enum Role {
|
||||
SYSTEM("system"),
|
||||
@@ -7,7 +7,7 @@ public enum Role {
|
||||
;
|
||||
|
||||
private final String name;
|
||||
private Role(String name) {
|
||||
Role(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package com.bensherriff.siren.commands;
|
||||
|
||||
import com.bensherriff.siren.Listener;
|
||||
import com.bensherriff.siren.ai.ImageSize;
|
||||
import com.theokanning.openai.image.ImageResult;
|
||||
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
|
||||
import net.dv8tion.jda.api.interactions.commands.OptionMapping;
|
||||
import net.dv8tion.jda.api.interactions.commands.OptionType;
|
||||
import net.dv8tion.jda.api.interactions.commands.build.Commands;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ImageCommand extends Command {
|
||||
|
||||
public ImageCommand(Listener listener) {
|
||||
super(listener);
|
||||
slashCommandData = Commands.slash("image", "Generate an image using DALL-E")
|
||||
.addOption(OptionType.STRING, "prompt", "The prompt for image generation", true)
|
||||
.addOption(OptionType.INTEGER, "count", "The number of images to be generated", false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(SlashCommandInteractionEvent event) throws IOException {
|
||||
if (event.getUser().getId().equals(listener.getSettings().getOwner())) {
|
||||
String prompt = Objects.requireNonNull(event.getOption("prompt")).getAsString();
|
||||
int count = 1;
|
||||
OptionMapping countOption = event.getOption("count");
|
||||
if (countOption != null) {
|
||||
count = countOption.getAsInt();
|
||||
}
|
||||
|
||||
ImageResult result = listener.getOpenAIManager().createImage(prompt, count, ImageSize.SMALL);
|
||||
StringBuilder responseURLS = new StringBuilder();
|
||||
result.getData().forEach(image -> responseURLS.append(image.getUrl()).append("\n"));
|
||||
event.getHook().sendMessage(responseURLS.toString()).queue();
|
||||
} else {
|
||||
event.getHook().sendMessage("This command is currently only available to " + listener.getOwner() + ".").queue();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.bensherriff.siren.database;
|
||||
|
||||
import com.bensherriff.siren.openai.OpenAIManager;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ public class DatabaseManager {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOGGER.debug("Successfully created database tables");
|
||||
LOGGER.debug("Databases initialized");
|
||||
}
|
||||
|
||||
private static boolean createTable(String tableName) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.bensherriff.siren.settings;
|
||||
|
||||
import com.bensherriff.siren.openai.Model;
|
||||
import com.bensherriff.siren.ai.Model;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@@ -30,7 +30,7 @@ public class GuildSettings {
|
||||
return volume;
|
||||
}
|
||||
|
||||
public void setVolume(int volume) throws IOException {
|
||||
public void setVolume(int volume) {
|
||||
this.volume = volume;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.bensherriff.siren.settings;
|
||||
|
||||
import com.bensherriff.siren.openai.Role;
|
||||
import com.bensherriff.siren.ai.Role;
|
||||
|
||||
public class OpenAISettings {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user