Re-implementing the API
This commit is contained in:
5
.env
5
.env
@@ -1,6 +1,7 @@
|
|||||||
RUST_LOG=warn,siren=info
|
RUST_LOG=warn,siren=info
|
||||||
|
|
||||||
DISCORD_TOKEN=
|
DISCORD_TOKEN=
|
||||||
|
DISCORD_SECRET=
|
||||||
|
|
||||||
DATABASE_USER=siren
|
DATABASE_USER=siren
|
||||||
DATABASE_PASSWORD=CHANGEME # Change this to a secure password
|
DATABASE_PASSWORD=CHANGEME # Change this to a secure password
|
||||||
@@ -8,7 +9,9 @@ DATABASE_NAME=siren
|
|||||||
DATABASE_HOST=localhost
|
DATABASE_HOST=localhost
|
||||||
DATABASE_PORT=5432
|
DATABASE_PORT=5432
|
||||||
|
|
||||||
SESSION_TTL=1440
|
API_CALLBACK_URI=http://localhost:3000/api/oauth/callback
|
||||||
|
API_PORT=3000
|
||||||
|
API_SESSION_TTL=86400
|
||||||
|
|
||||||
MINIO_ROOT_USER=siren
|
MINIO_ROOT_USER=siren
|
||||||
MINIO_ROOT_PASSWORD=CHANGEME # Change this to a secure password
|
MINIO_ROOT_PASSWORD=CHANGEME # Change this to a secure password
|
||||||
|
|||||||
@@ -23,8 +23,10 @@ reqwest = { version = "0.11", default-features = false, features = ["json"] }
|
|||||||
uuid = { version = "1.11.0", features = ["serde", "v4"] }
|
uuid = { version = "1.11.0", features = ["serde", "v4"] }
|
||||||
redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] }
|
redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
rand_chacha = "0.3.1"
|
||||||
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
|
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
|
||||||
regex = "1.11.0"
|
regex = "1.11.0"
|
||||||
axum = "0.7.7"
|
axum = "0.7.7"
|
||||||
lazy_static = "1.5.0"
|
lazy_static = "1.5.0"
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
|
axum-login = "0.16.0"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
DATABASE_HOST: siren-postgres
|
DATABASE_HOST: siren-postgres
|
||||||
DATABASE_PORT: 5432
|
DATABASE_PORT: 5432
|
||||||
REDIS_HOST: redis
|
REDIS_HOST: siren-redis
|
||||||
REDIS_PORT: 6379
|
REDIS_PORT: 6379
|
||||||
DATA_DIR_PATH: /data
|
DATA_DIR_PATH: /data
|
||||||
volumes:
|
volumes:
|
||||||
@@ -42,14 +42,27 @@ services:
|
|||||||
- ${DATABASE_PORT:-5432}:5432
|
- ${DATABASE_PORT:-5432}:5432
|
||||||
networks:
|
networks:
|
||||||
- backend
|
- backend
|
||||||
|
restart: unless-stopped
|
||||||
profiles:
|
profiles:
|
||||||
- backend
|
- backend
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:latest
|
||||||
|
container_name: siren-redis
|
||||||
|
volumes:
|
||||||
|
- redis:/data
|
||||||
|
ports:
|
||||||
|
- ${REDIS_PORT:-6379}:6379
|
||||||
|
networks:
|
||||||
|
- backend
|
||||||
|
restart: unless-stopped
|
||||||
|
profiles:
|
||||||
|
- backend
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres:
|
postgres:
|
||||||
postgres_logs:
|
postgres_logs:
|
||||||
|
redis:
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
frontend:
|
frontend:
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
CREATE TABLE IF NOT EXISTS guilds (
|
|
||||||
id BIGINT PRIMARY KEY NOT NULL,
|
|
||||||
bot_id BIGINT NOT NULL,
|
|
||||||
volume INTEGER NOT NULL
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
|
||||||
id TEXT PRIMARY KEY NOT NULL,
|
|
||||||
guild_id BIGINT NOT NULL,
|
|
||||||
channel_id BIGINT NOT NULL,
|
|
||||||
author_id BIGINT NOT NULL,
|
|
||||||
created BIGINT NOT NULL,
|
|
||||||
model TEXT NOT NULL,
|
|
||||||
request TEXT NOT NULL,
|
|
||||||
response TEXT NOT NULL,
|
|
||||||
request_tags TEXT[] NOT NULL,
|
|
||||||
response_tags TEXT[] NOT NULL
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS events (
|
|
||||||
id UUID PRIMARY KEY NOT NULL,
|
|
||||||
guild_id BIGINT NOT NULL,
|
|
||||||
author_id BIGINT NOT NULL,
|
|
||||||
title TEXT NOT NULL,
|
|
||||||
date_time TIMESTAMP NOT NULL,
|
|
||||||
description TEXT,
|
|
||||||
rsvp BIGINT[] NOT NULL
|
|
||||||
);
|
|
||||||
73
migrations/000_initial.sql
Normal file
73
migrations/000_initial.sql
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS guilds (
|
||||||
|
id BIGINT PRIMARY KEY NOT NULL,
|
||||||
|
name TEXT,
|
||||||
|
owner_id BIGINT,
|
||||||
|
volume INTEGER NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id TEXT PRIMARY KEY NOT NULL,
|
||||||
|
guild_id BIGINT NOT NULL,
|
||||||
|
channel_id BIGINT NOT NULL,
|
||||||
|
author_id BIGINT NOT NULL,
|
||||||
|
created BIGINT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
request TEXT NOT NULL,
|
||||||
|
response TEXT NOT NULL,
|
||||||
|
request_tags TEXT[] NOT NULL,
|
||||||
|
response_tags TEXT[] NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS dice_rolls (
|
||||||
|
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid()
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS events (
|
||||||
|
id UUID PRIMARY KEY NOT NULL,
|
||||||
|
guild_id BIGINT NOT NULL,
|
||||||
|
author_id BIGINT NOT NULL,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
date_time TIMESTAMP NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
rsvp BIGINT[] NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS races (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
size TEXT NOT NULL,
|
||||||
|
source TEXT NOT NULL,
|
||||||
|
data JSON NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS classes (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS feats (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS options_features (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS backgrounds (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS items (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS spells (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
school TEXT NOT NULL,
|
||||||
|
level INTEGER NOT NULL,
|
||||||
|
ritual BOOLEAN DEFAULT FALSE,
|
||||||
|
concentration BOOLEAN DEFAULT FALSE,
|
||||||
|
classes TEXT[] NOT NULL,
|
||||||
|
damage_inflict TEXT[] NOT NULL,
|
||||||
|
damage_resist TEXT[] NOT NULL,
|
||||||
|
conditions TEXT[] NOT NULL,
|
||||||
|
saving_throw TEXT[] NOT NULL,
|
||||||
|
attack_type TEXT,
|
||||||
|
data JSONB NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS conditions (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS bestiary (
|
||||||
|
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||||
|
);
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
CREATE TABLE IF NOT EXISTS races (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
size TEXT NOT NULL,
|
|
||||||
source TEXT NOT NULL,
|
|
||||||
data JSON NOT NULL
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS classes (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS feats (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS options_features (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS backgrounds (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS items (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS spells (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
school TEXT NOT NULL,
|
|
||||||
level INTEGER NOT NULL,
|
|
||||||
ritual BOOLEAN DEFAULT FALSE,
|
|
||||||
concentration BOOLEAN DEFAULT FALSE,
|
|
||||||
classes TEXT[] NOT NULL,
|
|
||||||
damage_inflict TEXT[] NOT NULL,
|
|
||||||
damage_resist TEXT[] NOT NULL,
|
|
||||||
conditions TEXT[] NOT NULL,
|
|
||||||
saving_throw TEXT[] NOT NULL,
|
|
||||||
attack_type TEXT,
|
|
||||||
data JSONB NOT NULL
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS conditions (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS bestiary (
|
|
||||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
|
||||||
);
|
|
||||||
29
src/api/app.rs
Normal file
29
src/api/app.rs
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use axum::Router;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use crate::{api, AppState};
|
||||||
|
use crate::error::SirenResult;
|
||||||
|
|
||||||
|
pub struct App {
|
||||||
|
app_state: AppState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl App {
|
||||||
|
pub fn new(app_state: AppState) -> Self {
|
||||||
|
Self { app_state }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn serve(self) -> SirenResult<()> {
|
||||||
|
let app = Router::new()
|
||||||
|
.nest("/api", api::get_routes())
|
||||||
|
.with_state(Arc::new(self.app_state));
|
||||||
|
|
||||||
|
let api_port: String = env::var("API_PORT").expect("Expected a port in the environment");
|
||||||
|
let addr = format!("0.0.0.0:{}", api_port);
|
||||||
|
|
||||||
|
let listener = TcpListener::bind(&addr).await?;
|
||||||
|
log::info!("API is listening on {}", &addr);
|
||||||
|
Ok(axum::serve(listener, app).await?)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,13 @@
|
|||||||
use axum::Router;
|
mod app;
|
||||||
|
mod oauth;
|
||||||
|
|
||||||
pub fn get_routes() -> Router {
|
pub use app::App;
|
||||||
Router::new()
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use axum::Router;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||||
|
Router::new().nest("/oauth", oauth::get_routes())
|
||||||
}
|
}
|
||||||
|
|||||||
220
src/api/oauth.rs
Normal file
220
src/api/oauth.rs
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::sync::{Arc, OnceLock};
|
||||||
|
use axum::extract::{Query, State};
|
||||||
|
use axum::http::{HeaderMap, HeaderValue, StatusCode};
|
||||||
|
use axum::{Json, Router};
|
||||||
|
use axum::http::header::SET_COOKIE;
|
||||||
|
use axum::response::Redirect;
|
||||||
|
use axum::routing::get;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rand::Rng;
|
||||||
|
use rand_chacha::ChaCha20Rng;
|
||||||
|
use rand_chacha::rand_core::SeedableRng;
|
||||||
|
use redis::{AsyncCommands, RedisResult};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::{data, AppState};
|
||||||
|
use crate::error::SirenResult;
|
||||||
|
|
||||||
|
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
|
||||||
|
|
||||||
|
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||||
|
Router::new()
|
||||||
|
.route("/authorize", get(discord_authorize))
|
||||||
|
.route("/callback", get(oauth_callback))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_session_ttl() -> i64 {
|
||||||
|
// Initialize the SESSION_TTL value lazily
|
||||||
|
*SESSION_TTL.get_or_init(|| {
|
||||||
|
env::var("SESSION_TTL")
|
||||||
|
.ok()
|
||||||
|
.and_then(|val| val.parse::<i64>().ok())
|
||||||
|
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn csprng(take: usize) -> String {
|
||||||
|
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
||||||
|
let rng = ChaCha20Rng::from_entropy();
|
||||||
|
rng
|
||||||
|
.sample_iter(rand::distributions::Alphanumeric)
|
||||||
|
.take(take)
|
||||||
|
.map(char::from)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct AuthQuery {
|
||||||
|
code: String,
|
||||||
|
state: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct TokenResponse {
|
||||||
|
access_token: String,
|
||||||
|
token_type: String,
|
||||||
|
expires_in: u64,
|
||||||
|
refresh_token: String,
|
||||||
|
scope: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct DiscordUser {
|
||||||
|
id: String,
|
||||||
|
username: String,
|
||||||
|
discriminator: String,
|
||||||
|
avatar: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct Session {
|
||||||
|
session_id: String,
|
||||||
|
user_id: String,
|
||||||
|
user_name: String,
|
||||||
|
pub expires_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session {
|
||||||
|
fn new(id: String, user_id: String, user_name: String) -> Session {
|
||||||
|
let now = Utc::now();
|
||||||
|
let session_ttl = get_session_ttl();
|
||||||
|
Session {
|
||||||
|
session_id: id,
|
||||||
|
user_id,
|
||||||
|
user_name,
|
||||||
|
expires_at: now + chrono::Duration::seconds(session_ttl),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn insert(&self) -> SirenResult<()> {
|
||||||
|
let mut redis = data::redis_async_connection().await?;
|
||||||
|
let session_id = self.session_id.clone();
|
||||||
|
redis
|
||||||
|
.set_ex(
|
||||||
|
session_id,
|
||||||
|
serde_json::to_string(self)?,
|
||||||
|
self.expires_at.timestamp() as u64,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(session_id: String) -> SirenResult<Option<Session>> {
|
||||||
|
let mut redis = data::redis_async_connection().await?;
|
||||||
|
let result: RedisResult<Option<String>> = redis.get(session_id).await;
|
||||||
|
match result {
|
||||||
|
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
|
||||||
|
Ok(None) => Ok(None),
|
||||||
|
Err(err) => Err(err.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete(session_id: String) -> SirenResult<()> {
|
||||||
|
let mut redis = data::redis_async_connection().await?;
|
||||||
|
let result: RedisResult<()> = redis.del(session_id).await;
|
||||||
|
match result {
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
Err(err) => Err(err.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
|
||||||
|
// // Construct the Discord OAuth URL
|
||||||
|
// let discord_auth_url = format!(
|
||||||
|
// "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
|
||||||
|
// state.client_id, state.redirect_uri
|
||||||
|
// );
|
||||||
|
// Redirect::temporary(&discord_auth_url)
|
||||||
|
// }
|
||||||
|
|
||||||
|
async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<String> {
|
||||||
|
// Construct the Discord OAuth URL
|
||||||
|
let discord_auth_url = format!(
|
||||||
|
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
|
||||||
|
state.client_id, state.redirect_uri
|
||||||
|
);
|
||||||
|
Ok(discord_auth_url)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn oauth_callback(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Query(query): Query<AuthQuery>,
|
||||||
|
) -> SirenResult<(HeaderMap, Json<DiscordUser>)> {
|
||||||
|
// Exchange code for an access token
|
||||||
|
let token_response = state
|
||||||
|
.client
|
||||||
|
.post("https://discord.com/api/oauth2/token")
|
||||||
|
.form(&[
|
||||||
|
("client_id", state.client_id.as_str()),
|
||||||
|
("client_secret", state.client_secret.as_str()),
|
||||||
|
("grant_type", "authorization_code"),
|
||||||
|
("code", query.code.as_str()),
|
||||||
|
("redirect_uri", state.redirect_uri.as_str()),
|
||||||
|
])
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
|
|
||||||
|
if !token_response.status().is_success() {
|
||||||
|
log::error!(
|
||||||
|
"Failed to exchange token: {:?}",
|
||||||
|
token_response.text().await
|
||||||
|
);
|
||||||
|
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_data: TokenResponse = token_response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
|
|
||||||
|
// Fetch user information
|
||||||
|
let user_response = state
|
||||||
|
.client
|
||||||
|
.get("https://discord.com/api/users/@me")
|
||||||
|
.bearer_auth(token_data.access_token)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
|
|
||||||
|
if !user_response.status().is_success() {
|
||||||
|
log::error!(
|
||||||
|
"Failed to fetch user information: {:?}",
|
||||||
|
user_response.text().await
|
||||||
|
);
|
||||||
|
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let user_data: DiscordUser = user_response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
|
|
||||||
|
log::debug!("User authenticated: {:?}", user_data);
|
||||||
|
|
||||||
|
// Generate a session token
|
||||||
|
let session_token = csprng(16);
|
||||||
|
let expiration = env::var("API_SESSION_TTL")
|
||||||
|
.expect("Expected a session ttl in the environment")
|
||||||
|
.parse::<u64>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Create and insert the session
|
||||||
|
let session = Session::new(
|
||||||
|
session_token.clone(),
|
||||||
|
user_data.id.clone(),
|
||||||
|
user_data.username.clone(),
|
||||||
|
);
|
||||||
|
session.insert().await?;
|
||||||
|
|
||||||
|
let cookie_value = format!(
|
||||||
|
"session={}; HttpOnly; Path=/; Max-Age={}",
|
||||||
|
session_token, expiration
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
|
||||||
|
|
||||||
|
Ok((headers, Json(user_data)))
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption};
|
use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption};
|
||||||
use serenity::model::prelude::GuildId;
|
use serenity::model::prelude::GuildId;
|
||||||
use serenity::{prelude::*, async_trait, futures};
|
use serenity::{prelude::*, async_trait};
|
||||||
use songbird::input::{Input, YoutubeDl};
|
use songbird::input::{Input, YoutubeDl};
|
||||||
use songbird::tracks::TrackHandle;
|
use songbird::tracks::TrackHandle;
|
||||||
use songbird::{Event, EventHandler, Songbird, TrackEvent};
|
use songbird::{Event, EventHandler, Songbird, TrackEvent};
|
||||||
@@ -25,7 +25,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
|||||||
"{} attempted to play a track without a track option",
|
"{} attempted to play a track without a track option",
|
||||||
command.user.id.get()
|
command.user.id.get()
|
||||||
);
|
);
|
||||||
create_message_response(&ctx, &command, format!("Track option is missing"), false).await;
|
create_message_response(&ctx, &command, "Track option is missing".to_string(), false).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -53,7 +53,9 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
|||||||
// Join the user's voice channel
|
// Join the user's voice channel
|
||||||
match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user).await {
|
match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user).await {
|
||||||
Ok(channel_id) => {
|
Ok(channel_id) => {
|
||||||
log::debug!("<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}");
|
log::debug!(
|
||||||
|
"<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}"
|
||||||
|
);
|
||||||
// Handle the track url
|
// Handle the track url
|
||||||
match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await {
|
match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await {
|
||||||
Ok(items) => {
|
Ok(items) => {
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
|||||||
Err(err) => {
|
Err(err) => {
|
||||||
log::error!(
|
log::error!(
|
||||||
"<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}",
|
"<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}",
|
||||||
err.message
|
err.details
|
||||||
);
|
);
|
||||||
"There was an error processing your message. Please try again later.".to_string()
|
"There was an error processing your message. Please try again later.".to_string()
|
||||||
}
|
}
|
||||||
@@ -196,7 +196,7 @@ async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
log::error!("Could not get response from OpenAI: {}", err.message);
|
log::error!("Could not get response from OpenAI: {}", err.details);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return response;
|
return response;
|
||||||
|
|||||||
@@ -3,6 +3,14 @@ use crate::bot::chat::create_message_response;
|
|||||||
|
|
||||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||||
log::debug!("Ping command executed");
|
log::debug!("Ping command executed");
|
||||||
|
|
||||||
|
if let Some(guild_id) = command.guild_id {
|
||||||
|
if let Some(guild) = guild_id.to_guild_cached(&ctx.cache) {
|
||||||
|
let owner_id = guild.owner_id;
|
||||||
|
if command.user.id == owner_id {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
create_message_response(&ctx, &command, "pong".to_string(), true).await;
|
create_message_response(&ctx, &command, "pong".to_string(), true).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ use crate::data::guilds::GuildCache;
|
|||||||
use super::{commands};
|
use super::{commands};
|
||||||
use super::chat::{create_message_response, create_modal_response};
|
use super::chat::{create_message_response, create_modal_response};
|
||||||
|
|
||||||
pub struct Handler {
|
pub struct BotHandler {
|
||||||
// Open AI Config
|
// Open AI Config
|
||||||
pub oai: Option<OAI>,
|
pub oai: Option<OAI>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EventHandler for Handler {
|
impl EventHandler for BotHandler {
|
||||||
async fn message(&self, ctx: Context, msg: Message) {
|
async fn message(&self, ctx: Context, msg: Message) {
|
||||||
// Ignore bot messages
|
// Ignore bot messages
|
||||||
if msg.author.bot {
|
if msg.author.bot {
|
||||||
@@ -47,7 +47,8 @@ impl EventHandler for Handler {
|
|||||||
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
|
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
|
||||||
let guild_cache = GuildCache {
|
let guild_cache = GuildCache {
|
||||||
id: guild_id,
|
id: guild_id,
|
||||||
bot_id: 1,
|
name: guild.id.name(&ctx.cache),
|
||||||
|
owner_id: None,
|
||||||
volume: 100,
|
volume: 100,
|
||||||
};
|
};
|
||||||
guild_cache.insert().await.unwrap();
|
guild_cache.insert().await.unwrap();
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ impl OAI {
|
|||||||
ResponseEvent::ResponseError(error) => {
|
ResponseEvent::ResponseError(error) => {
|
||||||
return Err(SirenError {
|
return Err(SirenError {
|
||||||
status: 500,
|
status: 500,
|
||||||
message: format!("Error: {}", error.message.unwrap()),
|
details: format!("Error: {}", error.message.unwrap()),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,7 +137,7 @@ impl OAI {
|
|||||||
Err(err) => {
|
Err(err) => {
|
||||||
return Err(SirenError {
|
return Err(SirenError {
|
||||||
status: 500,
|
status: 500,
|
||||||
message: format!("Error: {}", err),
|
details: format!("Error: {}", err),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ const TABLE_NAME: &str = "guilds";
|
|||||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||||
pub struct GuildCache {
|
pub struct GuildCache {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub bot_id: i64,
|
pub name: Option<String>,
|
||||||
|
pub owner_id: Option<i64>,
|
||||||
pub volume: i32,
|
pub volume: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -16,15 +17,17 @@ impl GuildCache {
|
|||||||
sqlx::query(&format!(
|
sqlx::query(&format!(
|
||||||
"INSERT INTO {} (
|
"INSERT INTO {} (
|
||||||
id,
|
id,
|
||||||
bot_id,
|
name,
|
||||||
|
owner_id,
|
||||||
volume
|
volume
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3
|
$1, $2, $3, $4
|
||||||
)",
|
)",
|
||||||
TABLE_NAME
|
TABLE_NAME
|
||||||
))
|
))
|
||||||
.bind(self.id)
|
.bind(self.id)
|
||||||
.bind(self.bot_id)
|
.bind(&self.name)
|
||||||
|
.bind(self.owner_id)
|
||||||
.bind(self.volume)
|
.bind(self.volume)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -45,13 +48,15 @@ impl GuildCache {
|
|||||||
let pool = crate::data::pool();
|
let pool = crate::data::pool();
|
||||||
sqlx::query(&format!(
|
sqlx::query(&format!(
|
||||||
"UPDATE {} SET
|
"UPDATE {} SET
|
||||||
bot_id = $2,
|
name = $2,
|
||||||
volume = $3
|
owner_id = $3,
|
||||||
|
volume = $4
|
||||||
WHERE id = $1",
|
WHERE id = $1",
|
||||||
TABLE_NAME
|
TABLE_NAME
|
||||||
))
|
))
|
||||||
.bind(self.id)
|
.bind(self.id)
|
||||||
.bind(self.bot_id)
|
.bind(&self.name)
|
||||||
|
.bind(self.owner_id)
|
||||||
.bind(self.volume)
|
.bind(self.volume)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
44
src/error.rs
44
src/error.rs
@@ -1,4 +1,7 @@
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::Json;
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub type SirenResult<T> = Result<T, Error>;
|
pub type SirenResult<T> = Result<T, Error>;
|
||||||
@@ -6,21 +9,44 @@ pub type SirenResult<T> = Result<T, Error>;
|
|||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
pub struct Error {
|
pub struct Error {
|
||||||
pub status: u16,
|
pub status: u16,
|
||||||
pub message: String,
|
pub details: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error {
|
impl Error {
|
||||||
pub fn new(error_status_code: u16, error_message: String) -> Self {
|
pub fn new(error_status_code: u16, error_message: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
status: error_status_code,
|
status: error_status_code,
|
||||||
message: error_message,
|
details: error_message,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for Error {
|
impl fmt::Display for Error {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
f.write_str(self.message.as_str())
|
f.write_str(self.details.as_str())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for Error {
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
&self.details
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for Error {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
// Create a JSON response with the structured error
|
||||||
|
let body = Json(serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"status": self.status,
|
||||||
|
"details": self.details,
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Return the response with the proper status and error body
|
||||||
|
(status, body).into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,6 +56,18 @@ impl From<std::io::Error> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<StatusCode> for Error {
|
||||||
|
fn from(status: StatusCode) -> Self {
|
||||||
|
Error {
|
||||||
|
status: status.as_u16(),
|
||||||
|
details: status
|
||||||
|
.canonical_reason()
|
||||||
|
.unwrap_or("Unknown error")
|
||||||
|
.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<std::string::FromUtf8Error> for Error {
|
impl From<std::string::FromUtf8Error> for Error {
|
||||||
fn from(error: std::string::FromUtf8Error) -> Self {
|
fn from(error: std::string::FromUtf8Error) -> Self {
|
||||||
Self::new(500, format!("Unknown from utf8 error: {}", error))
|
Self::new(500, format!("Unknown from utf8 error: {}", error))
|
||||||
|
|||||||
99
src/main.rs
99
src/main.rs
@@ -1,15 +1,13 @@
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use axum::Router;
|
|
||||||
use serenity::http::Http;
|
use serenity::http::Http;
|
||||||
use serenity::prelude::*;
|
use serenity::prelude::*;
|
||||||
use songbird::{SerenityInit, Songbird};
|
use songbird::{SerenityInit, Songbird};
|
||||||
use reqwest::Client as HttpClient;
|
use reqwest::Client as HttpClient;
|
||||||
use serenity::all::{ShardManager, UserId};
|
use serenity::all::{Cache, ShardManager, UserId};
|
||||||
use tokio::net::TcpListener;
|
use crate::api::App;
|
||||||
|
use crate::bot::handler::BotHandler;
|
||||||
use crate::bot::handler::Handler;
|
|
||||||
use crate::bot::oai::OAI;
|
use crate::bot::oai::OAI;
|
||||||
|
|
||||||
mod api;
|
mod api;
|
||||||
@@ -24,47 +22,24 @@ impl TypeMapKey for HttpKey {
|
|||||||
type Value = HttpClient;
|
type Value = HttpClient;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {
|
||||||
|
client: reqwest::Client,
|
||||||
|
client_id: String,
|
||||||
|
client_secret: String,
|
||||||
|
redirect_uri: String,
|
||||||
|
http: Arc<Http>,
|
||||||
|
cache: Arc<Cache>,
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
dotenv::dotenv().ok();
|
dotenv::dotenv().ok();
|
||||||
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
|
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
|
||||||
|
|
||||||
if let Err(err) = data::initialize().await {
|
data::initialize().await?;
|
||||||
log::error!("Failed to initialize database: {err}");
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start API server
|
|
||||||
tokio::spawn(start_api());
|
|
||||||
|
|
||||||
// Start Discord bot
|
|
||||||
start_bot().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start_api() {
|
|
||||||
let app = Router::new();
|
|
||||||
let addr: String = "127.0.0.1:3000".parse().unwrap();
|
|
||||||
|
|
||||||
let listener = TcpListener::bind(&addr).await.unwrap();
|
|
||||||
log::debug!("API is listening on {}", &addr);
|
|
||||||
axum::serve(listener, app).await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start_bot() {
|
|
||||||
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
|
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
|
||||||
let intents: GatewayIntents = GatewayIntents::all();
|
|
||||||
|
|
||||||
let http: Http = Http::new(&token);
|
|
||||||
let (owners, bot_id) = get_bot_info(&http).await;
|
|
||||||
|
|
||||||
log::debug!(
|
|
||||||
"Starting Discord bot with ID: {bot_id} and owners: {}",
|
|
||||||
owners
|
|
||||||
.iter()
|
|
||||||
.map(|id| id.to_string())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ")
|
|
||||||
);
|
|
||||||
|
|
||||||
// Set up handler with optional OpenAI integration
|
// Set up handler with optional OpenAI integration
|
||||||
let handler = configure_handler();
|
let handler = configure_handler();
|
||||||
@@ -72,6 +47,8 @@ async fn start_bot() {
|
|||||||
// Set up Songbird for voice functionality
|
// Set up Songbird for voice functionality
|
||||||
let songbird = Songbird::serenity();
|
let songbird = Songbird::serenity();
|
||||||
|
|
||||||
|
let intents: GatewayIntents = GatewayIntents::all();
|
||||||
|
|
||||||
let mut client = Client::builder(token, intents)
|
let mut client = Client::builder(token, intents)
|
||||||
.event_handler(handler)
|
.event_handler(handler)
|
||||||
// .framework(StandardFramework::new().configure(|c| c.owners(owners)))
|
// .framework(StandardFramework::new().configure(|c| c.owners(owners)))
|
||||||
@@ -80,29 +57,53 @@ async fn start_bot() {
|
|||||||
.await
|
.await
|
||||||
.expect("Error creating client");
|
.expect("Error creating client");
|
||||||
|
|
||||||
|
let (bot_owner, bot_id) = get_bot_info(&client.http).await;
|
||||||
|
|
||||||
|
let client_secret: String =
|
||||||
|
env::var("DISCORD_SECRET").expect("Expected a secret in the environment");
|
||||||
|
let redirect_uri: String =
|
||||||
|
env::var("API_CALLBACK_URI").expect("Expected a secret in the environment");
|
||||||
|
let app_state = AppState {
|
||||||
|
client: HttpClient::new(),
|
||||||
|
client_id: bot_id.to_string(),
|
||||||
|
client_secret,
|
||||||
|
redirect_uri,
|
||||||
|
http: Arc::clone(&client.http),
|
||||||
|
cache: Arc::clone(&client.cache),
|
||||||
|
};
|
||||||
|
|
||||||
|
log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner);
|
||||||
|
|
||||||
// Spawn shutdown signal handling
|
// Spawn shutdown signal handling
|
||||||
let shard_manager = Arc::clone(&client.shard_manager);
|
let shard_manager = Arc::clone(&client.shard_manager);
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
signal_shutdown(shard_manager).await;
|
signal_shutdown(shard_manager).await;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Start the bot
|
// Start API server
|
||||||
|
tokio::spawn(App::new(app_state).serve());
|
||||||
|
|
||||||
|
// Start Discord bot
|
||||||
if let Err(why) = client.start_autosharded().await {
|
if let Err(why) = client.start_autosharded().await {
|
||||||
log::error!("Client error: {why:?}");
|
log::error!("Client error: {why:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_bot_info(http: &Http) -> (HashSet<UserId>, UserId) {
|
async fn get_bot_info(http: &Http) -> (Option<UserId>, UserId) {
|
||||||
match http.get_current_application_info().await {
|
match http.get_current_application_info().await {
|
||||||
Ok(info) => {
|
Ok(info) => {
|
||||||
let mut owners = HashSet::new();
|
let bot_owner;
|
||||||
if let Some(team) = info.team {
|
if let Some(team) = info.team {
|
||||||
owners.insert(team.owner_user_id);
|
bot_owner = Some(team.owner_user_id);
|
||||||
} else if let Some(owner) = info.owner {
|
} else if let Some(owner) = info.owner {
|
||||||
owners.insert(owner.id);
|
bot_owner = Some(owner.id);
|
||||||
|
} else {
|
||||||
|
bot_owner = None;
|
||||||
}
|
}
|
||||||
match http.get_current_user().await {
|
match http.get_current_user().await {
|
||||||
Ok(bot) => (owners, bot.id),
|
Ok(bot) => (bot_owner, bot.id),
|
||||||
Err(why) => panic!("Could not access the bot id: {why:?}"),
|
Err(why) => panic!("Could not access the bot id: {why:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -110,13 +111,13 @@ async fn get_bot_info(http: &Http) -> (HashSet<UserId>, UserId) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn configure_handler() -> Handler {
|
fn configure_handler() -> BotHandler {
|
||||||
match env::var("OPENAI_TOKEN") {
|
match env::var("OPENAI_TOKEN") {
|
||||||
Ok(token) => {
|
Ok(token) => {
|
||||||
log::debug!("OpenAI functionality enabled");
|
log::debug!("OpenAI functionality enabled");
|
||||||
let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
|
let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
|
||||||
let base_url = env::var("OPENAI_BASE_URL").unwrap();
|
let base_url = env::var("OPENAI_BASE_URL").unwrap();
|
||||||
Handler {
|
BotHandler {
|
||||||
oai: Some(OAI {
|
oai: Some(OAI {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
base_url,
|
base_url,
|
||||||
@@ -129,7 +130,7 @@ fn configure_handler() -> Handler {
|
|||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
log::warn!("OpenAI functionality disabled");
|
log::warn!("OpenAI functionality disabled");
|
||||||
Handler { oai: None }
|
BotHandler { oai: None }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user