Compare commits

3 Commits

Author SHA1 Message Date
41072b91d1 some changes? 2024-05-01 16:51:07 +03:00
fb03a110ed allow to migrate tables 2022-05-26 17:03:55 +03:00
473953fc9c save users to sqlite 2022-05-26 12:19:47 +03:00
11 changed files with 2442 additions and 279 deletions

View File

@ -7,9 +7,12 @@ platform:
os: linux os: linux
arch: arm arch: arm
clone:
skip_verify: true
steps: steps:
- name: validate - name: validate
image: rust:1.52 image: rust:1.49
commands: commands:
- cargo test --release --target=armv7-unknown-linux-gnueabihf - cargo test --release --target=armv7-unknown-linux-gnueabihf
environment: environment:

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
/target /target
Cargo.lock
.testdata .testdata
.vscode .vscode

2103
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +1,13 @@
[package] [package]
name = "altherego" name = "altherego"
version = "0.9.9" version = "0.9.9"
authors = ["Aleksandr Trushkin <atrushkin@outlook.com>"] authors = ["Aleksandr Trushkin <aleksandr.trushkin@rt.ru>"]
edition = "2018" edition = "2018"
default-run = "altherego"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
teloxide = { version = "0.12.2", features = ["macros", "auto-send"] } teloxide = { version = "0.9.0", features = ["macros", "auto-send"] }
tokio = {version = "1.8", features = ["full"]} tokio = {version = "1.8", features = ["full"]}
uuid = { version = "0.8.1", features = ["v4"] } uuid = { version = "0.8.1", features = ["v4"] }
log = "0.4" log = "0.4"
@ -24,13 +23,5 @@ async-trait = "0.1.53"
tokio-stream = "0.1.8" tokio-stream = "0.1.8"
rand = "0.8.5" rand = "0.8.5"
[[bin]]
name = "altherego"
path = "src/main.rs"
[[bin]]
name = "migrator"
path = "src/migrator/main.rs"
[profile.dev.package.sqlx-macros] [profile.dev.package.sqlx-macros]
opt-level = 3 opt-level = 3

View File

@ -1,12 +1,12 @@
use std::env; use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() {
let rev = get_value_from_env("GIT_VERSION") let rev = get_value_from_env("GIT_VERSION")
.or_else(|| get_value_from_command("git", ["rev-parse", "--short", "HEAD"])) .or_else(|| get_value_from_command("git", &["rev-parse", "--short", "HEAD"]))
.unwrap_or_else(|| "unknown".to_owned()); .unwrap_or_else(|| "unknown".to_owned());
let branch = get_value_from_env("GIT_BRANCH") let branch = get_value_from_env("GIT_BRANCH")
.or_else(|| get_value_from_command("git", ["rev-parse", "--abbrev-ref", "HEAD"])) .or_else(|| get_value_from_command("git", &["rev-parse", "--abbrev-ref", "HEAD"]))
.unwrap_or_else(|| "unknown".to_owned()); .unwrap_or_else(|| "unknown".to_owned());
println!("cargo:rustc-env=GIT_REVISION={}", rev); println!("cargo:rustc-env=GIT_REVISION={}", rev);
@ -15,22 +15,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Ok(data) = std::fs::read_to_string(".env") { if let Ok(data) = std::fs::read_to_string(".env") {
data.split('\n').into_iter().for_each(|v| { data.split('\n').into_iter().for_each(|v| {
if v.starts_with("DATABASE") {
return;
}
let kv: Vec<&str> = v.split('=').collect(); let kv: Vec<&str> = v.split('=').collect();
if kv.len() != 2 { if kv.len() != 2 {
return; return;
} }
let (key, value) = (kv[0], kv[1]); let (key, value) = (kv[0], kv[1]);
if key == "DATABASE_URL" {
return;
}
println!("cargo:rustc-env={}={}", key, value); println!("cargo:rustc-env={}={}", key, value);
println!("cargo:rerun-if-env-changed={}", key); println!("cargo:rerun-if-env-changed={}", key);
}) })
} }
Ok(())
} }
fn get_value_from_env(key: &str) -> Option<String> { fn get_value_from_env(key: &str) -> Option<String> {

View File

@ -26,3 +26,8 @@ CREATE TABLE IF NOT EXISTS actions (
REFERENCES users(user_id) REFERENCES users(user_id)
ON DELETE CASCADE ON DELETE CASCADE
); );
-- drop index actions_action_id_user_id_idx;
-- drop table users;
-- drop table parameters;
-- drop table actions;

12
db/002_subscribers.sql Normal file
View File

@ -0,0 +1,12 @@
CREATE TABLE IF NOT EXISTS `subscribers` (
`subscriber_id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
`user_id` INTEGER NOT NULL,
`kind` VARCHAR(16) NOT NULL,
`arguments` JSON NOT NULL DEFAULT '{}',
FOREIGN KEY(`user_id`)
REFERENCES users(user_id)
ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS subscribers_kind_user_id
ON subscribers (`kind`, `user_id`);

View File

@ -3,8 +3,7 @@ export DOCKER_BUILDKIT=1
DOCKERFLAGS:=-it --rm \ DOCKERFLAGS:=-it --rm \
-v "${PWD}":"/app" \ -v "${PWD}":"/app" \
--workdir "/app" \ --workdir "/app" \
-e "PWD=/app" \ -e "PWD=/app"
-u $(shell id -u):$(shell id -g)
DOCKERIMG:="rust-build-env:V1" DOCKERIMG:="rust-build-env:V1"
@ -28,10 +27,6 @@ build_release_arm:
docker run ${DOCKERFLAGS} ${DOCKERIMG} /bin/sh -c 'cargo build --release --target=armv7-unknown-linux-gnueabihf' docker run ${DOCKERFLAGS} ${DOCKERIMG} /bin/sh -c 'cargo build --release --target=armv7-unknown-linux-gnueabihf'
.PHONY: build_release_arm .PHONY: build_release_arm
inside:
docker run ${DOCKERFLAGS} ${DOCKERIMG} /bin/bash
.PHONY: inside
docker_build_release_arm: docker_build_release_arm:
docker run ${DOCKERFLAGS} ${DOCKERIMG} make build_release_arm docker run ${DOCKERFLAGS} ${DOCKERIMG} make build_release_arm
@ -43,11 +38,3 @@ dronefile:
-V target_arch=${TARGET_ARCH} -V target_arch=${TARGET_ARCH}
drone sign frx/altherego --save drone sign frx/altherego --save
.PHONY: dronefile .PHONY: dronefile
init_db:
rm -rf .testdata
mkdir .testdata
sqlite3 -init ./db/001_initial.sql ./.testdata/db.sqlite '.q'
open_db:
sqlite3 ./.testdata/db.sqlite

View File

@ -1,5 +1,3 @@
#![allow(unused)]
mod climate; mod climate;
mod repo; mod repo;
mod utils; mod utils;
@ -8,15 +6,8 @@ use anyhow::Result;
use climate::SelfTemperature; use climate::SelfTemperature;
use envconfig::Envconfig; use envconfig::Envconfig;
use log::{debug, info, warn}; use log::{debug, info, warn};
use teloxide::{ use repo::UserDB;
dispatching::{update_listeners::AsUpdateStream, UpdateFilterExt}, use teloxide::{dispatching::UpdateFilterExt, prelude::*, utils::command::BotCommands};
dptree::di::Injectable,
filter_command,
payloads::SendMessage,
prelude::*,
utils::command::BotCommands,
};
use tokio_stream::StreamExt;
use crate::repo::Storage; use crate::repo::Storage;
@ -24,16 +15,20 @@ const VERSION: &str = env!("GIT_REVISION");
const BRANCH: &str = env!("GIT_BRANCH"); const BRANCH: &str = env!("GIT_BRANCH");
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> Result<()> {
env_logger::init(); env_logger::init();
debug!("starting the application"); debug!("starting the application");
tokio::spawn(run()).await.unwrap(); tokio::spawn(run()).await??;
Ok(())
} }
#[derive(Envconfig, Clone, Debug)] #[derive(Envconfig, Clone, Debug)]
struct Settings { struct Settings {
/// Token is used to authenticate itself as a bot and being able
/// to handle incoming commands and messages.
#[envconfig(from = "ALTEREGO_TELEGRAM_TOKEN")] #[envconfig(from = "ALTEREGO_TELEGRAM_TOKEN")]
pub telegram_token: String, pub telegram_token: String,
@ -57,9 +52,11 @@ async fn run() -> anyhow::Result<()> {
info!("starting"); info!("starting");
let settings = Settings::init_from_env().expect("reading config values"); let settings = Settings::init_from_env().expect("reading config values");
let bot = teloxide::Bot::new(&settings.telegram_token); let bot = teloxide::Bot::new(&settings.telegram_token).auto_send();
let migrate = std::env::args().any(|v| v == "migrate");
let repo_config = repo::SqliteConfig { let repo_config = repo::SqliteConfig {
source: settings.db_source, source: settings.db_source,
migrate,
..Default::default() ..Default::default()
}; };
@ -79,11 +76,37 @@ async fn run() -> anyhow::Result<()> {
let handler = Update::filter_message() let handler = Update::filter_message()
.filter_command::<Command>() .filter_command::<Command>()
.chain(dptree::filter(|msg: Message| msg.chat.is_private())) .chain(dptree::filter(|msg: Message| msg.chat.is_private()))
.chain(dptree::filter_map_async( .chain(dptree::filter_map_async(find_user_mw))
|msg: Message, storage: repo::SqliteRepo| async move { .branch(dptree::case![Command::RoomTemperature].endpoint(handle_temperature_sensor))
.branch(dptree::case![Command::HostTemperature].endpoint(handle_host_temperature))
.branch(dptree::case![Command::VersionRequest].endpoint(handle_version))
.branch(dptree::case![Command::Help].endpoint(handle_help));
let mut dependencies = DependencyMap::new();
dependencies.insert(sqlite_storage);
dependencies.insert(climate_client);
dependencies.insert(self_temp_client);
dependencies.insert(utils::Generators::new());
info!("running");
Dispatcher::builder(bot, handler)
.dependencies(dependencies)
.default_handler(|upd| async move {
warn!("unhandled update: {:?}", upd);
})
.build()
.setup_ctrlc_handler()
.dispatch()
.await;
Ok(())
}
async fn find_user_mw(msg: Message, storage: repo::SqliteRepo) -> Option<UserDB> {
let chat_id = msg.chat.id.0; let chat_id = msg.chat.id.0;
info!("checking if the user {chat_id} exists"); info!("checking if the user with chat_id={chat_id} exists");
let user = storage.load_user_by_chat_id(chat_id).await.unwrap(); let user = storage.load_user_by_chat_id(chat_id).await.unwrap();
match user { match user {
Some(user) => Some(user), Some(user) => Some(user),
@ -103,47 +126,17 @@ async fn run() -> anyhow::Result<()> {
Some(user_db) Some(user_db)
} }
} }
},
))
.branch(dptree::case![Command::RoomTemperature].endpoint(handle_temperature_sensor))
.branch(dptree::case![Command::HostTemperature].endpoint(handle_host_temperature))
.branch(dptree::case![Command::VersionRequest].endpoint(handle_version))
.branch(dptree::case![Command::ChatID].endpoint(handle_chat_id))
.branch(dptree::case![Command::Help].endpoint(handle_help));
let mut dependencies = DependencyMap::new();
dependencies.insert(sqlite_storage);
dependencies.insert(climate_client);
dependencies.insert(self_temp_client);
dependencies.insert(utils::Generators::new());
info!("running");
Dispatcher::builder(bot, handler)
.dependencies(dependencies)
.enable_ctrlc_handler()
.default_handler(|upd| async move {
warn!("unhandled update: {:?}", upd);
})
.build()
.dispatch()
.await;
Ok(())
} }
type HandlerResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
fn error_msg(reqid: &utils::RequestID) -> String { fn error_msg(reqid: &utils::RequestID) -> String {
format!("There was an error handling command, sorry. Reffer to {reqid}") format!("There was an error handling command, sorry. Reffer to {reqid}")
} }
async fn handle_temperature_sensor( async fn handle_temperature_sensor(
bot: Bot, bot: AutoSend<Bot>,
msg: Message, msg: Message,
climate: climate::Client, climate: climate::Client,
next_req_id: utils::Generators, next_req_id: utils::Generators,
storage: repo::SqliteRepo,
user: repo::UserDB, user: repo::UserDB,
) -> Result<()> { ) -> Result<()> {
let chat_id = msg.chat.id; let chat_id = msg.chat.id;
@ -173,7 +166,7 @@ async fn handle_temperature_sensor(
} }
async fn handle_host_temperature( async fn handle_host_temperature(
bot: Bot, bot: AutoSend<Bot>,
msg: Message, msg: Message,
temp: SelfTemperature, temp: SelfTemperature,
next_req_id: utils::Generators, next_req_id: utils::Generators,
@ -200,7 +193,7 @@ async fn handle_host_temperature(
Ok(()) Ok(())
} }
async fn handle_version(bot: Bot, msg: Message) -> Result<()> { async fn handle_version(bot: AutoSend<Bot>, msg: Message) -> Result<()> {
let chat_id = msg.chat.id; let chat_id = msg.chat.id;
let text = format!("Bot version is {} (branch: {})", VERSION, BRANCH,); let text = format!("Bot version is {} (branch: {})", VERSION, BRANCH,);
@ -209,16 +202,7 @@ async fn handle_version(bot: Bot, msg: Message) -> Result<()> {
Ok(()) Ok(())
} }
async fn handle_chat_id(bot: Bot, msg: Message) -> Result<()> { async fn handle_help(bot: AutoSend<Bot>, msg: Message) -> Result<()> {
let chat_id = msg.chat.id;
let text = format!("Current chat id: {chat_id}");
bot.send_message(chat_id, text).await?;
Ok(())
}
async fn handle_help(bot: Bot, msg: Message) -> Result<()> {
let chat_id = msg.chat.id; let chat_id = msg.chat.id;
bot.send_message(chat_id, Command::descriptions().to_string()) bot.send_message(chat_id, Command::descriptions().to_string())
@ -227,50 +211,15 @@ async fn handle_help(bot: Bot, msg: Message) -> Result<()> {
Ok(()) Ok(())
} }
struct Handler<S: repo::Storage + Clone + Send + Sync> {
storage: S,
climate: climate::Client,
self_temp: climate::SelfTemperature,
started: std::time::Instant,
}
fn log_error<'a, E: std::fmt::Display>(req_id: &'a str, msg: &'a str) -> impl FnOnce(E) -> E + 'a {
move |err: E| -> E {
warn!(
"request_id={}, {} err={}",
req_id.to_owned(),
msg.to_owned(),
err
);
err
}
}
fn log_message(req_id: &str, msg: Message) {
info!(
"message sent to chat_id={}, text={}",
msg.chat.id,
msg.text().unwrap_or_default(),
)
}
#[derive(serde::Deserialize, Debug)]
struct Climate {
humidity: f32,
temp: f32,
}
#[derive(BotCommands, Debug, Clone, PartialEq, Eq)] #[derive(BotCommands, Debug, Clone, PartialEq, Eq)]
#[command(description = "These commands are supported:")] #[command(rename = "lowercase", description = "These commands are supported:")]
enum Command { enum Command {
#[command(rename="help", description = "display this text.")] #[command(description = "display this text.")]
Help, Help,
#[command(rename="roomtemp", description = "temperature of your room.")] #[command(description = "temperature of your room.")]
RoomTemperature, RoomTemperature,
#[command(rename="hosttemp", description = "temperature of raspberry.")] #[command(description = "temperature of raspberry.")]
HostTemperature, HostTemperature,
#[command(rename="version", description = "prints current version.")] #[command(description = "prints current version.")]
VersionRequest, VersionRequest,
#[command(rename="chatid", description = "prints current chat id.")]
ChatID,
} }

View File

@ -1,44 +0,0 @@
use std::str::FromStr;
use anyhow::Result;
use log::{debug, info};
use envconfig::Envconfig;
use sqlx::{
SqlitePool,
sqlite::SqliteConnectOptions,
migrate,
};
#[derive(Envconfig)]
struct Settings {
#[envconfig(from = "ALTEREGO_DATABASE_URL", default = "./db.sqlite")]
pub db_source: String,
}
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
info!("starting the application");
tokio::spawn(run()).await?
}
async fn run() -> Result<()> {
debug!("running migrations");
let settings = Settings::init_from_env().expect("reading config values");
info!("opening database {}", settings.db_source);
let opts = SqliteConnectOptions::from_str(&settings.db_source)?
.create_if_missing(true)
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
let pool = SqlitePool::connect_with(opts).await?;
migrate!("./db/").run(&pool).await?;
Ok(())
}

View File

@ -3,26 +3,25 @@ use std::str::FromStr;
use anyhow::Result; use anyhow::Result;
use log::debug; use log::debug;
use sqlx::{ use sqlx::{
sqlite::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous},
SqliteConnectOptions, SqliteError, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous, FromRow, Pool, Sqlite,
},
Executor, Pool, Sqlite,
}; };
use teloxide::types::User;
#[derive(Debug)] #[derive(Debug)]
pub struct SqliteConfig { pub struct SqliteConfig {
pub source: String, pub source: String,
pub timeout: std::time::Duration, pub timeout: std::time::Duration,
pub max_conns: u32, pub max_conns: u32,
pub migrate: bool,
} }
impl Default for SqliteConfig { impl Default for SqliteConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
source: ":memory:".to_string(), source: "sqlite::memory:".to_string(),
timeout: std::time::Duration::from_secs(10), timeout: std::time::Duration::from_secs(10),
max_conns: 2, max_conns: 1,
migrate: true,
} }
} }
} }
@ -34,8 +33,8 @@ pub struct SqliteRepo {
impl SqliteRepo { impl SqliteRepo {
pub async fn from_config(config: SqliteConfig) -> Result<Self> { pub async fn from_config(config: SqliteConfig) -> Result<Self> {
let dsn = if !config.source.starts_with("sqlite://") { let dsn = if !config.source.starts_with("sqlite:") {
format!("sqlite://{}", &config.source) format!("sqlite:{}", &config.source)
} else { } else {
config.source config.source
}; };
@ -54,6 +53,8 @@ impl SqliteRepo {
.connect_with(opts) .connect_with(opts)
.await?; .await?;
sqlx::migrate!("./db").run(&pool).await?;
sqlx::query("pragma temp_store = memory;") sqlx::query("pragma temp_store = memory;")
.execute(&pool) .execute(&pool)
.await?; .await?;
@ -70,7 +71,7 @@ impl SqliteRepo {
} }
} }
#[derive(sqlx::FromRow, Debug, Clone)] #[derive(FromRow, Debug, Clone, PartialEq)]
pub struct UserDB { pub struct UserDB {
pub user_id: i64, pub user_id: i64,
pub chat_id: i64, pub chat_id: i64,
@ -78,7 +79,7 @@ pub struct UserDB {
pub created_at: chrono::NaiveDateTime, pub created_at: chrono::NaiveDateTime,
} }
#[derive(sqlx::FromRow)] #[derive(FromRow)]
pub struct ParameterDB { pub struct ParameterDB {
pub param_id: i64, pub param_id: i64,
pub user_id: i64, pub user_id: i64,
@ -101,79 +102,77 @@ pub trait Storage {
async fn load_user_by_chat_id(&self, chat_id: i64) -> Result<Option<UserDB>>; async fn load_user_by_chat_id(&self, chat_id: i64) -> Result<Option<UserDB>>;
async fn get_user_parameters(&self, user_id: i64) -> Result<MappedParameter>; async fn get_user_parameters(&self, user_id: i64) -> Result<MappedParameter>;
async fn upsert_parameter( async fn upsert_parameter(&self, user_id: i64, key: &str, value: &str) -> Result<ParameterDB>;
&self,
user_id: i64,
key: String,
value: String,
) -> Result<ParameterDB>;
async fn delete_parameter(&self, param_id: i64) -> Result<()>;
async fn insert_action(&self, user_id: i64, name: String) -> Result<()>; async fn insert_action(&self, user_id: i64, name: &str) -> Result<()>;
} }
#[async_trait::async_trait] type SQLResult<T> = sqlx::Result<T>;
impl Storage for SqliteRepo {
async fn create_user(&self, chat_id: i64, name: String) -> Result<UserDB> { async fn create_user(chat_id: i64, name: String, pool: &Pool<Sqlite>) -> SQLResult<UserDB> {
Ok(sqlx::query_as!( sqlx::query_as(
UserDB, "INSERT INTO `users` (`chat_id`, `name`, `created_at`)
"INSERT INTO users (chat_id, name, created_at)" VALUES (?, ?, datetime('now'))
+ " VALUES (?, ?, datetime('now'))" RETURNING `user_id`, `chat_id`, `name`, `created_at`;",
+ " RETURNING user_id, chat_id, name, created_at;",
chat_id,
name,
) )
.fetch_one(&self.pool) .bind(chat_id)
.await?) .bind(name)
.fetch_one(pool)
.await
} }
async fn load_user_by_chat_id(&self, chat_id: i64) -> Result<Option<UserDB>> { struct FindUserParams {
let result: std::result::Result<UserDB, sqlx::Error> = sqlx::query_as!( pub(crate) user_id: Option<i64>,
UserDB, pub(crate) chat_id: Option<i64>,
"SELECT user_id, chat_id, name, created_at" + " FROM users WHERE `chat_id` = ?;", }
chat_id,
)
.fetch_one(&self.pool)
.await;
match result { impl FindUserParams {
Ok(row) => Ok(Some(row)), pub fn new() -> Self {
Err(err) => match err { Self {
sqlx::Error::RowNotFound => Ok(None), user_id: None,
err => Err(anyhow::anyhow!(err)), chat_id: None,
},
} }
} }
async fn get_user(&self, user_id: i64) -> Result<Option<UserDB>> { pub fn with_user_id(mut self, user_id: i64) -> Self {
let result: std::result::Result<UserDB, sqlx::Error> = sqlx::query_as!( self.user_id = Some(user_id);
UserDB, self
"SELECT user_id, chat_id, name, created_at" + " FROM users WHERE `user_id` = ?;", }
user_id,
)
.fetch_one(&self.pool)
.await;
match result { pub fn with_chat_id(mut self, chat_id: i64) -> Self {
Ok(row) => Ok(Some(row)), self.chat_id = Some(chat_id);
Err(err) => match err { self
sqlx::Error::RowNotFound => Ok(None),
err => Err(anyhow::anyhow!(err)),
},
} }
} }
async fn get_user_parameters(&self, user_id: i64) -> Result<MappedParameter> { async fn find_user(params: FindUserParams, executor: &Pool<Sqlite>) -> sqlx::Result<UserDB> {
let mut qb = sqlx::QueryBuilder::new(
"SELECT `user_id`, `chat_id`, `name`, `created_at` FROM `users` WHERE 1=1",
);
if let Some(user_id) = params.user_id {
qb.push(" AND `user_id` = ");
qb.push_bind(user_id);
};
if let Some(chat_id) = params.chat_id {
qb.push(" AND `chat_id` = ");
qb.push_bind(chat_id);
}
let row = qb.build().fetch_one(executor).await?;
UserDB::from_row(&row)
}
async fn get_parameters_by_user(user_id: i64, pool: &Pool<Sqlite>) -> SQLResult<MappedParameter> {
let mut mp: MappedParameter = std::collections::HashMap::new(); let mut mp: MappedParameter = std::collections::HashMap::new();
sqlx::query_as!( sqlx::query_as(
ParameterDB, "SELECT `param_id`, `user_id`, `key`, `value` FROM parameters WHERE `user_id` = ?",
"SELECT `param_id`, `user_id`, `key`, `value`" + " FROM parameters WHERE `user_id` = ?",
user_id,
) )
.fetch_all(&self.pool) .bind(user_id)
.fetch_all(pool)
.await? .await?
.into_iter() .into_iter()
.for_each(|result| { .for_each(|result: ParameterDB| {
let param = ParameterBase { let param = ParameterBase {
param_id: result.param_id, param_id: result.param_id,
user_id: result.user_id, user_id: result.user_id,
@ -186,41 +185,202 @@ impl Storage for SqliteRepo {
Ok(mp) Ok(mp)
} }
async fn upsert_parameter( async fn upsert_parameter_for_user(
&self,
user_id: i64, user_id: i64,
key: String, key: &str,
value: String, value: &str,
) -> Result<ParameterDB> { pool: &Pool<Sqlite>,
Ok(sqlx::query_as!( ) -> SQLResult<ParameterDB> {
ParameterDB, sqlx::query_as(
"INSERT INTO parameters (`user_id`, `key`, `value`)" "INSERT INTO parameters (`user_id`, `key`, `value`)
+ " VALUES (?, ?, ?)" VALUES (?, ?, ?)
+ " RETURNING `param_id`, `user_id`, `key`, `value`;", RETURNING `param_id`, `user_id`, `key`, `value`;",
user_id,
key,
value,
) )
.fetch_one(&self.pool) .bind(user_id)
.await?) .bind(key)
.bind(value)
.fetch_one(pool)
.await
} }
async fn delete_parameter(&self, param_id: i64) -> Result<()> { async fn insert_user_action(user_id: i64, name: &str, pool: &Pool<Sqlite>) -> Result<()> {
sqlx::query("DELETE FROM parameters WHERE `param_id` = ?;")
.bind(param_id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn insert_action(&self, user_id: i64, name: String) -> Result<()> {
sqlx::query("INSERT INTO actions (`user_id`, `name`) VALUES (?, ?)") sqlx::query("INSERT INTO actions (`user_id`, `name`) VALUES (?, ?)")
.bind(user_id) .bind(user_id)
.bind(name) .bind(name)
.execute(&self.pool) .execute(pool)
.await?; .await?;
Ok(()) Ok(())
} }
async fn subscriber_user<S>(
user_id: i64,
kind: &str,
args: Option<&S>,
executor: &Pool<Sqlite>,
) -> Result<()>
where
S: serde::Serialize,
{
let args_out = match args {
Some(args) => Some(serde_json::to_string(args)?),
None => None,
};
sqlx::query("INSERT INTO subscribers (`user_id`, `kind`, `args`) VALUES (?, ?, ?)")
.bind(user_id)
.bind(kind)
.bind(args_out)
.execute(executor)
.await?;
Ok(())
}
async fn unsubscribe_user(user_id: i64, kind: &str, executor: &Pool<Sqlite>) -> Result<()> {
sqlx::query("DELETE FROM subscribers WHERE `user_id` = ? AND `kind` = ?")
.bind(user_id)
.bind(kind)
.execute(executor)
.await?;
Ok(())
}
#[derive(FromRow, Debug)]
pub struct SubscriptionDB {
pub subscribe_id: i64,
pub user_id: i64,
pub kind: String,
pub args: String,
}
async fn find_subscribers_by_kind(
kind: &str,
executor: &Pool<Sqlite>,
) -> Result<Vec<SubscriptionDB>> {
Ok(sqlx::query_as( "SELECT `subscriber_id`, `user_id`, `kind`, `arguments` FROM `subcribers` WHERE `kind` = ?")
.bind(kind)
.fetch_all(executor)
.await?)
}
#[async_trait::async_trait]
impl Storage for SqliteRepo {
async fn create_user(&self, chat_id: i64, name: String) -> Result<UserDB> {
Ok(create_user(chat_id, name, &self.pool).await?)
}
async fn load_user_by_chat_id(&self, chat_id: i64) -> Result<Option<UserDB>> {
let params = FindUserParams::new().with_chat_id(chat_id);
match find_user(params, &self.pool).await {
Ok(row) => Ok(Some(row)),
Err(err) => match err {
sqlx::Error::RowNotFound => Ok(None),
err => Err(anyhow::anyhow!(err)),
},
}
}
async fn get_user(&self, user_id: i64) -> Result<Option<UserDB>> {
let params = FindUserParams::new().with_user_id(user_id);
match find_user(params, &self.pool).await {
Ok(row) => Ok(Some(row)),
Err(err) => match err {
sqlx::Error::RowNotFound => Ok(None),
err => Err(anyhow::anyhow!(err)),
},
}
}
async fn get_user_parameters(&self, user_id: i64) -> Result<MappedParameter> {
Ok(get_parameters_by_user(user_id, &self.pool).await?)
}
async fn upsert_parameter(&self, user_id: i64, key: &str, value: &str) -> Result<ParameterDB> {
Ok(upsert_parameter_for_user(user_id, key, value, &self.pool).await?)
}
async fn insert_action(&self, user_id: i64, name: &str) -> Result<()> {
insert_user_action(user_id, name, &self.pool).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Once;
static ONCE: Once = Once::new();
async fn prepare() -> Result<Pool<Sqlite>> {
const DSN: &str = "sqlite::memory:";
ONCE.call_once(|| {
std::env::set_var("RUST_LOG", "debug");
env_logger::init();
});
let pool = SqlitePoolOptions::new().connect(DSN).await?;
sqlx::migrate!("./db").run(&pool).await?;
sqlx::query(
"INSERT INTO `users` (user_id, chat_id, name, created_at)
VALUES (1, 100, 'Alex', datetime('now'));",
)
.execute(&pool)
.await?;
Ok(pool)
}
#[tokio::test]
pub async fn test_get_user_by_chat_id() {
let executor = prepare().await.expect("should prepare store");
let params = FindUserParams::new().with_chat_id(100);
let user = find_user(params, &executor)
.await
.expect("should found user");
let exp_user = UserDB {
user_id: 1,
chat_id: 100,
name: "Alex".to_string(),
created_at: user.created_at,
};
assert_eq!(exp_user, user);
}
#[tokio::test]
pub async fn test_get_user_by_user_id() {
let executor = prepare().await.expect("should prepare store");
let params = FindUserParams::new().with_user_id(1);
let user = find_user(params, &executor)
.await
.expect("should found user");
let exp_user = UserDB {
user_id: 1,
chat_id: 100,
name: "Alex".to_string(),
created_at: user.created_at,
};
assert_eq!(exp_user, user);
}
#[tokio::test]
pub async fn test_create_user() {
let pool = prepare().await.expect("should prepare store");
let user = create_user(101, "Phew".to_owned(), &pool)
.await
.expect("should create user");
let exp_user = UserDB {
user_id: user.user_id,
chat_id: 101,
name: "Phew".to_string(),
created_at: user.created_at,
};
assert_eq!(exp_user, user);
}
} }