From aa7bad945a16b157bb7d35df0f4aa71dfe3b8f0e Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Wed, 18 Dec 2024 19:05:13 -0500 Subject: [PATCH] Simplified enqueue_track to parse track type (playlist or single track) --- src/bot/commands/audio/mod.rs | 10 +-- src/bot/commands/audio/play.rs | 77 +++++++++------------ src/bot/commands/fun/roll.rs | 121 +++++++++++++++++++++------------ src/bot/commands/mod.rs | 15 ++-- src/bot/ytdlp/model.rs | 37 ++++++++-- 5 files changed, 144 insertions(+), 116 deletions(-) diff --git a/src/bot/commands/audio/mod.rs b/src/bot/commands/audio/mod.rs index 6e91387..6dbc101 100644 --- a/src/bot/commands/audio/mod.rs +++ b/src/bot/commands/audio/mod.rs @@ -56,14 +56,8 @@ pub async fn leave_voice_channel(manager: &Arc, guild_id: &GuildId) -> * 1st tuple value is if the URL is valid. * 2nd tuple value is if the URL is a playlist. */ -fn is_valid_url(url: &str) -> (bool, bool) { - Url::parse(url).ok().map_or((false, false), |valid_url| { - let is_playlist: bool = valid_url - .query_pairs() - .find(|(key, _)| key == "list") - .map_or(false, |_| true); - (true, is_playlist) - }) +fn is_valid_url(url: &str) -> bool { + Url::parse(url).ok().map_or(false, |valid_url| true) } /** diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index 355a029..766fdd2 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -9,7 +9,7 @@ use songbird::{Event, EventHandler, Songbird, TrackEvent}; use crate::bot::commands::audio::leave_voice_channel; use crate::data::guilds::GuildCache; -use crate::bot::ytdlp::{PlaylistItem, YtDlp}; +use crate::bot::ytdlp::{YtDlp, YtDlpItem}; use crate::error::{SirenResult, Error as SirenError}; use crate::HttpKey; @@ -57,12 +57,12 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { log::debug!("<{guild_id}> Play command executed on {channel_id} with track: {track_url:?}"); // Handle the track url match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await { - Ok(count) => { - let mut message = format!("Playing {} tracks", count); - if count == 0 { + Ok(items) => { + let mut message = format!("Added {} tracks", items.len()); + if items.len() == 0 { message = "No tracks were played".to_string(); - } else if count == 1 { - message = "Playing 1 track".to_string(); + } else if items.len() == 1 { + message = format!("Added **{}**", items[0].get_title()); } edit_response(&ctx, &command, message).await; } @@ -84,42 +84,32 @@ pub async fn enqueue_track( manager: Arc, guild_id: GuildId, track_url: &str, -) -> SirenResult { - let mut track_count = 0; +) -> SirenResult> { + let mut playlist_items: Vec = Vec::new(); if let Some(handler_lock) = manager.get(guild_id) { let mut handler = handler_lock.lock().await; let guild = GuildCache::get_by_id(guild_id.get() as i64).await?.unwrap(); let valid = is_valid_url(&track_url); + // Check if the URL is valid - if !valid.0 { + if !valid { log::warn!("<{guild_id}> Invalid track url: {}", track_url); return Err(SirenError::new( 422, format!("Invalid track url: {}", track_url), )); } - let mut playlist_items: Vec = Vec::new(); - // Check if the URL is a playlist or a single track - if valid.1 { - playlist_items = match get_playlist_urls(&track_url) { - Ok(items) => items, - Err(err) => { - log::warn!("<{guild_id}> Failed to get playlist urls: {}", err); - return Err(SirenError::new(422, err.to_string())); - } - }; - } else { - let playlist_item = PlaylistItem { - id: "".to_string(), - url: track_url.to_string(), - title: "".to_string(), - duration: 0, - playlist_index: 0, - }; - playlist_items.push(playlist_item); - } + + playlist_items = match get_ytdlp_items(&track_url) { + Ok(items) => items, + Err(err) => { + log::warn!("<{guild_id}> Failed to get playlist urls: {}", err); + return Err(SirenError::new(422, err.to_string())); + } + }; + // Add each track to the queue - for item in playlist_items { + for item in &playlist_items { let volume = guild.volume as f32 / 100.0; let http_client = { let data = ctx.data.read().await; @@ -128,21 +118,17 @@ pub async fn enqueue_track( .cloned() .expect("Guaranteed to exist in the typemap.") }; - let source = YoutubeDl::new(http_client, item.url.to_owned()); - let mut input: Input = source.into(); - let metadata = match input.aux_metadata().await { - Ok(metadata) => metadata, - Err(err) => { - log::warn!("<{guild_id}> Failed to get metadata for track: {err}"); - let _ = leave_voice_channel(&manager, &guild_id).await; - return Err(SirenError::new(422, err.to_string())); - } - }; + + let source = YoutubeDl::new(http_client, item.get_url().to_owned()); + let input: Input = source.into(); + let track_title = item.get_title().to_owned(); + let track_handle: TrackHandle; track_handle = handler.enqueue_input(input).await; + // Set the volume let _ = track_handle.set_volume(volume); - let track_title = metadata.title.unwrap(); + log::debug!("<{guild_id}> Added track: {}", track_title); handler.remove_all_global_events(); handler.add_global_event( @@ -152,29 +138,28 @@ pub async fn enqueue_track( call: manager.clone(), }, ); - track_count += 1; } if handler.queue().is_empty() { let _ = handler.queue().resume(); } } - Ok(track_count) + Ok(playlist_items) } -pub fn get_playlist_urls(url: &str) -> SirenResult> { +pub fn get_ytdlp_items(url: &str) -> SirenResult> { let output = YtDlp::new() .arg("--flat-playlist") .arg("--dump-json") .arg(url) .execute()?; - let items: Vec = String::from_utf8(output.stdout)? + let items: Vec = String::from_utf8(output.stdout)? .split('\n') .filter_map(|line| { if line.is_empty() { None } else { Some( - serde_json::from_slice::(line.as_bytes()) + serde_json::from_slice::(line.as_bytes()) .map_err(|err| SirenError::new(500, err.to_string())), ) } diff --git a/src/bot/commands/fun/roll.rs b/src/bot/commands/fun/roll.rs index 34fe946..7ce3c9b 100644 --- a/src/bot/commands/fun/roll.rs +++ b/src/bot/commands/fun/roll.rs @@ -1,32 +1,45 @@ use rand::Rng; use serenity::all::{ - CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption, Mentionable, UserId + CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption, Mentionable, + UserId, }; use crate::bot::commands::{create_response, edit_response, user_id_dm}; pub async fn run(ctx: &Context, command: &CommandInteraction) { // Check if the roll result is private - let private = command.data.options.iter().find(|opt| opt.name == "private") - .and_then(|o| o.value.as_bool()) - .unwrap_or(true); + let private = command + .data + .options + .iter() + .find(|opt| opt.name == "private") + .and_then(|o| o.value.as_bool()) + .unwrap_or(true); // Retrieve the DM's name or ID from the options (optional) - let user = command.data.options.iter().find(|opt| opt.name == "user") + let user = command + .data + .options + .iter() + .find(|opt| opt.name == "user") .and_then(|o| o.value.as_mentionable()); create_response(&ctx, &command, format!("Rolling..."), private).await; - let dice_string = match command.data.options.get(0) + let dice_string = match command + .data + .options + .get(0) .and_then(|o| o.value.as_str()) - .map(|s| s.split_whitespace().collect::()) { - Some(dice_value) => dice_value, - None => { - log::warn!("Missing or invalid dice option"); - let _ = edit_response(&ctx, &command, "Dice option is missing".to_string()).await; - return; - } - }; + .map(|s| s.split_whitespace().collect::()) + { + Some(dice_value) => dice_value, + None => { + log::warn!("Missing or invalid dice option"); + let _ = edit_response(&ctx, &command, "Dice option is missing".to_string()).await; + return; + } + }; let dice = parse_dice(dice_string.as_str()); match dice { @@ -51,14 +64,24 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { "".to_string() } ); - + match user { Some(id) => { let user_id = UserId::new(id.get()); - user_id_dm(&ctx, &user_id, format!("Dice roll from {}: {}", &command.user.mention(), response)).await; - edit_response(&ctx, command, format!("Sending dice roll results to {}", &user_id.mention())).await; - }, - None => edit_response(&ctx, &command, response).await + user_id_dm( + &ctx, + &user_id, + format!("Dice roll from {}: {}", &command.user.mention(), response), + ) + .await; + edit_response( + &ctx, + command, + format!("Sending dice roll results to {}", &user_id.mention()), + ) + .await; + } + None => edit_response(&ctx, &command, response).await, }; } Err(why) => { @@ -77,7 +100,7 @@ fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { let dice = if dice.starts_with("d") { format!("1{}", dice) // Prepend "1" } else { - dice.to_string() + dice.to_string() }; let mut parts = dice.split(['d', '+', '-'].as_ref()); @@ -87,49 +110,53 @@ fn parse_dice(dice: &str) -> Result<(u32, u32, i32), String> { let count = match parts.next() { Some("") => 1, // Handle cases like "d6", assume 1 dice Some(c) => match c.parse::() { - Ok(n) => n, - Err(_) => return Err(format!("Invalid dice count: {}", c)), + Ok(n) => n, + Err(_) => return Err(format!("Invalid dice count: {}", c)), }, None => return Err(format!("Invalid dice string: {}", dice)), }; // Parse the number of sides - let sides_part = parts.next().ok_or_else(|| format!("Invalid dice string: {}", dice))?; + let sides_part = parts + .next() + .ok_or_else(|| format!("Invalid dice string: {}", dice))?; let sides = match sides_part.parse::() { - Ok(n) => { - if [4, 6, 8, 10, 12, 20, 100].contains(&n) { - n - } else { - return Err(format!( - "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", - n - )); - } - } - Err(_) => return Err(format!( + Ok(n) => { + if [4, 6, 8, 10, 12, 20, 100].contains(&n) { + n + } else { + return Err(format!( "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", - sides_part - )), + n + )); + } + } + Err(_) => { + return Err(format!( + "Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}", + sides_part + )) + } }; // Determine if there's a modifier (+ or -) if dice.contains('+') { positive_modifier = true; } else if dice.contains('-') { - positive_modifier = false; + positive_modifier = false; } // Parse the modifier, if present let modifier = match parts.next() { Some(m) => match m.parse::() { - Ok(n) => { - if positive_modifier { - n - } else { - -n - } + Ok(n) => { + if positive_modifier { + n + } else { + -n } - Err(_) => return Err(format!("Invalid dice modifier: {}", m)), + } + Err(_) => return Err(format!("Invalid dice modifier: {}", m)), }, None => 0, // No modifier found }; @@ -152,7 +179,11 @@ pub fn register() -> CreateCommand { .required(false), ) .add_option( - CreateCommandOption::new(CommandOptionType::Mentionable, "user", "User to receive the roll results") + CreateCommandOption::new( + CommandOptionType::Mentionable, + "user", + "User to receive the roll results", + ) .required(false), ) } diff --git a/src/bot/commands/mod.rs b/src/bot/commands/mod.rs index c44df9b..21c0703 100644 --- a/src/bot/commands/mod.rs +++ b/src/bot/commands/mod.rs @@ -1,6 +1,7 @@ use serenity::prelude::*; use serenity::all::{ - CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, User, UserId + CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, + EditInteractionResponse, InteractionResponseFlags, Message, User, UserId, }; pub mod audio; @@ -13,11 +14,7 @@ pub async fn process_message(ctx: &Context, command: &CommandInteraction, privat create_response(&ctx, &command, format!("Processing..."), private).await; } -pub async fn user_id_dm( - ctx: &Context, - user_id: &UserId, - content: String, -) -> Option { +pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option { let data = CreateMessage::new().content(content.to_owned()); return match user_id.dm(ctx, data).await { Ok(message) => Some(message), @@ -28,11 +25,7 @@ pub async fn user_id_dm( }; } -pub async fn user_dm( - ctx: &Context, - user: &User, - content: String, -) -> Option { +pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option { let data = CreateMessage::new().content(content.to_owned()); return match user.direct_message(ctx, data).await { Ok(message) => Some(message), diff --git a/src/bot/ytdlp/model.rs b/src/bot/ytdlp/model.rs index 8f83b98..df7d7f6 100644 --- a/src/bot/ytdlp/model.rs +++ b/src/bot/ytdlp/model.rs @@ -1,10 +1,35 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] -pub struct PlaylistItem { - pub id: String, - pub url: String, - pub title: String, - pub duration: i32, - pub playlist_index: i32, +#[serde(untagged)] +pub enum YtDlpItem { + PlaylistItem { + id: String, + url: String, + title: String, + duration: i32, + playlist_index: i32, + }, + VideoItem { + id: String, + webpage_url: String, + title: String, + duration: i32, + }, +} + +impl YtDlpItem { + pub fn get_title(&self) -> &str { + match self { + YtDlpItem::PlaylistItem { title, .. } => title, + YtDlpItem::VideoItem { title, .. } => title, + } + } + + pub fn get_url(&self) -> &str { + match self { + YtDlpItem::PlaylistItem { url, .. } => url, + YtDlpItem::VideoItem { webpage_url, .. } => webpage_url, + } + } }