1use rand::seq::IndexedRandom;
2use std::sync::LazyLock;
3
4use axum::body::Bytes;
5use axum_extra::extract::cookie::{Cookie, SameSite};
6use bindet::FileType;
7use diesel::{ExpressionMethods, QueryDsl};
8use diesel_async::RunQueryDsl;
9use getrandom::fill;
10use hex::encode;
11use regex::Regex;
12use serde::{Serialize, de::DeserializeOwned};
13use time::Duration;
14use uuid::Uuid;
15
16use crate::{
17 Conn,
18 config::Config,
19 error::Error,
20 objects::{HasIsAbove, HasUuid},
21 schema::users,
22 wordlist::{ADJECTIVES, ANIMALS},
23};
24
25pub static EMAIL_REGEX: LazyLock<Regex> = LazyLock::new(|| {
26 Regex::new(r"[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+(?:\.[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+)*@(?:[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?\.)+[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?").unwrap()
27});
28
29pub static USERNAME_REGEX: LazyLock<Regex> =
30 LazyLock::new(|| Regex::new(r"^[a-z0-9_.-]+$").unwrap());
31
32pub static CHANNEL_REGEX: LazyLock<Regex> =
33 LazyLock::new(|| Regex::new(r"^[a-z0-9_.-]+$").unwrap());
34
35pub static PASSWORD_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[0-9a-f]{96}").unwrap());
36
37pub fn new_refresh_token_cookie(config: &Config, refresh_token: String) -> Cookie {
38 Cookie::build(("refresh_token", refresh_token))
39 .http_only(true)
40 .secure(true)
41 .same_site(SameSite::None)
42 .path(config.web.backend_url.path().to_string())
43 .max_age(Duration::days(30))
44 .build()
45}
46
47pub fn generate_token<const N: usize>() -> Result<String, getrandom::Error> {
48 let mut buf = [0u8; N];
49 fill(&mut buf)?;
50 Ok(encode(buf))
51}
52
53pub fn image_check(icon: Bytes) -> Result<String, Error> {
54 let buf = std::io::Cursor::new(icon);
55
56 let detect = bindet::detect(buf).map_err(|e| e.kind());
57
58 if let Ok(Some(file_type)) = detect {
59 if file_type.likely_to_be == vec![FileType::Jpg] {
60 return Ok(String::from("jpg"));
61 } else if file_type.likely_to_be == vec![FileType::Png] {
62 return Ok(String::from("png"));
63 }
64 }
65
66 Err(Error::BadRequest(
67 "Uploaded file is not an image".to_string(),
68 ))
69}
70
71pub async fn user_uuid_from_identifier(
72 conn: &mut Conn,
73 identifier: &String,
74) -> Result<Uuid, Error> {
75 if EMAIL_REGEX.is_match(identifier) {
76 use users::dsl;
77 let user_uuid = dsl::users
78 .filter(dsl::email.eq(identifier))
79 .select(dsl::uuid)
80 .get_result(conn)
81 .await?;
82
83 Ok(user_uuid)
84 } else if USERNAME_REGEX.is_match(identifier) {
85 use users::dsl;
86 let user_uuid = dsl::users
87 .filter(dsl::username.eq(identifier))
88 .select(dsl::uuid)
89 .get_result(conn)
90 .await?;
91
92 Ok(user_uuid)
93 } else {
94 Err(Error::BadRequest(
95 "Please provide a valid username or email".to_string(),
96 ))
97 }
98}
99
100pub async fn user_uuid_from_username(conn: &mut Conn, username: &String) -> Result<Uuid, Error> {
101 if USERNAME_REGEX.is_match(username) {
102 use users::dsl;
103 let user_uuid = dsl::users
104 .filter(dsl::username.eq(username))
105 .select(dsl::uuid)
106 .get_result(conn)
107 .await?;
108
109 Ok(user_uuid)
110 } else {
111 Err(Error::BadRequest(
112 "Please provide a valid username".to_string(),
113 ))
114 }
115}
116
117pub async fn global_checks(conn: &mut Conn, config: &Config, user_uuid: Uuid) -> Result<(), Error> {
118 if config.instance.require_email_verification {
119 use users::dsl;
120 let email_verified: bool = dsl::users
121 .filter(dsl::uuid.eq(user_uuid))
122 .select(dsl::email_verified)
123 .get_result(conn)
124 .await?;
125
126 if !email_verified {
127 return Err(Error::Forbidden(
128 "server requires email verification".to_string(),
129 ));
130 }
131 }
132
133 Ok(())
134}
135
136pub async fn order_by_is_above<T>(mut items: Vec<T>) -> Result<Vec<T>, Error>
137where
138 T: HasUuid + HasIsAbove,
139{
140 let mut ordered = Vec::new();
141
142 let head_pos = items
144 .iter()
145 .position(|item| !items.iter().any(|i| i.is_above() == Some(item.uuid())));
146
147 if let Some(pos) = head_pos {
148 ordered.push(items.swap_remove(pos));
149
150 while let Some(next_pos) = items
151 .iter()
152 .position(|item| Some(item.uuid()) == ordered.last().unwrap().is_above())
153 {
154 ordered.push(items.swap_remove(next_pos));
155 }
156 }
157
158 Ok(ordered)
159}
160
161#[allow(async_fn_in_trait)]
162pub trait CacheFns {
163 async fn set_cache_key(
164 &self,
165 key: String,
166 value: impl Serialize,
167 expire: u32,
168 ) -> Result<(), Error>;
169 async fn get_cache_key<T>(&self, key: String) -> Result<T, Error>
170 where
171 T: DeserializeOwned;
172 async fn del_cache_key(&self, key: String) -> Result<(), Error>;
173}
174
175impl CacheFns for redis::Client {
176 async fn set_cache_key(
177 &self,
178 key: String,
179 value: impl Serialize,
180 expire: u32,
181 ) -> Result<(), Error> {
182 let mut conn = self.get_multiplexed_tokio_connection().await?;
183
184 let key_encoded = encode(key);
185
186 let value_json = serde_json::to_string(&value)?;
187
188 redis::cmd("SET")
189 .arg(&[key_encoded.clone(), value_json])
190 .exec_async(&mut conn)
191 .await?;
192
193 redis::cmd("EXPIRE")
194 .arg(&[key_encoded, expire.to_string()])
195 .exec_async(&mut conn)
196 .await?;
197
198 Ok(())
199 }
200
201 async fn get_cache_key<T>(&self, key: String) -> Result<T, Error>
202 where
203 T: DeserializeOwned,
204 {
205 let mut conn = self.get_multiplexed_tokio_connection().await?;
206
207 let key_encoded = encode(key);
208
209 let res: String = redis::cmd("GET")
210 .arg(key_encoded)
211 .query_async(&mut conn)
212 .await?;
213
214 Ok(serde_json::from_str(&res)?)
215 }
216
217 async fn del_cache_key(&self, key: String) -> Result<(), Error> {
218 let mut conn = self.get_multiplexed_tokio_connection().await?;
219
220 let key_encoded = encode(key);
221
222 Ok(redis::cmd("DEL")
223 .arg(key_encoded)
224 .query_async(&mut conn)
225 .await?)
226 }
227}
228
229pub fn generate_device_name() -> String {
230 let mut rng = rand::rng();
231
232 let adjective = ADJECTIVES.choose(&mut rng).unwrap();
233 let animal = ANIMALS.choose(&mut rng).unwrap();
234
235 [*adjective, *animal].join(" ")
236}