Re-implementing the API
This commit is contained in:
5
.env
5
.env
@@ -1,6 +1,7 @@
|
||||
RUST_LOG=warn,siren=info
|
||||
|
||||
DISCORD_TOKEN=
|
||||
DISCORD_SECRET=
|
||||
|
||||
DATABASE_USER=siren
|
||||
DATABASE_PASSWORD=CHANGEME # Change this to a secure password
|
||||
@@ -8,7 +9,9 @@ DATABASE_NAME=siren
|
||||
DATABASE_HOST=localhost
|
||||
DATABASE_PORT=5432
|
||||
|
||||
SESSION_TTL=1440
|
||||
API_CALLBACK_URI=http://localhost:3000/api/oauth/callback
|
||||
API_PORT=3000
|
||||
API_SESSION_TTL=86400
|
||||
|
||||
MINIO_ROOT_USER=siren
|
||||
MINIO_ROOT_PASSWORD=CHANGEME # Change this to a secure password
|
||||
|
||||
@@ -23,8 +23,10 @@ reqwest = { version = "0.11", default-features = false, features = ["json"] }
|
||||
uuid = { version = "1.11.0", features = ["serde", "v4"] }
|
||||
redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] }
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
|
||||
regex = "1.11.0"
|
||||
axum = "0.7.7"
|
||||
lazy_static = "1.5.0"
|
||||
futures = "0.3.31"
|
||||
axum-login = "0.16.0"
|
||||
|
||||
@@ -13,7 +13,7 @@ services:
|
||||
environment:
|
||||
DATABASE_HOST: siren-postgres
|
||||
DATABASE_PORT: 5432
|
||||
REDIS_HOST: redis
|
||||
REDIS_HOST: siren-redis
|
||||
REDIS_PORT: 6379
|
||||
DATA_DIR_PATH: /data
|
||||
volumes:
|
||||
@@ -42,14 +42,27 @@ services:
|
||||
- ${DATABASE_PORT:-5432}:5432
|
||||
networks:
|
||||
- backend
|
||||
restart: unless-stopped
|
||||
profiles:
|
||||
- backend
|
||||
restart: unless-stopped
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: siren-redis
|
||||
volumes:
|
||||
- redis:/data
|
||||
ports:
|
||||
- ${REDIS_PORT:-6379}:6379
|
||||
networks:
|
||||
- backend
|
||||
restart: unless-stopped
|
||||
profiles:
|
||||
- backend
|
||||
|
||||
volumes:
|
||||
postgres:
|
||||
postgres_logs:
|
||||
redis:
|
||||
|
||||
networks:
|
||||
frontend:
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS guilds (
|
||||
id BIGINT PRIMARY KEY NOT NULL,
|
||||
bot_id BIGINT NOT NULL,
|
||||
volume INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
guild_id BIGINT NOT NULL,
|
||||
channel_id BIGINT NOT NULL,
|
||||
author_id BIGINT NOT NULL,
|
||||
created BIGINT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
request TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
request_tags TEXT[] NOT NULL,
|
||||
response_tags TEXT[] NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id UUID PRIMARY KEY NOT NULL,
|
||||
guild_id BIGINT NOT NULL,
|
||||
author_id BIGINT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
date_time TIMESTAMP NOT NULL,
|
||||
description TEXT,
|
||||
rsvp BIGINT[] NOT NULL
|
||||
);
|
||||
73
migrations/000_initial.sql
Normal file
73
migrations/000_initial.sql
Normal file
@@ -0,0 +1,73 @@
|
||||
CREATE TABLE IF NOT EXISTS guilds (
|
||||
id BIGINT PRIMARY KEY NOT NULL,
|
||||
name TEXT,
|
||||
owner_id BIGINT,
|
||||
volume INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
guild_id BIGINT NOT NULL,
|
||||
channel_id BIGINT NOT NULL,
|
||||
author_id BIGINT NOT NULL,
|
||||
created BIGINT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
request TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
request_tags TEXT[] NOT NULL,
|
||||
response_tags TEXT[] NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS dice_rolls (
|
||||
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid()
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id UUID PRIMARY KEY NOT NULL,
|
||||
guild_id BIGINT NOT NULL,
|
||||
author_id BIGINT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
date_time TIMESTAMP NOT NULL,
|
||||
description TEXT,
|
||||
rsvp BIGINT[] NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS races (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
||||
name TEXT NOT NULL,
|
||||
size TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
data JSON NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS classes (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS feats (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS options_features (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS backgrounds (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS items (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS spells (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
||||
name TEXT NOT NULL,
|
||||
school TEXT NOT NULL,
|
||||
level INTEGER NOT NULL,
|
||||
ritual BOOLEAN DEFAULT FALSE,
|
||||
concentration BOOLEAN DEFAULT FALSE,
|
||||
classes TEXT[] NOT NULL,
|
||||
damage_inflict TEXT[] NOT NULL,
|
||||
damage_resist TEXT[] NOT NULL,
|
||||
conditions TEXT[] NOT NULL,
|
||||
saving_throw TEXT[] NOT NULL,
|
||||
attack_type TEXT,
|
||||
data JSONB NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS conditions (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS bestiary (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
@@ -1,43 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS races (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
||||
name TEXT NOT NULL,
|
||||
size TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
data JSON NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS classes (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS feats (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS options_features (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS backgrounds (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS items (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS spells (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY,
|
||||
name TEXT NOT NULL,
|
||||
school TEXT NOT NULL,
|
||||
level INTEGER NOT NULL,
|
||||
ritual BOOLEAN DEFAULT FALSE,
|
||||
concentration BOOLEAN DEFAULT FALSE,
|
||||
classes TEXT[] NOT NULL,
|
||||
damage_inflict TEXT[] NOT NULL,
|
||||
damage_resist TEXT[] NOT NULL,
|
||||
conditions TEXT[] NOT NULL,
|
||||
saving_throw TEXT[] NOT NULL,
|
||||
attack_type TEXT,
|
||||
data JSONB NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS conditions (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS bestiary (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
29
src/api/app.rs
Normal file
29
src/api/app.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use tokio::net::TcpListener;
|
||||
use crate::{api, AppState};
|
||||
use crate::error::SirenResult;
|
||||
|
||||
pub struct App {
|
||||
app_state: AppState,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new(app_state: AppState) -> Self {
|
||||
Self { app_state }
|
||||
}
|
||||
|
||||
pub async fn serve(self) -> SirenResult<()> {
|
||||
let app = Router::new()
|
||||
.nest("/api", api::get_routes())
|
||||
.with_state(Arc::new(self.app_state));
|
||||
|
||||
let api_port: String = env::var("API_PORT").expect("Expected a port in the environment");
|
||||
let addr = format!("0.0.0.0:{}", api_port);
|
||||
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
log::info!("API is listening on {}", &addr);
|
||||
Ok(axum::serve(listener, app).await?)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,13 @@
|
||||
use axum::Router;
|
||||
mod app;
|
||||
mod oauth;
|
||||
|
||||
pub fn get_routes() -> Router {
|
||||
Router::new()
|
||||
pub use app::App;
|
||||
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::AppState;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new().nest("/oauth", oauth::get_routes())
|
||||
}
|
||||
|
||||
220
src/api/oauth.rs
Normal file
220
src/api/oauth.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
use std::env;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use axum::extract::{Query, State};
|
||||
use axum::http::{HeaderMap, HeaderValue, StatusCode};
|
||||
use axum::{Json, Router};
|
||||
use axum::http::header::SET_COOKIE;
|
||||
use axum::response::Redirect;
|
||||
use axum::routing::get;
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::Rng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use rand_chacha::rand_core::SeedableRng;
|
||||
use redis::{AsyncCommands, RedisResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::{data, AppState};
|
||||
use crate::error::SirenResult;
|
||||
|
||||
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/authorize", get(discord_authorize))
|
||||
.route("/callback", get(oauth_callback))
|
||||
}
|
||||
|
||||
fn get_session_ttl() -> i64 {
|
||||
// Initialize the SESSION_TTL value lazily
|
||||
*SESSION_TTL.get_or_init(|| {
|
||||
env::var("SESSION_TTL")
|
||||
.ok()
|
||||
.and_then(|val| val.parse::<i64>().ok())
|
||||
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn csprng(take: usize) -> String {
|
||||
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
||||
let rng = ChaCha20Rng::from_entropy();
|
||||
rng
|
||||
.sample_iter(rand::distributions::Alphanumeric)
|
||||
.take(take)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthQuery {
|
||||
code: String,
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: u64,
|
||||
refresh_token: String,
|
||||
scope: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct DiscordUser {
|
||||
id: String,
|
||||
username: String,
|
||||
discriminator: String,
|
||||
avatar: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct Session {
|
||||
session_id: String,
|
||||
user_id: String,
|
||||
user_name: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn new(id: String, user_id: String, user_name: String) -> Session {
|
||||
let now = Utc::now();
|
||||
let session_ttl = get_session_ttl();
|
||||
Session {
|
||||
session_id: id,
|
||||
user_id,
|
||||
user_name,
|
||||
expires_at: now + chrono::Duration::seconds(session_ttl),
|
||||
}
|
||||
}
|
||||
|
||||
async fn insert(&self) -> SirenResult<()> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let session_id = self.session_id.clone();
|
||||
redis
|
||||
.set_ex(
|
||||
session_id,
|
||||
serde_json::to_string(self)?,
|
||||
self.expires_at.timestamp() as u64,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(session_id: String) -> SirenResult<Option<Session>> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let result: RedisResult<Option<String>> = redis.get(session_id).await;
|
||||
match result {
|
||||
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
|
||||
Ok(None) => Ok(None),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete(session_id: String) -> SirenResult<()> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let result: RedisResult<()> = redis.del(session_id).await;
|
||||
match result {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
|
||||
// // Construct the Discord OAuth URL
|
||||
// let discord_auth_url = format!(
|
||||
// "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
|
||||
// state.client_id, state.redirect_uri
|
||||
// );
|
||||
// Redirect::temporary(&discord_auth_url)
|
||||
// }
|
||||
|
||||
async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<String> {
|
||||
// Construct the Discord OAuth URL
|
||||
let discord_auth_url = format!(
|
||||
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
|
||||
state.client_id, state.redirect_uri
|
||||
);
|
||||
Ok(discord_auth_url)
|
||||
}
|
||||
|
||||
async fn oauth_callback(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<AuthQuery>,
|
||||
) -> SirenResult<(HeaderMap, Json<DiscordUser>)> {
|
||||
// Exchange code for an access token
|
||||
let token_response = state
|
||||
.client
|
||||
.post("https://discord.com/api/oauth2/token")
|
||||
.form(&[
|
||||
("client_id", state.client_id.as_str()),
|
||||
("client_secret", state.client_secret.as_str()),
|
||||
("grant_type", "authorization_code"),
|
||||
("code", query.code.as_str()),
|
||||
("redirect_uri", state.redirect_uri.as_str()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !token_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to exchange token: {:?}",
|
||||
token_response.text().await
|
||||
);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
|
||||
}
|
||||
|
||||
let token_data: TokenResponse = token_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Fetch user information
|
||||
let user_response = state
|
||||
.client
|
||||
.get("https://discord.com/api/users/@me")
|
||||
.bearer_auth(token_data.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !user_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to fetch user information: {:?}",
|
||||
user_response.text().await
|
||||
);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
|
||||
}
|
||||
|
||||
let user_data: DiscordUser = user_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
log::debug!("User authenticated: {:?}", user_data);
|
||||
|
||||
// Generate a session token
|
||||
let session_token = csprng(16);
|
||||
let expiration = env::var("API_SESSION_TTL")
|
||||
.expect("Expected a session ttl in the environment")
|
||||
.parse::<u64>()
|
||||
.unwrap();
|
||||
|
||||
// Create and insert the session
|
||||
let session = Session::new(
|
||||
session_token.clone(),
|
||||
user_data.id.clone(),
|
||||
user_data.username.clone(),
|
||||
);
|
||||
session.insert().await?;
|
||||
|
||||
let cookie_value = format!(
|
||||
"session={}; HttpOnly; Path=/; Max-Age={}",
|
||||
session_token, expiration
|
||||
);
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
|
||||
|
||||
Ok((headers, Json(user_data)))
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption};
|
||||
use serenity::model::prelude::GuildId;
|
||||
use serenity::{prelude::*, async_trait, futures};
|
||||
use serenity::{prelude::*, async_trait};
|
||||
use songbird::input::{Input, YoutubeDl};
|
||||
use songbird::tracks::TrackHandle;
|
||||
use songbird::{Event, EventHandler, Songbird, TrackEvent};
|
||||
@@ -25,7 +25,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
"{} attempted to play a track without a track option",
|
||||
command.user.id.get()
|
||||
);
|
||||
create_message_response(&ctx, &command, format!("Track option is missing"), false).await;
|
||||
create_message_response(&ctx, &command, "Track option is missing".to_string(), false).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -53,7 +53,9 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Join the user's voice channel
|
||||
match join_voice_channel(&ctx.cache, &manager, guild_id, &command.user).await {
|
||||
Ok(channel_id) => {
|
||||
log::debug!("<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}");
|
||||
log::debug!(
|
||||
"<{guild_id}> Play command executed on channel {channel_id} with track: {track_url:?}"
|
||||
);
|
||||
// Handle the track url
|
||||
match enqueue_track(ctx, manager, guild_id.to_owned(), track_url).await {
|
||||
Ok(items) => {
|
||||
|
||||
@@ -124,7 +124,7 @@ pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
Err(err) => {
|
||||
log::error!(
|
||||
"<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}",
|
||||
err.message
|
||||
err.details
|
||||
);
|
||||
"There was an error processing your message. Please try again later.".to_string()
|
||||
}
|
||||
@@ -196,7 +196,7 @@ async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String {
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Could not get response from OpenAI: {}", err.message);
|
||||
log::error!("Could not get response from OpenAI: {}", err.details);
|
||||
}
|
||||
};
|
||||
return response;
|
||||
|
||||
@@ -3,6 +3,14 @@ use crate::bot::chat::create_message_response;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
log::debug!("Ping command executed");
|
||||
|
||||
if let Some(guild_id) = command.guild_id {
|
||||
if let Some(guild) = guild_id.to_guild_cached(&ctx.cache) {
|
||||
let owner_id = guild.owner_id;
|
||||
if command.user.id == owner_id {}
|
||||
}
|
||||
}
|
||||
|
||||
create_message_response(&ctx, &command, "pong".to_string(), true).await;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@ use crate::data::guilds::GuildCache;
|
||||
use super::{commands};
|
||||
use super::chat::{create_message_response, create_modal_response};
|
||||
|
||||
pub struct Handler {
|
||||
pub struct BotHandler {
|
||||
// Open AI Config
|
||||
pub oai: Option<OAI>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EventHandler for Handler {
|
||||
impl EventHandler for BotHandler {
|
||||
async fn message(&self, ctx: Context, msg: Message) {
|
||||
// Ignore bot messages
|
||||
if msg.author.bot {
|
||||
@@ -47,7 +47,8 @@ impl EventHandler for Handler {
|
||||
if let None = GuildCache::get_by_id(guild_id).await.unwrap() {
|
||||
let guild_cache = GuildCache {
|
||||
id: guild_id,
|
||||
bot_id: 1,
|
||||
name: guild.id.name(&ctx.cache),
|
||||
owner_id: None,
|
||||
volume: 100,
|
||||
};
|
||||
guild_cache.insert().await.unwrap();
|
||||
|
||||
@@ -129,7 +129,7 @@ impl OAI {
|
||||
ResponseEvent::ResponseError(error) => {
|
||||
return Err(SirenError {
|
||||
status: 500,
|
||||
message: format!("Error: {}", error.message.unwrap()),
|
||||
details: format!("Error: {}", error.message.unwrap()),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -137,7 +137,7 @@ impl OAI {
|
||||
Err(err) => {
|
||||
return Err(SirenError {
|
||||
status: 500,
|
||||
message: format!("Error: {}", err),
|
||||
details: format!("Error: {}", err),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@ const TABLE_NAME: &str = "guilds";
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct GuildCache {
|
||||
pub id: i64,
|
||||
pub bot_id: i64,
|
||||
pub name: Option<String>,
|
||||
pub owner_id: Option<i64>,
|
||||
pub volume: i32,
|
||||
}
|
||||
|
||||
@@ -16,18 +17,20 @@ impl GuildCache {
|
||||
sqlx::query(&format!(
|
||||
"INSERT INTO {} (
|
||||
id,
|
||||
bot_id,
|
||||
name,
|
||||
owner_id,
|
||||
volume
|
||||
) VALUES (
|
||||
$1, $2, $3
|
||||
$1, $2, $3, $4
|
||||
)",
|
||||
TABLE_NAME
|
||||
))
|
||||
.bind(self.id)
|
||||
.bind(self.bot_id)
|
||||
.bind(self.volume)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
.bind(self.id)
|
||||
.bind(&self.name)
|
||||
.bind(self.owner_id)
|
||||
.bind(self.volume)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -45,16 +48,18 @@ impl GuildCache {
|
||||
let pool = crate::data::pool();
|
||||
sqlx::query(&format!(
|
||||
"UPDATE {} SET
|
||||
bot_id = $2,
|
||||
volume = $3
|
||||
name = $2,
|
||||
owner_id = $3,
|
||||
volume = $4
|
||||
WHERE id = $1",
|
||||
TABLE_NAME
|
||||
))
|
||||
.bind(self.id)
|
||||
.bind(self.bot_id)
|
||||
.bind(self.volume)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
.bind(self.id)
|
||||
.bind(&self.name)
|
||||
.bind(self.owner_id)
|
||||
.bind(self.volume)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
44
src/error.rs
44
src/error.rs
@@ -1,4 +1,7 @@
|
||||
use std::fmt;
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub type SirenResult<T> = Result<T, Error>;
|
||||
@@ -6,21 +9,44 @@ pub type SirenResult<T> = Result<T, Error>;
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Error {
|
||||
pub status: u16,
|
||||
pub message: String,
|
||||
pub details: String,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn new(error_status_code: u16, error_message: String) -> Self {
|
||||
Self {
|
||||
status: error_status_code,
|
||||
message: error_message,
|
||||
details: error_message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.message.as_str())
|
||||
f.write_str(self.details.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn description(&self) -> &str {
|
||||
&self.details
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> Response {
|
||||
let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
// Create a JSON response with the structured error
|
||||
let body = Json(serde_json::json!({
|
||||
"error": {
|
||||
"status": self.status,
|
||||
"details": self.details,
|
||||
}
|
||||
}));
|
||||
|
||||
// Return the response with the proper status and error body
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +56,18 @@ impl From<std::io::Error> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StatusCode> for Error {
|
||||
fn from(status: StatusCode) -> Self {
|
||||
Error {
|
||||
status: status.as_u16(),
|
||||
details: status
|
||||
.canonical_reason()
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::string::FromUtf8Error> for Error {
|
||||
fn from(error: std::string::FromUtf8Error) -> Self {
|
||||
Self::new(500, format!("Unknown from utf8 error: {}", error))
|
||||
|
||||
99
src/main.rs
99
src/main.rs
@@ -1,15 +1,13 @@
|
||||
use std::env;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use serenity::http::Http;
|
||||
use serenity::prelude::*;
|
||||
use songbird::{SerenityInit, Songbird};
|
||||
use reqwest::Client as HttpClient;
|
||||
use serenity::all::{ShardManager, UserId};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::bot::handler::Handler;
|
||||
use serenity::all::{Cache, ShardManager, UserId};
|
||||
use crate::api::App;
|
||||
use crate::bot::handler::BotHandler;
|
||||
use crate::bot::oai::OAI;
|
||||
|
||||
mod api;
|
||||
@@ -24,47 +22,24 @@ impl TypeMapKey for HttpKey {
|
||||
type Value = HttpClient;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
client: reqwest::Client,
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
redirect_uri: String,
|
||||
http: Arc<Http>,
|
||||
cache: Arc<Cache>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
dotenv::dotenv().ok();
|
||||
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
|
||||
|
||||
if let Err(err) = data::initialize().await {
|
||||
log::error!("Failed to initialize database: {err}");
|
||||
return;
|
||||
};
|
||||
data::initialize().await?;
|
||||
|
||||
// Start API server
|
||||
tokio::spawn(start_api());
|
||||
|
||||
// Start Discord bot
|
||||
start_bot().await;
|
||||
}
|
||||
|
||||
async fn start_api() {
|
||||
let app = Router::new();
|
||||
let addr: String = "127.0.0.1:3000".parse().unwrap();
|
||||
|
||||
let listener = TcpListener::bind(&addr).await.unwrap();
|
||||
log::debug!("API is listening on {}", &addr);
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
async fn start_bot() {
|
||||
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
|
||||
let intents: GatewayIntents = GatewayIntents::all();
|
||||
|
||||
let http: Http = Http::new(&token);
|
||||
let (owners, bot_id) = get_bot_info(&http).await;
|
||||
|
||||
log::debug!(
|
||||
"Starting Discord bot with ID: {bot_id} and owners: {}",
|
||||
owners
|
||||
.iter()
|
||||
.map(|id| id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
);
|
||||
|
||||
// Set up handler with optional OpenAI integration
|
||||
let handler = configure_handler();
|
||||
@@ -72,6 +47,8 @@ async fn start_bot() {
|
||||
// Set up Songbird for voice functionality
|
||||
let songbird = Songbird::serenity();
|
||||
|
||||
let intents: GatewayIntents = GatewayIntents::all();
|
||||
|
||||
let mut client = Client::builder(token, intents)
|
||||
.event_handler(handler)
|
||||
// .framework(StandardFramework::new().configure(|c| c.owners(owners)))
|
||||
@@ -80,29 +57,53 @@ async fn start_bot() {
|
||||
.await
|
||||
.expect("Error creating client");
|
||||
|
||||
let (bot_owner, bot_id) = get_bot_info(&client.http).await;
|
||||
|
||||
let client_secret: String =
|
||||
env::var("DISCORD_SECRET").expect("Expected a secret in the environment");
|
||||
let redirect_uri: String =
|
||||
env::var("API_CALLBACK_URI").expect("Expected a secret in the environment");
|
||||
let app_state = AppState {
|
||||
client: HttpClient::new(),
|
||||
client_id: bot_id.to_string(),
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
http: Arc::clone(&client.http),
|
||||
cache: Arc::clone(&client.cache),
|
||||
};
|
||||
|
||||
log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner);
|
||||
|
||||
// Spawn shutdown signal handling
|
||||
let shard_manager = Arc::clone(&client.shard_manager);
|
||||
tokio::spawn(async move {
|
||||
signal_shutdown(shard_manager).await;
|
||||
});
|
||||
|
||||
// Start the bot
|
||||
// Start API server
|
||||
tokio::spawn(App::new(app_state).serve());
|
||||
|
||||
// Start Discord bot
|
||||
if let Err(why) = client.start_autosharded().await {
|
||||
log::error!("Client error: {why:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_bot_info(http: &Http) -> (HashSet<UserId>, UserId) {
|
||||
async fn get_bot_info(http: &Http) -> (Option<UserId>, UserId) {
|
||||
match http.get_current_application_info().await {
|
||||
Ok(info) => {
|
||||
let mut owners = HashSet::new();
|
||||
let bot_owner;
|
||||
if let Some(team) = info.team {
|
||||
owners.insert(team.owner_user_id);
|
||||
bot_owner = Some(team.owner_user_id);
|
||||
} else if let Some(owner) = info.owner {
|
||||
owners.insert(owner.id);
|
||||
bot_owner = Some(owner.id);
|
||||
} else {
|
||||
bot_owner = None;
|
||||
}
|
||||
match http.get_current_user().await {
|
||||
Ok(bot) => (owners, bot.id),
|
||||
Ok(bot) => (bot_owner, bot.id),
|
||||
Err(why) => panic!("Could not access the bot id: {why:?}"),
|
||||
}
|
||||
}
|
||||
@@ -110,13 +111,13 @@ async fn get_bot_info(http: &Http) -> (HashSet<UserId>, UserId) {
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_handler() -> Handler {
|
||||
fn configure_handler() -> BotHandler {
|
||||
match env::var("OPENAI_TOKEN") {
|
||||
Ok(token) => {
|
||||
log::debug!("OpenAI functionality enabled");
|
||||
let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
|
||||
let base_url = env::var("OPENAI_BASE_URL").unwrap();
|
||||
Handler {
|
||||
BotHandler {
|
||||
oai: Some(OAI {
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
@@ -129,7 +130,7 @@ fn configure_handler() -> Handler {
|
||||
}
|
||||
Err(_) => {
|
||||
log::warn!("OpenAI functionality disabled");
|
||||
Handler { oai: None }
|
||||
BotHandler { oai: None }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user