From 096a47b96eca8a0ecda8c284c3ccd1c2ba336782 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Mon, 29 Jan 2024 11:39:52 -0500 Subject: [PATCH] Use openai to generate chat thread titles --- service/src/bot/commands/chat.rs | 42 +++++++++++++++++++++++++---- service/src/bot/commands/message.rs | 0 service/src/bot/commands/mod.rs | 1 - service/src/bot/oai/model.rs | 2 +- 4 files changed, 38 insertions(+), 7 deletions(-) delete mode 100644 service/src/bot/commands/message.rs diff --git a/service/src/bot/commands/chat.rs b/service/src/bot/commands/chat.rs index c1705e8..00c9709 100644 --- a/service/src/bot/commands/chat.rs +++ b/service/src/bot/commands/chat.rs @@ -65,8 +65,9 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { }; // Get the thread channel ID + let thread_name = generate_thread_name(oai, &parsed_content, 99).await; let response_channel = match msg.channel_id.create_private_thread(&ctx.http, |thread| { - thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread) + thread.name(thread_name).kind(ChannelType::PublicThread) }).await { Ok(c) => { let allow = Permissions::SEND_MESSAGES; @@ -140,9 +141,40 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { // }; } -fn truncate(s: &str, max_chars: usize) -> &str { - match s.char_indices().nth(max_chars) { +async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String { + println!("HERE: {}", s); + let message = ChatCompletionMessage { + role: GPTRole::User, + content: format!("---\n{}\n---\nSummarize the message above into a concise Discord thread title", s) + }; + let request = ChatCompletionRequest { + model: oai.default_model.clone(), + messages: vec![message], + temperature: Some(0.5), + top_p: None, + n: None, + max_tokens: Some(oai.max_tokens), + presence_penalty: Some(0.6), + frequency_penalty: Some(0.0), + user: None + }; + // Truncate the response to the max number of characters + let mut response = match s.char_indices().nth(max_chars) { None => s, - Some((idx, _)) => &s[..idx], - } + Some((idx, _)) => &s[..idx] + }.to_string(); + // Set the response to the OAI response + match oai.chat_completion(request).await { + Ok(r) => { + if !r.choices.is_empty() { + response = r.choices[0].message.content.clone(); + } else { + warn!("No choices received in the response from OpenAI"); + } + } + Err(err) => { + error!("Could not get response from OpenAI: {}", err.message); + } + }; + return response; } diff --git a/service/src/bot/commands/message.rs b/service/src/bot/commands/message.rs deleted file mode 100644 index e69de29..0000000 diff --git a/service/src/bot/commands/mod.rs b/service/src/bot/commands/mod.rs index dc49d80..5479ddd 100644 --- a/service/src/bot/commands/mod.rs +++ b/service/src/bot/commands/mod.rs @@ -1,6 +1,5 @@ pub mod audio; pub mod help; -pub mod message; pub mod chat; pub mod ping; pub mod schedule; diff --git a/service/src/bot/oai/model.rs b/service/src/bot/oai/model.rs index 2669029..2f42644 100644 --- a/service/src/bot/oai/model.rs +++ b/service/src/bot/oai/model.rs @@ -125,4 +125,4 @@ impl OAI { Err(err) => return Err(ServiceError { status: 500, message: format!("Error: {}", err) }) } } -} \ No newline at end of file +}