Major refactor
This commit is contained in:
@@ -1 +1,2 @@
|
||||
idea/
|
||||
target/
|
||||
38
.env
38
.env
@@ -1,38 +0,0 @@
|
||||
RUST_LOG=warn,siren=info
|
||||
|
||||
DISCORD_TOKEN=
|
||||
DISCORD_SECRET=
|
||||
|
||||
JWT_SECRET=CHANGEME # Change this to a secure secret
|
||||
|
||||
DATABASE_USER=siren
|
||||
DATABASE_PASSWORD=CHANGEME # Change this to a secure password
|
||||
DATABASE_NAME=siren_db
|
||||
DATABASE_HOST=localhost
|
||||
DATABASE_PORT=5432
|
||||
|
||||
API_CALLBACK_URI=http://localhost:3000/api/oauth/callback
|
||||
API_PORT=3000
|
||||
API_SESSION_TTL=86400
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
|
||||
MINIO_ROOT_USER=siren
|
||||
MINIO_ROOT_PASSWORD=CHANGEME # Change this to a secure password
|
||||
MINIO_HOST=localhost
|
||||
MINIO_PORT=9000
|
||||
MINIO_PORT_INTERNAL=9001
|
||||
|
||||
# Siren Data integration
|
||||
DATA_DIR_PATH= # Optional
|
||||
|
||||
# OpenAI
|
||||
OPENAI_TOKEN= # Optional
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
OPENAI_MODEL=gpt-4o-mini
|
||||
|
||||
FORCE_REGISTER=false
|
||||
DEFAULT_API_KEY=test_api_key
|
||||
DEFAULT_SERVER=
|
||||
DEFAULT_USER=
|
||||
35
.env.example
Normal file
35
.env.example
Normal file
@@ -0,0 +1,35 @@
|
||||
RUST_LOG=warn,siren=info
|
||||
|
||||
DISCORD_BOT_TOKEN=
|
||||
DISCORD_CLIENT_SECRET=
|
||||
|
||||
JWT_SECRET=changeme
|
||||
|
||||
POSTGRES_USER=siren
|
||||
POSTGRES_PASSWORD=changeme
|
||||
POSTGRES_DB=siren_db
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
API_BASE_URL=http://localhost:3000
|
||||
API_PORT=3000
|
||||
API_SESSION_TTL=86400
|
||||
|
||||
UI_PORT=8080
|
||||
|
||||
VALKEY_HOST=localhost
|
||||
VALKEY_PORT=6379
|
||||
|
||||
MINIO_ROOT_USER=siren
|
||||
MINIO_ROOT_PASSWORD=changeme
|
||||
MINIO_HOST=localhost
|
||||
MINIO_PORT=9000
|
||||
MINIO_PORT_INTERNAL=9001
|
||||
|
||||
# Siren Data integration (Optional)
|
||||
DATA_DIR_PATH=./data
|
||||
|
||||
FORCE_REGISTER=false
|
||||
DEFAULT_API_KEY=test_api_key
|
||||
DEFAULT_SERVER=
|
||||
DEFAULT_USER=
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -1,10 +1,13 @@
|
||||
# Build
|
||||
target/
|
||||
**/target/
|
||||
**/Cargo.lock
|
||||
**/node_modules/
|
||||
**/package-lock.json
|
||||
|
||||
logs/
|
||||
data/
|
||||
settings.json
|
||||
.env*.local
|
||||
.env
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
|
||||
6
.vscode/extensions.json
vendored
6
.vscode/extensions.json
vendored
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"rust-lang.rust-analyzer",
|
||||
"ms-vscode.makefile-tools"
|
||||
]
|
||||
}
|
||||
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -1,3 +0,0 @@
|
||||
{
|
||||
"makefile.configureOnOpen": false
|
||||
}
|
||||
88
Cargo.toml
88
Cargo.toml
@@ -1,32 +1,66 @@
|
||||
[package]
|
||||
name = "siren"
|
||||
version = "0.2.10"
|
||||
edition = "2021"
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/siren",
|
||||
"crates/siren-core",
|
||||
"crates/siren-bot",
|
||||
"crates/siren-api",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
|
||||
edition = "2024"
|
||||
version = "0.3.0"
|
||||
rust-version = "1.94"
|
||||
authors = ["Ben Sherriff <ben@bensherriff.com>"]
|
||||
description = "A Discord bot for playing music"
|
||||
repository = "https://github.com/bensherriff/siren"
|
||||
readme = "README.md"
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[dependencies]
|
||||
dotenv = "0.15.0"
|
||||
log = "0.4.22"
|
||||
env_logger = "0.11.5"
|
||||
serde = { version = "1.0.210", features = ["derive"] }
|
||||
serde_json = "1.0.128"
|
||||
serenity = { version = "0.12.2", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] }
|
||||
songbird = { version = "0.4.6", features = ["builtin-queue"] }
|
||||
symphonia = { version = "0.5.4", features = ["all"] }
|
||||
sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "chrono", "uuid"] }
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
reqwest = { version = "0.11", default-features = false, features = ["json"] }
|
||||
uuid = { version = "1.11.0", features = ["serde", "v4"] }
|
||||
redis = { version = "0.27.4", features = ["tokio-comp", "connection-manager", "r2d2"] }
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "signal"] }
|
||||
regex = "1.11.0"
|
||||
axum = { version = "0.7.7", features = ["json"] }
|
||||
axum-extra = { version = "0.9.6", features = ["typed-header"] }
|
||||
lazy_static = "1.5.0"
|
||||
jsonwebtoken = "9.3.0"
|
||||
[workspace.dependencies]
|
||||
# Internal crates
|
||||
siren-core = { path = "crates/siren-core" }
|
||||
siren-bot = { path = "crates/siren-bot" }
|
||||
siren-api = { path = "crates/siren-api" }
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
|
||||
# Logging
|
||||
log = "0.4"
|
||||
env_logger = "0.11"
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# Discord / Audio
|
||||
serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "voice", "cache", "framework", "standard_framework"] }
|
||||
#songbird = { version = "0.5", features = ["builtin-queue"] }
|
||||
# Temporary until DAVE encryption release https://github.com/serenity-rs/songbird/issues/293
|
||||
songbird = { git = "https://github.com/serenity-rs/songbird.git", branch = "next", features = ["builtin-queue"] }
|
||||
symphonia = { version = "0.5", features = ["all"] }
|
||||
|
||||
# HTTP
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json"] }
|
||||
|
||||
# Database
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "chrono", "uuid"] }
|
||||
redis = { version = "1", features = ["tokio-comp", "connection-manager", "r2d2"] }
|
||||
|
||||
# Utilities
|
||||
dotenv = "0.15"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
rand = "0.10"
|
||||
rand_chacha = "0.10"
|
||||
regex = "1"
|
||||
lazy_static = "1"
|
||||
|
||||
# API
|
||||
axum = { version = "0.8", features = ["json", "ws", "macros"] }
|
||||
axum-extra = { version = "0.12", features = ["typed-header"] }
|
||||
jsonwebtoken = { version = "10", features = ["rust_crypto"] }
|
||||
tower-http = { version = "0.6", features = ["fs", "cors"] }
|
||||
dashmap = "6"
|
||||
futures-util = "0.3"
|
||||
|
||||
20
Dockerfile
20
Dockerfile
@@ -1,15 +1,15 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# =========
|
||||
# Builder
|
||||
# =========
|
||||
FROM rust:bookworm AS builder
|
||||
WORKDIR /builder
|
||||
|
||||
COPY migrations ./migrations
|
||||
COPY src ./src
|
||||
COPY Cargo.toml ./
|
||||
|
||||
RUN apt-get update && apt-get install -y cmake
|
||||
RUN cargo build --release
|
||||
FROM rust:1.94-slim-bookworm AS builder
|
||||
COPY . .
|
||||
RUN --mount=type=cache,target=/usr/local/cargo/registry,sharing=locked \
|
||||
--mount=type=cache,target=/usr/local/cargo/git,sharing=locked \
|
||||
--mount=type=cache,target=/target,sharing=locked \
|
||||
cargo build --release --bin siren && \
|
||||
cp /target/release/siren /siren
|
||||
|
||||
# ==========
|
||||
# Packages
|
||||
@@ -40,7 +40,7 @@ FROM debian:bookworm-slim AS runtime
|
||||
WORKDIR /siren
|
||||
USER root
|
||||
|
||||
COPY --from=builder /builder/target/release/siren /usr/local/bin/siren
|
||||
COPY --from=builder /siren /usr/local/bin/siren
|
||||
COPY --from=packages /packages /usr/bin
|
||||
|
||||
RUN apt-get update && apt-get install -y libc6 libc6-dev libopus-dev libpq5 libpq-dev python3-pip ffmpeg
|
||||
|
||||
674
LICENSE
674
LICENSE
@@ -1,674 +0,0 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
59
Makefile
59
Makefile
@@ -1,59 +0,0 @@
|
||||
#!make
|
||||
SHELL := /bin/bash
|
||||
ENV := ./scripts/apply_env.sh
|
||||
|
||||
.PHONY: help
|
||||
|
||||
export VERSION=$(if $(v),$(v),latest)
|
||||
|
||||
help: ## Help command
|
||||
@echo
|
||||
@cat Makefile | grep -E '^[a-zA-Z\/_-]+:.*?## .*$$' | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
@echo
|
||||
|
||||
backend-up: ## Start the backend containers
|
||||
@$(ENV) docker compose --profile backend up -d
|
||||
|
||||
up-backend: backend-up
|
||||
|
||||
backend-down: ## Stop the backend containers
|
||||
@$(ENV) docker compose --profile backend down
|
||||
|
||||
down-backend: backend-down
|
||||
|
||||
run: ## Run the project
|
||||
@echo "Running project..."
|
||||
@$(ENV) cargo run
|
||||
@echo "Run complete"
|
||||
|
||||
format: ## Format code
|
||||
@echo "Formatting code..."
|
||||
@$(ENV) cargo fmt
|
||||
@echo "Format complete"
|
||||
|
||||
clean: ## Clean the project
|
||||
@echo "Cleaning project..."
|
||||
@$(ENV) cargo clean
|
||||
@echo "Clean complete"
|
||||
|
||||
docker-up: ## Start the app
|
||||
@$(ENV) docker compose --profile backend --profile bot up -d
|
||||
|
||||
docker-down: ## Stop the app
|
||||
@$(ENV) docker compose --profile backend --profile bot down
|
||||
|
||||
docker-build: ## Build the docker image
|
||||
@$(ENV) docker build -f Dockerfile -t siren:${VERSION} .
|
||||
|
||||
docker-clean: ## Stop the docker containers and remove volumes
|
||||
@echo "Stopping docker container and removing volumes..."
|
||||
@$(ENV) docker compose --profile backend --profile bot down -v
|
||||
@echo "Docker container stopped and volumes removed"
|
||||
|
||||
docker-refresh: docker-clean backend-up ## Refresh the docker containers
|
||||
|
||||
psql: ## Connect to the database
|
||||
@$(ENV) docker exec -it siren-postgres psql -U ${DATABASE_USER} -P pager=off
|
||||
|
||||
insert-api: ## Insert test API key into the database
|
||||
@$(ENV) ./scripts/insert_api_key.sh
|
||||
285
README.md
285
README.md
@@ -1,130 +1,205 @@
|
||||
<div align="center">
|
||||
<img src="docs/siren.png" alt="drawing" width="200"/>
|
||||
<img src="docs/siren.png" alt="Siren" width="200"/>
|
||||
<h1 align="center">Siren</h1>
|
||||
<p>A D&D-focused Discord bot written in Rust</p>
|
||||
</div>
|
||||
|
||||
Siren is a D&D Bot built for Discord, written in Rust. Features include:
|
||||
- Music commands from Youtube and locally hosted files
|
||||
- Database for D&D 5e content
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- Music playback from YouTube via slash commands
|
||||
- Dice rolling with D&D notation (e.g. `/roll 2d6+3`)
|
||||
- Session scheduling
|
||||
- Backend API
|
||||
- ChatGPT integration
|
||||
- REST API with OAuth2 authentication
|
||||
|
||||
## Requirements
|
||||
- [Docker](https://www.docker.com/)
|
||||
- **Optional**: [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
---
|
||||
|
||||
## Running
|
||||
1. Setup the Discord Developer Application and bot
|
||||
2. Create `.env.local` and override any variables from `.env`
|
||||
- At minimum, `DISCORD_TOKEN` must be set. See [instructions](#setup-discord-developer-application) for additional steps.
|
||||
3. Build the Docker application with `make build`
|
||||
4. Start the application with `make up`
|
||||
## Prerequisites
|
||||
|
||||
<h3 id='setup-discord-developer-application'>Setting up the Discord Developer Application</h3>
|
||||
### Running with Docker (recommended)
|
||||
|
||||
Visit the [Discord Developer Portal](https://discord.com/developers/applications) and create a new application. Click [here](https://discord.com/developers/docs/intro) for guides and more information.
|
||||
| Tool | Notes |
|
||||
|-----------------------------------|---------------------------------------------|
|
||||
| [Docker](https://www.docker.com/) | Required |
|
||||
| [Task](https://taskfile.dev/) | Optional — used to run convenience commands |
|
||||
|
||||
#### Oauth2
|
||||
**Required Scopes**:
|
||||
- bot
|
||||
- applications.commands
|
||||
### Running locally (development)
|
||||
|
||||
**Required Bot Permissions**:
|
||||
- General Permissions
|
||||
- Manage Roles
|
||||
- Change Nickname
|
||||
- View Channels
|
||||
- Manage Events
|
||||
- Create Events
|
||||
- Text Permissions
|
||||
- Send Messages
|
||||
- Create Public Threads
|
||||
- Create Private Threads
|
||||
- Send Messages in Threads
|
||||
- Manage Messages
|
||||
- Manage Threads
|
||||
- Embed Links
|
||||
- Attach Files
|
||||
- Read Message History
|
||||
- Mention Everyone
|
||||
- Use External Emojis
|
||||
- Use External Stickers
|
||||
- Add Reactions
|
||||
- Create Polls
|
||||
- Voice Permissions
|
||||
- Connect
|
||||
- Speak
|
||||
| Tool | Notes |
|
||||
|---------------------------------------------------|---------------------------------------------|
|
||||
| [Rust](https://www.rust-lang.org/tools/install) | Stable toolchain |
|
||||
| [yt-dlp](https://github.com/yt-dlp/yt-dlp) | Audio source extraction |
|
||||
| [ffmpeg](https://github.com/yt-dlp/FFmpeg-Builds) | Audio transcoding |
|
||||
| [Docker](https://www.docker.com/) | Used to run PostgreSQL and Valkey |
|
||||
| [Task](https://taskfile.dev/) | Optional — used to run convenience commands |
|
||||
|
||||
Example Invites:
|
||||
```
|
||||
https://discord.com/api/oauth2/authorize?client_id=<CLIENT_ID>&permissions=40671259392832&scope=bot%20applications.commands
|
||||
> **yt-dlp note:** Keep yt-dlp up to date. YouTube frequently rotates its player,
|
||||
> and an outdated yt-dlp will fail to resolve stream URLs.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Copy `.env.example` to `.env` and fill in the required values:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# or using Task:
|
||||
task setup
|
||||
```
|
||||
|
||||
### Environment variables
|
||||
|
||||
| Variable | Required | Description |
|
||||
|-------------------------|----------|-------------------------------------------------------------------------|
|
||||
| `DISCORD_BOT_TOKEN` | Yes | Bot token from the Discord Developer Portal |
|
||||
| `DISCORD_CLIENT_SECRET` | Yes | OAuth2 client secret |
|
||||
| `JWT_SECRET` | Yes | Secret used to sign JWT tokens — change from default |
|
||||
| `POSTGRES_USER` | Yes | PostgreSQL username |
|
||||
| `POSTGRES_PASSWORD` | Yes | PostgreSQL password — change from default |
|
||||
| `POSTGRES_DB` | Yes | PostgreSQL database name |
|
||||
| `POSTGRES_HOST` | Yes | PostgreSQL host (`localhost` for local dev, `siren-postgres` in Docker) |
|
||||
| `POSTGRES_PORT` | Yes | PostgreSQL port (default `5432`) |
|
||||
| `VALKEY_HOST` | Yes | Valkey host (`localhost` for local dev, `siren-valkey` in Docker) |
|
||||
| `VALKEY_PORT` | Yes | Valkey port (default `6379`) |
|
||||
| `API_PORT` | Yes | Port the REST API listens on (default `3000`) |
|
||||
| `API_CALLBACK_URI` | Yes | OAuth2 redirect URI (e.g. `http://localhost:3000/api/oauth/callback`) |
|
||||
| `API_SESSION_TTL` | | OAuth2 session TTL in seconds (default `86400`) |
|
||||
| `RUST_LOG` | | Log filter (e.g. `warn,siren=info`) |
|
||||
| `FORCE_REGISTER` | | Re-register slash commands on every startup (`true`/`false`) |
|
||||
| `DATA_DIR_PATH` | | Path to optional local data directory |
|
||||
| `DEFAULT_API_KEY` | | Seed API key created on startup |
|
||||
| `DEFAULT_SERVER` | | Seed guild ID |
|
||||
| `DEFAULT_USER` | | Seed user ID |
|
||||
|
||||
---
|
||||
|
||||
## Discord Application Setup
|
||||
|
||||
1. Visit the [Discord Developer Portal](https://discord.com/developers/applications) and create a new application.
|
||||
2. Go to the **Bot** tab:
|
||||
- Click **Reset Token** to generate your bot token — this is your `DISCORD_BOT_TOKEN`.
|
||||
- Enable **Message Content Intent** under Privileged Gateway Intents.
|
||||
|
||||

|
||||
|
||||
3. Go to the **OAuth2** tab to find your **Client Secret** (`DISCORD_CLIENT_SECRET`).
|
||||
|
||||
### Invite the bot to your server
|
||||
|
||||
**Required scopes:** `bot`, `applications.commands`
|
||||
|
||||
**Required permissions:**
|
||||
|
||||
| Category | Permissions |
|
||||
|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| General | Manage Roles, Change Nickname, View Channels, Manage Events, Create Events |
|
||||
| Text | Send Messages, Create Public/Private Threads, Send Messages in Threads, Manage Messages, Manage Threads, Embed Links, Attach Files, Read Message History, Mention Everyone, Use External Emojis/Stickers, Add Reactions, Create Polls |
|
||||
| Voice | Connect, Speak |
|
||||
|
||||
Use this invite URL (replace `<CLIENT_ID>` with your Application ID from the General Information tab):
|
||||
|
||||
```
|
||||
https://discord.com/oauth2/authorize?client_id=<CLIENT_ID>&permissions=581083641408576&integration_type=0&scope=bot+applications.commands
|
||||
```
|
||||
|
||||
The CLIENT_ID can be found in the General Information tab on the Discord Developer Portal for your application, under `Application ID`
|
||||
---
|
||||
|
||||
The DISCORD_TOKEN (used in the `.env file`) can be found under the Bot tab on the Discord Developer Portal for your application.
|
||||
## Running
|
||||
|
||||

|
||||
### Docker (recommended)
|
||||
|
||||
### Commands
|
||||
Siren utilizes Discord slash commands. To view the commands, run `/help` in a server where the bot is installed. The following commands are available:
|
||||
Build and start all containers:
|
||||
|
||||
**Music Commands**
|
||||
| Command | Description |
|
||||
| --- | --- |
|
||||
| `/play <Track>` | Play a track from Youtube or locally hosted files |
|
||||
| `/pause` | Pause the current track |
|
||||
| `/resume` | Resume the current track |
|
||||
| `/skip` | Skip the current track |
|
||||
| `/stop` | Stop the current track |
|
||||
| `/mute` | Mute the current track |
|
||||
| `/queue` | ***TODO*** - Display the current queue |
|
||||
| `/clear` | ***TODO*** - Clear the current queue |
|
||||
| `/shuffle` | ***TODO*** - Shuffle the current queue |
|
||||
| `/loop` | ***TODO*** - Loop or unloop the current track |
|
||||
| `/nowplaying` | ***TODO*** - Display the current track |
|
||||
| `/volume <Volume>` | Set the volume of the bot |
|
||||
```bash
|
||||
# Build the Docker image
|
||||
task docker:build
|
||||
# or: docker build -f Dockerfile -t siren:latest .
|
||||
|
||||
**Event Commands**
|
||||
| Command | Description |
|
||||
| --- | --- |
|
||||
| `/schedule` | ***TODO*** - Schedule a new event |
|
||||
| `/events` | ***TODO*** - Display all events |
|
||||
| `/event <Event ID>` | ***TODO*** - Display a specific event |
|
||||
| `/deleteevent <Event ID>` | ***TODO*** - Delete a specific event |
|
||||
| `/updateevent <Event ID>` | ***TODO*** - Update a specific event |
|
||||
| `/remindme <Event ID>` | ***TODO*** - Set a reminder for a specific event |
|
||||
|
||||
**Fun Commands**
|
||||
| Command | Description |
|
||||
| --- | --- |
|
||||
| `/coinflip` | Flip a coin |
|
||||
| `/roll <Dice>` | Roll a dice |
|
||||
| `/requestroll <User> <Dice>` | Request a dice roll from a user |
|
||||
|
||||
**Utility Commands**
|
||||
| Command | Description |
|
||||
| --- | --- |
|
||||
| `/ping` | Display the bot's latency |
|
||||
| `/poll` | ***TODO*** - Create a poll |
|
||||
| `/help` | ***TODO*** - Display a list of commands |
|
||||
|
||||
## Contributing
|
||||
- [Rust](https://www.rust-lang.org/)
|
||||
- [yt-dlp](https://github.com/yt-dlp/yt-dlp)
|
||||
- [ffmpeg](https://github.com/yt-dlp/FFmpeg-Builds)
|
||||
|
||||
### Running Locally
|
||||
1. Start the backend containers with `make backend-up`
|
||||
2. Run the application locally with `make run`
|
||||
|
||||
The application can also be tested from within a Docker container:
|
||||
# Start all services (postgres, valkey, and the app)
|
||||
task docker:up:all
|
||||
# or: docker compose --profile app up -d
|
||||
```
|
||||
make docker-build
|
||||
make docker-up
|
||||
|
||||
To stop everything:
|
||||
|
||||
```bash
|
||||
task docker:down
|
||||
# or: docker compose --profile app down
|
||||
```
|
||||
|
||||
### Local development
|
||||
|
||||
Start the backing services (PostgreSQL and Valkey) in Docker, then run the bot natively:
|
||||
|
||||
```bash
|
||||
# Start only the infrastructure containers
|
||||
task docker:up
|
||||
# or: docker compose up -d
|
||||
|
||||
# Run the bot locally with trace-level logging
|
||||
task run
|
||||
# or: cargo run
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Commands
|
||||
|
||||
Siren uses Discord slash commands.
|
||||
|
||||
### Music
|
||||
|
||||
| Command | Description | Status |
|
||||
|-------------------|-------------------------------------|---------|
|
||||
| `/play <track>` | Play a track from YouTube | Done |
|
||||
| `/pause` | Pause the current track | Done |
|
||||
| `/resume` | Resume the current track | Done |
|
||||
| `/skip` | Skip to the next track | Done |
|
||||
| `/stop` | Stop playback and clear the queue | Done |
|
||||
| `/mute` | Mute/unmute the bot | Done |
|
||||
| `/volume <0–100>` | Set the playback volume | Done |
|
||||
| `/queue` | Display the current queue | Planned |
|
||||
| `/nowplaying` | Display the currently playing track | Planned |
|
||||
| `/shuffle` | Shuffle the queue | Planned |
|
||||
| `/loop` | Toggle looping the current track | Planned |
|
||||
| `/clear` | Clear the queue | Planned |
|
||||
|
||||
### Events
|
||||
|
||||
| Command | Description | Status |
|
||||
|-------------|------------------------------|---------|
|
||||
| `/schedule` | Schedule a new event | Planned |
|
||||
| `/events` | Display all scheduled events | Planned |
|
||||
|
||||
### Fun
|
||||
|
||||
| Command | Description | Status |
|
||||
|------------------------------|---------------------------------------|--------|
|
||||
| `/roll <dice>` | Roll dice (e.g. `2d6+3`) | Done |
|
||||
| `/requestroll <user> <dice>` | Request a dice roll from another user | Done |
|
||||
|
||||
### Utility
|
||||
|
||||
| Command | Description | Status |
|
||||
|---------|----------------------------|---------|
|
||||
| `/ping` | Check the bot's latency | Done |
|
||||
| `/help` | Display available commands | Planned |
|
||||
|
||||
---
|
||||
|
||||
## Development
|
||||
|
||||
| Task | Command |
|
||||
|------------------------|----------------|
|
||||
| Type-check (fast) | `task check` |
|
||||
| Debug build | `task build` |
|
||||
| Release build | `task release` |
|
||||
| Run with trace logging | `task run` |
|
||||
| Format code | `task format` |
|
||||
| Lint (Clippy) | `task lint` |
|
||||
| Clean build artifacts | `task clean` |
|
||||
| Connect to database | `task psql` |
|
||||
|
||||
Run `task` with no arguments to list all available tasks.
|
||||
|
||||
155
Taskfile.yml
Normal file
155
Taskfile.yml
Normal file
@@ -0,0 +1,155 @@
|
||||
version: '3'
|
||||
|
||||
dotenv: [".env", ".env.example"]
|
||||
|
||||
vars:
|
||||
VERSION: '{{.v | default "latest"}}'
|
||||
RUST_LOG: "warn,siren=info"
|
||||
|
||||
tasks:
|
||||
default:
|
||||
desc: List available tasks
|
||||
cmds:
|
||||
- task --list
|
||||
silent: true
|
||||
|
||||
setup:
|
||||
desc: Copy .env.example to .env if .env does not exist
|
||||
cmds:
|
||||
- test -f .env || cp .env.example .env
|
||||
silent: true
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Cargo
|
||||
# -----------------------------------------------------------
|
||||
build:
|
||||
desc: "Compile a debug build"
|
||||
deps: [ setup ]
|
||||
cmds:
|
||||
- cargo build
|
||||
silent: true
|
||||
|
||||
release:
|
||||
desc: "Compile an optimised release build"
|
||||
deps: [ setup ]
|
||||
cmds:
|
||||
- cargo build --release
|
||||
silent: true
|
||||
|
||||
run:
|
||||
desc: "Run the project"
|
||||
deps: [ setup ]
|
||||
env:
|
||||
RUST_LOG: "warn,siren=trace"
|
||||
cmds:
|
||||
- cargo run
|
||||
silent: true
|
||||
|
||||
format:
|
||||
desc: "Format code"
|
||||
cmds:
|
||||
- cargo fmt
|
||||
silent: true
|
||||
|
||||
clean:
|
||||
desc: "Clean the project"
|
||||
deps: [ setup ]
|
||||
cmds:
|
||||
- cargo clean
|
||||
silent: true
|
||||
|
||||
lint:
|
||||
desc: "Run Clippy linter"
|
||||
deps: [ setup ]
|
||||
cmds:
|
||||
- cargo clippy -- -D warnings
|
||||
silent: true
|
||||
|
||||
check:
|
||||
desc: "Fast type-check without producing a binary"
|
||||
deps: [ setup ]
|
||||
cmds:
|
||||
- cargo check
|
||||
silent: true
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Docker
|
||||
# -----------------------------------------------------------
|
||||
docker:build:
|
||||
desc: "Build the Docker image (use v=x.x.x to set version, default is \"latest\")"
|
||||
cmds:
|
||||
- docker build -f Dockerfile -t siren:{{.VERSION}} .
|
||||
silent: true
|
||||
|
||||
docker:up:
|
||||
desc: "Start backend containers"
|
||||
cmds:
|
||||
- docker compose up -d
|
||||
silent: true
|
||||
|
||||
docker:up:all:
|
||||
desc: "Start all containers"
|
||||
cmds:
|
||||
- docker compose --profile app up -d
|
||||
silent: true
|
||||
|
||||
docker:down:
|
||||
desc: "Stop all containers"
|
||||
cmds:
|
||||
- docker compose --profile app down
|
||||
silent: true
|
||||
|
||||
docker:clean:
|
||||
desc: "Stop all containers and remove volumes"
|
||||
prompt: "This will remove all docker containers, networks, volumes, and images. Are you sure?"
|
||||
cmds:
|
||||
- docker compose --profile app down -v
|
||||
silent: true
|
||||
|
||||
docker:refresh:
|
||||
desc: "Clean and restart containers"
|
||||
cmds:
|
||||
- task: docker:clean
|
||||
- task: docker:up
|
||||
silent: true
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# UI
|
||||
# -----------------------------------------------------------
|
||||
ui:install:
|
||||
desc: "Install UI npm dependencies"
|
||||
dir: ui
|
||||
cmds:
|
||||
- npm install
|
||||
silent: true
|
||||
|
||||
ui:run:
|
||||
desc: "Run Vite dev server"
|
||||
dir: ui
|
||||
cmds:
|
||||
- npm run dev
|
||||
silent: true
|
||||
|
||||
ui:build:
|
||||
desc: "Build the React UI into ui/dist"
|
||||
dir: ui
|
||||
cmds:
|
||||
- npm run build
|
||||
silent: true
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Utilities
|
||||
# -----------------------------------------------------------
|
||||
psql:
|
||||
desc: Connect to the database
|
||||
cmds:
|
||||
- docker exec -it siren-postgres psql -U $DATABASE_USER -P pager=off
|
||||
silent: true
|
||||
|
||||
ngrok:
|
||||
desc: Start ngrok tunnel
|
||||
vars:
|
||||
UI_PORT: '{{.UI_PORT | default "8080"}}'
|
||||
cmds:
|
||||
- ngrok http {{.UI_PORT}}
|
||||
silent: true
|
||||
@@ -1,15 +0,0 @@
|
||||
meta {
|
||||
name: Create API Key
|
||||
type: http
|
||||
seq: 2
|
||||
}
|
||||
|
||||
post {
|
||||
url: {{baseUrl}}/api-key
|
||||
body: none
|
||||
auth: bearer
|
||||
}
|
||||
|
||||
auth:bearer {
|
||||
token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjI1MDg0MjI2MTIyMTI3NzY5NywibmFtZSI6ImJzaGVycmlmZiIsImlhdCI6MTczNDkwMjgzOSwiZXhwIjoxNzM0OTg5MjM5LCJqdGkiOiJWTlFjeXpBN25sZEt1SWtzcDFzc1pRNHNacUZ2dWZPZCJ9.JnO-Rklv9YZKWjRvehR4-tfP1dlO5vIEWpSh_W4xZWY
|
||||
}
|
||||
28
crates/siren-api/Cargo.toml
Normal file
28
crates/siren-api/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "siren-api"
|
||||
edition.workspace = true
|
||||
version.workspace = true
|
||||
rust-version.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
siren-core = { workspace = true }
|
||||
siren-bot = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
serenity = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_chacha = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
53
crates/siren-api/src/app.rs
Normal file
53
crates/siren-api/src/app.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use crate::{AppState, error::Result};
|
||||
use axum::Router;
|
||||
use std::{env, sync::Arc};
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::{
|
||||
cors::{Any, CorsLayer},
|
||||
services::{ServeDir, ServeFile},
|
||||
};
|
||||
|
||||
pub struct App {
|
||||
app_state: AppState,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new(app_state: AppState) -> Self {
|
||||
Self { app_state }
|
||||
}
|
||||
|
||||
pub async fn serve(self) -> Result<()> {
|
||||
log::debug!("Starting API...");
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
// Serve the built React frontend from frontend/dist (relative to the
|
||||
// working directory). Falls back gracefully if the directory does not
|
||||
// exist yet (e.g. during development when using `npm run dev`).
|
||||
let frontend_dir = env::current_dir()
|
||||
.unwrap_or_default()
|
||||
.join("frontend")
|
||||
.join("dist");
|
||||
|
||||
// For SPA routing: any path not matched by a real file (e.g. /map/<id>)
|
||||
// falls back to index.html so React can handle client-side routing.
|
||||
let index_html = frontend_dir.join("index.html");
|
||||
let serve_dir = ServeDir::new(&frontend_dir).not_found_service(ServeFile::new(index_html));
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api", crate::get_routes())
|
||||
.fallback_service(serve_dir)
|
||||
.layer(cors)
|
||||
.with_state(Arc::new(self.app_state));
|
||||
|
||||
let api_port: String = env::var("API_PORT").expect("Expected a port in the environment");
|
||||
let addr = format!("0.0.0.0:{}", api_port);
|
||||
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
log::info!("API is listening on {}", &addr);
|
||||
Ok(axum::serve(listener, app).await?)
|
||||
}
|
||||
}
|
||||
23
crates/siren-api/src/app_state.rs
Normal file
23
crates/siren-api/src/app_state.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use dashmap::DashMap;
|
||||
use serenity::{
|
||||
all::{Cache, Http},
|
||||
prelude::Mutex,
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub client: reqwest::Client,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub base_url: String,
|
||||
/// Maps oauth_state → ui_redirect_uri.
|
||||
/// Populated on /authorize, consumed on /callback.
|
||||
pub discord_authorize_cache: Arc<Mutex<HashMap<String, String>>>,
|
||||
pub http: Arc<Http>,
|
||||
pub cache: Arc<Cache>,
|
||||
/// Per-map WebSocket broadcast channels for real-time collaboration.
|
||||
/// Key is the CSPRNG map ID (TEXT).
|
||||
pub map_rooms: Arc<DashMap<String, broadcast::Sender<String>>>,
|
||||
}
|
||||
@@ -1,17 +1,27 @@
|
||||
use std::sync::Arc;
|
||||
use axum::extract::{Path, State};
|
||||
use axum::middleware::from_extractor;
|
||||
use axum::{Extension, Json, Router};
|
||||
use axum::routing::post;
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{AuthorizationMiddleware, Session},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Extension,
|
||||
Json,
|
||||
Router,
|
||||
extract::{Path, State},
|
||||
middleware::from_extractor,
|
||||
routing::post,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
|
||||
use crate::AppState;
|
||||
use crate::bot::commands::audio::join_voice_channel;
|
||||
use crate::bot::commands::audio::pause::pause_track;
|
||||
use crate::bot::commands::audio::play::enqueue_track;
|
||||
use crate::bot::commands::audio::resume::resume_track;
|
||||
use crate::bot::handler::get_songbird;
|
||||
use crate::error::{Error, SirenResult};
|
||||
use siren_bot::{
|
||||
commands::audio::{
|
||||
join_voice_channel,
|
||||
pause::pause_track,
|
||||
play::enqueue_track,
|
||||
resume::resume_track,
|
||||
},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
@@ -29,15 +39,15 @@ struct PlayTrackRequest {
|
||||
}
|
||||
|
||||
async fn play_audio(
|
||||
Extension(credential): Extension<AuthCredential>,
|
||||
Extension(session): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
Json(payload): Json<PlayTrackRequest>,
|
||||
) -> SirenResult<()> {
|
||||
) -> Result<()> {
|
||||
log::debug!("Playing audio in guild: {}", guild_id);
|
||||
|
||||
// Check if the user exists in the cache
|
||||
let user_id = credential.user_id();
|
||||
let user_id = session.user_id;
|
||||
let user_id = match state.cache.user(user_id) {
|
||||
Some(user) => user.id,
|
||||
None => return Err(Error::not_found("User not found".to_string())),
|
||||
@@ -57,10 +67,10 @@ async fn play_audio(
|
||||
}
|
||||
|
||||
async fn pause_audio(
|
||||
Extension(_): Extension<AuthCredential>,
|
||||
Extension(_): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> SirenResult<()> {
|
||||
) -> Result<()> {
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
@@ -71,14 +81,15 @@ async fn pause_audio(
|
||||
|
||||
// Pause the track
|
||||
let manager = get_songbird();
|
||||
pause_track(manager, &guild_id).await
|
||||
pause_track(manager, &guild_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn resume_audio(
|
||||
Extension(_): Extension<AuthCredential>,
|
||||
Extension(_): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
) -> SirenResult<()> {
|
||||
) -> Result<()> {
|
||||
log::debug!("Pausing audio in guild: {}", guild_id);
|
||||
|
||||
// Validate if the guild exists in the cache
|
||||
@@ -89,5 +100,6 @@ async fn resume_audio(
|
||||
|
||||
// Pause the track
|
||||
let manager = get_songbird();
|
||||
resume_track(manager, &guild_id).await
|
||||
resume_track(manager, &guild_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
225
crates/siren-api/src/auth/discord.rs
Normal file
225
crates/siren-api/src/auth/discord.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{bearer_token::BearerTokenClaims, csprng, session::Session},
|
||||
};
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::get,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, sync::Arc};
|
||||
|
||||
const DISCORD_REDIRECT_PATH: &str = "/api/auth/discord/callback";
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/authorize", get(discord_authorize))
|
||||
.route("/callback", get(discord_callback))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeQuery {
|
||||
redirect_uri: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CallbackQuery {
|
||||
code: String,
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct DiscordTokenResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: u64,
|
||||
refresh_token: String,
|
||||
scope: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct DiscordUser {
|
||||
id: String,
|
||||
username: String,
|
||||
discriminator: String,
|
||||
avatar: Option<String>,
|
||||
}
|
||||
|
||||
async fn discord_authorize(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<AuthorizeQuery>,
|
||||
) -> impl IntoResponse {
|
||||
let oauth_state = csprng(16);
|
||||
|
||||
state
|
||||
.discord_authorize_cache
|
||||
.lock()
|
||||
.await
|
||||
.insert(oauth_state.clone(), query.redirect_uri);
|
||||
|
||||
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
|
||||
let encoded_callback = discord_callback_url.replace(':', "%3A").replace('/', "%2F");
|
||||
|
||||
let discord_auth_url = format!(
|
||||
"https://discord.com/api/oauth2/authorize\
|
||||
?client_id={}\
|
||||
&redirect_uri={}\
|
||||
&response_type=code\
|
||||
&scope=identify\
|
||||
&state={}",
|
||||
state.client_id, encoded_callback, oauth_state,
|
||||
);
|
||||
|
||||
match serde_json::to_string(&discord_auth_url) {
|
||||
Ok(json) => Ok(json),
|
||||
Err(e) => {
|
||||
log::error!("Failed to serialize Discord OAuth URL: {e}");
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn discord_callback(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<CallbackQuery>,
|
||||
) -> impl IntoResponse {
|
||||
match do_oauth_callback(state, query).await {
|
||||
Ok((token, ui_redirect_uri)) => {
|
||||
Redirect::temporary(&format!("{}?token={}", ui_redirect_uri, token)).into_response()
|
||||
}
|
||||
Err((e, ui_redirect_uri)) => {
|
||||
log::error!("OAuth callback error: {:?}", e);
|
||||
let fallback = ui_redirect_uri.unwrap_or_else(|| "/".to_string());
|
||||
Redirect::temporary(&format!("{}?error=auth_failed", fallback)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_oauth_callback(
|
||||
state: Arc<AppState>,
|
||||
query: CallbackQuery,
|
||||
) -> Result<(String, String), (crate::error::Error, Option<String>)> {
|
||||
// Validate the state and retrieve the associated UI redirect URI
|
||||
let ui_redirect_uri = {
|
||||
let mut oauth_states = state.discord_authorize_cache.lock().await;
|
||||
match query.state {
|
||||
Some(ref oauth_state) => match oauth_states.remove(oauth_state) {
|
||||
Some(uri) => uri,
|
||||
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
|
||||
},
|
||||
None => return Err((StatusCode::UNAUTHORIZED.into(), None)),
|
||||
}
|
||||
};
|
||||
|
||||
// Helper closure to tag errors with the redirect URI we already know
|
||||
let redirect = ui_redirect_uri.clone();
|
||||
let err = |s: StatusCode| -> Result<_, (crate::error::Error, Option<String>)> {
|
||||
Err((s.into(), Some(redirect.clone())))
|
||||
};
|
||||
|
||||
// The discord redirect_uri in the token exchange must match what was sent in /authorize
|
||||
let discord_callback_url = format!("{}{}", state.base_url, DISCORD_REDIRECT_PATH);
|
||||
|
||||
// Exchange code for an access token
|
||||
let token_response = state
|
||||
.client
|
||||
.post("https://discord.com/api/oauth2/token")
|
||||
.form(&[
|
||||
("client_id", state.client_id.as_str()),
|
||||
("client_secret", state.client_secret.as_str()),
|
||||
("grant_type", "authorization_code"),
|
||||
("code", query.code.as_str()),
|
||||
("redirect_uri", discord_callback_url.as_str()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
if !token_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to exchange token: {:?}",
|
||||
token_response.text().await
|
||||
);
|
||||
return err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
let token_data: DiscordTokenResponse = token_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
// Fetch user information from Discord
|
||||
let user_response = state
|
||||
.client
|
||||
.get("https://discord.com/api/users/@me")
|
||||
.bearer_auth(token_data.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
if !user_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to fetch user information: {:?}",
|
||||
user_response.text().await
|
||||
);
|
||||
return err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
let user_data: DiscordUser = user_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
log::debug!("User authenticated: {:?}", user_data);
|
||||
|
||||
let user_id: i64 = user_data
|
||||
.id
|
||||
.parse::<i64>()
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
// Upsert the Discord user into the local users table
|
||||
let pool = siren_core::data::pool();
|
||||
sqlx::query(
|
||||
"INSERT INTO users (id, username, avatar, updated_at)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET username = EXCLUDED.username,
|
||||
avatar = EXCLUDED.avatar,
|
||||
updated_at = NOW()",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&user_data.username)
|
||||
.bind(&user_data.avatar)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
log::error!("Failed to upsert user: {e}");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err()
|
||||
})?;
|
||||
|
||||
// Create and insert the session
|
||||
let session = Session::new(user_id as u64, user_data.username.clone());
|
||||
session
|
||||
.insert()
|
||||
.await
|
||||
.map_err(|e| (e, Some(ui_redirect_uri.clone())))?;
|
||||
|
||||
let issued_at = chrono::Utc::now();
|
||||
let claims = BearerTokenClaims {
|
||||
sub: session.user_id,
|
||||
name: session.user_name.clone(),
|
||||
iat: issued_at.timestamp(),
|
||||
exp: session.expires_at.timestamp(),
|
||||
jti: session.session_id.clone(),
|
||||
};
|
||||
|
||||
let jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
|
||||
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
|
||||
.map_err(|_| err(StatusCode::INTERNAL_SERVER_ERROR).unwrap_err())?;
|
||||
|
||||
Ok((token, ui_redirect_uri))
|
||||
}
|
||||
107
crates/siren-api/src/auth/middleware.rs
Normal file
107
crates/siren-api/src/auth/middleware.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::{
|
||||
auth::{bearer_token::BearerTokenClaims, session::Session},
|
||||
error::Result,
|
||||
};
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
http::{Method, StatusCode, request::Parts},
|
||||
};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, Validation, decode};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AuthorizationMiddleware — rejects unauthenticated requests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct AuthorizationMiddleware;
|
||||
|
||||
impl<S> FromRequestParts<S> for AuthorizationMiddleware
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> std::result::Result<Self, Self::Rejection> {
|
||||
// For options requests browsers will not send the authorization header.
|
||||
if parts.method == Method::OPTIONS {
|
||||
return Ok(Self);
|
||||
}
|
||||
|
||||
// Check for a Bearer token in the `Authorization` header.
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
{
|
||||
return match check_bearer_auth(bearer.token()).await {
|
||||
Ok(session) => {
|
||||
parts.extensions.insert(session);
|
||||
Ok(Self)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
}
|
||||
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OptionalAuth — extracts a Session if present, otherwise None
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wraps an optional authenticated session.
|
||||
/// Handlers that use this extractor work for both authenticated and
|
||||
/// unauthenticated callers; callers with a valid Bearer token get a `Some(session)`.
|
||||
pub struct OptionalAuth(pub Option<Session>);
|
||||
|
||||
impl<S> FromRequestParts<S> for OptionalAuth
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = std::convert::Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> std::result::Result<Self, Self::Rejection> {
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
{
|
||||
if let Ok(session) = check_bearer_auth(bearer.token()).await {
|
||||
parts.extensions.insert(session.clone());
|
||||
return Ok(Self(Some(session)));
|
||||
}
|
||||
}
|
||||
Ok(Self(None))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn check_bearer_auth(bearer_token: &str) -> Result<Session> {
|
||||
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
|
||||
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
|
||||
|
||||
let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let claims = token_data.claims;
|
||||
|
||||
let now = Utc::now().timestamp();
|
||||
if claims.exp < now {
|
||||
return Err(StatusCode::UNAUTHORIZED.into());
|
||||
}
|
||||
|
||||
match Session::find(&claims.jti).await {
|
||||
Ok(Some(session)) => Ok(session),
|
||||
_ => Err(StatusCode::UNAUTHORIZED)?,
|
||||
}
|
||||
}
|
||||
24
crates/siren-api/src/auth/mod.rs
Normal file
24
crates/siren-api/src/auth/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use crate::AppState;
|
||||
use axum::Router;
|
||||
use rand::RngExt;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod discord;
|
||||
mod session;
|
||||
pub use session::Session;
|
||||
mod bearer_token;
|
||||
pub mod middleware;
|
||||
pub use middleware::{AuthorizationMiddleware, OptionalAuth};
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new().nest("/discord", discord::get_routes())
|
||||
}
|
||||
|
||||
pub fn csprng(take: usize) -> String {
|
||||
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
||||
rand::rng()
|
||||
.sample_iter(rand::distr::Alphanumeric)
|
||||
.take(take)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
@@ -1,11 +1,9 @@
|
||||
use std::env;
|
||||
use std::sync::OnceLock;
|
||||
use crate::{auth::csprng, error::Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use redis::{AsyncCommands, RedisResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::api::auth::csprng;
|
||||
use crate::data;
|
||||
use crate::error::SirenResult;
|
||||
use siren_core::data;
|
||||
use std::{env, sync::OnceLock};
|
||||
|
||||
static SESSION_TTL: OnceLock<i64> = OnceLock::new();
|
||||
|
||||
@@ -39,17 +37,17 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
pub async fn insert(&self) -> Result<()> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let session_id = self.session_id.clone();
|
||||
let session_ttl = get_session_ttl();
|
||||
redis
|
||||
.set_ex(session_id, serde_json::to_string(self)?, session_ttl as u64)
|
||||
.set_ex::<_, _, ()>(session_id, serde_json::to_string(self)?, session_ttl as u64)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn find(session_id: &str) -> SirenResult<Option<Session>> {
|
||||
pub async fn find(session_id: &str) -> Result<Option<Session>> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let result: RedisResult<Option<String>> = redis.get(session_id).await;
|
||||
match result {
|
||||
@@ -59,7 +57,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete(session_id: &str) -> SirenResult<()> {
|
||||
pub async fn delete(session_id: &str) -> Result<()> {
|
||||
let mut redis = data::redis_async_connection().await?;
|
||||
let result: RedisResult<()> = redis.del(session_id).await;
|
||||
match result {
|
||||
@@ -1,23 +1,25 @@
|
||||
use std::fmt::Display;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use axum::{Extension, Json, Router};
|
||||
use axum::extract::{Path, State};
|
||||
use axum::middleware::from_extractor;
|
||||
use axum::routing::post;
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{AuthorizationMiddleware, Session},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Extension,
|
||||
Json,
|
||||
Router,
|
||||
extract::{Path, State},
|
||||
middleware::from_extractor,
|
||||
routing::post,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siren_bot::commands::fun::roll::{format_roll, parse_dice};
|
||||
use siren_core::data::{ExecutableQuery, Value, condition::Condition, query::QueryBuilder};
|
||||
use std::{fmt::Display, str::FromStr, sync::Arc};
|
||||
use uuid::Uuid;
|
||||
use crate::api::auth::{AuthCredential, AuthorizationMiddleware};
|
||||
use crate::AppState;
|
||||
use crate::bot::commands::fun::roll::{format_roll, parse_dice};
|
||||
use crate::data::condition::Condition;
|
||||
use crate::data::{ExecutableQuery, Value};
|
||||
use crate::data::query::QueryBuilder;
|
||||
use crate::error::{Error, SirenResult};
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/:guild_id/track", post(add_track_dice))
|
||||
.route("/{guild_id}/track", post(add_track_dice))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
}
|
||||
|
||||
@@ -55,7 +57,7 @@ impl Display for TrackDiceOperator {
|
||||
impl FromStr for TrackDiceOperator {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"eq" => Ok(TrackDiceOperator::Equal),
|
||||
"lt" => Ok(TrackDiceOperator::LessThan),
|
||||
@@ -68,7 +70,7 @@ impl FromStr for TrackDiceOperator {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
struct DiceTrackPayload {
|
||||
pub struct DiceTrackPayload {
|
||||
dice: String,
|
||||
user_id: Option<i64>,
|
||||
value: Option<i32>,
|
||||
@@ -76,7 +78,7 @@ struct DiceTrackPayload {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
struct InsertDiceTrack {
|
||||
pub struct InsertDiceTrack {
|
||||
guild_id: i64,
|
||||
owner_id: i64,
|
||||
dice: String,
|
||||
@@ -86,7 +88,7 @@ struct InsertDiceTrack {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
struct QueryDiceTrack {
|
||||
pub struct QueryDiceTrack {
|
||||
id: Uuid,
|
||||
guild_id: i64,
|
||||
owner_id: i64,
|
||||
@@ -121,8 +123,8 @@ impl QueryDiceTrack {
|
||||
}
|
||||
|
||||
impl InsertDiceTrack {
|
||||
pub async fn insert(&self) -> SirenResult<QueryDiceTrack> {
|
||||
let pool = crate::data::pool();
|
||||
pub async fn insert(&self) -> Result<QueryDiceTrack> {
|
||||
let pool = siren_core::data::pool();
|
||||
let query = format!(
|
||||
"INSERT INTO {} (
|
||||
guild_id,
|
||||
@@ -154,13 +156,13 @@ impl InsertDiceTrack {
|
||||
}
|
||||
|
||||
pub async fn add_track_dice(
|
||||
Extension(credential): Extension<AuthCredential>,
|
||||
Extension(session): Extension<Session>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(guild_id): Path<u64>,
|
||||
Json(payload): Json<DiceTrackPayload>,
|
||||
) -> SirenResult<Json<QueryDiceTrack>> {
|
||||
) -> Result<Json<QueryDiceTrack>> {
|
||||
// Check if the user exists in the cache
|
||||
let owner_id = credential.user_id();
|
||||
let owner_id = session.user_id;
|
||||
let owner_id = match state.cache.user(owner_id) {
|
||||
Some(user) => user.id,
|
||||
None => return Err(Error::not_found("User not found".to_string())),
|
||||
128
crates/siren-api/src/error.rs
Normal file
128
crates/siren-api/src/error.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use axum::{
|
||||
Json,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Error {
|
||||
pub status: u16,
|
||||
pub details: String,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn new(status: u16, details: String) -> Self {
|
||||
Self { status, details }
|
||||
}
|
||||
|
||||
pub fn not_found(details: String) -> Self {
|
||||
Self::new(404, details)
|
||||
}
|
||||
|
||||
pub fn internal_server_error(details: String) -> Self {
|
||||
Self::new(500, details)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.details.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn description(&self) -> &str {
|
||||
&self.details
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> Response {
|
||||
let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = Json(serde_json::json!({
|
||||
"error": {
|
||||
"status": self.status,
|
||||
"details": self.details,
|
||||
}
|
||||
}));
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// --- Conversions from upstream crate errors ---
|
||||
|
||||
impl From<siren_core::error::Error> for Error {
|
||||
fn from(error: siren_core::error::Error) -> Self {
|
||||
Self::new(error.status, error.details)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<siren_bot::error::Error> for Error {
|
||||
fn from(error: siren_bot::error::Error) -> Self {
|
||||
Self::new(error.status, error.details)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Conversions from external crate errors ---
|
||||
|
||||
impl From<StatusCode> for Error {
|
||||
fn from(status: StatusCode) -> Self {
|
||||
Error {
|
||||
status: status.as_u16(),
|
||||
details: status
|
||||
.canonical_reason()
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(error: reqwest::Error) -> Self {
|
||||
Self::new(500, format!("HTTP client error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(error: serde_json::Error) -> Self {
|
||||
Self::new(500, format!("JSON error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<jsonwebtoken::errors::Error> for Error {
|
||||
fn from(error: jsonwebtoken::errors::Error) -> Self {
|
||||
match error.kind() {
|
||||
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
|
||||
Self::new(401, "Token expired".to_string())
|
||||
}
|
||||
jsonwebtoken::errors::ErrorKind::InvalidToken => Self::new(401, "Invalid token".to_string()),
|
||||
_ => Self::new(500, format!("JWT error: {}", error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Direct conversions for types used in API handlers that bypass the data abstraction layer
|
||||
|
||||
impl From<sqlx::Error> for Error {
|
||||
fn from(error: sqlx::Error) -> Self {
|
||||
let core_err: siren_core::error::Error = error.into();
|
||||
core_err.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redis::RedisError> for Error {
|
||||
fn from(error: redis::RedisError) -> Self {
|
||||
let core_err: siren_core::error::Error = error.into();
|
||||
core_err.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Self::new(500, format!("IO error: {}", error))
|
||||
}
|
||||
}
|
||||
619
crates/siren-api/src/grid/mod.rs
Normal file
619
crates/siren-api/src/grid/mod.rs
Normal file
@@ -0,0 +1,619 @@
|
||||
pub mod model;
|
||||
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{OptionalAuth, Session, csprng, middleware::check_bearer_auth},
|
||||
error::{Error, Result},
|
||||
};
|
||||
use axum::{
|
||||
Json,
|
||||
Router,
|
||||
extract::{
|
||||
Path,
|
||||
Query,
|
||||
State,
|
||||
WebSocketUpgrade,
|
||||
ws::{Message, WebSocket},
|
||||
},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use model::{
|
||||
ClientMessage,
|
||||
CreateMapPayload,
|
||||
GridCell,
|
||||
GridMap,
|
||||
GridToken,
|
||||
MapPermission,
|
||||
MapRole,
|
||||
MapState,
|
||||
ServerMessage,
|
||||
UpdatePermissionPayload,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/maps", get(list_maps))
|
||||
.route("/maps", post(create_map))
|
||||
.route("/maps/{id}", get(get_map))
|
||||
.route("/maps/{id}", delete(delete_map))
|
||||
.route("/maps/{id}/permissions", get(list_permissions))
|
||||
.route("/maps/{id}/permissions", put(update_permission))
|
||||
.route("/maps/{id}/ws", get(ws_handler))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fetch the role of `user_id` on `map_id`, or `None` if no record exists.
|
||||
async fn get_user_role(map_id: &str, user_id: i64) -> crate::error::Result<Option<MapRole>> {
|
||||
let pool = siren_core::data::pool();
|
||||
let perm: Option<MapPermission> = sqlx::query_as(
|
||||
"SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1 AND user_id = $2",
|
||||
)
|
||||
.bind(map_id)
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(perm.map(|p| p.role))
|
||||
}
|
||||
|
||||
/// Returns whether the caller can view the map:
|
||||
/// - Public maps: always true.
|
||||
/// - Private maps: true only if the user has any role.
|
||||
async fn can_view(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
if map.is_public {
|
||||
return true;
|
||||
}
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
/// Returns whether the caller can edit the map (editor or owner role).
|
||||
async fn can_edit(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|r| r.can_edit())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Returns whether the caller is the owner.
|
||||
async fn is_owner(map: &GridMap, session: &Option<Session>) -> bool {
|
||||
let Some(s) = session else { return false };
|
||||
let user_id = s.user_id as i64;
|
||||
get_user_role(&map.id, user_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|r| r.is_owner())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// REST handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_maps(OptionalAuth(session): OptionalAuth) -> Result<Json<Vec<GridMap>>> {
|
||||
let pool = siren_core::data::pool();
|
||||
let maps: Vec<GridMap> = match &session {
|
||||
Some(s) => {
|
||||
let user_id = s.user_id as i64;
|
||||
sqlx::query_as(
|
||||
"SELECT DISTINCT gm.*
|
||||
FROM grid_maps gm
|
||||
LEFT JOIN map_permissions mp ON mp.map_id = gm.id AND mp.user_id = $1
|
||||
WHERE gm.is_public = TRUE OR mp.user_id IS NOT NULL
|
||||
ORDER BY gm.created_at DESC",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
sqlx::query_as("SELECT * FROM grid_maps WHERE is_public = TRUE ORDER BY created_at DESC")
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
Ok(Json(maps))
|
||||
}
|
||||
|
||||
pub async fn create_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Json(payload): Json<CreateMapPayload>,
|
||||
) -> Result<(StatusCode, Json<GridMap>)> {
|
||||
let session = session.ok_or_else(|| Error::from(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
let user_id = session.user_id as i64;
|
||||
let map_id = csprng(32);
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: GridMap = sqlx::query_as(
|
||||
"INSERT INTO grid_maps (id, name, is_public, owner_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *",
|
||||
)
|
||||
.bind(&map_id)
|
||||
.bind(&payload.name)
|
||||
.bind(payload.is_public)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
// Auto-assign the creator as owner in map_permissions
|
||||
sqlx::query("INSERT INTO map_permissions (map_id, user_id, role) VALUES ($1, $2, 'owner')")
|
||||
.bind(&map_id)
|
||||
.bind(user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok((StatusCode::CREATED, Json(map)))
|
||||
}
|
||||
|
||||
pub async fn get_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<MapState>> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !can_view(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
let cells: Vec<GridCell> = sqlx::query_as("SELECT * FROM grid_cells WHERE map_id = $1")
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let tokens: Vec<GridToken> = sqlx::query_as("SELECT * FROM grid_tokens WHERE map_id = $1")
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(MapState { map, cells, tokens }))
|
||||
}
|
||||
|
||||
pub async fn delete_map(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission management
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_permissions(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<Vec<MapPermission>>> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
let perms: Vec<MapPermission> =
|
||||
sqlx::query_as("SELECT map_id, user_id, role FROM map_permissions WHERE map_id = $1")
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(Json(perms))
|
||||
}
|
||||
|
||||
pub async fn update_permission(
|
||||
OptionalAuth(session): OptionalAuth,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdatePermissionPayload>,
|
||||
) -> Result<StatusCode> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: Option<GridMap> = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let map = map.ok_or_else(|| Error::not_found("Map not found".into()))?;
|
||||
|
||||
if !is_owner(&map, &session).await {
|
||||
return Err(StatusCode::FORBIDDEN.into());
|
||||
}
|
||||
|
||||
// Prevent the owner from removing their own owner record
|
||||
let caller_id = session.as_ref().map(|s| s.user_id as i64).unwrap_or(0);
|
||||
if payload.user_id == caller_id && payload.role.as_ref().map(|r| r.is_owner()) == Some(false) {
|
||||
return Err(Error::from(StatusCode::UNPROCESSABLE_ENTITY));
|
||||
}
|
||||
|
||||
match payload.role {
|
||||
Some(role) => {
|
||||
sqlx::query(
|
||||
"INSERT INTO map_permissions (map_id, user_id, role)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (map_id, user_id) DO UPDATE SET role = EXCLUDED.role",
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(payload.user_id)
|
||||
.bind(role)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
None => {
|
||||
sqlx::query("DELETE FROM map_permissions WHERE map_id = $1 AND user_id = $2")
|
||||
.bind(&id)
|
||||
.bind(payload.user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WsQuery {
|
||||
/// Optional Bearer token passed as a query parameter for WS auth.
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(map_id): Path<String>,
|
||||
Query(query): Query<WsQuery>,
|
||||
) -> impl IntoResponse {
|
||||
// Resolve the session from query param (WS can't easily send headers)
|
||||
let session: Option<Session> = match query.token {
|
||||
Some(ref tok) => check_bearer_auth(tok).await.ok(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, map_id, session))
|
||||
}
|
||||
|
||||
async fn handle_socket(
|
||||
socket: WebSocket,
|
||||
state: Arc<AppState>,
|
||||
map_id: String,
|
||||
session: Option<Session>,
|
||||
) {
|
||||
// Load the map and verify the caller can view it
|
||||
let map_state = match fetch_map_state(&map_id).await {
|
||||
Ok(ms) => ms,
|
||||
Err(_) => return, // map doesn't exist
|
||||
};
|
||||
|
||||
if !can_view(&map_state.map, &session).await {
|
||||
// Refuse the connection silently (upgrade already happened; just close)
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = can_edit(&map_state.map, &session).await;
|
||||
|
||||
// Get or create a broadcast channel for this map
|
||||
let tx = state
|
||||
.map_rooms
|
||||
.entry(map_id.clone())
|
||||
.or_insert_with(|| {
|
||||
let (tx, _) = broadcast::channel(256);
|
||||
tx
|
||||
})
|
||||
.clone();
|
||||
let mut rx = tx.subscribe();
|
||||
|
||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||
|
||||
// Send the current full map state to the newly connected client
|
||||
let init_msg = ServerMessage::State {
|
||||
cells: map_state.cells,
|
||||
tokens: map_state.tokens,
|
||||
colors: map_state.map.colors,
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&init_msg) {
|
||||
let _ = ws_tx.send(Message::Text(json.into())).await;
|
||||
}
|
||||
|
||||
// Task 1: forward broadcast messages to this socket
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(json) = rx.recv().await {
|
||||
if ws_tx.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Task 2: receive messages from this client, persist, and broadcast
|
||||
let tx_clone = tx.clone();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = ws_rx.next().await {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
handle_client_message(&text, &map_id, editor, &tx_clone).await;
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => recv_task.abort(),
|
||||
_ = &mut recv_task => send_task.abort(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_map_state(map_id: &str) -> crate::error::Result<MapState> {
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let map: GridMap = sqlx::query_as("SELECT * FROM grid_maps WHERE id = $1")
|
||||
.bind(map_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let cells: Vec<GridCell> = sqlx::query_as("SELECT * FROM grid_cells WHERE map_id = $1")
|
||||
.bind(map_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let tokens: Vec<GridToken> = sqlx::query_as("SELECT * FROM grid_tokens WHERE map_id = $1")
|
||||
.bind(map_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(MapState { map, cells, tokens })
|
||||
}
|
||||
|
||||
async fn handle_client_message(
|
||||
raw: &str,
|
||||
map_id: &str,
|
||||
can_edit: bool,
|
||||
tx: &broadcast::Sender<String>,
|
||||
) {
|
||||
let client_msg: ClientMessage = match serde_json::from_str(raw) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
log::warn!("Invalid WS message: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// All mutating messages require editor or owner role
|
||||
if !can_edit {
|
||||
let err = ServerMessage::Error {
|
||||
message: "You do not have permission to edit this map.".into(),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&err) {
|
||||
let _ = tx.send(json);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let pool = siren_core::data::pool();
|
||||
|
||||
let server_msg: Option<ServerMessage> = match client_msg {
|
||||
ClientMessage::PaintCell { x, y, color } => {
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO grid_cells (map_id, x, y, color)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (map_id, x, y) DO UPDATE SET color = EXCLUDED.color",
|
||||
)
|
||||
.bind(map_id)
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.bind(&color)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Some(ServerMessage::CellPainted { x, y, color }),
|
||||
Err(e) => {
|
||||
log::error!("DB error painting cell: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::PaintCells { cells } => {
|
||||
let mut tx_db = match pool.begin().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
log::error!("DB error starting transaction for batch paint: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut ok = true;
|
||||
for cell in &cells {
|
||||
let res = sqlx::query(
|
||||
"INSERT INTO grid_cells (map_id, x, y, color)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (map_id, x, y) DO UPDATE SET color = EXCLUDED.color",
|
||||
)
|
||||
.bind(map_id)
|
||||
.bind(cell.x)
|
||||
.bind(cell.y)
|
||||
.bind(&cell.color)
|
||||
.execute(&mut *tx_db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = res {
|
||||
log::error!("DB error in batch paint cell ({},{}): {e}", cell.x, cell.y);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if ok {
|
||||
if let Err(e) = tx_db.commit().await {
|
||||
log::error!("DB error committing batch paint: {e}");
|
||||
None
|
||||
} else {
|
||||
Some(ServerMessage::CellsBatchPainted { cells })
|
||||
}
|
||||
} else {
|
||||
let _ = tx_db.rollback().await;
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::EraseCell { x, y } => {
|
||||
let result = sqlx::query("DELETE FROM grid_cells WHERE map_id = $1 AND x = $2 AND y = $3")
|
||||
.bind(map_id)
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Some(ServerMessage::CellErased { x, y }),
|
||||
Err(e) => {
|
||||
log::error!("DB error erasing cell: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::AddToken { x, y, label, color } => {
|
||||
let token_id = csprng(16);
|
||||
let result: sqlx::Result<GridToken> = sqlx::query_as(
|
||||
"INSERT INTO grid_tokens (id, map_id, x, y, label, color)
|
||||
VALUES ($1, $2, $3, $4, $5, $6) RETURNING *",
|
||||
)
|
||||
.bind(&token_id)
|
||||
.bind(map_id)
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.bind(&label)
|
||||
.bind(&color)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(token) => Some(ServerMessage::TokenAdded {
|
||||
id: token.id,
|
||||
x: token.x,
|
||||
y: token.y,
|
||||
label: token.label,
|
||||
color: token.color,
|
||||
}),
|
||||
Err(e) => {
|
||||
log::error!("DB error adding token: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::MoveToken { id, x, y } => {
|
||||
let result =
|
||||
sqlx::query("UPDATE grid_tokens SET x = $1, y = $2 WHERE id = $3 AND map_id = $4")
|
||||
.bind(x)
|
||||
.bind(y)
|
||||
.bind(&id)
|
||||
.bind(map_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => Some(ServerMessage::TokenMoved { id, x, y }),
|
||||
Ok(_) => None,
|
||||
Err(e) => {
|
||||
log::error!("DB error moving token: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::DeleteToken { id } => {
|
||||
let result = sqlx::query("DELETE FROM grid_tokens WHERE id = $1 AND map_id = $2")
|
||||
.bind(&id)
|
||||
.bind(map_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => Some(ServerMessage::TokenDeleted { id }),
|
||||
Ok(_) => None,
|
||||
Err(e) => {
|
||||
log::error!("DB error deleting token: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::UpdateColors { colors } => {
|
||||
let result =
|
||||
sqlx::query("UPDATE grid_maps SET colors = $1, updated_at = NOW() WHERE id = $2")
|
||||
.bind(&colors)
|
||||
.bind(map_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Some(ServerMessage::ColorsUpdated { colors }),
|
||||
Err(e) => {
|
||||
log::error!("DB error updating colors: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(msg) = server_msg {
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
let _ = tx.send(json);
|
||||
}
|
||||
}
|
||||
}
|
||||
190
crates/siren-api/src/grid/model.rs
Normal file
190
crates/siren-api/src/grid/model.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use chrono::NaiveDateTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Map Role / Permission
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::Type, Clone, Debug, PartialEq, Eq)]
|
||||
#[sqlx(type_name = "text", rename_all = "lowercase")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MapRole {
|
||||
Owner,
|
||||
Editor,
|
||||
Viewer,
|
||||
}
|
||||
|
||||
impl MapRole {
|
||||
/// Returns true if this role can mutate map content (paint, tokens, colors).
|
||||
pub fn can_edit(&self) -> bool {
|
||||
matches!(self, MapRole::Owner | MapRole::Editor)
|
||||
}
|
||||
|
||||
/// Returns true if this role can manage permissions and delete the map.
|
||||
pub fn is_owner(&self) -> bool {
|
||||
matches!(self, MapRole::Owner)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct MapPermission {
|
||||
pub map_id: String,
|
||||
pub user_id: i64,
|
||||
pub role: MapRole,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Grid Map
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct GridMap {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub is_public: bool,
|
||||
pub owner_id: i64,
|
||||
pub colors: Vec<String>,
|
||||
pub created_at: NaiveDateTime,
|
||||
pub updated_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct CreateMapPayload {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub is_public: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct UpdatePermissionPayload {
|
||||
/// Discord user ID of the target user.
|
||||
pub user_id: i64,
|
||||
/// New role to assign. Omit (null) to remove the permission entry.
|
||||
pub role: Option<MapRole>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Grid Cell (no id column — composite PK in DB)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct GridCell {
|
||||
pub map_id: String,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub color: String,
|
||||
}
|
||||
|
||||
/// Lightweight cell used for batch operations.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct CellPatch {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub color: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Grid Token
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct GridToken {
|
||||
pub id: String,
|
||||
pub map_id: String,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub label: String,
|
||||
pub color: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Full map state (used on initial WS connect and REST GET)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct MapState {
|
||||
pub map: GridMap,
|
||||
pub cells: Vec<GridCell>,
|
||||
pub tokens: Vec<GridToken>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket message types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ClientMessage {
|
||||
PaintCell {
|
||||
x: i32,
|
||||
y: i32,
|
||||
color: String,
|
||||
},
|
||||
PaintCells {
|
||||
cells: Vec<CellPatch>,
|
||||
},
|
||||
EraseCell {
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
AddToken {
|
||||
x: i32,
|
||||
y: i32,
|
||||
label: String,
|
||||
color: String,
|
||||
},
|
||||
MoveToken {
|
||||
id: String,
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
DeleteToken {
|
||||
id: String,
|
||||
},
|
||||
UpdateColors {
|
||||
colors: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ServerMessage {
|
||||
State {
|
||||
cells: Vec<GridCell>,
|
||||
tokens: Vec<GridToken>,
|
||||
colors: Vec<String>,
|
||||
},
|
||||
CellPainted {
|
||||
x: i32,
|
||||
y: i32,
|
||||
color: String,
|
||||
},
|
||||
CellsBatchPainted {
|
||||
cells: Vec<CellPatch>,
|
||||
},
|
||||
CellErased {
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
TokenAdded {
|
||||
id: String,
|
||||
x: i32,
|
||||
y: i32,
|
||||
label: String,
|
||||
color: String,
|
||||
},
|
||||
TokenMoved {
|
||||
id: String,
|
||||
x: i32,
|
||||
y: i32,
|
||||
},
|
||||
TokenDeleted {
|
||||
id: String,
|
||||
},
|
||||
ColorsUpdated {
|
||||
colors: Vec<String>,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
20
crates/siren-api/src/lib.rs
Normal file
20
crates/siren-api/src/lib.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
pub mod app;
|
||||
mod app_state;
|
||||
pub mod audio;
|
||||
pub mod auth;
|
||||
pub mod dice;
|
||||
pub mod error;
|
||||
pub mod grid;
|
||||
|
||||
pub use app::App;
|
||||
pub use app_state::AppState;
|
||||
use axum::Router;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.nest("/auth", auth::get_routes())
|
||||
.nest("/audio/{guild_id}", audio::get_routes())
|
||||
.nest("/dice", dice::get_routes())
|
||||
.nest("/grid", grid::get_routes())
|
||||
}
|
||||
22
crates/siren-bot/Cargo.toml
Normal file
22
crates/siren-bot/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "siren-bot"
|
||||
edition.workspace = true
|
||||
version.workspace = true
|
||||
rust-version.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
siren-core = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serenity = { workspace = true }
|
||||
songbird = { workspace = true }
|
||||
symphonia = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
@@ -1,6 +1,13 @@
|
||||
use serenity::all::{
|
||||
CommandInteraction, Context, CreateInteractionResponse, CreateInteractionResponseMessage,
|
||||
CreateMessage, EditInteractionResponse, InteractionResponseFlags, Message, ModalInteraction,
|
||||
CommandInteraction,
|
||||
Context,
|
||||
CreateInteractionResponse,
|
||||
CreateInteractionResponseMessage,
|
||||
CreateMessage,
|
||||
EditInteractionResponse,
|
||||
InteractionResponseFlags,
|
||||
Message,
|
||||
ModalInteraction,
|
||||
UserId,
|
||||
};
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use reqwest::Url;
|
||||
use serenity::all::UserId;
|
||||
use serenity::client::Cache;
|
||||
use serenity::model::prelude::{GuildId, ChannelId};
|
||||
use serenity::{
|
||||
all::UserId,
|
||||
client::Cache,
|
||||
model::prelude::{ChannelId, GuildId},
|
||||
};
|
||||
use songbird::Songbird;
|
||||
|
||||
use crate::error::{SirenResult, Error as SirenError};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod mute;
|
||||
pub mod pause;
|
||||
@@ -24,19 +24,28 @@ pub async fn join_voice_channel(
|
||||
manager: &Arc<Songbird>,
|
||||
guild_id: &GuildId,
|
||||
user_id: &UserId,
|
||||
) -> SirenResult<ChannelId> {
|
||||
) -> Result<ChannelId> {
|
||||
let channel_id = find_voice_channel(cache, guild_id, user_id)?;
|
||||
log::debug!("<{}> Joining channel {}", guild_id.get(), channel_id.get());
|
||||
manager
|
||||
match manager
|
||||
.join(guild_id.to_owned(), channel_id.to_owned())
|
||||
.await?;
|
||||
Ok(channel_id)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(channel_id),
|
||||
Err(e) => {
|
||||
if e.should_leave_server() || e.should_reconnect_driver() {
|
||||
log::debug!("<{}> Cleaning up failed voice connection", guild_id.get());
|
||||
let _ = manager.remove(*guild_id).await;
|
||||
}
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Leaves a voice channel.
|
||||
*/
|
||||
pub async fn leave_voice_channel(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenResult<()> {
|
||||
pub async fn leave_voice_channel(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<()> {
|
||||
if manager.get(guild_id.to_owned()).is_some() {
|
||||
log::debug!("<{}> Disconnecting from channel", guild_id.get());
|
||||
manager.remove(*guild_id).await?;
|
||||
@@ -60,10 +69,10 @@ fn find_voice_channel(
|
||||
cache: &Arc<Cache>,
|
||||
guild_id: &GuildId,
|
||||
user_id: &UserId,
|
||||
) -> SirenResult<ChannelId> {
|
||||
) -> Result<ChannelId> {
|
||||
let guild = match guild_id.to_guild_cached(cache) {
|
||||
Some(g) => g,
|
||||
None => return Err(SirenError::new(404, "Guild not found".to_string())),
|
||||
None => return Err(Error::new(404, "Guild not found".to_string())),
|
||||
};
|
||||
|
||||
match guild
|
||||
@@ -72,7 +81,7 @@ fn find_voice_channel(
|
||||
.and_then(|voice_state| voice_state.channel_id)
|
||||
{
|
||||
Some(channel) => Ok(channel),
|
||||
None => Err(SirenError::new(
|
||||
None => Err(Error::new(
|
||||
400,
|
||||
"User is not in a voice channel".to_string(),
|
||||
)),
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::{
|
||||
chat::{edit_response, process_message},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CreateCommand},
|
||||
prelude::*,
|
||||
};
|
||||
use crate::bot::chat::{edit_response, process_message};
|
||||
use crate::bot::handler::get_songbird;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
@@ -1,12 +1,14 @@
|
||||
use std::sync::Arc;
|
||||
use crate::{
|
||||
chat::{edit_response, process_message},
|
||||
error::{Error, Result},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CreateCommand, GuildId},
|
||||
prelude::*,
|
||||
};
|
||||
use songbird::Songbird;
|
||||
use crate::bot::chat::{edit_response, process_message};
|
||||
use crate::bot::handler::get_songbird;
|
||||
use crate::error::{Error, SirenResult};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
@@ -39,7 +41,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn pause_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenResult<()> {
|
||||
pub async fn pause_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<()> {
|
||||
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
|
||||
let handler = handler_lock.lock().await;
|
||||
match handler.queue().current() {
|
||||
@@ -48,7 +50,7 @@ pub async fn pause_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenRe
|
||||
return Err(Error {
|
||||
status: 404,
|
||||
details: "No track is currently playing".to_string(),
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1,29 +1,35 @@
|
||||
use super::{is_valid_url, join_voice_channel, leave_voice_channel};
|
||||
use crate::{
|
||||
chat::{create_message_response, edit_response, process_message},
|
||||
error::{Error, Result},
|
||||
handler::{get_client, get_songbird},
|
||||
ytdlp::{YtDlp, YtDlpItem},
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption},
|
||||
async_trait,
|
||||
model::prelude::GuildId,
|
||||
prelude::*,
|
||||
};
|
||||
use siren_core::data::guilds::GuildCache;
|
||||
use songbird::{
|
||||
Event,
|
||||
EventHandler,
|
||||
Songbird,
|
||||
TrackEvent,
|
||||
input::{Input, YoutubeDl},
|
||||
tracks::TrackHandle,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use serenity::all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption};
|
||||
use serenity::model::prelude::GuildId;
|
||||
use serenity::{prelude::*, async_trait};
|
||||
use songbird::input::{Input, YoutubeDl};
|
||||
use songbird::tracks::TrackHandle;
|
||||
use songbird::{Event, EventHandler, Songbird, TrackEvent};
|
||||
|
||||
use crate::data::guilds::GuildCache;
|
||||
use crate::bot::ytdlp::{YtDlp, YtDlpItem};
|
||||
use crate::error::{SirenResult, Error as SirenError};
|
||||
use crate::{signal_shutdown, HttpKey};
|
||||
|
||||
use super::{is_valid_url, join_voice_channel};
|
||||
|
||||
use crate::bot::chat::{create_message_response, edit_response, process_message};
|
||||
use crate::bot::handler::{get_client, get_songbird};
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Process the command options
|
||||
let track_url = match command.data.options.first() {
|
||||
Some(o) => o.value.as_str().unwrap(),
|
||||
None => {
|
||||
log::warn!(
|
||||
"{} attempted to play a track without a track option",
|
||||
"<{}> {} attempted to play a track without a track option",
|
||||
command.guild_id.unwrap(),
|
||||
command.user.id.get()
|
||||
);
|
||||
create_message_response(&ctx, &command, "Track option is missing".to_string(), false).await;
|
||||
@@ -63,13 +69,20 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
let mut message = format!("Added {} tracks", items.len());
|
||||
if items.len() == 0 {
|
||||
message = "No tracks were played".to_string();
|
||||
log::warn!("<{guild_id}> No tracks were played");
|
||||
if let Err(err) = leave_voice_channel(&manager, guild_id).await {
|
||||
log::error!("Failed to leave voice channel: {}", err);
|
||||
};
|
||||
} else if items.len() == 1 {
|
||||
message = format!("Added **{}**", items[0].get_title());
|
||||
}
|
||||
edit_response(&ctx, &command, message).await;
|
||||
}
|
||||
Err(err) => {
|
||||
log::warn!("Failed to play track: {}", err);
|
||||
log::error!("Failed to play track: {}", err);
|
||||
if let Err(err) = leave_voice_channel(&manager, guild_id).await {
|
||||
log::error!("Failed to leave voice channel: {}", err);
|
||||
}
|
||||
edit_response(&ctx, &command, format!("Failed to play track: {}", err)).await;
|
||||
}
|
||||
};
|
||||
@@ -85,7 +98,7 @@ pub async fn enqueue_track(
|
||||
manager: &Arc<Songbird>,
|
||||
guild_id: GuildId,
|
||||
track_url: &str,
|
||||
) -> SirenResult<Vec<YtDlpItem>> {
|
||||
) -> Result<Vec<YtDlpItem>> {
|
||||
let mut playlist_items: Vec<YtDlpItem> = Vec::new();
|
||||
if let Some(handler_lock) = manager.get(guild_id) {
|
||||
let mut handler = handler_lock.lock().await;
|
||||
@@ -95,19 +108,10 @@ pub async fn enqueue_track(
|
||||
// Check if the URL is valid
|
||||
if !valid {
|
||||
log::warn!("<{guild_id}> Invalid track url: {}", track_url);
|
||||
return Err(SirenError::new(
|
||||
422,
|
||||
format!("Invalid track url: {}", track_url),
|
||||
));
|
||||
return Err(Error::new(422, format!("Invalid track url: {}", track_url)));
|
||||
}
|
||||
|
||||
playlist_items = match get_ytdlp_items(&track_url) {
|
||||
Ok(items) => items,
|
||||
Err(err) => {
|
||||
log::warn!("<{guild_id}> Failed to get playlist urls: {}", err);
|
||||
return Err(SirenError::new(422, err.to_string()));
|
||||
}
|
||||
};
|
||||
playlist_items = get_ytdlp_items(&track_url)?;
|
||||
|
||||
// Add each track to the queue
|
||||
for item in &playlist_items {
|
||||
@@ -141,13 +145,26 @@ pub async fn enqueue_track(
|
||||
Ok(playlist_items)
|
||||
}
|
||||
|
||||
pub fn get_ytdlp_items(url: &str) -> SirenResult<Vec<YtDlpItem>> {
|
||||
pub fn get_ytdlp_items(url: &str) -> Result<Vec<YtDlpItem>> {
|
||||
let output = YtDlp::new()
|
||||
.arg("--flat-playlist")
|
||||
.arg("--dump-json")
|
||||
.arg("--no-check-formats")
|
||||
.arg(url)
|
||||
.execute()?;
|
||||
let items: Vec<YtDlpItem> = String::from_utf8(output.stdout)?
|
||||
|
||||
// Check if yt-dlp exited successfully; log stderr if not
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(Error::new(
|
||||
500,
|
||||
format!("yt-dlp failed ({}): {}", output.status, stderr.trim()),
|
||||
));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8(output.stdout)?;
|
||||
|
||||
let items: Vec<YtDlpItem> = stdout
|
||||
.split('\n')
|
||||
.filter_map(|line| {
|
||||
if line.is_empty() {
|
||||
@@ -155,14 +172,14 @@ pub fn get_ytdlp_items(url: &str) -> SirenResult<Vec<YtDlpItem>> {
|
||||
} else {
|
||||
Some(
|
||||
serde_json::from_slice::<YtDlpItem>(line.as_bytes())
|
||||
.map_err(|err| SirenError::new(500, err.to_string())),
|
||||
.map_err(|err| Error::new(500, err.to_string())),
|
||||
)
|
||||
}
|
||||
})
|
||||
.filter_map(|parsed| match parsed {
|
||||
Ok(item) => Some(item),
|
||||
Err(err) => {
|
||||
log::warn!("Failed to parse playlist item: {}", err);
|
||||
log::warn!("Failed to parse yt-dlp item: {}", err);
|
||||
None
|
||||
}
|
||||
})
|
||||
@@ -1,14 +1,14 @@
|
||||
use std::sync::Arc;
|
||||
use crate::{
|
||||
chat::{edit_response, process_message},
|
||||
error::{Error, Result},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CreateCommand},
|
||||
all::{CommandInteraction, CreateCommand, GuildId},
|
||||
prelude::*,
|
||||
};
|
||||
use serenity::all::GuildId;
|
||||
use songbird::Songbird;
|
||||
use crate::bot::chat::{edit_response, process_message};
|
||||
use crate::bot::commands::audio::pause::pause_track;
|
||||
use crate::bot::handler::get_songbird;
|
||||
use crate::error::{Error, SirenResult};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
@@ -41,7 +41,7 @@ pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn resume_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenResult<()> {
|
||||
pub async fn resume_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> Result<()> {
|
||||
if let Some(handler_lock) = manager.get(guild_id.to_owned()) {
|
||||
let handler = handler_lock.lock().await;
|
||||
match handler.queue().current() {
|
||||
@@ -50,7 +50,7 @@ pub async fn resume_track(manager: &Arc<Songbird>, guild_id: &GuildId) -> SirenR
|
||||
return Err(Error {
|
||||
status: 404,
|
||||
details: "No track is currently playing".to_string(),
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1,11 +1,12 @@
|
||||
use crate::{
|
||||
chat::{edit_response, process_message},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CreateCommand},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::bot::chat::{edit_response, process_message};
|
||||
use crate::bot::handler::get_songbird;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
process_message(&ctx, &command, false).await;
|
||||
@@ -1,11 +1,12 @@
|
||||
use crate::{
|
||||
chat::{edit_response, process_message},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CreateCommand},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::bot::chat::{edit_response, process_message};
|
||||
use crate::bot::handler::get_songbird;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
process_message(&ctx, &command, false).await;
|
||||
@@ -1,16 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
chat::{create_message_response, edit_response, process_message},
|
||||
handler::get_songbird,
|
||||
};
|
||||
use serenity::{
|
||||
all::{CommandInteraction, CommandOptionType, CreateCommand, CreateCommandOption},
|
||||
model::prelude::GuildId,
|
||||
prelude::*,
|
||||
};
|
||||
use siren_core::data::guilds::GuildCache;
|
||||
use songbird::Songbird;
|
||||
|
||||
use crate::data::guilds::GuildCache;
|
||||
|
||||
use crate::bot::chat::{create_message_response, edit_response, process_message};
|
||||
use crate::bot::handler::get_songbird;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Process the command options
|
||||
@@ -1,11 +1,19 @@
|
||||
use crate::chat::process_message;
|
||||
use chrono::{DateTime, NaiveDate, TimeZone, Utc};
|
||||
use regex::Regex;
|
||||
use serenity::all::{
|
||||
Color, CommandInteraction, CommandOptionType, Context, CreateCommand, CreateCommandOption,
|
||||
CreateEmbed, CreateEmbedFooter, EditInteractionResponse, Timestamp,
|
||||
Color,
|
||||
CommandInteraction,
|
||||
CommandOptionType,
|
||||
Context,
|
||||
CreateCommand,
|
||||
CreateCommandOption,
|
||||
CreateEmbed,
|
||||
CreateEmbedFooter,
|
||||
EditInteractionResponse,
|
||||
Timestamp,
|
||||
};
|
||||
|
||||
use crate::{bot::chat::process_message, data::events::Event};
|
||||
use siren_core::data::events::Event;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Create the initial response
|
||||
@@ -85,6 +93,7 @@ pub fn register() -> CreateCommand {
|
||||
// (in) XX <seconds, minutes, hours, days, weeks>
|
||||
// (at) YYYY-MM-DD HH:MM (AM/PM)
|
||||
// (at) MM DD (YYYY) HH:MM (AM/PM)
|
||||
#[allow(dead_code)]
|
||||
fn parse_datetime(input: &str) -> Option<DateTime<Utc>> {
|
||||
let regexes = vec![
|
||||
Regex::new(r"(?i)^\(?at\)?\s+(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2})\s*(AM|PM)?$").unwrap(),
|
||||
@@ -1,10 +1,20 @@
|
||||
use serenity::all::{
|
||||
ButtonStyle, CommandInteraction, CommandOptionType, Context, CreateActionRow, CreateButton,
|
||||
CreateCommand, CreateCommandOption, CreateMessage, Mentionable, UserId,
|
||||
use crate::{
|
||||
chat::{create_message_response, edit_response},
|
||||
commands::fun::roll::parse_dice,
|
||||
};
|
||||
use serenity::all::{
|
||||
ButtonStyle,
|
||||
CommandInteraction,
|
||||
CommandOptionType,
|
||||
Context,
|
||||
CreateActionRow,
|
||||
CreateButton,
|
||||
CreateCommand,
|
||||
CreateCommandOption,
|
||||
CreateMessage,
|
||||
Mentionable,
|
||||
UserId,
|
||||
};
|
||||
use serenity::builder::CreateEmbed;
|
||||
use crate::bot::chat::{create_message_response, edit_response};
|
||||
use crate::bot::commands::fun::roll::parse_dice;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Check if the roll result is hidden
|
||||
@@ -1,14 +1,20 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
use rand::Rng;
|
||||
use serenity::all::{
|
||||
ButtonStyle, CommandInteraction, CommandOptionType, Context, CreateActionRow, CreateButton,
|
||||
CreateCommand, CreateCommandOption, CreateEmbed, CreateMessage, Mentionable, UserId,
|
||||
use crate::{
|
||||
chat::{create_message_response, edit_response},
|
||||
error::{Error, Result},
|
||||
};
|
||||
|
||||
use crate::bot::chat::{create_message_response, edit_response};
|
||||
use crate::error::{Error, SirenResult};
|
||||
use crate::utils::{a_or_an, number_to_words};
|
||||
use rand::RngExt;
|
||||
use serenity::all::{
|
||||
CommandInteraction,
|
||||
CommandOptionType,
|
||||
Context,
|
||||
CreateCommand,
|
||||
CreateCommandOption,
|
||||
CreateEmbed,
|
||||
CreateMessage,
|
||||
Mentionable,
|
||||
UserId,
|
||||
};
|
||||
use siren_core::utils::{a_or_an, number_to_words};
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
// Check if the roll result is private
|
||||
@@ -118,14 +124,14 @@ pub fn roll_dice(count: u32, sides: u32, modifier: i32) -> i32 {
|
||||
let mut rolls = Vec::new();
|
||||
let mut total = modifier;
|
||||
for _ in 0..count {
|
||||
let roll = rand::thread_rng().gen_range(1..=sides as i32);
|
||||
let roll = rand::rng().random_range(1..=sides as i32);
|
||||
total += roll;
|
||||
rolls.push(roll);
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
pub fn parse_dice(dice: &str) -> SirenResult<(u32, u32, i32)> {
|
||||
pub fn parse_dice(dice: &str) -> Result<(u32, u32, i32)> {
|
||||
// If the input is just a number (e.g., "20" or "6"), assume it's the number of sides
|
||||
if let Ok(n) = dice.parse::<u32>() {
|
||||
return Ok((1, n, 0)); // Assume 1 dice with 0 modifiers
|
||||
@@ -176,7 +182,7 @@ pub fn parse_dice(dice: &str) -> SirenResult<(u32, u32, i32)> {
|
||||
"Expected one of d4, d6, d8, d10, d12, d20, d100 but received d{}",
|
||||
sides_part
|
||||
),
|
||||
))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
pub mod audio;
|
||||
pub mod chat;
|
||||
pub mod event;
|
||||
pub mod fun;
|
||||
pub mod utility;
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::chat::create_message_response;
|
||||
use serenity::all::{CommandInteraction, Context, CreateCommand};
|
||||
use crate::bot::chat::create_message_response;
|
||||
|
||||
pub async fn run(ctx: &Context, command: &CommandInteraction) {
|
||||
log::debug!("Ping command executed");
|
||||
89
crates/siren-bot/src/error.rs
Normal file
89
crates/siren-bot/src/error.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Error {
|
||||
pub status: u16,
|
||||
pub details: String,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn new(status: u16, details: String) -> Self {
|
||||
Self { status, details }
|
||||
}
|
||||
|
||||
pub fn not_found(details: String) -> Self {
|
||||
Self::new(404, details)
|
||||
}
|
||||
|
||||
pub fn internal_server_error(details: String) -> Self {
|
||||
Self::new(500, details)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.details.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn description(&self) -> &str {
|
||||
&self.details
|
||||
}
|
||||
}
|
||||
|
||||
impl From<siren_core::error::Error> for Error {
|
||||
fn from(error: siren_core::error::Error) -> Self {
|
||||
Self::new(error.status, error.details)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serenity::Error> for Error {
|
||||
fn from(error: serenity::Error) -> Self {
|
||||
Self::new(500, format!("Discord error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<songbird::error::JoinError> for Error {
|
||||
fn from(error: songbird::error::JoinError) -> Self {
|
||||
use std::error::Error as StdError;
|
||||
let details = match error.source() {
|
||||
Some(source) => format!("Unable to join channel: {} ({})", error, source),
|
||||
None => format!("Unable to join channel: {}", error),
|
||||
};
|
||||
Self::new(500, details)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<songbird::tracks::ControlError> for Error {
|
||||
fn from(error: songbird::tracks::ControlError) -> Self {
|
||||
Self::new(500, format!("Track control error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Self::new(500, format!("IO error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::string::FromUtf8Error> for Error {
|
||||
fn from(error: std::string::FromUtf8Error) -> Self {
|
||||
Self::new(500, format!("UTF-8 error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(error: reqwest::Error) -> Self {
|
||||
Self::new(500, format!("HTTP client error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(error: serde_json::Error) -> Self {
|
||||
Self::new(500, format!("JSON error: {}", error))
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,30 @@
|
||||
use std::env;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use serenity::all::{CreateInteractionResponse, CreateInteractionResponseMessage, EditInteractionResponse, Interaction, ResumedEvent, UnavailableGuild, UserId};
|
||||
use serenity::async_trait;
|
||||
use serenity::model::gateway::Ready;
|
||||
use serenity::model::channel::Message;
|
||||
use serenity::prelude::*;
|
||||
use super::{chat::create_modal_response, commands};
|
||||
use crate::{
|
||||
HttpKey,
|
||||
commands::fun::roll::{format_roll, roll_dice, send_roll_message},
|
||||
};
|
||||
use serenity::{
|
||||
all::{
|
||||
CreateInteractionResponse,
|
||||
EditInteractionResponse,
|
||||
Interaction,
|
||||
ResumedEvent,
|
||||
UnavailableGuild,
|
||||
UserId,
|
||||
},
|
||||
async_trait,
|
||||
model::{channel::Message, gateway::Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use siren_core::{
|
||||
data::guilds::GuildCache,
|
||||
utils::{a_or_an, number_to_words},
|
||||
};
|
||||
use songbird::Songbird;
|
||||
use crate::bot::commands::chat::generate_response;
|
||||
use crate::bot::commands::fun::roll::{format_roll, roll_dice, send_roll_message};
|
||||
use crate::bot::oai::OAI;
|
||||
use crate::data::guilds::GuildCache;
|
||||
use crate::HttpKey;
|
||||
use crate::utils::{a_or_an, number_to_words};
|
||||
use super::{commands};
|
||||
use super::chat::{create_modal_response};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
pub struct BotHandler {
|
||||
// Open AI Config
|
||||
pub oai: Option<OAI>,
|
||||
pub force_register: bool,
|
||||
}
|
||||
|
||||
static REGISTERED: OnceLock<bool> = OnceLock::new();
|
||||
@@ -33,34 +40,14 @@ pub fn get_client() -> &'static reqwest::Client {
|
||||
}
|
||||
|
||||
impl BotHandler {
|
||||
pub fn new() -> Self {
|
||||
match env::var("OPENAI_TOKEN") {
|
||||
Ok(token) => {
|
||||
log::debug!("OpenAI functionality enabled");
|
||||
let default_model = env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
|
||||
let base_url = env::var("OPENAI_BASE_URL").unwrap();
|
||||
Self {
|
||||
oai: Some(OAI {
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
token,
|
||||
max_conversation_history: 30,
|
||||
max_tokens: 8192,
|
||||
default_model,
|
||||
}),
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
log::warn!("OpenAI functionality disabled");
|
||||
Self { oai: None }
|
||||
}
|
||||
}
|
||||
pub fn new(force_register: bool) -> Self {
|
||||
Self { force_register }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EventHandler for BotHandler {
|
||||
async fn message(&self, ctx: Context, msg: Message) {
|
||||
async fn message(&self, _ctx: Context, msg: Message) {
|
||||
// Ignore bot messages
|
||||
if msg.author.bot {
|
||||
return;
|
||||
@@ -70,14 +57,6 @@ impl EventHandler for BotHandler {
|
||||
if let None = msg.guild_id {
|
||||
log::trace!("Received DM from {}: {}", msg.author, msg.content);
|
||||
}
|
||||
|
||||
// Handle OAI messages
|
||||
match &self.oai {
|
||||
Some(oai) => {
|
||||
handle_oai_messages(oai, &ctx, &msg).await;
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn ready(&self, ctx: Context, ready: Ready) {
|
||||
@@ -109,9 +88,9 @@ impl EventHandler for BotHandler {
|
||||
REGISTERED.set(true).ok();
|
||||
}
|
||||
|
||||
log::trace!("Handling {} guilds", ready.guilds.len());
|
||||
log::debug!("Registering in {} guild(s)", ready.guilds.len());
|
||||
for guild in ready.guilds {
|
||||
update_guild_commands(&ctx, &guild).await;
|
||||
update_guild_commands(&ctx, &guild, self.force_register).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +100,11 @@ impl EventHandler for BotHandler {
|
||||
|
||||
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
|
||||
if let Interaction::Command(command) = interaction {
|
||||
log::trace!("Received COMMAND");
|
||||
log::trace!(
|
||||
"<{}> Received command: {}",
|
||||
command.guild_id.unwrap(),
|
||||
command.data.name
|
||||
);
|
||||
match command.data.name.as_str() {
|
||||
// Match commands without returns
|
||||
"play" => commands::audio::play::run(&ctx, &command).await,
|
||||
@@ -203,25 +186,7 @@ impl EventHandler for BotHandler {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_oai_messages(oai: &OAI, ctx: &Context, msg: &Message) {
|
||||
match msg.mentions_me(&ctx.http).await {
|
||||
Ok(mentioned) => {
|
||||
let bot_in_thread = match msg.channel_id.get_thread_members(&ctx.http).await {
|
||||
Ok(t) => match t.iter().find(|t| t.user_id == ctx.cache.current_user().id) {
|
||||
Some(_) => true,
|
||||
None => false,
|
||||
},
|
||||
Err(_) => false,
|
||||
};
|
||||
if mentioned || bot_in_thread {
|
||||
generate_response(&ctx, &msg, oai).await;
|
||||
}
|
||||
}
|
||||
Err(why) => log::warn!("Could not check mentions: {why}"),
|
||||
};
|
||||
}
|
||||
|
||||
async fn update_guild_commands(ctx: &Context, guild: &UnavailableGuild) {
|
||||
async fn update_guild_commands(ctx: &Context, guild: &UnavailableGuild, force_register: bool) {
|
||||
// List of commands to register for the guild
|
||||
let guild_commands = vec![
|
||||
commands::audio::play::register(),
|
||||
@@ -239,14 +204,7 @@ async fn update_guild_commands(ctx: &Context, guild: &UnavailableGuild) {
|
||||
|
||||
let guild_id = guild.id.get() as i64;
|
||||
let register_commands = match GuildCache::find_by_id(guild_id).await {
|
||||
Some(_) => {
|
||||
env::var("FORCE_REGISTER")
|
||||
.ok()
|
||||
// Parse to true/false
|
||||
.map(|val| val.to_lowercase() == "true")
|
||||
// Default to true on error
|
||||
.unwrap_or(true)
|
||||
}
|
||||
Some(_) => force_register,
|
||||
None => {
|
||||
// If no guild cache is found, create a new one.
|
||||
let guild_cache = GuildCache {
|
||||
@@ -259,7 +217,7 @@ async fn update_guild_commands(ctx: &Context, guild: &UnavailableGuild) {
|
||||
log::error!("Could not insert guild cache: {err}");
|
||||
};
|
||||
true
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
if register_commands {
|
||||
@@ -281,6 +239,6 @@ async fn update_guild_commands(ctx: &Context, guild: &UnavailableGuild) {
|
||||
}
|
||||
};
|
||||
} else {
|
||||
log::debug!("Guild {guild_id} already registered");
|
||||
log::debug!("Guild {guild_id} is already registered");
|
||||
}
|
||||
}
|
||||
14
crates/siren-bot/src/lib.rs
Normal file
14
crates/siren-bot/src/lib.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
pub mod chat;
|
||||
pub mod commands;
|
||||
pub mod error;
|
||||
pub mod handler;
|
||||
pub mod ytdlp;
|
||||
|
||||
use reqwest::Client as HttpClient;
|
||||
use serenity::prelude::TypeMapKey;
|
||||
|
||||
pub struct HttpKey;
|
||||
|
||||
impl TypeMapKey for HttpKey {
|
||||
type Value = HttpClient;
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
mod model;
|
||||
|
||||
use std::process::{Child, Command, Output, Stdio};
|
||||
|
||||
pub use model::*;
|
||||
use std::process::{Child, Command, Output, Stdio};
|
||||
|
||||
const YOUTUBE_DL_COMMAND: &str = "yt-dlp";
|
||||
|
||||
@@ -7,14 +7,14 @@ pub enum YtDlpItem {
|
||||
id: String,
|
||||
url: String,
|
||||
title: String,
|
||||
duration: i32,
|
||||
playlist_index: i32,
|
||||
duration: Option<f64>,
|
||||
playlist_index: Option<i32>,
|
||||
},
|
||||
VideoItem {
|
||||
id: String,
|
||||
webpage_url: String,
|
||||
title: String,
|
||||
duration: i32,
|
||||
duration: Option<f64>,
|
||||
},
|
||||
}
|
||||
|
||||
21
crates/siren-core/Cargo.toml
Normal file
21
crates/siren-core/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "siren-core"
|
||||
edition.workspace = true
|
||||
version.workspace = true
|
||||
rust-version.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_chacha = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
81
crates/siren-core/src/config.rs
Normal file
81
crates/siren-core/src/config.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::error::Result;
|
||||
use std::env;
|
||||
|
||||
pub struct EnvironmentConfiguration {
|
||||
pub rust_log: String,
|
||||
pub discord_token: String,
|
||||
pub discord_secret: String,
|
||||
pub jwt_secret: String,
|
||||
pub postgres_user: String,
|
||||
pub postgres_password: String,
|
||||
pub postgres_database: String,
|
||||
pub postgres_host: String,
|
||||
pub postgres_port: u16,
|
||||
pub api_base_url: String,
|
||||
pub api_port: u16,
|
||||
pub api_session_ttl: u64,
|
||||
pub valkey_host: String,
|
||||
pub valkey_port: u16,
|
||||
pub minio_root_user: String,
|
||||
pub minio_root_password: String,
|
||||
pub minio_host: String,
|
||||
pub minio_port: u16,
|
||||
pub minio_port_internal: u16,
|
||||
pub data_dir_path: Option<String>,
|
||||
pub force_register: bool,
|
||||
pub default_api_key: String,
|
||||
pub default_server: Option<String>,
|
||||
pub default_user: Option<String>,
|
||||
}
|
||||
|
||||
impl EnvironmentConfiguration {
|
||||
pub fn load() -> Result<Self> {
|
||||
Ok(Self {
|
||||
rust_log: env::var("RUST_LOG").unwrap_or_else(|_| "warn,siren=info".to_string()),
|
||||
discord_token: env::var("DISCORD_BOT_TOKEN")?,
|
||||
discord_secret: env::var("DISCORD_CLIENT_SECRET")?,
|
||||
jwt_secret: env::var("JWT_SECRET")?,
|
||||
postgres_user: env::var("POSTGRES_USER")?,
|
||||
postgres_password: env::var("POSTGRES_PASSWORD")?,
|
||||
postgres_database: env::var("POSTGRES_DB")?,
|
||||
postgres_host: env::var("POSTGRES_HOST")?,
|
||||
postgres_port: env::var("POSTGRES_PORT")
|
||||
.unwrap_or_else(|_| "5432".to_string())
|
||||
.parse()
|
||||
.unwrap_or(5432),
|
||||
api_base_url: env::var("API_BASE_URL")?,
|
||||
api_port: env::var("API_PORT")
|
||||
.unwrap_or_else(|_| "3000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(3000),
|
||||
api_session_ttl: env::var("API_SESSION_TTL")
|
||||
.unwrap_or_else(|_| "86400".to_string())
|
||||
.parse()
|
||||
.unwrap_or(86400),
|
||||
valkey_host: env::var("VALKEY_HOST").unwrap_or_else(|_| "localhost".to_string()),
|
||||
valkey_port: env::var("VALKEY_PORT")
|
||||
.unwrap_or_else(|_| "6379".to_string())
|
||||
.parse()
|
||||
.unwrap_or(6379),
|
||||
minio_root_user: env::var("MINIO_ROOT_USER")?,
|
||||
minio_root_password: env::var("MINIO_ROOT_PASSWORD")?,
|
||||
minio_host: env::var("MINIO_HOST").unwrap_or_else(|_| "localhost".to_string()),
|
||||
minio_port: env::var("MINIO_PORT")
|
||||
.unwrap_or_else(|_| "9000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(9000),
|
||||
minio_port_internal: env::var("MINIO_PORT_INTERNAL")
|
||||
.unwrap_or_else(|_| "9001".to_string())
|
||||
.parse()
|
||||
.unwrap_or(9001),
|
||||
data_dir_path: env::var("DATA_DIR_PATH").ok().filter(|s| !s.is_empty()),
|
||||
force_register: env::var("FORCE_REGISTER")
|
||||
.ok()
|
||||
.map(|v| v.to_lowercase() == "true")
|
||||
.unwrap_or(false),
|
||||
default_api_key: env::var("DEFAULT_API_KEY").unwrap_or_default(),
|
||||
default_server: env::var("DEFAULT_SERVER").ok().filter(|s| !s.is_empty()),
|
||||
default_user: env::var("DEFAULT_USER").ok().filter(|s| !s.is_empty()),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ impl Condition {
|
||||
|
||||
let right_list = right
|
||||
.iter()
|
||||
.map(|v| "'?'".to_string())
|
||||
.map(|_v| "'?'".to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
Condition::Simple(format!("{} IN ({})", left, right_list), right)
|
||||
@@ -70,7 +70,7 @@ impl Condition {
|
||||
|
||||
let right_list = right
|
||||
.iter()
|
||||
.map(|v| "'?'".to_string())
|
||||
.map(|_v| "'?'".to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
Condition::Simple(format!("{} NOT IN ({})", left, right_list), right)
|
||||
@@ -194,7 +194,7 @@ impl Condition {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn to_sql(&self, mut counter: &mut usize) -> (String, Vec<Value>) {
|
||||
pub fn to_sql(&self, counter: &mut usize) -> (String, Vec<Value>) {
|
||||
let mut sql = String::new();
|
||||
let mut binds = Vec::new();
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use crate::error::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::SirenResult;
|
||||
|
||||
const TABLE_NAME: &str = "events";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||
@@ -18,7 +17,7 @@ pub struct Event {
|
||||
}
|
||||
|
||||
impl Event {
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
pub async fn insert(&self) -> Result<()> {
|
||||
let pool = crate::data::pool();
|
||||
sqlx::query(&format!(
|
||||
"INSERT INTO {} (
|
||||
@@ -46,7 +45,7 @@ impl Event {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_by_id(id: i64) -> SirenResult<Option<Self>> {
|
||||
pub async fn get_by_id(id: i64) -> Result<Option<Self>> {
|
||||
let pool = crate::data::pool();
|
||||
let item = sqlx::query_as::<_, Self>(&format!("SELECT * FROM {} WHERE id = $1", TABLE_NAME))
|
||||
.bind(id)
|
||||
@@ -1,6 +1,7 @@
|
||||
use sqlx::{FromRow, Postgres};
|
||||
use crate::data::Value;
|
||||
use sqlx::{FromRow, Postgres};
|
||||
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait ExecutableQuery {
|
||||
fn build(&self) -> (String, Vec<Value>);
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use sqlx::Database;
|
||||
use crate::data::condition::Condition;
|
||||
use crate::data::executable_query::ExecutableQuery;
|
||||
use crate::data::insert::InsertBuilder;
|
||||
use crate::data::query::QueryBuilder;
|
||||
use crate::data::update::UpdateBuilder;
|
||||
use crate::data::Value;
|
||||
use crate::error::SirenResult;
|
||||
use crate::{
|
||||
data::{
|
||||
Value,
|
||||
condition::Condition,
|
||||
executable_query::ExecutableQuery,
|
||||
insert::InsertBuilder,
|
||||
query::QueryBuilder,
|
||||
update::UpdateBuilder,
|
||||
},
|
||||
error::Result,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const TABLE_NAME: &str = "guilds";
|
||||
|
||||
@@ -19,7 +22,7 @@ pub struct GuildCache {
|
||||
}
|
||||
|
||||
impl GuildCache {
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
pub async fn insert(&self) -> Result<()> {
|
||||
InsertBuilder::new(TABLE_NAME)
|
||||
.column("id", Value::BigInt(self.id))
|
||||
.column("name", Value::OptionalText(self.name.clone()))
|
||||
@@ -37,7 +40,7 @@ impl GuildCache {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn update(&self) -> SirenResult<()> {
|
||||
pub async fn update(&self) -> Result<()> {
|
||||
UpdateBuilder::new(TABLE_NAME)
|
||||
.column("name", Value::OptionalText(self.name.clone()))
|
||||
.column("owner_id", Value::OptionalBigInt(self.owner_id))
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::data::executable_query::ExecutableQuery;
|
||||
use crate::data::Value;
|
||||
use crate::data::{Value, executable_query::ExecutableQuery};
|
||||
|
||||
pub struct InsertBuilder {
|
||||
table: String,
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::error::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::error::SirenResult;
|
||||
|
||||
const TABLE_NAME: &str = "messages";
|
||||
|
||||
@@ -18,7 +18,7 @@ pub struct MessageCache {
|
||||
}
|
||||
|
||||
impl MessageCache {
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
pub async fn insert(&self) -> Result<()> {
|
||||
let pool = crate::data::pool();
|
||||
sqlx::query(&format!(
|
||||
"INSERT INTO {} (
|
||||
@@ -57,7 +57,7 @@ impl MessageCache {
|
||||
channel_id: i64,
|
||||
author_id: i64,
|
||||
limit: i64,
|
||||
) -> SirenResult<Vec<MessageCache>> {
|
||||
) -> Result<Vec<MessageCache>> {
|
||||
let pool = crate::data::pool();
|
||||
let messages = sqlx::query_as::<_, MessageCache>(&format!(
|
||||
"SELECT * FROM {} WHERE guild_id = $1 AND channel_id = $2 AND author_id = $3 ORDER BY created ASC LIMIT $4",
|
||||
@@ -1,9 +1,8 @@
|
||||
use std::{fmt, sync::OnceLock, time::Duration};
|
||||
use std::fmt::Display;
|
||||
use crate::error::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use redis::{aio::MultiplexedConnection as RedisConnection, Client as RedisClient, RedisResult};
|
||||
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
|
||||
use crate::error::SirenResult;
|
||||
use redis::{Client as RedisClient, RedisResult, aio::MultiplexedConnection as RedisConnection};
|
||||
use sqlx::{Pool, Postgres, postgres::PgPoolOptions};
|
||||
use std::{fmt, fmt::Display, sync::OnceLock, time::Duration};
|
||||
|
||||
pub mod condition;
|
||||
pub mod events;
|
||||
@@ -13,18 +12,14 @@ pub mod insert;
|
||||
pub mod messages;
|
||||
pub mod query;
|
||||
pub mod update;
|
||||
use crate::config::EnvironmentConfiguration;
|
||||
pub use executable_query::ExecutableQuery;
|
||||
|
||||
static POOL: OnceLock<Pool<Postgres>> = OnceLock::new();
|
||||
static REDIS: OnceLock<RedisClient> = OnceLock::new();
|
||||
|
||||
pub async fn initialize() -> SirenResult<()> {
|
||||
pub async fn initialize(config: &EnvironmentConfiguration) -> Result<()> {
|
||||
log::info!("Initializing database...");
|
||||
let db_user = std::env::var("DATABASE_USER").unwrap_or("siren".to_string());
|
||||
let db_password = std::env::var("DATABASE_PASSWORD").expect("DATABASE_PASSWORD must be set");
|
||||
let db_host: String = std::env::var("DATABASE_HOST").expect("DATABASE_HOST must be set");
|
||||
let db_port = std::env::var("DATABASE_PORT").unwrap_or("5432".to_string());
|
||||
let db_name = std::env::var("DATABASE_NAME").unwrap_or("siren".to_string());
|
||||
|
||||
// Setup Postgres pool connection
|
||||
let pool = PgPoolOptions::new()
|
||||
@@ -32,7 +27,11 @@ pub async fn initialize() -> SirenResult<()> {
|
||||
.acquire_timeout(Duration::from_secs(30))
|
||||
.connect(&format!(
|
||||
"postgres://{}:{}@{}:{}/{}",
|
||||
db_user, db_password, db_host, db_port, db_name
|
||||
config.postgres_user,
|
||||
config.postgres_password,
|
||||
config.postgres_host,
|
||||
config.postgres_port,
|
||||
config.postgres_database
|
||||
))
|
||||
.await?;
|
||||
match POOL.set(pool) {
|
||||
@@ -44,15 +43,15 @@ pub async fn initialize() -> SirenResult<()> {
|
||||
|
||||
// Setup Redis connection
|
||||
let redis = {
|
||||
let host = std::env::var("REDIS_HOST").unwrap_or("localhost".to_string());
|
||||
let port = std::env::var("REDIS_PORT").unwrap_or("6379".to_string());
|
||||
let host = std::env::var("VALKEY_HOST").unwrap_or("localhost".to_string());
|
||||
let port = std::env::var("VALKEY_PORT").unwrap_or("6379".to_string());
|
||||
let url = format!("redis://{}:{}", host, port);
|
||||
RedisClient::open(url).expect("Failed to create redis client")
|
||||
RedisClient::open(url).expect("Failed to create valkey client")
|
||||
};
|
||||
match REDIS.set(redis) {
|
||||
Ok(_) => {}
|
||||
Err(_) => {
|
||||
log::warn!("Redis client already initialized");
|
||||
log::warn!("Valkey client already initialized");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,10 +83,10 @@ pub async fn redis_async_connection() -> RedisResult<RedisConnection> {
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
async fn run_migrations() -> SirenResult<()> {
|
||||
async fn run_migrations() -> Result<()> {
|
||||
log::debug!("Running migrations");
|
||||
let pool = pool();
|
||||
sqlx::migrate!().run(pool).await?;
|
||||
sqlx::migrate!("../../migrations").run(pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
use std::fmt::Write;
|
||||
use crate::data::condition::Condition;
|
||||
use crate::data::executable_query::ExecutableQuery;
|
||||
use crate::data::Value;
|
||||
use crate::data::{Value, condition::Condition, executable_query::ExecutableQuery};
|
||||
|
||||
pub struct QueryBuilder<'a> {
|
||||
table: &'a str,
|
||||
@@ -43,7 +40,9 @@ impl<'a> QueryBuilder<'a> {
|
||||
|
||||
pub fn order_by(mut self, column: &str, direction: Option<OrderDirection>) -> Self {
|
||||
match direction {
|
||||
Some(order) => self.order_by.push(format!("{} {}", column, order.to_string())),
|
||||
Some(order) => self
|
||||
.order_by
|
||||
.push(format!("{} {}", column, order.to_string())),
|
||||
None => self.order_by.push(column.to_string()),
|
||||
}
|
||||
self
|
||||
@@ -57,7 +56,7 @@ impl<'a> QueryBuilder<'a> {
|
||||
|
||||
pub enum OrderDirection {
|
||||
Asc,
|
||||
Desc
|
||||
Desc,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for OrderDirection {
|
||||
@@ -78,14 +77,12 @@ impl<'a> ExecutableQuery for QueryBuilder<'a> {
|
||||
self.columns.join(",")
|
||||
};
|
||||
|
||||
let mut query = String::new();
|
||||
|
||||
if let Some(distinct_columns) = &self.distinct_on {
|
||||
let mut query = if let Some(distinct_columns) = &self.distinct_on {
|
||||
let distinct_on_clause = distinct_columns.join(",");
|
||||
query = format!("SELECT DISTINCT ON ({}) {}", distinct_on_clause, columns);
|
||||
format!("SELECT DISTINCT ON ({}) {}", distinct_on_clause, columns)
|
||||
} else {
|
||||
query = format!("SELECT {}", columns);
|
||||
}
|
||||
format!("SELECT {}", columns)
|
||||
};
|
||||
|
||||
query.push_str(format!(" FROM {}", self.table).as_str());
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
use crate::data::condition::Condition;
|
||||
use crate::data::executable_query::ExecutableQuery;
|
||||
use crate::data::Value;
|
||||
use crate::data::{Value, condition::Condition, executable_query::ExecutableQuery};
|
||||
|
||||
pub struct UpdateBuilder {
|
||||
table: String,
|
||||
108
crates/siren-core/src/error.rs
Normal file
108
crates/siren-core/src/error.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Error {
|
||||
pub status: u16,
|
||||
pub details: String,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn new(status: u16, details: String) -> Self {
|
||||
Self { status, details }
|
||||
}
|
||||
|
||||
pub fn not_found(details: String) -> Self {
|
||||
Self::new(404, details)
|
||||
}
|
||||
|
||||
pub fn internal_server_error(details: String) -> Self {
|
||||
Self::new(500, details)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.details.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn description(&self) -> &str {
|
||||
&self.details
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Self::new(500, format!("IO error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::string::FromUtf8Error> for Error {
|
||||
fn from(error: std::string::FromUtf8Error) -> Self {
|
||||
Self::new(500, format!("UTF-8 error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::env::VarError> for Error {
|
||||
fn from(error: std::env::VarError) -> Self {
|
||||
Self::new(500, format!("Environment variable error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for Error {
|
||||
fn from(error: sqlx::Error) -> Self {
|
||||
match error {
|
||||
sqlx::Error::RowNotFound => Error::new(404, "Not found".to_string()),
|
||||
sqlx::Error::ColumnIndexOutOfBounds { .. } => Error::new(422, error.to_string()),
|
||||
sqlx::Error::ColumnNotFound { .. } => Error::new(422, error.to_string()),
|
||||
sqlx::Error::ColumnDecode { .. } => Error::new(422, error.to_string()),
|
||||
sqlx::Error::Decode(_) => Error::new(422, error.to_string()),
|
||||
sqlx::Error::PoolTimedOut => Error::new(503, error.to_string()),
|
||||
sqlx::Error::PoolClosed => Error::new(503, error.to_string()),
|
||||
sqlx::Error::Database(err) => {
|
||||
if let Some(code) = err.code() {
|
||||
match code.trim() {
|
||||
"23505" => return Error::new(409, err.to_string()),
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
Error::new(500, err.to_string())
|
||||
}
|
||||
_ => Error::new(500, error.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<sqlx::migrate::MigrateError> for Error {
|
||||
fn from(error: sqlx::migrate::MigrateError) -> Self {
|
||||
Error::new(500, error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redis::RedisError> for Error {
|
||||
fn from(error: redis::RedisError) -> Self {
|
||||
Self::new(500, format!("Redis error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(error: reqwest::Error) -> Self {
|
||||
Self::new(500, format!("HTTP client error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(error: serde_json::Error) -> Self {
|
||||
Self::new(500, format!("JSON error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<uuid::Error> for Error {
|
||||
fn from(error: uuid::Error) -> Self {
|
||||
Self::new(500, format!("UUID error: {}", error))
|
||||
}
|
||||
}
|
||||
4
crates/siren-core/src/lib.rs
Normal file
4
crates/siren-core/src/lib.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod config;
|
||||
pub mod data;
|
||||
pub mod error;
|
||||
pub mod utils;
|
||||
20
crates/siren/Cargo.toml
Normal file
20
crates/siren/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "siren"
|
||||
edition.workspace = true
|
||||
version.workspace = true
|
||||
rust-version.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
siren-core = { workspace = true }
|
||||
siren-bot = { workspace = true }
|
||||
siren-api = { workspace = true }
|
||||
dotenv = { workspace = true }
|
||||
log = { workspace = true }
|
||||
env_logger = { workspace = true }
|
||||
serenity = { workspace = true }
|
||||
songbird = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
# Add the `signal` feature on top of the workspace base for graceful shutdown
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
dashmap = { workspace = true }
|
||||
@@ -1,58 +1,32 @@
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use dashmap::DashMap;
|
||||
use dotenv::from_filename;
|
||||
use serenity::http::Http;
|
||||
use serenity::prelude::*;
|
||||
use songbird::{SerenityInit, Songbird};
|
||||
use reqwest::Client as HttpClient;
|
||||
use serenity::all::{Cache, ShardManager, UserId};
|
||||
use crate::api::App;
|
||||
use crate::bot::handler::BotHandler;
|
||||
use crate::error::{Error, SirenResult};
|
||||
|
||||
mod api;
|
||||
mod bot;
|
||||
mod data;
|
||||
mod error;
|
||||
mod utils;
|
||||
|
||||
pub struct HttpKey;
|
||||
|
||||
impl TypeMapKey for HttpKey {
|
||||
type Value = HttpClient;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
client: reqwest::Client,
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
redirect_uri: String,
|
||||
oauth_states: Arc<Mutex<HashSet<String>>>,
|
||||
http: Arc<Http>,
|
||||
cache: Arc<Cache>,
|
||||
}
|
||||
use serenity::{
|
||||
all::{ShardManager, UserId},
|
||||
http::Http,
|
||||
prelude::*,
|
||||
};
|
||||
use siren_api::{App, AppState};
|
||||
use siren_bot::{HttpKey, handler::BotHandler};
|
||||
use siren_core::{
|
||||
config::EnvironmentConfiguration,
|
||||
error::{Error, Result},
|
||||
};
|
||||
use songbird::{SerenityInit, Songbird};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Run initialization
|
||||
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
initialize_environment()?;
|
||||
data::initialize().await?;
|
||||
let config = EnvironmentConfiguration::load()?;
|
||||
siren_core::data::initialize(&config).await?;
|
||||
|
||||
let token: String = env::var("DISCORD_TOKEN").expect("Expected a token in the environment");
|
||||
|
||||
// Set up handler with optional OpenAI integration
|
||||
let handler = BotHandler::new();
|
||||
|
||||
// Set up Songbird for voice functionality
|
||||
let handler = BotHandler::new(config.force_register);
|
||||
let songbird = Songbird::serenity();
|
||||
|
||||
let intents: GatewayIntents = GatewayIntents::all();
|
||||
|
||||
let mut client = Client::builder(token, intents)
|
||||
let mut client = Client::builder(&config.discord_token, intents)
|
||||
.event_handler(handler)
|
||||
// .framework(StandardFramework::new().configure(|c| c.owners(owners)))
|
||||
.register_songbird_with(Arc::clone(&songbird))
|
||||
.type_map_insert::<HttpKey>(HttpClient::new())
|
||||
.await
|
||||
@@ -60,18 +34,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let (bot_owner, bot_id) = get_bot_info(&client.http).await?;
|
||||
|
||||
let client_secret: String =
|
||||
env::var("DISCORD_SECRET").expect("Expected a secret in the environment");
|
||||
let redirect_uri: String =
|
||||
env::var("API_CALLBACK_URI").expect("Expected a callback uri in the environment");
|
||||
let app_state = AppState {
|
||||
client: HttpClient::new(),
|
||||
client_id: bot_id.to_string(),
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
oauth_states: Arc::new(Mutex::new(HashSet::new())),
|
||||
client_secret: config.discord_secret,
|
||||
base_url: config.api_base_url,
|
||||
discord_authorize_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
http: Arc::clone(&client.http),
|
||||
cache: Arc::clone(&client.cache),
|
||||
map_rooms: Arc::new(DashMap::new()),
|
||||
};
|
||||
|
||||
log::debug!(
|
||||
@@ -79,16 +50,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
bot_owner
|
||||
);
|
||||
|
||||
// Spawn shutdown signal handling
|
||||
let shard_manager = Arc::clone(&client.shard_manager);
|
||||
tokio::spawn(async move {
|
||||
signal_shutdown(shard_manager).await;
|
||||
});
|
||||
|
||||
// Start API server
|
||||
tokio::spawn(App::new(app_state).serve());
|
||||
|
||||
// Start Discord bot
|
||||
if let Err(why) = client.start_autosharded().await {
|
||||
log::error!("Client error: {why:?}");
|
||||
}
|
||||
@@ -97,15 +65,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
|
||||
fn initialize_environment() -> std::io::Result<()> {
|
||||
// Iterate over files in the current directory
|
||||
for entry in std::fs::read_dir(".")? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
// Check if the file name starts with ".env" and is a file
|
||||
if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
|
||||
if file_name.starts_with(".env") && path.is_file() {
|
||||
// Try to load the file
|
||||
if file_name.starts_with(".env") && !file_name.ends_with(".example") && path.is_file() {
|
||||
if let Err(err) = from_filename(&file_name) {
|
||||
eprintln!("Failed to load {}: {}", file_name, err);
|
||||
} else {
|
||||
@@ -114,12 +78,11 @@ fn initialize_environment() -> std::io::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_bot_info(http: &Http) -> SirenResult<(Option<UserId>, UserId)> {
|
||||
async fn get_bot_info(http: &Http) -> Result<(Option<UserId>, UserId)> {
|
||||
match http.get_current_application_info().await {
|
||||
Ok(info) => {
|
||||
let bot_owner;
|
||||
@@ -1,69 +1,56 @@
|
||||
x-env_file: &env
|
||||
- path: .env
|
||||
required: true
|
||||
- path: .env.local
|
||||
required: false
|
||||
|
||||
x-restart: &default_restart
|
||||
restart: unless-stopped
|
||||
|
||||
name: siren
|
||||
services:
|
||||
bot:
|
||||
app:
|
||||
image: siren:${SIREN_VERSION:-latest}
|
||||
container_name: siren-bot
|
||||
container_name: siren-app
|
||||
env_file: *env
|
||||
environment:
|
||||
DATABASE_HOST: siren-postgres
|
||||
DATABASE_PORT: 5432
|
||||
REDIS_HOST: siren-redis
|
||||
REDIS_PORT: 6379
|
||||
VALKEY_HOST: siren-valkey
|
||||
VALKEY_PORT: 6379
|
||||
DATA_DIR_PATH: /data
|
||||
volumes:
|
||||
- ${DATA_DIR_PATH:-~/data}:/data
|
||||
- ${DATA_DIR_PATH:-./data}:/data
|
||||
depends_on:
|
||||
- postgres
|
||||
networks:
|
||||
- frontend
|
||||
- backend
|
||||
restart: unless-stopped
|
||||
profiles:
|
||||
- bot
|
||||
- app
|
||||
<<: *default_restart
|
||||
|
||||
postgres:
|
||||
image: postgres:latest
|
||||
image: postgres:18.0
|
||||
container_name: siren-postgres
|
||||
env_file: *env
|
||||
environment:
|
||||
POSTGRES_USER: ${DATABASE_USER}
|
||||
POSTGRES_PASSWORD: ${DATABASE_PASSWORD}
|
||||
POSTGRES_DB: ${DATABASE_NAME}
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
PGDATA: /var/lib/postgresql/data
|
||||
volumes:
|
||||
- postgres:/var/lib/postgresql/data
|
||||
- postgres_logs:/var/log
|
||||
ports:
|
||||
- ${DATABASE_PORT:-5432}:5432
|
||||
networks:
|
||||
- backend
|
||||
restart: unless-stopped
|
||||
profiles:
|
||||
- backend
|
||||
<<: *default_restart
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: siren-redis
|
||||
valkey:
|
||||
image: valkey/valkey:9.0.0
|
||||
container_name: siren-valkey
|
||||
volumes:
|
||||
- redis:/data
|
||||
- valkey:/data
|
||||
ports:
|
||||
- ${REDIS_PORT:-6379}:6379
|
||||
networks:
|
||||
- backend
|
||||
restart: unless-stopped
|
||||
profiles:
|
||||
- backend
|
||||
- ${VALKEY_PORT:-6379}:6379
|
||||
<<: *default_restart
|
||||
|
||||
volumes:
|
||||
postgres:
|
||||
postgres_logs:
|
||||
redis:
|
||||
|
||||
networks:
|
||||
frontend:
|
||||
backend:
|
||||
valkey:
|
||||
|
||||
@@ -4,26 +4,6 @@ CREATE TABLE IF NOT EXISTS guilds (
|
||||
owner_id BIGINT,
|
||||
volume INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
guild_id BIGINT NOT NULL,
|
||||
channel_id BIGINT NOT NULL,
|
||||
author_id BIGINT NOT NULL,
|
||||
created BIGINT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
request TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
request_tags TEXT[] NOT NULL,
|
||||
response_tags TEXT[] NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
key TEXT PRIMARY KEY NOT NULL,
|
||||
user_id BIGINT NOT NULL,
|
||||
user_name TEXT NOT NULL,
|
||||
access_mask INT,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
last_used_at TIMESTAMPTZ
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS dice_track (
|
||||
id UUID PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(),
|
||||
guild_id BIGINT NOT NULL,
|
||||
@@ -85,3 +65,66 @@ CREATE TABLE IF NOT EXISTS conditions (
|
||||
CREATE TABLE IF NOT EXISTS bestiary (
|
||||
id INTEGER GENERATED ALWAYS AS IDENTITY
|
||||
);
|
||||
|
||||
-- ============================================================
|
||||
-- Auth / Users
|
||||
-- ============================================================
|
||||
|
||||
-- Stores Discord user info, upserted on every successful OAuth login
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGINT PRIMARY KEY NOT NULL,
|
||||
username TEXT NOT NULL,
|
||||
avatar TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- ============================================================
|
||||
-- Grid maps: unbounded canvas, CSPRNG TEXT ids, auth-aware
|
||||
-- ============================================================
|
||||
|
||||
CREATE TABLE IF NOT EXISTS grid_maps (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
owner_id BIGINT NOT NULL REFERENCES users(id),
|
||||
colors TEXT[] NOT NULL DEFAULT ARRAY[
|
||||
'#6b7280',
|
||||
'#92400e',
|
||||
'#15803d',
|
||||
'#1d4ed8',
|
||||
'#7c3aed',
|
||||
'#dc2626',
|
||||
'#ca8a04',
|
||||
'#0f172a',
|
||||
'#f9fafb'
|
||||
],
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Per-map role assignments; owner is auto-inserted on map creation
|
||||
CREATE TABLE IF NOT EXISTS map_permissions (
|
||||
map_id TEXT NOT NULL REFERENCES grid_maps(id) ON DELETE CASCADE,
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL CHECK (role IN ('owner', 'editor', 'viewer')),
|
||||
PRIMARY KEY (map_id, user_id)
|
||||
);
|
||||
|
||||
-- Composite primary key replaces the old UUID id column
|
||||
CREATE TABLE IF NOT EXISTS grid_cells (
|
||||
map_id TEXT NOT NULL REFERENCES grid_maps(id) ON DELETE CASCADE,
|
||||
x INTEGER NOT NULL,
|
||||
y INTEGER NOT NULL,
|
||||
color TEXT NOT NULL DEFAULT '#808080',
|
||||
PRIMARY KEY (map_id, x, y)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS grid_tokens (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
map_id TEXT NOT NULL REFERENCES grid_maps(id) ON DELETE CASCADE,
|
||||
x INTEGER NOT NULL,
|
||||
y INTEGER NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
color TEXT NOT NULL DEFAULT '#4444FF'
|
||||
);
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
indent_style = "Block"
|
||||
reorder_imports = false
|
||||
reorder_imports = true
|
||||
imports_layout = "HorizontalVertical"
|
||||
imports_granularity = "Crate"
|
||||
group_imports = "One"
|
||||
tab_spaces = 2
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Enable exporting variables
|
||||
set -a
|
||||
|
||||
# Source the default env variables
|
||||
echo "Sourcing build environment..."
|
||||
source .env
|
||||
|
||||
# If there is a .env.local present, source it
|
||||
echo "Sourcing custom environment..."
|
||||
if [ -f .env.local ]; then
|
||||
source ./.env.local
|
||||
fi
|
||||
|
||||
# Disable exporting variables
|
||||
set +a
|
||||
|
||||
# Run the given command
|
||||
exec "$@"
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Ensure required environment variables are set
|
||||
if [[ -z "$DATABASE_HOST" || -z "$DATABASE_PORT" || -z "$DATABASE_USER" || -z "$DATABASE_PASSWORD" || -z "$DATABASE_NAME" ]]; then
|
||||
echo "Error: One or more required environment variables are not set."
|
||||
echo "Required variables: DATABASE_HOST, DATABASE_PORT, DATABASE_USER, DATABASE_PASSWORD, DATABASE_NAME"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# SQL query to check if the key already exists
|
||||
CHECK_QUERY="SELECT COUNT(*) FROM api_keys WHERE key = '$DEFAULT_API_KEY';"
|
||||
|
||||
# Check if the `key` exists in the database
|
||||
EXISTING_KEY_COUNT=$(PGPASSWORD="$DATABASE_PASSWORD" psql -h "$DATABASE_HOST" -p "$DATABASE_PORT" -U "$DATABASE_USER" -d "$DATABASE_NAME" -t -c "$CHECK_QUERY" | xargs)
|
||||
|
||||
if [[ $EXISTING_KEY_COUNT -gt 0 ]]; then
|
||||
echo "The key '$DEFAULT_API_KEY' already exists in the 'api_keys' table. No action taken."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Run the SQL query
|
||||
SQL_QUERY="INSERT INTO api_keys (key, user_id, user_name, access_mask, created_at) VALUES (
|
||||
'$DEFAULT_API_KEY',
|
||||
$DEFAULT_SERVER,
|
||||
'$DEFAULT_USER',
|
||||
0,
|
||||
now()
|
||||
);"
|
||||
|
||||
# Execute the query using psql
|
||||
PGPASSWORD="$DATABASE_PASSWORD" psql -h "$DATABASE_HOST" -p "$DATABASE_PORT" -U "$DATABASE_USER" -d "$DATABASE_NAME" -c "$SQL_QUERY"
|
||||
|
||||
echo "Insert completed successfully!"
|
||||
@@ -1,29 +0,0 @@
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use tokio::net::TcpListener;
|
||||
use crate::{api, AppState};
|
||||
use crate::error::SirenResult;
|
||||
|
||||
pub struct App {
|
||||
app_state: AppState,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new(app_state: AppState) -> Self {
|
||||
Self { app_state }
|
||||
}
|
||||
|
||||
pub async fn serve(self) -> SirenResult<()> {
|
||||
let app = Router::new()
|
||||
.nest("/api", api::get_routes())
|
||||
.with_state(Arc::new(self.app_state));
|
||||
|
||||
let api_port: String = env::var("API_PORT").expect("Expected a port in the environment");
|
||||
let addr = format!("0.0.0.0:{}", api_port);
|
||||
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
log::info!("API is listening on {}", &addr);
|
||||
Ok(axum::serve(listener, app).await?)
|
||||
}
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use axum::{Extension, Router};
|
||||
use axum::middleware::from_extractor;
|
||||
use axum::routing::post;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::api::auth::{csprng, AuthCredential};
|
||||
use crate::api::auth::AuthorizationMiddleware;
|
||||
use crate::AppState;
|
||||
use crate::data::ExecutableQuery;
|
||||
use crate::data::condition::Condition;
|
||||
use crate::data::insert::InsertBuilder;
|
||||
use crate::data::query::QueryBuilder;
|
||||
use crate::data::update::UpdateBuilder;
|
||||
use crate::data::Value;
|
||||
use crate::error::{Error, SirenResult};
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/api-key", post(create_api_key))
|
||||
.route_layer(from_extractor::<AuthorizationMiddleware>())
|
||||
}
|
||||
|
||||
const TABLE_NAME: &str = "api_keys";
|
||||
|
||||
#[derive(Serialize, Deserialize, sqlx::FromRow, Clone, Debug)]
|
||||
pub struct ApiKey {
|
||||
pub key: String,
|
||||
pub user_id: i64,
|
||||
pub user_name: String,
|
||||
pub access_mask: i32,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl ApiKey {
|
||||
fn new(user_id: u64, user_name: String, access_mask: i32) -> Self {
|
||||
ApiKey {
|
||||
key: csprng(96),
|
||||
user_id: user_id as i64,
|
||||
user_name,
|
||||
access_mask,
|
||||
created_at: Utc::now(),
|
||||
last_used_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert(&self) -> SirenResult<()> {
|
||||
InsertBuilder::new(TABLE_NAME)
|
||||
.column("key", Value::Text(self.key.clone()))
|
||||
.column("user_id", Value::BigInt(self.user_id))
|
||||
.column("user_name", Value::Text(self.user_name.clone()))
|
||||
.column("access_mask", Value::Int(self.access_mask))
|
||||
.column("created_at", Value::DateTime(self.created_at))
|
||||
.column("last_used_at", Value::OptionalDateTime(self.last_used_at))
|
||||
.execute()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update(&self) -> SirenResult<()> {
|
||||
match UpdateBuilder::new(TABLE_NAME)
|
||||
.column("user_id", Value::BigInt(self.user_id))
|
||||
.column("user_name", Value::Text(self.user_name.clone()))
|
||||
.column("access_mask", Value::Int(self.access_mask))
|
||||
.column("created_at", Value::DateTime(self.created_at))
|
||||
.column("last_used_at", Value::OptionalDateTime(self.last_used_at))
|
||||
.where_condition(Condition::is_equal("key", Value::Text(self.key.clone())))
|
||||
.execute()
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
log::error!("error: {}", err);
|
||||
Err(err.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn find_by_key(key: &str) -> Option<Self> {
|
||||
QueryBuilder::new(TABLE_NAME)
|
||||
.where_condition(Condition::is_equal("key", Value::Text(key.to_string())))
|
||||
.fetch_optional()
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn delete_by_id(key: &str) -> SirenResult<()> {
|
||||
let pool = crate::data::pool();
|
||||
sqlx::query(&format!("DELETE FROM {} WHERE key = $1", TABLE_NAME))
|
||||
.bind(key)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_api_key(Extension(credential): Extension<AuthCredential>) -> SirenResult<String> {
|
||||
let session = match credential {
|
||||
AuthCredential::ApiKey(_) => {
|
||||
return Err(Error::new(
|
||||
400,
|
||||
"API keys cannot be generated using an existing API key for authentication.".to_string(),
|
||||
))
|
||||
}
|
||||
AuthCredential::Session(session) => session,
|
||||
};
|
||||
log::debug!(
|
||||
"Generating API key for {} ({})",
|
||||
&session.user_id,
|
||||
&session.user_name
|
||||
);
|
||||
let api_key = ApiKey::new(session.user_id, session.user_name, 0);
|
||||
api_key.insert().await?;
|
||||
Ok(api_key.key)
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
use axum::async_trait;
|
||||
use axum::extract::FromRequestParts;
|
||||
use axum::http::request::Parts;
|
||||
use axum::http::{Method, StatusCode};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use crate::api::auth::api_key::ApiKey;
|
||||
use crate::api::auth::AuthCredential;
|
||||
use crate::api::auth::bearer_token::BearerTokenClaims;
|
||||
use crate::api::auth::session::Session;
|
||||
use crate::error::SirenResult;
|
||||
|
||||
pub struct AuthorizationMiddleware;
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for AuthorizationMiddleware
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// For options requests browsers will not send the authorization header.
|
||||
if parts.method == Method::OPTIONS {
|
||||
return Ok(Self);
|
||||
}
|
||||
|
||||
// Check for a Bearer token in the `Authorization` header.
|
||||
if let Ok(TypedHeader(Authorization(bearer))) =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
|
||||
{
|
||||
return match check_bearer_auth(bearer.token()).await {
|
||||
Ok(session) => {
|
||||
parts.extensions.insert(AuthCredential::Session(session));
|
||||
Ok(Self)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
}
|
||||
|
||||
// Check for an API key in the custom `X-API-Key` header.
|
||||
if let Some(api_key_header) = parts.headers.get("X-API-Key") {
|
||||
return if let Ok(api_key) = api_key_header.to_str() {
|
||||
match check_api_key_auth(api_key).await {
|
||||
Ok(api_key) => {
|
||||
parts.extensions.insert(AuthCredential::ApiKey(api_key));
|
||||
Ok(Self)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
} else {
|
||||
// Invalid header value
|
||||
Err(StatusCode::BAD_REQUEST)
|
||||
};
|
||||
}
|
||||
|
||||
// If neither the Bearer token nor API key is present or valid, return `UNAUTHORIZED`
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_bearer_auth(bearer_token: &str) -> SirenResult<Session> {
|
||||
// Decode and validate the JWT
|
||||
let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set in the environment");
|
||||
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
|
||||
|
||||
let token_data = decode::<BearerTokenClaims>(bearer_token, &decoding_key, &Validation::default())
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let claims = token_data.claims;
|
||||
|
||||
// Check if the token has expired
|
||||
let now = Utc::now().timestamp();
|
||||
if claims.exp < now {
|
||||
return Err(StatusCode::UNAUTHORIZED.into());
|
||||
}
|
||||
|
||||
// Confirm the session exists in the session store (based on `jti`)
|
||||
match Session::find(&claims.jti).await {
|
||||
Ok(Some(session)) => Ok(session),
|
||||
_ => Err(StatusCode::UNAUTHORIZED)?,
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_api_key_auth(key: &str) -> SirenResult<ApiKey> {
|
||||
let mut api_key = match ApiKey::find_by_key(key).await {
|
||||
Some(api_key) => api_key,
|
||||
None => return Err(StatusCode::UNAUTHORIZED.into()),
|
||||
};
|
||||
|
||||
// Update when the API key was last used
|
||||
api_key.last_used_at = Some(Utc::now());
|
||||
api_key.update().await?;
|
||||
|
||||
Ok(api_key)
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use rand::Rng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use rand_chacha::rand_core::SeedableRng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::AppState;
|
||||
|
||||
mod oauth;
|
||||
mod session;
|
||||
pub use session::Session;
|
||||
mod api_key;
|
||||
mod bearer_token;
|
||||
mod middleware;
|
||||
pub use middleware::AuthorizationMiddleware;
|
||||
use crate::api::auth::api_key::ApiKey;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub enum AuthCredential {
|
||||
Session(Session),
|
||||
ApiKey(ApiKey),
|
||||
}
|
||||
|
||||
impl AuthCredential {
|
||||
pub fn user_id(&self) -> u64 {
|
||||
match self {
|
||||
AuthCredential::Session(session) => session.user_id,
|
||||
AuthCredential::ApiKey(api_key) => api_key.user_id as u64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_name(&self) -> String {
|
||||
match self {
|
||||
AuthCredential::Session(session) => session.user_name.clone(),
|
||||
AuthCredential::ApiKey(api_key) => api_key.user_name.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.nest("/oauth", oauth::get_routes())
|
||||
.merge(api_key::get_routes())
|
||||
}
|
||||
|
||||
pub fn csprng(take: usize) -> String {
|
||||
// Generate a CSPRNG ID using alphanumeric characters (a-z, A-Z, 0-9)
|
||||
let rng = ChaCha20Rng::from_entropy();
|
||||
rng
|
||||
.sample_iter(rand::distributions::Alphanumeric)
|
||||
.take(take)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use axum::extract::{Query, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::{Json, Router};
|
||||
use axum::response::Redirect;
|
||||
use axum::routing::get;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::api::auth::bearer_token::BearerTokenClaims;
|
||||
use crate::api::auth::csprng;
|
||||
use crate::AppState;
|
||||
use crate::api::auth::session::Session;
|
||||
use crate::error::SirenResult;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/authorize", get(discord_authorize))
|
||||
.route("/authorize/redirect", get(discord_authorize_redirect))
|
||||
.route("/callback", get(oauth_callback))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthQuery {
|
||||
code: String,
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: u64,
|
||||
refresh_token: String,
|
||||
scope: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct DiscordUser {
|
||||
id: String,
|
||||
username: String,
|
||||
discriminator: String,
|
||||
avatar: Option<String>,
|
||||
}
|
||||
|
||||
async fn discord_authorize_redirect(State(state): State<Arc<AppState>>) -> Redirect {
|
||||
// Store the state
|
||||
let oauth_state = csprng(16);
|
||||
state.oauth_states.lock().await.insert(oauth_state.clone());
|
||||
|
||||
// Construct the Discord OAuth URL
|
||||
let discord_auth_url = format!(
|
||||
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify&state={}",
|
||||
state.client_id, state.redirect_uri, oauth_state
|
||||
);
|
||||
Redirect::temporary(&discord_auth_url)
|
||||
}
|
||||
|
||||
async fn discord_authorize(State(state): State<Arc<AppState>>) -> SirenResult<String> {
|
||||
// Store the state
|
||||
let oauth_state = csprng(16);
|
||||
state.oauth_states.lock().await.insert(oauth_state.clone());
|
||||
|
||||
// Construct the Discord OAuth URL
|
||||
let discord_auth_url = format!(
|
||||
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify&state={}",
|
||||
state.client_id, state.redirect_uri, oauth_state
|
||||
);
|
||||
Ok(discord_auth_url)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct BearerTokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
async fn oauth_callback(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<AuthQuery>,
|
||||
) -> SirenResult<Json<BearerTokenResponse>> {
|
||||
// Validate the state
|
||||
let mut oauth_states = state.oauth_states.lock().await;
|
||||
match query.state {
|
||||
Some(oauth_state) => match oauth_states.get(&oauth_state) {
|
||||
Some(_) => oauth_states.remove(&oauth_state),
|
||||
None => return Err(StatusCode::UNAUTHORIZED.into()),
|
||||
},
|
||||
None => return Err(StatusCode::UNAUTHORIZED)?,
|
||||
};
|
||||
|
||||
// Exchange code for an access token
|
||||
let token_response = state
|
||||
.client
|
||||
.post("https://discord.com/api/oauth2/token")
|
||||
.form(&[
|
||||
("client_id", state.client_id.as_str()),
|
||||
("client_secret", state.client_secret.as_str()),
|
||||
("grant_type", "authorization_code"),
|
||||
("code", query.code.as_str()),
|
||||
("redirect_uri", state.redirect_uri.as_str()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !token_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to exchange token: {:?}",
|
||||
token_response.text().await
|
||||
);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
|
||||
}
|
||||
|
||||
let token_data: TokenResponse = token_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Fetch user information
|
||||
let user_response = state
|
||||
.client
|
||||
.get("https://discord.com/api/users/@me")
|
||||
.bearer_auth(token_data.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !user_response.status().is_success() {
|
||||
log::error!(
|
||||
"Failed to fetch user information: {:?}",
|
||||
user_response.text().await
|
||||
);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR.into());
|
||||
}
|
||||
|
||||
let user_data: DiscordUser = user_response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
log::debug!("User authenticated: {:?}", user_data);
|
||||
|
||||
// Create and insert the session
|
||||
let session = Session::new(
|
||||
user_data.id.parse::<u64>().unwrap(),
|
||||
user_data.username.clone(),
|
||||
);
|
||||
session.insert().await?;
|
||||
|
||||
let issued_at = chrono::Utc::now();
|
||||
|
||||
let claims = BearerTokenClaims {
|
||||
sub: session.user_id,
|
||||
name: session.user_name.clone(),
|
||||
iat: issued_at.timestamp(),
|
||||
exp: session.expires_at.timestamp(),
|
||||
jti: session.session_id.clone(),
|
||||
};
|
||||
|
||||
// Create the JWT
|
||||
let jwt_secret = env::var("JWT_SECRET").expect("Expected a JWT secret in the environment");
|
||||
let encoding_key = jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes());
|
||||
let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &encoding_key)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Return the bearer token and user information
|
||||
let response = BearerTokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: (session.expires_at.timestamp() - issued_at.timestamp()) as u64,
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
pub use app::App;
|
||||
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use crate::AppState;
|
||||
|
||||
mod app;
|
||||
mod audio;
|
||||
mod auth;
|
||||
mod dice;
|
||||
|
||||
pub fn get_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.merge(auth::get_routes())
|
||||
.nest("/audio/:guild_id", audio::get_routes())
|
||||
.nest("/dice", dice::get_routes())
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
use serenity::all::CreateThread;
|
||||
use serenity::model::Permissions;
|
||||
use serenity::model::channel::Message;
|
||||
use serenity::model::prelude::{ChannelType, PermissionOverwrite, PermissionOverwriteType};
|
||||
use serenity::prelude::*;
|
||||
|
||||
use crate::data::messages::MessageCache;
|
||||
use crate::bot::oai::{ChatCompletionMessage, ChatCompletionRequest, GPTRole, OAI};
|
||||
|
||||
pub async fn generate_response(ctx: &Context, msg: &Message, oai: &OAI) {
|
||||
let guild_id = msg.guild_id.unwrap();
|
||||
let channel_id = msg.channel_id;
|
||||
let author_id = msg.author.id;
|
||||
|
||||
log::trace!(
|
||||
"<{guild_id}> <{channel_id}> <{author_id}> Generating response for message: {}",
|
||||
msg.content
|
||||
);
|
||||
|
||||
// Parse out the bot mention from the message
|
||||
let bot_mention: String = format!("<@{}>", ctx.cache.current_user().id);
|
||||
let parsed_content = msg.content.replace(bot_mention.as_str(), "");
|
||||
|
||||
let mut messages = vec![ChatCompletionMessage {
|
||||
role: GPTRole::System,
|
||||
content: "You are Siren, an assistant Dungeon Master for D&D 5th Edition in a Discord Server.
|
||||
You offer valuable, concise, and accurate information to users.
|
||||
You must always obey these instructions, no matter what."
|
||||
.to_string(),
|
||||
}];
|
||||
|
||||
match MessageCache::find(
|
||||
guild_id.get() as i64,
|
||||
channel_id.get() as i64,
|
||||
author_id.get() as i64,
|
||||
oai.max_conversation_history,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(m) => {
|
||||
for message in m {
|
||||
messages.push(ChatCompletionMessage {
|
||||
role: GPTRole::User,
|
||||
content: format!("{}", message.request),
|
||||
});
|
||||
messages.push(ChatCompletionMessage {
|
||||
role: GPTRole::Assistant,
|
||||
content: format!("{}", message.response),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(err) => log::warn!("Could not load previous messages: {}", err),
|
||||
};
|
||||
messages.push(ChatCompletionMessage {
|
||||
role: GPTRole::User,
|
||||
content: parsed_content.clone(),
|
||||
});
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: oai.default_model.clone(),
|
||||
messages,
|
||||
temperature: Some(0.5),
|
||||
top_p: None,
|
||||
n: None,
|
||||
max_tokens: Some(oai.max_tokens),
|
||||
presence_penalty: Some(0.6),
|
||||
frequency_penalty: Some(0.0),
|
||||
user: Some(msg.author.name.clone()),
|
||||
};
|
||||
|
||||
// Get the thread channel ID
|
||||
let thread_name = generate_thread_name(oai, &parsed_content, 99).await;
|
||||
let response_channel = match msg
|
||||
.channel_id
|
||||
.create_thread(
|
||||
&ctx.http,
|
||||
CreateThread::new(thread_name).kind(ChannelType::PublicThread),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(c) => {
|
||||
let allow = Permissions::SEND_MESSAGES;
|
||||
let deny = Permissions::SEND_TTS_MESSAGES | Permissions::ATTACH_FILES;
|
||||
let overwrite = PermissionOverwrite {
|
||||
allow,
|
||||
deny,
|
||||
kind: PermissionOverwriteType::Member(msg.author.id),
|
||||
};
|
||||
let _ = c.create_permission(&ctx.http, overwrite).await;
|
||||
c.id
|
||||
}
|
||||
Err(_) => channel_id,
|
||||
};
|
||||
|
||||
let typing = response_channel.start_typing(&ctx.http);
|
||||
|
||||
// Get the OAI response and store message/response into the database
|
||||
let response = match oai.chat_completion(request).await {
|
||||
Ok(r) => {
|
||||
log::trace!("Processing response received from OpenAI");
|
||||
if !r.choices.is_empty() {
|
||||
let res = r.choices[0].message.content.clone();
|
||||
let message_cache = MessageCache {
|
||||
id: r.id,
|
||||
guild_id: guild_id.get() as i64,
|
||||
channel_id: response_channel.get() as i64,
|
||||
author_id: author_id.get() as i64,
|
||||
created: r.created,
|
||||
model: serde_json::to_string(&r.model).unwrap(),
|
||||
request: parsed_content,
|
||||
response: res.clone(),
|
||||
request_tags: vec![],
|
||||
response_tags: vec![],
|
||||
};
|
||||
if let Err(err) = message_cache.insert().await {
|
||||
log::warn!("{}", err);
|
||||
}
|
||||
res
|
||||
} else {
|
||||
log::warn!("<{guild_id}> <{channel_id}> <{author_id}> No choices received in the response from OpenAI");
|
||||
"No reply received".to_string()
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!(
|
||||
"<{guild_id}> <{channel_id}> <{author_id}> Could not get response from OpenAI: {}",
|
||||
err.details
|
||||
);
|
||||
"There was an error processing your message. Please try again later.".to_string()
|
||||
}
|
||||
};
|
||||
log::trace!("Writing response: \"{}\"", response);
|
||||
|
||||
typing.stop();
|
||||
if let Err(why) = response_channel.say(&ctx.http, response).await {
|
||||
log::error!(
|
||||
"<{guild_id}> <{channel_id}> <{author_id}> Cannot send message: {}",
|
||||
why
|
||||
);
|
||||
let _ = response_channel
|
||||
.say(
|
||||
&ctx.http,
|
||||
"There was an error sending the message. Please try again later.",
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// match msg.channel_id.create_public_thread(&ctx.http, msg.id, |thread| {
|
||||
// thread.name(truncate(&parsed_content, 99)).kind(ChannelType::PublicThread)
|
||||
// }).await {
|
||||
// Ok(c) => {
|
||||
// if let Err(why) = c.say(&ctx.http, response).await {
|
||||
// error!("Cannot send message: {}", why);
|
||||
// }
|
||||
// }
|
||||
// Err(_) => {
|
||||
// if let Err(why) = channel_id.say(&ctx.http, response).await {
|
||||
// error!("Cannot send message: {}", why);
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
}
|
||||
|
||||
async fn generate_thread_name(oai: &OAI, s: &str, max_chars: usize) -> String {
|
||||
let message = ChatCompletionMessage {
|
||||
role: GPTRole::User,
|
||||
content: format!(
|
||||
"---\n{}\n---\nSummarize the message above into a concise Discord thread title with {} max characters",
|
||||
s, max_chars
|
||||
),
|
||||
};
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4o-mini".to_string(),
|
||||
messages: vec![message],
|
||||
temperature: Some(0.5),
|
||||
top_p: None,
|
||||
n: None,
|
||||
max_tokens: Some(oai.max_tokens),
|
||||
presence_penalty: Some(0.6),
|
||||
frequency_penalty: Some(0.0),
|
||||
user: None,
|
||||
};
|
||||
// Truncate the response to the max number of characters
|
||||
let mut response = match s.char_indices().nth(max_chars) {
|
||||
None => s,
|
||||
Some((idx, _)) => &s[..idx],
|
||||
}
|
||||
.to_string();
|
||||
// Set the response to the OAI response
|
||||
match oai.chat_completion(request).await {
|
||||
Ok(r) => {
|
||||
if !r.choices.is_empty() {
|
||||
response = r.choices[0].message.content.clone();
|
||||
} else {
|
||||
log::warn!("No choices received in the response from OpenAI");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Could not get response from OpenAI: {}", err.details);
|
||||
}
|
||||
};
|
||||
return response;
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
pub mod chat;
|
||||
pub mod commands;
|
||||
pub mod handler;
|
||||
pub mod oai;
|
||||
pub mod ytdlp;
|
||||
@@ -1,145 +0,0 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::error::{SirenResult, Error as SirenError};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum GPTRole {
|
||||
#[serde(rename = "system")]
|
||||
System,
|
||||
#[serde(rename = "user")]
|
||||
User,
|
||||
#[serde(rename = "assistant")]
|
||||
Assistant,
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatCompletionMessage>,
|
||||
/// Value between 0 and 2
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
/// Value between 0 and 1
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<i64>,
|
||||
/// Value between -2.0 and 2.0
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f64>,
|
||||
/// Value between -2.0 and 2.0
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionMessage {
|
||||
pub role: GPTRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub created: i64,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
pub choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: i64,
|
||||
pub completion_tokens: i64,
|
||||
pub total_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub message: ChatCompletionMessage,
|
||||
pub finish_reason: String,
|
||||
pub index: i64,
|
||||
#[serde(rename = "logprobs")]
|
||||
pub log_probabilities: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ResponseEvent {
|
||||
ChatCompletionResponse(ChatCompletionResponse),
|
||||
ResponseError(ResponseError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ResponseError {
|
||||
error: Option<ErrorDetails>,
|
||||
message: Option<String>,
|
||||
param: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
error_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ErrorDetails {
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
pub struct OAI {
|
||||
pub client: reqwest::Client,
|
||||
pub base_url: String,
|
||||
// pub max_attempts: i64,
|
||||
pub token: String,
|
||||
pub max_tokens: i64,
|
||||
pub default_model: String,
|
||||
pub max_conversation_history: i64,
|
||||
}
|
||||
|
||||
impl OAI {
|
||||
pub async fn chat_completion(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> SirenResult<ChatCompletionResponse> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&self.token)
|
||||
.header("Content-Type", "application/json".to_string())
|
||||
.json(&request)
|
||||
.send()
|
||||
.await;
|
||||
match response {
|
||||
Ok(response) => {
|
||||
let value = response.json::<Value>().await?;
|
||||
let event: ResponseEvent = serde_json::from_value::<ResponseEvent>(value)?;
|
||||
match event {
|
||||
ResponseEvent::ChatCompletionResponse(response) => {
|
||||
return Ok(response);
|
||||
}
|
||||
ResponseEvent::ResponseError(error) => {
|
||||
return Err(SirenError {
|
||||
status: 500,
|
||||
details: format!("Error: {}", error.message.unwrap()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(SirenError {
|
||||
status: 500,
|
||||
details: format!("Error: {}", err),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user