From 52ba998438a92ffb547675e4eace54837439563a Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Mon, 11 Jul 2022 13:47:17 +0800 Subject: [PATCH] MySQL and PostgreSQL support --- registry/Cargo.lock | 29 -- registry/common-utils/src/lib.rs | 18 +- registry/feathr-registry/Cargo.toml | 2 +- .../src/models/entity_prop.rs | 11 + registry/sql-provider/Cargo.toml | 4 +- registry/sql-provider/src/database/mod.rs | 35 +- registry/sql-provider/src/database/mssql.rs | 37 +- registry/sql-provider/src/database/sqlx.rs | 350 +++++++++++++++++- registry/sql-provider/src/lib.rs | 2 +- 9 files changed, 398 insertions(+), 90 deletions(-) diff --git a/registry/Cargo.lock b/registry/Cargo.lock index e508561..51dfab8 100644 --- a/registry/Cargo.lock +++ b/registry/Cargo.lock @@ -800,18 +800,6 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" -[[package]] -name = "flume" -version = "0.10.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ceeb589a3157cac0ab8cc585feb749bd2cea5cb55a6ee802ad72d9fd38303da" -dependencies = [ - "futures-core", - "futures-sink", - "pin-project", - "spin 0.9.3", -] - [[package]] name = "fnv" version = "1.0.7" @@ -1309,17 +1297,6 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33a33a362ce288760ec6a508b94caaec573ae7d3bbbd91b87aa0bad4456839db" -[[package]] -name = "libsqlite3-sys" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898745e570c7d0453cc1fbc4a701eb6c662ed54e8fec8b7d14be137ebeeb9d14" -dependencies = [ - "cc", - "pkg-config", - "vcpkg", -] - [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2645,9 +2622,6 @@ name = "spin" version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c530c2b0d0bf8b69304b39fe2001993e267461948b890cd037d8ad4293fa1a0d" -dependencies = [ - "lock_api", -] [[package]] name = "spki" @@ -2722,10 +2696,8 @@ dependencies = [ "dirs", "either", "event-listener", - "flume", "futures-channel", "futures-core", - "futures-executor", "futures-intrusive", "futures-util", "generic-array", @@ -2736,7 +2708,6 @@ dependencies = [ "indexmap", "itoa", "libc", - "libsqlite3-sys", "log", "md-5", "memchr", diff --git a/registry/common-utils/src/lib.rs b/registry/common-utils/src/lib.rs index cd414eb..76c0ddc 100644 --- a/registry/common-utils/src/lib.rs +++ b/registry/common-utils/src/lib.rs @@ -65,8 +65,7 @@ impl Appliable for T where T: Sized {} /** * Flip `Option>` to `Result, E>` so we can use `?` on the result */ -pub trait FlippedOptionResult -{ +pub trait FlippedOptionResult { fn flip(self) -> Result, E>; } @@ -78,9 +77,9 @@ impl FlippedOptionResult for Option> { pub fn is_default(t: &T) -> bool where - T: Default + Eq + T: Default + Eq, { - t==&T::default() + t == &T::default() } pub trait Blank { @@ -112,8 +111,15 @@ pub fn init_logger() { "registry_api", "registry_app", ]; - let module_logs = modules.into_iter().map(|m| format!("{}=debug", m)).collect::>().join(","); - let rust_log = format!("info,tantivy=warn,tiberius=warn,openraft=warn,{}", module_logs); + let module_logs = modules + .into_iter() + .map(|m| format!("{}=trace", m)) + .collect::>() + .join(","); + let rust_log = format!( + "info,tantivy=warn,tiberius=warn,openraft=warn,sqlx=warn,{}", + module_logs + ); if std::env::var_os("RUST_LOG").is_none() { std::env::set_var("RUST_LOG", &rust_log); } diff --git a/registry/feathr-registry/Cargo.toml b/registry/feathr-registry/Cargo.toml index 6414f1a..aa2aaf9 100644 --- a/registry/feathr-registry/Cargo.toml +++ b/registry/feathr-registry/Cargo.toml @@ -25,7 +25,7 @@ openraft = { git = "https://github.com/windoze/openraft.git", features = ["serde common-utils = { path = "../common-utils" } registry-provider = { path = "../registry-provider" } -sql-provider = { path = "../sql-provider" } +sql-provider = { path = "../sql-provider", features = ["default"] } registry-api = { path = "../registry-api" } raft-registry = { path = "../raft-registry" } diff --git a/registry/registry-provider/src/models/entity_prop.rs b/registry/registry-provider/src/models/entity_prop.rs index a9bad9e..ade7a74 100644 --- a/registry/registry-provider/src/models/entity_prop.rs +++ b/registry/registry-provider/src/models/entity_prop.rs @@ -17,6 +17,14 @@ pub enum EntityStatus { Deprecated, } +fn default_version() -> u64 { + 1 +} + +fn default_created_on() -> DateTime { + Utc::now() +} + #[derive(Clone, Debug, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct EntityProperty { @@ -28,8 +36,11 @@ pub struct EntityProperty { pub labels: Vec, #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub tags: HashMap, + #[serde(default = "default_version")] pub version: u64, + #[serde(default)] pub created_by: String, + #[serde(default = "default_created_on")] pub created_on: DateTime, #[serde(flatten)] pub attributes: Attributes, diff --git a/registry/sql-provider/Cargo.toml b/registry/sql-provider/Cargo.toml index 65f668b..7e1a295 100644 --- a/registry/sql-provider/Cargo.toml +++ b/registry/sql-provider/Cargo.toml @@ -32,6 +32,7 @@ sqlx = { version = "0.6.0", features = [ "runtime-tokio-rustls", "any", "uuid", + "macros", ], default-features = false, optional = true } common-utils = { path = "../common-utils" } @@ -43,7 +44,6 @@ rand = "0.8" [features] default = ["mssql", "ossdbms"] mssql = ["tiberius", "bb8", "bb8-tiberius"] -ossdbms = ["mysql", "postgres", "sqlite"] +ossdbms = ["mysql", "postgres"] mysql = ["sqlx/mysql"] postgres = ["sqlx/postgres"] -sqlite = ["sqlx/sqlite"] diff --git a/registry/sql-provider/src/database/mod.rs b/registry/sql-provider/src/database/mod.rs index e01dbeb..0db7fa0 100644 --- a/registry/sql-provider/src/database/mod.rs +++ b/registry/sql-provider/src/database/mod.rs @@ -5,34 +5,31 @@ use crate::Registry; #[cfg(feature = "mssql")] mod mssql; -#[cfg(feature = "ossdmbs")] +#[cfg(feature = "ossdbms")] mod sqlx; -pub async fn load_registry() -> Result, anyhow::Error> { - #[cfg(feature = "ossdmbs")] - if sqlx::validate_condition() { - return sqlx::load_registry().await; - } +pub fn attach_storage(registry: &mut Registry) { #[cfg(feature = "mssql")] if mssql::validate_condition() { - return mssql::load_registry().await; + mssql::attach_storage(registry); } - anyhow::bail!("Unable to load registry") -} -pub fn attach_storage(registry: &mut Registry) { - #[cfg(feature = "ossdmbs")] - todo!(); - - #[cfg(feature = "mssql")] - mssql::attach_storage(registry); + #[cfg(feature = "ossdbms")] + if sqlx::validate_condition() { + sqlx::attach_storage(registry); + } } pub async fn load_content( ) -> Result<(Vec>, Vec), anyhow::Error> { - #[cfg(feature = "ossdmbs")] - todo!(); - #[cfg(feature = "mssql")] - mssql::load_content().await + if mssql::validate_condition() { + return mssql::load_content().await; + } + + #[cfg(feature = "ossdbms")] + if sqlx::validate_condition() { + return sqlx::load_content().await; + } + anyhow::bail!("Unable to load registry") } diff --git a/registry/sql-provider/src/database/mssql.rs b/registry/sql-provider/src/database/mssql.rs index 73613a1..699c777 100644 --- a/registry/sql-provider/src/database/mssql.rs +++ b/registry/sql-provider/src/database/mssql.rs @@ -53,7 +53,7 @@ async fn load_entities( conn: &mut PooledConnection<'static, ConnectionManager>, ) -> Result, anyhow::Error> { let entities_table = - std::env::var("MSSQL_ENTITY_TABLE").unwrap_or_else(|_| "entities".to_string()); + std::env::var("ENTITY_TABLE").unwrap_or_else(|_| "entities".to_string()); debug!("Loading entities from {}", entities_table); let x: Vec = conn .simple_query(format!("SELECT entity_content from {}", entities_table)) @@ -70,7 +70,7 @@ async fn load_entities( async fn load_edges( conn: &mut PooledConnection<'static, ConnectionManager>, ) -> Result, anyhow::Error> { - let edges_table = std::env::var("MSSQL_EDGE_TABLE").unwrap_or_else(|_| "edges".to_string()); + let edges_table = std::env::var("EDGE_TABLE").unwrap_or_else(|_| "edges".to_string()); debug!("Loading edges from {}", edges_table); let x: Vec = conn .simple_query(format!( @@ -112,30 +112,11 @@ async fn connect() -> Result, anyho } pub fn validate_condition() -> bool { - // TODO: - true -} - -pub async fn load_registry() -> Result, anyhow::Error> { - debug!("Loading registry data from database"); - let mut conn = connect().await?; - let edges = load_edges(&mut conn).await?; - let entities = load_entities(&mut conn).await?; - debug!( - "{} entities and {} edges loaded", - entities.len(), - edges.len() - ); - let mut registry = Registry::load( - entities.into_iter().map(|e| e.into()), - edges.into_iter().map(|e| e.into()), - ) - .await?; - registry - .external_storage - .push(Arc::new(RwLock::new(MsSqlStorage::default()))); - - Ok(registry) + if let Ok(conn_str) = std::env::var("CONNECTION_STR") { + tiberius::Config::from_ado_string(&conn_str).is_ok() + } else { + false + } } pub async fn load_content() -> Result<(Vec>, Vec), anyhow::Error> { @@ -178,8 +159,8 @@ impl MsSqlStorage { impl Default for MsSqlStorage { fn default() -> Self { Self::new( - &std::env::var("MSSQL_ENTITY_TABLE").unwrap_or_else(|_| "entities".to_string()), - &std::env::var("MSSQL_EDGE_TABLE").unwrap_or_else(|_| "edges".to_string()), + &std::env::var("ENTITY_TABLE").unwrap_or_else(|_| "entities".to_string()), + &std::env::var("EDGE_TABLE").unwrap_or_else(|_| "edges".to_string()), ) } } diff --git a/registry/sql-provider/src/database/sqlx.rs b/registry/sql-provider/src/database/sqlx.rs index a1fbc64..d639905 100644 --- a/registry/sql-provider/src/database/sqlx.rs +++ b/registry/sql-provider/src/database/sqlx.rs @@ -1,8 +1,350 @@ -pub async fn load_registry() -> Result, anyhow::Error> { - todo!() +use std::sync::Arc; + +use async_trait::async_trait; +use log::debug; +use sqlx::{pool::PoolConnection, Any, AnyConnection, AnyPool, Connection, Executor}; + +use crate::{db_registry::ExternalStorage, Registry}; +use common_utils::Logged; +use registry_provider::{Edge, EdgeType, Entity, EntityProperty, RegistryError}; +use tokio::sync::{OnceCell, RwLock}; +use uuid::Uuid; + +#[derive(sqlx::FromRow)] +struct EntityPropertyWrapper { + entity_content: String, +} + +async fn load_entities() -> Result, anyhow::Error> { + let entities_table = std::env::var("ENTITY_TABLE").unwrap_or_else(|_| "entities".to_string()); + debug!("Loading entities from {}", entities_table); + let pool = POOL + .get_or_init(|| async { init_pool().await.ok() }) + .await + .clone() + .ok_or_else(|| anyhow::Error::msg("Environment variable 'CONNECTION_STR' is not set."))?; + debug!("SQLx connection pool acquired, connecting to database"); + let sql = format!("SELECT entity_content from {}", entities_table); + let rows = sqlx::query_as::<_, EntityPropertyWrapper>(&sql) + .fetch_all(&pool) + .await?; + debug!("{} rows loaded", rows.len()); + let x = rows + .into_iter() + .map(|r| { + debug!("{}", r.entity_content); + serde_json::from_str::( + &r.entity_content.replace('\n', "").replace('\r', ""), + ) + .map_err(|e| { + anyhow::Error::msg(format!( + "Failed to parse entity content: '{}', error is {}", + &r.entity_content, + e.to_string() + )) + }) + .log() + }) + .collect::, anyhow::Error>>()?; + debug!("{} entities loaded", x.len()); + Ok(x) +} + +#[derive(sqlx::FromRow)] +struct EdgeWrapper { + from_id: String, + to_id: String, + edge_type: String, +} + +async fn load_edges() -> Result, anyhow::Error> { + let edges_table = std::env::var("EDGE_TABLE").unwrap_or_else(|_| "edges".to_string()); + debug!("Loading edges from {}", edges_table); + let pool = POOL + .get_or_init(|| async { init_pool().await.ok() }) + .await + .clone() + .ok_or_else(|| anyhow::Error::msg("Environment variable 'CONNECTION_STR' is not set."))?; + debug!("SQLx connection pool acquired, connecting to database"); + let sql = format!("SELECT from_id, to_id, edge_type from {}", edges_table); + let rows: Vec = sqlx::query_as::<_, EdgeWrapper>(&sql) + .fetch_all(&pool) + .await?; + debug!("{} rows loaded", rows.len()); + let x = rows + .into_iter() + .map(|r| -> Result { + let edge_type = match serde_json::from_str::(&format!("\"{}\"", &r.edge_type)) + { + Ok(v) => v, + Err(e) => { + return Err(anyhow::Error::msg(format!( + "Failed to parse edge type: {}, error {}", + r.edge_type, + e.to_string() + ))); + } + }; + let from = match Uuid::parse_str(&r.from_id) { + Ok(v) => v, + Err(e) => { + return Err(anyhow::Error::msg(format!( + "Failed to parse from id: {}, error {}", + r.from_id, + e.to_string() + ))); + } + }; + let to = match Uuid::parse_str(&r.to_id) { + Ok(v) => v, + Err(e) => { + return Err(anyhow::Error::msg(format!( + "Failed to parse to id: {}, error {}", + r.to_id, + e.to_string() + ))); + } + }; + + Ok(Edge { + edge_type, + from, + to, + }) + }) + .collect::, anyhow::Error>>()?; + debug!("{} edges loaded", x.len()); + Ok(x) +} + +pub async fn load_content() -> Result<(Vec>, Vec), anyhow::Error> { + debug!("Loading registry data from database"); + let edges = load_edges().await?; + let entities = load_entities().await?; + debug!( + "{} entities and {} edges loaded", + entities.len(), + edges.len() + ); + Ok(( + entities.into_iter().map(|e| e.into()).collect(), + edges.into_iter().map(|e| e.into()).collect(), + )) } pub fn validate_condition() -> bool { - // TODO: - false + if let Ok(conn_str) = std::env::var("CONNECTION_STR") { + conn_str + .parse::<::Options>() + .is_ok() + } else { + false + } +} + +pub fn attach_storage(registry: &mut Registry) { + registry + .external_storage + .push(Arc::new(RwLock::new(SqlxStorage::default()))); +} + +static POOL: OnceCell> = OnceCell::const_new(); + +async fn init_pool() -> anyhow::Result { + debug!("Initializing SQLx connection pool"); + let conn_str = std::env::var("CONNECTION_STR")?; + let pool = AnyPool::connect(conn_str.as_str()).await?; + debug!("SQLx connection pool initialized"); + Ok(pool) +} + +async fn connect() -> Result, anyhow::Error> { + debug!("Acquiring SQLx connection pool"); + let pool = POOL + .get_or_init(|| async { init_pool().await.ok() }) + .await + .clone() + .ok_or_else(|| anyhow::Error::msg("Environment variable 'CONNECTION_STR' is not set."))?; + debug!("SQLx connection pool acquired, connecting to database"); + let conn = pool.acquire().await?; + debug!("Database connected"); + Ok(conn) +} + +#[derive(Debug)] +struct SqlxStorage { + entity_table: String, + edge_table: String, +} + +impl SqlxStorage { + pub fn new(entity_table: &str, edge_table: &str) -> Self { + Self { + entity_table: entity_table.to_string(), + edge_table: edge_table.to_string(), + } + } +} + +impl Default for SqlxStorage { + fn default() -> Self { + Self::new( + &std::env::var("ENTITY_TABLE").unwrap_or_else(|_| "entities".to_string()), + &std::env::var("EDGE_TABLE").unwrap_or_else(|_| "edges".to_string()), + ) + } +} + +#[async_trait] +impl ExternalStorage for SqlxStorage { + /** + * Function will be called when a new entity is added in the graph + * ExternalStorage may need to create the entity record in database, etc + */ + async fn add_entity( + &mut self, + id: Uuid, + entity: &Entity, + ) -> Result<(), RegistryError> { + let mut conn = connect() + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + match conn.kind() { + sqlx::any::AnyKind::Postgres => { + let sql = &format!( + r#"INSERT INTO {} + (entity_id, entity_content) + values + ($1, $2) + ON CONFLICT DO NOTHING;"#, + self.entity_table, + ); + let query = sqlx::query(&sql) + .bind(id.to_string()) + .bind(serde_json::to_string_pretty(&entity.properties).unwrap()); + conn.execute(query) + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + } + sqlx::any::AnyKind::MySql => { + let sql = format!( + r#"INSERT IGNORE INTO {} + (entity_id, entity_content) + values + (?, ?)"#, + self.entity_table, + ); + let query = sqlx::query(&sql) + .bind(id.to_string()) + .bind(serde_json::to_string_pretty(&entity.properties).unwrap()); + conn.execute(query) + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + } + }; + Ok(()) + } + + /** + * Function will be called when an entity is deleted in the graph + * ExternalStorage may need to remove the entity record from database, etc + */ + async fn delete_entity( + &mut self, + id: Uuid, + _entity: &Entity, + ) -> Result<(), RegistryError> { + let sql = format!(r#"DELETE {} WHERE entity_id = ?;"#, self.entity_table,); + let query = sqlx::query(&sql).bind(id.to_string()); + let mut conn = connect() + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + conn.execute(query) + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + Ok(()) + } + + /** + * Function will be called when 2 entities are connected. + * EntityProp has already been updated accordingly. + * ExternalStorage may need to create the edge record in database, etc + */ + async fn connect( + &mut self, + from_id: Uuid, + to_id: Uuid, + edge_type: EdgeType, + ) -> Result<(), RegistryError> { + let mut conn = connect() + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + match conn.kind() { + sqlx::any::AnyKind::Postgres => { + let sql = &format!( + r#"INSERT INTO {} + (from_id, to_id, edge_type) + values + ($1, $2, $3) + ON CONFLICT DO NOTHING;"#, + self.edge_table, + ); + let query = sqlx::query(&sql) + .bind(from_id.to_string()) + .bind(to_id.to_string()) + .bind(format!("{:?}", edge_type)); + conn.execute(query) + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + } + sqlx::any::AnyKind::MySql => { + let sql = format!( + r#"INSERT IGNORE INTO {} + (from_id, to_id, edge_type) + values + (?, ?, ?)"#, + self.edge_table, + ); + let query = sqlx::query(&sql) + .bind(from_id.to_string()) + .bind(to_id.to_string()) + .bind(format!("{:?}", edge_type)); + conn.execute(query) + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + } + }; + Ok(()) + } + + /** + * Function will be called when 2 entities are disconnected. + * EntityProp has already been updated accordingly. + * ExternalStorage may need to remove the edge record from database, etc + */ + async fn disconnect( + &mut self, + _from: &Entity, + from_id: Uuid, + _to: &Entity, + to_id: Uuid, + edge_type: EdgeType, + _edge_id: Uuid, + ) -> Result<(), RegistryError> { + let sql = format!( + r#"DELETE {} WHERE from_id=? and to_id=? and edge_type=?;"#, + self.edge_table, + ); + let query = sqlx::query(&sql) + .bind(from_id.to_string()) + .bind(to_id.to_string()) + .bind(format!("{:?}", edge_type)); + let mut conn = connect() + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + conn.execute(query) + .await + .map_err(|e| RegistryError::ExternalStorageError(format!("{:?}", e)))?; + Ok(()) + } } diff --git a/registry/sql-provider/src/lib.rs b/registry/sql-provider/src/lib.rs index 252635d..b3c55a7 100644 --- a/registry/sql-provider/src/lib.rs +++ b/registry/sql-provider/src/lib.rs @@ -10,7 +10,7 @@ use std::collections::HashSet; use std::fmt::Debug; use async_trait::async_trait; -pub use database::{attach_storage, load_content, load_registry}; +pub use database::{attach_storage, load_content}; pub use db_registry::Registry; use log::debug; use registry_provider::{