diff --git a/src/commands/oai.rs b/src/commands/oai.rs index 366a9c3..99536e7 100644 --- a/src/commands/oai.rs +++ b/src/commands/oai.rs @@ -17,7 +17,8 @@ pub struct OAI { pub client: reqwest::Client, pub base_url: String, pub max_attempts: i64, - pub token: String + pub token: String, + pub max_context_questions: i64 } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -119,6 +120,7 @@ impl OAI { let uri = format!("{}/chat/completions", self.base_url); let body = serde_json::to_string(&request).unwrap(); trace!("Sending request to {}: {}", uri, body); + let value = match match self.client .post(&uri) .bearer_auth(&self.token) @@ -168,28 +170,38 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P let parsed_content = msg.content.replace(bot_mention.as_str(), ""); debug!("Generating response for message: {}", msg.content); - let messages = vec![ + let instructions = vec![ ChatCompletionMessage { role: "system".to_string(), - content: "You are a helpful Discord bot named Siren.".to_string() + content: "You are a Discord user named Siren.".to_string() }, ChatCompletionMessage { - role: "user".to_string(), - content: parsed_content.to_string() + role: "system".to_string(), + content: "Siren is an expert on Dungeons and Dragons.".to_string() }, ]; + let user_request = ChatCompletionMessage { + role: "user".to_string(), + content: parsed_content.to_string() + }; + + let mut messages: Vec = vec![]; + messages.extend(instructions); + // TODO: Get previous messages + messages.push(user_request); + let model = "gpt-3.5-turbo".to_string(); let request = ChatCompletionRequest { model: model.to_string(), messages, - temperature: None, + temperature: Some(0.5), top_p: None, n: None, - max_tokens: Some(2000), - presence_penalty: None, - frequency_penalty: None, + max_tokens: Some(1000), + presence_penalty: Some(0.6), + frequency_penalty: Some(0.0), user: Some(msg.author.name.clone()) }; let response = match oai.get_request(request).await { @@ -227,4 +239,4 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI, pool: &P if let Err(why) = msg.channel_id.say(&ctx.http, response).await { error!("Cannot send message: {}", why); } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 5854ad8..9e681a8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,9 +18,6 @@ use songbird::SerenityInit; mod commands; mod database; - -pub const MIGRATIONS: diesel_migrations::EmbeddedMigrations = diesel_migrations::embed_migrations!("migrations"); - struct Handler { // Open AI Config oai: Option, @@ -125,7 +122,7 @@ async fn main() { Ok(token) => { info!("Loaded OpenAI token"); Handler { - oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token }), + oai: Some(commands::oai::OAI { client: reqwest::Client::new(), base_url: "https://api.openai.com/v1".to_string(), max_attempts: 5, token , max_context_questions: 10 }), pool } }