Re-implementing the API

This commit is contained in:
2024-12-19 13:50:31 -05:00
parent 9344979d72
commit 4a18af9014
17 changed files with 486 additions and 152 deletions

5
.env
View File

@@ -1,6 +1,7 @@
RUST_LOG=warn,siren=info RUST_LOG=warn,siren=info
DISCORD_TOKEN= DISCORD_TOKEN=
DISCORD_SECRET=
DATABASE_USER=siren DATABASE_USER=siren
DATABASE_PASSWORD=CHANGEME # Change this to a secure password DATABASE_PASSWORD=CHANGEME # Change this to a secure password
@@ -8,7 +9,9 @@ DATABASE_NAME=siren
DATABASE_HOST=localhost DATABASE_HOST=localhost
DATABASE_PORT=5432 DATABASE_PORT=5432
SESSION_TTL=1440 API_CALLBACK_URI=http://localhost:3000/api/oauth/callback
API_PORT=3000
API_SESSION_TTL=86400
MINIO_ROOT_USER=siren MINIO_ROOT_USER=siren
MINIO_ROOT_PASSWORD=CHANGEME # Change this to a secure password MINIO_ROOT_PASSWORD=CHANGEME # Change this to a secure password

View File

@@ -23,8 +23,10 @@ reqwest = { version = "0.11", default-features = false, features = ["json"] }
uuid = { version = "1.11.0", features = ["serde", "v4"] } uuid = { version = "1.11.0", features = ["serde", "v4"] }
redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] } redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] }
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1"
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] } tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
regex = "1.11.0" regex = "1.11.0"
axum = "0.7.7" axum = "0.7.7"
lazy_static = "1.5.0" lazy_static = "1.5.0"
futures = "0.3.31" futures = "0.3.31"
axum-login = "0.16.0"

View File

@@ -13,7 +13,7 @@ services:
environment: environment:
DATABASE_HOST: siren-postgres DATABASE_HOST: siren-postgres
DATABASE_PORT: 5432 DATABASE_PORT: 5432
REDIS_HOST: redis REDIS_HOST: siren-redis
REDIS_PORT: 6379 REDIS_PORT: 6379
DATA_DIR_PATH: /data DATA_DIR_PATH: /data
volumes: volumes:
@@ -42,14 +42,27 @@ services:
- ${DATABASE_PORT:-5432}:5432 - ${DATABASE_PORT:-5432}:5432
networks: networks:
- backend - backend
restart: unless-stopped
profiles: profiles:
- backend - backend
restart: unless-stopped
redis:
image: redis:latest
container_name: siren-redis
volumes:
- redis:/data
ports:
- ${REDIS_PORT:-6379}:6379
networks:
- backend
restart: unless-stopped
profiles:
- backend
volumes: volumes:
postgres: postgres:
postgres_logs: postgres_logs:
redis:
networks: networks:
frontend: frontend:

View File

@@ -1,26 +0,0 @@
CREATE TABLE IF NOT EXISTS guilds (
id BIGINT PRIMARY KEY NOT NULL,
bot_id BIGINT NOT NULL,
volume INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
channel_id BIGINT NOT NULL,
author_id BIGINT NOT NULL,
created BIGINT NOT NULL,
model TEXT NOT NULL,
request TEXT NOT NULL,
response TEXT NOT NULL,
request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL
);
CREATE TABLE IF NOT EXISTS events (
id UUID PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
author_id BIGINT NOT NULL,
title TEXT NOT NULL,
date_time TIMESTAMP NOT NULL,
description TEXT,
rsvp BIGINT[] NOT NULL
);

View File

@@ -0,0 +1,73 @@
CREATE TABLE IF NOT EXISTS guilds (
id BIGINT PRIMARY KEY NOT NULL,
name TEXT,
owner_id BIGINT,
volume INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
channel_id BIGINT NOT NULL,
author_id BIGINT NOT NULL,
created BIGINT NOT NULL,
model TEXT NOT NULL,
request TEXT NOT NULL,
response TEXT NOT NULL,
request_tags TEXT[] NOT NULL,
response_tags TEXT[] NOT NULL
);
CREATE TABLE IF NOT EXISTS dice_rolls (
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid()
);
CREATE TABLE IF NOT EXISTS events (
id UUID PRIMARY KEY NOT NULL,
guild_id BIGINT NOT NULL,
author_id BIGINT NOT NULL,
title TEXT NOT NULL,
date_time TIMESTAMP NOT NULL,
description TEXT,
rsvp BIGINT[] NOT NULL
);
CREATE TABLE IF NOT EXISTS races (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
size TEXT NOT NULL,
source TEXT NOT NULL,
data JSON NOT NULL
);
CREATE TABLE IF NOT EXISTS classes (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS feats (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS options_features (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS backgrounds (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS items (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS spells (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
school TEXT NOT NULL,
level INTEGER NOT NULL,
ritual BOOLEAN DEFAULT FALSE,
concentration BOOLEAN DEFAULT FALSE,
classes TEXT[] NOT NULL,
damage_inflict TEXT[] NOT NULL,
damage_resist TEXT[] NOT NULL,
conditions TEXT[] NOT NULL,
saving_throw TEXT[] NOT NULL,
attack_type TEXT,
data JSONB NOT NULL
);
CREATE TABLE IF NOT EXISTS conditions (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS bestiary (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

View File

@@ -1,43 +0,0 @@
CREATE TABLE IF NOT EXISTS races (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
size TEXT NOT NULL,
source TEXT NOT NULL,
data JSON NOT NULL
);
CREATE TABLE IF NOT EXISTS classes (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS feats (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS options_features (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS backgrounds (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS items (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS spells (
id INTEGER GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
school TEXT NOT NULL,
level INTEGER NOT NULL,
ritual BOOLEAN DEFAULT FALSE,
concentration BOOLEAN DEFAULT FALSE,
classes TEXT[] NOT NULL,
damage_inflict TEXT[] NOT NULL,
damage_resist TEXT[] NOT NULL,
conditions TEXT[] NOT NULL,
saving_throw TEXT[] NOT NULL,
attack_type TEXT,
data JSONB NOT NULL
);
CREATE TABLE IF NOT EXISTS conditions (
id INTEGER GENERATED ALWAYS AS IDENTITY
);
CREATE TABLE IF NOT EXISTS bestiary (
id INTEGER GENERATED ALWAYS AS IDENTITY
);

29
src/api/app.rs Normal file
View File

@@ -0,0 +1,29 @@
use std::env;
use std::sync::Arc;
use axum::Router;
use tokio::net::TcpListener;
use crate::{api, AppState};
use crate::error::SirenResult;
pub struct App {
app_state: AppState,
}
impl App {
pub fn new(app_state: AppState) -> Self {
Self { app_state }
}
pub async fn serve(self) -> SirenResult<()> {
let app = Router::new()
.nest("/api", api::get_routes())
.with_state(Arc::new(self.app_state));
let api_port: String = env::var("API_PORT").expect("Expected a port in the environment");
let addr = format!("0.0.0.0:{}", api_port);
let listener = TcpListener::bind(&addr).await?;
log::info!("API is listening on {}", &addr);
Ok(axum::serve(listener, app).await?)
}
}

View File

@@ -1,5 +1,13 @@
use axum::Router; mod app;
mod oauth;
pub fn get_routes() -> Router { pub use app::App;
Router::new()
use std::sync::Arc;
use axum::Router;
use serde::{Deserialize, Serialize};
use crate::AppState;
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new().nest("/oauth", oauth::get_routes())
} }

220
src/api/oauth.rs Normal file
View File

@@ -0,0 +1,220 @@
use std::env;
use std::sync::{Arc, OnceLock};
use axum::extract::{Query, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::{Json, Router};
use axum::http::header::SET_COOKIE;
use axum::response::Redirect;
use axum::routing::get;
use chrono::{DateTime, Utc};
use rand::Rng;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::SeedableRng;
use redis::{AsyncCommands, RedisResult};
use serde::{Deserialize, Serialize};
use crate::{data, AppState};
use crate::error::SirenResult;
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
pub fn get_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/authorize", get(discord_authorize))
.route("/callback", get(oauth_callback))
}
fn get_session_ttl() -> i64 {
// Initialize the SESSION_TTL value lazily
*SESSION_TTL.get_or_init(|| {
env::var("SESSION_TTL")
.ok()
.and_then(|val| val.parse::<i64>().ok())
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
})
}
pub fn csprng(take: usize) -> String {
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
let rng = ChaCha20Rng::from_entropy();
rng
.sample_iter(rand::distributions::Alphanumeric)
.take(take)
.map(char::from)
.collect()
}
#[derive(Deserialize)]
struct AuthQuery {
code: String,
state: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct TokenResponse {
access_token: String,
token_type: String,
expires_in: u64,
refresh_token: String,
scope: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct DiscordUser {
id: String,
username: String,
discriminator: String,
avatar: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct Session {
session_id: String,
user_id: String,
user_name: String,
pub expires_at: DateTime<Utc>,
}
impl Session {
fn new(id: String, user_id: String, user_name: String) -> Session {
let now = Utc::now();
let session_ttl = get_session_ttl();
Session {
session_id: id,
user_id,
user_name,
expires_at: now + chrono::Duration::seconds(session_ttl),
}
}
async fn insert(&self) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let session_id = self.session_id.clone();
redis
.set_ex(
session_id,
serde_json::to_string(self)?,
self.expires_at.timestamp() as u64,
)
.await?;
Ok(())
}
async fn get(session_id: String) -> SirenResult<Option<Session>> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<Option<String>> = redis.get(session_id).await;
match result {
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
}
}
async fn delete(session_id: String) -> SirenResult<()> {
let mut redis = data::redis_async_connection().await?;
let result: RedisResult<()> = redis.del(session_id).await;
match result {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
}
}
// async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
// // Construct the Discord OAuth URL
// let discord_auth_url = format!(
// "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
// state.client_id, state.redirect_uri
// );
// Redirect::temporary(&discord_auth_url)
// }
async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<String> {
// Construct the Discord OAuth URL
let discord_auth_url = format!(
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
state.client_id, state.redirect_uri
);
Ok(discord_auth_url)
}
async fn oauth_callback(
State(state): State<Arc<AppState>>,
Query(query): Query<AuthQuery>,
) -> SirenResult<(HeaderMap, Json<DiscordUser>)> {
// Exchange code for an access token
let token_response = state
.client
.post("https://discord.com/api/oauth2/token")
.form(&[
("client_id", state.client_id.as_str()),
("client_secret", state.client_secret.as_str()),
("grant_type", "authorization_code"),
("code", query.code.as_str()),
("redirect_uri", state.redirect_uri.as_str()),
])
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !token_response.status().is_success() {
log::error!(
"Failed to exchange token: {:?}",
token_response.text().await
);
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
}
let token_data: TokenResponse = token_response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Fetch user information
let user_response = state
.client
.get("https://discord.com/api/users/@me")
.bearer_auth(token_data.access_token)
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !user_response.status().is_success() {
log::error!(
"Failed to fetch user information: {:?}",
user_response.text().await
);
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
}
let user_data: DiscordUser = user_response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
log::debug!("User authenticated: {:?}", user_data);
// Generate a session token
let session_token = csprng(16);
let expiration = env::var("API_SESSION_TTL")
.expect("Expected a session ttl in the environment")
.parse::<u64>()
.unwrap();
// Create and insert the session
let session = Session::new(
session_token.clone(),
user_data.id.clone(),
user_data.username.clone(),
);
session.insert().await?;
let cookie_value = format!(
"session={}; HttpOnly; Path=/; Max-Age={}",
session_token, expiration
);
let mut headers = HeaderMap::new();
headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
Ok((headers, Json(user_data)))
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption}; use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption};
use serenity::model::prelude::GuildId; use serenity::model::prelude::GuildId;
use serenity::{prelude::*, async_trait, futures}; use serenity::{prelude::*, async_trait};
use songbird::input::{Input, YoutubeDl}; use songbird::input::{Input, YoutubeDl};
use songbird::tracks::TrackHandle; use songbird::tracks::TrackHandle;
use songbird::{Event, EventHandler, Songbird, TrackEvent}; use songbird::{Event, EventHandler, Songbird, TrackEvent};
@@ -25,7 +25,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
"{} attempted to play a track without a track option", "{} attempted to play a track without a track option",
command.user.id.get() command.user.id.get()
); );
create_message_response(&ctx, &command, format!("Track option is missing"), false).await; create_message_response(&ctx, &command, "Track option is missing".to_string(), false).await;
return; return;
} }
}; };
@@ -53,7 +53,9 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
// Join the user's voice channel // Join the user's voice channel
match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user).await { match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user).await {
Ok(channel_id) => { Ok(channel_id) => {
log::debug!("<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}"); log::debug!(
"<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}"
);
// Handle the track url // Handle the track url
match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await { match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await {
Ok(items) => { Ok(items) => {

View File

@@ -124,7 +124,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
Err(err) => { Err(err) => {
log::error!( log::error!(
"<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}", "<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}",
err.message err.details
); );
"There was an error processing your message. Please try again later.".to_string() "There was an error processing your message. Please try again later.".to_string()
} }
@@ -196,7 +196,7 @@ async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String {
} }
} }
Err(err) => { Err(err) => {
log::error!("Could not get response from OpenAI: {}", err.message); log::error!("Could not get response from OpenAI: {}", err.details);
} }
}; };
return response; return response;

View File

@@ -3,6 +3,14 @@ use crate::bot::chat::create_message_response;
pub async fn run(ctx: &Context, command: &CommandInteraction) { pub async fn run(ctx: &Context, command: &CommandInteraction) {
log::debug!("Ping command executed"); log::debug!("Ping command executed");
if let Some(guild_id) = command.guild_id {
if let Some(guild) = guild_id.to_guild_cached(&ctx.cache) {
let owner_id = guild.owner_id;
if command.user.id == owner_id {}
}
}
create_message_response(&ctx, &command, "pong".to_string(), true).await; create_message_response(&ctx, &command, "pong".to_string(), true).await;
} }

View File

@@ -9,13 +9,13 @@ use crate::data::guilds::GuildCache;
use super::{commands}; use super::{commands};
use super::chat::{create_message_response, create_modal_response}; use super::chat::{create_message_response, create_modal_response};
pub struct Handler { pub struct BotHandler {
// Open AI Config // Open AI Config
pub oai: Option<OAI>, pub oai: Option<OAI>,
} }
#[async_trait] #[async_trait]
impl EventHandler for Handler { impl EventHandler for BotHandler {
async fn message(&self, ctx: Context, msg: Message) { async fn message(&self, ctx: Context, msg: Message) {
// Ignore bot messages // Ignore bot messages
if msg.author.bot { if msg.author.bot {
@@ -47,7 +47,8 @@ impl EventHandler for Handler {
if let None = GuildCache::get_by_id(guild_id).await.unwrap() { if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
let guild_cache = GuildCache { let guild_cache = GuildCache {
id: guild_id, id: guild_id,
bot_id: 1, name: guild.id.name(&ctx.cache),
owner_id: None,
volume: 100, volume: 100,
}; };
guild_cache.insert().await.unwrap(); guild_cache.insert().await.unwrap();

View File

@@ -129,7 +129,7 @@ impl OAI {
ResponseEvent::ResponseError(error) => { ResponseEvent::ResponseError(error) => {
return Err(SirenError { return Err(SirenError {
status: 500, status: 500,
message: format!("Error: {}", error.message.unwrap()), details: format!("Error: {}", error.message.unwrap()),
}); });
} }
} }
@@ -137,7 +137,7 @@ impl OAI {
Err(err) => { Err(err) => {
return Err(SirenError { return Err(SirenError {
status: 500, status: 500,
message: format!("Error: {}", err), details: format!("Error: {}", err),
}) })
} }
} }

View File

@@ -6,7 +6,8 @@ const TABLE_NAME: &str = "guilds";
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)] #[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct GuildCache { pub struct GuildCache {
pub id: i64, pub id: i64,
pub bot_id: i64, pub name: Option<String>,
pub owner_id: Option<i64>,
pub volume: i32, pub volume: i32,
} }
@@ -16,15 +17,17 @@ impl GuildCache {
sqlx::query(&format!( sqlx::query(&format!(
"INSERT INTO {} ( "INSERT INTO {} (
id, id,
bot_id, name,
owner_id,
volume volume
) VALUES ( ) VALUES (
$1, $2, $3 $1, $2, $3, $4
)", )",
TABLE_NAME TABLE_NAME
)) ))
.bind(self.id) .bind(self.id)
.bind(self.bot_id) .bind(&self.name)
.bind(self.owner_id)
.bind(self.volume) .bind(self.volume)
.execute(pool) .execute(pool)
.await?; .await?;
@@ -45,13 +48,15 @@ impl GuildCache {
let pool = crate::data::pool(); let pool = crate::data::pool();
sqlx::query(&format!( sqlx::query(&format!(
"UPDATE {} SET "UPDATE {} SET
bot_id = $2, name = $2,
volume = $3 owner_id = $3,
volume = $4
WHERE id = $1", WHERE id = $1",
TABLE_NAME TABLE_NAME
)) ))
.bind(self.id) .bind(self.id)
.bind(self.bot_id) .bind(&self.name)
.bind(self.owner_id)
.bind(self.volume) .bind(self.volume)
.execute(pool) .execute(pool)
.await?; .await?;

View File

@@ -1,4 +1,7 @@
use std::fmt; use std::fmt;
use axum::http::StatusCode;
use axum::Json;
use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub type SirenResult<T> = Result<T, Error>; pub type SirenResult<T> = Result<T, Error>;
@@ -6,21 +9,44 @@ pub type SirenResult<T> = Result<T, Error>;
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
pub struct Error { pub struct Error {
pub status: u16, pub status: u16,
pub message: String, pub details: String,
} }
impl Error { impl Error {
pub fn new(error_status_code: u16, error_message: String) -> Self { pub fn new(error_status_code: u16, error_message: String) -> Self {
Self { Self {
status: error_status_code, status: error_status_code,
message: error_message, details: error_message,
} }
} }
} }
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.message.as_str()) f.write_str(self.details.as_str())
}
}
impl std::error::Error for Error {
fn description(&self) -> &str {
&self.details
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
// Create a JSON response with the structured error
let body = Json(serde_json::json!({
"error": {
"status": self.status,
"details": self.details,
}
}));
// Return the response with the proper status and error body
(status, body).into_response()
} }
} }
@@ -30,6 +56,18 @@ impl From<std::io::Error> for Error {
} }
} }
impl From<StatusCode> for Error {
fn from(status: StatusCode) -> Self {
Error {
status: status.as_u16(),
details: status
.canonical_reason()
.unwrap_or("Unknown error")
.to_string(),
}
}
}
impl From<std::string::FromUtf8Error> for Error { impl From<std::string::FromUtf8Error> for Error {
fn from(error: std::string::FromUtf8Error) -> Self { fn from(error: std::string::FromUtf8Error) -> Self {
Self::new(500, format!("Unknown from utf8 error: {}", error)) Self::new(500, format!("Unknown from utf8 error: {}", error))

View File

@@ -1,15 +1,13 @@
use std::env; use std::env;
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use axum::Router;
use serenity::http::Http; use serenity::http::Http;
use serenity::prelude::*; use serenity::prelude::*;
use songbird::{SerenityInit, Songbird}; use songbird::{SerenityInit, Songbird};
use reqwest::Client as HttpClient; use reqwest::Client as HttpClient;
use serenity::all::{ShardManager, UserId}; use serenity::all::{Cache, ShardManager, UserId};
use tokio::net::TcpListener; use crate::api::App;
use crate::bot::handler::BotHandler;
use crate::bot::handler::Handler;
use crate::bot::oai::OAI; use crate::bot::oai::OAI;
mod api; mod api;
@@ -24,47 +22,24 @@ impl TypeMapKey for HttpKey {
type Value = HttpClient; type Value = HttpClient;
} }
#[derive(Clone)]
struct AppState {
client: reqwest::Client,
client_id: String,
client_secret: String,
redirect_uri: String,
http: Arc<Http>,
cache: Arc<Cache>,
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> Result<(), Box<dyn std::error::Error>> {
dotenv::dotenv().ok(); dotenv::dotenv().ok();
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info")); env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
if let Err(err) = data::initialize().await { data::initialize().await?;
log::error!("Failed to initialize database: {err}");
return;
};
// Start API server
tokio::spawn(start_api());
// Start Discord bot
start_bot().await;
}
async fn start_api() {
let app = Router::new();
let addr: String = "127.0.0.1:3000".parse().unwrap();
let listener = TcpListener::bind(&addr).await.unwrap();
log::debug!("API is listening on {}", &addr);
axum::serve(listener, app).await.unwrap();
}
async fn start_bot() {
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
let intents: GatewayIntents = GatewayIntents::all();
let http: Http = Http::new(&token);
let (owners, bot_id) = get_bot_info(&http).await;
log::debug!(
"Starting Discord bot with ID: {bot_id} and owners: {}",
owners
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", ")
);
// Set up handler with optional OpenAI integration // Set up handler with optional OpenAI integration
let handler = configure_handler(); let handler = configure_handler();
@@ -72,6 +47,8 @@ async fn start_bot() {
// Set up Songbird for voice functionality // Set up Songbird for voice functionality
let songbird = Songbird::serenity(); let songbird = Songbird::serenity();
let intents: GatewayIntents = GatewayIntents::all();
let mut client = Client::builder(token, intents) let mut client = Client::builder(token, intents)
.event_handler(handler) .event_handler(handler)
// .framework(StandardFramework::new().configure(|c| c.owners(owners))) // .framework(StandardFramework::new().configure(|c| c.owners(owners)))
@@ -80,29 +57,53 @@ async fn start_bot() {
.await .await
.expect("Error creating client"); .expect("Error creating client");
let (bot_owner, bot_id) = get_bot_info(&client.http).await;
let client_secret: String =
env::var("DISCORD_SECRET").expect("Expected a secret in the environment");
let redirect_uri: String =
env::var("API_CALLBACK_URI").expect("Expected a secret in the environment");
let app_state = AppState {
client: HttpClient::new(),
client_id: bot_id.to_string(),
client_secret,
redirect_uri,
http: Arc::clone(&client.http),
cache: Arc::clone(&client.cache),
};
log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner);
// Spawn shutdown signal handling // Spawn shutdown signal handling
let shard_manager = Arc::clone(&client.shard_manager); let shard_manager = Arc::clone(&client.shard_manager);
tokio::spawn(async move { tokio::spawn(async move {
signal_shutdown(shard_manager).await; signal_shutdown(shard_manager).await;
}); });
// Start the bot // Start API server
tokio::spawn(App::new(app_state).serve());
// Start Discord bot
if let Err(why) = client.start_autosharded().await { if let Err(why) = client.start_autosharded().await {
log::error!("Client error: {why:?}"); log::error!("Client error: {why:?}");
} }
Ok(())
} }
async fn get_bot_info(http: &Http) -> (HashSet<UserId>, UserId) { async fn get_bot_info(http: &Http) -> (Option<UserId>, UserId) {
match http.get_current_application_info().await { match http.get_current_application_info().await {
Ok(info) => { Ok(info) => {
let mut owners = HashSet::new(); let bot_owner;
if let Some(team) = info.team { if let Some(team) = info.team {
owners.insert(team.owner_user_id); bot_owner = Some(team.owner_user_id);
} else if let Some(owner) = info.owner { } else if let Some(owner) = info.owner {
owners.insert(owner.id); bot_owner = Some(owner.id);
} else {
bot_owner = None;
} }
match http.get_current_user().await { match http.get_current_user().await {
Ok(bot) => (owners, bot.id), Ok(bot) => (bot_owner, bot.id),
Err(why) => panic!("Could not access the bot id: {why:?}"), Err(why) => panic!("Could not access the bot id: {why:?}"),
} }
} }
@@ -110,13 +111,13 @@ async fn get_bot_info(http: &Http) -> (HashSet<UserId>, UserId) {
} }
} }
fn configure_handler() -> Handler { fn configure_handler() -> BotHandler {
match env::var("OPENAI_TOKEN") { match env::var("OPENAI_TOKEN") {
Ok(token) => { Ok(token) => {
log::debug!("OpenAI functionality enabled"); log::debug!("OpenAI functionality enabled");
let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()); let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
let base_url = env::var("OPENAI_BASE_URL").unwrap(); let base_url = env::var("OPENAI_BASE_URL").unwrap();
Handler { BotHandler {
oai: Some(OAI { oai: Some(OAI {
client: reqwest::Client::new(), client: reqwest::Client::new(),
base_url, base_url,
@@ -129,7 +130,7 @@ fn configure_handler() -> Handler {
} }
Err(_) => { Err(_) => {
log::warn!("OpenAI functionality disabled"); log::warn!("OpenAI functionality disabled");
Handler { oai: None } BotHandler { oai: None }
} }
} }
} }