From 4a18af9014b680fc946ab1dddd03d9999feaa3fa Mon Sep 17 00:00:00 2001 From: Benjamin Sherriff Date: Thu, 19 Dec 2024 13:50:31 -0500 Subject: [PATCH] Re-implementing the API --- .env | 5 +- Cargo.toml | 2 + docker-compose.yml | 17 ++- migrations/000_base.sql | 26 ---- migrations/000_initial.sql | 73 ++++++++++ migrations/001_dnd_tables.sql | 43 ------ src/api/app.rs | 29 ++++ src/api/mod.rs | 14 +- src/api/oauth.rs | 220 +++++++++++++++++++++++++++++++ src/bot/commands/audio/play.rs | 8 +- src/bot/commands/chat.rs | 4 +- src/bot/commands/utility/ping.rs | 8 ++ src/bot/handler.rs | 7 +- src/bot/oai/model.rs | 4 +- src/data/guilds/model.rs | 35 ++--- src/error.rs | 44 ++++++- src/main.rs | 99 +++++++------- 17 files changed, 486 insertions(+), 152 deletions(-) delete mode 100644 migrations/000_base.sql create mode 100644 migrations/000_initial.sql delete mode 100644 migrations/001_dnd_tables.sql create mode 100644 src/api/app.rs create mode 100644 src/api/oauth.rs diff --git a/.env b/.env index 3dd5996..8a212e6 100644 --- a/.env +++ b/.env @@ -1,6 +1,7 @@ RUST_LOG=warn,siren=info DISCORD_TOKEN= +DISCORD_SECRET= DATABASE_USER=siren DATABASE_PASSWORD=CHANGEME # Change this to a secure password @@ -8,7 +9,9 @@ DATABASE_NAME=siren DATABASE_HOST=localhost DATABASE_PORT=5432 -SESSION_TTL=1440 +API_CALLBACK_URI=http://localhost:3000/api/oauth/callback +API_PORT=3000 +API_SESSION_TTL=86400 MINIO_ROOT_USER=siren MINIO_ROOT_PASSWORD=CHANGEME # Change this to a secure password diff --git a/Cargo.toml b/Cargo.toml index 2cc2e3a..838d75e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,10 @@ reqwest = { version = "0.11", default-features = false, features = ["json"] } uuid = { version = "1.11.0", features = ["serde", "v4"] } redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] } rand = "0.8.5" +rand_chacha = "0.3.1" tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] } regex = "1.11.0" axum = "0.7.7" lazy_static = "1.5.0" futures = "0.3.31" +axum-login = "0.16.0" diff --git a/docker-compose.yml b/docker-compose.yml index a34fabd..e166fe7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,7 @@ services: environment: DATABASE_HOST: siren-postgres DATABASE_PORT: 5432 - REDIS_HOST: redis + REDIS_HOST: siren-redis REDIS_PORT: 6379 DATA_DIR_PATH: /data volumes: @@ -42,14 +42,27 @@ services: - ${DATABASE_PORT:-5432}:5432 networks: - backend + restart: unless-stopped profiles: - backend - restart: unless-stopped + redis: + image: redis:latest + container_name: siren-redis + volumes: + - redis:/data + ports: + - ${REDIS_PORT:-6379}:6379 + networks: + - backend + restart: unless-stopped + profiles: + - backend volumes: postgres: postgres_logs: + redis: networks: frontend: diff --git a/migrations/000_base.sql b/migrations/000_base.sql deleted file mode 100644 index 0fc6125..0000000 --- a/migrations/000_base.sql +++ /dev/null @@ -1,26 +0,0 @@ -CREATE TABLE IF NOT EXISTS guilds ( - id BIGINT PRIMARY KEY NOT NULL, - bot_id BIGINT NOT NULL, - volume INTEGER NOT NULL -); -CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY NOT NULL, - guild_id BIGINT NOT NULL, - channel_id BIGINT NOT NULL, - author_id BIGINT NOT NULL, - created BIGINT NOT NULL, - model TEXT NOT NULL, - request TEXT NOT NULL, - response TEXT NOT NULL, - request_tags TEXT[] NOT NULL, - response_tags TEXT[] NOT NULL -); -CREATE TABLE IF NOT EXISTS events ( - id UUID PRIMARY KEY NOT NULL, - guild_id BIGINT NOT NULL, - author_id BIGINT NOT NULL, - title TEXT NOT NULL, - date_time TIMESTAMP NOT NULL, - description TEXT, - rsvp BIGINT[] NOT NULL -); diff --git a/migrations/000_initial.sql b/migrations/000_initial.sql new file mode 100644 index 0000000..3559238 --- /dev/null +++ b/migrations/000_initial.sql @@ -0,0 +1,73 @@ +CREATE TABLE IF NOT EXISTS guilds ( + id BIGINT PRIMARY KEY NOT NULL, + name TEXT, + owner_id BIGINT, + volume INTEGER NOT NULL +); +CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY NOT NULL, + guild_id BIGINT NOT NULL, + channel_id BIGINT NOT NULL, + author_id BIGINT NOT NULL, + created BIGINT NOT NULL, + model TEXT NOT NULL, + request TEXT NOT NULL, + response TEXT NOT NULL, + request_tags TEXT[] NOT NULL, + response_tags TEXT[] NOT NULL +); +CREATE TABLE IF NOT EXISTS dice_rolls ( + id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid() +); +CREATE TABLE IF NOT EXISTS events ( + id UUID PRIMARY KEY NOT NULL, + guild_id BIGINT NOT NULL, + author_id BIGINT NOT NULL, + title TEXT NOT NULL, + date_time TIMESTAMP NOT NULL, + description TEXT, + rsvp BIGINT[] NOT NULL +); +CREATE TABLE IF NOT EXISTS races ( + id INTEGER GENERATED ALWAYS AS IDENTITY, + name TEXT NOT NULL, + size TEXT NOT NULL, + source TEXT NOT NULL, + data JSON NOT NULL +); +CREATE TABLE IF NOT EXISTS classes ( + id INTEGER GENERATED ALWAYS AS IDENTITY +); +CREATE TABLE IF NOT EXISTS feats ( + id INTEGER GENERATED ALWAYS AS IDENTITY +); +CREATE TABLE IF NOT EXISTS options_features ( + id INTEGER GENERATED ALWAYS AS IDENTITY +); +CREATE TABLE IF NOT EXISTS backgrounds ( + id INTEGER GENERATED ALWAYS AS IDENTITY +); +CREATE TABLE IF NOT EXISTS items ( + id INTEGER GENERATED ALWAYS AS IDENTITY +); +CREATE TABLE IF NOT EXISTS spells ( + id INTEGER GENERATED ALWAYS AS IDENTITY, + name TEXT NOT NULL, + school TEXT NOT NULL, + level INTEGER NOT NULL, + ritual BOOLEAN DEFAULT FALSE, + concentration BOOLEAN DEFAULT FALSE, + classes TEXT[] NOT NULL, + damage_inflict TEXT[] NOT NULL, + damage_resist TEXT[] NOT NULL, + conditions TEXT[] NOT NULL, + saving_throw TEXT[] NOT NULL, + attack_type TEXT, + data JSONB NOT NULL +); +CREATE TABLE IF NOT EXISTS conditions ( + id INTEGER GENERATED ALWAYS AS IDENTITY +); +CREATE TABLE IF NOT EXISTS bestiary ( + id INTEGER GENERATED ALWAYS AS IDENTITY +); diff --git a/migrations/001_dnd_tables.sql b/migrations/001_dnd_tables.sql deleted file mode 100644 index 98ca9dd..0000000 --- a/migrations/001_dnd_tables.sql +++ /dev/null @@ -1,43 +0,0 @@ -CREATE TABLE IF NOT EXISTS races ( - id INTEGER GENERATED ALWAYS AS IDENTITY, - name TEXT NOT NULL, - size TEXT NOT NULL, - source TEXT NOT NULL, - data JSON NOT NULL -); -CREATE TABLE IF NOT EXISTS classes ( - id INTEGER GENERATED ALWAYS AS IDENTITY -); -CREATE TABLE IF NOT EXISTS feats ( - id INTEGER GENERATED ALWAYS AS IDENTITY -); -CREATE TABLE IF NOT EXISTS options_features ( - id INTEGER GENERATED ALWAYS AS IDENTITY -); -CREATE TABLE IF NOT EXISTS backgrounds ( - id INTEGER GENERATED ALWAYS AS IDENTITY -); -CREATE TABLE IF NOT EXISTS items ( - id INTEGER GENERATED ALWAYS AS IDENTITY -); -CREATE TABLE IF NOT EXISTS spells ( - id INTEGER GENERATED ALWAYS AS IDENTITY, - name TEXT NOT NULL, - school TEXT NOT NULL, - level INTEGER NOT NULL, - ritual BOOLEAN DEFAULT FALSE, - concentration BOOLEAN DEFAULT FALSE, - classes TEXT[] NOT NULL, - damage_inflict TEXT[] NOT NULL, - damage_resist TEXT[] NOT NULL, - conditions TEXT[] NOT NULL, - saving_throw TEXT[] NOT NULL, - attack_type TEXT, - data JSONB NOT NULL -); -CREATE TABLE IF NOT EXISTS conditions ( - id INTEGER GENERATED ALWAYS AS IDENTITY -); -CREATE TABLE IF NOT EXISTS bestiary ( - id INTEGER GENERATED ALWAYS AS IDENTITY -); \ No newline at end of file diff --git a/src/api/app.rs b/src/api/app.rs new file mode 100644 index 0000000..b906b36 --- /dev/null +++ b/src/api/app.rs @@ -0,0 +1,29 @@ +use std::env; +use std::sync::Arc; +use axum::Router; +use tokio::net::TcpListener; +use crate::{api, AppState}; +use crate::error::SirenResult; + +pub struct App { + app_state: AppState, +} + +impl App { + pub fn new(app_state: AppState) -> Self { + Self { app_state } + } + + pub async fn serve(self) -> SirenResult<()> { + let app = Router::new() + .nest("/api", api::get_routes()) + .with_state(Arc::new(self.app_state)); + + let api_port: String = env::var("API_PORT").expect("Expected a port in the environment"); + let addr = format!("0.0.0.0:{}", api_port); + + let listener = TcpListener::bind(&addr).await?; + log::info!("API is listening on {}", &addr); + Ok(axum::serve(listener, app).await?) + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 429f314..7c5bd9b 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,13 @@ -use axum::Router; +mod app; +mod oauth; -pub fn get_routes() -> Router { - Router::new() +pub use app::App; + +use std::sync::Arc; +use axum::Router; +use serde::{Deserialize, Serialize}; +use crate::AppState; + +pub fn get_routes() -> Router> { + Router::new().nest("/oauth", oauth::get_routes()) } diff --git a/src/api/oauth.rs b/src/api/oauth.rs new file mode 100644 index 0000000..cd7dbaa --- /dev/null +++ b/src/api/oauth.rs @@ -0,0 +1,220 @@ +use std::env; +use std::sync::{Arc, OnceLock}; +use axum::extract::{Query, State}; +use axum::http::{HeaderMap, HeaderValue, StatusCode}; +use axum::{Json, Router}; +use axum::http::header::SET_COOKIE; +use axum::response::Redirect; +use axum::routing::get; +use chrono::{DateTime, Utc}; +use rand::Rng; +use rand_chacha::ChaCha20Rng; +use rand_chacha::rand_core::SeedableRng; +use redis::{AsyncCommands, RedisResult}; +use serde::{Deserialize, Serialize}; +use crate::{data, AppState}; +use crate::error::SirenResult; + +static SESSION_TTL: OnceLock = OnceLock::new(); + +pub fn get_routes() -> Router> { + Router::new() + .route("/authorize", get(discord_authorize)) + .route("/callback", get(oauth_callback)) +} + +fn get_session_ttl() -> i64 { + // Initialize the SESSION_TTL value lazily + *SESSION_TTL.get_or_init(|| { + env::var("SESSION_TTL") + .ok() + .and_then(|val| val.parse::().ok()) + .unwrap_or(3600) // Default to 3600 seconds (1 hour) + }) +} + +pub fn csprng(take: usize) -> String { + // Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9) + let rng = ChaCha20Rng::from_entropy(); + rng + .sample_iter(rand::distributions::Alphanumeric) + .take(take) + .map(char::from) + .collect() +} + +#[derive(Deserialize)] +struct AuthQuery { + code: String, + state: Option, +} + +#[derive(Serialize, Deserialize)] +struct TokenResponse { + access_token: String, + token_type: String, + expires_in: u64, + refresh_token: String, + scope: String, +} + +#[derive(Serialize, Deserialize, Debug)] +struct DiscordUser { + id: String, + username: String, + discriminator: String, + avatar: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +struct Session { + session_id: String, + user_id: String, + user_name: String, + pub expires_at: DateTime, +} + +impl Session { + fn new(id: String, user_id: String, user_name: String) -> Session { + let now = Utc::now(); + let session_ttl = get_session_ttl(); + Session { + session_id: id, + user_id, + user_name, + expires_at: now + chrono::Duration::seconds(session_ttl), + } + } + + async fn insert(&self) -> SirenResult<()> { + let mut redis = data::redis_async_connection().await?; + let session_id = self.session_id.clone(); + redis + .set_ex( + session_id, + serde_json::to_string(self)?, + self.expires_at.timestamp() as u64, + ) + .await?; + Ok(()) + } + + async fn get(session_id: String) -> SirenResult> { + let mut redis = data::redis_async_connection().await?; + let result: RedisResult> = redis.get(session_id).await; + match result { + Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)), + Ok(None) => Ok(None), + Err(err) => Err(err.into()), + } + } + + async fn delete(session_id: String) -> SirenResult<()> { + let mut redis = data::redis_async_connection().await?; + let result: RedisResult<()> = redis.del(session_id).await; + match result { + Ok(_) => Ok(()), + Err(err) => Err(err.into()), + } + } +} + +// async fn discord_authorize_redirect(State(state): State>) -> Redirect { +// // Construct the Discord OAuth URL +// let discord_auth_url = format!( +// "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", +// state.client_id, state.redirect_uri +// ); +// Redirect::temporary(&discord_auth_url) +// } + +async fn discord_authorize(State(state): State>) -> SirenResult { + // Construct the Discord OAuth URL + let discord_auth_url = format!( + "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", + state.client_id, state.redirect_uri + ); + Ok(discord_auth_url) +} + +async fn oauth_callback( + State(state): State>, + Query(query): Query, +) -> SirenResult<(HeaderMap, Json)> { + // Exchange code for an access token + let token_response = state + .client + .post("https://discord.com/api/oauth2/token") + .form(&[ + ("client_id", state.client_id.as_str()), + ("client_secret", state.client_secret.as_str()), + ("grant_type", "authorization_code"), + ("code", query.code.as_str()), + ("redirect_uri", state.redirect_uri.as_str()), + ]) + .send() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if !token_response.status().is_success() { + log::error!( + "Failed to exchange token: {:?}", + token_response.text().await + ); + return Err(StatusCode::INTERNAL_SERVER_ERROR.into()); + } + + let token_data: TokenResponse = token_response + .json() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // Fetch user information + let user_response = state + .client + .get("https://discord.com/api/users/@me") + .bearer_auth(token_data.access_token) + .send() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if !user_response.status().is_success() { + log::error!( + "Failed to fetch user information: {:?}", + user_response.text().await + ); + return Err(StatusCode::INTERNAL_SERVER_ERROR.into()); + } + + let user_data: DiscordUser = user_response + .json() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + log::debug!("User authenticated: {:?}", user_data); + + // Generate a session token + let session_token = csprng(16); + let expiration = env::var("API_SESSION_TTL") + .expect("Expected a session ttl in the environment") + .parse::() + .unwrap(); + + // Create and insert the session + let session = Session::new( + session_token.clone(), + user_data.id.clone(), + user_data.username.clone(), + ); + session.insert().await?; + + let cookie_value = format!( + "session={}; HttpOnly; Path=/; Max-Age={}", + session_token, expiration + ); + + let mut headers = HeaderMap::new(); + headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap()); + + Ok((headers, Json(user_data))) +} diff --git a/src/bot/commands/audio/play.rs b/src/bot/commands/audio/play.rs index fc73461..a342e0d 100644 --- a/src/bot/commands/audio/play.rs +++ b/src/bot/commands/audio/play.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}; use serenity::model::prelude::GuildId; -use serenity::{prelude::*, async_trait, futures}; +use serenity::{prelude::*, async_trait}; use songbird::input::{Input, YoutubeDl}; use songbird::tracks::TrackHandle; use songbird::{Event, EventHandler, Songbird, TrackEvent}; @@ -25,7 +25,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { "{} attempted to play a track without a track option", command.user.id.get() ); - create_message_response(&ctx, &command, format!("Track option is missing"), false).await; + create_message_response(&ctx, &command, "Track option is missing".to_string(), false).await; return; } }; @@ -53,7 +53,9 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) { // Join the user's voice channel match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user).await { Ok(channel_id) => { - log::debug!("<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}"); + log::debug!( + "<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}" + ); // Handle the track url match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await { Ok(items) => { diff --git a/src/bot/commands/chat.rs b/src/bot/commands/chat.rs index a72fb34..ab2716f 100644 --- a/src/bot/commands/chat.rs +++ b/src/bot/commands/chat.rs @@ -124,7 +124,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) { Err(err) => { log::error!( "<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}", - err.message + err.details ); "There was an error processing your message. Please try again later.".to_string() } @@ -196,7 +196,7 @@ async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String { } } Err(err) => { - log::error!("Could not get response from OpenAI: {}", err.message); + log::error!("Could not get response from OpenAI: {}", err.details); } }; return response; diff --git a/src/bot/commands/utility/ping.rs b/src/bot/commands/utility/ping.rs index e6622e6..6e24f70 100644 --- a/src/bot/commands/utility/ping.rs +++ b/src/bot/commands/utility/ping.rs @@ -3,6 +3,14 @@ use crate::bot::chat::create_message_response; pub async fn run(ctx: &Context, command: &CommandInteraction) { log::debug!("Ping command executed"); + + if let Some(guild_id) = command.guild_id { + if let Some(guild) = guild_id.to_guild_cached(&ctx.cache) { + let owner_id = guild.owner_id; + if command.user.id == owner_id {} + } + } + create_message_response(&ctx, &command, "pong".to_string(), true).await; } diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 447dbfd..b7a4c83 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -9,13 +9,13 @@ use crate::data::guilds::GuildCache; use super::{commands}; use super::chat::{create_message_response, create_modal_response}; -pub struct Handler { +pub struct BotHandler { // Open AI Config pub oai: Option, } #[async_trait] -impl EventHandler for Handler { +impl EventHandler for BotHandler { async fn message(&self, ctx: Context, msg: Message) { // Ignore bot messages if msg.author.bot { @@ -47,7 +47,8 @@ impl EventHandler for Handler { if let None = GuildCache::get_by_id(guild_id).await.unwrap() { let guild_cache = GuildCache { id: guild_id, - bot_id: 1, + name: guild.id.name(&ctx.cache), + owner_id: None, volume: 100, }; guild_cache.insert().await.unwrap(); diff --git a/src/bot/oai/model.rs b/src/bot/oai/model.rs index b580997..b5b0dec 100644 --- a/src/bot/oai/model.rs +++ b/src/bot/oai/model.rs @@ -129,7 +129,7 @@ impl OAI { ResponseEvent::ResponseError(error) => { return Err(SirenError { status: 500, - message: format!("Error: {}", error.message.unwrap()), + details: format!("Error: {}", error.message.unwrap()), }); } } @@ -137,7 +137,7 @@ impl OAI { Err(err) => { return Err(SirenError { status: 500, - message: format!("Error: {}", err), + details: format!("Error: {}", err), }) } } diff --git a/src/data/guilds/model.rs b/src/data/guilds/model.rs index 8d7f06e..54f5cde 100644 --- a/src/data/guilds/model.rs +++ b/src/data/guilds/model.rs @@ -6,7 +6,8 @@ const TABLE_NAME: &str = "guilds"; #[derive(Debug, Serialize, Deserialize, sqlx::FromRow)] pub struct GuildCache { pub id: i64, - pub bot_id: i64, + pub name: Option, + pub owner_id: Option, pub volume: i32, } @@ -16,18 +17,20 @@ impl GuildCache { sqlx::query(&format!( "INSERT INTO {} ( id, - bot_id, + name, + owner_id, volume ) VALUES ( - $1, $2, $3 + $1, $2, $3, $4 )", TABLE_NAME )) - .bind(self.id) - .bind(self.bot_id) - .bind(self.volume) - .execute(pool) - .await?; + .bind(self.id) + .bind(&self.name) + .bind(self.owner_id) + .bind(self.volume) + .execute(pool) + .await?; Ok(()) } @@ -45,16 +48,18 @@ impl GuildCache { let pool = crate::data::pool(); sqlx::query(&format!( "UPDATE {} SET - bot_id = $2, - volume = $3 + name = $2, + owner_id = $3, + volume = $4 WHERE id = $1", TABLE_NAME )) - .bind(self.id) - .bind(self.bot_id) - .bind(self.volume) - .execute(pool) - .await?; + .bind(self.id) + .bind(&self.name) + .bind(self.owner_id) + .bind(self.volume) + .execute(pool) + .await?; Ok(()) } } diff --git a/src/error.rs b/src/error.rs index 28a14a5..60ff4d2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,7 @@ use std::fmt; +use axum::http::StatusCode; +use axum::Json; +use axum::response::{IntoResponse, Response}; use serde::{Deserialize, Serialize}; pub type SirenResult = Result; @@ -6,21 +9,44 @@ pub type SirenResult = Result; #[derive(Debug, Deserialize, Serialize)] pub struct Error { pub status: u16, - pub message: String, + pub details: String, } impl Error { pub fn new(error_status_code: u16, error_message: String) -> Self { Self { status: error_status_code, - message: error_message, + details: error_message, } } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(self.message.as_str()) + f.write_str(self.details.as_str()) + } +} + +impl std::error::Error for Error { + fn description(&self) -> &str { + &self.details + } +} + +impl IntoResponse for Error { + fn into_response(self) -> Response { + let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + // Create a JSON response with the structured error + let body = Json(serde_json::json!({ + "error": { + "status": self.status, + "details": self.details, + } + })); + + // Return the response with the proper status and error body + (status, body).into_response() } } @@ -30,6 +56,18 @@ impl From for Error { } } +impl From for Error { + fn from(status: StatusCode) -> Self { + Error { + status: status.as_u16(), + details: status + .canonical_reason() + .unwrap_or("Unknown error") + .to_string(), + } + } +} + impl From for Error { fn from(error: std::string::FromUtf8Error) -> Self { Self::new(500, format!("Unknown from utf8 error: {}", error)) diff --git a/src/main.rs b/src/main.rs index 1f6c959..ad8597a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,13 @@ use std::env; use std::collections::HashSet; use std::sync::Arc; -use axum::Router; use serenity::http::Http; use serenity::prelude::*; use songbird::{SerenityInit, Songbird}; use reqwest::Client as HttpClient; -use serenity::all::{ShardManager, UserId}; -use tokio::net::TcpListener; - -use crate::bot::handler::Handler; +use serenity::all::{Cache, ShardManager, UserId}; +use crate::api::App; +use crate::bot::handler::BotHandler; use crate::bot::oai::OAI; mod api; @@ -24,47 +22,24 @@ impl TypeMapKey for HttpKey { type Value = HttpClient; } +#[derive(Clone)] +struct AppState { + client: reqwest::Client, + client_id: String, + client_secret: String, + redirect_uri: String, + http: Arc, + cache: Arc, +} + #[tokio::main] -async fn main() { +async fn main() -> Result<(), Box> { dotenv::dotenv().ok(); env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); - if let Err(err) = data::initialize().await { - log::error!("Failed to initialize database: {err}"); - return; - }; + data::initialize().await?; - // Start API server - tokio::spawn(start_api()); - - // Start Discord bot - start_bot().await; -} - -async fn start_api() { - let app = Router::new(); - let addr: String = "127.0.0.1:3000".parse().unwrap(); - - let listener = TcpListener::bind(&addr).await.unwrap(); - log::debug!("API is listening on {}", &addr); - axum::serve(listener, app).await.unwrap(); -} - -async fn start_bot() { let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); - let intents: GatewayIntents = GatewayIntents::all(); - - let http: Http = Http::new(&token); - let (owners, bot_id) = get_bot_info(&http).await; - - log::debug!( - "Starting Discord bot with ID: {bot_id} and owners: {}", - owners - .iter() - .map(|id| id.to_string()) - .collect::>() - .join(", ") - ); // Set up handler with optional OpenAI integration let handler = configure_handler(); @@ -72,6 +47,8 @@ async fn start_bot() { // Set up Songbird for voice functionality let songbird = Songbird::serenity(); + let intents: GatewayIntents = GatewayIntents::all(); + let mut client = Client::builder(token, intents) .event_handler(handler) // .framework(StandardFramework::new().configure(|c| c.owners(owners))) @@ -80,29 +57,53 @@ async fn start_bot() { .await .expect("Error creating client"); + let (bot_owner, bot_id) = get_bot_info(&client.http).await; + + let client_secret: String = + env::var("DISCORD_SECRET").expect("Expected a secret in the environment"); + let redirect_uri: String = + env::var("API_CALLBACK_URI").expect("Expected a secret in the environment"); + let app_state = AppState { + client: HttpClient::new(), + client_id: bot_id.to_string(), + client_secret, + redirect_uri, + http: Arc::clone(&client.http), + cache: Arc::clone(&client.cache), + }; + + log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner); + // Spawn shutdown signal handling let shard_manager = Arc::clone(&client.shard_manager); tokio::spawn(async move { signal_shutdown(shard_manager).await; }); - // Start the bot + // Start API server + tokio::spawn(App::new(app_state).serve()); + + // Start Discord bot if let Err(why) = client.start_autosharded().await { log::error!("Client error: {why:?}"); } + + Ok(()) } -async fn get_bot_info(http: &Http) -> (HashSet, UserId) { +async fn get_bot_info(http: &Http) -> (Option, UserId) { match http.get_current_application_info().await { Ok(info) => { - let mut owners = HashSet::new(); + let bot_owner; if let Some(team) = info.team { - owners.insert(team.owner_user_id); + bot_owner = Some(team.owner_user_id); } else if let Some(owner) = info.owner { - owners.insert(owner.id); + bot_owner = Some(owner.id); + } else { + bot_owner = None; } match http.get_current_user().await { - Ok(bot) => (owners, bot.id), + Ok(bot) => (bot_owner, bot.id), Err(why) => panic!("Could not access the bot id: {why:?}"), } } @@ -110,13 +111,13 @@ async fn get_bot_info(http: &Http) -> (HashSet, UserId) { } } -fn configure_handler() -> Handler { +fn configure_handler() -> BotHandler { match env::var("OPENAI_TOKEN") { Ok(token) => { log::debug!("OpenAI functionality enabled"); let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()); let base_url = env::var("OPENAI_BASE_URL").unwrap(); - Handler { + BotHandler { oai: Some(OAI { client: reqwest::Client::new(), base_url, @@ -129,7 +130,7 @@ fn configure_handler() -> Handler { } Err(_) => { log::warn!("OpenAI functionality disabled"); - Handler { oai: None } + BotHandler { oai: None } } } }