v0.1.17 Added image generation support
This commit is contained in:
@@ -5,7 +5,7 @@ import com.bensherriff.siren.audio.PlayerManager;
|
|||||||
import com.bensherriff.siren.commands.*;
|
import com.bensherriff.siren.commands.*;
|
||||||
import com.bensherriff.siren.database.DatabaseManager;
|
import com.bensherriff.siren.database.DatabaseManager;
|
||||||
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
import com.bensherriff.siren.exceptions.EmptyVoiceChannelException;
|
||||||
import com.bensherriff.siren.openai.OpenAIManager;
|
import com.bensherriff.siren.ai.OpenAIManager;
|
||||||
import com.bensherriff.siren.settings.GuildSettings;
|
import com.bensherriff.siren.settings.GuildSettings;
|
||||||
import com.bensherriff.siren.settings.Settings;
|
import com.bensherriff.siren.settings.Settings;
|
||||||
import com.bensherriff.siren.settings.SettingsManager;
|
import com.bensherriff.siren.settings.SettingsManager;
|
||||||
@@ -57,6 +57,10 @@ public class Listener extends ListenerAdapter {
|
|||||||
return settings;
|
return settings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getOwner() {
|
||||||
|
return owner;
|
||||||
|
}
|
||||||
|
|
||||||
public JDA getJDA() {
|
public JDA getJDA() {
|
||||||
return jda;
|
return jda;
|
||||||
}
|
}
|
||||||
@@ -65,8 +69,8 @@ public class Listener extends ListenerAdapter {
|
|||||||
this.jda = jda;
|
this.jda = jda;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getOwner() {
|
public OpenAIManager getOpenAIManager() {
|
||||||
return owner;
|
return openAIManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void initialize() {
|
public void initialize() {
|
||||||
@@ -75,15 +79,6 @@ public class Listener extends ListenerAdapter {
|
|||||||
this.openAIManager = new OpenAIManager(this);
|
this.openAIManager = new OpenAIManager(this);
|
||||||
|
|
||||||
DatabaseManager.createTables();
|
DatabaseManager.createTables();
|
||||||
|
|
||||||
commands.put("play", new PlayCommand(this));
|
|
||||||
commands.put("stop", new StopCommand(this));
|
|
||||||
commands.put("skip", new SkipCommand(this));
|
|
||||||
commands.put("volume", new VolumeCommand(this));
|
|
||||||
commands.put("pause", new PauseCommand(this));
|
|
||||||
commands.put("resume", new ResumeCommand(this));
|
|
||||||
commands.put("help", new PauseCommand(this));
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void closeAudioConnection(long guildID) {
|
public void closeAudioConnection(long guildID) {
|
||||||
@@ -130,6 +125,13 @@ public class Listener extends ListenerAdapter {
|
|||||||
@Override
|
@Override
|
||||||
public void onReady(@NotNull ReadyEvent event) {
|
public void onReady(@NotNull ReadyEvent event) {
|
||||||
super.onReady(event);
|
super.onReady(event);
|
||||||
|
commands.put("play", new PlayCommand(this));
|
||||||
|
commands.put("stop", new StopCommand(this));
|
||||||
|
commands.put("skip", new SkipCommand(this));
|
||||||
|
commands.put("volume", new VolumeCommand(this));
|
||||||
|
commands.put("pause", new PauseCommand(this));
|
||||||
|
commands.put("resume", new ResumeCommand(this));
|
||||||
|
commands.put("image", new ImageCommand(this));
|
||||||
jda.getGuilds().forEach(guild -> executor.execute(() -> {
|
jda.getGuilds().forEach(guild -> executor.execute(() -> {
|
||||||
LOGGER.debug("Updating commands for {}", guild.getId());
|
LOGGER.debug("Updating commands for {}", guild.getId());
|
||||||
guild.updateCommands().addCommands(
|
guild.updateCommands().addCommands(
|
||||||
|
|||||||
17
src/main/java/com/bensherriff/siren/ai/ImageSize.java
Normal file
17
src/main/java/com/bensherriff/siren/ai/ImageSize.java
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package com.bensherriff.siren.ai;
|
||||||
|
|
||||||
|
public enum ImageSize {
|
||||||
|
SMALL("256x256"),
|
||||||
|
MEDIUM("512x512"),
|
||||||
|
LARGE("1024x1024");
|
||||||
|
|
||||||
|
private final String size;
|
||||||
|
|
||||||
|
ImageSize(String size) {
|
||||||
|
this.size = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getSize() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.bensherriff.siren.openai;
|
package com.bensherriff.siren.ai;
|
||||||
|
|
||||||
public enum Model {
|
public enum Model {
|
||||||
DAVINCI_3("text-davinci-003"),
|
DAVINCI_3("text-davinci-003"),
|
||||||
@@ -7,15 +7,13 @@ public enum Model {
|
|||||||
BABBAGE_1("text-babbage-001"),
|
BABBAGE_1("text-babbage-001"),
|
||||||
ADA_1("text-ada-001"),
|
ADA_1("text-ada-001"),
|
||||||
GPT_4("gpt-4"),
|
GPT_4("gpt-4"),
|
||||||
GPT_4_0314("gpt-4-0314"),
|
|
||||||
GPT_4_32K("gpt-4-32k"),
|
GPT_4_32K("gpt-4-32k"),
|
||||||
GPT_4_32K_0314("gpt-4-32k-0314"),
|
|
||||||
GPT_35_TURBO("gpt-3.5-turbo"),
|
GPT_35_TURBO("gpt-3.5-turbo"),
|
||||||
GPT_35_TURBO_0301("gpt-3.5-turbo-0301"),
|
ADA_EMBEDDING_2("text-embedding-ada-002")
|
||||||
;
|
;
|
||||||
|
|
||||||
private final String name;
|
private final String name;
|
||||||
private Model(String name) {
|
Model(String name) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.bensherriff.siren.openai;
|
package com.bensherriff.siren.ai;
|
||||||
|
|
||||||
import edu.stanford.nlp.ling.CoreAnnotations;
|
import edu.stanford.nlp.ling.CoreAnnotations;
|
||||||
import edu.stanford.nlp.ling.CoreLabel;
|
import edu.stanford.nlp.ling.CoreLabel;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.bensherriff.siren.openai;
|
package com.bensherriff.siren.ai;
|
||||||
|
|
||||||
import com.bensherriff.siren.Listener;
|
import com.bensherriff.siren.Listener;
|
||||||
import com.bensherriff.siren.database.DatabaseManager;
|
import com.bensherriff.siren.database.DatabaseManager;
|
||||||
@@ -13,6 +13,8 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
|||||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||||
import com.theokanning.openai.embedding.Embedding;
|
import com.theokanning.openai.embedding.Embedding;
|
||||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||||
|
import com.theokanning.openai.image.CreateImageRequest;
|
||||||
|
import com.theokanning.openai.image.ImageResult;
|
||||||
import com.theokanning.openai.service.OpenAiService;
|
import com.theokanning.openai.service.OpenAiService;
|
||||||
import net.dv8tion.jda.api.JDA;
|
import net.dv8tion.jda.api.JDA;
|
||||||
import net.dv8tion.jda.api.entities.Message;
|
import net.dv8tion.jda.api.entities.Message;
|
||||||
@@ -90,7 +92,7 @@ public class OpenAIManager {
|
|||||||
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
CompletionResult completionResult = openAiService.createCompletion(completionRequest);
|
||||||
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
completionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getText().trim()));
|
||||||
}
|
}
|
||||||
case GPT_4, GPT_4_0314, GPT_4_32K, GPT_4_32K_0314, GPT_35_TURBO, GPT_35_TURBO_0301 -> {
|
case GPT_4, GPT_4_32K, GPT_35_TURBO -> {
|
||||||
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
ChatCompletionRequest chatCompletionRequest = createCompletionRequest(chatMessage, event);
|
||||||
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
ChatCompletionResult chatCompletionResult = openAiService.createChatCompletion(chatCompletionRequest);
|
||||||
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
chatCompletionResult.getChoices().forEach(choice -> stringBuilder.append(choice.getMessage().getContent().trim()));
|
||||||
@@ -117,12 +119,21 @@ public class OpenAIManager {
|
|||||||
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
private EmbeddingRequest createEmbeddingRequest(List<ChatMessage> chatMessages, MessageReceivedEvent event) {
|
||||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||||
return EmbeddingRequest.builder()
|
return EmbeddingRequest.builder()
|
||||||
.model(guildSettings.getModel().getName())
|
.model(Model.ADA_EMBEDDING_2.getName())
|
||||||
.user(event.getAuthor().getId())
|
.user(event.getAuthor().getId())
|
||||||
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
|
.input(chatMessages.stream().map(ChatMessage::getContent).collect(Collectors.toList()))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ImageResult createImage(String prompt, int count, ImageSize imageSize) {
|
||||||
|
LOGGER.trace("Generating {} image(s) of size {} with prompt <{}>", count, imageSize.getSize(), prompt);
|
||||||
|
return openAiService.createImage(CreateImageRequest.builder()
|
||||||
|
.prompt(prompt)
|
||||||
|
.size(imageSize.getSize())
|
||||||
|
.n(count <= 0 ? 1 : Math.min(count, 3))
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
|
||||||
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
|
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
|
||||||
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
|
||||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.bensherriff.siren.openai;
|
package com.bensherriff.siren.ai;
|
||||||
|
|
||||||
public enum Role {
|
public enum Role {
|
||||||
SYSTEM("system"),
|
SYSTEM("system"),
|
||||||
@@ -7,7 +7,7 @@ public enum Role {
|
|||||||
;
|
;
|
||||||
|
|
||||||
private final String name;
|
private final String name;
|
||||||
private Role(String name) {
|
Role(String name) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package com.bensherriff.siren.commands;
|
||||||
|
|
||||||
|
import com.bensherriff.siren.Listener;
|
||||||
|
import com.bensherriff.siren.ai.ImageSize;
|
||||||
|
import com.theokanning.openai.image.ImageResult;
|
||||||
|
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
|
||||||
|
import net.dv8tion.jda.api.interactions.commands.OptionMapping;
|
||||||
|
import net.dv8tion.jda.api.interactions.commands.OptionType;
|
||||||
|
import net.dv8tion.jda.api.interactions.commands.build.Commands;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class ImageCommand extends Command {
|
||||||
|
|
||||||
|
public ImageCommand(Listener listener) {
|
||||||
|
super(listener);
|
||||||
|
slashCommandData = Commands.slash("image", "Generate an image using DALL-E")
|
||||||
|
.addOption(OptionType.STRING, "prompt", "The prompt for image generation", true)
|
||||||
|
.addOption(OptionType.INTEGER, "count", "The number of images to be generated", false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void execute(SlashCommandInteractionEvent event) throws IOException {
|
||||||
|
if (event.getUser().getId().equals(listener.getSettings().getOwner())) {
|
||||||
|
String prompt = Objects.requireNonNull(event.getOption("prompt")).getAsString();
|
||||||
|
int count = 1;
|
||||||
|
OptionMapping countOption = event.getOption("count");
|
||||||
|
if (countOption != null) {
|
||||||
|
count = countOption.getAsInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
ImageResult result = listener.getOpenAIManager().createImage(prompt, count, ImageSize.SMALL);
|
||||||
|
StringBuilder responseURLS = new StringBuilder();
|
||||||
|
result.getData().forEach(image -> responseURLS.append(image.getUrl()).append("\n"));
|
||||||
|
event.getHook().sendMessage(responseURLS.toString()).queue();
|
||||||
|
} else {
|
||||||
|
event.getHook().sendMessage("This command is currently only available to " + listener.getOwner() + ".").queue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.bensherriff.siren.database;
|
package com.bensherriff.siren.database;
|
||||||
|
|
||||||
import com.bensherriff.siren.openai.OpenAIManager;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ public class DatabaseManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOGGER.debug("Successfully created database tables");
|
LOGGER.debug("Databases initialized");
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean createTable(String tableName) {
|
private static boolean createTable(String tableName) {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.bensherriff.siren.settings;
|
package com.bensherriff.siren.settings;
|
||||||
|
|
||||||
import com.bensherriff.siren.openai.Model;
|
import com.bensherriff.siren.ai.Model;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ public class GuildSettings {
|
|||||||
return volume;
|
return volume;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setVolume(int volume) throws IOException {
|
public void setVolume(int volume) {
|
||||||
this.volume = volume;
|
this.volume = volume;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.bensherriff.siren.settings;
|
package com.bensherriff.siren.settings;
|
||||||
|
|
||||||
import com.bensherriff.siren.openai.Role;
|
import com.bensherriff.siren.ai.Role;
|
||||||
|
|
||||||
public class OpenAISettings {
|
public class OpenAISettings {
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user