backend/
utils.rs

1use std::sync::LazyLock;
2
3use actix_web::{
4    cookie::{Cookie, SameSite, time::Duration},
5    http::header::HeaderMap,
6    web::BytesMut,
7};
8use bindet::FileType;
9use diesel::{ExpressionMethods, QueryDsl};
10use diesel_async::RunQueryDsl;
11use getrandom::fill;
12use hex::encode;
13use redis::RedisError;
14use regex::Regex;
15use serde::Serialize;
16use uuid::Uuid;
17
18use crate::{
19    Conn, Data,
20    config::Config,
21    error::Error,
22    objects::{HasIsAbove, HasUuid},
23    schema::users,
24};
25
26pub static EMAIL_REGEX: LazyLock<Regex> = LazyLock::new(|| {
27    Regex::new(r"[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+(?:\.[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+)*@(?:[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?\.)+[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?").unwrap()
28});
29
30pub static USERNAME_REGEX: LazyLock<Regex> =
31    LazyLock::new(|| Regex::new(r"^[a-z0-9_.-]+$").unwrap());
32
33pub static CHANNEL_REGEX: LazyLock<Regex> =
34    LazyLock::new(|| Regex::new(r"^[a-z0-9_.-]+$").unwrap());
35
36// Password is expected to be hashed using SHA3-384
37pub static PASSWORD_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[0-9a-f]{96}").unwrap());
38
39pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, Error> {
40    let auth_token = headers.get(actix_web::http::header::AUTHORIZATION);
41
42    if auth_token.is_none() {
43        return Err(Error::Unauthorized(
44            "No authorization header provided".to_string(),
45        ));
46    }
47
48    let auth_raw = auth_token.unwrap().to_str()?;
49
50    let mut auth = auth_raw.split_whitespace();
51
52    let auth_type = auth.next();
53
54    let auth_value = auth.next();
55
56    if auth_type.is_none() {
57        return Err(Error::BadRequest(
58            "Authorization header is empty".to_string(),
59        ));
60    } else if auth_type.is_some_and(|at| at != "Bearer") {
61        return Err(Error::BadRequest(
62            "Only token auth is supported".to_string(),
63        ));
64    }
65
66    if auth_value.is_none() {
67        return Err(Error::BadRequest("No token provided".to_string()));
68    }
69
70    Ok(auth_value.unwrap())
71}
72
73pub fn get_ws_protocol_header(headers: &HeaderMap) -> Result<&str, Error> {
74    let auth_token = headers.get(actix_web::http::header::SEC_WEBSOCKET_PROTOCOL);
75
76    if auth_token.is_none() {
77        return Err(Error::Unauthorized(
78            "No authorization header provided".to_string(),
79        ));
80    }
81
82    let auth_raw = auth_token.unwrap().to_str()?;
83
84    let mut auth = auth_raw.split_whitespace();
85
86    let response_proto = auth.next();
87
88    let auth_value = auth.next();
89
90    if response_proto.is_none() {
91        return Err(Error::BadRequest(
92            "Sec-WebSocket-Protocol header is empty".to_string(),
93        ));
94    } else if response_proto.is_some_and(|rp| rp != "Authorization,") {
95        return Err(Error::BadRequest(
96            "First protocol should be Authorization".to_string(),
97        ));
98    }
99
100    if auth_value.is_none() {
101        return Err(Error::BadRequest("No token provided".to_string()));
102    }
103
104    Ok(auth_value.unwrap())
105}
106
107pub fn new_refresh_token_cookie(config: &Config, refresh_token: String) -> Cookie<'static> {
108    Cookie::build("refresh_token", refresh_token)
109        .http_only(true)
110        .secure(true)
111        .same_site(SameSite::None)
112        //.domain(config.web.backend_url.domain().unwrap().to_string())
113        .path(config.web.backend_url.path().to_string())
114        .max_age(Duration::days(30))
115        .finish()
116}
117
118pub fn generate_access_token() -> Result<String, getrandom::Error> {
119    let mut buf = [0u8; 16];
120    fill(&mut buf)?;
121    Ok(encode(buf))
122}
123
124pub fn generate_refresh_token() -> Result<String, getrandom::Error> {
125    let mut buf = [0u8; 32];
126    fill(&mut buf)?;
127    Ok(encode(buf))
128}
129
130pub fn image_check(icon: BytesMut) -> Result<String, Error> {
131    let buf = std::io::Cursor::new(icon);
132
133    let detect = bindet::detect(buf).map_err(|e| e.kind());
134
135    if let Ok(Some(file_type)) = detect {
136        if file_type.likely_to_be == vec![FileType::Jpg] {
137            return Ok(String::from("jpg"));
138        } else if file_type.likely_to_be == vec![FileType::Png] {
139            return Ok(String::from("png"));
140        }
141    }
142
143    Err(Error::BadRequest(
144        "Uploaded file is not an image".to_string(),
145    ))
146}
147
148pub async fn user_uuid_from_identifier(
149    conn: &mut Conn,
150    identifier: &String,
151) -> Result<Uuid, Error> {
152    if EMAIL_REGEX.is_match(identifier) {
153        use users::dsl;
154        let user_uuid = dsl::users
155            .filter(dsl::email.eq(identifier))
156            .select(dsl::uuid)
157            .get_result(conn)
158            .await?;
159
160        Ok(user_uuid)
161    } else if USERNAME_REGEX.is_match(identifier) {
162        use users::dsl;
163        let user_uuid = dsl::users
164            .filter(dsl::username.eq(identifier))
165            .select(dsl::uuid)
166            .get_result(conn)
167            .await?;
168
169        Ok(user_uuid)
170    } else {
171        Err(Error::BadRequest(
172            "Please provide a valid username or email".to_string(),
173        ))
174    }
175}
176
177pub async fn global_checks(data: &Data, user_uuid: Uuid) -> Result<(), Error> {
178    if data.config.instance.require_email_verification {
179        let mut conn = data.pool.get().await?;
180
181        use users::dsl;
182        let email_verified: bool = dsl::users
183            .filter(dsl::uuid.eq(user_uuid))
184            .select(dsl::email_verified)
185            .get_result(&mut conn)
186            .await?;
187
188        if !email_verified {
189            return Err(Error::Forbidden(
190                "server requires email verification".to_string(),
191            ));
192        }
193    }
194
195    Ok(())
196}
197
198pub async fn order_by_is_above<T>(mut items: Vec<T>) -> Result<Vec<T>, Error>
199where
200    T: HasUuid + HasIsAbove,
201{
202    let mut ordered = Vec::new();
203
204    // Find head
205    let head_pos = items
206        .iter()
207        .position(|item| !items.iter().any(|i| i.is_above() == Some(item.uuid())));
208
209    if let Some(pos) = head_pos {
210        ordered.push(items.swap_remove(pos));
211
212        while let Some(next_pos) = items
213            .iter()
214            .position(|item| Some(item.uuid()) == ordered.last().unwrap().is_above())
215        {
216            ordered.push(items.swap_remove(next_pos));
217        }
218    }
219
220    Ok(ordered)
221}
222
223impl Data {
224    pub async fn set_cache_key(
225        &self,
226        key: String,
227        value: impl Serialize,
228        expire: u32,
229    ) -> Result<(), Error> {
230        let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?;
231
232        let key_encoded = encode(key);
233
234        let value_json = serde_json::to_string(&value)?;
235
236        redis::cmd("SET")
237            .arg(&[key_encoded.clone(), value_json])
238            .exec_async(&mut conn)
239            .await?;
240
241        redis::cmd("EXPIRE")
242            .arg(&[key_encoded, expire.to_string()])
243            .exec_async(&mut conn)
244            .await?;
245
246        Ok(())
247    }
248
249    pub async fn get_cache_key(&self, key: String) -> Result<String, RedisError> {
250        let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?;
251
252        let key_encoded = encode(key);
253
254        redis::cmd("GET")
255            .arg(key_encoded)
256            .query_async(&mut conn)
257            .await
258    }
259
260    pub async fn del_cache_key(&self, key: String) -> Result<(), RedisError> {
261        let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?;
262
263        let key_encoded = encode(key);
264
265        redis::cmd("DEL")
266            .arg(key_encoded)
267            .query_async(&mut conn)
268            .await
269    }
270}