Merge pull request #12 from bensherriff/develop

v0.2.6
This commit is contained in:
Ben Sherriff
2023-10-19 16:56:11 -04:00
committed by GitHub
54 changed files with 2516 additions and 515 deletions

View File

@@ -1,5 +1,6 @@
{
"rust-analyzer.linkedProjects": [
"./service/Cargo.toml"
]
],
"rust-analyzer.showUnlinkedFileNotification": false
}

View File

@@ -6,6 +6,23 @@ DATABASE_NAME=siren
DATABASE_HOST=localhost
DATABASE_PORT=5432
ACCESS_TOKEN_PRIVATE_KEY=
ACCESS_TOKEN_PUBLIC_KEY=
ACCESS_TOKEN_MAXAGE=5
REFRESH_TOKEN_PRIVATE_KEY=
REFRESH_TOKEN_PUBLIC_KEY=
REFRESH_TOKEN_MAXAGE=30
REDIS_HOST=localhost
REDIS_PORT=6379
MINIO_ROOT_USER=siren
MINIO_ROOT_PASSWORD=
MINIO_HOST=localhost
MINIO_PORT=9000
MINIO_PORT_INTERNAL=9001
SERVICE_HOST=localhost
SERVICE_PORT=5000
DATA_DIR_PATH=

View File

@@ -1 +1 @@
SIREN_VERSION=0.2.5
SIREN_VERSION=0.2.6

View File

@@ -1,6 +1,6 @@
[package]
name = "service"
version = "0.2.5"
version = "0.2.6"
edition = "2021"
authors = ["Ben Sherriff <hello@bensherriff.com>"]
repository = "https://github.com/bensherriff/siren"
@@ -13,7 +13,6 @@ path = "src/lib.rs"
[dependencies]
actix-web = "4.4.0"
actix-rt = "2.9.0"
actix-cors = "0.6.4"
actix-web-httpauth = "0.8.1"
chrono = { version = "0.4.31", features = ["serde"] }
@@ -25,6 +24,11 @@ diesel_migrations = { version = "2.1.0", features = ["postgres"] }
r2d2 = "0.8.10"
lazy_static = "1.4.0"
uuid = { version = "1.4.1", features = ["serde", "v4"] }
argon2 = "0.5.2"
jsonwebtoken = "9.0.0"
redis = { version = "0.23.3", features = ["tokio-comp", "connection-manager", "r2d2"] }
base64 = "0.21.4"
rust-s3 = "0.33.0"
[dependencies.tokio]
version = "1.32.0"

View File

@@ -17,20 +17,41 @@ RUN cargo build --release
FROM debian:bookworm-slim as packages
WORKDIR /packages
ARG TARGETPLATFORM
RUN apt-get update && apt-get install -y curl tar xz-utils && \
curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux > yt-dlp && \
chmod +x yt-dlp && \
curl -L https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz > ffmpeg.tar.xz && \
tar -xJf ffmpeg.tar.xz --wildcards */bin/ffmpeg --transform='s/^.*\///' && rm ffmpeg.tar.xz
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
echo "Unsupported platform: amd64" && false; \
elif [ "$TARGETPLATFORM" = "linux/arm/v7" ]; then \
curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux_armv7l > yt-dlp && \
chmod +x yt-dlp; \
elif [ "$TARGETPLATFORM" = "linux/aarch64" ]; then \
curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux_aarch64 > yt-dlp && \
chmod +x yt-dlp; \
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux_aarch64 > yt-dlp && \
chmod +x yt-dlp && \
curl -L https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linuxarm64-gpl.tar.xz > ffmpeg.tar.xz && \
tar -xJf ffmpeg.tar.xz --wildcards */bin/ffmpeg --transform='s/^.*\///' && rm ffmpeg.tar.xz; \
elif [ "$TARGETPLATFORM" = "linux/x86_64" ]; then \
curl -L https://github.com/yt-dlp/yt-dlp/releases/latest/download/yt-dlp_linux > yt-dlp && \
chmod +x yt-dlp && \
curl -L https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz > ffmpeg.tar.xz && \
tar -xJf ffmpeg.tar.xz --wildcards */bin/ffmpeg --transform='s/^.*\///' && rm ffmpeg.tar.xz; \
else \
echo "Unsupported platform: $TARGETPLATFORM" && false; \
fi
# =========
# Runtime
# =========
FROM rust:bookworm as runtime
FROM debian:bookworm-slim as runtime
WORKDIR /service
USER root
COPY --from=builder /builder/target/release/service /usr/local/bin/service
COPY --from=packages /packages /usr/bin
RUN apt-get update && apt-get install -y libc6 libc6-dev libopus-dev libpq5 libpq-dev python3-pip ffmpeg
CMD ["service"]

View File

@@ -13,8 +13,10 @@ help: ## Help command
build: ## Build the docker image
docker compose build
db: ## Start the docker database
utils: ## Start the utils
docker compose up -d db
docker compose up -d redis
docker compose up -d minio
up: ## Start the app
docker compose up -d

View File

@@ -1,5 +1,8 @@
version: '3.8'
x-env_file_personifi: &env
- .env
name: siren
services:
service:
@@ -10,11 +13,14 @@ services:
dockerfile: ./Dockerfile
args:
- VERSION=${SIREN_VERSION:-latest}
env_file:
- .env
env_file: *env
environment:
DATABASE_HOST: db
DATABASE_PORT: 5432
REDIS_HOST: redis
REDIS_PORT: 6379
MINIO_HOST: minio
MINIO_PORT: 9000
SERVICE_HOST: service
SERVICE_PORT: 5000
DATA_DIR_PATH: /data
@@ -31,8 +37,7 @@ services:
db:
image: postgres:latest
container_name: siren-db
env_file:
- .env
env_file: *env
environment:
POSTGRES_USER: ${DATABASE_USER}
POSTGRES_PASSWORD: ${DATABASE_PASSWORD}
@@ -45,10 +50,34 @@ services:
networks:
- backend
restart: unless-stopped
redis:
image: redis:latest
container_name: siren-redis
ports:
- ${REDIS_PORT:-6379}:6379
networks:
- backend
restart: unless-stopped
minio:
image: minio/minio
container_name: siren-minio
environment:
MINIO_ROOT_USER: ${MINIO_ROOT_USER}
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD}
volumes:
- minio:/data
ports:
- ${MINIO_PORT:-9000}:9000
- ${MINIO_PORT_INTERNAL:-9001}:9001
networks:
- backend
command: server --console-address ":9001" /data
restart: unless-stopped
volumes:
db:
db_logs:
minio:
networks:
frontend:

View File

@@ -0,0 +1 @@
DROP TABLE guilds;

View File

@@ -0,0 +1,5 @@
CREATE TABLE IF NOT EXISTS guilds (
id BIGINT PRIMARY KEY NOT NULL,
bot_id BIGINT NOT NULL,
volume INTEGER NOT NULL
);

View File

@@ -0,0 +1 @@
DROP TABLE users;

View File

@@ -0,0 +1,8 @@
CREATE TABLE IF NOT EXISTS users (
email TEXT PRIMARY KEY NOT NULL,
hash TEXT NOT NULL,
role TEXT NOT NULL,
first_name TEXT NOT NULL,
last_name TEXT NOT NULL,
verified BOOLEAN NOT NULL DEFAULT FALSE
);

98
service/src/auth/mod.rs Normal file
View File

@@ -0,0 +1,98 @@
use std::env;
use argon2::{password_hash::{rand_core::OsRng, PasswordHasher, PasswordVerifier, SaltString, Error as HashError}, Argon2, PasswordHash};
use base64::{engine::general_purpose, Engine as _};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, encode, decode, Validation, Algorithm};
use serde::{Deserialize, Serialize};
mod model;
mod routes;
pub use model::*;
pub use routes::init_routes;
use siren::ServiceError;
#[derive(Debug, Serialize, Deserialize)]
struct TokenClaims {
sub: String, // Subject
token_uuid: String, // Issuer
exp: i64, // Expiration time
iat: i64, // Issued At
nbf: i64 // Not Before
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenDetails {
pub token: Option<String>,
pub token_uuid: uuid::Uuid,
pub email: String,
pub expires_in: Option<i64>
}
// https://codevoweb.com/rust-actix-web-jwt-access-and-refresh-tokens/
// https://github.com/wpcodevo/rust-jwt-rs256/blob/master/src/main.rs
pub fn verify_token(token: &str, public_key: &str) -> Result<TokenDetails, ServiceError> {
let bytes_public_key = general_purpose::STANDARD.decode(public_key).unwrap();
let decoded_public_key = String::from_utf8(bytes_public_key).unwrap();
let key = DecodingKey::from_rsa_pem(decoded_public_key.as_bytes())?;
let validation = Validation::new(Algorithm::RS256);
let decoded = decode::<TokenClaims>(token, &key, &validation)?;
let email = decoded.claims.sub;
let token_uuid = uuid::Uuid::parse_str(decoded.claims.token_uuid.as_str()).unwrap();
Ok(TokenDetails { token: None, token_uuid, email, expires_in: None })
}
pub fn generate_access_token(email: &str) -> Result<TokenDetails, ServiceError> {
let access_token_max_age = env::var("ACCESS_TOKEN_MAXAGE")
.expect("ACCESS_TOKEN_MAXAGE must be set")
.parse::<i64>()
.expect("ACCESS_TOKEN_MAXAGE must be an integer");
let access_private_key = env::var("ACCESS_TOKEN_PRIVATE_KEY")
.expect("ACCESS_TOKEN_PRIVATE_KEY must be set");
generate_token(&email, access_token_max_age, &access_private_key)
}
pub fn generate_refresh_token(email: &str) -> Result<TokenDetails, ServiceError> {
let refresh_token_max_age = env::var("REFRESH_TOKEN_MAXAGE")
.expect("REFRESH_TOKEN_MAXAGE must be set")
.parse::<i64>()
.expect("REFRESH_TOKEN_MAXAGE must be an integer");
let refresh_private_key = env::var("REFRESH_TOKEN_PRIVATE_KEY")
.expect("REFRESH_TOKEN_PRIVATE_KEY must be set");
generate_token(&email, refresh_token_max_age, &refresh_private_key)
}
pub fn generate_token(email: &str, ttl: i64, private_key: &str) -> Result<TokenDetails, ServiceError> {
let now = chrono::Utc::now();
let mut token_details = TokenDetails {
token: None,
token_uuid: uuid::Uuid::new_v4(),
email: email.to_string(),
expires_in: Some((now + chrono::Duration::minutes(ttl)).timestamp())
};
let claims = TokenClaims {
sub: token_details.email.clone(),
token_uuid: token_details.token_uuid.to_string(),
exp: token_details.expires_in.unwrap(),
iat: now.timestamp(),
nbf: now.timestamp()
};
let header = Header::new(Algorithm::RS256);
let bytes_private_key = general_purpose::STANDARD.decode(private_key).unwrap();
let decoded_private_key = String::from_utf8(bytes_private_key).unwrap();
let key = EncodingKey::from_rsa_pem(decoded_private_key.as_bytes())?;
let token = encode(&header, &claims, &key)?;
token_details.token = Some(token);
Ok(token_details)
}
pub fn hash_password(password: &[u8]) -> Result<String, HashError> {
let salt = SaltString::generate(&mut OsRng);
Ok(Argon2::default().hash_password(password, &salt)?.to_string())
}
pub fn verify_password(hash: &str, password: &[u8]) -> Result<(), HashError> {
let parsed_hash = PasswordHash::new(hash)?;
Ok(Argon2::default().verify_password(password, &parsed_hash)?)
}

185
service/src/auth/model.rs Normal file
View File

@@ -0,0 +1,185 @@
use std::{future::{ready, Ready}, env};
use actix_web::{FromRequest, Error as ActixError, HttpRequest, dev::Payload, http};
use diesel::prelude::*;
use log::error;
use redis::Commands;
use serde::{Serialize, Deserialize};
use siren::ServiceError;
use crate::db::schema::users;
use super::{hash_password, verify_token};
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterUser {
pub email: String,
pub password: String,
pub first_name: String,
pub last_name: String,
}
impl RegisterUser {
pub fn convert_to_insert(self) -> Result<InsertUser, ServiceError> {
let hash = hash_password(self.password.as_bytes())?;
Ok(InsertUser {
email: self.email.to_lowercase(),
hash,
role: "user".to_string(),
first_name: self.first_name,
last_name: self.last_name,
verified: false,
})
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginRequest {
pub email: String,
pub password: String,
}
#[derive(Debug, Queryable, QueryableByName, Serialize, Deserialize)]
#[diesel(table_name = users)]
pub struct QueryUser {
pub email: String,
pub hash: String,
pub role: String,
pub first_name: String,
pub last_name: String,
pub verified: bool,
}
impl QueryUser {
pub fn get_by_email(email: &str) -> Result<QueryUser, ServiceError> {
let mut conn = crate::db::connection()?;
// Check if the user exists by email, case insensitive
let user = users::table
.filter(users::email.eq(email.to_lowercase()))
.first(&mut conn)?;
Ok(user)
}
}
#[derive(Debug, Insertable, AsChangeset, Serialize, Deserialize)]
#[diesel(table_name = users)]
pub struct InsertUser {
pub email: String,
pub hash: String,
pub role: String,
pub first_name: String,
pub last_name: String,
pub verified: bool,
}
impl InsertUser {
pub fn insert(user: Self) -> Result<QueryUser, ServiceError> {
let mut conn = crate::db::connection()?;
let user = diesel::insert_into(users::table)
.values(user)
.get_result(&mut conn)?;
Ok(user)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseUser {
pub email: String,
pub role: String,
pub first_name: String,
pub last_name: String,
}
impl From<QueryUser> for ResponseUser {
fn from(user: QueryUser) -> Self {
ResponseUser {
email: user.email,
role: user.role,
first_name: user.first_name,
last_name: user.last_name,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtAuth {
pub token: uuid::Uuid,
pub user: ResponseUser
}
impl FromRequest for JwtAuth {
type Error = ActixError;
type Future = Ready<Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let access_token = match req
.cookie("access_token")
.map(|c| c.value().to_string())
.or_else(|| {
req.headers().get(http::header::AUTHORIZATION)
.map(|h| h.to_str().unwrap().split_at(7).1.to_string())
}) {
Some(token) => token,
None => return ready(Err(ActixError::from(ServiceError {
status: 401,
message: "Unauthorized".to_string()
})))
};
let public_key = env::var("ACCESS_TOKEN_PUBLIC_KEY")
.expect("ACCESS_TOKEN_PUBLIC_KEY must be set");
let access_token_details = match verify_token(&access_token, &public_key) {
Ok(token_details) => token_details,
Err(err) => {
error!("Failed to verify access token: {}", err);
return ready(Err(ActixError::from(ServiceError {
status: 401,
message: format!("Failed to verify access token: {}", err)
})))
}
};
let access_token_uuid = uuid::Uuid::parse_str(&access_token_details.token_uuid.to_string()).unwrap();
let mut conn = match crate::db::redis_connection() {
Ok(conn) => conn,
Err(err) => {
error!("Failed to get redis connection: {}", err);
return ready(Err(ActixError::from(ServiceError {
status: 500,
message: format!("Failed to get redis connection: {}", err)
})))
}
};
let user_email = match conn.get::<_, String>(access_token_uuid.clone().to_string()) {
Ok(result) => result,
Err(_) => {
return ready(Err(ActixError::from(ServiceError {
status: 401,
message: format!("Access token was not found")
})))
}
};
match QueryUser::get_by_email(&user_email) {
Ok(user) => {
ready(Ok(JwtAuth { token: access_token_uuid, user: user.into() }))
}
Err(_) => return ready(Err(ActixError::from(ServiceError {
status: 401,
message: format!("User was not found")
})))
}
}
}
pub fn verify_role(auth: &JwtAuth, role: &str) -> Result<(), ServiceError> {
if auth.user.role == role {
Ok(())
} else {
Err(ServiceError {
status: 403,
message: "Forbidden".to_string()
})
}
}

367
service/src/auth/routes.rs Normal file
View File

@@ -0,0 +1,367 @@
use std::env;
use actix_web::{get, post, web, HttpResponse, ResponseError, cookie::{Cookie, time::Duration}, HttpRequest};
use log::error;
use redis::AsyncCommands;
use serde::{Serialize, Deserialize};
use siren::ServiceError;
use crate::{auth::{LoginRequest, RegisterUser, InsertUser, QueryUser, verify_password, JwtAuth, verify_token, generate_access_token, generate_refresh_token}, db};
#[post("/register")]
async fn register(user: web::Json<RegisterUser>) -> HttpResponse {
let register_user = user.0;
let insert_user: InsertUser = match register_user.convert_to_insert() {
Ok(user) => user,
Err(err) => return ResponseError::error_response(&err)
};
match InsertUser::insert(insert_user) {
Ok(_) => {
HttpResponse::Created().finish()
},
Err(err) => {
// Obfuscate the service error message to prevent leaking database details
if err.status == 409 {
return HttpResponse::Conflict().finish();
} else {
return ResponseError::error_response(&err);
}
}
}
}
#[post("/login")]
async fn login(request: web::Json<LoginRequest>) -> HttpResponse {
let email = request.email.clone();
let query_user = match QueryUser::get_by_email(&email) {
Ok(query_user) => query_user,
Err(err) => return ResponseError::error_response(&err)
};
let hash = &query_user.hash;
let password = request.password.as_bytes();
match verify_password(hash, password) {
Ok(_) => {
let access_token_details = match generate_access_token(&email) {
Ok(token_details) => token_details,
Err(err) => {
error!("Failed to generate access token: {}", err);
return ResponseError::error_response(&err)
}
};
let refresh_token_details = match generate_refresh_token(&email) {
Ok(token_details) => token_details,
Err(err) => {
error!("Failed to generate refresh token: {}", err);
return ResponseError::error_response(&err)
}
};
let mut conn = match db::redis_async_connection().await {
Ok(conn) => conn,
Err(err) => {
error!("Failed to get redis connection: {}", err);
return ResponseError::error_response(&err)
}
};
let access_token_max_age = env::var("ACCESS_TOKEN_MAXAGE")
.expect("ACCESS_TOKEN_MAXAGE must be set")
.parse::<i64>()
.expect("ACCESS_TOKEN_MAXAGE must be an integer");
let refresh_token_max_age = env::var("REFRESH_TOKEN_MAXAGE")
.expect("REFRESH_TOKEN_MAXAGE must be set")
.parse::<i64>()
.expect("REFRESH_TOKEN_MAXAGE must be an integer");
let access_result: redis::RedisResult<()> = conn.set_ex(access_token_details.token_uuid.to_string(), &email, (access_token_max_age * 60) as usize).await;
if let Err(err) = access_result {
error!("Failed to set access token in redis: {}", err);
return ResponseError::error_response(&ServiceError {
status: 500,
message: format!("Failed to set access token in redis: {}", err)
})
};
let refresh_result: redis::RedisResult<()> = conn.set_ex(refresh_token_details.token_uuid.to_string(), &email, (refresh_token_max_age * 60) as usize).await;
if let Err(err) = refresh_result {
error!("Failed to set refresh token in redis: {}", err);
return ResponseError::error_response(&ServiceError {
status: 500,
message: format!("Failed to set refresh token in redis: {}", err)
})
};
let access_cookie = Cookie::build("access_token", access_token_details.token.clone().unwrap())
.path("/")
.max_age(Duration::new(access_token_max_age * 60, 0))
.http_only(true)
.secure(true)
.finish();
let refresh_cookie = Cookie::build("refresh_token", refresh_token_details.token.clone().unwrap())
.path("/")
.max_age(Duration::new(refresh_token_max_age * 60, 0))
.http_only(true)
.secure(true)
.finish();
let logged_in_cookie = Cookie::build("logged_in", "true")
.path("/")
.max_age(Duration::new(access_token_max_age * 60, 0))
.http_only(false)
.finish();
let access_token_uuid = uuid::Uuid::parse_str(&access_token_details.token_uuid.to_string()).unwrap();
HttpResponse::Ok()
.cookie(access_cookie)
.cookie(refresh_cookie)
.cookie(logged_in_cookie)
.json(JwtAuth { token: access_token_uuid, user: query_user.into() })
},
Err(err) => ResponseError::error_response(&ServiceError {
status: 401,
message: err.to_string()
})
}
}
#[derive(Serialize, Deserialize)]
struct RefreshParams {
refresh_token_rotation: Option<bool>
}
#[get("/refresh")]
async fn refresh(req: HttpRequest) -> HttpResponse {
let params = match web::Query::<RefreshParams>::from_query(req.query_string()) {
Ok(params) => params,
Err(err) => return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
};
let refresh_token = match req.cookie("refresh_token") {
Some(cookie) => cookie.value().to_string(),
None => return ResponseError::error_response(&ServiceError {
status: 401,
message: "Refresh token not found".to_string()
})
};
let public_key = env::var("REFRESH_TOKEN_PUBLIC_KEY")
.expect("REFRESH_TOKEN_PUBLIC_KEY must be set");
let refresh_token_details = match verify_token(&refresh_token, &public_key) {
Ok(token_details) => token_details,
Err(err) => return ResponseError::error_response(&err)
};
let email = refresh_token_details.email.clone();
match QueryUser::get_by_email(&email) {
Ok(query_user) => {
let access_token_details = match generate_access_token(&email) {
Ok(token_details) => token_details,
Err(err) => {
error!("Failed to generate access token: {}", err);
return ResponseError::error_response(&err)
}
};
let mut conn = match db::redis_async_connection().await {
Ok(conn) => conn,
Err(err) => {
error!("Failed to get redis connection: {}", err);
return ResponseError::error_response(&err)
}
};
// Delete old auth token if it exists
match req.cookie("access_token") {
Some(cookie) => {
let access_token = cookie.value().to_string();
let public_key = env::var("ACCESS_TOKEN_PUBLIC_KEY")
.expect("ACCESS_TOKEN_PUBLIC_KEY must be set");
match verify_token(&access_token, &public_key) {
Ok(token_details) => {
let _: redis::RedisResult<()> = conn.del(token_details.token_uuid.to_string()).await;
},
Err(_) => {}
};
},
None => {}
};
let access_token_max_age = env::var("ACCESS_TOKEN_MAXAGE")
.expect("ACCESS_TOKEN_MAXAGE must be set")
.parse::<i64>()
.expect("ACCESS_TOKEN_MAXAGE must be an integer");
let access_result: redis::RedisResult<()> = conn.set_ex(access_token_details.token_uuid.to_string(), &email, (access_token_max_age * 60) as usize).await;
if let Err(err) = access_result {
error!("Failed to set access token in redis: {}", err);
return ResponseError::error_response(&ServiceError {
status: 500,
message: format!("Failed to set access token in redis: {}", err)
})
};
let access_cookie = Cookie::build("access_token", access_token_details.token.clone().unwrap())
.path("/")
.max_age(Duration::new(access_token_max_age * 60, 0))
.http_only(true)
.secure(true)
.finish();
let logged_in_cookie = Cookie::build("logged_in", "true")
.path("/")
.max_age(Duration::new(access_token_max_age * 60, 0))
.http_only(false)
.finish();
let access_token_uuid = uuid::Uuid::parse_str(&access_token_details.token_uuid.to_string()).unwrap();
// Refresh the refresh token if requested
let refresh_token_rotation = match params.refresh_token_rotation {
Some(refresh_token_rotation) => refresh_token_rotation,
None => false
};
if refresh_token_rotation {
// Delete the old refresh token
let _: redis::RedisResult<()> = conn.del(refresh_token_details.token_uuid.to_string()).await;
let refresh_token_details = match generate_refresh_token(&refresh_token_details.email) {
Ok(token_details) => token_details,
Err(err) => {
error!("Failed to generate refresh token: {}", err);
return ResponseError::error_response(&err)
}
};
let refresh_token_max_age = env::var("REFRESH_TOKEN_MAXAGE")
.expect("REFRESH_TOKEN_MAXAGE must be set")
.parse::<i64>()
.expect("REFRESH_TOKEN_MAXAGE must be an integer");
let refresh_result: redis::RedisResult<()> = conn.set_ex(refresh_token_details.token_uuid.to_string(), &refresh_token_details.email, (refresh_token_max_age * 60) as usize).await;
if let Err(err) = refresh_result {
error!("Failed to set refresh token in redis: {}", err);
return ResponseError::error_response(&ServiceError {
status: 500,
message: format!("Failed to set refresh token in redis: {}", err)
})
};
let refresh_cookie = Cookie::build("refresh_token", refresh_token_details.token.clone().unwrap())
.path("/")
.max_age(Duration::new(refresh_token_max_age * 60, 0))
.http_only(true)
.secure(true)
.finish();
HttpResponse::Ok()
.cookie(refresh_cookie)
.cookie(access_cookie)
.cookie(logged_in_cookie)
.json(JwtAuth { token: access_token_uuid, user: query_user.into() })
} else {
HttpResponse::Ok()
.cookie(access_cookie)
.cookie(logged_in_cookie)
.json(JwtAuth { token: access_token_uuid, user: query_user.into() })
}
},
Err(err) => return ResponseError::error_response(&err)
}
}
#[post("/logout")]
async fn logout(req: HttpRequest, auth: JwtAuth) -> HttpResponse {
let refresh_token = match req.cookie("refresh_token") {
Some(cookie) => cookie.value().to_string(),
None => return ResponseError::error_response(&ServiceError {
status: 401,
message: "Refresh token not found".to_string()
})
};
let public_key = env::var("REFRESH_TOKEN_PUBLIC_KEY")
.expect("REFRESH_TOKEN_PUBLIC_KEY must be set");
let refresh_token_details = match verify_token(&refresh_token, &public_key) {
Ok(token_details) => token_details,
Err(err) => return ResponseError::error_response(&err)
};
let mut conn = match db::redis_async_connection().await {
Ok(conn) => conn,
Err(err) => {
error!("Failed to get redis connection: {}", err);
return ResponseError::error_response(&err)
}
};
let access_result: redis::RedisResult<()> = conn.del(&[
refresh_token_details.token_uuid.to_string(),
auth.token.to_string()
]).await;
if let Err(err) = access_result {
error!("Failed to set access token in redis: {}", err);
return ResponseError::error_response(&ServiceError {
status: 500,
message: format!("Failed to set access token in redis: {}", err)
})
};
let access_cookie = Cookie::build("access_token", "")
.path("/")
.max_age(Duration::new(-1, 0))
.http_only(true)
.finish();
let refresh_cookie = Cookie::build("refresh_token", "")
.path("/")
.max_age(Duration::new(-1, 0))
.http_only(true)
.finish();
let logged_in_cookie = Cookie::build("logged_in", "")
.path("/")
.max_age(Duration::new(-1, 0))
.http_only(true)
.finish();
HttpResponse::Ok()
.cookie(access_cookie)
.cookie(refresh_cookie)
.cookie(logged_in_cookie)
.finish()
}
#[get("/me")]
async fn me(auth: JwtAuth) -> HttpResponse {
HttpResponse::Ok().json(auth)
}
#[get("/roles")]
async fn roles() -> HttpResponse {
HttpResponse::Ok().json(vec!["admin", "user"])
}
pub fn init_routes(config: &mut web::ServiceConfig) {
let r = RegisterUser {
email: "admin".to_string(),
password: "admin".to_string(),
first_name: "Admin".to_string(),
last_name: "Admin".to_string(),
};
let mut u = r.convert_to_insert().unwrap();
u.role = "admin".to_string();
u.verified = true;
let _ = InsertUser::insert(u);
config.service(web::scope("auth")
.service(register)
.service(login)
.service(refresh)
.service(logout)
.service(me)
.service(roles)
);
}

View File

@@ -0,0 +1,5 @@
mod model;
mod routes;
pub use model::*;
pub use routes::init_routes;

View File

View File

@@ -0,0 +1,396 @@
use std::{sync::Arc, pin::Pin};
use actix_web::{get, post, web, HttpResponse, ResponseError};
use log::warn;
use serde::{Serialize, Deserialize};
use serenity::model::prelude::{GuildChannel, ChannelType};
use siren::ServiceError;
use crate::{AppState, bot::commands::audio::{play::play_track, join}, db::guilds::QueryGuild, auth::{JwtAuth, verify_role}};
#[get("/guilds")]
async fn get_guilds(data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let guild_results = &data.http.get_guilds(None, None).await;
let guilds = match guild_results {
Ok(guilds) => guilds,
Err(err) => return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
};
HttpResponse::Ok().json(guilds)
}
#[get("/{id}/text")]
async fn get_text_channels(id: web::Path<String>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let channel_results = &data.http.get_channels(id.parse::<u64>().unwrap()).await;
let channels = match channel_results {
Ok(channels) => channels.iter().filter(|c| c.kind == ChannelType::Text).collect::<Vec<&GuildChannel>>(),
Err(err) => return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
};
HttpResponse::Ok().json(channels)
}
#[get("/{id}/voice")]
async fn get_voice_channels(id: web::Path<String>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let channel_results = &data.http.get_channels(id.parse::<u64>().unwrap()).await;
let channels = match channel_results {
Ok(channels) => channels.iter().filter(|c| c.kind == ChannelType::Voice).collect::<Vec<&GuildChannel>>(),
Err(err) => return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
};
HttpResponse::Ok().json(channels)
}
#[derive(Serialize, Deserialize)]
struct ChannelMessage {
message: String
}
#[post("/{guild_id}/text/{channel_id}/message")]
async fn send_message(path: web::Path<(String, String)>, text: web::Json<ChannelMessage>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let (guild_id, channel_id) = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
let channel_id = match channel_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse channel id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
let channel_results = &data.http.get_channels(guild_id).await;
let channels = match channel_results {
Ok(channels) => channels,
Err(err) => {
warn!("Could not get channels: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
let channel = match channels.iter().find(|c| c.id.0 == channel_id) {
Some(channel) => channel,
None => {
warn!("Could not find channel with id {}", channel_id);
return ResponseError::error_response(&ServiceError {
status: 422,
message: format!("Could not find channel with id {}", channel_id)
})
}
};
if let Err(err) = channel.say(&Pin::new(&data.http).get_ref(), &text.message).await {
warn!("Could not send message: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
};
HttpResponse::Ok().finish()
}
#[derive(Serialize, Deserialize)]
struct PlayRequest {
track_url: String
}
#[post("/{guild_id}/voice/{channel_id}/play")]
async fn play(path: web::Path<(String, String)>, play_request: web::Json<PlayRequest>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let (guild_id, channel_id) = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError { status: 422, message: err.to_string() })
}
};
let channel_id = match channel_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse channel id: {:?}", err);
return ResponseError::error_response(&ServiceError { status: 422, message: err.to_string() })
}
};
let http = Pin::new(&data.http).get_ref();
let guild = match http.get_guild(guild_id).await {
Ok(guild) => guild,
Err(err) => {
warn!("Could not get guild: {:?}", err);
return ResponseError::error_response(&ServiceError { status: 422, message: err.to_string() })
}
};
let channel = match http.get_channel(channel_id).await {
Ok(channel) => channel,
Err(err) => {
warn!("Could not get channel: {:?}", err);
return ResponseError::error_response(&ServiceError { status: 422, message: err.to_string() })
}
};
let manager = Arc::clone(&data.songbird);
match join(Arc::clone(&manager), &guild.id, &channel.id()).await {
Ok(_) => {
match play_track(Arc::clone(&data.songbird), guild.id, play_request.track_url.to_string()).await {
Ok(_) => HttpResponse::Ok().finish(),
Err(err) => {
warn!("Could not play track: {:?}", err);
return ResponseError::error_response(&err)
}
}
},
Err(err) => {
warn!("Could not join channel: {:?}", err);
return ResponseError::error_response(&ServiceError { status: 500, message: err.to_string() })
}
}
}
#[post("/{guild_id}/voice/stop")]
async fn stop(path: web::Path<String>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let guild_id = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
if let Some(handler_lock) = data.songbird.get(guild_id) {
let handler = handler_lock.lock().await;
handler.queue().stop();
}
HttpResponse::Ok().finish()
}
#[post("/{guild_id}/voice/resume")]
async fn resume(path: web::Path<String>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let guild_id = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
if let Some(handler_lock) = data.songbird.get(guild_id) {
let handler = handler_lock.lock().await;
if let Err(err) = handler.queue().resume() {
warn!("Could not resume track: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
}
HttpResponse::Ok().finish()
}
#[post("/{guild_id}/voice/pause")]
async fn pause(path: web::Path<String>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let guild_id = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
if let Some(handler_lock) = data.songbird.get(guild_id) {
let handler = handler_lock.lock().await;
if let Err(err) = handler.queue().pause() {
warn!("Could not pause track: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
}
HttpResponse::Ok().finish()
}
#[derive(Serialize, Deserialize)]
struct SetVolume {
volume: String
}
#[get("/{guild_id}/voice/volume")]
async fn get_volume(path: web::Path<String>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let guild_id = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
let volume = match QueryGuild::get(guild_id as i64) {
Ok(guild) => guild.volume,
Err(err) => {
warn!("Could not get volume: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
HttpResponse::Ok().json(volume)
}
#[post("/{guild_id}/voice/volume")]
async fn set_volume(path: web::Path<String>, volume: web::Json::<SetVolume>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let guild_id = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
let volume = volume.volume.parse::<i32>().unwrap_or(0);
let manager = Arc::clone(&data.songbird);
let http = Arc::clone(&data.http);
let guild = match http.get_guild(guild_id).await {
Ok(guild) => guild,
Err(err) => {
warn!("Could not get guild: {:?}", err);
return ResponseError::error_response(&ServiceError { status: 422, message: err.to_string() })
}
};
crate::bot::commands::audio::volume::set_volume(manager, guild.id, volume).await;
HttpResponse::Ok().finish()
}
#[post("/{guild_id}/voice/skip")]
async fn skip(path: web::Path<String>, data: web::Data<Arc<AppState>>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let guild_id = path.into_inner();
let guild_id = match guild_id.parse::<u64>() {
Ok(id) => id,
Err(err) => {
warn!("Could not parse guild id: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
};
if let Some(handler_lock) = data.songbird.get(guild_id) {
let handler = handler_lock.lock().await;
if let Err(err) = handler.queue().skip() {
warn!("Could not skip track: {:?}", err);
return ResponseError::error_response(&ServiceError {
status: 422,
message: err.to_string()
})
}
}
HttpResponse::Ok().finish()
}
pub fn init_routes(config: &mut web::ServiceConfig) {
config
.service(get_guilds)
.service(web::scope("guilds")
.service(get_text_channels)
.service(get_voice_channels)
.service(send_message)
.service(play)
.service(stop)
.service(resume)
.service(pause)
.service(set_volume)
.service(get_volume)
.service(skip)
);
}

View File

@@ -1,12 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;
use log::debug;
use log::{debug, warn};
use serenity::client::Cache;
use serenity::model::application::interaction::{InteractionResponseType, application_command::ApplicationCommandInteraction};
use serenity::model::prelude::{GuildId, ChannelId};
use serenity::model::user::User;
use serenity::prelude::*;
use siren::ServiceError;
use songbird::{Call, Songbird};
use songbird::input::{Restartable, Input, Metadata, error::Error as SongbirdError};
@@ -17,58 +18,33 @@ pub mod skip;
pub mod stop;
pub mod volume;
#[derive(Clone, Debug)]
pub struct AudioConfigs;
impl TypeMapKey for AudioConfigs {
type Value = Arc<RwLock<HashMap<GuildId, AudioConfig>>>;
}
#[derive(Clone, Debug)]
pub struct AudioConfig {
pub volume: f32
}
/// Joins a Discord voice channel.
///
/// # Arguments
/// - ctx - The context of the command.
/// - guild_id_option - The guild ID of the guild to join.
/// - user - The user that is requesting to join the voice channel.
///
/// # Returns
/// Result<(), String> - Ok if the bot successfully joined the voice channel, Err if there was an error.
pub async fn join(ctx: &Context, guild_id_option: &Option<GuildId>, user: &User) -> Result<(), String> {
pub async fn join_by_user(cache: &Arc<Cache>, manager: Arc<Songbird>, guild_id_option: &Option<GuildId>, user: &User) -> Result<(), ServiceError> {
let guild_id = match guild_id_option {
Some(g) => g,
None => {
return Err(format!("{}", "No guild ID set"));
}
None => return Err(ServiceError { status: 422, message: format!("{}", "No guild ID set") })
};
let channel_id = match find_voice_channel(&ctx, &guild_id, &user) {
let channel_id = match find_voice_channel(cache, &guild_id, &user) {
Ok(channel) => channel,
Err(err) => return Err(format!("{}", err))
Err(err) => return Err(ServiceError { status: 500, message: err.to_string() })
};
debug!("<{}> Joining channel {}", guild_id.0, channel_id);
let manager = get_songbird(ctx).await;
join(manager, guild_id, &channel_id).await
}
pub async fn join(manager: Arc<Songbird>, guild_id: &GuildId, channel_id: &ChannelId) -> Result<(), ServiceError> {
debug!("<{}> Joining channel {}", guild_id.0, channel_id.0);
let (_handle_lock, success) = manager.join(guild_id.to_owned(), channel_id.to_owned()).await;
match success {
Ok(s) => Ok(s),
Err(err) => Err(format!("{}", err))
Err(err) => {
warn!("Failed to join channel: {:?}", err);
Err(ServiceError { status: 500, message: err.to_string() })
}
}
}
/// Leaves a Discord voice channel.
///
/// # Arguments
/// - ctx - The context of the command.
/// - guild_id_option - The guild ID of the guild to leave.
///
/// # Returns
/// Result<(), String> - Ok if the bot successfully left the voice channel, Err if there was an error.
pub async fn leave(ctx: &Context, guild_id_option: &Option<GuildId>) -> Result<(), String> {
pub async fn leave(manager: Arc<Songbird>, guild_id_option: &Option<GuildId>) -> Result<(), String> {
let guild_id = match guild_id_option {
Some(g) => g,
None => {
@@ -76,7 +52,6 @@ pub async fn leave(ctx: &Context, guild_id_option: &Option<GuildId>) -> Result<(
}
};
let manager = get_songbird(ctx).await;
if manager.get(*guild_id).is_some() {
debug!("<{}> Disconnecting from channel", guild_id.0);
if let Err(e) = manager.remove(*guild_id).await {
@@ -86,17 +61,8 @@ pub async fn leave(ctx: &Context, guild_id_option: &Option<GuildId>) -> Result<(
Ok(())
}
/// Finds the voice channel that the user is in.
///
/// # Arguments
/// - ctx - The context of the command.
/// - guild_id - The guild ID of the guild to search.
/// - user - The user to search for.
///
/// # Returns
/// Result<ChannelId, String> - Ok if the user is in a voice channel, Err if the user is not in a voice channel.
fn find_voice_channel(ctx: &Context, guild_id: &GuildId, user: &User) -> Result<ChannelId, String> {
let guild = match guild_id.to_guild_cached(ctx.cache.to_owned()) {
fn find_voice_channel(cache: &Arc<Cache>, guild_id: &GuildId, user: &User) -> Result<ChannelId, String> {
let guild = match guild_id.to_guild_cached(cache.to_owned()) {
Some(g) => g,
None => return Err(format!("Guild not found"))
};
@@ -107,15 +73,6 @@ fn find_voice_channel(ctx: &Context, guild_id: &GuildId, user: &User) -> Result<
}
}
/// Creates a response to an interaction.
///
/// # Arguments
/// - ctx - The context of the command.
/// - command - The command that was sent.
/// - content - The content of the response.
///
/// # Returns
/// Result<(), SerenityError> - Ok if the response was created successfully, Err if there was an error.
pub async fn create_response(ctx: &Context, command: &ApplicationCommandInteraction, content: String) -> Result<(), SerenityError> {
command.create_interaction_response(&ctx.http, |response: &mut serenity::builder::CreateInteractionResponse<'_>| {
response
@@ -124,31 +81,13 @@ pub async fn create_response(ctx: &Context, command: &ApplicationCommandInteract
}).await
}
/// Edits a response to an interaction.
///
/// # Arguments
/// - ctx - The context of the command.
/// - command - The command that was sent.
/// - content - The content of the response.
///
/// # Returns
/// Result<Message, SerenityError> - Ok if the response was edited successfully, Err if there was an error.
pub async fn edit_response(ctx: &Context, command: &ApplicationCommandInteraction, content: String) -> Result<serenity::model::channel::Message, SerenityError> {
command.edit_original_interaction_response(&ctx.http, |response: &mut serenity::builder::EditInteractionResponse| {
response.content(content)
}).await
}
/// Adds a song to the queue.
///
/// # Arguments
/// - call - The call to add the song to.
/// - url - The URL of the song to add.
/// - lazy - Whether or not to lazy load the song.
///
/// # Returns
/// Result<Metadata, SongbirdError> - Ok if the song was added successfully, Err if there was an error.
pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool, audio_config: Option<&AudioConfig>) -> Result<Metadata, SongbirdError> {
pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool, volume: Option<f32>) -> Result<Metadata, SongbirdError> {
let source = if is_valid_url(url) {
Restartable::ytdl(url.to_owned(), lazy).await?
} else {
@@ -158,19 +97,12 @@ pub async fn add_song(call: Arc<Mutex<Call>>, url: &str, lazy: bool, audio_confi
let track: Input = source.into();
let metadata = *track.metadata.clone();
let track_handle = handler.enqueue_source(track);
if let Some(ac) = audio_config {
let _ = track_handle.set_volume(ac.volume);
if let Some(volume) = volume {
let _ = track_handle.set_volume(volume);
}
Ok(metadata)
}
/// Checks if a string is a valid URL.
///
/// # Arguments
/// - url - The string to check.
///
/// # Returns
/// bool - True if the string is a valid URL, false if it is not.
fn is_valid_url(url: &str) -> bool {
match url.parse::<reqwest::Url>() {
Ok(_) => return true,
@@ -178,13 +110,6 @@ fn is_valid_url(url: &str) -> bool {
}
}
/// Gets the Songbird voice client.
///
/// # Arguments
/// - ctx - The context of the command.
///
/// # Returns
/// Arc<Songbird> - The Songbird voice client.
pub async fn get_songbird(ctx: &Context) -> Arc<Songbird> {
songbird::get(ctx).await.expect("Songbird Voice client placed in at initialization")
}

View File

@@ -1,13 +1,18 @@
use std::sync::Arc;
use log::{debug, warn, error};
use serenity::model::prelude::GuildId;
use serenity::{prelude::*, async_trait};
use serenity::builder::CreateApplicationCommand;
use serenity::model::application::interaction::application_command::ApplicationCommandInteraction;
use songbird::EventHandler;
use siren::ServiceError;
use songbird::{EventHandler, Songbird};
use crate::bot::commands::audio::{join, leave, add_song, get_songbird, AudioConfigs};
use crate::bot::commands::audio::{leave, add_song, get_songbird};
use crate::db::guilds::QueryGuild;
use super::{create_response, edit_response};
use super::{create_response, edit_response, join_by_user};
pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
// Get the track url
@@ -46,7 +51,8 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
return;
}
match join(&ctx, &command.guild_id, &command.user).await {
let manager = get_songbird(ctx).await;
match join_by_user(&ctx.cache, manager,&command.guild_id, &command.user).await {
Ok(_) => {
let guild_id = match command.guild_id {
Some(g) => g,
@@ -58,41 +64,20 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
}
};
debug!("Play command executed with track: {:?}", track_url);
let manager = get_songbird(ctx).await;
if let Some(handler_lock) = manager.get(guild_id) {
let is_queue_empty = {
let call_handler = handler_lock.lock().await;
call_handler.queue().is_empty()
};
let audio_config = {
let data_read = ctx.data.read().await;
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
};
let ac = audio_config.read().await;
match add_song(handler_lock.clone(), &track_url, is_queue_empty, ac.get(&guild_id)).await {
Ok(added_song) => {
let track_title = added_song.title.unwrap();
debug!("Added track: {}", track_title);
if let Err(why) = edit_response(&ctx, &command, format!("Added track to queue: {}", track_title)).await {
error!("Failed to edit response message: {}", why);
}
let mut handler = handler_lock.lock().await;
handler.remove_all_global_events();
handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager })
match play_track(manager, guild_id, track_url).await {
Ok(_) => {
if let Err(why) = edit_response(&ctx, &command, "Playing track".to_string()).await {
error!("Failed to edit response message: {}", why);
}
Err(why) => {
warn!("Failed to add song: {}", why);
if let Err(why) = edit_response(&ctx, &command, format!("Failed to add song: {}", why)).await {
error!("Failed to edit response message: {}", why);
}
if let Err(why) = leave(&ctx, &command.guild_id).await {
error!("Failed to leave voice channel: {}", why);
}
return;
},
Err(err) => {
warn!("Failed to play track: {}", err);
if let Err(why) = edit_response(&ctx, &command, format!("Failed to play track: {}", err)).await {
error!("Failed to edit response message: {}", why);
}
};
}
}
};
},
Err(err) => {
warn!("{}", err);
@@ -103,6 +88,33 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
}
}
pub async fn play_track(manager: Arc<Songbird>, guild_id: GuildId, track_url: String) -> Result<(), ServiceError> {
if let Some(handler_lock) = manager.get(guild_id) {
let is_queue_empty = {
let call_handler = handler_lock.lock().await;
call_handler.queue().is_empty()
};
let guild = QueryGuild::get(guild_id.0 as i64)?;
match add_song(handler_lock.clone(), &track_url, is_queue_empty, Some(guild.volume as f32)).await {
Ok(added_song) => {
let track_title = added_song.title.unwrap();
debug!("Added track: {}", track_title);
let mut handler = handler_lock.lock().await;
handler.remove_all_global_events();
handler.add_global_event(songbird::Event::Track(songbird::TrackEvent::End), TrackEndNotifier { guild_id, call: manager })
},
Err(err) => {
warn!("Failed to add song: {}", err);
if let Err(why) = leave(manager, &Some(guild_id)).await {
error!("Failed to leave voice channel: {}", why);
}
return Err(ServiceError { status: 422, message: err.to_string() })
}
}
}
Ok(())
}
pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand {
command.name("play").description("Plays the given track").create_option(|option| { option
.name("track")

View File

@@ -1,17 +1,22 @@
use std::sync::Arc;
use log::{error, warn};
use serenity::prelude::*;
use serenity::{prelude::*, model::prelude::GuildId};
use serenity::builder::CreateApplicationCommand;
use serenity::model::application::interaction::application_command::ApplicationCommandInteraction;
use songbird::Songbird;
use super::{get_songbird, create_response, edit_response, AudioConfigs, AudioConfig};
use crate::db::guilds::InsertGuild;
use super::{get_songbird, create_response, edit_response};
pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
// Get the volume
let volume = match command.data.options.get(0) {
Some(t) => match &t.value {
Some(v) => match v.as_i64() {
Some(p) => std::cmp::min(100, std::cmp::max(0, p)),
Some(p) => p as i32,
None => {
warn!("Unable to get volume option as a string");
if let Err(why) = create_response(&ctx, &command, format!("Volume option is missing")).await {
@@ -37,9 +42,6 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
}
};
// Format volume to f32 bound between 0.0 and 1.0
let bound_volume = volume as f32 / 100.0;
// Create the initial response
if let Err(why) = create_response(&ctx, &command, "Processing command...".to_string()).await {
error!("Failed to create response message: {}", why);
@@ -55,24 +57,25 @@ pub async fn run(ctx: &Context, command: &ApplicationCommandInteraction) {
return;
}
};
let audio_config_lock = {
let data_read = ctx.data.read().await;
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
};
{
let mut audio_configs = audio_config_lock.write().await;
*audio_configs.entry(guild_id).or_insert(AudioConfig { volume: 1.0 }) = AudioConfig { volume: bound_volume };
}
let manager = get_songbird(ctx).await;
set_volume(manager, guild_id, volume).await;
if let Err(why) = edit_response(&ctx, &command, format!("Setting the volume to {}", volume)).await {
error!("Failed to set the volume: {}", why);
}
}
pub async fn set_volume(manager: Arc<Songbird>, guild_id: GuildId, volume: i32) {
// Format volume to f32 bound between 0.0 and 1.0
let volume = std::cmp::min(100, std::cmp::max(0, volume));
let bound_volume = volume as f32 / 100.0;
let _ = InsertGuild::update_audio(guild_id.0 as i64, volume);
if let Some(handler_lock) = manager.get(guild_id) {
let handler = handler_lock.lock().await;
for (_, track_handle) in handler.queue().current_queue().iter().enumerate() {
let _ = track_handle.set_volume(bound_volume);
}
}
if let Err(why) = edit_response(&ctx, &command, format!("Setting the volume to {}", volume)).await {
error!("Failed to set the volume: {}", why);
}
}
pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand {

View File

View File

@@ -1,5 +1,6 @@
pub mod audio;
pub mod help;
pub mod message;
pub mod oai;
pub mod ping;
pub mod schedule;

View File

@@ -0,0 +1,97 @@
use log::{warn, info, error};
use serenity::async_trait;
use serenity::model::application::interaction::Interaction;
use serenity::model::gateway::Ready;
use serenity::model::channel::Message;
use serenity::prelude::*;
use crate::db::guilds::InsertGuild;
use super::commands;
use super::commands::audio::create_response;
pub struct Handler {
// Open AI Config
pub oai: Option<commands::oai::OAI>
}
#[async_trait]
impl EventHandler for Handler {
async fn message(&self, ctx: Context, msg: Message) {
// Ignore messages from bots
if msg.author.bot {
return;
}
match &self.oai {
Some(oai) => {
match msg.mentions_me(&ctx.http).await {
Ok(mentioned) => {
let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
Ok(t) => {
match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) {
Some(_) => true,
None => false
}
}
Err(_) => false
};
if mentioned || bot_in_thread {
commands::oai::generate_response(&ctx, &msg, oai).await;
}
}
Err(why) => warn!("Could not check mentions: {:?}", why)
};
}
None => {}
}
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::ApplicationCommand(command) = interaction {
match command.data.name.as_str() {
"play" => commands::audio::play::run(&ctx, &command).await,
"stop" => commands::audio::stop::run(&ctx, &command).await,
"pause" => commands::audio::pause::run(&ctx, &command).await,
"resume" => commands::audio::resume::run(&ctx, &command).await,
"skip" => commands::audio::skip::run(&ctx, &command).await,
"volume" => commands::audio::volume::run(&ctx, &command).await,
_ => {
let content: String = match command.data.name.as_str() {
"ping" => commands::ping::run(&command.data.options),
_ => "Unknown command".to_string()
};
if let Err(why) = create_response(&ctx, &command, content).await {
warn!("Cannot respond to slash command: {}", why);
}
}
}
}
}
async fn ready(&self, ctx: Context, ready: Ready) {
if ready.guilds.is_empty() {
warn!("No ready guilds found");
}
for guild in ready.guilds {
let _ = InsertGuild::insert(InsertGuild {
id: (guild.id.0 as i64),
bot_id: ctx.cache.current_user().id.0 as i64,
volume: 100
});
let commands = guild.id.set_application_commands(&ctx.http, |commands| {
commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::stop::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::volume::register(command) })
}).await;
match commands {
Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0),
Err(why) => error!("Could not register commands for guild {}: {:?}", guild.id.0, why)
};
}
}
}

View File

@@ -1,169 +1,3 @@
use std::collections::{HashSet, HashMap};
use std::env;
use std::sync::Arc;
use commands::audio::{create_response, AudioConfig, AudioConfigs};
use log::{error, warn, info};
use serenity::async_trait;
use serenity::framework::StandardFramework;
use serenity::model::application::interaction::Interaction;
use serenity::model::gateway::Ready;
use serenity::model::channel::Message;
use serenity::http::Http;
use serenity::prelude::*;
use songbird::SerenityInit;
use crate::bot::commands::oai::GPTModel;
pub mod api;
pub mod commands;
struct Handler {
// Open AI Config
oai: Option<commands::oai::OAI>
}
#[async_trait]
impl EventHandler for Handler {
async fn message(&self, ctx: Context, msg: Message) {
// Ignore messages from bots
if msg.author.bot {
return;
}
match &self.oai {
Some(oai) => {
match msg.mentions_me(&ctx.http).await {
Ok(mentioned) => {
let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
Ok(t) => {
match t.iter().find(|t| t.user_id.unwrap().0 == ctx.cache.current_user_id().0) {
Some(_) => true,
None => false
}
}
Err(_) => false
};
if mentioned || bot_in_thread {
commands::oai::generate_response(&ctx, &msg, oai).await;
}
}
Err(why) => warn!("Could not check mentions: {:?}", why)
};
}
None => {}
}
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::ApplicationCommand(command) = interaction {
match command.data.name.as_str() {
"play" => commands::audio::play::run(&ctx, &command).await,
"stop" => commands::audio::stop::run(&ctx, &command).await,
"pause" => commands::audio::pause::run(&ctx, &command).await,
"resume" => commands::audio::resume::run(&ctx, &command).await,
"skip" => commands::audio::skip::run(&ctx, &command).await,
"volume" => commands::audio::volume::run(&ctx, &command).await,
_ => {
let content: String = match command.data.name.as_str() {
"ping" => commands::ping::run(&command.data.options),
_ => "Unknown command".to_string()
};
if let Err(why) = create_response(&ctx, &command, content).await {
warn!("Cannot respond to slash command: {}", why);
}
}
}
}
}
async fn ready(&self, ctx: Context, ready: Ready) {
if ready.guilds.is_empty() {
warn!("No ready guilds found");
}
for guild in ready.guilds {
let audio_config_lock = {
let data_read = ctx.data.read().await;
data_read.get::<AudioConfigs>().expect("Expected AudioConfigs in TypeMap.").clone()
};
{
let mut audio_configs = audio_config_lock.write().await;
let _ = audio_configs.insert(guild.id, AudioConfig { volume: 1.0 });
}
let commands = guild.id.set_application_commands(&ctx.http, |commands| {
commands.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::ping::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::play::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::stop::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::pause::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::resume::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::skip::register(command) })
.create_application_command(|command: &mut serenity::builder::CreateApplicationCommand| { commands::audio::volume::register(command) })
}).await;
match commands {
Ok(c) => info!("Registered {} commands for guild {}", c.len(), guild.id.0),
Err(why) => error!("Could not register commands for guild {}: {:?}", guild.id.0, why)
};
}
}
}
pub async fn run() {
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
let intents: GatewayIntents = GatewayIntents::all();
let http: Http = Http::new(&token);
let (owners, _bot_id) = match http.get_current_application_info().await {
Ok(info) => {
let mut owners: HashSet<serenity::model::id::UserId> = HashSet::new();
if let Some(team) = info.team {
owners.insert(team.owner_user_id);
} else {
owners.insert(info.owner.id);
}
match http.get_current_user().await {
Ok(bot) => (owners, bot.id),
Err(why) => panic!("Could not access the bot id: {:?}", why)
}
},
Err(why) => panic!("Could not access application info: {:?}", why)
};
let handler = match env::var("OPENAI_API_KEY") {
Ok(token) => {
info!("Loaded OpenAI token");
Handler {
oai: Some(commands::oai::OAI {
client: reqwest::Client::new(),
base_url: "https://api.openai.com/v1".to_string(),
service_url: "http://localhost:5000".to_string(),
max_attempts: 5,
token,
max_context_questions: 30,
max_tokens: 2048,
default_model: GPTModel::GPT35Turbo,
})
}
}
Err(err) => {
warn!("Could not load OpenAI token: {}", err);
Handler { oai: None }
}
};
let mut client = Client::builder(token, intents)
.event_handler(handler)
.framework(StandardFramework::new()
.configure(|c| c.owners(owners)))
.register_songbird()
.await
.expect("Error creating client");
{
let mut data = client.data.write().await;
data.insert::<AudioConfigs>(Arc::new(RwLock::new(HashMap::default())));
}
if let Err(why) = client.start_autosharded().await {
error!("An error occurred while running the client: {:?}", why);
}
}
pub mod handler;

View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

View File

@@ -0,0 +1,43 @@
use diesel::prelude::*;
use serde::{Serialize, Deserialize};
use siren::ServiceError;
use crate::db::{schema::guilds, connection};
#[derive(Queryable, QueryableByName, Serialize, Deserialize)]
#[diesel(table_name = guilds)]
pub struct QueryGuild {
pub id: i64,
pub bot_id: i64,
pub volume: i32
}
impl QueryGuild {
pub fn get(id: i64) -> Result<Self, ServiceError> {
let mut conn = connection()?;
let guild = guilds::table.filter(guilds::id.eq(id)).first(&mut conn)?;
Ok(guild)
}
}
#[derive(Insertable, AsChangeset, Serialize, Deserialize)]
#[diesel(table_name = guilds)]
pub struct InsertGuild {
pub id: i64,
pub bot_id: i64,
pub volume: i32
}
impl InsertGuild {
pub fn insert(guild: Self) -> Result<QueryGuild, ServiceError> {
let mut conn = connection()?;
let guild = diesel::insert_into(guilds::table).values(guild).get_result(&mut conn)?;
Ok(guild)
}
pub fn update_audio(id: i64, volume: i32) -> Result<QueryGuild, ServiceError> {
let mut conn = connection()?;
let guild = diesel::update(guilds::table.filter(guilds::id.eq(id))).set(guilds::volume.eq(volume)).get_result(&mut conn)?;
Ok(guild)
}
}

View File

@@ -3,7 +3,7 @@ use log::error;
use serde::{Serialize, Deserialize};
use siren::{GetResponse, Metadata, ServiceError};
use crate::db::messages::{QueryMessage, QueryFilters, InsertMessage};
use crate::{db::messages::{QueryMessage, QueryFilters, InsertMessage}, auth::{JwtAuth, verify_role}};
#[derive(Serialize, Deserialize)]
struct GetAllParams {
@@ -21,7 +21,11 @@ struct GetAllParams {
}
#[get("/messages")]
async fn get_all(req: HttpRequest) -> HttpResponse {
async fn get_all(req: HttpRequest, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let params = match web::Query::<GetAllParams>::from_query(req.query_string()) {
Ok(params) => params,
Err(err) => return ResponseError::error_response(&ServiceError {
@@ -64,7 +68,11 @@ async fn get_all(req: HttpRequest) -> HttpResponse {
}
#[post("/messages")]
async fn create(message: web::Json<InsertMessage>) -> HttpResponse {
async fn create(message: web::Json<InsertMessage>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
match InsertMessage::insert(message.into_inner()) {
Ok(message) => HttpResponse::Created().json(message),
Err(err) => {

View File

@@ -1,4 +1,6 @@
use diesel::{r2d2::ConnectionManager, PgConnection};
use diesel::{r2d2::ConnectionManager as DieselConnectionManager, PgConnection};
// use redis::{aio::{Connection as RedisConnection, ConnectionManager as RedisConnectionManager}, AsyncCommands};
use redis::aio::Connection as RedisConnection;
use siren::ServiceError;
use crate::diesel_migrations::MigrationHarness;
use lazy_static::lazy_static;
@@ -11,30 +13,39 @@ pub mod bestiary;
pub mod classes;
pub mod conditions;
pub mod feats;
pub mod guilds;
pub mod items;
pub mod messages;
pub mod options;
pub mod races;
pub mod spells;
pub mod users;
pub mod schema;
type Pool = r2d2::Pool<ConnectionManager<PgConnection>>;
pub type DbConnection = r2d2::PooledConnection<ConnectionManager<PgConnection>>;
type DbPool = r2d2::Pool<DieselConnectionManager<PgConnection>>;
pub type DbConnection = r2d2::PooledConnection<DieselConnectionManager<PgConnection>>;
// type RedisPool = r2d2::Pool<redis::ConnectionManager>;
pub const MIGRATIONS: diesel_migrations::EmbeddedMigrations = embed_migrations!();
lazy_static! {
static ref POOL: Pool = {
static ref POOL: DbPool = {
let username = env::var("DATABASE_USER").expect("DATABASE_USERNAME is not set");
let password = env::var("DATABASE_PASSWORD").expect("DATABASE_PASSWORD is not set");
let host = env::var("DATABASE_HOST").unwrap_or("localhost".to_string());
let name = env::var("DATABASE_NAME").expect("DATABASE_NAME is not set");
let port = env::var("DATABASE_PORT").unwrap_or("5432".to_string());
let url = format!("postgres://{}:{}@{}:{}/{}", username, password, host, port, name);
let manager = ConnectionManager::<PgConnection>::new(url);
Pool::builder().test_on_check_out(true).build(manager).expect("Failed to create db pool")
let manager = DieselConnectionManager::<PgConnection>::new(url);
DbPool::builder().test_on_check_out(true).build(manager).expect("Failed to create db pool")
};
// static ref REDIS_POOL: RedisPool = {
// let host = env::var("REDIS_HOST").unwrap_or("localhost".to_string());
// let port = env::var("REDIS_PORT").unwrap_or("6379".to_string());
// let url = format!("redis://{}:{}", host, port);
// let client = redis::Client::open(url).expect("Failed to create redis client");
// let manager = RedisConnectionManager::new(client);
// "".to_string()
// };
}
pub fn init() {
@@ -51,6 +62,26 @@ pub fn connection() -> Result<DbConnection, ServiceError> {
.map_err(|e| ServiceError::new(500, format!("Failed getting db connection: {}", e)))
}
pub fn redis_client() -> Result<redis::Client, ServiceError> {
let host = env::var("REDIS_HOST").unwrap_or("localhost".to_string());
let port = env::var("REDIS_PORT").unwrap_or("6379".to_string());
let url = format!("redis://{}:{}", host, port);
let client = redis::Client::open(url)?;
Ok(client)
}
pub fn redis_connection() -> Result<redis::Connection, ServiceError> {
let client = redis_client()?;
let conn = client.get_connection()?;
Ok(conn)
}
pub async fn redis_async_connection() -> Result<RedisConnection, ServiceError> {
let client = redis_client()?;
let conn = client.get_async_connection().await?;
Ok(conn)
}
pub fn load_data(data_dir_path: &str) {
spells::load_data(data_dir_path);
}

View File

@@ -30,3 +30,22 @@ diesel::table! {
data -> Jsonb
}
}
diesel::table! {
guilds (id) {
id -> BigInt,
bot_id -> BigInt,
volume -> Integer,
}
}
diesel::table! {
users (email) {
email -> Text,
hash -> Text,
role -> Text,
first_name -> Text,
last_name -> Text,
verified -> Bool,
}
}

View File

@@ -6,7 +6,7 @@ use crate::db::{schema::spells::{self}, classes::AbilityType, conditions::Condit
use super::{SchoolType, CastingTime, SpellAttackType, SpellDamageType, Range, Area, Components, Duration, Source, Description, DurationType, Effect};
#[derive(Queryable, QueryableByName, Serialize, Deserialize)]
#[derive(Debug, Queryable, QueryableByName, Serialize, Deserialize)]
#[diesel(table_name = spells)]
pub struct QuerySpell {
pub id: i32,
@@ -163,7 +163,7 @@ impl QuerySpell {
}
}
#[derive(Insertable, AsChangeset)]
#[derive(Debug, Insertable, AsChangeset)]
#[diesel(table_name = spells)]
pub struct InsertSpell {
pub name: String,

View File

@@ -3,7 +3,7 @@ use log::error;
use serde::{Serialize, Deserialize};
use siren::{GetResponse, Metadata, ServiceError};
use crate::db::spells::{QuerySpell, QueryFilters};
use crate::{db::spells::{QuerySpell, QueryFilters}, auth::{JwtAuth, verify_role}};
use super::{Spell, InsertSpell};
@@ -134,7 +134,11 @@ async fn get_by_id(id: web::Path<String>) -> HttpResponse {
}
#[post("/spells")]
async fn create(spell: web::Json<Spell>) -> HttpResponse {
async fn create(spell: web::Json<Spell>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
match InsertSpell::insert(spell.into_inner().into()) {
Ok(spell) => HttpResponse::Created().json(Spell::from(spell)),
Err(err) => {
@@ -145,7 +149,11 @@ async fn create(spell: web::Json<Spell>) -> HttpResponse {
}
#[put("/spells/{id}")]
async fn update(id: web::Path<String>, spell: web::Json<Spell>) -> HttpResponse {
async fn update(id: web::Path<String>, spell: web::Json<Spell>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let id = match id.parse::<i32>() {
Ok(id) => id,
Err(err) => return ResponseError::error_response(&ServiceError {
@@ -163,7 +171,11 @@ async fn update(id: web::Path<String>, spell: web::Json<Spell>) -> HttpResponse
}
#[delete("/spells/{id}")]
async fn delete(id: web::Path<String>) -> HttpResponse {
async fn delete(id: web::Path<String>, auth: JwtAuth) -> HttpResponse {
let _ = match verify_role(&auth, "admin") {
Ok(_) => {},
Err(err) => return ResponseError::error_response(&err)
};
let id = match id.parse::<i32>() {
Ok(id) => id,
Err(err) => return ResponseError::error_response(&ServiceError {
@@ -181,8 +193,10 @@ async fn delete(id: web::Path<String>) -> HttpResponse {
}
pub fn init_routes(config: &mut web::ServiceConfig) {
config.service(get_all);
config.service(get_by_id);
config.service(create);
config.service(delete);
config.service(web::scope("dnd")
.service(get_all)
.service(get_by_id)
.service(create)
.service(update)
);
}

View File

@@ -177,54 +177,13 @@ pub struct Range {
#[derive(Debug, Serialize, Deserialize)]
pub struct Area {
#[serde(rename = "type")]
pub area_type: AreaType,
pub area_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub value: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub unit: Option<String>
}
#[derive(Debug, Serialize, Deserialize)]
pub enum AreaType {
#[serde(rename = "cone")]
Cone,
#[serde(rename = "cube")]
Cube,
#[serde(rename = "cylinder")]
Cylinder,
#[serde(rename = "line")]
Line,
#[serde(rename = "sphere")]
Sphere
}
// impl AreaType {
// pub fn to_string(&self) -> String {
// match self {
// AreaType::Cone => "cone".to_string(),
// AreaType::Cube => "cube".to_string(),
// AreaType::Cylinder => "cylinder".to_string(),
// AreaType::Line => "line".to_string(),
// AreaType::Sphere => "sphere".to_string()
// }
// }
// }
// impl FromStr for AreaType {
// type Err = ();
// fn from_str(s: &str) -> Result<Self, Self::Err> {
// match s {
// "cone" => Ok(AreaType::Cone),
// "cube" => Ok(AreaType::Cube),
// "cylinder" => Ok(AreaType::Cylinder),
// "line" => Ok(AreaType::Line),
// "sphere" => Ok(AreaType::Sphere),
// _ => Err(())
// }
// }
// }
#[derive(Debug, Serialize, Deserialize)]
pub struct Duration {
#[serde(rename = "type")]
@@ -263,7 +222,7 @@ pub struct Description {
#[derive(Debug)]
pub struct Entry {
pub text: Option<Vec<String>>,
pub text: Option<String>,
pub list: Option<Vec<String>>,
pub table: Option<EntryTable>
}
@@ -279,11 +238,18 @@ impl<'de> Deserialize<'de> for Entry {
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::String(s) => Ok(Entry {
text: Some(vec![s]),
text: Some(s),
list: None,
table: None,
}),
serde_json::Value::Object(o) => {
let text = match o.get("text") {
Some(t) => match t.as_str() {
Some(s) => Some(s.to_string()),
None => return Err(serde::de::Error::custom("Invalid entry text"))
},
None => None
};
let list = match o.get("list") {
Some(i) => match i.as_array() {
Some(a) => {
@@ -352,7 +318,7 @@ impl<'de> Deserialize<'de> for Entry {
None => None
};
Ok(Entry {
text: None,
text,
list,
table
})

View File

@@ -1,3 +0,0 @@
mod model;
pub use model::*;

View File

@@ -1,3 +0,0 @@
pub struct User {
pub id: i32
}

0
service/src/dnd/mod.rs Normal file
View File

View File

@@ -57,7 +57,14 @@ impl fmt::Display for ServiceError {
impl From<DieselError> for ServiceError {
fn from(error: DieselError) -> ServiceError {
match error {
DieselError::DatabaseError(_, err) => ServiceError::new(409, err.message().to_string()),
DieselError::DatabaseError(kind, err) => {
match kind {
diesel::result::DatabaseErrorKind::UniqueViolation => {
ServiceError::new(409, err.message().to_string())
},
_ => ServiceError::new(500, err.message().to_string())
}
},
DieselError::NotFound => {
ServiceError::new(404, "The record was not found".to_string())
},
@@ -81,6 +88,30 @@ impl From<serde_json::Error> for ServiceError {
}
}
impl From<serenity::Error> for ServiceError {
fn from(error: serenity::Error) -> ServiceError {
ServiceError::new(500, format!("Unknown serenity error: {}", error))
}
}
impl From<argon2::password_hash::Error> for ServiceError {
fn from(error: argon2::password_hash::Error) -> ServiceError {
ServiceError::new(500, format!("Unknown argon2 error: {}", error))
}
}
impl From<jsonwebtoken::errors::Error> for ServiceError {
fn from(error: jsonwebtoken::errors::Error) -> ServiceError {
ServiceError::new(500, format!("Unknown jsonwebtoken error: {}", error))
}
}
impl From<redis::RedisError> for ServiceError {
fn from(error: redis::RedisError) -> ServiceError {
ServiceError::new(500, format!("Unknown redis error: {}", error))
}
}
impl ResponseError for ServiceError {
fn error_response(&self) -> HttpResponse {
let status_code = match StatusCode::from_u16(self.status) {

View File

@@ -3,13 +3,23 @@ extern crate diesel;
extern crate diesel_migrations;
use std::env;
use std::collections::HashSet;
use std::sync::Arc;
use log::{error, warn, info};
use serenity::client::Cache;
use serenity::framework::StandardFramework;
use serenity::http::Http;
use serenity::prelude::*;
use songbird::{SerenityInit, Songbird};
use actix_cors::Cors;
use actix_web::{HttpServer, App};
use actix_web::{HttpServer, App, web};
use crate::bot::{commands::oai::GPTModel, handler::Handler};
use dotenv::dotenv;
use log::{error, info, warn};
mod auth;
mod dnd;
mod bot;
mod db;
@@ -23,21 +33,98 @@ async fn main() -> std::io::Result<()> {
Err(err) => warn!("Unable to load initial database data: {}", err)
};
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
let intents: GatewayIntents = GatewayIntents::all();
let http: Http = Http::new(&token);
let (owners, _bot_id) = match http.get_current_application_info().await {
Ok(info) => {
let mut owners: HashSet<serenity::model::id::UserId> = HashSet::new();
if let Some(team) = info.team {
owners.insert(team.owner_user_id);
} else {
owners.insert(info.owner.id);
}
match http.get_current_user().await {
Ok(bot) => (owners, bot.id),
Err(why) => panic!("Could not access the bot id: {:?}", why)
}
},
Err(why) => panic!("Could not access application info: {:?}", why)
};
let handler = match env::var("OPENAI_API_KEY") {
Ok(token) => {
info!("Loaded OpenAI token");
Handler {
oai: Some(bot::commands::oai::OAI {
client: reqwest::Client::new(),
base_url: "https://api.openai.com/v1".to_string(),
service_url: "http://localhost:5000".to_string(),
max_attempts: 5,
token,
max_context_questions: 30,
max_tokens: 2048,
default_model: GPTModel::GPT35Turbo,
})
}
}
Err(err) => {
warn!("Could not load OpenAI token: {}", err);
Handler { oai: None }
}
};
let songbird = Songbird::serenity();
let mut client = Client::builder(token, intents)
.event_handler(handler)
.framework(StandardFramework::new()
.configure(|c| c.owners(owners)))
.register_songbird_with(Arc::clone(&songbird))
.await
.expect("Error creating client");
let http = Arc::clone(&client.cache_and_http.http);
let cache = Arc::clone(&client.cache_and_http.cache);
let app_data = Arc::new(AppState {
http,
cache,
songbird: Arc::clone(&songbird)
});
let shard_manager = Arc::clone(&client.shard_manager);
// tokio::spawn(async move {
// tokio::signal::ctrl_c().await.expect("Could not register ctrl+c handler");
// shard_manager.lock().await.shutdown_all().await;
// });
// tokio::spawn(async move {
// if let Err(why) = client.start_autosharded().await {
// error!("An error occurred while running the client: {:?}", why);
// }
// });
let host = env::var("SERVICE_HOST").unwrap_or("localhost".to_string());
let port = env::var("SERVICE_PORT").unwrap_or("5000".to_string());
tokio::spawn(bot::run());
match HttpServer::new(|| {
let server = match HttpServer::new(move || {
let cors = Cors::default()
.allow_any_origin()
.allow_any_method()
.allow_any_header()
.supports_credentials()
.max_age(3600);
App::new()
.configure(db::messages::init_routes)
.configure(db::spells::init_routes)
.wrap(cors)
.app_data(web::Data::new(Arc::clone(&app_data)))
.configure(crate::db::messages::init_routes)
.configure(crate::db::spells::init_routes)
.configure(crate::auth::init_routes)
.configure(crate::bot::api::init_routes)
})
.bind(format!("{}:{}", host, port)) {
Ok(b) => {
@@ -48,7 +135,14 @@ async fn main() -> std::io::Result<()> {
error!("Could not bind server: {}", err);
return Err(err);
}
}
.run()
};
server.run()
.await
}
pub struct AppState {
pub http: Arc<Http>,
pub cache: Arc<Cache>,
pub songbird: Arc<Songbird>
}

View File

@@ -9,7 +9,7 @@ services:
environment:
- NODE_ENV=${NODE_ENV:-development}
ports:
- ${UI_PORT:-8080}:3000
- ${UI_PORT:-3000}:3000
build:
context: ./
target: dev

40
ui/package-lock.json generated
View File

@@ -9,10 +9,12 @@
"version": "0.1.0",
"dependencies": {
"@mantine/core": "^7.1.2",
"@mantine/form": "^7.1.2",
"@mantine/hooks": "^7.1.2",
"@mantine/modals": "^7.1.2",
"@mantine/notifications": "^7.1.2",
"axios": "^1.5.1",
"js-cookie": "^3.0.5",
"next": "^13.5.4",
"react": "^18.2.0",
"react-dom": "^18.2.0",
@@ -22,6 +24,7 @@
"recoil": "^0.7.7"
},
"devDependencies": {
"@types/js-cookie": "^3.0.4",
"@types/node": "20.8.2",
"@types/react": "18.2.24",
"@types/react-dom": "18.2.8",
@@ -225,6 +228,18 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/@mantine/form": {
"version": "7.1.2",
"resolved": "https://registry.npmjs.org/@mantine/form/-/form-7.1.2.tgz",
"integrity": "sha512-FnUu5XNmRM265G0wy19qSRiItG/2eQ0GQCctnokw6ws9ZnCU1NqvsmpuDE/UiV4YCAOhAVHfqnjG/8tsrlw7ug==",
"dependencies": {
"fast-deep-equal": "^3.1.3",
"klona": "^2.0.5"
},
"peerDependencies": {
"react": "^18.2.0"
}
},
"node_modules/@mantine/hooks": {
"version": "7.1.2",
"resolved": "https://registry.npmjs.org/@mantine/hooks/-/hooks-7.1.2.tgz",
@@ -573,6 +588,12 @@
"resolved": "https://registry.npmjs.org/@types/d3-timer/-/d3-timer-3.0.0.tgz",
"integrity": "sha512-HNB/9GHqu7Fo8AQiugyJbv6ZxYz58wef0esl4Mv828w1ZKpAshw/uFWVDUcIB9KKFeFKoxS3cHY07FFgtTRZ1g=="
},
"node_modules/@types/js-cookie": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/@types/js-cookie/-/js-cookie-3.0.4.tgz",
"integrity": "sha512-vMMnFF+H5KYqdd/myCzq6wLDlPpteJK+jGFgBus3Da7lw+YsDmx2C8feGTzY2M3Fo823yON+HC2CL240j4OV+w==",
"dev": true
},
"node_modules/@types/json-schema": {
"version": "7.0.13",
"resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz",
@@ -2305,8 +2326,7 @@
"node_modules/fast-deep-equal": {
"version": "3.1.3",
"resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz",
"integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==",
"dev": true
"integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q=="
},
"node_modules/fast-diff": {
"version": "1.3.0",
@@ -3282,6 +3302,14 @@
"reflect.getprototypeof": "^1.0.3"
}
},
"node_modules/js-cookie": {
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/js-cookie/-/js-cookie-3.0.5.tgz",
"integrity": "sha512-cEiJEAEoIbWfCZYKWhVwFuvPX1gETRYPw6LlaTKoxD3s2AkXzkCjnp6h0V77ozyqj0jakteJ4YqDJT830+lVGw==",
"engines": {
"node": ">=14"
}
},
"node_modules/js-tokens": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
@@ -3353,6 +3381,14 @@
"json-buffer": "3.0.1"
}
},
"node_modules/klona": {
"version": "2.0.6",
"resolved": "https://registry.npmjs.org/klona/-/klona-2.0.6.tgz",
"integrity": "sha512-dhG34DXATL5hSxJbIexCft8FChFXtmskoZYnoPWjXQuebWYCNkVeV3KkGegCK9CP1oswI/vQibS2GY7Em/sJJA==",
"engines": {
"node": ">= 8"
}
},
"node_modules/language-subtag-registry": {
"version": "0.3.22",
"resolved": "https://registry.npmjs.org/language-subtag-registry/-/language-subtag-registry-0.3.22.tgz",

View File

@@ -10,10 +10,12 @@
},
"dependencies": {
"@mantine/core": "^7.1.2",
"@mantine/form": "^7.1.2",
"@mantine/hooks": "^7.1.2",
"@mantine/modals": "^7.1.2",
"@mantine/notifications": "^7.1.2",
"axios": "^1.5.1",
"js-cookie": "^3.0.5",
"next": "^13.5.4",
"react": "^18.2.0",
"react-dom": "^18.2.0",
@@ -23,6 +25,7 @@
"recoil": "^0.7.7"
},
"devDependencies": {
"@types/js-cookie": "^3.0.4",
"@types/node": "20.8.2",
"@types/react": "18.2.24",
"@types/react-dom": "18.2.8",

42
ui/src/api/auth.ts Normal file
View File

@@ -0,0 +1,42 @@
import { getRequest, postRequest } from '.';
import { RegisterUser, ResponseAuth } from './auth.types';
export async function login(email: string, password: string): Promise<ResponseAuth | undefined> {
const response = await postRequest('auth/login', { email, password });
if (response?.status === 200) {
return response.data as ResponseAuth;
} else {
return undefined;
}
}
export async function register(user: RegisterUser): Promise<boolean> {
const response = await postRequest('auth/register', user);
if (response?.status === 201) {
return true;
} else {
return false;
}
}
export async function logout() {
return await postRequest('auth/logout', {});
}
export async function refresh(refresh_token_rotation?: boolean): Promise<ResponseAuth | undefined> {
const response = await getRequest('auth/refresh', { params: { refresh_token_rotation } });
if (response?.status === 200) {
return response.data as ResponseAuth;
} else {
return undefined;
}
}
export async function me(): Promise<ResponseAuth | undefined> {
const response = await getRequest('auth/me');
if (response?.status === 200) {
return response.data;
} else {
return undefined;
}
}

18
ui/src/api/auth.types.ts Normal file
View File

@@ -0,0 +1,18 @@
export interface ResponseAuth {
token: string;
user: User;
}
export interface RegisterUser {
email: string;
password: string;
first_name: string;
last_name: string;
}
export interface User {
email: string;
role: string;
first_name: string;
last_name: string;
}

50
ui/src/api/guilds.ts Normal file
View File

@@ -0,0 +1,50 @@
import { getRequest, postRequest } from '.';
import { GuildChannel, GuildInfo } from './guilds.types';
export async function getGuilds(): Promise<GuildInfo[]> {
const response = await getRequest('guilds');
return response?.data || { data: [] };
}
export async function getTextChannels(guildId: number): Promise<GuildChannel[]> {
const response = await getRequest(`guilds/${guildId}/text`);
return response?.data || { data: [] };
}
export async function sendMessage(guildId: number, channelId: number, message: string): Promise<void> {
await postRequest(`guilds/${guildId}/text/${channelId}/message`, { message });
}
export async function getVoiceChannels(guildId: number): Promise<GuildChannel[]> {
const response = await getRequest(`guilds/${guildId}/voice`);
return response?.data || { data: [] };
}
export async function playTrack(guildId: number, channelId: number, track: string): Promise<void> {
await postRequest(`guilds/${guildId}/voice/${channelId}/play`, { track_url: track });
}
export async function stopTrack(guildId: number): Promise<void> {
await postRequest(`guilds/${guildId}/voice/stop`, {});
}
export async function pauseTrack(guildId: number): Promise<void> {
await postRequest(`guilds/${guildId}/voice/pause`, {});
}
export async function resumeTrack(guildId: number): Promise<void> {
await postRequest(`guilds/${guildId}/voice/resume`, {});
}
export async function setVolume(guildId: number, volume: number): Promise<void> {
await postRequest(`guilds/${guildId}/voice/volume`, { volume: `${volume}` });
}
export async function skipTrack(guildId: number): Promise<void> {
await postRequest(`guilds/${guildId}/voice/skip`, {});
}
export async function getVolume(guildId: number): Promise<number> {
const response = await getRequest(`guilds/${guildId}/voice/volume`);
return response?.data?.volume || 0;
}

View File

@@ -0,0 +1,13 @@
export interface GuildInfo {
id: number;
icon?: string;
name: string;
owner: boolean;
}
export interface GuildChannel {
id: number;
name: string;
type: string;
guild_id: number;
}

View File

@@ -1,19 +1,42 @@
import axios, { AxiosResponse } from 'axios';
import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios';
const serviceHost = process.env.SERVICE_HOST || 'http://localhost';
const servicePort = process.env.SERVICE_PORT || 5000;
export async function getRequest(endpoint: string, params: any): Promise<AxiosResponse<any, any> | undefined> {
const response = await axios
.get(`${serviceHost}:${servicePort}/${endpoint}`, { params })
.catch((error) => console.error(error));
function createAxiosClient(): AxiosInstance {
const axiosClient = axios.create({
baseURL: `${serviceHost}:${servicePort}`
});
axiosClient.interceptors.request.use(
(request) => {
request.withCredentials = true;
return request;
},
(error) => {
console.error(error);
return Promise.reject(error);
}
);
return axiosClient;
}
const axiosClient = createAxiosClient();
export async function getRequest(
url: string,
config?: AxiosRequestConfig<any>
): Promise<AxiosResponse<any, any> | undefined> {
const response = await axiosClient.get(`/${url}`, config);
return response || undefined;
}
export async function postRequest(endpoint: string, body: any): Promise<AxiosResponse<any, any> | undefined> {
const response = await axios
.post(`${serviceHost}:${servicePort}/${endpoint}`, { body })
.catch((error) => console.error(error));
export async function postRequest(
url: string,
data?: any,
config?: AxiosRequestConfig<any>
): Promise<AxiosResponse<any, any> | undefined> {
const response = await axiosClient.post(`/${url}`, data, config);
return response || undefined;
}

View File

@@ -19,21 +19,23 @@ interface GetSpellsParams {
}
export async function getSpells(params?: GetSpellsParams): Promise<GetSpellsResponse> {
const response = await getRequest('spells', {
name: params?.name,
like_name: params?.like_name,
schools: params?.schools?.join(','),
levels: params?.levels?.join(','),
ritual: params?.ritual,
concentration: params?.concentration,
classes: params?.classes?.join(','),
damage_inflict: params?.damage_inflict?.join(','),
damage_resist: params?.damage_resist?.join(','),
conditions: params?.conditions?.join(','),
saving_throw: params?.saving_throw?.join(','),
attack_type: params?.attack_type?.join(','),
limit: params?.limit,
page: params?.page
const response = await getRequest('dnd/spells', {
params: {
name: params?.name,
like_name: params?.like_name,
schools: params?.schools?.join(','),
levels: params?.levels?.join(','),
ritual: params?.ritual,
concentration: params?.concentration,
classes: params?.classes?.join(','),
damage_inflict: params?.damage_inflict?.join(','),
damage_resist: params?.damage_resist?.join(','),
conditions: params?.conditions?.join(','),
saving_throw: params?.saving_throw?.join(','),
attack_type: params?.attack_type?.join(','),
limit: params?.limit,
page: params?.page
}
});
return response?.data || { data: [] };
}

View File

@@ -61,14 +61,21 @@ export interface Source {
}
export interface Description {
entries: EntryType[];
// entries: EntryType[];
entries: Entry[];
}
type EntryType = string | Entry;
// type EntryType = string | Entry;
export interface Entry {
type: string;
items: string[];
text?: string;
list?: string[];
table?: EntryTable;
}
export interface EntryTable {
headers: string[];
rows: string[][];
}
export interface GetSpellResponse {

View File

@@ -0,0 +1,140 @@
'use client';
import {
getGuilds,
getTextChannels,
getVoiceChannels,
getVolume,
pauseTrack,
playTrack,
resumeTrack,
sendMessage,
setVolume,
skipTrack,
stopTrack
} from '@/api/guilds';
import { GuildChannel, GuildInfo } from '@/api/guilds.types';
import { Button, Slider, Tabs, TextInput, Textarea } from '@mantine/core';
import { useForm } from '@mantine/form';
import React, { useEffect, useState } from 'react';
export default function Page() {
const [guilds, setGuilds] = useState<GuildInfo[]>([]);
const [activeGuild, setActiveGuild] = useState<GuildInfo | null>(null);
const [textChannels, setTextChannels] = useState<GuildChannel[]>([]);
const [voiceChannels, setVoiceChannels] = useState<GuildChannel[]>([]);
const [guildVolume, setGuildVolume] = useState<number>(50.0);
useEffect(() => {
getGuilds().then((g) => {
setGuilds(g);
if (g.length > 0) {
setActiveGuild(g[0]);
}
});
}, []);
useEffect(() => {
if (activeGuild) {
getTextChannels(activeGuild.id).then((c) => setTextChannels(c));
getVoiceChannels(activeGuild.id).then((c) => setVoiceChannels(c));
getVolume(activeGuild.id).then((v) => setGuildVolume(v));
}
}, [activeGuild]);
const playForm = useForm({
initialValues: {
message: '',
trackUrl: '',
volume: 50.0
}
});
return (
<Tabs orientation='vertical' defaultValue={activeGuild?.name}>
<Tabs.List>
{guilds.map((guild) => (
<Tabs.Tab key={`guild-tab-${guild.id}`} value={guild.name} onClick={() => setActiveGuild(guild)}>
{guild.name}
</Tabs.Tab>
))}
</Tabs.List>
{guilds.map((guild) => (
<Tabs.Panel key={`guild-${guild.id}`} value={guild.name}>
<h1>{guild.name}</h1>
<h2>Text Channels</h2>
<Tabs orientation='horizontal' defaultValue={textChannels[0]?.name}>
<Tabs.List>
{textChannels.map((channel) => (
<Tabs.Tab key={`text-channel-tab-${channel.id}`} value={channel.name}>
{channel.name}
</Tabs.Tab>
))}
</Tabs.List>
{textChannels.map((channel) => (
<Tabs.Panel key={`text-channel-${channel.id}`} value={channel.name}>
<form
style={{ margin: '1em' }}
onSubmit={playForm.onSubmit((values) => sendMessage(activeGuild!.id, channel.id, values.message))}
>
<Textarea placeholder='Message...' {...playForm.getInputProps('message')} />
<Button type='submit'>Send Message</Button>
</form>
</Tabs.Panel>
))}
</Tabs>
<h2>Voice Channels</h2>
<Tabs orientation='horizontal' defaultValue={voiceChannels[0]?.name}>
<Tabs.List>
{voiceChannels.map((channel) => (
<Tabs.Tab key={`voice-channel-tab-${channel.id}`} value={channel.name}>
{channel.name}
</Tabs.Tab>
))}
</Tabs.List>
{voiceChannels.map((channel) => (
<Tabs.Panel key={`voice-channel-${channel.id}`} value={channel.name}>
<form
style={{ margin: '1em' }}
onSubmit={playForm.onSubmit((values) => {
playTrack(activeGuild!.id, channel.id, values.trackUrl);
})}
>
<TextInput placeholder='Youtube URL...' {...playForm.getInputProps('trackUrl')} />
<Button type='submit'>Play Track</Button>
<Button onClick={() => skipTrack(activeGuild!.id)}>Skip Track</Button>
</form>
<div style={{ margin: '1em' }}>
<Button style={{ marginRight: '1em' }} onClick={() => stopTrack(activeGuild!.id)}>
Stop
</Button>
<Button style={{ marginRight: '1em' }} onClick={() => pauseTrack(activeGuild!.id)}>
Pause
</Button>
<Button style={{ marginRight: '1em' }} onClick={() => resumeTrack(activeGuild!.id)}>
Resume
</Button>
</div>
<form
style={{ margin: '1em' }}
onSubmit={playForm.onSubmit((values) => setVolume(activeGuild!.id, values.volume))}
>
<Slider
defaultValue={guildVolume}
{...playForm.getInputProps('volume')}
marks={[
{ value: 25, label: '25%' },
{ value: 50, label: '50%' },
{ value: 75, label: '75%' }
]}
/>
<Button type='submit'>Set Volume</Button>
</form>
</Tabs.Panel>
))}
</Tabs>
</Tabs.Panel>
))}
</Tabs>
);
}

View File

@@ -1,5 +1,6 @@
import React from 'react';
// Home page for siren
export default function Page() {
return <></>;
return <div></div>;
}

View File

@@ -22,6 +22,16 @@ export default function Page() {
const [activeSpell, setActiveSpell] = useState<Spell | undefined>(undefined);
const [isOpen, setIsOpen] = useState(false);
const [searchName, setSearchName] = useState('');
const [includeCantrips, setIncludeCantrips] = useState(true);
const [includeLevel1, setIncludeLevel1] = useState(true);
const [includeLevel2, setIncludeLevel2] = useState(true);
const [includeLevel3, setIncludeLevel3] = useState(true);
const [includeLevel4, setIncludeLevel4] = useState(true);
const [includeLevel5, setIncludeLevel5] = useState(true);
const [includeLevel6, setIncludeLevel6] = useState(true);
const [includeLevel7, setIncludeLevel7] = useState(true);
const [includeLevel8, setIncludeLevel8] = useState(true);
const [includeLevel9, setIncludeLevel9] = useState(true);
useEffect(() => {
getSpells({ levels: [0] }).then((s) => setCantrips(s.data));
@@ -54,6 +64,15 @@ export default function Page() {
}}
/>
<hr />
<SpellSection
title='Level 1'
spells={level1.filter((s) => s.name.toLowerCase().includes(searchName.toLowerCase()))}
onClick={(spell) => {
setActiveSpell(spell);
setIsOpen(true);
}}
/>
<hr />
{activeSpell && <SpellModal spell={activeSpell} isOpen={isOpen} onClose={() => setIsOpen(false)} />}
</Box>
);

View File

@@ -50,7 +50,7 @@ export default function SpellModal({ spell, isOpen, onClose }: SpellModalProps)
<span style={{ overflowWrap: 'break-word' }}>
{spell.classes.map((c) => (
<span style={{ paddingRight: '0.6em', display: 'inline-block' }} className='link'>
{capitalize(c)}
{parseText(c, true)}
</span>
))}
</span>
@@ -86,66 +86,105 @@ export default function SpellModal({ spell, isOpen, onClose }: SpellModalProps)
);
}
function SpellDescription({ spell }: { spell: Spell }) {
function parseText(text: string) {
const regex = /{@(.*?) (.*?)}/g;
const matches = text.matchAll(regex);
const result = [];
let lastIndex = 0;
for (const match of matches) {
const [full, type, name] = match;
result.push(text.slice(lastIndex, match.index));
if (match.index !== undefined) {
function parseText(text: string, capitalizeFirst?: boolean) {
const regex = /{@(.*?) (.*?)}/g;
const matches = text.matchAll(regex);
const result = [];
let lastIndex = 0;
for (const match of matches) {
const [full, type, name] = match;
result.push(text.slice(lastIndex, match.index));
if (match.index !== undefined) {
if (type == 'dice') {
result.push(
<span onClick={() => handleLink(type, name)} className='link'>
{name}
</span>
);
lastIndex = match.index + full.length;
} else if (type == 'scaledice') {
// scaledice format is {@scaledice 1d6|1-9|1d6|}. Parse this out into dice, levels, and dice again.
const [dice, levels] = name.split('|');
result.push(
<span onClick={() => handleLink('dice', dice)} className='link'>
{dice}
</span>
);
} else if (type == 'bold') {
result.push(<span style={{ fontWeight: 'bold' }}>{name}</span>);
} else if (type == 'subclass') {
const [className, subclassName] = name.split('|');
result.push(
<span>
{capitalize(className)} ({capitalize(subclassName)})
</span>
);
} else {
result.push(<span>{capitalizeFirst ? capitalize(name) : name}</span>);
}
lastIndex = match.index + full.length;
}
result.push(text.slice(lastIndex));
return result;
}
result.push(text.slice(lastIndex));
return result;
}
function handleLink(type: string, name: string) {
if (type == 'spell') {
console.log(`Link to spell: ${name}`);
} else if (type == 'dice' || type == 'damage') {
const rolls = rollDice(name);
notifications.show({
title: `Rolling ${name}`,
message: `${rolls.join(' + ')} = ${rolls.reduce((a, b) => a + b, 0)}`,
color: 'blue',
autoClose: 5000,
withCloseButton: false
});
} else {
console.error(`Unknown link type: ${type}`);
}
function handleLink(type: string, name: string) {
if (type == 'spell') {
console.log(`Link to spell: ${name}`);
} else if (type == 'dice') {
const rolls = rollDice(name);
notifications.show({
title: `Rolling ${name}`,
message: `${rolls.join(' + ')} = ${rolls.reduce((a, b) => a + b, 0)}`,
color: 'blue',
autoClose: 5000,
withCloseButton: false
});
} else if (type == 'scaledice') {
console.log(`Link to scaledice: ${name}`);
} else {
console.error(`Unknown link type: ${type}`);
}
}
function SpellDescription({ spell }: { spell: Spell }) {
return (
<>
{spell.description && (
<>
{spell.description.entries.map((e) =>
typeof e === 'string' ? (
<p>{parseText(e)}</p>
) : (
<>
{e.type == 'list' ? (
<ul>
{e.items.map((text) => (
<li>{parseText(text)}</li>
{spell.description.entries.map((e) => (
<>
{e.text && <p>{parseText(e.text)}</p>}
{e.list && (
<ul>
{e.list.map((text) => (
<li>{parseText(text)}</li>
))}
</ul>
)}
{e.table && (
<table>
<thead>
<tr>
{e.table.headers.map((label) => (
<th>{label}</th>
))}
</tr>
</thead>
<tbody>
{e.table.rows.map((row) => (
<tr>
{row.map((cell) => (
<td>{parseText(cell)}</td>
))}
</tr>
))}
</ul>
) : (
<></>
)}
</>
)
)}
</tbody>
</table>
)}
</>
))}
</>
)}
</>

View File

@@ -3,8 +3,39 @@
import Link from 'next/link';
import { usePathname } from 'next/navigation';
import './topbar.css';
import {
Anchor,
Avatar,
Button,
Card,
Checkbox,
Container,
Grid,
Group,
Menu,
Modal,
Paper,
PasswordInput,
Text,
TextInput,
Title,
UnstyledButton
} from '@mantine/core';
import Cookies from 'js-cookie';
import { useEffect, useState } from 'react';
import { useForm } from '@mantine/form';
import { login, register, logout, me, refresh } from '@/api/auth';
import { User } from '@/api/auth.types';
import { useToggle } from '@mantine/hooks';
import { notifications } from '@mantine/notifications';
const headerItems = [
interface HeaderItem {
name: string;
link: string;
role?: string;
}
const headerItems: HeaderItem[] = [
{
name: 'Races',
link: '/races'
@@ -32,26 +63,375 @@ const headerItems = [
{
name: 'Spells',
link: '/spells'
},
{
name: 'Management',
link: '/management',
role: 'admin'
}
];
export default function Topbar() {
const pathName = usePathname();
const [modalType, toggle] = useToggle([undefined, 'login', 'register', 'reset']);
const [headers, setHeaders] = useState<HeaderItem[]>([]);
const [user, setUser] = useState<User | undefined>(undefined);
useEffect(() => {
if (Cookies.get('logged_in')) {
me().then((response) => {
if (response) {
setUser(response.user);
}
});
} else {
refresh(true).then((response) => {
if (response) {
setUser(response.user);
} else {
setUser(undefined);
}
});
}
}, [pathName]);
useEffect(() => {
const h: HeaderItem[] = [];
headerItems.forEach((item) => {
if (item.role == undefined || user?.role == item.role) {
h.push(item);
}
setHeaders(h);
});
}, [user]);
return (
<nav className='navbar'>
<div className='left'>
<Link href={'/'} className='title'>
Siren
</Link>
<div className='header-items'>
{headerItems.map((item) => (
<Link className={`header-item ${pathName == item.link && 'active'}`} href={item.link} key={item.name}>
{item.name}
</Link>
))}
<>
<nav className='navbar'>
<div className='left'>
<Link href={'/'} className='title'>
Siren
</Link>
<div className='header-items'>
{headers.map((item) => (
<Link className={`header-item ${pathName == item.link && 'active'}`} href={item.link} key={item.name}>
{item.name}
</Link>
))}
</div>
</div>
</div>
</nav>
<div className='user-section'>
{user ? (
<Menu shadow='md' width={200} openDelay={100} closeDelay={400}>
<Menu.Target>
<UnstyledButton className='user user-button'>
<Group>
<Avatar />
<div style={{ flex: 1 }}>
<Text size='sm' fw={500}>
{user.first_name} {user.last_name}
</Text>
<Text c='dimmed' size='xs'>
{user.role}
</Text>
</div>
</Group>
</UnstyledButton>
</Menu.Target>
<Menu.Dropdown>
<Card>
<Card.Section h={140} style={{}} />
<Avatar size={80} radius={80} mx={'auto'} mt={-30} />
<Text ta='center' fz='lg' fw={500} mt='sm'>
{user.first_name} {user.last_name}
</Text>
<Text ta='center' fz='sm' c='dimmed'>
{user.role}
</Text>
<Grid mt='xl'>
<Grid.Col span={6}>
<Button
fullWidth
radius='md'
size='xs'
variant='default'
onClick={() => {
toggle(undefined);
}}
>
Profile
</Button>
</Grid.Col>
<Grid.Col span={6}>
<Button
fullWidth
radius='md'
size='xs'
variant='default'
onClick={async () => {
const response = await logout();
if (response?.status == 200) {
Cookies.remove('logged_in');
setUser(undefined);
}
}}
>
Logout
</Button>
</Grid.Col>
</Grid>
</Card>
</Menu.Dropdown>
</Menu>
) : (
<Group className='user'>
<Button onClick={() => toggle('login')}>Login</Button>
<Button variant='outline' onClick={() => toggle('register')}>
Sign up
</Button>
</Group>
)}
</div>
</nav>
<LoginModal type={modalType} toggle={toggle} setUser={setUser} />
</>
);
}
interface LoginModalProps {
type?: string;
toggle: any;
setUser: (user: User) => void;
}
function LoginModal({ type, toggle, setUser }: LoginModalProps) {
function passwordValidator(value: string) {
if (value.trim().length < 10) {
return 'Password must be at least 10 characters';
}
if (value.trim().length >= 128) {
return 'Password must be at most 128 characters';
}
if (!/(\d)/.test(value)) {
return 'Password must contain at least one number';
}
if (!/[a-z]/.test(value)) {
return 'Password must contain at least one lowercase letter';
}
if (!/[A-Z]/.test(value)) {
return 'Password must contain at least one uppercase letter';
}
if (!/[!@#$%^&*]/.test(value)) {
return 'Password must contain at least one special character';
}
return null;
}
function emailValidator(value: string) {
if (value.trim().length == 0) {
return 'Email is required';
}
if (!/^\S+@\S+$/.test(value)) {
return 'Invalid email';
}
return null;
}
const registerForm = useForm({
initialValues: {
firstName: '',
lastName: '',
email: '',
password: ''
},
validate: {
firstName: (value) => (value.trim().length > 0 ? null : 'First name is required'),
lastName: (value) => (value.trim().length > 0 ? null : 'Last name is required'),
email: emailValidator,
password: passwordValidator
}
});
const loginForm = useForm({
initialValues: {
email: '',
password: '',
remember: false
}
});
const resetForm = useForm({
initialValues: {
email: ''
}
});
function onClose() {
toggle(undefined);
registerForm.reset();
resetForm.reset();
if (!loginForm.values.remember) {
loginForm.reset();
}
}
return (
<Modal opened={type !== undefined} onClose={onClose} withCloseButton={false}>
{type == 'reset' ? (
<Container>
<Title ta='center'>Reset password</Title>
<Text c='dimmed' size='sm' ta='center' mt={5}>
Enter your email and we will send you a link to reset your password.{' '}
<Anchor size='sm' component='a' onClick={() => toggle('login')}>
Go Back
</Anchor>
</Text>
<Paper withBorder shadow='md' p={30} mt={30} radius='md'>
<form onSubmit={resetForm.onSubmit(async (values) => console.log(values))}>
<TextInput label='Email' placeholder='you@example.com' required {...resetForm.getInputProps('email')} />
<Button type='submit' fullWidth mt='xl'>
Reset password
</Button>
</form>
</Paper>
</Container>
) : type == 'register' ? (
<Container>
<Title ta='center'>Create account</Title>
<Text c='dimmed' size='sm' ta='center' mt={5}>
Already have an account?{' '}
<Anchor size='sm' component='a' onClick={() => toggle('login')}>
Sign in
</Anchor>
</Text>
<Paper withBorder shadow='md' p={30} mt={30} radius='md'>
<form
onSubmit={registerForm.onSubmit(async (values) => {
const id = notifications.show({
loading: true,
title: `Creating account`,
message: `Please wait...`,
autoClose: false,
withCloseButton: false
});
const registerResponse = await register({
first_name: values.firstName,
last_name: values.lastName,
email: values.email,
password: values.password
});
if (registerResponse) {
const loginResponse = await login(values.email, values.password);
if (loginResponse) {
setUser(loginResponse.user);
onClose();
notifications.update({
id,
title: `Account created`,
message: `Welcome ${loginResponse.user.first_name}!`,
color: 'green',
autoClose: 2000,
loading: false
});
} else {
notifications.update({
id,
title: `Unable to Login`,
message: `Please try again.`,
color: 'red',
autoClose: 2000,
loading: false
});
}
} else {
notifications.update({
id,
title: `Unable to Register`,
message: `Please try again.`,
color: 'error',
autoClose: 2000,
loading: false
});
}
})}
>
<TextInput label='First name' placeholder='John' required {...registerForm.getInputProps('firstName')} />
<TextInput
label='Last name'
placeholder='Smith'
required
mt='md'
{...registerForm.getInputProps('lastName')}
/>
<TextInput
label='Email'
placeholder='you@example.com'
required
{...registerForm.getInputProps('email')}
/>
<PasswordInput
label='Password'
description='Passwords must be at least 10 characters long, contain at least one number, one uppercase letter, one lowercase letter, and one special character.'
placeholder='Your password'
required
mt='md'
{...registerForm.getInputProps('password')}
/>
<Button type='submit' fullWidth mt='xl'>
Sign up
</Button>
</form>
</Paper>
</Container>
) : (
<Container>
<Title ta='center'>Welcome back!</Title>
<Text c='dimmed' size='sm' ta='center' mt={5}>
Do not have an account yet?{' '}
<Anchor size='sm' component='a' onClick={() => toggle('register')}>
Create account
</Anchor>
</Text>
<Paper withBorder shadow='md' p={30} mt={30} radius='md'>
<form
onSubmit={loginForm.onSubmit(async (values) => {
const response = await login(values.email, values.password);
if (response) {
setUser(response.user);
onClose();
} else {
notifications.show({
title: `Unable to Login`,
message: `Please try again.`,
color: 'red',
autoClose: 2000
});
}
})}
>
<TextInput label='Email' placeholder='you@example.com' required {...loginForm.getInputProps('email')} />
<PasswordInput
label='Password'
placeholder='Your password'
required
mt='md'
{...loginForm.getInputProps('password')}
/>
<Group justify='space-between' mt='lg'>
<Checkbox label='Remember me' {...loginForm.getInputProps('remember')} />
<Anchor component='a' size='sm' onClick={() => toggle('reset')}>
Forgot password?
</Anchor>
</Group>
<Button type='submit' fullWidth mt='xl'>
Sign in
</Button>
</form>
</Paper>
</Container>
)}
</Modal>
);
}

View File

@@ -20,20 +20,14 @@
margin: auto;
}
.navbar .avatar {
padding-right: 2em;
margin-top: auto;
margin-bottom: auto;
}
.header-items {
display: flex;
justify-content: space-between;
}
.header-items .header-item {
padding-left: 2em;
padding-right: 2em;
padding-left: 2rem;
padding-right: 2rem;
margin: auto;
border-bottom: 2px solid transparent;
}
@@ -45,3 +39,23 @@
.header-items .active {
border-bottom: 2px solid #5f5f5f;
}
.user-section {
margin-left: 2rem;
margin-right: 2rem;
}
.user {
display: flex;
justify-content: space-between;
border-radius: 0.5rem;
padding: 0.5rem;
padding-left: 1rem;
padding-right: 1rem;
margin-top: 0.5rem;
margin-bottom: 0.5rem;
}
.user-button:hover {
background-color: #e6e6e6;
}

View File

@@ -18,6 +18,5 @@ export function rollDice(dice: string): number[] {
for (let i = 0; i < parseInt(count); i++) {
rolls.push(Math.floor(Math.random() * parseInt(sides)) + 1);
}
console.log(rolls);
return rolls;
}