Added auth route (temp api key route)
This commit is contained in:
2
.env
2
.env
@@ -3,6 +3,8 @@ RUST_LOG=warn,siren=info
|
|||||||
DISCORD_TOKEN=
|
DISCORD_TOKEN=
|
||||||
DISCORD_SECRET=
|
DISCORD_SECRET=
|
||||||
|
|
||||||
|
JWT_SECRET=CHANGEME
|
||||||
|
|
||||||
DATABASE_USER=siren
|
DATABASE_USER=siren
|
||||||
DATABASE_PASSWORD=CHANGEME # Change this to a secure password
|
DATABASE_PASSWORD=CHANGEME # Change this to a secure password
|
||||||
DATABASE_NAME=siren
|
DATABASE_NAME=siren
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ rand_chacha = "0.3.1"
|
|||||||
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
|
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
|
||||||
regex = "1.11.0"
|
regex = "1.11.0"
|
||||||
axum = "0.7.7"
|
axum = "0.7.7"
|
||||||
|
axum-extra = { version = "0.9.6", features = ["typed-header"] }
|
||||||
lazy_static = "1.5.0"
|
lazy_static = "1.5.0"
|
||||||
futures = "0.3.31"
|
jsonwebtoken = "9.3.0"
|
||||||
axum-login = "0.16.0"
|
|
||||||
sqlx-postgres = "0.8.2"
|
|
||||||
|
|||||||
36
src/api/auth/api_key.rs
Normal file
36
src/api/auth/api_key.rs
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use axum::{middleware, Extension, Router};
|
||||||
|
use axum::middleware::from_extractor;
|
||||||
|
use axum::routing::post;
|
||||||
|
use crate::api::auth::{authenticate_middleware, csprng};
|
||||||
|
use crate::api::auth::middleware::AuthorizationMiddleware;
|
||||||
|
use crate::api::auth::session::Session;
|
||||||
|
use crate::AppState;
|
||||||
|
use crate::error::SirenResult;
|
||||||
|
|
||||||
|
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||||
|
Router::new().route("/api-key", post(create_api_key))
|
||||||
|
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ApiKey {
|
||||||
|
pub key: String,
|
||||||
|
pub user_id: String,
|
||||||
|
pub access_mask: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApiKey {
|
||||||
|
fn new(user_id: String, access_mask: u32) -> Self {
|
||||||
|
ApiKey {
|
||||||
|
key: csprng(64),
|
||||||
|
user_id,
|
||||||
|
access_mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_api_key(Extension(session): Extension<Session>) -> SirenResult<String> {
|
||||||
|
log::debug!("Generating API key for {} ({})", &session.user_id, &session.user_name);
|
||||||
|
let api_key = ApiKey::new(session.user_id, 0);
|
||||||
|
Ok(api_key.key)
|
||||||
|
}
|
||||||
10
src/api/auth/bearer_token.rs
Normal file
10
src/api/auth/bearer_token.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct BearerTokenClaims {
|
||||||
|
pub sub: String,
|
||||||
|
pub name: String,
|
||||||
|
pub iat: i64,
|
||||||
|
pub exp: i64,
|
||||||
|
pub jti: String,
|
||||||
|
}
|
||||||
71
src/api/auth/middleware.rs
Normal file
71
src/api/auth/middleware.rs
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
use axum::async_trait;
|
||||||
|
use axum::extract::FromRequestParts;
|
||||||
|
use axum::http::request::Parts;
|
||||||
|
use axum::http::{Method, StatusCode};
|
||||||
|
use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}};
|
||||||
|
use chrono::Utc;
|
||||||
|
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||||
|
use crate::api::auth::bearer_token::BearerTokenClaims;
|
||||||
|
use crate::api::auth::session::Session;
|
||||||
|
use crate::error::SirenResult;
|
||||||
|
|
||||||
|
pub struct AuthorizationMiddleware;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequestParts<S> for AuthorizationMiddleware
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
{
|
||||||
|
type Rejection = StatusCode;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
// For options requests browsers will not send the authorization header.
|
||||||
|
if parts.method == Method::OPTIONS {
|
||||||
|
return Ok(Self);
|
||||||
|
}
|
||||||
|
|
||||||
|
let Ok(TypedHeader(Authorization(bearer))) =
|
||||||
|
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||||
|
else {
|
||||||
|
log::error!("Could not get Authorization header from the request");
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
};
|
||||||
|
|
||||||
|
match check_auth(bearer).await {
|
||||||
|
Ok(session) => {
|
||||||
|
parts.extensions.insert(session);
|
||||||
|
Ok(Self)
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
log::error!("{:?}", err);
|
||||||
|
Err(StatusCode::UNAUTHORIZED)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_auth(bearer: Bearer) -> SirenResult<Session> {
|
||||||
|
// Decode and validate the JWT
|
||||||
|
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
|
||||||
|
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
|
||||||
|
|
||||||
|
let token_data = decode::<BearerTokenClaims>(
|
||||||
|
bearer.token(),
|
||||||
|
&decoding_key,
|
||||||
|
&Validation::default()
|
||||||
|
).map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||||
|
|
||||||
|
let claims = token_data.claims;
|
||||||
|
|
||||||
|
// Check if the token has expired
|
||||||
|
let now = Utc::now().timestamp();
|
||||||
|
if claims.exp < now {
|
||||||
|
return Err(StatusCode::UNAUTHORIZED.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Confirm the session exists in the session store (based on `jti`)
|
||||||
|
match Session::find(&claims.jti).await {
|
||||||
|
Ok(Some(session)) => Ok(session),
|
||||||
|
_ => Err(StatusCode::UNAUTHORIZED)?,
|
||||||
|
}
|
||||||
|
}
|
||||||
28
src/api/auth/mod.rs
Normal file
28
src/api/auth/mod.rs
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use axum::Router;
|
||||||
|
use rand::Rng;
|
||||||
|
use rand_chacha::ChaCha20Rng;
|
||||||
|
use rand_chacha::rand_core::SeedableRng;
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
mod oauth;
|
||||||
|
mod session;
|
||||||
|
mod api_key;
|
||||||
|
mod bearer_token;
|
||||||
|
mod middleware;
|
||||||
|
|
||||||
|
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||||
|
Router::new()
|
||||||
|
.nest("/oauth", oauth::get_routes())
|
||||||
|
.merge(api_key::get_routes())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn csprng(take: usize) -> String {
|
||||||
|
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
||||||
|
let rng = ChaCha20Rng::from_entropy();
|
||||||
|
rng
|
||||||
|
.sample_iter(rand::distributions::Alphanumeric)
|
||||||
|
.take(take)
|
||||||
|
.map(char::from)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
@@ -1,48 +1,24 @@
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::sync::{Arc, OnceLock};
|
use std::sync::Arc;
|
||||||
use axum::extract::{Query, State};
|
use axum::extract::{Query, State};
|
||||||
use axum::http::{HeaderMap, HeaderValue, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, StatusCode};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use axum::http::header::SET_COOKIE;
|
use axum::http::header::SET_COOKIE;
|
||||||
use axum::response::Redirect;
|
use axum::response::Redirect;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
use rand::Rng;
|
|
||||||
use rand_chacha::ChaCha20Rng;
|
|
||||||
use rand_chacha::rand_core::SeedableRng;
|
|
||||||
use redis::{AsyncCommands, RedisResult};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use crate::{data, AppState};
|
use crate::api::auth::bearer_token::BearerTokenClaims;
|
||||||
|
use crate::AppState;
|
||||||
|
use crate::api::auth::csprng;
|
||||||
|
use crate::api::auth::session::Session;
|
||||||
use crate::error::SirenResult;
|
use crate::error::SirenResult;
|
||||||
|
|
||||||
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
|
|
||||||
|
|
||||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/authorize", get(discord_authorize))
|
.route("/authorize", get(discord_authorize))
|
||||||
.route("/callback", get(oauth_callback))
|
.route("/callback", get(oauth_callback))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_session_ttl() -> i64 {
|
|
||||||
// Initialize the SESSION_TTL value lazily
|
|
||||||
*SESSION_TTL.get_or_init(|| {
|
|
||||||
env::var("SESSION_TTL")
|
|
||||||
.ok()
|
|
||||||
.and_then(|val| val.parse::<i64>().ok())
|
|
||||||
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn csprng(take: usize) -> String {
|
|
||||||
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
|
||||||
let rng = ChaCha20Rng::from_entropy();
|
|
||||||
rng
|
|
||||||
.sample_iter(rand::distributions::Alphanumeric)
|
|
||||||
.take(take)
|
|
||||||
.map(char::from)
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct AuthQuery {
|
struct AuthQuery {
|
||||||
code: String,
|
code: String,
|
||||||
@@ -66,59 +42,6 @@ struct DiscordUser {
|
|||||||
avatar: Option<String>,
|
avatar: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
struct Session {
|
|
||||||
session_id: String,
|
|
||||||
user_id: String,
|
|
||||||
user_name: String,
|
|
||||||
pub expires_at: DateTime<Utc>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Session {
|
|
||||||
fn new(id: String, user_id: String, user_name: String) -> Session {
|
|
||||||
let now = Utc::now();
|
|
||||||
let session_ttl = get_session_ttl();
|
|
||||||
Session {
|
|
||||||
session_id: id,
|
|
||||||
user_id,
|
|
||||||
user_name,
|
|
||||||
expires_at: now + chrono::Duration::seconds(session_ttl),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn insert(&self) -> SirenResult<()> {
|
|
||||||
let mut redis = data::redis_async_connection().await?;
|
|
||||||
let session_id = self.session_id.clone();
|
|
||||||
redis
|
|
||||||
.set_ex(
|
|
||||||
session_id,
|
|
||||||
serde_json::to_string(self)?,
|
|
||||||
self.expires_at.timestamp() as u64,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get(session_id: String) -> SirenResult<Option<Session>> {
|
|
||||||
let mut redis = data::redis_async_connection().await?;
|
|
||||||
let result: RedisResult<Option<String>> = redis.get(session_id).await;
|
|
||||||
match result {
|
|
||||||
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
|
|
||||||
Ok(None) => Ok(None),
|
|
||||||
Err(err) => Err(err.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn delete(session_id: String) -> SirenResult<()> {
|
|
||||||
let mut redis = data::redis_async_connection().await?;
|
|
||||||
let result: RedisResult<()> = redis.del(session_id).await;
|
|
||||||
match result {
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(err) => Err(err.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
|
// async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
|
||||||
// // Construct the Discord OAuth URL
|
// // Construct the Discord OAuth URL
|
||||||
// let discord_auth_url = format!(
|
// let discord_auth_url = format!(
|
||||||
@@ -137,10 +60,17 @@ async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<St
|
|||||||
Ok(discord_auth_url)
|
Ok(discord_auth_url)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct BearerTokenResponse {
|
||||||
|
pub access_token: String,
|
||||||
|
pub token_type: String,
|
||||||
|
pub expires_in: u64,
|
||||||
|
}
|
||||||
|
|
||||||
async fn oauth_callback(
|
async fn oauth_callback(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Query(query): Query<AuthQuery>,
|
Query(query): Query<AuthQuery>,
|
||||||
) -> SirenResult<(HeaderMap, Json<DiscordUser>)> {
|
) -> SirenResult<Json<BearerTokenResponse>> {
|
||||||
// Exchange code for an access token
|
// Exchange code for an access token
|
||||||
let token_response = state
|
let token_response = state
|
||||||
.client
|
.client
|
||||||
@@ -193,28 +123,32 @@ async fn oauth_callback(
|
|||||||
|
|
||||||
log::debug!("User authenticated: {:?}", user_data);
|
log::debug!("User authenticated: {:?}", user_data);
|
||||||
|
|
||||||
// Generate a session token
|
|
||||||
let session_token = csprng(16);
|
|
||||||
let expiration = env::var("API_SESSION_TTL")
|
|
||||||
.expect("Expected a session ttl in the environment")
|
|
||||||
.parse::<u64>()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Create and insert the session
|
// Create and insert the session
|
||||||
let session = Session::new(
|
let session = Session::new(user_data.id.clone(), user_data.username.clone());
|
||||||
session_token.clone(),
|
|
||||||
user_data.id.clone(),
|
|
||||||
user_data.username.clone(),
|
|
||||||
);
|
|
||||||
session.insert().await?;
|
session.insert().await?;
|
||||||
|
|
||||||
let cookie_value = format!(
|
let issued_at = chrono::Utc::now();
|
||||||
"session={}; HttpOnly; Path=/; Max-Age={}",
|
|
||||||
session_token, expiration
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let claims = BearerTokenClaims {
|
||||||
headers.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
|
sub: session.user_id.clone(),
|
||||||
|
name: session.user_name.clone(),
|
||||||
|
iat: issued_at.timestamp(),
|
||||||
|
exp: session.expires_at.timestamp(),
|
||||||
|
jti: session.session_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
Ok((headers, Json(user_data)))
|
// Create the JWT
|
||||||
|
let jwt_secret = env::var("JWT_SECRET").expect("Expected a JWT secret in the environment");
|
||||||
|
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
|
||||||
|
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
|
|
||||||
|
// Return the bearer token and user information
|
||||||
|
let response = BearerTokenResponse {
|
||||||
|
access_token: token,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
expires_in: (session.expires_at.timestamp() - issued_at.timestamp()) as u64,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Json(response))
|
||||||
}
|
}
|
||||||
73
src/api/auth/session.rs
Normal file
73
src/api/auth/session.rs
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use redis::{AsyncCommands, RedisResult};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::api::auth::csprng;
|
||||||
|
use crate::data;
|
||||||
|
use crate::error::SirenResult;
|
||||||
|
|
||||||
|
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
|
||||||
|
|
||||||
|
fn get_session_ttl() -> i64 {
|
||||||
|
// Initialize the SESSION_TTL value lazily
|
||||||
|
*SESSION_TTL.get_or_init(|| {
|
||||||
|
env::var("API_SESSION_TTL")
|
||||||
|
.ok()
|
||||||
|
.and_then(|val| val.parse::<i64>().ok())
|
||||||
|
.unwrap_or(3600) // Default to 3600 seconds (1 hour)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||||
|
pub struct Session {
|
||||||
|
pub session_id: String,
|
||||||
|
pub user_id: String,
|
||||||
|
pub user_name: String,
|
||||||
|
pub expires_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session {
|
||||||
|
pub fn new(user_id: String, user_name: String) -> Session {
|
||||||
|
let now = Utc::now();
|
||||||
|
let session_ttl = get_session_ttl();
|
||||||
|
Session {
|
||||||
|
session_id: csprng(32),
|
||||||
|
user_id,
|
||||||
|
user_name,
|
||||||
|
expires_at: now + chrono::Duration::seconds(session_ttl),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert(&self) -> SirenResult<()> {
|
||||||
|
let mut redis = data::redis_async_connection().await?;
|
||||||
|
let session_id = self.session_id.clone();
|
||||||
|
redis
|
||||||
|
.set_ex(
|
||||||
|
session_id,
|
||||||
|
serde_json::to_string(self)?,
|
||||||
|
self.expires_at.timestamp() as u64,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn find(session_id: &str) -> SirenResult<Option<Session>> {
|
||||||
|
let mut redis = data::redis_async_connection().await?;
|
||||||
|
let result: RedisResult<Option<String>> = redis.get(session_id).await;
|
||||||
|
match result {
|
||||||
|
Ok(Some(value)) => Ok(Some(serde_json::from_str(&value)?)),
|
||||||
|
Ok(None) => Ok(None),
|
||||||
|
Err(err) => Err(err.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete(session_id: &str) -> SirenResult<()> {
|
||||||
|
let mut redis = data::redis_async_connection().await?;
|
||||||
|
let result: RedisResult<()> = redis.del(session_id).await;
|
||||||
|
match result {
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
Err(err) => Err(err.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,12 @@
|
|||||||
mod app;
|
|
||||||
mod oauth;
|
|
||||||
|
|
||||||
pub use app::App;
|
pub use app::App;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
|
mod app;
|
||||||
|
mod auth;
|
||||||
|
|
||||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||||
Router::new().nest("/oauth", oauth::get_routes())
|
Router::new().merge(auth::get_routes())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,24 +10,24 @@ pub async fn process_message(ctx: &Context, command: &CommandInteraction, privat
|
|||||||
|
|
||||||
pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> {
|
pub async fn user_id_dm(ctx: &Context, user_id: &UserId, content: String) -> Option<Message> {
|
||||||
let data = CreateMessage::new().content(content.to_owned());
|
let data = CreateMessage::new().content(content.to_owned());
|
||||||
return match user_id.dm(ctx, data).await {
|
match user_id.dm(ctx, data).await {
|
||||||
Ok(message) => Some(message),
|
Ok(message) => Some(message),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
log::error!("Failed to create direct message for {content}\n{err}");
|
log::error!("Failed to create direct message for {content}\n{err}");
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option<Message> {
|
pub async fn user_dm(ctx: &Context, user: &User, content: String) -> Option<Message> {
|
||||||
let data = CreateMessage::new().content(content.to_owned());
|
let data = CreateMessage::new().content(content.to_owned());
|
||||||
return match user.direct_message(ctx, data).await {
|
match user.direct_message(ctx, data).await {
|
||||||
Ok(message) => Some(message),
|
Ok(message) => Some(message),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
log::error!("Failed to create direct message for {content}\n{err}");
|
log::error!("Failed to create direct message for {content}\n{err}");
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_message_response(
|
pub async fn create_message_response(
|
||||||
@@ -50,7 +50,7 @@ pub async fn create_message_response(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) {
|
pub async fn create_modal_response(ctx: &Context, modal: &ModalInteraction) {
|
||||||
let mut data = CreateInteractionResponseMessage::new();
|
let data = CreateInteractionResponseMessage::new();
|
||||||
let builder = CreateInteractionResponse::Message(data);
|
let builder = CreateInteractionResponse::Message(data);
|
||||||
match modal.create_response(&ctx.http, builder).await {
|
match modal.create_response(&ctx.http, builder).await {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use serenity::all::{CommandDataOption, CommandInteraction, Context, CreateCommand};
|
use serenity::all::{CommandInteraction, Context, CreateCommand};
|
||||||
use crate::bot::chat::create_message_response;
|
use crate::bot::chat::create_message_response;
|
||||||
|
|
||||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use serenity::all::{CreateInteractionResponse, Interaction, ResumedEvent};
|
use serenity::all::{Interaction, ResumedEvent};
|
||||||
use serenity::async_trait;
|
use serenity::async_trait;
|
||||||
use serenity::model::gateway::Ready;
|
use serenity::model::gateway::Ready;
|
||||||
use serenity::model::channel::Message;
|
use serenity::model::channel::Message;
|
||||||
@@ -7,7 +7,7 @@ use crate::bot::commands::chat::generate_response;
|
|||||||
use crate::bot::oai::OAI;
|
use crate::bot::oai::OAI;
|
||||||
use crate::data::guilds::GuildCache;
|
use crate::data::guilds::GuildCache;
|
||||||
use super::{commands};
|
use super::{commands};
|
||||||
use super::chat::{create_message_response, create_modal_response};
|
use super::chat::{create_modal_response};
|
||||||
|
|
||||||
pub struct BotHandler {
|
pub struct BotHandler {
|
||||||
// Open AI Config
|
// Open AI Config
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ impl GuildCache {
|
|||||||
)",
|
)",
|
||||||
TABLE_NAME
|
TABLE_NAME
|
||||||
))
|
))
|
||||||
.bind(self.id)
|
.bind(self.id)
|
||||||
.bind(&self.name)
|
.bind(&self.name)
|
||||||
.bind(self.owner_id)
|
.bind(self.owner_id)
|
||||||
.bind(self.volume)
|
.bind(self.volume)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,10 +40,7 @@ impl GuildCache {
|
|||||||
let query = QueryBuilder::new(TABLE_NAME)
|
let query = QueryBuilder::new(TABLE_NAME)
|
||||||
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder
|
.where_condition(Condition::is_equal("id", "$1")) // Use a placeholder
|
||||||
.build();
|
.build();
|
||||||
let item = sqlx::query_as(&query)
|
let item = sqlx::query_as(&query).bind(id).fetch_optional(pool).await?;
|
||||||
.bind(id)
|
|
||||||
.fetch_optional(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(item)
|
Ok(item)
|
||||||
}
|
}
|
||||||
@@ -58,12 +55,12 @@ impl GuildCache {
|
|||||||
WHERE id = $1",
|
WHERE id = $1",
|
||||||
TABLE_NAME
|
TABLE_NAME
|
||||||
))
|
))
|
||||||
.bind(self.id)
|
.bind(self.id)
|
||||||
.bind(&self.name)
|
.bind(&self.name)
|
||||||
.bind(self.owner_id)
|
.bind(self.owner_id)
|
||||||
.bind(self.volume)
|
.bind(self.volume)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,4 +160,4 @@ impl Condition {
|
|||||||
Condition::Group(a) => format!("({})", a.to_sql()),
|
Condition::Group(a) => format!("({})", a.to_sql()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::collections::HashSet;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use serenity::http::Http;
|
use serenity::http::Http;
|
||||||
use serenity::prelude::*;
|
use serenity::prelude::*;
|
||||||
@@ -72,7 +71,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
cache: Arc::clone(&client.cache),
|
cache: Arc::clone(&client.cache),
|
||||||
};
|
};
|
||||||
|
|
||||||
log::debug!("Starting Siren with ID: {bot_id} (Contact: {:?})", bot_owner);
|
log::debug!(
|
||||||
|
"Starting Siren with ID: {bot_id} (Contact: {:?})",
|
||||||
|
bot_owner
|
||||||
|
);
|
||||||
|
|
||||||
// Spawn shutdown signal handling
|
// Spawn shutdown signal handling
|
||||||
let shard_manager = Arc::clone(&client.shard_manager);
|
let shard_manager = Arc::clone(&client.shard_manager);
|
||||||
|
|||||||
Reference in New Issue
Block a user