diff --git a/.env.TEMPLATE b/.env.TEMPLATE index d789382..3e1ad7b 100644 --- a/.env.TEMPLATE +++ b/.env.TEMPLATE @@ -2,4 +2,5 @@ DISCORD_TOKEN= RUST_LOG=warn,siren=info POSTGRES_USER=siren POSTGRES_PASSWORD= -POSTGRES_DB=siren \ No newline at end of file +POSTGRES_DB=siren +OPENAI_API_KEY= \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 0fd885c..990393f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,10 @@ name = "siren" version = "0.2.1" edition = "2021" +authors = ["Ben Sherriff "] +repository = "https://github.com/bensherriff/siren" +readme = "README.md" +license = "GPL-3.0-or-later" [dependencies] dotenv = "0.15.0" diff --git a/Makefile b/Makefile index af44f7e..5207d97 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,8 @@ include .version export $(shell sed 's/=.*//' .env) export $(shell sed 's/=.*//' .version) +SIREN_IMAGES = $(shell docker images 'siren' -a -q) + .PHONY: help build test up down exec clean build: @@ -23,4 +25,4 @@ exec: docker exec -it siren bash clean: - docker rmi siren + docker rmi $(SIREN_IMAGES) diff --git a/docker-compose.yml b/docker-compose.yml index e86eb79..fdfc898 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,10 +9,11 @@ services: dockerfile: ./Dockerfile args: - VERSION=${SIREN_VERSION} - #volumes: - # - ./app:/siren + volumes: + - ./app:/siren environment: DISCORD_TOKEN: ${DISCORD_TOKEN} + RUST_LOG: ${RUST_LOG} DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} depends_on: - db @@ -24,8 +25,8 @@ services: POSTGRES_USER: ${POSTGRES_USER} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} POSTGRES_DB: ${POSTGRES_DB} - #volumes: - # - ./data:/var/lib/postgresql/data + volumes: + - ./data:/var/lib/postgresql/data ports: - "5432:5432" restart: unless-stopped diff --git a/src/commands/mod.rs b/src/commands/mod.rs index be3685c..fd3b9e5 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,2 +1,4 @@ +pub mod audio; +pub mod help; +pub mod oai; pub mod ping; -pub mod audio; \ No newline at end of file diff --git a/src/commands/oai.rs b/src/commands/oai.rs new file mode 100644 index 0000000..0980300 --- /dev/null +++ b/src/commands/oai.rs @@ -0,0 +1,206 @@ + +use std::error::Error; +use std::fmt; + +use log::{error, debug, trace}; + +use serde::{Serialize, Deserialize}; +use serde_json::Value; +use serenity::model::channel::Message; +use serenity::prelude::*; + +pub struct OAI { + pub client: reqwest::Client, + pub base_url: String, + pub max_attempts: u64, + pub token: String +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ChatCompletionRequest { + model: String, + messages: Vec, + /// Value between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + /// Value between 0 and 1 + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + /// Value between -2.0 and 2.0 + #[serde(skip_serializing_if = "Option::is_none")] + presence_penalty: Option, + /// Value between -2.0 and 2.0 + #[serde(skip_serializing_if = "Option::is_none")] + frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + user: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ChatCompletionMessage { + role: String, + content: String +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum Role { + #[serde(rename = "system")] + SYSTEM, + #[serde(rename = "user")] + USER, + #[serde(rename = "assistant")] + ASSISTANT, + #[serde(rename = "function")] + FUNCTION +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ChatCompletionResponse { + id: String, + object: String, + created: u64, + model: String, + usage: Usage, + choices: Vec +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Usage { + prompt_tokens: u64, + completion_tokens: u64, + total_tokens: u64 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Choice { + message: ChatCompletionMessage, + finish_reason: String, + index: u64 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ResponseError { + code: Option, + message: Option, + param: Option, + #[serde(rename = "type")] + error_type: Option +} + +#[derive(Debug)] +struct OAIError { + pub message: String +} + +impl fmt::Display for OAIError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "OAIError: {}", self.message) + } +} + +impl Error for OAIError {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum ResponseEvent { + ChatCompletionResponse(ChatCompletionResponse), + ResponseError(ResponseError) +} + +impl OAI { + async fn get_request(&self, request: ChatCompletionRequest) -> Result { + let uri = format!("{}/chat/completions", self.base_url); + let body = serde_json::to_string(&request).unwrap(); + trace!("Sending request to {}", uri); + let value = match match self.client + .post(&uri) + .bearer_auth(&self.token) + .header("Content-Type", "application/json".to_string()) + .body(body) + .send() + .await { + Ok(r) => r, + Err(err) => return Err(OAIError { + message: format!("Could not send request to OpenAI: {}", err), + }) + } + .json::() + .await { + Ok(r) => r, + Err(err) => return Err(OAIError { + message: format!("Could not read response from OpenAI: {}", err) + }) + }; + + debug!("Received response from OpenAI: {:?}", value); + + // let response = match serde_json::from_value::(value) { + // Ok(r) => { + // match r { + // OAIResponseEvent::OAIResponse(r) => r, + // OAIResponseEvent::OAIError(e) => return Err(OAIError { message: e.message.unwrap_or("Unknown error".to_string()) }) + // } + // }, + // Err(err) => return Err(OAIError { + // message: format!("Could not parse response from OpenAI: {}", err) + // }) + // }; + let response = match serde_json::from_value::(value) { + Ok(r) => r, + Err(err) => return Err(OAIError { + message: format!("Could not parse response from OpenAI: {}", err) + }) + }; + + Ok(response) + } +} + +pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { + let bot_mention = format!("<@{}>", ctx.cache.current_user_id().0); + let parsed_content = msg.content.replace(bot_mention.as_str(), ""); + debug!("Generating response for message: {}", msg.content); + + let messages = vec![ + ChatCompletionMessage { + role: "system".to_string(), + content: "You are a helpful Discord bot named Siren.".to_string() + }, + ChatCompletionMessage { + role: "user".to_string(), + content: parsed_content + }, + ]; + + let request = ChatCompletionRequest { + model: "gpt-3.5-turbo".to_string(), + messages, + temperature: None, + top_p: None, + n: None, + max_tokens: Some(2000), + presence_penalty: None, + frequency_penalty: None, + user: Some(msg.author.name.clone()) + }; + let response = match oai.get_request(request).await { + Ok(r) => { + debug!("Received response from OpenAI"); + if !r.choices.is_empty() { + r.choices[0].message.content.clone() + } else { + "No reply received".to_string() + } + } + Err(err) => { + error!("Could not get response from OpenAI: {}", err.message); + err.message + } + }; + if let Err(why) = msg.channel_id.say(&ctx.http, response).await { + error!("Cannot send message: {}", why); + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e4b814f..8fdc3a2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,15 +8,40 @@ use serenity::async_trait; use serenity::framework::StandardFramework; use serenity::model::application::interaction::Interaction; use serenity::model::gateway::Ready; +use serenity::model::channel::Message; use serenity::http::Http; use serenity::prelude::*; use songbird::SerenityInit; mod commands; -struct Handler; +struct Handler { + // Open AI Config + oai: Option +} #[async_trait] impl EventHandler for Handler { + async fn message(&self, ctx: Context, msg: Message) { + // Ignore messages from bots + if msg.author.bot { + return; + } + match &self.oai { + Some(oai) => { + match msg.mentions_me(&ctx.http).await { + Ok(mentioned) => { + if mentioned { + commands::oai::generate_response(&ctx, &msg, oai).await; + } + } + Err(why) => warn!("Could not check mentions: {:?}", why) + }; + } + None => {} + } + + } + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { if let Interaction::ApplicationCommand(command) = interaction { match command.data.name.as_str() { @@ -63,7 +88,7 @@ impl EventHandler for Handler { #[tokio::main] async fn main() { dotenv().ok(); - env_logger::init(); + env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); let intents: GatewayIntents = GatewayIntents::all(); @@ -90,9 +115,22 @@ async fn main() { .owners(owners) .prefix("!") ); + + let handler = match env::var("OPENAI_API_KEY") { + Ok(token) => { + info!("Loaded OpenAI token"); + Handler { + oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token }) + } + } + Err(err) => { + warn!("Could not load OpenAI token: {}", err); + Handler { oai: None } + } + }; let mut client = Client::builder(token, intents) - .event_handler(Handler) + .event_handler(handler) .framework(framework) .register_songbird() .await