Added user auth
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use crate::db;
|
||||
use crate::error_handler::ServiceError;
|
||||
use crate::schema::airports;
|
||||
use crate::db::schema::airports;
|
||||
use diesel::dsl::count_star;
|
||||
use diesel::prelude::*;
|
||||
// use log::trace;
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
use std::env;
|
||||
|
||||
use argon2::{password_hash::{rand_core::OsRng, PasswordHasher, PasswordVerifier, SaltString, Error as HashError}, Argon2, PasswordHash};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, encode, decode, Validation, Algorithm};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod model;
|
||||
mod routes;
|
||||
|
||||
pub use model::*;
|
||||
pub use routes::init_routes;
|
||||
use crate::error_handler::ServiceError;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct TokenClaims {
|
||||
sub: String, // Subject
|
||||
token_uuid: String, // Token UUID
|
||||
iss: String, // Issuer
|
||||
exp: i64, // Expiration time
|
||||
iat: i64, // Issued At
|
||||
nbf: i64 // Not Before
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TokenDetails {
|
||||
pub token: Option<String>,
|
||||
pub token_uuid: uuid::Uuid,
|
||||
pub email: String,
|
||||
pub expires_in: Option<i64>
|
||||
}
|
||||
|
||||
pub fn verify_token(token: &str, public_key: &str) -> Result<TokenDetails, ServiceError> {
|
||||
let bytes_public_key = general_purpose::STANDARD.decode(public_key).unwrap();
|
||||
let decoded_public_key = String::from_utf8(bytes_public_key).unwrap();
|
||||
let key = DecodingKey::from_rsa_pem(decoded_public_key.as_bytes())?;
|
||||
let validation = Validation::new(Algorithm::RS256);
|
||||
let decoded = decode::<TokenClaims>(token, &key, &validation)?;
|
||||
let email = decoded.claims.sub;
|
||||
let token_uuid = uuid::Uuid::parse_str(decoded.claims.token_uuid.as_str()).unwrap();
|
||||
Ok(TokenDetails { token: None, token_uuid, email, expires_in: None })
|
||||
}
|
||||
|
||||
pub fn generate_access_token(email: &str) -> Result<TokenDetails, ServiceError> {
|
||||
let access_token_max_age = env::var("ACCESS_TOKEN_MAXAGE")
|
||||
.expect("ACCESS_TOKEN_MAXAGE must be set")
|
||||
.parse::<i64>()
|
||||
.expect("ACCESS_TOKEN_MAXAGE must be an integer");
|
||||
let access_private_key = env::var("ACCESS_TOKEN_PRIVATE_KEY")
|
||||
.expect("ACCESS_TOKEN_PRIVATE_KEY must be set");
|
||||
generate_token(&email, access_token_max_age, &access_private_key)
|
||||
}
|
||||
|
||||
pub fn generate_refresh_token(email: &str) -> Result<TokenDetails, ServiceError> {
|
||||
let refresh_token_max_age = env::var("REFRESH_TOKEN_MAXAGE")
|
||||
.expect("REFRESH_TOKEN_MAXAGE must be set")
|
||||
.parse::<i64>()
|
||||
.expect("REFRESH_TOKEN_MAXAGE must be an integer");
|
||||
let refresh_private_key = env::var("REFRESH_TOKEN_PRIVATE_KEY")
|
||||
.expect("REFRESH_TOKEN_PRIVATE_KEY must be set");
|
||||
generate_token(&email, refresh_token_max_age, &refresh_private_key)
|
||||
}
|
||||
|
||||
pub fn generate_token(email: &str, ttl: i64, private_key: &str) -> Result<TokenDetails, ServiceError> {
|
||||
let now = chrono::Utc::now();
|
||||
let mut token_details = TokenDetails {
|
||||
token: None,
|
||||
token_uuid: uuid::Uuid::new_v4(),
|
||||
email: email.to_string(),
|
||||
expires_in: Some((now + chrono::Duration::minutes(ttl)).timestamp())
|
||||
};
|
||||
let claims = TokenClaims {
|
||||
sub: token_details.email.clone(),
|
||||
iss: "aviation-weather".to_string(),
|
||||
token_uuid: token_details.token_uuid.to_string(),
|
||||
exp: token_details.expires_in.unwrap(),
|
||||
iat: now.timestamp(),
|
||||
nbf: now.timestamp()
|
||||
};
|
||||
let header = Header::new(Algorithm::RS256);
|
||||
let bytes_private_key = general_purpose::STANDARD.decode(private_key).unwrap();
|
||||
let decoded_private_key = String::from_utf8(bytes_private_key).unwrap();
|
||||
let key = EncodingKey::from_rsa_pem(decoded_private_key.as_bytes())?;
|
||||
let token = encode(&header, &claims, &key)?;
|
||||
token_details.token = Some(token);
|
||||
Ok(token_details)
|
||||
}
|
||||
|
||||
pub fn hash_password(password: &[u8]) -> Result<String, HashError> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
Ok(Argon2::default().hash_password(password, &salt)?.to_string())
|
||||
}
|
||||
|
||||
pub fn verify_password(hash: &str, password: &[u8]) -> Result<(), HashError> {
|
||||
let parsed_hash = PasswordHash::new(hash)?;
|
||||
Ok(Argon2::default().verify_password(password, &parsed_hash)?)
|
||||
}
|
||||
205
service/src/auth/model.rs
Normal file
205
service/src/auth/model.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
use std::{future::{ready, Ready}, env};
|
||||
use actix_web::{FromRequest, Error as ActixError, HttpRequest, dev::Payload, http};
|
||||
use diesel::prelude::*;
|
||||
use log::error;
|
||||
use redis::Commands;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error_handler::ServiceError;
|
||||
|
||||
use crate::db::{schema::users, connection};
|
||||
|
||||
use super::{hash_password, verify_token};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RegisterUser {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
pub first_name: String,
|
||||
pub last_name: String,
|
||||
}
|
||||
|
||||
impl RegisterUser {
|
||||
pub fn convert_to_insert(self) -> Result<InsertUser, ServiceError> {
|
||||
let hash = hash_password(self.password.as_bytes())?;
|
||||
Ok(InsertUser {
|
||||
email: self.email.to_lowercase(),
|
||||
hash,
|
||||
role: "user".to_string(),
|
||||
first_name: self.first_name,
|
||||
last_name: self.last_name,
|
||||
updated_at: chrono::Utc::now().naive_utc(),
|
||||
created_at: chrono::Utc::now().naive_utc(),
|
||||
profile_picture: None,
|
||||
verified: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Queryable, QueryableByName, Serialize, Deserialize)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct QueryUser {
|
||||
pub email: String,
|
||||
pub hash: String,
|
||||
pub role: String,
|
||||
pub first_name: String,
|
||||
pub last_name: String,
|
||||
pub updated_at: chrono::NaiveDateTime,
|
||||
pub created_at: chrono::NaiveDateTime,
|
||||
pub profile_picture: Option<String>,
|
||||
pub verified: bool,
|
||||
}
|
||||
|
||||
impl QueryUser {
|
||||
pub fn get_by_email(email: &str) -> Result<QueryUser, ServiceError> {
|
||||
let mut conn = connection()?;
|
||||
// Check if the user exists by email, case insensitive
|
||||
|
||||
let user = users::table
|
||||
.filter(users::email.eq(email.to_lowercase()))
|
||||
.first(&mut conn)?;
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Insertable, AsChangeset, Serialize, Deserialize)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct InsertUser {
|
||||
pub email: String,
|
||||
pub hash: String,
|
||||
pub role: String,
|
||||
pub first_name: String,
|
||||
pub last_name: String,
|
||||
pub updated_at: chrono::NaiveDateTime,
|
||||
pub created_at: chrono::NaiveDateTime,
|
||||
pub profile_picture: Option<String>,
|
||||
pub verified: bool,
|
||||
}
|
||||
|
||||
impl InsertUser {
|
||||
pub fn insert(user: Self) -> Result<QueryUser, ServiceError> {
|
||||
let mut conn = connection()?;
|
||||
let user = diesel::insert_into(users::table)
|
||||
.values(user)
|
||||
.get_result(&mut conn)?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub fn update_profile(email: &str, profile_picture: Option<&str>) -> Result<QueryUser, ServiceError> {
|
||||
let mut conn = connection()?;
|
||||
let user = diesel::update(users::table)
|
||||
.filter(users::email.eq(&email))
|
||||
.set(users::profile_picture.eq(profile_picture))
|
||||
.get_result(&mut conn)?;
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ResponseUser {
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub first_name: String,
|
||||
pub last_name: String,
|
||||
pub profile_picture: Option<String>,
|
||||
}
|
||||
|
||||
impl From<QueryUser> for ResponseUser {
|
||||
fn from(user: QueryUser) -> Self {
|
||||
ResponseUser {
|
||||
email: user.email,
|
||||
role: user.role,
|
||||
first_name: user.first_name,
|
||||
last_name: user.last_name,
|
||||
profile_picture: user.profile_picture,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct JwtAuth {
|
||||
pub token: uuid::Uuid,
|
||||
pub user: ResponseUser
|
||||
}
|
||||
|
||||
impl FromRequest for JwtAuth {
|
||||
type Error = ActixError;
|
||||
type Future = Ready<Result<Self, Self::Error>>;
|
||||
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
|
||||
let access_token = match req
|
||||
.cookie("access_token")
|
||||
.map(|c| c.value().to_string())
|
||||
.or_else(|| {
|
||||
req.headers().get(http::header::AUTHORIZATION)
|
||||
.map(|h| h.to_str().unwrap().split_at(7).1.to_string())
|
||||
}) {
|
||||
Some(token) => token,
|
||||
None => return ready(Err(ActixError::from(ServiceError {
|
||||
status: 401,
|
||||
message: "Unauthorized".to_string()
|
||||
})))
|
||||
};
|
||||
|
||||
let public_key = env::var("ACCESS_TOKEN_PUBLIC_KEY")
|
||||
.expect("ACCESS_TOKEN_PUBLIC_KEY must be set");
|
||||
|
||||
let access_token_details = match verify_token(&access_token, &public_key) {
|
||||
Ok(token_details) => token_details,
|
||||
Err(err) => {
|
||||
error!("Failed to verify access token: {}", err);
|
||||
return ready(Err(ActixError::from(ServiceError {
|
||||
status: 401,
|
||||
message: format!("Failed to verify access token: {}", err)
|
||||
})))
|
||||
}
|
||||
};
|
||||
|
||||
let access_token_uuid = uuid::Uuid::parse_str(&access_token_details.token_uuid.to_string()).unwrap();
|
||||
|
||||
let mut conn = match crate::db::redis_connection() {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Failed to get redis connection: {}", err);
|
||||
return ready(Err(ActixError::from(ServiceError {
|
||||
status: 500,
|
||||
message: format!("Failed to get redis connection: {}", err)
|
||||
})))
|
||||
}
|
||||
};
|
||||
let user_email = match conn.get::<_, String>(access_token_uuid.clone().to_string()) {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
return ready(Err(ActixError::from(ServiceError {
|
||||
status: 401,
|
||||
message: format!("Access token was not found")
|
||||
})))
|
||||
}
|
||||
};
|
||||
|
||||
match QueryUser::get_by_email(&user_email) {
|
||||
Ok(user) => {
|
||||
ready(Ok(JwtAuth { token: access_token_uuid, user: user.into() }))
|
||||
}
|
||||
Err(_) => return ready(Err(ActixError::from(ServiceError {
|
||||
status: 401,
|
||||
message: format!("User was not found")
|
||||
})))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn verify_role(auth: &JwtAuth, role: &str) -> Result<(), ServiceError> {
|
||||
if auth.user.role == role {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ServiceError {
|
||||
status: 403,
|
||||
message: "Forbidden".to_string()
|
||||
})
|
||||
}
|
||||
}
|
||||
367
service/src/auth/routes.rs
Normal file
367
service/src/auth/routes.rs
Normal file
@@ -0,0 +1,367 @@
|
||||
use std::env;
|
||||
|
||||
use actix_web::{get, post, web, HttpResponse, ResponseError, cookie::{Cookie, time::Duration}, HttpRequest};
|
||||
use log::error;
|
||||
use redis::AsyncCommands;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error_handler::ServiceError;
|
||||
|
||||
use crate::{auth::{LoginRequest, RegisterUser, InsertUser, QueryUser, verify_password, JwtAuth, verify_token, generate_access_token, generate_refresh_token}, db};
|
||||
|
||||
#[post("/register")]
|
||||
async fn register(user: web::Json<RegisterUser>) -> HttpResponse {
|
||||
let register_user = user.0;
|
||||
let insert_user: InsertUser = match register_user.convert_to_insert() {
|
||||
Ok(user) => user,
|
||||
Err(err) => return ResponseError::error_response(&err)
|
||||
};
|
||||
match InsertUser::insert(insert_user) {
|
||||
Ok(_) => {
|
||||
HttpResponse::Created().finish()
|
||||
},
|
||||
Err(err) => {
|
||||
// Obfuscate the service error message to prevent leaking database details
|
||||
if err.status == 409 {
|
||||
return HttpResponse::Conflict().finish();
|
||||
} else {
|
||||
return ResponseError::error_response(&err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/login")]
|
||||
async fn login(request: web::Json<LoginRequest>) -> HttpResponse {
|
||||
let email = request.email.clone();
|
||||
|
||||
let query_user = match QueryUser::get_by_email(&email) {
|
||||
Ok(query_user) => query_user,
|
||||
Err(err) => return ResponseError::error_response(&err)
|
||||
};
|
||||
let hash = &query_user.hash;
|
||||
let password = request.password.as_bytes();
|
||||
match verify_password(hash, password) {
|
||||
Ok(_) => {
|
||||
let access_token_details = match generate_access_token(&email) {
|
||||
Ok(token_details) => token_details,
|
||||
Err(err) => {
|
||||
error!("Failed to generate access token: {}", err);
|
||||
return ResponseError::error_response(&err)
|
||||
}
|
||||
};
|
||||
|
||||
let refresh_token_details = match generate_refresh_token(&email) {
|
||||
Ok(token_details) => token_details,
|
||||
Err(err) => {
|
||||
error!("Failed to generate refresh token: {}", err);
|
||||
return ResponseError::error_response(&err)
|
||||
}
|
||||
};
|
||||
|
||||
let mut conn = match db::redis_async_connection().await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Failed to get redis connection: {}", err);
|
||||
return ResponseError::error_response(&err)
|
||||
}
|
||||
};
|
||||
|
||||
let access_token_max_age = env::var("ACCESS_TOKEN_MAXAGE")
|
||||
.expect("ACCESS_TOKEN_MAXAGE must be set")
|
||||
.parse::<i64>()
|
||||
.expect("ACCESS_TOKEN_MAXAGE must be an integer");
|
||||
|
||||
let refresh_token_max_age = env::var("REFRESH_TOKEN_MAXAGE")
|
||||
.expect("REFRESH_TOKEN_MAXAGE must be set")
|
||||
.parse::<i64>()
|
||||
.expect("REFRESH_TOKEN_MAXAGE must be an integer");
|
||||
|
||||
let access_result: redis::RedisResult<()> = conn.set_ex(access_token_details.token_uuid.to_string(), &email, (access_token_max_age * 60) as usize).await;
|
||||
if let Err(err) = access_result {
|
||||
error!("Failed to set access token in redis: {}", err);
|
||||
return ResponseError::error_response(&ServiceError {
|
||||
status: 500,
|
||||
message: format!("Failed to set access token in redis: {}", err)
|
||||
})
|
||||
};
|
||||
|
||||
let refresh_result: redis::RedisResult<()> = conn.set_ex(refresh_token_details.token_uuid.to_string(), &email, (refresh_token_max_age * 60) as usize).await;
|
||||
if let Err(err) = refresh_result {
|
||||
error!("Failed to set refresh token in redis: {}", err);
|
||||
return ResponseError::error_response(&ServiceError {
|
||||
status: 500,
|
||||
message: format!("Failed to set refresh token in redis: {}", err)
|
||||
})
|
||||
};
|
||||
|
||||
let access_cookie = Cookie::build("access_token", access_token_details.token.clone().unwrap())
|
||||
.path("/")
|
||||
.max_age(Duration::new(access_token_max_age * 60, 0))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.finish();
|
||||
let refresh_cookie = Cookie::build("refresh_token", refresh_token_details.token.clone().unwrap())
|
||||
.path("/")
|
||||
.max_age(Duration::new(refresh_token_max_age * 60, 0))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.finish();
|
||||
let logged_in_cookie = Cookie::build("logged_in", "true")
|
||||
.path("/")
|
||||
.max_age(Duration::new(access_token_max_age * 60, 0))
|
||||
.http_only(false)
|
||||
.finish();
|
||||
|
||||
let access_token_uuid = uuid::Uuid::parse_str(&access_token_details.token_uuid.to_string()).unwrap();
|
||||
|
||||
HttpResponse::Ok()
|
||||
.cookie(access_cookie)
|
||||
.cookie(refresh_cookie)
|
||||
.cookie(logged_in_cookie)
|
||||
.json(JwtAuth { token: access_token_uuid, user: query_user.into() })
|
||||
},
|
||||
Err(err) => ResponseError::error_response(&ServiceError {
|
||||
status: 401,
|
||||
message: err.to_string()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct RefreshParams {
|
||||
refresh_token_rotation: Option<bool>
|
||||
}
|
||||
|
||||
#[get("/refresh")]
|
||||
async fn refresh(req: HttpRequest) -> HttpResponse {
|
||||
let params = match web::Query::<RefreshParams>::from_query(req.query_string()) {
|
||||
Ok(params) => params,
|
||||
Err(err) => return ResponseError::error_response(&ServiceError {
|
||||
status: 422,
|
||||
message: err.to_string()
|
||||
})
|
||||
};
|
||||
|
||||
let refresh_token = match req.cookie("refresh_token") {
|
||||
Some(cookie) => cookie.value().to_string(),
|
||||
None => return ResponseError::error_response(&ServiceError {
|
||||
status: 401,
|
||||
message: "Refresh token not found".to_string()
|
||||
})
|
||||
};
|
||||
|
||||
let public_key = env::var("REFRESH_TOKEN_PUBLIC_KEY")
|
||||
.expect("REFRESH_TOKEN_PUBLIC_KEY must be set");
|
||||
let refresh_token_details = match verify_token(&refresh_token, &public_key) {
|
||||
Ok(token_details) => token_details,
|
||||
Err(err) => return ResponseError::error_response(&err)
|
||||
};
|
||||
|
||||
let email = refresh_token_details.email.clone();
|
||||
|
||||
match QueryUser::get_by_email(&email) {
|
||||
Ok(query_user) => {
|
||||
let access_token_details = match generate_access_token(&email) {
|
||||
Ok(token_details) => token_details,
|
||||
Err(err) => {
|
||||
error!("Failed to generate access token: {}", err);
|
||||
return ResponseError::error_response(&err)
|
||||
}
|
||||
};
|
||||
|
||||
let mut conn = match db::redis_async_connection().await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Failed to get redis connection: {}", err);
|
||||
return ResponseError::error_response(&err)
|
||||
}
|
||||
};
|
||||
|
||||
// Delete old auth token if it exists
|
||||
match req.cookie("access_token") {
|
||||
Some(cookie) => {
|
||||
let access_token = cookie.value().to_string();
|
||||
let public_key = env::var("ACCESS_TOKEN_PUBLIC_KEY")
|
||||
.expect("ACCESS_TOKEN_PUBLIC_KEY must be set");
|
||||
match verify_token(&access_token, &public_key) {
|
||||
Ok(token_details) => {
|
||||
let _: redis::RedisResult<()> = conn.del(token_details.token_uuid.to_string()).await;
|
||||
|
||||
},
|
||||
Err(_) => {}
|
||||
};
|
||||
},
|
||||
None => {}
|
||||
};
|
||||
|
||||
let access_token_max_age = env::var("ACCESS_TOKEN_MAXAGE")
|
||||
.expect("ACCESS_TOKEN_MAXAGE must be set")
|
||||
.parse::<i64>()
|
||||
.expect("ACCESS_TOKEN_MAXAGE must be an integer");
|
||||
|
||||
let access_result: redis::RedisResult<()> = conn.set_ex(access_token_details.token_uuid.to_string(), &email, (access_token_max_age * 60) as usize).await;
|
||||
if let Err(err) = access_result {
|
||||
error!("Failed to set access token in redis: {}", err);
|
||||
return ResponseError::error_response(&ServiceError {
|
||||
status: 500,
|
||||
message: format!("Failed to set access token in redis: {}", err)
|
||||
})
|
||||
};
|
||||
|
||||
let access_cookie = Cookie::build("access_token", access_token_details.token.clone().unwrap())
|
||||
.path("/")
|
||||
.max_age(Duration::new(access_token_max_age * 60, 0))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.finish();
|
||||
let logged_in_cookie = Cookie::build("logged_in", "true")
|
||||
.path("/")
|
||||
.max_age(Duration::new(access_token_max_age * 60, 0))
|
||||
.http_only(false)
|
||||
.finish();
|
||||
|
||||
let access_token_uuid = uuid::Uuid::parse_str(&access_token_details.token_uuid.to_string()).unwrap();
|
||||
|
||||
// Refresh the refresh token if requested
|
||||
let refresh_token_rotation = match params.refresh_token_rotation {
|
||||
Some(refresh_token_rotation) => refresh_token_rotation,
|
||||
None => false
|
||||
};
|
||||
if refresh_token_rotation {
|
||||
// Delete the old refresh token
|
||||
let _: redis::RedisResult<()> = conn.del(refresh_token_details.token_uuid.to_string()).await;
|
||||
|
||||
let refresh_token_details = match generate_refresh_token(&refresh_token_details.email) {
|
||||
Ok(token_details) => token_details,
|
||||
Err(err) => {
|
||||
error!("Failed to generate refresh token: {}", err);
|
||||
return ResponseError::error_response(&err)
|
||||
}
|
||||
};
|
||||
|
||||
let refresh_token_max_age = env::var("REFRESH_TOKEN_MAXAGE")
|
||||
.expect("REFRESH_TOKEN_MAXAGE must be set")
|
||||
.parse::<i64>()
|
||||
.expect("REFRESH_TOKEN_MAXAGE must be an integer");
|
||||
|
||||
let refresh_result: redis::RedisResult<()> = conn.set_ex(refresh_token_details.token_uuid.to_string(), &refresh_token_details.email, (refresh_token_max_age * 60) as usize).await;
|
||||
if let Err(err) = refresh_result {
|
||||
error!("Failed to set refresh token in redis: {}", err);
|
||||
return ResponseError::error_response(&ServiceError {
|
||||
status: 500,
|
||||
message: format!("Failed to set refresh token in redis: {}", err)
|
||||
})
|
||||
};
|
||||
|
||||
let refresh_cookie = Cookie::build("refresh_token", refresh_token_details.token.clone().unwrap())
|
||||
.path("/")
|
||||
.max_age(Duration::new(refresh_token_max_age * 60, 0))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.finish();
|
||||
|
||||
HttpResponse::Ok()
|
||||
.cookie(refresh_cookie)
|
||||
.cookie(access_cookie)
|
||||
.cookie(logged_in_cookie)
|
||||
.json(JwtAuth { token: access_token_uuid, user: query_user.into() })
|
||||
} else {
|
||||
HttpResponse::Ok()
|
||||
.cookie(access_cookie)
|
||||
.cookie(logged_in_cookie)
|
||||
.json(JwtAuth { token: access_token_uuid, user: query_user.into() })
|
||||
}
|
||||
},
|
||||
Err(err) => return ResponseError::error_response(&err)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/logout")]
|
||||
async fn logout(req: HttpRequest, auth: JwtAuth) -> HttpResponse {
|
||||
let refresh_token = match req.cookie("refresh_token") {
|
||||
Some(cookie) => cookie.value().to_string(),
|
||||
None => return ResponseError::error_response(&ServiceError {
|
||||
status: 401,
|
||||
message: "Refresh token not found".to_string()
|
||||
})
|
||||
};
|
||||
let public_key = env::var("REFRESH_TOKEN_PUBLIC_KEY")
|
||||
.expect("REFRESH_TOKEN_PUBLIC_KEY must be set");
|
||||
let refresh_token_details = match verify_token(&refresh_token, &public_key) {
|
||||
Ok(token_details) => token_details,
|
||||
Err(err) => return ResponseError::error_response(&err)
|
||||
};
|
||||
|
||||
let mut conn = match db::redis_async_connection().await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Failed to get redis connection: {}", err);
|
||||
return ResponseError::error_response(&err)
|
||||
}
|
||||
};
|
||||
|
||||
let access_result: redis::RedisResult<()> = conn.del(&[
|
||||
refresh_token_details.token_uuid.to_string(),
|
||||
auth.token.to_string()
|
||||
]).await;
|
||||
if let Err(err) = access_result {
|
||||
error!("Failed to set access token in redis: {}", err);
|
||||
return ResponseError::error_response(&ServiceError {
|
||||
status: 500,
|
||||
message: format!("Failed to set access token in redis: {}", err)
|
||||
})
|
||||
};
|
||||
|
||||
let access_cookie = Cookie::build("access_token", "")
|
||||
.path("/")
|
||||
.max_age(Duration::new(-1, 0))
|
||||
.http_only(true)
|
||||
.finish();
|
||||
let refresh_cookie = Cookie::build("refresh_token", "")
|
||||
.path("/")
|
||||
.max_age(Duration::new(-1, 0))
|
||||
.http_only(true)
|
||||
.finish();
|
||||
let logged_in_cookie = Cookie::build("logged_in", "")
|
||||
.path("/")
|
||||
.max_age(Duration::new(-1, 0))
|
||||
.http_only(true)
|
||||
.finish();
|
||||
|
||||
HttpResponse::Ok()
|
||||
.cookie(access_cookie)
|
||||
.cookie(refresh_cookie)
|
||||
.cookie(logged_in_cookie)
|
||||
.finish()
|
||||
}
|
||||
|
||||
#[get("/me")]
|
||||
async fn me(auth: JwtAuth) -> HttpResponse {
|
||||
HttpResponse::Ok().json(auth)
|
||||
}
|
||||
|
||||
#[get("/roles")]
|
||||
async fn roles() -> HttpResponse {
|
||||
HttpResponse::Ok().json(vec!["admin", "user"])
|
||||
}
|
||||
|
||||
pub fn init_routes(config: &mut web::ServiceConfig) {
|
||||
let r = RegisterUser {
|
||||
email: "admin".to_string(),
|
||||
password: "admin".to_string(),
|
||||
first_name: "Admin".to_string(),
|
||||
last_name: "Admin".to_string(),
|
||||
};
|
||||
let mut u = r.convert_to_insert().unwrap();
|
||||
u.role = "admin".to_string();
|
||||
u.verified = true;
|
||||
let _ = InsertUser::insert(u);
|
||||
config.service(web::scope("auth")
|
||||
.service(register)
|
||||
.service(login)
|
||||
.service(refresh)
|
||||
.service(logout)
|
||||
.service(me)
|
||||
.service(roles)
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{error_handler::ServiceError, airports::{InsertAirport, QueryAirport}};
|
||||
use diesel::{r2d2::ConnectionManager, PgConnection};
|
||||
use redis::{Client as RedisClient, aio::Connection as RedisConnection};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::diesel_migrations::MigrationHarness;
|
||||
use lazy_static::lazy_static;
|
||||
@@ -7,6 +8,8 @@ use log::{error, debug, info, warn};
|
||||
use r2d2;
|
||||
use std::env;
|
||||
|
||||
pub mod schema;
|
||||
|
||||
type Pool = r2d2::Pool<ConnectionManager<PgConnection>>;
|
||||
pub type DbConnection = r2d2::PooledConnection<ConnectionManager<PgConnection>>;
|
||||
|
||||
@@ -16,29 +19,24 @@ lazy_static! {
|
||||
static ref POOL: Pool = {
|
||||
let username = env::var("DATABASE_USER").expect("Database username is not set");
|
||||
let password = env::var("DATABASE_PASSWORD").expect("Database password is not set");
|
||||
let host = match env::var("DATABASE_HOST") {
|
||||
Ok(h) => h,
|
||||
Err(_) => {
|
||||
warn!("Defaulting to DATABASE_HOST localhost");
|
||||
"localhost".to_string()
|
||||
}
|
||||
};
|
||||
let host = env::var("DATABASE_HOST").unwrap_or("localhost".to_string());
|
||||
let name = env::var("DATABASE_NAME").expect("Database name is not set");
|
||||
let port = match env::var("DATABASE_PORT") {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
warn!("Defaulting to DATABASE_PORT 5432");
|
||||
"5432".to_string()
|
||||
}
|
||||
};
|
||||
let port = env::var("DATABASE_PORT").unwrap_or("5432".to_string());
|
||||
let url = format!("postgres://{}:{}@{}:{}/{}", username, password, host, port, name);
|
||||
let manager = ConnectionManager::<PgConnection>::new(url);
|
||||
Pool::builder().test_on_check_out(true).build(manager).expect("Failed to create db pool")
|
||||
};
|
||||
static ref REDIS: RedisClient = {
|
||||
let host = env::var("REDIS_HOST").unwrap_or("localhost".to_string());
|
||||
let port = env::var("REDIS_PORT").unwrap_or("6379".to_string());
|
||||
let url = format!("redis://{}:{}", host, port);
|
||||
RedisClient::open(url).expect("Failed to create redis client")
|
||||
};
|
||||
}
|
||||
|
||||
pub fn init() {
|
||||
lazy_static::initialize(&POOL);
|
||||
lazy_static::initialize(&REDIS);
|
||||
let mut pool: DbConnection = connection().expect("Failed to get db connection");
|
||||
match pool.run_pending_migrations(MIGRATIONS) {
|
||||
Ok(_) => info!("Database initialized"),
|
||||
@@ -51,6 +49,16 @@ pub fn connection() -> Result<DbConnection, ServiceError> {
|
||||
.map_err(|e| ServiceError::new(500, format!("Failed getting db connection: {}", e)))
|
||||
}
|
||||
|
||||
pub fn redis_connection() -> Result<redis::Connection, ServiceError> {
|
||||
let conn = REDIS.get_connection()?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub async fn redis_async_connection() -> Result<RedisConnection, ServiceError> {
|
||||
let conn = REDIS.get_async_connection().await?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub fn import_data() {
|
||||
let path = "airport-codes.json";
|
||||
debug!("Importing data from {}", path);
|
||||
@@ -48,13 +48,15 @@ diesel::table! {
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
use diesel::sql_types::*;
|
||||
use crate::users::PgUserType;
|
||||
users (id) {
|
||||
id -> Uuid,
|
||||
users (email) {
|
||||
email -> Text,
|
||||
hash -> Text,
|
||||
role -> Text,
|
||||
first_name -> Text,
|
||||
last_name -> Text,
|
||||
user_type -> PgUserType,
|
||||
favorites -> Array<Text>
|
||||
updated_at -> Timestamp,
|
||||
created_at -> Timestamp,
|
||||
profile_picture -> Nullable<Text>,
|
||||
verified -> Bool,
|
||||
}
|
||||
}
|
||||
@@ -8,60 +8,100 @@ use std::fmt;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ServiceError {
|
||||
pub error_status_code: u16,
|
||||
pub error_message: String,
|
||||
pub status: u16,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl ServiceError {
|
||||
pub fn new(error_status_code: u16, error_message: String) -> ServiceError {
|
||||
ServiceError {
|
||||
error_status_code,
|
||||
error_message,
|
||||
}
|
||||
pub fn new(status: u16, message: String) -> ServiceError {
|
||||
ServiceError {
|
||||
status,
|
||||
message,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_http_response(&self) -> HttpResponse {
|
||||
let status_code = match StatusCode::from_u16(self.error_status_code) {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
warn!("{}", err);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
};
|
||||
HttpResponse::build(status_code).body(self.error_message.to_string())
|
||||
}
|
||||
pub fn to_http_response(&self) -> HttpResponse {
|
||||
let status = match StatusCode::from_u16(self.status) {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
warn!("{}", err);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
};
|
||||
HttpResponse::build(status).body(self.message.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ServiceError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.error_message.as_str())
|
||||
}
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(self.message.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DieselError> for ServiceError {
|
||||
fn from(error: DieselError) -> ServiceError {
|
||||
match error {
|
||||
DieselError::DatabaseError(_, err) => ServiceError::new(409, err.message().to_string()),
|
||||
DieselError::NotFound => {
|
||||
ServiceError::new(404, "The record was not found".to_string())
|
||||
}
|
||||
err => ServiceError::new(500, format!("Unknown Diesel error: {}", err)),
|
||||
fn from(error: DieselError) -> ServiceError {
|
||||
match error {
|
||||
DieselError::DatabaseError(kind, err) => {
|
||||
match kind {
|
||||
diesel::result::DatabaseErrorKind::UniqueViolation => {
|
||||
ServiceError::new(409, err.message().to_string())
|
||||
},
|
||||
_ => ServiceError::new(500, err.message().to_string())
|
||||
}
|
||||
},
|
||||
DieselError::NotFound => {
|
||||
ServiceError::new(404, "The record was not found".to_string())
|
||||
},
|
||||
DieselError::SerializationError(err) => {
|
||||
ServiceError::new(422, err.to_string())
|
||||
},
|
||||
err => ServiceError::new(500, format!("Unknown Diesel error: {}", err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for ServiceError {
|
||||
fn from(error: reqwest::Error) -> ServiceError {
|
||||
ServiceError::new(500, format!("Unknown reqwest error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for ServiceError {
|
||||
fn from(error: serde_json::Error) -> ServiceError {
|
||||
ServiceError::new(500, format!("Unknown serde_json error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<argon2::password_hash::Error> for ServiceError {
|
||||
fn from(error: argon2::password_hash::Error) -> ServiceError {
|
||||
ServiceError::new(500, format!("Unknown argon2 error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<jsonwebtoken::errors::Error> for ServiceError {
|
||||
fn from(error: jsonwebtoken::errors::Error) -> ServiceError {
|
||||
ServiceError::new(500, format!("Unknown jsonwebtoken error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redis::RedisError> for ServiceError {
|
||||
fn from(error: redis::RedisError) -> ServiceError {
|
||||
ServiceError::new(500, format!("Unknown redis error: {}", error))
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseError for ServiceError {
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
let status_code = match StatusCode::from_u16(self.error_status_code) {
|
||||
Ok(status_code) => status_code,
|
||||
Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
let status = match StatusCode::from_u16(self.status) {
|
||||
Ok(status) => status,
|
||||
Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
let error_message = match status_code.as_u16() < 500 {
|
||||
true => self.error_message.clone(),
|
||||
false => "Internal server error".to_string(),
|
||||
};
|
||||
let message = match status.as_u16() < 500 {
|
||||
true => self.message.clone(),
|
||||
false => "Internal server error".to_string(),
|
||||
};
|
||||
|
||||
HttpResponse::build(status_code).json(json!({ "message": error_message }))
|
||||
}
|
||||
HttpResponse::build(status).json(json!({ "status": status.as_u16(), "message": message }))
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
extern crate actix_web;
|
||||
extern crate diesel;
|
||||
#[macro_use]
|
||||
extern crate diesel_migrations;
|
||||
|
||||
use std::env;
|
||||
|
||||
use actix_cors::Cors;
|
||||
use actix_web::{App, HttpServer, middleware::Logger};
|
||||
use dotenv::dotenv;
|
||||
use env_logger::Env;
|
||||
use listenfd::ListenFd;
|
||||
use log::{debug, warn};
|
||||
use log::{info, error};
|
||||
|
||||
mod airports;
|
||||
mod auth;
|
||||
@@ -16,53 +15,43 @@ mod db;
|
||||
mod error_handler;
|
||||
mod metars;
|
||||
mod users;
|
||||
mod schema;
|
||||
mod scheduler;
|
||||
|
||||
#[actix_rt::main]
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
dotenv().ok();
|
||||
if std::env::var_os("RUST_LOG").is_none() {
|
||||
std::env::set_var("RUST_LOG", "info,actix=info,diesel_migrations=warn,reqwest=warn,hyper=warn");
|
||||
}
|
||||
env_logger::init_from_env(Env::default().default_filter_or("info"));
|
||||
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "warn,siren=info"));
|
||||
db::init();
|
||||
scheduler::update_airports();
|
||||
|
||||
let mut listenfd = ListenFd::from_env();
|
||||
let mut server = HttpServer::new(|| {
|
||||
let host = env::var("SERVICE_HOST").unwrap_or("localhost".to_string());
|
||||
let port = env::var("SERVICE_PORT").unwrap_or("5000".to_string());
|
||||
|
||||
let server = match HttpServer::new(move || {
|
||||
let cors = Cors::default()
|
||||
.allow_any_origin()
|
||||
.allow_any_method()
|
||||
.allow_any_header();
|
||||
.allow_any_header()
|
||||
.supports_credentials()
|
||||
.max_age(3600);
|
||||
App::new()
|
||||
.configure(airports::init_routes)
|
||||
.configure(metars::init_routes)
|
||||
.configure(users::init_routes)
|
||||
.wrap(cors)
|
||||
.wrap(Logger::default())
|
||||
});
|
||||
|
||||
server = match listenfd.take_tcp_listener(0)? {
|
||||
Some(listener) => server.listen(listener)?,
|
||||
None => {
|
||||
let host = match std::env::var("SERVICE_HOST") {
|
||||
Ok(h) => h,
|
||||
Err(_) => {
|
||||
warn!("Defaulting to SERVICE_HOST localhost");
|
||||
"localhost".to_string()
|
||||
}
|
||||
};
|
||||
let port = match std::env::var("SERVICE_PORT") {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
warn!("Defaulting to SERVICE_PORT 5000");
|
||||
"5000".to_string()
|
||||
}
|
||||
};
|
||||
debug!("Binding server to {}:{}", host, port);
|
||||
server.bind(format!("{}:{}", host, port))?
|
||||
.configure(airports::init_routes)
|
||||
.configure(metars::init_routes)
|
||||
.configure(auth::init_routes)
|
||||
.configure(users::init_routes)
|
||||
})
|
||||
.bind(format!("{}:{}", host, port)) {
|
||||
Ok(b) => {
|
||||
info!("Binding server to {}:{}", host, port);
|
||||
b
|
||||
},
|
||||
Err(err) => {
|
||||
error!("Could not bind server: {}", err);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
server.run().await
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{error_handler::ServiceError, db};
|
||||
use crate::schema::metars::{self};
|
||||
use crate::db::schema::metars::{self};
|
||||
use diesel::{prelude::*, sql_query};
|
||||
use log::{warn, trace};
|
||||
use std::collections::HashSet;
|
||||
@@ -375,7 +375,7 @@ impl QueryMetar {
|
||||
let mut conn = db::connection()?;
|
||||
let db_metars: Vec<Self> = match sql_query(format!("SELECT DISTINCT ON (station_id) * FROM metars WHERE station_id IN ({}) ORDER BY station_id, observation_time DESC", station_query.join(","))).load(&mut conn) {
|
||||
Ok(m) => m,
|
||||
Err(err) => return Err(ServiceError { error_status_code: 500, error_message: format!("{}", err) })
|
||||
Err(err) => return Err(ServiceError { status: 500, message: format!("{}", err) })
|
||||
};
|
||||
return Ok(db_metars);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
mod model;
|
||||
mod routes;
|
||||
mod user_type;
|
||||
|
||||
pub use user_type::PgUserType;
|
||||
pub use model::*;
|
||||
pub use routes::init_routes;
|
||||
@@ -1,50 +0,0 @@
|
||||
use std::{future::Future, pin::Pin, sync::RwLock};
|
||||
|
||||
use actix_web::{dev::Payload, error::ErrorUnauthorized, web, Error, FromRequest, HttpRequest};
|
||||
use actix_identity::Identity;
|
||||
use diesel::{query_builder::AsChangeset, prelude::Insertable};
|
||||
use log::warn;
|
||||
use crate::schema::users;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use super::user_type::UserType;
|
||||
|
||||
#[derive(Serialize, Deserialize, AsChangeset, Insertable)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct InsertUser {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
user_type: UserType,
|
||||
favorites: Vec<String>
|
||||
}
|
||||
|
||||
// impl FromRequest for InsertUser {
|
||||
// type Config = ();
|
||||
// type Error = Error;
|
||||
// type Future = Pin<Box<dyn Future<Output = Result<Self, Error>>>>;
|
||||
|
||||
// fn from_request(req: &HttpRequest, pl: &mut Payload) -> Self::Future {
|
||||
// let fut = Identity::from_request(req, pl);
|
||||
// let sessions: Option<&web::Data<RwLock<Sessions>>> = req.app_data();
|
||||
// if sessions.is_none() {
|
||||
// warn!("sessions is empty(none)!");
|
||||
// return Box::pin(async { Err(ErrorUnauthorized("unauthorized")) });
|
||||
// }
|
||||
// let sessions = sessions.unwrap().clone();
|
||||
// Box::pin(async move {
|
||||
// if let Some(identity) = fut.await?.identity() {
|
||||
// if let Some(user) = sessions
|
||||
// .read()
|
||||
// .unwrap()
|
||||
// .map
|
||||
// .get(&identity)
|
||||
// .map(|x| x.clone())
|
||||
// {
|
||||
// return Ok(user);
|
||||
// }
|
||||
// };
|
||||
|
||||
// Err(ErrorUnauthorized("unauthorized"))
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
@@ -1,29 +1,4 @@
|
||||
use actix_web::{get, post, delete, put, web, HttpResponse};
|
||||
|
||||
#[get("users")]
|
||||
async fn get() -> HttpResponse {
|
||||
HttpResponse::NotImplemented().finish()
|
||||
}
|
||||
|
||||
#[get("users/{id}")]
|
||||
async fn get_all() -> HttpResponse {
|
||||
HttpResponse::NotImplemented().finish()
|
||||
}
|
||||
|
||||
#[post("users")]
|
||||
async fn create() -> HttpResponse {
|
||||
HttpResponse::NotImplemented().finish()
|
||||
}
|
||||
|
||||
#[delete("users")]
|
||||
async fn delete() -> HttpResponse {
|
||||
HttpResponse::NotImplemented().finish()
|
||||
}
|
||||
|
||||
#[put("users")]
|
||||
async fn update() -> HttpResponse {
|
||||
HttpResponse::NotImplemented().finish()
|
||||
}
|
||||
use actix_web::{get, post, delete, web, HttpResponse};
|
||||
|
||||
#[get("users/favorites")]
|
||||
async fn get_favorites() -> HttpResponse {
|
||||
@@ -41,10 +16,6 @@ async fn delete_favorite() -> HttpResponse {
|
||||
}
|
||||
|
||||
pub fn init_routes(config: &mut web::ServiceConfig) {
|
||||
config.service(get);
|
||||
config.service(create);
|
||||
config.service(delete);
|
||||
config.service(update);
|
||||
config.service(get_favorites);
|
||||
config.service(add_favorite);
|
||||
config.service(delete_favorite);
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
use std::io::Write;
|
||||
|
||||
use diesel::{sql_types::SqlType, deserialize::{FromSqlRow, FromSql, self}, expression::AsExpression, serialize::{ToSql, Output, self, IsNull}, pg::{Pg, PgValue}};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(SqlType)]
|
||||
#[diesel(postgres_type(name = "User_Type"))]
|
||||
pub struct PgUserType;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, FromSqlRow, AsExpression, Eq)]
|
||||
#[diesel(sql_type = PgUserType)]
|
||||
pub enum UserType {
|
||||
Admin,
|
||||
User,
|
||||
}
|
||||
|
||||
impl ToSql<PgUserType, Pg> for UserType {
|
||||
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
|
||||
match *self {
|
||||
Self::Admin => out.write_all(b"admin")?,
|
||||
Self::User => out.write_all(b"user")?,
|
||||
}
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql<PgUserType, Pg> for UserType {
|
||||
fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> {
|
||||
match bytes.as_bytes() {
|
||||
b"admin" => Ok(Self::Admin),
|
||||
b"user" => Ok(Self::User),
|
||||
_ => Err("Unrecognized enum variant".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user