v0.1.22 Updated settings layout

This commit is contained in:
2023-04-17 10:33:18 -04:00
parent 7e87ca5030
commit 74efbc2352
12 changed files with 156 additions and 99 deletions

View File

@@ -130,14 +130,14 @@ public class Listener extends ListenerAdapter {
LOGGER.error("Failed to create directory at {}", directory.getPath());
return;
}
Tracks tracks = SettingsManager.load(SettingsManager.TRACKS_PATH, Tracks.class);
LocalTracks tracks = SettingsManager.load(SettingsManager.TRACKS_PATH, LocalTracks.class);
int rows = 0;
File[] files = directory.listFiles();
if (files != null) {
for (Track track : tracks.getTracks()) {
for (LocalTrack track : tracks.getTracks()) {
for (File file : files) {
if (file.exists() && file.getName().equals(track.getFileName())) {
rows += DatabaseManager.storeAudio(track.getFileName(), track.getTags());
rows += DatabaseManager.storeAudio(track);
}
}
}

View File

@@ -1,5 +1,11 @@
package com.bensherriff.siren.ai;
/**
* CompletionRequest Models:
* text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001
* ChatCompletionRequest Models:
* gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
*/
public enum Model {
DAVINCI_3("text-davinci-003"),
DAVINCI_2("text-davinci-002"),

View File

@@ -4,8 +4,7 @@ import com.bensherriff.siren.Listener;
import com.bensherriff.siren.database.DatabaseManager;
import com.bensherriff.siren.database.MessageData;
import com.bensherriff.siren.database.QueryBuilder;
import com.bensherriff.siren.settings.GuildSettings;
import com.bensherriff.siren.settings.Settings;
import com.bensherriff.siren.settings.*;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
@@ -82,11 +81,16 @@ public class OpenAIManager {
try {
StringBuilder stringBuilder = new StringBuilder();
long guildId = event.getGuild().getIdLong();
Model model = settings.getGuildSettings().get(guildId).getModel();
ChatMessage chatMessage = createChatMessage(message, event);
long authorId = event.getAuthor().getIdLong();
UserSettings userSettings = new UserSettings();
settings.getGuildSettings().get(guildId).getUserSettings().putIfAbsent(authorId, userSettings);
SettingsManager.write(settings);
userSettings = settings.getGuildSettings().get(guildId).getUserSettings().get(authorId);
Model model = userSettings.getModel();
ChatMessage chatMessage = createUserMessage(message);
List<Embedding> embeddings = new ArrayList<>();
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, event.getAuthor().getId(), event.getMessageId(), message);
LOGGER.trace("Guild: <{}> User: <{}> Message <{}>: {}", guildId, authorId, event.getMessageId(), message);
// Send OpenAI Message and get response
switch (model) {
case DAVINCI_3, DAVINCI_2, CURIE_1, BABBAGE_1, ADA_1 -> {
@@ -137,7 +141,8 @@ public class OpenAIManager {
}
private ChatCompletionRequest createCompletionRequest(ChatMessage chatMessage, MessageReceivedEvent event) throws SQLException {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
.getUserSettings().get(event.getAuthor().getIdLong());
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(chatMessage);
@@ -188,8 +193,8 @@ public class OpenAIManager {
}
return ChatCompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.model(userSettings.getModel().getName())
.maxTokens(userSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
@@ -200,10 +205,11 @@ public class OpenAIManager {
}
private CompletionRequest createCompletionRequest(String message, MessageReceivedEvent event) {
GuildSettings guildSettings = settings.getGuildSettings().get(event.getGuild().getIdLong());
UserSettings userSettings = settings.getGuildSettings().get(event.getGuild().getIdLong())
.getUserSettings().get(event.getAuthor().getIdLong());
return CompletionRequest.builder()
.model(guildSettings.getModel().getName())
.maxTokens(guildSettings.getMaxTokens())
.model(userSettings.getModel().getName())
.maxTokens(userSettings.getMaxTokens())
.user(event.getAuthor().getId())
.temperature(settings.getOpenAISettings().getTemperature())
.topP(settings.getOpenAISettings().getTopP())
@@ -217,15 +223,12 @@ public class OpenAIManager {
return new ChatMessage(Role.SYSTEM.getName(), message);
}
private ChatMessage createChatMessage(String message, MessageReceivedEvent event) {
ChatMessage chatMessage = new ChatMessage();
chatMessage.setContent(message);
if (event.getAuthor().getId().equals(settings.getOwner())) {
chatMessage.setRole(Role.ASSISTANT.getName());
} else {
chatMessage.setRole(settings.getOpenAISettings().getDefaultRole().getName());
}
return chatMessage;
private ChatMessage createAssistantMessage(String message) {
return new ChatMessage(Role.ASSISTANT.getName(), message);
}
private ChatMessage createUserMessage(String message) {
return new ChatMessage(Role.USER.getName(), message);
}
private void handleResponse(ChatMessage chatMessage, MessageReceivedEvent event, String response, List<Embedding> embeddings) {

View File

@@ -1,5 +1,6 @@
package com.bensherriff.siren.database;
import com.bensherriff.siren.settings.LocalTrack;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -114,12 +115,17 @@ public class DatabaseManager {
return storeMessage("message_embeddings", INSERT_MESSAGE_EMBEDDINGS, messageId, embeddingId);
}
public static int storeAudio(String fileName, String[] tags) {
public static int storeAudio(LocalTrack track) {
String INSERT_AUDIO = "INSERT INTO audio (" +
"author, " +
"title, " +
"length, " +
"identifier, " +
"file_name, " +
"tags) " +
"VALUES (?, ?)";
return storeMessage("audio", INSERT_AUDIO, fileName, tags);
"VALUES (?, ?, ?, ?, ?, ?)";
return storeMessage("audio", INSERT_AUDIO, track.getAuthor(), track.getTitle(), track.getLength(),
track.getIdentifier(), track.getFileName(), track.getTags());
}
private static int storeMessage(String tableName, String insertString, Object... params) {

View File

@@ -2,21 +2,14 @@ package com.bensherriff.siren.settings;
import com.bensherriff.siren.ai.Model;
import java.io.IOException;
import java.util.*;
public class GuildSettings {
private String prefix = "!";
private int volume = 100;
/**
* CompletionRequest Models:
* text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001
* ChatCompletionRequest Models:
* gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
*/
private Model model = Model.ADA_1;
private int maxTokens = 100;
private Map<Long, UserSettings> userSettings = new LinkedHashMap<>();
public String getPrefix() {
return prefix;
@@ -34,19 +27,11 @@ public class GuildSettings {
this.volume = volume;
}
public Model getModel() {
return model;
public Map<Long, UserSettings> getUserSettings() {
return userSettings;
}
public void setModel(Model model) {
this.model = model;
}
public int getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(int maxTokens) {
this.maxTokens = maxTokens;
public void setUserSettings(Map<Long, UserSettings> userSettings) {
this.userSettings = userSettings;
}
}

View File

@@ -0,0 +1,61 @@
package com.bensherriff.siren.settings;
import java.util.ArrayList;
import java.util.List;
public class LocalTrack {
private String title = "";
private String author = "";
private long length;
private String identifier = "";
private String fileName = "";
private List<String> tags = new ArrayList<>();
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
public List<String> getTags() {
return tags;
}
public void setTags(List<String> tags) {
this.tags = tags;
}
public String getTitle() {
return title;
}
public void setTitle(String title) {
this.title = title;
}
public String getAuthor() {
return author;
}
public void setAuthor(String author) {
this.author = author;
}
public long getLength() {
return length;
}
public void setLength(long length) {
this.length = length;
}
public String getIdentifier() {
return identifier;
}
public void setIdentifier(String identifier) {
this.identifier = identifier;
}
}

View File

@@ -0,0 +1,17 @@
package com.bensherriff.siren.settings;
import java.util.ArrayList;
import java.util.List;
public class LocalTracks {
private List<LocalTrack> tracks = new ArrayList<>();
public List<LocalTrack> getTracks() {
return tracks;
}
public void setTracks(List<LocalTrack> localTracks) {
this.tracks = localTracks;
}
}

View File

@@ -1,18 +1,20 @@
package com.bensherriff.siren.settings;
import com.bensherriff.siren.ai.Model;
import com.bensherriff.siren.ai.Role;
public class OpenAISettings {
private String token = "";
/**
* One of ['system', 'assistant', 'user']
* System: The system role provides full access to all OpenAI APIs and resources, including access to billing information and account management features. An API key with the system role can perform any action that is allowed by the OpenAI API.
* Assistant: The assistant role is designed for use with virtual assistants or chatbots that interact with users. An API key with the assistant role can perform actions related to natural language processing, such as generating text, answering questions, or completing tasks. The assistant role is more limited than the system role, but it still provides access to powerful language models such as GPT-3.
* User: The user role provides limited access to specific OpenAI APIs and resources. An API key with the user role can only perform actions related to specific use cases, such as accessing a particular language model or dataset. The user role is more restricted than the assistant or system roles, but it is suitable for many common use cases.
*/
private Role defaultRole = Role.USER;
public static Role defaultRole = Role.USER;
public static Model defaultModel = Model.ADA_1;
public static int defaultMaxTokens = 100;
private String token = "";
/**
* In milliseconds
@@ -31,14 +33,6 @@ public class OpenAISettings {
this.token = token;
}
public Role getDefaultRole() {
return defaultRole;
}
public void setDefaultRole(Role defaultRole) {
this.defaultRole = defaultRole;
}
public long getTimeout() {
return timeout;
}

View File

@@ -1,22 +0,0 @@
package com.bensherriff.siren.settings;
public class Track {
private String fileName;
private String[] tags;
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
public String[] getTags() {
return tags;
}
public void setTags(String[] tags) {
this.tags = tags;
}
}

View File

@@ -1,17 +0,0 @@
package com.bensherriff.siren.settings;
import java.util.ArrayList;
import java.util.List;
public class Tracks {
private List<Track> tracks = new ArrayList<>();
public List<Track> getTracks() {
return tracks;
}
public void setTracks(List<Track> tracks) {
this.tracks = tracks;
}
}

View File

@@ -0,0 +1,24 @@
package com.bensherriff.siren.settings;
import com.bensherriff.siren.ai.Model;
public class UserSettings {
private Model model = OpenAISettings.defaultModel;
private int maxTokens = OpenAISettings.defaultMaxTokens;
public Model getModel() {
return model;
}
public void setModel(Model model) {
this.model = model;
}
public int getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(int maxTokens) {
this.maxTokens = maxTokens;
}
}