use std::str::FromStr; use anyhow::Result; use log::debug; use sqlx::{ sqlite::{ SqliteConnectOptions, SqliteError, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous, }, Executor, Pool, Sqlite, }; use teloxide::types::User; #[derive(Debug)] pub struct SqliteConfig { pub source: String, pub timeout: std::time::Duration, pub max_conns: u32, } impl Default for SqliteConfig { fn default() -> Self { Self { source: ":memory:".to_string(), timeout: std::time::Duration::from_secs(10), max_conns: 2, } } } #[derive(Clone)] pub struct SqliteRepo { pool: Pool, } impl SqliteRepo { pub async fn from_config(config: SqliteConfig) -> Result { let dsn = if !config.source.starts_with("sqlite://") { format!("sqlite://{}", &config.source) } else { config.source }; debug!("connecting to {}", dsn); let opts = SqliteConnectOptions::from_str(&dsn)? .create_if_missing(true) .journal_mode(SqliteJournalMode::Wal) .synchronous(SqliteSynchronous::Normal) .busy_timeout(config.timeout); let pool = SqlitePoolOptions::new() .max_connections(config.max_conns) .connect_timeout(config.timeout) .connect_with(opts) .await?; sqlx::query("pragma temp_store = memory;") .execute(&pool) .await?; sqlx::query("pragma mmap_size = 30000000000;") .execute(&pool) .await?; sqlx::query("pragma page_size = 4096;") .execute(&pool) .await?; Ok(Self { pool }) } } #[derive(sqlx::FromRow, Debug, Clone)] pub struct UserDB { pub user_id: i64, pub chat_id: i64, pub name: String, pub created_at: chrono::NaiveDateTime, } #[derive(sqlx::FromRow)] pub struct ParameterDB { pub param_id: i64, pub user_id: i64, pub key: String, pub value: String, } pub type MappedParameter = std::collections::HashMap; pub struct ParameterBase { pub param_id: i64, pub user_id: i64, pub value: String, } #[async_trait::async_trait] pub trait Storage { async fn create_user(&self, chat_id: i64, name: String) -> Result; async fn get_user(&self, user_id: i64) -> Result>; async fn load_user_by_chat_id(&self, chat_id: i64) -> Result>; async fn get_user_parameters(&self, user_id: i64) -> Result; async fn upsert_parameter( &self, user_id: i64, key: String, value: String, ) -> Result; async fn delete_parameter(&self, param_id: i64) -> Result<()>; async fn insert_action(&self, user_id: i64, name: String) -> Result<()>; } #[async_trait::async_trait] impl Storage for SqliteRepo { async fn create_user(&self, chat_id: i64, name: String) -> Result { Ok(sqlx::query_as!( UserDB, "INSERT INTO users (chat_id, name, created_at)" + " VALUES (?, ?, datetime('now'))" + " RETURNING user_id, chat_id, name, created_at;", chat_id, name, ) .fetch_one(&self.pool) .await?) } async fn load_user_by_chat_id(&self, chat_id: i64) -> Result> { let result: std::result::Result = sqlx::query_as!( UserDB, "SELECT user_id, chat_id, name, created_at" + " FROM users WHERE `chat_id` = ?;", chat_id, ) .fetch_one(&self.pool) .await; match result { Ok(row) => Ok(Some(row)), Err(err) => match err { sqlx::Error::RowNotFound => Ok(None), err => Err(anyhow::anyhow!(err)), }, } } async fn get_user(&self, user_id: i64) -> Result> { let result: std::result::Result = sqlx::query_as!( UserDB, "SELECT user_id, chat_id, name, created_at" + " FROM users WHERE `user_id` = ?;", user_id, ) .fetch_one(&self.pool) .await; match result { Ok(row) => Ok(Some(row)), Err(err) => match err { sqlx::Error::RowNotFound => Ok(None), err => Err(anyhow::anyhow!(err)), }, } } async fn get_user_parameters(&self, user_id: i64) -> Result { let mut mp: MappedParameter = std::collections::HashMap::new(); sqlx::query_as!( ParameterDB, "SELECT `param_id`, `user_id`, `key`, `value`" + " FROM parameters WHERE `user_id` = ?", user_id, ) .fetch_all(&self.pool) .await? .into_iter() .for_each(|result| { let param = ParameterBase { param_id: result.param_id, user_id: result.user_id, value: result.value, }; mp.insert(result.key, param); }); Ok(mp) } async fn upsert_parameter( &self, user_id: i64, key: String, value: String, ) -> Result { Ok(sqlx::query_as!( ParameterDB, "INSERT INTO parameters (`user_id`, `key`, `value`)" + " VALUES (?, ?, ?)" + " RETURNING `param_id`, `user_id`, `key`, `value`;", user_id, key, value, ) .fetch_one(&self.pool) .await?) } async fn delete_parameter(&self, param_id: i64) -> Result<()> { sqlx::query("DELETE FROM parameters WHERE `param_id` = ?;") .bind(param_id) .execute(&self.pool) .await?; Ok(()) } async fn insert_action(&self, user_id: i64, name: String) -> Result<()> { sqlx::query("INSERT INTO actions (`user_id`, `name`) VALUES (?, ?)") .bind(user_id) .bind(name) .execute(&self.pool) .await?; Ok(()) } }