227 lines
6.1 KiB
Rust
227 lines
6.1 KiB
Rust
use std::str::FromStr;
|
|
|
|
use anyhow::Result;
|
|
use log::debug;
|
|
use sqlx::{
|
|
sqlite::{
|
|
SqliteConnectOptions, SqliteError, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous,
|
|
},
|
|
Executor, Pool, Sqlite,
|
|
};
|
|
use teloxide::types::User;
|
|
|
|
#[derive(Debug)]
|
|
pub struct SqliteConfig {
|
|
pub source: String,
|
|
pub timeout: std::time::Duration,
|
|
pub max_conns: u32,
|
|
}
|
|
|
|
impl Default for SqliteConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
source: ":memory:".to_string(),
|
|
timeout: std::time::Duration::from_secs(10),
|
|
max_conns: 2,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct SqliteRepo {
|
|
pool: Pool<Sqlite>,
|
|
}
|
|
|
|
impl SqliteRepo {
|
|
pub async fn from_config(config: SqliteConfig) -> Result<Self> {
|
|
let dsn = if !config.source.starts_with("sqlite://") {
|
|
format!("sqlite://{}", &config.source)
|
|
} else {
|
|
config.source
|
|
};
|
|
|
|
debug!("connecting to {}", dsn);
|
|
|
|
let opts = SqliteConnectOptions::from_str(&dsn)?
|
|
.create_if_missing(true)
|
|
.journal_mode(SqliteJournalMode::Wal)
|
|
.synchronous(SqliteSynchronous::Normal)
|
|
.busy_timeout(config.timeout);
|
|
|
|
let pool = SqlitePoolOptions::new()
|
|
.max_connections(config.max_conns)
|
|
.connect_timeout(config.timeout)
|
|
.connect_with(opts)
|
|
.await?;
|
|
|
|
sqlx::query("pragma temp_store = memory;")
|
|
.execute(&pool)
|
|
.await?;
|
|
|
|
sqlx::query("pragma mmap_size = 30000000000;")
|
|
.execute(&pool)
|
|
.await?;
|
|
|
|
sqlx::query("pragma page_size = 4096;")
|
|
.execute(&pool)
|
|
.await?;
|
|
|
|
Ok(Self { pool })
|
|
}
|
|
}
|
|
|
|
#[derive(sqlx::FromRow, Debug, Clone)]
|
|
pub struct UserDB {
|
|
pub user_id: i64,
|
|
pub chat_id: i64,
|
|
pub name: String,
|
|
pub created_at: chrono::NaiveDateTime,
|
|
}
|
|
|
|
#[derive(sqlx::FromRow)]
|
|
pub struct ParameterDB {
|
|
pub param_id: i64,
|
|
pub user_id: i64,
|
|
pub key: String,
|
|
pub value: String,
|
|
}
|
|
|
|
pub type MappedParameter = std::collections::HashMap<String, ParameterBase>;
|
|
|
|
pub struct ParameterBase {
|
|
pub param_id: i64,
|
|
pub user_id: i64,
|
|
pub value: String,
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
pub trait Storage {
|
|
async fn create_user(&self, chat_id: i64, name: String) -> Result<UserDB>;
|
|
async fn get_user(&self, user_id: i64) -> Result<Option<UserDB>>;
|
|
async fn load_user_by_chat_id(&self, chat_id: i64) -> Result<Option<UserDB>>;
|
|
async fn get_user_parameters(&self, user_id: i64) -> Result<MappedParameter>;
|
|
|
|
async fn upsert_parameter(
|
|
&self,
|
|
user_id: i64,
|
|
key: String,
|
|
value: String,
|
|
) -> Result<ParameterDB>;
|
|
async fn delete_parameter(&self, param_id: i64) -> Result<()>;
|
|
|
|
async fn insert_action(&self, user_id: i64, name: String) -> Result<()>;
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl Storage for SqliteRepo {
|
|
async fn create_user(&self, chat_id: i64, name: String) -> Result<UserDB> {
|
|
Ok(sqlx::query_as!(
|
|
UserDB,
|
|
"INSERT INTO users (chat_id, name, created_at)"
|
|
+ " VALUES (?, ?, datetime('now'))"
|
|
+ " RETURNING user_id, chat_id, name, created_at;",
|
|
chat_id,
|
|
name,
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?)
|
|
}
|
|
|
|
async fn load_user_by_chat_id(&self, chat_id: i64) -> Result<Option<UserDB>> {
|
|
let result: std::result::Result<UserDB, sqlx::Error> = sqlx::query_as!(
|
|
UserDB,
|
|
"SELECT user_id, chat_id, name, created_at" + " FROM users WHERE `chat_id` = ?;",
|
|
chat_id,
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(row) => Ok(Some(row)),
|
|
Err(err) => match err {
|
|
sqlx::Error::RowNotFound => Ok(None),
|
|
err => Err(anyhow::anyhow!(err)),
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn get_user(&self, user_id: i64) -> Result<Option<UserDB>> {
|
|
let result: std::result::Result<UserDB, sqlx::Error> = sqlx::query_as!(
|
|
UserDB,
|
|
"SELECT user_id, chat_id, name, created_at" + " FROM users WHERE `user_id` = ?;",
|
|
user_id,
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(row) => Ok(Some(row)),
|
|
Err(err) => match err {
|
|
sqlx::Error::RowNotFound => Ok(None),
|
|
err => Err(anyhow::anyhow!(err)),
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn get_user_parameters(&self, user_id: i64) -> Result<MappedParameter> {
|
|
let mut mp: MappedParameter = std::collections::HashMap::new();
|
|
sqlx::query_as!(
|
|
ParameterDB,
|
|
"SELECT `param_id`, `user_id`, `key`, `value`" + " FROM parameters WHERE `user_id` = ?",
|
|
user_id,
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?
|
|
.into_iter()
|
|
.for_each(|result| {
|
|
let param = ParameterBase {
|
|
param_id: result.param_id,
|
|
user_id: result.user_id,
|
|
value: result.value,
|
|
};
|
|
|
|
mp.insert(result.key, param);
|
|
});
|
|
|
|
Ok(mp)
|
|
}
|
|
|
|
async fn upsert_parameter(
|
|
&self,
|
|
user_id: i64,
|
|
key: String,
|
|
value: String,
|
|
) -> Result<ParameterDB> {
|
|
Ok(sqlx::query_as!(
|
|
ParameterDB,
|
|
"INSERT INTO parameters (`user_id`, `key`, `value`)"
|
|
+ " VALUES (?, ?, ?)"
|
|
+ " RETURNING `param_id`, `user_id`, `key`, `value`;",
|
|
user_id,
|
|
key,
|
|
value,
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?)
|
|
}
|
|
|
|
async fn delete_parameter(&self, param_id: i64) -> Result<()> {
|
|
sqlx::query("DELETE FROM parameters WHERE `param_id` = ?;")
|
|
.bind(param_id)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn insert_action(&self, user_id: i64, name: String) -> Result<()> {
|
|
sqlx::query("INSERT INTO actions (`user_id`, `name`) VALUES (?, ?)")
|
|
.bind(user_id)
|
|
.bind(name)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|