From 2ecfa92d8b2c866595673f704f1a5c5a090e80e4 Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Wed, 18 Dec 2024 23:38:22 -0500 Subject: [PATCH] Reformat, working on roll --- .env | 14 ++-- Cargo.toml | 9 ++- src/bot/chat/mod.rs | 20 ++++-- src/bot/commands/audio/mute.rs | 4 +- src/bot/commands/audio/pause.rs | 4 +- src/bot/commands/audio/play.rs | 4 +- src/bot/commands/audio/resume.rs | 4 +- src/bot/commands/audio/skip.rs | 2 +- src/bot/commands/audio/stop.rs | 2 +- src/bot/commands/audio/volume.rs | 10 ++- src/bot/commands/fun/roll.rs | 92 ++++++++++++++++++------ src/bot/commands/utility/ping.rs | 7 +- src/bot/handler.rs | 75 ++++++++++--------- src/bot/oai/model.rs | 3 +- src/main.rs | 120 +++++++++++++++++++------------ src/utils/mod.rs | 2 + src/utils/text_utils.rs | 62 ++++++++++++++++ 17 files changed, 304 insertions(+), 130 deletions(-) create mode 100644 src/utils/mod.rs create mode 100644 src/utils/text_utils.rs diff --git a/.env b/.env index 59a5f02..3dd5996 100644 --- a/.env +++ b/.env @@ -1,5 +1,7 @@ RUST_LOG=warn,siren=info +DISCORD_TOKEN= + DATABASE_USER=siren DATABASE_PASSWORD=CHANGEME # Change this to a secure password DATABASE_NAME=siren @@ -17,10 +19,10 @@ MINIO_PORT_INTERNAL=9001 REDIS_HOST=localhost REDIS_PORT=6379 -# OPTIONAL +# Siren Data integration DATA_DIR_PATH= -# Mandatory -DISCORD_TOKEN= -# OPTIONAL -OPENAI_API_KEY= -OPENAI_API_MODEL=gpt-4o-mini + +# OpenAI +OPENAI_BASE_URL=https://api.openai.com/v1 +OPENAI_TOKEN= +OPENAI_MODEL=gpt-4o-mini diff --git a/Cargo.toml b/Cargo.toml index 47f0e4c..9adddc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "siren" -version = "0.2.9" +version = "0.2.10" edition = "2021" authors = ["Ben Sherriff "] description = "A Discord bot for playing music" @@ -15,15 +15,14 @@ env_logger = "0.11.5" serde = { version = "1.0.210", features = ["derive"] } serde_json = "1.0.128" serenity = { version = "0.12.2", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] } -songbird = { version = "0.4.3", features = ["builtin-queue"] } -symphonia = { version = "0.5.4", features = ["all"] } +songbird = { version = "0.4.6", features = ["builtin-queue"] } sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "chrono", "uuid"] } chrono = { version = "0.4.38", features = ["serde"] } reqwest = { version = "0.11", default-features = false, features = ["json"] } -lazy_static = "1.5.0" uuid = { version = "1.11.0", features = ["serde", "v4"] } redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] } rand = "0.8.5" -tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] } regex = "1.11.0" axum = "0.7.7" +lazy_static = "1.5.0" diff --git a/src/bot/chat/mod.rs b/src/bot/chat/mod.rs index 330533d..af3bb08 100644 --- a/src/bot/chat/mod.rs +++ b/src/bot/chat/mod.rs @@ -1,10 +1,11 @@ use serenity::all::{ CommandInteraction, Context, CreateInteractionResponse, CreateInteractionResponseMessage, - CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, User, UserId, + CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, ModalInteraction, + User, UserId, }; pub async fn process_message(ctx: &Context, command: &CommandInteraction, private: bool) { - create_response(&ctx, &command, "Processing...".to_string(), private).await; + create_message_response(&ctx, &command, "Processing...".to_string(), private).await; } pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option { @@ -29,7 +30,7 @@ pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option {} Err(err) => { - log::error!("Failed to create response for {content}\n{err}"); + log::error!("Failed to create message response for {content}\n{err}"); } }; } +pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) { + let mut data = CreateInteractionResponseMessage::new(); + let builder = CreateInteractionResponse::Message(data); + match modal.create_response(&ctx.http, builder).await { + Ok(_) => {} + Err(err) => { + log::error!("Failed to create modal response\n{err}"); + } + } +} + pub async fn edit_response(ctx: &Context, command: &CommandInteraction, content: String) { let builder = EditInteractionResponse::new().content(content.to_owned()); match command.edit_response(&ctx.http, builder).await { diff --git a/src/bot/commands/audio/mute.rs b/src/bot/commands/audio/mute.rs index 80688ec..07c721c 100644 --- a/src/bot/commands/audio/mute.rs +++ b/src/bot/commands/audio/mute.rs @@ -35,10 +35,10 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { Ok(_) => { if is_muted { log::debug!("<{guild_id}> Unmuted"); - edit_response(&ctx, &command, format!("Unmuted")).await; + edit_response(&ctx, &command, "Unmuted".to_string()).await; } else { log::debug!("<{guild_id}> Muted"); - edit_response(&ctx, &command, format!("Muted")).await; + edit_response(&ctx, &command, "Muted".to_string()).await; } } Err(err) => { diff --git a/src/bot/commands/audio/pause.rs b/src/bot/commands/audio/pause.rs index f3a541d..80210a4 100644 --- a/src/bot/commands/audio/pause.rs +++ b/src/bot/commands/audio/pause.rs @@ -35,14 +35,14 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { Some(track) => match track.pause() { Ok(_) => { log::debug!("<{guild_id}> Paused the track"); - edit_response(&ctx, &command, format!("Pausing the track")).await; + edit_response(&ctx, &command, "Pausing the track".to_string()).await; } Err(err) => { edit_response(&ctx, &command, format!("Failed to pause: {}", err)).await; } }, None => { - edit_response(ctx, command, format!("No track currently being played")).await; + edit_response(ctx, command, "No track currently being played".to_string()).await; } } } diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index b3902ef..cb7e4b0 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -14,7 +14,7 @@ use crate::HttpKey; use super::{get_songbird, is_valid_url, join_voice_channel}; -use crate::bot::chat::{create_response, edit_response, process_message}; +use crate::bot::chat::{create_message_response, edit_response, process_message}; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Process the command options @@ -25,7 +25,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { "{} attempted to play a track without a track option", command.user.id.get() ); - create_response(&ctx, &command, format!("Track option is missing"), false).await; + create_message_response(&ctx, &command, format!("Track option is missing"), false).await; return; } }; diff --git a/src/bot/commands/audio/resume.rs b/src/bot/commands/audio/resume.rs index 1a5870b..6be9c85 100644 --- a/src/bot/commands/audio/resume.rs +++ b/src/bot/commands/audio/resume.rs @@ -35,14 +35,14 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { Some(track) => match track.play() { Ok(_) => { log::debug!("<{guild_id}> Resumed the track"); - edit_response(&ctx, &command, format!("Resuming the track")).await; + edit_response(&ctx, &command, "Resuming the track".to_string()).await; } Err(err) => { edit_response(&ctx, &command, format!("Failed to resume: {}", err)).await; } }, None => { - edit_response(&ctx, &command, format!("No track is currently playing")).await; + edit_response(&ctx, &command, "No track is currently playing".to_string()).await; return; } } diff --git a/src/bot/commands/audio/skip.rs b/src/bot/commands/audio/skip.rs index 61880a3..49790a9 100644 --- a/src/bot/commands/audio/skip.rs +++ b/src/bot/commands/audio/skip.rs @@ -34,7 +34,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { match handler.queue().skip() { Ok(_) => { log::debug!("<{guild_id}> Skipped the track"); - edit_response(&ctx, &command, format!("Skipping the track")).await; + edit_response(&ctx, &command, "Skipping the track".to_string()).await; } Err(err) => { edit_response(&ctx, &command, format!("Failed to skip: {}", err)).await; diff --git a/src/bot/commands/audio/stop.rs b/src/bot/commands/audio/stop.rs index a542f50..c8087d0 100644 --- a/src/bot/commands/audio/stop.rs +++ b/src/bot/commands/audio/stop.rs @@ -33,7 +33,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { let handler = handler_lock.lock().await; handler.queue().stop(); log::debug!("<{guild_id}> Stopped the track"); - edit_response(&ctx, &command, format!("Stopping the tracks")).await; + edit_response(&ctx, &command, "Stopping the tracks".to_string()).await; } } diff --git a/src/bot/commands/audio/volume.rs b/src/bot/commands/audio/volume.rs index 237c5e4..213a6a9 100644 --- a/src/bot/commands/audio/volume.rs +++ b/src/bot/commands/audio/volume.rs @@ -9,7 +9,7 @@ use songbird::Songbird; use crate::data::guilds::GuildCache; -use crate::bot::chat::{create_response, edit_response, process_message}; +use crate::bot::chat::{create_message_response, edit_response, process_message}; use super::get_songbird; @@ -22,7 +22,13 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { "{} attempted to change the volume without a volume option", command.user.id.get() ); - create_response(&ctx, &command, format!("Volume option is missing"), false).await; + create_message_response( + &ctx, + &command, + "Volume option is missing".to_string(), + false, + ) + .await; return; } }; diff --git a/src/bot/commands/fun/roll.rs b/src/bot/commands/fun/roll.rs index df9dcea..37d2c86 100644 --- a/src/bot/commands/fun/roll.rs +++ b/src/bot/commands/fun/roll.rs @@ -1,10 +1,26 @@ +use std::collections::HashMap; +use std::sync::Mutex; use rand::Rng; use serenity::all::{ - CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption, Mentionable, - UserId, + ButtonStyle, CommandInteraction, CommandOptionType, Context, CreateActionRow, CreateButton, + CreateCommand, CreateCommandOption, CreateEmbed, CreateMessage, Mentionable, UserId, }; -use crate::bot::chat::{create_response, edit_response, user_id_dm}; +use crate::bot::chat::{create_message_response, edit_response}; +use crate::utils::{a_or_an, number_to_words}; + +lazy_static::lazy_static! { + static ref SAVED_ROLLS: Mutex>> = Mutex::new(HashMap::new()); +} + +pub fn temp() { + // // Add to the HashMap after processing the modal + // let mut saved_rolls = SAVED_ROLLS.lock().unwrap(); + // saved_rolls + // .entry(user_id) + // .or_default() + // .push((dice_roll, description.clone())); +} pub async fn run(ctx: &Context, command: &CommandInteraction) { // Check if the roll result is private @@ -24,7 +40,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { .find(|opt| opt.name == "user") .and_then(|o| o.value.as_mentionable()); - create_response(&ctx, &command, "Rolling...".to_string(), private).await; + create_message_response(&ctx, &command, "Rolling...".to_string(), private).await; let dice_string = match command .data @@ -51,29 +67,54 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { total += roll; rolls.push(roll); } - let response = format!( - "🎲 **{}** (Rolled {}d{}{})", + + let response = ( total + (modifier as u32), - count, - sides, - if modifier > 0 { - format!("+{}", modifier) - } else if modifier < 0 { - format!("-{}", modifier) - } else { - "".to_string() - } + format!( + "(Rolled {}d{}{})", + count, + sides, + if modifier > 0 { + format!("+{}", modifier) + } else if modifier < 0 { + format!("-{}", modifier) + } else { + "".to_string() + } + ), ); match user { Some(id) => { let user_id = UserId::new(id.get()); - user_id_dm( - &ctx, - &user_id, - format!("Dice roll from {}: {}", &command.user.mention(), response), - ) - .await; + + // Create the dice roll embed + let a = a_or_an(&number_to_words(response.0 as i32)); + let embed = CreateEmbed::new() + .title("🎲 Received a dice roll! 🎲".to_string()) + .color(0x00FF00) + .description(format!( + "{} rolled {} **{}**\n-# *{}*", + &command.user.mention(), + a, + response.0, + response.1 + )); + + // Create a button with a custom ID + let save_button = CreateButton::new("save_dice_roll") + .label("💾") + .style(ButtonStyle::Primary); + + // Action row to hold the button + let action_row = CreateActionRow::Buttons(vec![save_button]); + + let message = CreateMessage::new() + .embed(embed) + .components(vec![action_row]); + if let Err(err) = user_id.dm(&ctx, message).await { + log::error!("Could not send message: {}", err); + } edit_response( &ctx, command, @@ -81,7 +122,14 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { ) .await; } - None => edit_response(&ctx, &command, response).await, + None => { + edit_response( + &ctx, + &command, + format!("🎲 {}\n-# {}", response.0, response.1), + ) + .await + } }; } Err(why) => { diff --git a/src/bot/commands/utility/ping.rs b/src/bot/commands/utility/ping.rs index 79d1066..e6622e6 100644 --- a/src/bot/commands/utility/ping.rs +++ b/src/bot/commands/utility/ping.rs @@ -1,8 +1,9 @@ -use serenity::all::{CommandDataOption, CreateCommand}; +use serenity::all::{CommandDataOption, CommandInteraction, Context, CreateCommand}; +use crate::bot::chat::create_message_response; -pub fn run(_options: &[CommandDataOption]) -> String { +pub async fn run(ctx: &Context, command: &CommandInteraction) { log::debug!("Ping command executed"); - "pong".to_string() + create_message_response(&ctx, &command, "pong".to_string(), true).await; } pub fn register() -> CreateCommand { diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 3990b2b..484b2a9 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -1,42 +1,36 @@ -use serenity::all::Interaction; +use serenity::all::{CreateInteractionResponse, Interaction}; use serenity::async_trait; use serenity::model::gateway::Ready; use serenity::model::channel::Message; use serenity::prelude::*; - +use crate::bot::commands::chat::generate_response; +use crate::bot::oai::OAI; use crate::data::guilds::GuildCache; -use super::{commands, oai}; -use super::chat::create_response; +use super::{commands}; +use super::chat::{create_message_response, create_modal_response}; pub struct Handler { // Open AI Config - pub oai: Option, + pub oai: Option, } #[async_trait] impl EventHandler for Handler { async fn message(&self, ctx: Context, msg: Message) { - // Ignore messages from bots + // Ignore bot messages if msg.author.bot { return; } + + // Handle direct messages + if let None = msg.guild_id { + log::trace!("Received DM from {}: {}", msg.author, msg.content); + } + + // Handle OAI messages match &self.oai { Some(oai) => { - match msg.mentions_me(&ctx.http).await { - Ok(mentioned) => { - let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { - Ok(t) => match t.iter().find(|t| t.user_id == ctx.cache.current_user().id) { - Some(_) => true, - None => false, - }, - Err(_) => false, - }; - if mentioned || bot_in_thread { - commands::chat::generate_response(&ctx, &msg, oai).await; - } - } - Err(why) => log::warn!("Could not check mentions: {:?}", why), - }; + handle_oai_messages(oai, &ctx, &msg).await; } None => {} } @@ -80,7 +74,7 @@ impl EventHandler for Handler { Ok(c) => log::info!( "Registered {} commands for guild {}", c.len(), - guild.id.get() + guild.id.get(), ), Err(why) => log::error!( "Could not register commands for guild {}: {:?}", @@ -92,8 +86,10 @@ impl EventHandler for Handler { } async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - if let Interaction::Command(command) = interaction { - log::trace!("Received command interaction: {command:#?}"); + if let Interaction::Ping(ping) = interaction { + log::debug!("Received interaction ping: {:?}", ping); + } else if let Interaction::Command(command) = interaction { + log::debug!("Received command interaction: {command:#?}"); match command.data.name.as_str() { // Match commands without returns "play" => commands::audio::play::run(&ctx, &command).await, @@ -105,15 +101,30 @@ impl EventHandler for Handler { "volume" => commands::audio::volume::run(&ctx, &command).await, "schedule" => commands::event::schedule::run(&ctx, &command).await, "roll" => commands::fun::roll::run(&ctx, &command).await, - _ => { - let content: String = match command.data.name.as_str() { - // Match commands with string returns - "ping" => commands::utility::ping::run(&command.data.options), - _ => "Unknown command".to_string(), - }; - create_response(&ctx, &command, content, false).await; - } + "ping" => commands::utility::ping::run(&ctx, &command).await, + _ => {} } + } else if let Interaction::Modal(modal) = interaction { + log::debug!("Received interaction modal: {:?}", modal); + create_modal_response(&ctx, &modal).await; } } } + +async fn handle_oai_messages(oai: &OAI, ctx: &Context, msg: &Message) { + match msg.mentions_me(&ctx.http).await { + Ok(mentioned) => { + let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await { + Ok(t) => match t.iter().find(|t| t.user_id == ctx.cache.current_user().id) { + Some(_) => true, + None => false, + }, + Err(_) => false, + }; + if mentioned || bot_in_thread { + generate_response(&ctx, &msg, oai).await; + } + } + Err(why) => log::warn!("Could not check mentions: {why}"), + }; +} diff --git a/src/bot/oai/model.rs b/src/bot/oai/model.rs index 52152a6..b580997 100644 --- a/src/bot/oai/model.rs +++ b/src/bot/oai/model.rs @@ -68,7 +68,8 @@ pub struct Choice { pub message: ChatCompletionMessage, pub finish_reason: String, pub index: i64, - pub logprobs: Option, + #[serde(rename = "logprobs")] + pub log_probabilities: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/main.rs b/src/main.rs index 9a9d107..1f6c959 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,14 +6,17 @@ use serenity::http::Http; use serenity::prelude::*; use songbird::{SerenityInit, Songbird}; use reqwest::Client as HttpClient; +use serenity::all::{ShardManager, UserId}; use tokio::net::TcpListener; use crate::bot::handler::Handler; +use crate::bot::oai::OAI; mod api; mod bot; mod data; mod error; +mod utils; pub struct HttpKey; @@ -25,15 +28,14 @@ impl TypeMapKey for HttpKey { async fn main() { dotenv::dotenv().ok(); env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); + if let Err(err) = data::initialize().await { log::error!("Failed to initialize database: {err}"); return; }; // Start API server - tokio::spawn(async move { - start_api().await; - }); + tokio::spawn(start_api()); // Start Discord bot start_bot().await; @@ -41,8 +43,10 @@ async fn main() { async fn start_api() { let app = Router::new(); - let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); - log::debug!("listening on {}", listener.local_addr().unwrap()); + let addr: String = "127.0.0.1:3000".parse().unwrap(); + + let listener = TcpListener::bind(&addr).await.unwrap(); + log::debug!("API is listening on {}", &addr); axum::serve(listener, app).await.unwrap(); } @@ -51,45 +55,21 @@ async fn start_bot() { let intents: GatewayIntents = GatewayIntents::all(); let http: Http = Http::new(&token); - let (_owners, _bot_id) = match http.get_current_application_info().await { - Ok(info) => { - let mut owners: HashSet = HashSet::new(); - if let Some(team) = info.team { - owners.insert(team.owner_user_id); - } else { - owners.insert(info.owner.unwrap().id); - } - match http.get_current_user().await { - Ok(bot) => (owners, bot.id), - Err(why) => panic!("Could not access the bot id: {why:?}"), - } - } - Err(why) => panic!("Could not access application info: {why:?}"), - }; + let (owners, bot_id) = get_bot_info(&http).await; - let handler = match env::var("OPENAI_API_KEY") { - Ok(token) => { - log::info!("OpenAI functionality enabled"); - let default_model = env::var("OPENAI_API_MODEL").unwrap_or("gpt-4o-mini".to_string()); - Handler { - oai: Some(bot::oai::OAI { - client: reqwest::Client::new(), - base_url: "https://api.openai.com/v1".to_string(), - // max_attempts: 5, - token, - max_conversation_history: 30, - max_tokens: 8192, - default_model, - }), - } - } - Err(err) => { - log::trace!("No OPENAI_API_KEY found: {err}"); - log::warn!("OpenAI functionality disabled"); - Handler { oai: None } - } - }; + log::debug!( + "Starting Discord bot with ID: {bot_id} and owners: {}", + owners + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(", ") + ); + // Set up handler with optional OpenAI integration + let handler = configure_handler(); + + // Set up Songbird for voice functionality let songbird = Songbird::serenity(); let mut client = Client::builder(token, intents) @@ -100,14 +80,64 @@ async fn start_bot() { .await .expect("Error creating client"); - // Handle shutdown signals + // Spawn shutdown signal handling let shard_manager = Arc::clone(&client.shard_manager); tokio::spawn(async move { - shard_manager.shutdown_all().await; + signal_shutdown(shard_manager).await; }); - // Start Discord bot + // Start the bot if let Err(why) = client.start_autosharded().await { log::error!("Client error: {why:?}"); } } + +async fn get_bot_info(http: &Http) -> (HashSet, UserId) { + match http.get_current_application_info().await { + Ok(info) => { + let mut owners = HashSet::new(); + if let Some(team) = info.team { + owners.insert(team.owner_user_id); + } else if let Some(owner) = info.owner { + owners.insert(owner.id); + } + match http.get_current_user().await { + Ok(bot) => (owners, bot.id), + Err(why) => panic!("Could not access the bot id: {why:?}"), + } + } + Err(why) => panic!("Could not access application info: {why:?}"), + } +} + +fn configure_handler() -> Handler { + match env::var("OPENAI_TOKEN") { + Ok(token) => { + log::debug!("OpenAI functionality enabled"); + let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()); + let base_url = env::var("OPENAI_BASE_URL").unwrap(); + Handler { + oai: Some(OAI { + client: reqwest::Client::new(), + base_url, + token, + max_conversation_history: 30, + max_tokens: 8192, + default_model, + }), + } + } + Err(_) => { + log::warn!("OpenAI functionality disabled"); + Handler { oai: None } + } + } +} + +async fn signal_shutdown(shard_manager: Arc) { + tokio::signal::ctrl_c() + .await + .expect("Failed to listen for shutdown signal"); + shard_manager.shutdown_all().await; + log::info!("Bot shutdown gracefully."); +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..fa44ed7 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod text_utils; +pub use text_utils::*; diff --git a/src/utils/text_utils.rs b/src/utils/text_utils.rs new file mode 100644 index 0000000..2075111 --- /dev/null +++ b/src/utils/text_utils.rs @@ -0,0 +1,62 @@ +pub fn a_or_an(word: &str) -> &'static str { + let vowels = ['a', 'e', 'i', 'o', 'u']; + let lowercase_word = word.to_lowercase(); + + // Special cases where the article should be "a" + let special_cases_a = vec!["one"]; + if special_cases_a.contains(&lowercase_word.as_str()) { + return "a"; + } + + // Special cases where the article should be "an" + let special_cases_an = vec!["hour"]; + if special_cases_an.contains(&lowercase_word.as_str()) { + return "an"; + } + + let first_char = lowercase_word.chars().next(); + + match first_char { + // If the first character is a vowel, return "an" + Some(c) if vowels.contains(&c) => "an", + // Otherwise, return "a" + _ => "a", + } +} + +pub fn number_to_words(n: i32) -> String { + let ones = [ + "", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", + ]; + let teens = [ + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ]; + let tens = [ + "", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety", + ]; + + if n < 10 { + ones[n as usize].to_string() + } else if n < 20 { + teens[(n - 10) as usize].to_string() + } else if n < 100 { + let ten_part = tens[(n / 10) as usize]; + let one_part = ones[(n % 10) as usize]; + if n % 10 == 0 { + ten_part.to_string() // e.g., 20 → "twenty" + } else { + format!("{}-{}", ten_part, one_part) // e.g., 42 → "forty-two" + } + } else { + "Number out of range".to_string() // Handle numbers >= 100 (or extend the logic) + } +}