diff --git a/.eslintignore b/.eslintignore index bb47cf2..ceb9de4 100644 --- a/.eslintignore +++ b/.eslintignore @@ -13,5 +13,4 @@ node_modules/ pkg/ target/ .eslintrc.js -babel.config.js -packages/webapp/src/workers/sds-wasm/worker.js \ No newline at end of file +babel.config.js \ No newline at end of file diff --git a/.pnp.cjs b/.pnp.cjs index 626e66c..e686ead 100755 --- a/.pnp.cjs +++ b/.pnp.cjs @@ -13697,10 +13697,10 @@ function $$SETUP_STATE(hydrateRuntimeState, basePath) { }] ]], ["sds-wasm", [ - ["file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp", { - "packageLocation": "./.yarn/cache/sds-wasm-file-aa35777e98-1550beb7ae.zip/node_modules/sds-wasm/", + ["file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp", { + "packageLocation": "./.yarn/cache/sds-wasm-file-6dcf65ffd1-dd8b934db1.zip/node_modules/sds-wasm/", "packageDependencies": [ - ["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp"] + ["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp"] ], "linkType": "HARD", }] @@ -15482,7 +15482,7 @@ function $$SETUP_STATE(hydrateRuntimeState, basePath) { ["react-is", "npm:17.0.2"], ["react-router-dom", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:6.0.2"], ["recoil", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:0.5.2"], - ["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp"], + ["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp"], ["styled-components", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:5.3.3"], ["ts-node", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:10.4.0"], ["typescript", "patch:typescript@npm%3A4.4.3#~builtin::version=4.4.3&hash=32657b"], diff --git a/.vsts-ci.yml b/.vsts-ci.yml deleted file mode 100644 index f853a1d..0000000 --- a/.vsts-ci.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: SDS CI -pool: - vmImage: ubuntu-latest - -trigger: - batch: true - branches: - include: - - main - -stages: - - stage: Compliance - dependsOn: [] - jobs: - - job: ComplianceJob - pool: - vmImage: windows-latest - steps: - - task: CredScan@3 - inputs: - outputFormat: sarif - debugMode: false - - - task: ComponentGovernanceComponentDetection@0 - inputs: - scanType: 'Register' - verbosity: 'Verbose' - alertWarningLevel: 'High' - - - task: PublishSecurityAnalysisLogs@3 - inputs: - ArtifactName: 'CodeAnalysisLogs' - ArtifactType: 'Container' \ No newline at end of file diff --git a/.yarn/cache/sds-wasm-file-6dcf65ffd1-dd8b934db1.zip b/.yarn/cache/sds-wasm-file-6dcf65ffd1-dd8b934db1.zip new file mode 100644 index 0000000..2aefde3 Binary files /dev/null and b/.yarn/cache/sds-wasm-file-6dcf65ffd1-dd8b934db1.zip differ diff --git a/.yarn/cache/sds-wasm-file-aa35777e98-1550beb7ae.zip b/.yarn/cache/sds-wasm-file-aa35777e98-1550beb7ae.zip deleted file mode 100644 index 7a48ae7..0000000 Binary files a/.yarn/cache/sds-wasm-file-aa35777e98-1550beb7ae.zip and /dev/null differ diff --git a/.yarn/sdks/eslint/bin/eslint.js b/.yarn/sdks/eslint/bin/eslint.js old mode 100755 new mode 100644 diff --git a/.yarn/sdks/prettier/index.js b/.yarn/sdks/prettier/index.js old mode 100755 new mode 100644 diff --git a/.yarn/sdks/typescript/bin/tsc b/.yarn/sdks/typescript/bin/tsc old mode 100755 new mode 100644 diff --git a/.yarn/sdks/typescript/bin/tsserver b/.yarn/sdks/typescript/bin/tsserver old mode 100755 new mode 100644 diff --git a/Cargo.lock b/Cargo.lock index 47d6be0..941093e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -42,6 +42,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + [[package]] name = "base-x" version = "0.2.8" @@ -103,6 +109,50 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "lazy_static", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" +dependencies = [ + "cfg-if", + "lazy_static", +] + [[package]] name = "csv" version = "1.1.6" @@ -307,6 +357,25 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +[[package]] +name = "memoffset" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.8.0" @@ -499,6 +568,31 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rayon" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" +dependencies = [ + "autocfg", + "crossbeam-deque", + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "lazy_static", + "num_cpus", +] + [[package]] name = "redox_syscall" version = "0.2.10" @@ -582,6 +676,9 @@ dependencies = [ "lru", "pyo3", "rand", + "rayon", + "serde", + "serde_json", ] [[package]] @@ -589,6 +686,7 @@ name = "sds-pyo3" version = "1.0.0" dependencies = [ "csv", + "env_logger", "log", "pyo3", "sds-core", diff --git a/docker-compose.yml b/docker-compose.yml index 4522b94..a9d056f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,8 @@ -version: "3" +version: '3' services: - webapp: - build: - context: . - dockerfile: webapp.dockerfile - ports: - - 3000:80 \ No newline at end of file + webapp: + build: + context: . + dockerfile: webapp.dockerfile + ports: + - 3000:80 diff --git a/package.json b/package.json index 3d471be..78b1ae8 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,7 @@ "build:": "yarn workspaces foreach -ivt run build", "start:": "yarn workspaces foreach -piv run start", "lint:": "essex lint --fix --strict", - "build:lib-wasm": "cd packages/lib-wasm && wasm-pack build --release --target no-modules --out-dir ../../target/wasm", + "build:lib-wasm": "cd packages/lib-wasm && wasm-pack build --release --target web --out-dir ../../target/wasm", "prettify": "essex prettify", "rebuild-all": "cargo clean && run-s clean: && cargo build --release && run-s build:lib-wasm && yarn install && run-s build:" }, diff --git a/packages/cli/Cargo.toml b/packages/cli/Cargo.toml index 3a40aab..dcf10c3 100644 --- a/packages/cli/Cargo.toml +++ b/packages/cli/Cargo.toml @@ -7,7 +7,7 @@ repository = "https://github.com/microsoft/synthetic-data-showcase" edition = "2018" [dependencies] -sds-core = { path = "../core" } +sds-core = { path = "../core", features = ["rayon"] } log = { version = "0.4" } env_logger = { version = "0.9" } structopt = { version = "0.3" } diff --git a/packages/cli/src/main.rs b/packages/cli/src/main.rs index 5c7ae83..6148b8e 100644 --- a/packages/cli/src/main.rs +++ b/packages/cli/src/main.rs @@ -1,8 +1,11 @@ use log::{error, info, log_enabled, trace, Level::Debug}; use sds_core::{ - data_block::{block::DataBlockCreator, csv_block_creator::CsvDataBlockCreator}, - processing::{aggregator::Aggregator, generator::Generator}, - utils::reporting::LoggerProgressReporter, + data_block::{csv_block_creator::CsvDataBlockCreator, data_block_creator::DataBlockCreator}, + processing::{ + aggregator::Aggregator, + generator::{Generator, SynthesisMode}, + }, + utils::{reporting::LoggerProgressReporter, threading::set_number_of_threads}, }; use std::process; use structopt::StructOpt; @@ -26,6 +29,14 @@ enum Command { default_value = "100000" )] cache_max_size: usize, + #[structopt( + long = "mode", + help = "synthesis mode", + possible_values = &["seeded", "unseeded"], + case_insensitive = true, + default_value = "seeded" + )] + mode: SynthesisMode, }, Aggregate { #[structopt(long = "aggregates-path", help = "generated aggregates file path")] @@ -104,6 +115,12 @@ struct Cli { help = "columns where zeros should not be ignored (can be set multiple times)" )] sensitive_zeros: Vec, + + #[structopt( + long = "n-threads", + help = "number of threads used to process the data in parallel (default is the number of cores)" + )] + n_threads: Option, } fn main() { @@ -118,6 +135,10 @@ fn main() { trace!("execution parameters: {:#?}", cli); + if let Some(n_threads) = cli.n_threads { + set_number_of_threads(n_threads); + } + match CsvDataBlockCreator::create( csv::ReaderBuilder::new() .delimiter(cli.sensitive_delimiter.chars().next().unwrap() as u8) @@ -131,15 +152,20 @@ fn main() { synthetic_path, synthetic_delimiter, cache_max_size, + mode, } => { - let mut generator = Generator::new(&data_block); - let generated_data = - generator.generate(cli.resolution, cache_max_size, "", &mut progress_reporter); + let mut generator = Generator::new(data_block); + let generated_data = generator.generate( + cli.resolution, + cache_max_size, + String::from(""), + mode, + &mut progress_reporter, + ); - if let Err(err) = generator.write_records( - &generated_data.synthetic_data, + if let Err(err) = generated_data.write_synthetic_data( &synthetic_path, - synthetic_delimiter.chars().next().unwrap() as u8, + synthetic_delimiter.chars().next().unwrap(), ) { error!("error writing output file: {}", err); process::exit(1); @@ -153,28 +179,24 @@ fn main() { sensitivity_threshold, records_sensitivity_path, } => { - let mut aggregator = Aggregator::new(&data_block); + let mut aggregator = Aggregator::new(data_block); let mut aggregated_data = aggregator.aggregate( reporting_length, sensitivity_threshold, &mut progress_reporter, ); - let privacy_risk = - aggregator.calc_privacy_risk(&aggregated_data.aggregates_count, cli.resolution); + let privacy_risk = aggregated_data.calc_privacy_risk(cli.resolution); if !not_protect { - Aggregator::protect_aggregates_count( - &mut aggregated_data.aggregates_count, - cli.resolution, - ); + aggregated_data.protect_aggregates_count(cli.resolution); } info!("Calculated privacy risk is: {:#?}", privacy_risk); - if let Err(err) = aggregator.write_aggregates_count( - &aggregated_data.aggregates_count, + if let Err(err) = aggregated_data.write_aggregates_count( &aggregates_path, aggregates_delimiter.chars().next().unwrap(), + ";", cli.resolution, !not_protect, ) { @@ -183,9 +205,7 @@ fn main() { } if let Some(path) = records_sensitivity_path { - if let Err(err) = aggregator - .write_records_sensitivity(&aggregated_data.records_sensitivity, &path) - { + if let Err(err) = aggregated_data.write_records_sensitivity(&path, '\t') { error!("error writing output file: {}", err); process::exit(1); } diff --git a/packages/core/Cargo.toml b/packages/core/Cargo.toml index 7631b61..47c84d8 100644 --- a/packages/core/Cargo.toml +++ b/packages/core/Cargo.toml @@ -18,4 +18,7 @@ getrandom = { version = "0.2", features = ["js"] } log = { version = "0.4", features = ["std"] } csv = { version = "1.1" } instant = { version = "0.1", features = [ "stdweb", "wasm-bindgen" ] } -pyo3 = { version = "0.15", features = ["extension-module"], optional = true } \ No newline at end of file +pyo3 = { version = "0.15", features = ["extension-module"], optional = true } +rayon = { version = "1.5", optional = true } +serde = { version = "1.0", features = [ "derive", "rc" ] } +serde_json = { version = "1.0" } \ No newline at end of file diff --git a/packages/core/src/data_block/block.rs b/packages/core/src/data_block/block.rs index 2856040..b1a282d 100644 --- a/packages/core/src/data_block/block.rs +++ b/packages/core/src/data_block/block.rs @@ -1,12 +1,16 @@ use super::{ - record::DataBlockRecord, typedefs::{ - AttributeRows, AttributeRowsMap, CsvRecord, CsvRecordSlice, DataBlockHeaders, - DataBlockRecords, DataBlockRecordsSlice, + AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, ColumnIndexByName, + DataBlockHeaders, DataBlockRecords, }, value::DataBlockValue, }; -use std::collections::HashSet; +use fnv::FnvHashMap; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::{processing::aggregator::typedefs::RecordsSet, utils::math::uround_down}; #[cfg(feature = "pyo3")] use pyo3::prelude::*; @@ -15,189 +19,162 @@ use pyo3::prelude::*; /// The goal of this is to allow data processing to handle with memory references /// to the data block instead of copying data around #[cfg_attr(feature = "pyo3", pyclass)] -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct DataBlock { - /// Vector of strings representhing the data headers + /// Vector of strings representing the data headers pub headers: DataBlockHeaders, /// Vector of data records, where each record represents a row (headers not included) pub records: DataBlockRecords, } impl DataBlock { + /// Returns a new DataBlock with default values + pub fn default() -> DataBlock { + DataBlock { + headers: DataBlockHeaders::default(), + records: DataBlockRecords::default(), + } + } + /// Returns a new DataBlock /// # Arguments - /// * `headers` - Vector of string representhing the data headers + /// * `headers` - Vector of string representing the data headers /// * `records` - Vector of data records, where each record represents a row (headers not included) #[inline] pub fn new(headers: DataBlockHeaders, records: DataBlockRecords) -> DataBlock { DataBlock { headers, records } } - /// Calcules the rows where each value on the data records is present - /// # Arguments - /// * `records`: List of records to analyze + /// Returns a map of column name -> column index #[inline] - pub fn calc_attr_rows(records: &DataBlockRecordsSlice) -> AttributeRowsMap { + pub fn calc_column_index_by_name(&self) -> ColumnIndexByName { + self.headers + .iter() + .enumerate() + .map(|(column_index, column_name)| ((**column_name).clone(), column_index)) + .collect() + } + + /// Calculates the rows where each value on the data records is present + #[inline] + pub fn calc_attr_rows(&self) -> AttributeRowsMap { let mut attr_rows: AttributeRowsMap = AttributeRowsMap::default(); - for (i, r) in records.iter().enumerate() { + for (i, r) in self.records.iter().enumerate() { for v in r.values.iter() { attr_rows - .entry(v) + .entry(v.clone()) .or_insert_with(AttributeRows::new) .push(i); } } attr_rows } -} - -/// Trait that needs to be implement to create a data block. -/// It already contains the logic to create the data block, so we only -/// need to worry about mapping the headers and records from InputType -pub trait DataBlockCreator { - /// Creator input type, it can be a File Reader, another data structure... - type InputType; - /// The error type that can be generated when parsing headers/records - type ErrorType; + // Calculates the rows where each value on the data records is present + /// grouped by column index. #[inline] - fn gen_use_columns_set(headers: &CsvRecordSlice, use_columns: &[String]) -> HashSet { - let use_columns_str_set: HashSet = use_columns - .iter() - .map(|c| c.trim().to_lowercase()) - .collect(); - headers - .iter() - .enumerate() - .filter_map(|(i, h)| { - if use_columns_str_set.is_empty() - || use_columns_str_set.contains(&h.trim().to_lowercase()) - { - Some(i) - } else { - None - } + pub fn calc_attr_rows_with_no_empty_values(&self) -> AttributeRowsByColumnMap { + let mut attr_rows_by_column: AttributeRowsByColumnMap = AttributeRowsByColumnMap::default(); + + for (i, r) in self.records.iter().enumerate() { + for v in r.values.iter() { + attr_rows_by_column + .entry(v.column_index) + .or_insert_with(AttributeRowsMap::default) + .entry(v.clone()) + .or_insert_with(AttributeRows::new) + .push(i); + } + } + attr_rows_by_column + } + + /// Calculates the rows where each value on the data records is present + /// grouped by column index. This will also include empty values mapped with + /// the `empty_value` string. + /// # Arguments + /// * `empty_value` - Empty values on the final synthetic data will be represented by this + #[inline] + pub fn calc_attr_rows_by_column(&self, empty_value: &Arc) -> AttributeRowsByColumnMap { + let mut attr_rows_by_column: FnvHashMap< + usize, + FnvHashMap, RecordsSet>, + > = FnvHashMap::default(); + let empty_records: RecordsSet = (0..self.records.len()).collect(); + + // start with empty values for all columns + for column_index in 0..self.headers.len() { + attr_rows_by_column + .entry(column_index) + .or_insert_with(FnvHashMap::default) + .entry(Arc::new(DataBlockValue::new( + column_index, + empty_value.clone(), + ))) + .or_insert_with(|| empty_records.clone()); + } + + // go through each record and map where the values occur on the columns + for (i, r) in self.records.iter().enumerate() { + for value in r.values.iter() { + let current_attr_rows = attr_rows_by_column + .entry(value.column_index) + .or_insert_with(FnvHashMap::default); + + // insert it on the correspondent entry for the data block value + current_attr_rows + .entry(value.clone()) + .or_insert_with(RecordsSet::default) + .insert(i); + // it's now used being used, so we make sure to remove this from the column empty records + current_attr_rows + .entry(Arc::new(DataBlockValue::new( + value.column_index, + empty_value.clone(), + ))) + .or_insert_with(RecordsSet::default) + .remove(&i); + } + } + + // sort the records ids, so we can leverage the intersection alg later on + attr_rows_by_column + .drain() + .map(|(column_index, mut attr_rows)| { + ( + column_index, + attr_rows + .drain() + .map(|(value, mut rows_set)| (value, rows_set.drain().sorted().collect())) + .collect(), + ) }) .collect() } #[inline] - fn gen_sensitive_zeros_set( - filtered_headers: &CsvRecordSlice, - sensitive_zeros: &[String], - ) -> HashSet { - let sensitive_zeros_str_set: HashSet = sensitive_zeros - .iter() - .map(|c| c.trim().to_lowercase()) - .collect(); - filtered_headers - .iter() - .enumerate() - .filter_map(|(i, h)| { - if sensitive_zeros_str_set.contains(&h.trim().to_lowercase()) { - Some(i) - } else { - None - } - }) - .collect() + /// Returns the number of records on the data block + pub fn number_of_records(&self) -> usize { + self.records.len() } #[inline] - fn map_headers( - headers: &mut CsvRecord, - use_columns: &[String], - sensitive_zeros: &[String], - ) -> (CsvRecord, HashSet, HashSet) { - let use_columns_set = Self::gen_use_columns_set(headers, use_columns); - let filtered_headers: CsvRecord = headers - .iter() - .enumerate() - .filter_map(|(i, h)| { - if use_columns_set.contains(&i) { - Some(h.clone()) - } else { - None - } - }) - .collect(); - let sensitive_zeros_set = Self::gen_sensitive_zeros_set(&filtered_headers, sensitive_zeros); - - (filtered_headers, use_columns_set, sensitive_zeros_set) + /// Returns the number of records on the data block protected by `resolution` + pub fn protected_number_of_records(&self, resolution: usize) -> usize { + uround_down(self.number_of_records() as f64, resolution as f64) } #[inline] - fn map_records( - records: &[CsvRecord], - use_columns_set: &HashSet, - sensitive_zeros_set: &HashSet, - record_limit: usize, - ) -> DataBlockRecords { - let map_result = |record: &CsvRecord| { - let values: CsvRecord = record - .iter() - .enumerate() - .filter_map(|(i, h)| { - if use_columns_set.contains(&i) { - Some(h.trim().to_string()) - } else { - None - } - }) - .collect(); - - DataBlockRecord::new( - values - .iter() - .enumerate() - .filter_map(|(i, r)| { - let record_val = r.trim(); - if !record_val.is_empty() - && (sensitive_zeros_set.contains(&i) || record_val != "0") - { - Some(DataBlockValue::new(i, record_val.into())) - } else { - None - } - }) - .collect(), - ) - }; - - if record_limit > 0 { - records.iter().take(record_limit).map(map_result).collect() + /// Normalizes the reporting length based on the number of selected headers. + /// Returns the normalized value + /// # Arguments + /// * `reporting_length` - Reporting length to be normalized (0 means use all columns) + pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize { + if reporting_length == 0 { + self.headers.len() } else { - records.iter().map(map_result).collect() + usize::min(reporting_length, self.headers.len()) } } - - #[inline] - fn create( - input_res: Result, - use_columns: &[String], - sensitive_zeros: &[String], - record_limit: usize, - ) -> Result { - let mut input = input_res?; - let (headers, use_columns_set, sensitive_zeros_set) = Self::map_headers( - &mut Self::get_headers(&mut input)?, - use_columns, - sensitive_zeros, - ); - let records = Self::map_records( - &Self::get_records(&mut input)?, - &use_columns_set, - &sensitive_zeros_set, - record_limit, - ); - - Ok(DataBlock::new(headers, records)) - } - - /// Should be implemented to return the CsvRecords reprensenting the headers - fn get_headers(input: &mut Self::InputType) -> Result; - - /// Should be implemented to return the vector of CsvRecords reprensenting rows - fn get_records(input: &mut Self::InputType) -> Result, Self::ErrorType>; } diff --git a/packages/core/src/data_block/csv_block_creator.rs b/packages/core/src/data_block/csv_block_creator.rs index 1826dbc..6a8a0c4 100644 --- a/packages/core/src/data_block/csv_block_creator.rs +++ b/packages/core/src/data_block/csv_block_creator.rs @@ -1,4 +1,4 @@ -use super::{block::DataBlockCreator, typedefs::CsvRecord}; +use super::{data_block_creator::DataBlockCreator, typedefs::CsvRecord}; use csv::{Error, Reader, StringRecord}; use std::fs::File; diff --git a/packages/core/src/data_block/csv_io_error.rs b/packages/core/src/data_block/csv_io_error.rs new file mode 100644 index 0000000..88f9f95 --- /dev/null +++ b/packages/core/src/data_block/csv_io_error.rs @@ -0,0 +1,36 @@ +use csv::Error; +use std::fmt::{Display, Formatter, Result}; + +#[cfg(feature = "pyo3")] +use pyo3::exceptions::PyIOError; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +/// Wrapper for a csv::Error, so the from +/// trait can be implemented for PyErr +pub struct CsvIOError { + error: Error, +} + +impl CsvIOError { + /// Creates a new CsvIOError from a csv::Error + /// # Arguments + /// * `error` - Related csv::Error + pub fn new(error: Error) -> CsvIOError { + CsvIOError { error } + } +} + +impl Display for CsvIOError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "{}", self.error) + } +} + +#[cfg(feature = "pyo3")] +impl From for PyErr { + fn from(err: CsvIOError) -> PyErr { + PyIOError::new_err(err.error.to_string()) + } +} diff --git a/packages/core/src/data_block/data_block_creator.rs b/packages/core/src/data_block/data_block_creator.rs new file mode 100644 index 0000000..22d566b --- /dev/null +++ b/packages/core/src/data_block/data_block_creator.rs @@ -0,0 +1,166 @@ +use super::{ + block::DataBlock, + record::DataBlockRecord, + typedefs::{CsvRecord, CsvRecordRef, CsvRecordRefSlice, CsvRecordSlice, DataBlockRecords}, + value::DataBlockValue, +}; +use std::{collections::HashSet, sync::Arc}; + +/// Trait that needs to be implement to create a data block. +/// It already contains the logic to create the data block, so we only +/// need to worry about mapping the headers and records from InputType +pub trait DataBlockCreator { + /// Creator input type, it can be a File Reader, another data structure... + type InputType; + /// The error type that can be generated when parsing headers/records + type ErrorType; + + #[inline] + fn gen_use_columns_set(headers: &CsvRecordSlice, use_columns: &[String]) -> HashSet { + let use_columns_str_set: HashSet = use_columns + .iter() + .map(|c| c.trim().to_lowercase()) + .collect(); + headers + .iter() + .enumerate() + .filter_map(|(i, h)| { + if use_columns_str_set.is_empty() + || use_columns_str_set.contains(&h.trim().to_lowercase()) + { + Some(i) + } else { + None + } + }) + .collect() + } + + #[inline] + fn gen_sensitive_zeros_set( + filtered_headers: &CsvRecordRefSlice, + sensitive_zeros: &[String], + ) -> HashSet { + let sensitive_zeros_str_set: HashSet = sensitive_zeros + .iter() + .map(|c| c.trim().to_lowercase()) + .collect(); + filtered_headers + .iter() + .enumerate() + .filter_map(|(i, h)| { + if sensitive_zeros_str_set.contains(&h.trim().to_lowercase()) { + Some(i) + } else { + None + } + }) + .collect() + } + + #[inline] + fn normalize_value(value: &str) -> String { + value + .trim() + // replace reserved delimiters + .replace(";", "") + .replace(":", "") + } + + #[inline] + fn map_headers( + headers: &mut CsvRecord, + use_columns: &[String], + sensitive_zeros: &[String], + ) -> (CsvRecordRef, HashSet, HashSet) { + let use_columns_set = Self::gen_use_columns_set(headers, use_columns); + let filtered_headers: CsvRecordRef = headers + .iter() + .enumerate() + .filter_map(|(i, h)| { + if use_columns_set.contains(&i) { + Some(Arc::new(Self::normalize_value(h))) + } else { + None + } + }) + .collect(); + let sensitive_zeros_set = Self::gen_sensitive_zeros_set(&filtered_headers, sensitive_zeros); + + (filtered_headers, use_columns_set, sensitive_zeros_set) + } + + #[inline] + fn map_records( + records: &[CsvRecord], + use_columns_set: &HashSet, + sensitive_zeros_set: &HashSet, + record_limit: usize, + ) -> DataBlockRecords { + let map_result = |record: &CsvRecord| { + let values: CsvRecord = record + .iter() + .enumerate() + .filter_map(|(i, h)| { + if use_columns_set.contains(&i) { + Some(h.trim().to_string()) + } else { + None + } + }) + .collect(); + + Arc::new(DataBlockRecord::new( + values + .iter() + .enumerate() + .filter_map(|(i, r)| { + let record_val = Self::normalize_value(r); + if !record_val.is_empty() + && (sensitive_zeros_set.contains(&i) || record_val != "0") + { + Some(Arc::new(DataBlockValue::new(i, Arc::new(record_val)))) + } else { + None + } + }) + .collect(), + )) + }; + + if record_limit > 0 { + records.iter().take(record_limit).map(map_result).collect() + } else { + records.iter().map(map_result).collect() + } + } + + #[inline] + fn create( + input_res: Result, + use_columns: &[String], + sensitive_zeros: &[String], + record_limit: usize, + ) -> Result, Self::ErrorType> { + let mut input = input_res?; + let (headers, use_columns_set, sensitive_zeros_set) = Self::map_headers( + &mut Self::get_headers(&mut input)?, + use_columns, + sensitive_zeros, + ); + let records = Self::map_records( + &Self::get_records(&mut input)?, + &use_columns_set, + &sensitive_zeros_set, + record_limit, + ); + + Ok(Arc::new(DataBlock::new(headers, records))) + } + + /// Should be implemented to return the CsvRecords representing the headers + fn get_headers(input: &mut Self::InputType) -> Result; + + /// Should be implemented to return the vector of CsvRecords representing rows + fn get_records(input: &mut Self::InputType) -> Result, Self::ErrorType>; +} diff --git a/packages/core/src/data_block/mod.rs b/packages/core/src/data_block/mod.rs index 3e73948..a01d02d 100644 --- a/packages/core/src/data_block/mod.rs +++ b/packages/core/src/data_block/mod.rs @@ -1,10 +1,21 @@ /// Module defining the structures that represent a data block pub mod block; + /// Module to create data blocks from CSV files pub mod csv_block_creator; + +/// Defines io errors for handling csv files +/// that are easier to bind to other languages +pub mod csv_io_error; + +/// Module to create data blocks from different input types (trait definitions) +pub mod data_block_creator; + /// Module defining the structures that represent a data block record pub mod record; + /// Type definitions related to data blocks pub mod typedefs; + /// Module defining the structures that represent a data block value pub mod value; diff --git a/packages/core/src/data_block/record.rs b/packages/core/src/data_block/record.rs index a358b35..4912ce3 100644 --- a/packages/core/src/data_block/record.rs +++ b/packages/core/src/data_block/record.rs @@ -1,10 +1,12 @@ use super::value::DataBlockValue; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; /// Represents all the values of a given row in a data block -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct DataBlockRecord { /// Vector of data block values for a given row indexed by column - pub values: Vec, + pub values: Vec>, } impl DataBlockRecord { @@ -12,7 +14,7 @@ impl DataBlockRecord { /// # Arguments /// * `values` - Vector of data block values for a given row indexed by column #[inline] - pub fn new(values: Vec) -> DataBlockRecord { + pub fn new(values: Vec>) -> DataBlockRecord { DataBlockRecord { values } } } diff --git a/packages/core/src/data_block/typedefs.rs b/packages/core/src/data_block/typedefs.rs index 6a303f9..6a74600 100644 --- a/packages/core/src/data_block/typedefs.rs +++ b/packages/core/src/data_block/typedefs.rs @@ -1,5 +1,8 @@ use super::{record::DataBlockRecord, value::DataBlockValue}; use fnv::FnvHashMap; +use std::sync::Arc; + +use crate::processing::evaluator::rare_combinations_comparison_data::CombinationComparison; /// Ordered vector of rows where a particular value combination is present pub type AttributeRows = Vec; @@ -11,23 +14,38 @@ pub type AttributeRowsSlice = [usize]; pub type CsvRecord = Vec; /// The same as CsvRecord, but keeping a reference to the string value, so data does not have to be duplicated -pub type CsvRecordRef<'data_block_value> = Vec<&'data_block_value str>; +pub type CsvRecordRef = Vec>; /// Slice of CsvRecord pub type CsvRecordSlice = [String]; -/// Vector of strings representhing the data block headers -pub type DataBlockHeaders = CsvRecord; +/// Slice of CsvRecord +pub type CsvRecordRefSlice = [Arc]; + +/// Vector of strings representing the data block headers +pub type DataBlockHeaders = CsvRecordRef; /// Slice of DataBlockHeaders -pub type DataBlockHeadersSlice = CsvRecordSlice; +pub type DataBlockHeadersSlice = CsvRecordRefSlice; /// Vector of data block records, where each record represents a row -pub type DataBlockRecords = Vec; - -/// Slice of DataBlockRecords -pub type DataBlockRecordsSlice = [DataBlockRecord]; +pub type DataBlockRecords = Vec>; /// HashMap with a data block value as key and all the attribute row indexes where it occurs as value -pub type AttributeRowsMap<'data_block_value> = - FnvHashMap<&'data_block_value DataBlockValue, AttributeRows>; +pub type AttributeRowsMap = FnvHashMap, AttributeRows>; + +/// HashMap with a data block value as key and attribute row indexes as value +pub type AttributeRowsRefMap = FnvHashMap, Arc>; + +/// Maps the column index -> data block value -> rows where the value appear +pub type AttributeRowsByColumnMap = FnvHashMap; + +/// Raw synthesized data (vector of csv record references to the original data block) +pub type RawSyntheticData = Vec; + +/// A vector of combination comparisons +/// (between sensitive and synthetic data) +pub type CombinationsComparisons = Vec; + +/// Maps a column name to the corresponding column index +pub type ColumnIndexByName = FnvHashMap; diff --git a/packages/core/src/data_block/value.rs b/packages/core/src/data_block/value.rs index 926d682..c3f32c3 100644 --- a/packages/core/src/data_block/value.rs +++ b/packages/core/src/data_block/value.rs @@ -1,10 +1,16 @@ +use super::typedefs::DataBlockHeadersSlice; +use serde::{Deserialize, Serialize}; +use std::{fmt::Display, str::FromStr, sync::Arc}; + +const VALUE_DELIMITER: char = ':'; + /// Represents a value of a given data block for a particular row and column -#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub struct DataBlockValue { /// Column index this value belongs to starting in '0' pub column_index: usize, /// Value stored on the CSV file for a given row at `column_index` - pub value: String, + pub value: Arc, } impl DataBlockValue { @@ -13,10 +19,72 @@ impl DataBlockValue { /// * `column_index` - Column index this value belongs to starting in '0' /// * `value` - Value stored on the CSV file for a given row at `column_index` #[inline] - pub fn new(column_index: usize, value: String) -> DataBlockValue { + pub fn new(column_index: usize, value: Arc) -> DataBlockValue { DataBlockValue { column_index, value, } } + + /// Formats a data block value as String using the + /// corresponding header name + /// The result is formatted as: `{header_name}:{block_value}` + /// # Arguments + /// * `headers` - data block headers + /// * `value` - value to be formatted + #[inline] + pub fn format_str_using_headers(&self, headers: &DataBlockHeadersSlice) -> String { + format!( + "{}{}{}", + headers[self.column_index], VALUE_DELIMITER, self.value + ) + } +} + +/// Error that can happen when parsing a data block from +/// a string +pub struct ParseDataBlockValueError { + error_message: String, +} + +impl ParseDataBlockValueError { + #[inline] + /// Creates a new ParseDataBlockValueError with `error_message` + pub fn new(error_message: String) -> ParseDataBlockValueError { + ParseDataBlockValueError { error_message } + } +} + +impl Display for ParseDataBlockValueError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error_message) + } +} + +impl FromStr for DataBlockValue { + type Err = ParseDataBlockValueError; + + /// Creates a new DataBlockValue by parsing `str_value` + fn from_str(str_value: &str) -> Result { + if let Some(pos) = str_value.find(VALUE_DELIMITER) { + Ok(DataBlockValue::new( + str_value[..pos] + .parse::() + .map_err(|err| ParseDataBlockValueError::new(err.to_string()))?, + Arc::new(str_value[pos + 1..].into()), + )) + } else { + Err(ParseDataBlockValueError::new(format!( + "data block value missing '{}'", + VALUE_DELIMITER + ))) + } + } +} + +impl Display for DataBlockValue { + /// Formats the DataBlockValue as a string + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(write!(f, "{}:{}", self.column_index, self.value)?) + } } diff --git a/packages/core/src/processing/aggregator/aggregated_data.rs b/packages/core/src/processing/aggregator/aggregated_data.rs new file mode 100644 index 0000000..a9514c7 --- /dev/null +++ b/packages/core/src/processing/aggregator/aggregated_data.rs @@ -0,0 +1,448 @@ +use super::{ + privacy_risk_summary::PrivacyRiskSummary, + records_analysis_data::RecordsAnalysisData, + typedefs::{ + AggregatedCountByLenMap, AggregatesCountMap, AggregatesCountStringMap, RecordsByLenMap, + RecordsSensitivity, + }, +}; +use itertools::Itertools; +use log::info; +use serde::{Deserialize, Serialize}; +use std::{ + io::{BufReader, BufWriter, Error, Write}, + sync::Arc, +}; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +use crate::{ + data_block::block::DataBlock, + processing::aggregator::typedefs::RecordsSet, + utils::{math::uround_down, time::ElapsedDurationLogger}, +}; + +/// Aggregated data produced by the Aggregator +#[cfg_attr(feature = "pyo3", pyclass)] +#[derive(Serialize, Deserialize)] +pub struct AggregatedData { + /// Data block from where this aggregated data was generated + pub data_block: Arc, + /// Maps a value combination to its aggregated count + pub aggregates_count: AggregatesCountMap, + /// A vector of sensitivities for each record (the vector index is the record index) + pub records_sensitivity: RecordsSensitivity, + /// Maximum length used to compute attribute combinations + pub reporting_length: usize, +} + +impl AggregatedData { + /// Creates a new AggregatedData struct with default values + #[inline] + pub fn default() -> AggregatedData { + AggregatedData { + data_block: Arc::new(DataBlock::default()), + aggregates_count: AggregatesCountMap::default(), + records_sensitivity: RecordsSensitivity::default(), + reporting_length: 0, + } + } + + /// Creates a new AggregatedData struct + /// # Arguments: + /// * `data_block` - Data block with the original data + /// * `aggregates_count` - Computed aggregates count map + /// * `records_sensitivity` - Computed sensitivity for the records + /// * `reporting_length` - Maximum length used to compute attribute combinations + #[inline] + pub fn new( + data_block: Arc, + aggregates_count: AggregatesCountMap, + records_sensitivity: RecordsSensitivity, + reporting_length: usize, + ) -> AggregatedData { + AggregatedData { + data_block, + aggregates_count, + records_sensitivity, + reporting_length, + } + } + + #[inline] + /// Check if the records len map contains a value across all lengths + fn records_by_len_contains(records_by_len: &RecordsByLenMap, value: &usize) -> bool { + records_by_len + .values() + .any(|records| records.contains(value)) + } + + #[inline] + /// Whe first generated the RecordsByLenMap might contain + /// the same records appearing in different combination lengths. + /// This will keep only the record on the shortest length key + /// and remove the other occurrences + fn keep_records_only_on_shortest_len(records_by_len: &mut RecordsByLenMap) { + let lengths: Vec = records_by_len.keys().cloned().sorted().collect(); + let max_len = lengths.last().copied().unwrap_or(0); + + // make sure the record will be only present in the shortest len + // start on the shortest length + for l in lengths { + for r in records_by_len.get(&l).unwrap().clone() { + // search all lengths > l and remove the record from there + for n in l + 1..=max_len { + if let Some(records) = records_by_len.get_mut(&n) { + records.remove(&r); + } + } + } + } + // retain only non-empty record lists + records_by_len.retain(|_, records| !records.is_empty()); + } + + #[inline] + fn _read_from_json(file_path: &str) -> Result { + let _duration_logger = ElapsedDurationLogger::new("aggregated count json read"); + + Ok(serde_json::from_reader(BufReader::new( + std::fs::File::open(file_path)?, + ))?) + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl AggregatedData { + /// Builds a map from value combinations formatted as string to its aggregated count + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + /// # Arguments: + /// * `combination_delimiter` - Delimiter used to join combinations + pub fn get_formatted_aggregates_count( + &self, + combination_delimiter: &str, + ) -> AggregatesCountStringMap { + self.aggregates_count + .iter() + .map(|(key, value)| { + ( + key.format_str_using_headers(&self.data_block.headers, combination_delimiter), + value.clone(), + ) + }) + .collect() + } + + #[cfg(feature = "pyo3")] + /// A vector of sensitivities for each record (the vector index is the record index) + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + pub fn get_records_sensitivity(&self) -> RecordsSensitivity { + self.records_sensitivity.clone() + } + + /// Round the aggregated counts down to the nearest multiple of resolution + /// # Arguments: + /// * `resolution` - Reporting resolution used for data synthesis + pub fn protect_aggregates_count(&mut self, resolution: usize) { + let _duration_logger = ElapsedDurationLogger::new("aggregates count protect"); + + info!( + "protecting aggregates counts with resolution {}", + resolution + ); + + for count in self.aggregates_count.values_mut() { + count.count = uround_down(count.count as f64, resolution as f64); + } + // remove 0 counts from response + self.aggregates_count.retain(|_, count| count.count > 0); + } + + /// Calculates the records that contain rare combinations grouped by length. + /// This might contain duplicated records on different lengths if the record + /// contains more than one rare combination. Unique combinations are also contained + /// in this. + /// # Arguments: + /// * `resolution` - Reporting resolution used for data synthesis + pub fn calc_all_rare_combinations_records_by_len(&self, resolution: usize) -> RecordsByLenMap { + let _duration_logger = + ElapsedDurationLogger::new("all rare combinations records by len calculation"); + let mut rare_records_by_len: RecordsByLenMap = RecordsByLenMap::default(); + + for (agg, count) in self.aggregates_count.iter() { + if count.count < resolution { + rare_records_by_len + .entry(agg.len()) + .or_insert_with(RecordsSet::default) + .extend(&count.contained_in_records); + } + } + rare_records_by_len + } + + /// Calculates the records that contain unique combinations grouped by length. + /// This might contain duplicated records on different lengths if the record + /// contains more than one unique combination. + pub fn calc_all_unique_combinations_records_by_len(&self) -> RecordsByLenMap { + let _duration_logger = + ElapsedDurationLogger::new("all unique combinations records by len calculation"); + let mut unique_records_by_len: RecordsByLenMap = RecordsByLenMap::default(); + + for (agg, count) in self.aggregates_count.iter() { + if count.count == 1 { + unique_records_by_len + .entry(agg.len()) + .or_insert_with(RecordsSet::default) + .extend(&count.contained_in_records); + } + } + unique_records_by_len + } + + /// Calculate the records that contain unique and rare combinations grouped by length. + /// A tuple with the `(unique, rare)` is returned. + /// Both returned maps are ensured to only contain the records on the shortest length, + /// so each record will appear only on the shortest combination length that isolates it + /// within a rare group. Also, if the record has a unique combination, it will not + /// be present on the rare map, only on the unique one. + /// # Arguments: + /// * `resolution` - Reporting resolution used for data synthesis + pub fn calc_unique_rare_combinations_records_by_len( + &self, + resolution: usize, + ) -> (RecordsByLenMap, RecordsByLenMap) { + let _duration_logger = + ElapsedDurationLogger::new("unique/rare combinations records by len calculation"); + let mut unique_records_by_len = self.calc_all_unique_combinations_records_by_len(); + let mut rare_records_by_len = self.calc_all_rare_combinations_records_by_len(resolution); + + AggregatedData::keep_records_only_on_shortest_len(&mut unique_records_by_len); + AggregatedData::keep_records_only_on_shortest_len(&mut rare_records_by_len); + + // remove records with unique combinations from the rare map + rare_records_by_len.values_mut().for_each(|records| { + records.retain(|r| !AggregatedData::records_by_len_contains(&unique_records_by_len, r)); + }); + + (unique_records_by_len, rare_records_by_len) + } + + /// Perform the records analysis and returns the data containing + /// unique, rare and risky information grouped per length. + /// # Arguments: + /// * `resolution` - Reporting resolution used for data synthesis + /// * `protect` - Whether or not the counts should be rounded to the + /// nearest smallest multiple of resolution + pub fn calc_records_analysis_by_len( + &self, + resolution: usize, + protect: bool, + ) -> RecordsAnalysisData { + let _duration_logger = ElapsedDurationLogger::new("records analysis by len"); + let (unique_records_by_len, rare_records_by_len) = + self.calc_unique_rare_combinations_records_by_len(resolution); + + RecordsAnalysisData::from_unique_rare_combinations_records_by_len( + &unique_records_by_len, + &rare_records_by_len, + self.data_block.records.len(), + self.reporting_length, + resolution, + protect, + ) + } + + /// Calculates the number of rare combinations grouped by combination length + /// # Arguments: + /// * `resolution` - Reporting resolution used for data synthesis + pub fn calc_rare_combinations_count_by_len( + &self, + resolution: usize, + ) -> AggregatedCountByLenMap { + let _duration_logger = + ElapsedDurationLogger::new("rare combinations count by len calculation"); + let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); + + info!( + "calculating rare combinations counts by length with resolution {}", + resolution + ); + + for (agg, count) in self.aggregates_count.iter() { + if count.count < resolution { + let curr_count = result.entry(agg.len()).or_insert(0); + *curr_count += 1; + } + } + result + } + + /// Calculates the number of combinations grouped by combination length + pub fn calc_combinations_count_by_len(&self) -> AggregatedCountByLenMap { + let _duration_logger = ElapsedDurationLogger::new("combination count by len calculation"); + let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); + + info!("calculating combination counts by length"); + + for agg in self.aggregates_count.keys() { + let curr_count = result.entry(agg.len()).or_insert(0); + *curr_count += 1; + } + result + } + + /// Calculates the sum of all combination counts grouped by combination length + pub fn calc_combinations_sum_by_len(&self) -> AggregatedCountByLenMap { + let _duration_logger = ElapsedDurationLogger::new("combinations sum by len calculation"); + let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); + + info!("calculating combination counts sums by length"); + + for (agg, count) in self.aggregates_count.iter() { + let curr_sum = result.entry(agg.len()).or_insert(0); + *curr_sum += count.count; + } + result + } + + /// Calculates the privacy risk related with data block and the generated + /// aggregates counts + /// # Arguments: + /// * `resolution` - Reporting resolution used for data synthesis + pub fn calc_privacy_risk(&self, resolution: usize) -> PrivacyRiskSummary { + let _duration_logger = ElapsedDurationLogger::new("privacy risk calculation"); + + info!("calculating privacy risk..."); + + PrivacyRiskSummary::from_aggregates_count( + self.data_block.records.len(), + &self.aggregates_count, + resolution, + ) + } + + /// Writes the aggregates counts to the file system in a csv/tsv like format + /// # Arguments: + /// * `aggregates_path` - File path to be written + /// * `aggregates_delimiter` - Delimiter to use when writing to `aggregates_path` + /// * `combination_delimiter` - Delimiter used to join combinations and format then + /// as strings + /// * `resolution` - Reporting resolution used for data synthesis + /// * `protected` - Whether or not the counts were protected before calling this + pub fn write_aggregates_count( + &self, + aggregates_path: &str, + aggregates_delimiter: char, + combination_delimiter: &str, + resolution: usize, + protected: bool, + ) -> Result<(), Error> { + let _duration_logger = ElapsedDurationLogger::new("write aggregates count"); + + info!("writing file {}", aggregates_path); + + let mut file = std::fs::File::create(aggregates_path)?; + + file.write_all( + format!( + "selections{}{}\n", + aggregates_delimiter, + if protected { + "protected_count" + } else { + "count" + } + ) + .as_bytes(), + )?; + file.write_all( + format!( + "selections{}{}\n", + aggregates_delimiter, + uround_down(self.data_block.records.len() as f64, resolution as f64) + ) + .as_bytes(), + )?; + for aggregate in self.aggregates_count.keys() { + file.write_all( + format!( + "{}{}{}\n", + aggregate + .format_str_using_headers(&self.data_block.headers, combination_delimiter), + aggregates_delimiter, + self.aggregates_count[aggregate].count + ) + .as_bytes(), + )? + } + Ok(()) + } + + /// Writes the records sensitivity to the file system in a csv/tsv like format + /// # Arguments: + /// * `records_sensitivity_path` - File path to be written + /// * `records_sensitivity_delimiter` - Delimiter to use when writing to `records_sensitivity_path` + pub fn write_records_sensitivity( + &self, + records_sensitivity_path: &str, + records_sensitivity_delimiter: char, + ) -> Result<(), Error> { + let _duration_logger = ElapsedDurationLogger::new("write records sensitivity"); + + info!("writing file {}", records_sensitivity_path); + + let mut file = std::fs::File::create(records_sensitivity_path)?; + + file.write_all( + format!( + "record_index{}record_sensitivity\n", + records_sensitivity_delimiter + ) + .as_bytes(), + )?; + for (i, sensitivity) in self.records_sensitivity.iter().enumerate() { + file.write_all( + format!("{}{}{}\n", i, records_sensitivity_delimiter, sensitivity).as_bytes(), + )? + } + Ok(()) + } + + /// Serializes the aggregated data to a json file + /// # Arguments: + /// * `file_path` - File path to be written + pub fn write_to_json(&self, file_path: &str) -> Result<(), Error> { + let _duration_logger = ElapsedDurationLogger::new("aggregated count json write"); + + Ok(serde_json::to_writer( + BufWriter::new(std::fs::File::create(file_path)?), + &self, + )?) + } + + #[cfg(feature = "pyo3")] + #[staticmethod] + /// Deserializes the aggregated data from a json file + /// # Arguments: + /// * `file_path` - File path to read from + pub fn read_from_json(file_path: &str) -> Result { + AggregatedData::_read_from_json(file_path) + } + + #[cfg(not(feature = "pyo3"))] + /// Deserializes the aggregated data from a json file + /// # Arguments: + /// * `file_path` - File path to read from + pub fn read_from_json(file_path: &str) -> Result { + AggregatedData::_read_from_json(file_path) + } +} + +#[cfg(feature = "pyo3")] +pub fn register(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/packages/core/src/processing/aggregator/data_aggregator.rs b/packages/core/src/processing/aggregator/data_aggregator.rs new file mode 100644 index 0000000..0ae57ce --- /dev/null +++ b/packages/core/src/processing/aggregator/data_aggregator.rs @@ -0,0 +1,174 @@ +use super::aggregated_data::AggregatedData; +use super::rows_aggregator::RowsAggregator; +use super::typedefs::RecordsSet; +use itertools::Itertools; +use log::info; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::data_block::block::DataBlock; +use crate::processing::aggregator::record_attrs_selector::RecordAttrsSelector; +use crate::utils::math::calc_percentage; +use crate::utils::reporting::ReportProgress; +use crate::utils::threading::get_number_of_threads; +use crate::utils::time::ElapsedDurationLogger; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +/// Result of data aggregation for each combination +#[cfg_attr(feature = "pyo3", pyclass)] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AggregatedCount { + /// How many times the combination appears on the records + pub count: usize, + /// Which records this combinations is part of + pub contained_in_records: RecordsSet, +} + +#[cfg(feature = "pyo3")] +#[cfg_attr(feature = "pyo3", pymethods)] +impl AggregatedCount { + /// How many times the combination appears on the records + #[getter] + fn count(&self) -> usize { + self.count + } + + /// Which records this combinations is part of + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + fn get_contained_in_records(&self) -> RecordsSet { + self.contained_in_records.clone() + } +} + +impl Default for AggregatedCount { + fn default() -> AggregatedCount { + AggregatedCount { + count: 0, + contained_in_records: RecordsSet::default(), + } + } +} + +/// Process a data block to produced aggregated data +pub struct Aggregator { + data_block: Arc, +} + +impl Aggregator { + /// Returns a data aggregator for the given data block + /// # Arguments + /// * `data_block` - The data block to be processed + #[inline] + pub fn new(data_block: Arc) -> Aggregator { + Aggregator { data_block } + } + + /// Aggregates the data block and returns the aggregated data back + /// # Arguments + /// * `reporting_length` - Calculate combinations from 1 up to `reporting_length` + /// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes + /// (0 means no suppression) + /// * `progress_reporter` - Will be used to report the processing + /// progress (`ReportProgress` trait). If `None`, nothing will be reported + pub fn aggregate( + &mut self, + reporting_length: usize, + sensitivity_threshold: usize, + progress_reporter: &mut Option, + ) -> AggregatedData + where + T: ReportProgress, + { + let _duration_logger = ElapsedDurationLogger::new("data aggregation"); + let normalized_reporting_length = + self.data_block.normalize_reporting_length(reporting_length); + let length_range = (1..=normalized_reporting_length).collect::>(); + let total_n_records = self.data_block.records.len(); + let total_n_records_f64 = total_n_records as f64; + + info!( + "aggregating data with reporting length = {}, sensitivity_threshold = {} and {} thread(s)", + normalized_reporting_length, sensitivity_threshold, get_number_of_threads() + ); + + let result = RowsAggregator::aggregate_all( + total_n_records, + &mut self.build_rows_aggregators(&length_range, sensitivity_threshold), + progress_reporter, + ); + + Aggregator::update_aggregate_progress( + progress_reporter, + total_n_records, + total_n_records_f64, + ); + + info!( + "data aggregated resulting in {} distinct combinations...", + result.aggregates_count.len() + ); + info!( + "suppression ratio of aggregates is {:.2}%", + (1.0 - (result.selected_combs_count as f64 / result.all_combs_count as f64)) * 100.0 + ); + + AggregatedData::new( + self.data_block.clone(), + result.aggregates_count, + result.records_sensitivity, + normalized_reporting_length, + ) + } + + #[inline] + fn build_rows_aggregators<'length_range>( + &self, + length_range: &'length_range [usize], + sensitivity_threshold: usize, + ) -> Vec> { + if self.data_block.records.is_empty() { + return Vec::default(); + } + + let chunk_size = ((self.data_block.records.len() as f64) / (get_number_of_threads() as f64)) + .ceil() as usize; + let mut rows_aggregators: Vec = Vec::default(); + let attr_rows_map = Arc::new(self.data_block.calc_attr_rows()); + + for c in &self + .data_block + .records + .clone() + .drain(..) + .enumerate() + .chunks(chunk_size) + { + rows_aggregators.push(RowsAggregator::new( + self.data_block.clone(), + c.collect(), + RecordAttrsSelector::new( + length_range, + sensitivity_threshold, + attr_rows_map.clone(), + ), + )) + } + rows_aggregators + } + + #[inline] + fn update_aggregate_progress( + progress_reporter: &mut Option, + n_processed: usize, + total: f64, + ) where + T: ReportProgress, + { + if let Some(r) = progress_reporter { + r.report(calc_percentage(n_processed as f64, total)); + } + } +} diff --git a/packages/core/src/processing/aggregator/mod.rs b/packages/core/src/processing/aggregator/mod.rs index a5fcf75..d509ba5 100644 --- a/packages/core/src/processing/aggregator/mod.rs +++ b/packages/core/src/processing/aggregator/mod.rs @@ -1,501 +1,27 @@ +/// Module to represent aggregated data and provide +/// some methods/utilities for information extracted from it +pub mod aggregated_data; + +/// Dataset privacy risk definitions +pub mod privacy_risk_summary; + /// Module that can the used to suppress attributes from records to /// meet a certain sensitivity threshold pub mod record_attrs_selector; +/// Defines structures related to records analysis (unique, rare and risky +/// information) +pub mod records_analysis_data; + /// Type definitions related to the aggregation process pub mod typedefs; -use self::typedefs::{ - AggregatedCountByLenMap, AggregatesCountMap, RecordsSensitivity, RecordsSensitivitySlice, - RecordsSet, ValueCombinationSlice, -}; -use instant::Duration; -use itertools::Itertools; -use log::Level::Trace; -use log::{info, log_enabled, trace}; -use std::cmp::min; -use std::io::{Error, Write}; +/// Defines structures to store value combinations generated +/// during the aggregate process +pub mod value_combination; -use crate::data_block::block::DataBlock; -use crate::data_block::typedefs::DataBlockHeadersSlice; -use crate::data_block::value::DataBlockValue; -use crate::measure_time; -use crate::processing::aggregator::record_attrs_selector::RecordAttrsSelector; -use crate::utils::math::{calc_n_combinations_range, calc_percentage, uround_down}; -use crate::utils::reporting::ReportProgress; -use crate::utils::time::ElapsedDuration; +mod data_aggregator; -/// Represents the privacy risk information related to a data block -#[derive(Debug)] -pub struct PrivacyRiskSummary { - /// Total number of records on the data block - pub total_number_of_records: usize, - /// Total number of combinations aggregated (up to reporting length) - pub total_number_of_combinations: usize, - /// Number of records with unique combinations - pub records_with_unique_combinations_count: usize, - /// Number of records with rare combinations (combination count < resolution) - pub records_with_rare_combinations_count: usize, - /// Number of unique combinations - pub unique_combinations_count: usize, - /// Number of rare combinations - pub rare_combinations_count: usize, - /// Proportion of records containing unique combinations - pub records_with_unique_combinations_proportion: f64, - /// Proportion of records containing rare combinations - pub records_with_rare_combinations_proportion: f64, - /// Proportion of unique combinations - pub unique_combinations_proportion: f64, - /// Proportion of rare combinations - pub rare_combinations_proportion: f64, -} +pub use data_aggregator::{AggregatedCount, Aggregator}; -/// Result of data aggregation for each combination -#[derive(Debug)] -pub struct AggregatedCount { - /// How many times the combination appears on the records - pub count: usize, - /// Which records this combinations is part of - pub contained_in_records: RecordsSet, -} - -impl Default for AggregatedCount { - fn default() -> AggregatedCount { - AggregatedCount { - count: 0, - contained_in_records: RecordsSet::default(), - } - } -} - -#[derive(Debug)] -struct AggregatorDurations { - aggregate: Duration, - calc_rare_combinations_count_by_len: Duration, - calc_combinations_count_by_len: Duration, - calc_combinations_sum_by_len: Duration, - calc_privacy_risk: Duration, - write_aggregates_count: Duration, -} - -/// Aggregated data produced by the Aggregator -pub struct AggregatedData<'data_block> { - /// Maps a value combination to its aggregated count - pub aggregates_count: AggregatesCountMap<'data_block>, - /// A vector of sensitivities for each record (the vector index is the record index) - pub records_sensitivity: RecordsSensitivity, -} - -/// Process a data block to produced aggregated data -pub struct Aggregator<'data_block> { - data_block: &'data_block DataBlock, - durations: AggregatorDurations, -} - -impl<'data_block> Aggregator<'data_block> { - /// Returns a data aggregator for the given data block - /// # Arguments - /// * `data_block` - The data block to be processed - #[inline] - pub fn new(data_block: &'data_block DataBlock) -> Aggregator<'data_block> { - Aggregator { - data_block, - durations: AggregatorDurations { - aggregate: Duration::default(), - calc_rare_combinations_count_by_len: Duration::default(), - calc_combinations_count_by_len: Duration::default(), - calc_combinations_sum_by_len: Duration::default(), - calc_privacy_risk: Duration::default(), - write_aggregates_count: Duration::default(), - }, - } - } - - /// Formats a data block value as String. - /// The result is formatted as: `{header_name}:{block_value}` - /// # Arguments - /// * `headers`: data block headers - /// * `value`: value to be formatted - #[inline] - pub fn format_data_block_value_str( - headers: &DataBlockHeadersSlice, - value: &'data_block DataBlockValue, - ) -> String { - format!("{}:{}", headers[value.column_index], value.value) - } - - /// Formats a value combination as String. - /// The result is formatted as: - /// `{header_name}:{block_value};{header_name}:{block_value}...` - /// # Arguments - /// * `headers`: data block headers - /// * `aggregate`: combinations to be formatted - #[inline] - pub fn format_aggregate_str( - headers: &DataBlockHeadersSlice, - aggregate: &ValueCombinationSlice<'data_block>, - ) -> String { - let mut str = String::default(); - for comb in aggregate { - if !str.is_empty() { - str += ";"; - } - str += Aggregator::format_data_block_value_str(headers, comb).as_str(); - } - str - } - - /// Aggregates the data block and returns the aggregated data back - /// # Arguments - /// * `reporting_length` - Calculate combinations from 1 up to `reporting_length` - /// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes - /// (0 means no suppression) - /// * `progress_reporter` - Will be used to report the processing - /// progress (`ReportProgress` trait). If `None`, nothing will be reported - pub fn aggregate( - &mut self, - reporting_length: usize, - sensitivity_threshold: usize, - progress_reporter: &mut Option, - ) -> AggregatedData<'data_block> - where - T: ReportProgress, - { - measure_time!( - || { - let mut aggregates_count: AggregatesCountMap = AggregatesCountMap::default(); - let max_len = if reporting_length == 0 { - self.data_block.headers.len() - } else { - min(reporting_length, self.data_block.headers.len()) - }; - let length_range = (1..=max_len).collect::>(); - let mut record_attrs_selector = RecordAttrsSelector::new( - &length_range, - sensitivity_threshold, - DataBlock::calc_attr_rows(&self.data_block.records), - ); - let mut selected_combs_count = 0_u64; - let mut all_combs_count = 0_u64; - let mut records_sensitivity: RecordsSensitivity = RecordsSensitivity::default(); - let total_n_records = self.data_block.records.len(); - let total_n_records_f64 = total_n_records as f64; - - records_sensitivity.resize(total_n_records, 0); - - info!( - "aggregating data with reporting length = {} and sensitivity_threshold = {}", - max_len, sensitivity_threshold - ); - - for (record_index, record) in self.data_block.records.iter().enumerate() { - let selected_attrs = record_attrs_selector.select_from_record(&record.values); - - all_combs_count += - calc_n_combinations_range(record.values.len(), &length_range); - - Aggregator::update_aggregate_progress( - progress_reporter, - record_index, - total_n_records_f64, - ); - - for l in &length_range { - for c in selected_attrs.iter().combinations(*l) { - let curr_count = aggregates_count - .entry( - c.iter() - .sorted_by_key(|k| { - Aggregator::format_data_block_value_str( - &self.data_block.headers, - k, - ) - }) - .cloned() - .copied() - .collect_vec(), - ) - .or_insert_with(AggregatedCount::default); - curr_count.count += 1; - curr_count.contained_in_records.insert(record_index); - records_sensitivity[record_index] += 1; - selected_combs_count += 1; - } - } - } - Aggregator::update_aggregate_progress( - progress_reporter, - total_n_records, - total_n_records_f64, - ); - - info!( - "data aggregated resulting in {} distinct combinations...", - aggregates_count.len() - ); - info!( - "suppression ratio of aggregates is {:.2}%", - (1.0 - (selected_combs_count as f64 / all_combs_count as f64)) * 100.0 - ); - AggregatedData { - aggregates_count, - records_sensitivity, - } - }, - (self.durations.aggregate) - ) - } - - /// Round the aggregated counts down to the nearest multiple of resolution - /// # Arguments: - /// * `aggregates_count` - Counts to be rounded in place - /// * `resolution` - Reporting resolution used for data synthesis - pub fn protect_aggregates_count( - aggregates_count: &mut AggregatesCountMap<'data_block>, - resolution: usize, - ) { - info!( - "protecting aggregates counts with resolution {}", - resolution - ); - for count in aggregates_count.values_mut() { - count.count = uround_down(count.count as f64, resolution as f64); - } - // remove 0 counts from response - aggregates_count.retain(|_, count| count.count > 0); - } - - /// Calculates the number of rare combinations grouped by combination length - /// # Arguments: - /// * `aggregates_count` - Calculated aggregates count map - /// * `resolution` - Reporting resolution used for data synthesis - pub fn calc_rare_combinations_count_by_len( - &mut self, - aggregates_count: &AggregatesCountMap<'data_block>, - resolution: usize, - ) -> AggregatedCountByLenMap { - info!( - "calculating rare combinations counts by length with resolution {}", - resolution - ); - measure_time!( - || { - let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); - - for (agg, count) in aggregates_count.iter() { - if count.count < resolution { - let curr_count = result.entry(agg.len()).or_insert(0); - *curr_count += 1; - } - } - result - }, - (self.durations.calc_rare_combinations_count_by_len) - ) - } - - /// Calculates the number of combinations grouped by combination length - /// # Arguments: - /// * `aggregates_count` - Calculated aggregates count map - pub fn calc_combinations_count_by_len( - &mut self, - aggregates_count: &AggregatesCountMap<'data_block>, - ) -> AggregatedCountByLenMap { - info!("calculating combination counts by length"); - measure_time!( - || { - let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); - - for agg in aggregates_count.keys() { - let curr_count = result.entry(agg.len()).or_insert(0); - *curr_count += 1; - } - result - }, - (self.durations.calc_combinations_count_by_len) - ) - } - - /// Calculates the sum of all combination counts grouped by combination length - /// # Arguments: - /// * `aggregates_count` - Calculated aggregates count map - pub fn calc_combinations_sum_by_len( - &mut self, - aggregates_count: &AggregatesCountMap<'data_block>, - ) -> AggregatedCountByLenMap { - info!("calculating combination counts sums by length"); - measure_time!( - || { - let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); - - for (agg, count) in aggregates_count.iter() { - let curr_sum = result.entry(agg.len()).or_insert(0); - *curr_sum += count.count; - } - result - }, - (self.durations.calc_combinations_sum_by_len) - ) - } - - /// Calculates the privacy risk related with data block and the generated - /// aggregates counts - /// # Arguments: - /// * `aggregates_count` - Calculated aggregates count map - /// * `resolution` - Reporting resolution used for data synthesis - pub fn calc_privacy_risk( - &mut self, - aggregates_count: &AggregatesCountMap<'data_block>, - resolution: usize, - ) -> PrivacyRiskSummary { - info!("calculating privacy risk..."); - measure_time!( - || { - let mut records_with_unique_combinations = RecordsSet::default(); - let mut records_with_rare_combinations = RecordsSet::default(); - let mut unique_combinations_count: usize = 0; - let mut rare_combinations_count: usize = 0; - let total_number_of_records = self.data_block.records.len(); - let total_number_of_combinations = aggregates_count.len(); - - for count in aggregates_count.values() { - if count.count == 1 { - unique_combinations_count += 1; - count.contained_in_records.iter().for_each(|record_index| { - records_with_unique_combinations.insert(*record_index); - }); - } - if count.count < resolution { - rare_combinations_count += 1; - count.contained_in_records.iter().for_each(|record_index| { - records_with_rare_combinations.insert(*record_index); - }); - } - } - - PrivacyRiskSummary { - total_number_of_records, - total_number_of_combinations, - records_with_unique_combinations_count: records_with_unique_combinations.len(), - records_with_rare_combinations_count: records_with_rare_combinations.len(), - unique_combinations_count, - rare_combinations_count, - records_with_unique_combinations_proportion: (records_with_unique_combinations - .len() - as f64) - / (total_number_of_records as f64), - records_with_rare_combinations_proportion: (records_with_rare_combinations.len() - as f64) - / (total_number_of_records as f64), - unique_combinations_proportion: (unique_combinations_count as f64) - / (total_number_of_combinations as f64), - rare_combinations_proportion: (rare_combinations_count as f64) - / (total_number_of_combinations as f64), - } - }, - (self.durations.calc_privacy_risk) - ) - } - - /// Writes the aggregates counts to the file system in a csv/tsv like format - /// # Arguments: - /// * `aggregates_count` - Calculated aggregates count map - /// * `aggregates_path` - File path to be written - /// * `aggregates_delimiter` - Delimiter to use when writing to `aggregates_path` - /// * `resolution` - Reporting resolution used for data synthesis - /// * `protected` - Whether or not the counts were protected before calling this - pub fn write_aggregates_count( - &mut self, - aggregates_count: &AggregatesCountMap<'data_block>, - aggregates_path: &str, - aggregates_delimiter: char, - resolution: usize, - protected: bool, - ) -> Result<(), Error> { - info!("writing file {}", aggregates_path); - - measure_time!( - || { - let mut file = std::fs::File::create(aggregates_path)?; - - file.write_all( - format!( - "selections{}{}\n", - aggregates_delimiter, - if protected { - "protected_count" - } else { - "count" - } - ) - .as_bytes(), - )?; - file.write_all( - format!( - "selections{}{}\n", - aggregates_delimiter, - uround_down(self.data_block.records.len() as f64, resolution as f64) - ) - .as_bytes(), - )?; - for aggregate in aggregates_count.keys() { - file.write_all( - format!( - "{}{}{}\n", - Aggregator::format_aggregate_str(&self.data_block.headers, aggregate), - aggregates_delimiter, - aggregates_count[aggregate].count - ) - .as_bytes(), - )? - } - Ok(()) - }, - (self.durations.write_aggregates_count) - ) - } - - /// Writes the records sensitivity to the file system in a csv/tsv like format - /// # Arguments: - /// * `records_sensitivity` - Calculated records sensitivity - /// * `records_sensitivity_path` - File path to be written - pub fn write_records_sensitivity( - &mut self, - records_sensitivity: &RecordsSensitivitySlice, - records_sensitivity_path: &str, - ) -> Result<(), Error> { - info!("writing file {}", records_sensitivity_path); - - measure_time!( - || { - let mut file = std::fs::File::create(records_sensitivity_path)?; - - file.write_all("record_index\trecord_sensitivity\n".as_bytes())?; - for (i, sensitivity) in records_sensitivity.iter().enumerate() { - file.write_all(format!("{}\t{}\n", i, sensitivity).as_bytes())? - } - Ok(()) - }, - (self.durations.write_aggregates_count) - ) - } - - #[inline] - fn update_aggregate_progress( - progress_reporter: &mut Option, - n_processed: usize, - total: f64, - ) where - T: ReportProgress, - { - if let Some(r) = progress_reporter { - r.report(calc_percentage(n_processed, total)); - } - } -} - -impl<'data_block> Drop for Aggregator<'data_block> { - fn drop(&mut self) { - trace!("aggregator durations: {:#?}", self.durations); - } -} +mod rows_aggregator; diff --git a/packages/core/src/processing/aggregator/privacy_risk_summary.rs b/packages/core/src/processing/aggregator/privacy_risk_summary.rs new file mode 100644 index 0000000..b19f01e --- /dev/null +++ b/packages/core/src/processing/aggregator/privacy_risk_summary.rs @@ -0,0 +1,144 @@ +use super::typedefs::{AggregatesCountMap, RecordsSet}; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +/// Represents the privacy risk information related to a data block +#[cfg_attr(feature = "pyo3", pyclass)] +#[derive(Debug)] +pub struct PrivacyRiskSummary { + /// Total number of records on the data block + pub total_number_of_records: usize, + /// Total number of combinations aggregated (up to reporting length) + pub total_number_of_combinations: usize, + /// Number of records with unique combinations + pub records_with_unique_combinations_count: usize, + /// Number of records with rare combinations (combination count < resolution) + pub records_with_rare_combinations_count: usize, + /// Number of unique combinations + pub unique_combinations_count: usize, + /// Number of rare combinations + pub rare_combinations_count: usize, + /// Proportion of records containing unique combinations + pub records_with_unique_combinations_proportion: f64, + /// Proportion of records containing rare combinations + pub records_with_rare_combinations_proportion: f64, + /// Proportion of unique combinations + pub unique_combinations_proportion: f64, + /// Proportion of rare combinations + pub rare_combinations_proportion: f64, +} + +#[cfg(feature = "pyo3")] +#[cfg_attr(feature = "pyo3", pymethods)] +impl PrivacyRiskSummary { + /// Total number of records on the data block + #[getter] + fn total_number_of_records(&self) -> usize { + self.total_number_of_records + } + + /// Total number of combinations aggregated (up to reporting length) + #[getter] + fn total_number_of_combinations(&self) -> usize { + self.total_number_of_combinations + } + + /// Number of records with unique combinations + #[getter] + fn records_with_unique_combinations_count(&self) -> usize { + self.records_with_unique_combinations_count + } + + /// Number of records with rare combinations (combination count < resolution) + #[getter] + fn records_with_rare_combinations_count(&self) -> usize { + self.records_with_rare_combinations_count + } + + /// Number of unique combinations + #[getter] + fn unique_combinations_count(&self) -> usize { + self.unique_combinations_count + } + + /// Number of rare combinations + #[getter] + fn rare_combinations_count(&self) -> usize { + self.rare_combinations_count + } + + /// Proportion of records containing unique combinations + #[getter] + fn records_with_unique_combinations_proportion(&self) -> f64 { + self.records_with_unique_combinations_proportion + } + + /// Proportion of records containing rare combinations + #[getter] + fn records_with_rare_combinations_proportion(&self) -> f64 { + self.records_with_rare_combinations_proportion + } + + /// Proportion of unique combinations + #[getter] + fn unique_combinations_proportion(&self) -> f64 { + self.unique_combinations_proportion + } + + /// Proportion of rare combinations + #[getter] + fn rare_combinations_proportion(&self) -> f64 { + self.rare_combinations_proportion + } +} + +impl PrivacyRiskSummary { + #[inline] + /// Calculates ands returns the privacy risk related to the aggregates counts + /// # Arguments: + /// * `total_number_of_records` - Total number of records on the data block + /// * `aggregates_count` - Aggregates counts to compute the privacy risk for + /// * `resolution` - Reporting resolution used for data synthesis + pub fn from_aggregates_count( + total_number_of_records: usize, + aggregates_count: &AggregatesCountMap, + resolution: usize, + ) -> PrivacyRiskSummary { + let mut records_with_unique_combinations = RecordsSet::default(); + let mut records_with_rare_combinations = RecordsSet::default(); + let mut unique_combinations_count: usize = 0; + let mut rare_combinations_count: usize = 0; + let total_number_of_combinations = aggregates_count.len(); + + for count in aggregates_count.values() { + if count.count == 1 { + unique_combinations_count += 1; + records_with_unique_combinations.extend(&count.contained_in_records); + } + if count.count < resolution { + rare_combinations_count += 1; + records_with_rare_combinations.extend(&count.contained_in_records); + } + } + + PrivacyRiskSummary { + total_number_of_records, + total_number_of_combinations, + records_with_unique_combinations_count: records_with_unique_combinations.len(), + records_with_rare_combinations_count: records_with_rare_combinations.len(), + unique_combinations_count, + rare_combinations_count, + records_with_unique_combinations_proportion: (records_with_unique_combinations.len() + as f64) + / (total_number_of_records as f64), + records_with_rare_combinations_proportion: (records_with_rare_combinations.len() + as f64) + / (total_number_of_records as f64), + unique_combinations_proportion: (unique_combinations_count as f64) + / (total_number_of_combinations as f64), + rare_combinations_proportion: (rare_combinations_count as f64) + / (total_number_of_combinations as f64), + } + } +} diff --git a/packages/core/src/processing/aggregator/record_attrs_selector.rs b/packages/core/src/processing/aggregator/record_attrs_selector.rs index aa62a43..9c85c1d 100644 --- a/packages/core/src/processing/aggregator/record_attrs_selector.rs +++ b/packages/core/src/processing/aggregator/record_attrs_selector.rs @@ -1,5 +1,5 @@ use fnv::FnvHashMap; -use itertools::Itertools; +use std::sync::Arc; use crate::{ data_block::{typedefs::AttributeRowsMap, value::DataBlockValue}, @@ -13,19 +13,19 @@ use crate::{ /// Sensitivity in this context means the number of combinations that /// can be generated based on the number of non-empty values on the record. /// So attributes will suppressed from the record until record sensitivity <= `sensitivity_threshold` -pub struct RecordAttrsSelector<'length_range, 'data_block> { +pub struct RecordAttrsSelector<'length_range> { + /// Range to compute the number of combinations [1...reporting_length] + pub length_range: &'length_range [usize], /// Cache for how many values should be suppressed based of the number /// of non-empty attributes on the record cache: FnvHashMap, - /// Range to compute the number of combinations [1...reporting_length] - length_range: &'length_range [usize], /// Maximum sensitivity allowed for each record sensitivity_threshold: usize, /// Map with a data block value as key and all the attribute row indexes where it occurs as value - attr_rows_map: AttributeRowsMap<'data_block>, + attr_rows_map: Arc, } -impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block> { +impl<'length_range> RecordAttrsSelector<'length_range> { /// Returns a new RecordAttrsSelector /// # Arguments /// * `length_range` - Range to compute the number of combinations [1...reporting_length] @@ -36,8 +36,8 @@ impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block> pub fn new( length_range: &'length_range [usize], sensitivity_threshold: usize, - attr_rows_map: AttributeRowsMap<'data_block>, - ) -> RecordAttrsSelector<'length_range, 'data_block> { + attr_rows_map: Arc, + ) -> RecordAttrsSelector<'length_range> { RecordAttrsSelector { cache: FnvHashMap::default(), length_range, @@ -53,30 +53,30 @@ impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block> /// all records, higher the chance for it **not** to be suppressed /// /// # Arguments - /// * `record`: record to select attributes from + /// * `record` - record to select attributes from pub fn select_from_record( &mut self, - record: &'data_block [DataBlockValue], - ) -> Vec<&'data_block DataBlockValue> { + record: &[Arc], + ) -> Vec> { let n_attributes = record.len(); let selected_attrs_count = n_attributes - self.get_suppressed_count(n_attributes); if selected_attrs_count < n_attributes { let mut attr_count_map: AttributeCountMap = record .iter() - .map(|r| (r, self.attr_rows_map.get(r).unwrap().len())) + .map(|r| (r.clone(), self.attr_rows_map.get(r).unwrap().len())) .collect(); - let mut res: Vec<&DataBlockValue> = Vec::default(); + let mut res: Vec> = Vec::default(); for _ in 0..selected_attrs_count { if let Some(sample) = SynthesizerContext::sample_from_attr_counts(&attr_count_map) { + attr_count_map.remove(&sample); res.push(sample); - attr_count_map.remove(sample); } } res } else { - record.iter().collect_vec() + record.to_vec() } } diff --git a/packages/core/src/processing/aggregator/records_analysis_data.rs b/packages/core/src/processing/aggregator/records_analysis_data.rs new file mode 100644 index 0000000..0c831b7 --- /dev/null +++ b/packages/core/src/processing/aggregator/records_analysis_data.rs @@ -0,0 +1,247 @@ +use super::typedefs::{RecordsAnalysisByLenMap, RecordsByLenMap}; +use itertools::Itertools; +use std::io::{Error, Write}; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +use crate::utils::math::{calc_percentage, uround_down}; + +#[cfg_attr(feature = "pyo3", pyclass)] +#[derive(Clone)] +/// Analysis information related to a single record +pub struct RecordsAnalysis { + /// Number of records containing unique combinations + pub unique_combinations_records_count: usize, + /// Percentage of records containing unique combinations + pub unique_combinations_records_percentage: f64, + /// Number of records containing rare combinations + /// (unique os not taken into account) + pub rare_combinations_records_count: usize, + /// Percentage of records containing rare combinations + /// (unique os not taken into account) + pub rare_combinations_records_percentage: f64, + /// Count of unique + rare + pub risky_combinations_records_count: usize, + /// Percentage of unique + rare + pub risky_combinations_records_percentage: f64, +} + +impl RecordsAnalysis { + #[inline] + /// Created a new RecordsAnalysis with default values + pub fn default() -> RecordsAnalysis { + RecordsAnalysis { + unique_combinations_records_count: 0, + unique_combinations_records_percentage: 0.0, + rare_combinations_records_count: 0, + rare_combinations_records_percentage: 0.0, + risky_combinations_records_count: 0, + risky_combinations_records_percentage: 0.0, + } + } +} + +#[cfg(feature = "pyo3")] +#[cfg_attr(feature = "pyo3", pymethods)] +impl RecordsAnalysis { + #[getter] + /// Number of records containing unique combinations + fn unique_combinations_records_count(&self) -> usize { + self.unique_combinations_records_count + } + + #[getter] + /// Percentage of records containing unique combinations + fn unique_combinations_records_percentage(&self) -> f64 { + self.unique_combinations_records_percentage + } + + #[getter] + /// Number of records containing rare combinations + /// (unique os not taken into account) + fn rare_combinations_records_count(&self) -> usize { + self.rare_combinations_records_count + } + + #[getter] + /// Percentage of records containing rare combinations + /// (unique os not taken into account) + fn rare_combinations_records_percentage(&self) -> f64 { + self.rare_combinations_records_percentage + } + + #[getter] + /// Count of unique + rare + fn risky_combinations_records_count(&self) -> usize { + self.risky_combinations_records_count + } + + #[getter] + /// Percentage of unique + rare + fn risky_combinations_records_percentage(&self) -> f64 { + self.risky_combinations_records_percentage + } +} + +#[cfg_attr(feature = "pyo3", pyclass)] +/// Stores the records analysis for all records grouped by +/// combination length +pub struct RecordsAnalysisData { + /// Map of records analysis grouped by combination len + pub records_analysis_by_len: RecordsAnalysisByLenMap, +} + +impl RecordsAnalysisData { + /// Computes the record analysis from the arguments. + /// # Arguments + /// * `unique_records_by_len` - unique records grouped by length + /// * `rare_records_by_len` - rare records grouped by length + /// * `total_number_of_records` - Total number of records on the data block + /// * `reporting_length` - Reporting length used for the data aggregation + /// * `resolution` - Reporting resolution used for data synthesis + /// * `protect` - Whether or not the counts should be rounded to the + /// nearest smallest multiple of resolution + #[inline] + pub fn from_unique_rare_combinations_records_by_len( + unique_records_by_len: &RecordsByLenMap, + rare_records_by_len: &RecordsByLenMap, + total_n_records: usize, + reporting_length: usize, + resolution: usize, + protect: bool, + ) -> RecordsAnalysisData { + let total_n_records_f64 = if protect { + uround_down(total_n_records as f64, resolution as f64) + } else { + total_n_records + } as f64; + let records_analysis_by_len: RecordsAnalysisByLenMap = (1..=reporting_length) + .map(|l| { + let mut ra = RecordsAnalysis::default(); + + ra.unique_combinations_records_count = unique_records_by_len + .get(&l) + .map_or(0, |records| records.len()); + ra.rare_combinations_records_count = rare_records_by_len + .get(&l) + .map_or(0, |records| records.len()); + + if protect { + ra.unique_combinations_records_count = uround_down( + ra.unique_combinations_records_count as f64, + resolution as f64, + ); + ra.rare_combinations_records_count = + uround_down(ra.rare_combinations_records_count as f64, resolution as f64) + } + + ra.unique_combinations_records_percentage = calc_percentage( + ra.unique_combinations_records_count as f64, + total_n_records_f64, + ); + ra.rare_combinations_records_percentage = calc_percentage( + ra.rare_combinations_records_count as f64, + total_n_records_f64, + ); + ra.risky_combinations_records_count = + ra.unique_combinations_records_count + ra.rare_combinations_records_count; + ra.risky_combinations_records_percentage = calc_percentage( + ra.risky_combinations_records_count as f64, + total_n_records_f64, + ); + + (l, ra) + }) + .collect(); + RecordsAnalysisData { + records_analysis_by_len, + } + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl RecordsAnalysisData { + /// Returns the records analysis map grouped by length. + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + pub fn get_records_analysis_by_len(&self) -> RecordsAnalysisByLenMap { + self.records_analysis_by_len.clone() + } + + /// Returns the total number of records containing unique combinations + /// for all lengths + pub fn get_total_unique(&self) -> usize { + self.records_analysis_by_len + .values() + .map(|v| v.unique_combinations_records_count) + .sum() + } + + /// Returns the total number of records containing rare combinations + /// for all lengths + pub fn get_total_rare(&self) -> usize { + self.records_analysis_by_len + .values() + .map(|v| v.rare_combinations_records_count) + .sum() + } + + /// Returns the total number of records containing risky (unique + rare) + /// combinations for all lengths + pub fn get_total_risky(&self) -> usize { + self.records_analysis_by_len + .values() + .map(|v| v.risky_combinations_records_count) + .sum() + } + + /// Writes the records analysis to the file system in a csv/tsv like format + /// # Arguments: + /// * `records_analysis_path` - File path to be written + /// * `records_analysis_delimiter` - Delimiter to use when writing to `records_analysis_path` + pub fn write_records_analysis( + &mut self, + records_analysis_path: &str, + records_analysis_delimiter: char, + ) -> Result<(), Error> { + let mut file = std::fs::File::create(records_analysis_path)?; + + file.write_all( + format!( + "combo_length{}sen_rare{}sen_rare_pct{}sen_unique{}sen_unique_pct{}sen_risky{}sen_risky_pct\n", + records_analysis_delimiter, + records_analysis_delimiter, + records_analysis_delimiter, + records_analysis_delimiter, + records_analysis_delimiter, + records_analysis_delimiter + ) + .as_bytes(), + )?; + for l in self.records_analysis_by_len.keys().sorted() { + let ra = &self.records_analysis_by_len[l]; + + file.write_all( + format!( + "{}{}{}{}{}{}{}{}{}{}{}{}{}\n", + l, + records_analysis_delimiter, + ra.rare_combinations_records_count, + records_analysis_delimiter, + ra.rare_combinations_records_percentage, + records_analysis_delimiter, + ra.unique_combinations_records_count, + records_analysis_delimiter, + ra.unique_combinations_records_percentage, + records_analysis_delimiter, + ra.risky_combinations_records_count, + records_analysis_delimiter, + ra.risky_combinations_records_percentage, + ) + .as_bytes(), + )? + } + Ok(()) + } +} diff --git a/packages/core/src/processing/aggregator/rows_aggregator.rs b/packages/core/src/processing/aggregator/rows_aggregator.rs new file mode 100644 index 0000000..9913d78 --- /dev/null +++ b/packages/core/src/processing/aggregator/rows_aggregator.rs @@ -0,0 +1,195 @@ +use super::{ + record_attrs_selector::RecordAttrsSelector, + typedefs::{AggregatesCountMap, EnumeratedDataBlockRecords, RecordsSensitivity}, + value_combination::ValueCombination, + AggregatedCount, +}; +use itertools::Itertools; +use log::info; +use std::sync::Arc; + +#[cfg(feature = "rayon")] +use rayon::prelude::*; + +#[cfg(feature = "rayon")] +use std::sync::Mutex; + +use crate::{ + data_block::block::DataBlock, + utils::{ + math::calc_n_combinations_range, + reporting::{ReportProgress, SendableProgressReporter, SendableProgressReporterRef}, + }, +}; + +pub struct RowsAggregatorResult { + pub all_combs_count: u64, + pub selected_combs_count: u64, + pub aggregates_count: AggregatesCountMap, + pub records_sensitivity: RecordsSensitivity, +} + +impl RowsAggregatorResult { + #[inline] + pub fn new(total_n_records: usize) -> RowsAggregatorResult { + let mut records_sensitivity = RecordsSensitivity::default(); + + records_sensitivity.resize(total_n_records, 0); + RowsAggregatorResult { + all_combs_count: 0, + selected_combs_count: 0, + aggregates_count: AggregatesCountMap::default(), + records_sensitivity, + } + } +} + +pub struct RowsAggregator<'length_range> { + data_block: Arc, + enumerated_records: EnumeratedDataBlockRecords, + record_attrs_selector: RecordAttrsSelector<'length_range>, +} + +impl<'length_range> RowsAggregator<'length_range> { + #[inline] + pub fn new( + data_block: Arc, + enumerated_records: EnumeratedDataBlockRecords, + record_attrs_selector: RecordAttrsSelector<'length_range>, + ) -> RowsAggregator { + RowsAggregator { + data_block, + enumerated_records, + record_attrs_selector, + } + } + + #[cfg(feature = "rayon")] + #[inline] + pub fn aggregate_all( + total_n_records: usize, + rows_aggregators: &mut Vec, + progress_reporter: &mut Option, + ) -> RowsAggregatorResult + where + T: ReportProgress, + { + let sendable_pr = + Arc::new(Mutex::new(progress_reporter.as_mut().map(|r| { + SendableProgressReporter::new(total_n_records as f64, 1.0, r) + }))); + + RowsAggregator::join_partial_results( + total_n_records, + rows_aggregators + .par_iter_mut() + .map(|ra| ra.aggregate_rows(&mut sendable_pr.clone())) + .collect(), + ) + } + + #[cfg(not(feature = "rayon"))] + #[inline] + pub fn aggregate_all( + total_n_records: usize, + rows_aggregators: &mut Vec, + progress_reporter: &mut Option, + ) -> RowsAggregatorResult + where + T: ReportProgress, + { + let mut sendable_pr = progress_reporter + .as_mut() + .map(|r| SendableProgressReporter::new(total_n_records as f64, 1.0, r)); + + RowsAggregator::join_partial_results( + total_n_records, + rows_aggregators + .iter_mut() + .map(|ra| ra.aggregate_rows(&mut sendable_pr)) + .collect(), + ) + } + + #[inline] + fn join_partial_results( + total_n_records: usize, + mut partial_results: Vec, + ) -> RowsAggregatorResult { + info!("joining aggregated partial results..."); + + // take last element and the initial result + // or use the default one if there are no results + let mut final_result = partial_results + .pop() + .unwrap_or_else(|| RowsAggregatorResult::new(total_n_records)); + + // use drain instead of fold, so we do not duplicate memory + for mut partial_result in partial_results.drain(..) { + // join counts + final_result.all_combs_count += partial_result.all_combs_count; + final_result.selected_combs_count += partial_result.selected_combs_count; + + // join aggregated counts + for (comb, value) in partial_result.aggregates_count.drain() { + let final_count = final_result + .aggregates_count + .entry(comb) + .or_insert_with(AggregatedCount::default); + final_count.count += value.count; + final_count + .contained_in_records + .extend(value.contained_in_records); + } + + // join records sensitivity + for (i, sensitivity) in partial_result.records_sensitivity.drain(..).enumerate() { + final_result.records_sensitivity[i] += sensitivity; + } + } + final_result + } + + #[inline] + fn aggregate_rows( + &mut self, + progress_reporter: &mut SendableProgressReporterRef, + ) -> RowsAggregatorResult + where + T: ReportProgress, + { + let mut result = RowsAggregatorResult::new(self.data_block.records.len()); + + for (record_index, record) in self.enumerated_records.iter() { + let mut selected_attrs = self + .record_attrs_selector + .select_from_record(&record.values); + + // sort the attributes here, so combinations will be already sorted + // and we do not need to sort entry by entry on the loop below + selected_attrs.sort_by_key(|k| k.format_str_using_headers(&self.data_block.headers)); + + result.all_combs_count += calc_n_combinations_range( + record.values.len(), + self.record_attrs_selector.length_range, + ); + + for l in self.record_attrs_selector.length_range { + for mut c in selected_attrs.iter().combinations(*l) { + let current_count = result + .aggregates_count + .entry(ValueCombination::new( + c.drain(..).map(|k| (*k).clone()).collect(), + )) + .or_insert_with(AggregatedCount::default); + current_count.count += 1; + current_count.contained_in_records.insert(*record_index); + result.records_sensitivity[*record_index] += 1; + result.selected_combs_count += 1; + } + } + SendableProgressReporter::update_progress(progress_reporter, 1.0); + } + result + } +} diff --git a/packages/core/src/processing/aggregator/typedefs.rs b/packages/core/src/processing/aggregator/typedefs.rs index b651e5e..a2f9d1d 100644 --- a/packages/core/src/processing/aggregator/typedefs.rs +++ b/packages/core/src/processing/aggregator/typedefs.rs @@ -1,26 +1,36 @@ -use super::AggregatedCount; +use super::{ + data_aggregator::AggregatedCount, records_analysis_data::RecordsAnalysis, + value_combination::ValueCombination, +}; use fnv::{FnvHashMap, FnvHashSet}; +use std::sync::Arc; -use crate::data_block::value::DataBlockValue; +use crate::data_block::record::DataBlockRecord; /// Set of records where the key is the record index starting in 0 pub type RecordsSet = FnvHashSet; -/// Vector of data block values representing a value combination (sorted by `{header_name}:{block_value}`) -pub type ValueCombination<'data_block_value> = Vec<&'data_block_value DataBlockValue>; - -/// Slice of ValueCombination -pub type ValueCombinationSlice<'data_block_value> = [&'data_block_value DataBlockValue]; - /// Maps a value combination to its aggregated count -pub type AggregatesCountMap<'data_block_value> = - FnvHashMap, AggregatedCount>; +pub type AggregatesCountMap = FnvHashMap; + +/// Maps a value combination represented as a string to its aggregated count +pub type AggregatesCountStringMap = FnvHashMap; /// Maps a length (1,2,3... up to reporting length) to a determined count pub type AggregatedCountByLenMap = FnvHashMap; +/// Maps a length (1,2,3... up to reporting length) to a record set +pub type RecordsByLenMap = FnvHashMap; + /// A vector of sensitivities for each record (the vector index is the record index) pub type RecordsSensitivity = Vec; /// Slice of RecordsSensitivity pub type RecordsSensitivitySlice = [usize]; + +/// Vector of tuples: +/// (index of the original record, reference to the original record) +pub type EnumeratedDataBlockRecords = Vec<(usize, Arc)>; + +/// Map of records analysis grouped by combination len +pub type RecordsAnalysisByLenMap = FnvHashMap; diff --git a/packages/core/src/processing/aggregator/value_combination.rs b/packages/core/src/processing/aggregator/value_combination.rs new file mode 100644 index 0000000..42fe3d6 --- /dev/null +++ b/packages/core/src/processing/aggregator/value_combination.rs @@ -0,0 +1,152 @@ +use serde::{ + de::{self, Visitor}, + Deserialize, Serialize, +}; +use std::{ + fmt::Display, + marker::PhantomData, + ops::{Deref, DerefMut}, + str::FromStr, + sync::Arc, +}; + +use crate::data_block::{ + typedefs::DataBlockHeadersSlice, + value::{DataBlockValue, ParseDataBlockValueError}, +}; + +const COMBINATIONS_DELIMITER: char = ';'; + +#[derive(Eq, PartialEq, Hash)] +/// Wraps a vector of data block values representing a value +/// combination (sorted by `{header_name}:{block_value}`) +pub struct ValueCombination { + combination: Vec>, +} + +impl ValueCombination { + #[inline] + /// Creates a new ValueCombination with default values + pub fn default() -> ValueCombination { + ValueCombination::new(Vec::default()) + } + + #[inline] + /// Creates a new ValueCombination + /// # Arguments + /// * `combination` - raw vector of value combinations + /// sorted by `{header_name}:{block_value}` + pub fn new(combination: Vec>) -> ValueCombination { + ValueCombination { combination } + } + + /// Formats a value combination as String using the headers. + /// The result is formatted as: + /// `{header_name}:{block_value};{header_name}:{block_value}...` + /// # Arguments + /// * `headers` - Data block headers + /// * `combination_delimiter` - Delimiter used to join combinations + #[inline] + pub fn format_str_using_headers( + &self, + headers: &DataBlockHeadersSlice, + combination_delimiter: &str, + ) -> String { + let mut str = String::default(); + + if let Some(comb) = self.combination.get(0) { + str.push_str(&comb.format_str_using_headers(headers)); + } + for comb in self.combination.iter().skip(1) { + str += combination_delimiter; + str.push_str(&comb.format_str_using_headers(headers)); + } + str + } +} + +impl Display for ValueCombination { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(comb) = self.combination.get(0) { + write!(f, "{}", comb)?; + } + for comb in self.combination.iter().skip(1) { + write!(f, "{}{}", COMBINATIONS_DELIMITER, comb)?; + } + Ok(()) + } +} + +impl FromStr for ValueCombination { + type Err = ParseDataBlockValueError; + + /// Creates a new ValueCombination by parsing `str_value` + fn from_str(str_value: &str) -> Result { + Ok(ValueCombination::new( + str_value + .split(COMBINATIONS_DELIMITER) + .map(|v| Ok(Arc::new(DataBlockValue::from_str(v)?))) + .collect::>, Self::Err>>()?, + )) + } +} + +impl Deref for ValueCombination { + type Target = Vec>; + + fn deref(&self) -> &Self::Target { + &self.combination + } +} + +impl DerefMut for ValueCombination { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.combination + } +} + +// Implementing a custom serializer, so this can +// be properly written to a json file +impl Serialize for ValueCombination { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&format!("{}", self)) + } +} + +struct ValueCombinationVisitor { + marker: PhantomData ValueCombination>, +} + +impl ValueCombinationVisitor { + fn new() -> Self { + ValueCombinationVisitor { + marker: PhantomData, + } + } +} + +impl<'de> Visitor<'de> for ValueCombinationVisitor { + type Value = ValueCombination; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string representing the value combination") + } + + fn visit_str(self, v: &str) -> Result { + ValueCombination::from_str(v).map_err(E::custom) + } +} + +// Implementing a custom deserializer, so this can +// be properly read from a json file +impl<'de> Deserialize<'de> for ValueCombination { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_string(ValueCombinationVisitor::new()) + } +} diff --git a/packages/core/src/processing/evaluator/bucket.rs b/packages/core/src/processing/evaluator/bucket.rs deleted file mode 100644 index 3400e91..0000000 --- a/packages/core/src/processing/evaluator/bucket.rs +++ /dev/null @@ -1,86 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -const INITIAL_BIN: usize = 10; -const BIN_RATIO: usize = 2; - -/// Stores the preservation information related to a particular bucket. -/// A bucket stores count in a certain range value. -/// For example: (0, 10], (10, 20], (20, 40]... -#[derive(Debug)] -pub struct PreservationByCountBucket { - /// How many elements are stored in the bucket - pub size: usize, - /// Preservation sum of all elements in the bucket - pub preservation_sum: f64, - /// Combination length sum of all elements in the bucket - pub length_sum: usize, -} - -impl PreservationByCountBucket { - /// Return a new PreservationByCountBucket with default values - pub fn default() -> PreservationByCountBucket { - PreservationByCountBucket { - size: 0, - preservation_sum: 0.0, - length_sum: 0, - } - } - - /// Adds a new value to the bucket - /// # Arguments - /// * `preservation` - Preservation related to the value - /// * `length` - Combination length related to the value - pub fn add(&mut self, preservation: f64, length: usize) { - self.size += 1; - self.preservation_sum += preservation; - self.length_sum += length; - } -} - -/// Stores the bins for each bucket. For example: -/// [10, 20, 40, 80, 160...] -#[derive(Debug)] -pub struct PreservationByCountBucketBins { - /// Bins where the index of the vector is the bin index, - /// and the value is the max value count allowed on the bin - bins: Vec, -} - -impl PreservationByCountBucketBins { - /// Generates a new PreservationByCountBucketBins with bins up to a `max_val` - /// # Arguments - /// * `max_val` - Max value allowed on the last bucket (inclusive) - pub fn new(max_val: usize) -> PreservationByCountBucketBins { - let mut bins: Vec = vec![INITIAL_BIN]; - loop { - let last = *bins.last().unwrap(); - if last >= max_val { - break; - } - bins.push(last * BIN_RATIO); - } - PreservationByCountBucketBins { bins } - } - - /// Find the first `bucket_max_val` where `val >= bucket_max_val` - /// # Arguments - /// * `val` - Count to look in which bucket it will fit - pub fn find_bucket_max_val(&self, val: usize) -> usize { - // find first element x where x >= val - self.bins[self.bins.partition_point(|x| *x < val)] - } -} - -impl Deref for PreservationByCountBucketBins { - type Target = Vec; - - fn deref(&self) -> &Self::Target { - &self.bins - } -} - -impl DerefMut for PreservationByCountBucketBins { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.bins - } -} diff --git a/packages/core/src/processing/evaluator/data_evaluator.rs b/packages/core/src/processing/evaluator/data_evaluator.rs new file mode 100644 index 0000000..5f7a7da --- /dev/null +++ b/packages/core/src/processing/evaluator/data_evaluator.rs @@ -0,0 +1,252 @@ +use super::preservation_by_count::{PreservationByCountBucketBins, PreservationByCountBuckets}; +use super::rare_combinations_comparison_data::RareCombinationsComparisonData; +use fnv::FnvHashSet; +use log::info; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +use crate::processing::aggregator::aggregated_data::AggregatedData; +use crate::processing::aggregator::typedefs::AggregatedCountByLenMap; +use crate::processing::aggregator::value_combination::ValueCombination; +use crate::processing::evaluator::preservation_bucket::PreservationBucket; +use crate::processing::evaluator::preservation_by_length::PreservationByLengthBuckets; +use crate::utils::time::ElapsedDurationLogger; + +#[cfg_attr(feature = "pyo3", pyclass)] +/// Evaluates aggregated, sensitive and synthesized data +pub struct Evaluator {} + +impl Evaluator { + /// Returns a new Evaluator + pub fn default() -> Evaluator { + Evaluator {} + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl Evaluator { + /// Returns a new Evaluator + #[cfg(feature = "pyo3")] + #[new] + pub fn default_pyo3() -> Evaluator { + Evaluator::default() + } + + /// Calculates the leakage counts grouped by combination length + /// (how many attribute combinations exist on the sensitive data and + /// appear on the synthetic data with `count < resolution`). + /// By design this should be `0` + /// # Arguments + /// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data + /// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data + /// * `resolution` - Reporting resolution used for data synthesis + pub fn calc_leakage_count( + &self, + sensitive_aggregated_data: &AggregatedData, + synthetic_aggregated_data: &AggregatedData, + resolution: usize, + ) -> AggregatedCountByLenMap { + let _duration_logger = ElapsedDurationLogger::new("leakage count calculation"); + let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); + + info!("calculating rare sensitive combination leakages by length"); + + for (sensitive_agg, sensitive_count) in sensitive_aggregated_data.aggregates_count.iter() { + if sensitive_count.count < resolution { + if let Some(synthetic_count) = synthetic_aggregated_data + .aggregates_count + .get(sensitive_agg) + { + if synthetic_count.count < resolution { + let leaks = result.entry(sensitive_agg.len()).or_insert(0); + *leaks += 1; + } + } + } + } + result + } + + /// Calculates the fabricated counts grouped by combination length + /// (how many attribute combinations exist on the synthetic data that do not + /// exist on the sensitive data). + /// By design this should be `0` + /// # Arguments + /// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data + /// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data + pub fn calc_fabricated_count( + &self, + sensitive_aggregated_data: &AggregatedData, + synthetic_aggregated_data: &AggregatedData, + ) -> AggregatedCountByLenMap { + let _duration_logger = ElapsedDurationLogger::new("fabricated count calculation"); + let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); + + info!("calculating fabricated synthetic combinations by length"); + + for synthetic_agg in synthetic_aggregated_data.aggregates_count.keys() { + if sensitive_aggregated_data + .aggregates_count + .get(synthetic_agg) + .is_none() + { + let fabricated = result.entry(synthetic_agg.len()).or_insert(0); + *fabricated += 1; + } + } + result + } + + /// Calculates the preservation information grouped by combination count. + /// An example output might be a map like: + /// `{ 10 -> PreservationBucket, 20 -> PreservationBucket, ...}` + /// # Arguments + /// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data + /// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data + /// * `resolution` - Reporting resolution used for data synthesis + pub fn calc_preservation_by_count( + &self, + sensitive_aggregated_data: &AggregatedData, + synthetic_aggregated_data: &AggregatedData, + resolution: usize, + ) -> PreservationByCountBuckets { + let _duration_logger = ElapsedDurationLogger::new("preservation by count calculation"); + + info!( + "calculating preservation by count with resolution: {}", + resolution + ); + + let max_syn_count = *synthetic_aggregated_data + .aggregates_count + .values() + .map(|a| &a.count) + .max() + .unwrap_or(&0); + let bins: PreservationByCountBucketBins = PreservationByCountBucketBins::new(max_syn_count); + let mut buckets: PreservationByCountBuckets = PreservationByCountBuckets::default(); + let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default(); + + for (comb, count) in sensitive_aggregated_data.aggregates_count.iter() { + // exclude sensitive rare combinations + if count.count >= resolution && !processed_combs.contains(comb) { + buckets.populate( + &bins, + comb, + &sensitive_aggregated_data.aggregates_count, + &synthetic_aggregated_data.aggregates_count, + ); + processed_combs.insert(comb); + } + } + for comb in synthetic_aggregated_data.aggregates_count.keys() { + if !processed_combs.contains(comb) { + buckets.populate( + &bins, + comb, + &sensitive_aggregated_data.aggregates_count, + &synthetic_aggregated_data.aggregates_count, + ); + processed_combs.insert(comb); + } + } + + // fill empty buckets with default value + for bin in bins.iter() { + buckets + .entry(*bin) + .or_insert_with(PreservationBucket::default); + } + + buckets + } + + /// Calculates the preservation information grouped by combination length. + /// An example output might be a map like: + /// `{ 1 -> PreservationBucket, 2 -> PreservationBucket, ...}` + /// # Arguments + /// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data + /// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data + /// * `resolution` - Reporting resolution used for data synthesis + pub fn calc_preservation_by_length( + &self, + sensitive_aggregated_data: &AggregatedData, + synthetic_aggregated_data: &AggregatedData, + resolution: usize, + ) -> PreservationByLengthBuckets { + let _duration_logger = ElapsedDurationLogger::new("preservation by length calculation"); + + info!( + "calculating preservation by length with resolution: {}", + resolution + ); + + let mut buckets: PreservationByLengthBuckets = PreservationByLengthBuckets::default(); + let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default(); + + for (comb, count) in sensitive_aggregated_data.aggregates_count.iter() { + // exclude sensitive rare combinations + if count.count >= resolution && !processed_combs.contains(comb) { + buckets.populate( + comb, + &sensitive_aggregated_data.aggregates_count, + &synthetic_aggregated_data.aggregates_count, + ); + processed_combs.insert(comb); + } + } + for comb in synthetic_aggregated_data.aggregates_count.keys() { + if !processed_combs.contains(comb) { + buckets.populate( + comb, + &sensitive_aggregated_data.aggregates_count, + &synthetic_aggregated_data.aggregates_count, + ); + processed_combs.insert(comb); + } + } + + // fill empty buckets with default value + for l in 1..=synthetic_aggregated_data.reporting_length { + buckets.entry(l).or_insert_with(PreservationBucket::default); + } + + buckets + } + + //// Compares the rare combinations on the synthetic data with + /// the sensitive data counts + /// # Arguments + /// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data + /// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data + /// * `resolution` - Reporting resolution used for data synthesis + /// * `combination_delimiter` - Delimiter used to join combinations and format then + /// as strings + /// * `protect` - Whether or not the sensitive counts should be rounded to the + /// nearest smallest multiple of resolution + pub fn compare_synthetic_and_sensitive_rare( + &self, + synthetic_aggregated_data: &AggregatedData, + sensitive_aggregated_data: &AggregatedData, + resolution: usize, + combination_delimiter: &str, + protect: bool, + ) -> RareCombinationsComparisonData { + let _duration_logger = + ElapsedDurationLogger::new("synthetic and sensitive rare comparison"); + RareCombinationsComparisonData::from_synthetic_and_sensitive_aggregated_data( + synthetic_aggregated_data, + sensitive_aggregated_data, + resolution, + combination_delimiter, + protect, + ) + } +} + +#[cfg(feature = "pyo3")] +pub fn register(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/packages/core/src/processing/evaluator/mod.rs b/packages/core/src/processing/evaluator/mod.rs index b11dd8a..0c325e7 100644 --- a/packages/core/src/processing/evaluator/mod.rs +++ b/packages/core/src/processing/evaluator/mod.rs @@ -1,227 +1,23 @@ -pub mod bucket; +/// Module defining routines to calculate preservation by count +pub mod preservation_by_count; + +/// Module defining routines to calculate preservation by length +pub mod preservation_by_length; + +/// Module defining a bucket used to calculate preservation by count +/// and length +pub mod preservation_bucket; + +/// Defines structures related to rare combination comparisons +/// between the synthetic and sensitive datasets +pub mod rare_combinations_comparison_data; + +/// Type definitions related to the evaluation process pub mod typedefs; -use self::bucket::{PreservationByCountBucket, PreservationByCountBucketBins}; -use self::typedefs::PreservationByCountBuckets; -use fnv::FnvHashSet; -use instant::Duration; -use log::Level::Trace; -use log::{info, log_enabled, trace}; +mod data_evaluator; -use crate::measure_time; -use crate::processing::aggregator::typedefs::ValueCombination; -use crate::utils::time::ElapsedDuration; +pub use data_evaluator::Evaluator; -use super::aggregator::typedefs::{ - AggregatedCountByLenMap, AggregatesCountMap, ValueCombinationSlice, -}; - -#[derive(Debug)] -struct EvaluatorDurations { - calc_leakage_count: Duration, - calc_fabricated_count: Duration, - calc_preservation_by_count: Duration, - calc_combination_loss: Duration, -} - -/// Evaluates aggregated, sensitive and synthesized data -pub struct Evaluator { - durations: EvaluatorDurations, -} - -impl Evaluator { - /// Returns a new Evaluator - #[inline] - pub fn default() -> Evaluator { - Evaluator { - durations: EvaluatorDurations { - calc_leakage_count: Duration::default(), - calc_fabricated_count: Duration::default(), - calc_preservation_by_count: Duration::default(), - calc_combination_loss: Duration::default(), - }, - } - } - - /// Calculates the leakage counts grouped by combination length - /// (how many attribute combinations exist on the sensitive data and - /// appear on the synthetic data with `count < resolution`). - /// By design this should be `0` - /// # Arguments - /// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data - /// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data - pub fn calc_leakage_count( - &mut self, - sensitive_aggregates: &AggregatesCountMap, - synthetic_aggregates: &AggregatesCountMap, - resolution: usize, - ) -> AggregatedCountByLenMap { - info!("calculating rare sensitive combination leakages by length"); - measure_time!( - || { - let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); - - for (sensitive_agg, sensitive_count) in sensitive_aggregates.iter() { - if sensitive_count.count < resolution { - if let Some(synthetic_count) = synthetic_aggregates.get(sensitive_agg) { - if synthetic_count.count < resolution { - let leaks = result.entry(sensitive_agg.len()).or_insert(0); - *leaks += 1; - } - } - } - } - result - }, - (self.durations.calc_leakage_count) - ) - } - - /// Calculates the fabricated counts grouped by combination length - /// (how many attribute combinations exist on the synthetic data that do not - /// exist on the sensitive data). - /// By design this should be `0` - /// # Arguments - /// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data - /// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data - pub fn calc_fabricated_count( - &mut self, - sensitive_aggregates: &AggregatesCountMap, - synthetic_aggregates: &AggregatesCountMap, - ) -> AggregatedCountByLenMap { - info!("calculating fabricated synthetic combinations by length"); - measure_time!( - || { - let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default(); - - for synthetic_agg in synthetic_aggregates.keys() { - if sensitive_aggregates.get(synthetic_agg).is_none() { - let fabricated = result.entry(synthetic_agg.len()).or_insert(0); - *fabricated += 1; - } - } - result - }, - (self.durations.calc_fabricated_count) - ) - } - - /// Calculates the preservation information by bucket. - /// An example output might be a map like: - /// `{ 10 -> PreservationByCountBucket, 20 -> PreservationByCountBucket, ...}` - /// # Arguments - /// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data - /// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data - /// * `resolution` - Reporting resolution used for data synthesis - pub fn calc_preservation_by_count( - &mut self, - sensitive_aggregates: &AggregatesCountMap, - synthetic_aggregates: &AggregatesCountMap, - resolution: usize, - ) -> PreservationByCountBuckets { - info!( - "calculating preservation by count with resolution: {}", - resolution - ); - measure_time!( - || { - let max_syn_count = *synthetic_aggregates - .values() - .map(|a| &a.count) - .max() - .unwrap_or(&0); - let bins: PreservationByCountBucketBins = - PreservationByCountBucketBins::new(max_syn_count); - let mut buckets: PreservationByCountBuckets = PreservationByCountBuckets::default(); - let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default(); - - for (comb, count) in sensitive_aggregates.iter() { - if count.count >= resolution && !processed_combs.contains(comb) { - Evaluator::populate_buckets( - &mut buckets, - &bins, - comb, - sensitive_aggregates, - synthetic_aggregates, - ); - processed_combs.insert(comb); - } - } - for comb in synthetic_aggregates.keys() { - if !processed_combs.contains(comb) { - Evaluator::populate_buckets( - &mut buckets, - &bins, - comb, - sensitive_aggregates, - synthetic_aggregates, - ); - processed_combs.insert(comb); - } - } - - // fill empty buckets with default value - for bin in bins.iter() { - if !buckets.contains_key(bin) { - buckets.insert(*bin, PreservationByCountBucket::default()); - } - } - - buckets - }, - (self.durations.calc_preservation_by_count) - ) - } - - /// Calculates the combination loss for the calculated bucket preservation information - /// (`combination_loss = avg(1 - bucket_preservation_sum / bucket_size) for all buckets`) - /// # Arguments - /// - `buckets` - Calculated preservation buckets - pub fn calc_combination_loss(&mut self, buckets: &PreservationByCountBuckets) -> f64 { - measure_time!( - || { - buckets - .values() - .map(|b| 1.0 - (b.preservation_sum / (b.size as f64))) - .sum::() - / (buckets.len() as f64) - }, - (self.durations.calc_combination_loss) - ) - } - - #[inline] - fn populate_buckets( - buckets: &mut PreservationByCountBuckets, - bins: &PreservationByCountBucketBins, - comb: &ValueCombinationSlice, - sensitive_aggregates: &AggregatesCountMap, - synthetic_aggregates: &AggregatesCountMap, - ) { - let sen_count = match sensitive_aggregates.get(comb) { - Some(count) => count.count, - _ => 0, - }; - let syn_count = match synthetic_aggregates.get(comb) { - Some(count) => count.count, - _ => 0, - }; - let preservation = if sen_count > 0 { - // max value is 100%, so use min - f64::min((syn_count as f64) / (sen_count as f64), 1.0) - } else { - 0.0 - }; - - buckets - .entry(bins.find_bucket_max_val(syn_count)) - .or_insert_with(PreservationByCountBucket::default) - .add(preservation, comb.len()); - } -} - -impl Drop for Evaluator { - fn drop(&mut self) { - trace!("evaluator durations: {:#?}", self.durations); - } -} +#[cfg(feature = "pyo3")] +pub use data_evaluator::register; diff --git a/packages/core/src/processing/evaluator/preservation_bucket.rs b/packages/core/src/processing/evaluator/preservation_bucket.rs new file mode 100644 index 0000000..760c4d4 --- /dev/null +++ b/packages/core/src/processing/evaluator/preservation_bucket.rs @@ -0,0 +1,58 @@ +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +/// Bucket to store preservation information +#[cfg_attr(feature = "pyo3", pyclass)] +#[derive(Debug, Clone)] +pub struct PreservationBucket { + /// How many elements are stored in the bucket + pub size: usize, + /// Preservation sum of all elements in the bucket + pub preservation_sum: f64, + /// Combination length sum of all elements in the bucket + pub length_sum: usize, + /// Combination count sum + pub combination_count_sum: usize, +} + +impl PreservationBucket { + /// Return a new PreservationBucket with default values + pub fn default() -> PreservationBucket { + PreservationBucket { + size: 0, + preservation_sum: 0.0, + length_sum: 0, + combination_count_sum: 0, + } + } + + /// Adds a new value to the bucket + /// # Arguments + /// * `preservation` - Preservation related to the value + /// * `length` - Combination length related to the value + /// * `combination_count` - Combination count related to the value + pub fn add(&mut self, preservation: f64, length: usize, combination_count: usize) { + self.size += 1; + self.preservation_sum += preservation; + self.length_sum += length; + self.combination_count_sum += combination_count; + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl PreservationBucket { + /// Gets the mean preservation for the values in this bucket + pub fn get_mean_preservation(&self) -> f64 { + self.preservation_sum / (self.size as f64) + } + + /// Gets the mean combination length for the values in this bucket + pub fn get_mean_combination_length(&self) -> f64 { + (self.length_sum as f64) / (self.size as f64) + } + + /// Gets the mean combination count for the values in this bucket + pub fn get_mean_combination_count(&self) -> f64 { + (self.combination_count_sum as f64) / (self.size as f64) + } +} diff --git a/packages/core/src/processing/evaluator/preservation_by_count.rs b/packages/core/src/processing/evaluator/preservation_by_count.rs new file mode 100644 index 0000000..47e1728 --- /dev/null +++ b/packages/core/src/processing/evaluator/preservation_by_count.rs @@ -0,0 +1,191 @@ +use super::{preservation_bucket::PreservationBucket, typedefs::PreservationBucketsMap}; +use itertools::Itertools; +use std::{ + io::{Error, Write}, + ops::{Deref, DerefMut}, +}; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +use crate::{ + processing::aggregator::{typedefs::AggregatesCountMap, value_combination::ValueCombination}, + utils::time::ElapsedDurationLogger, +}; + +const INITIAL_BIN: usize = 10; +const BIN_RATIO: usize = 2; + +#[cfg_attr(feature = "pyo3", pyclass)] +/// Wrapping struct mapping the max value allowed in the bucket +/// to its correspondent PreservationBucket. +/// In this context a PreservationBucket stores the preservation information +/// related to a particular bucket. +/// A bucket stores counts in a certain range value. +/// For example: (0, 10], (10, 20], (20, 40]... +pub struct PreservationByCountBuckets { + buckets_map: PreservationBucketsMap, +} + +impl PreservationByCountBuckets { + /// Returns a new default PreservationByCountBuckets + #[inline] + pub fn default() -> PreservationByCountBuckets { + PreservationByCountBuckets { + buckets_map: PreservationBucketsMap::default(), + } + } + + #[inline] + pub(super) fn populate( + &mut self, + bins: &PreservationByCountBucketBins, + comb: &ValueCombination, + sensitive_aggregates: &AggregatesCountMap, + synthetic_aggregates: &AggregatesCountMap, + ) { + let sen_count = match sensitive_aggregates.get(comb) { + Some(count) => count.count, + _ => 0, + }; + let syn_count = match synthetic_aggregates.get(comb) { + Some(count) => count.count, + _ => 0, + }; + let preservation = if sen_count > 0 { + // max value is 100%, so use min + f64::min((syn_count as f64) / (sen_count as f64), 1.0) + } else { + 0.0 + }; + + self.buckets_map + .entry(bins.find_bucket_max_val(syn_count)) + .or_insert_with(PreservationBucket::default) + .add(preservation, comb.len(), syn_count); + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl PreservationByCountBuckets { + /// Returns the actual buckets grouped by its max value. + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + pub fn get_buckets(&self) -> PreservationBucketsMap { + self.buckets_map.clone() + } + + /// Calculates the combination loss for the calculated bucket preservation information + /// (`combination_loss = avg(1 - bucket_preservation_sum / bucket_size) for all buckets`) + pub fn calc_combination_loss(&self) -> f64 { + let _duration_logger = ElapsedDurationLogger::new("combination loss calculation"); + + self.buckets_map + .values() + .map(|b| 1.0 - (b.preservation_sum / (b.size as f64))) + .sum::() + / (self.buckets_map.len() as f64) + } + + /// Writes the preservation grouped by counts to the file system in a csv/tsv like format + /// # Arguments: + /// * `preservation_by_count_path` - File path to be written + /// * `preservation_by_count_delimiter` - Delimiter to use when writing to `preservation_by_count_path` + pub fn write_preservation_by_count( + &self, + preservation_by_count_path: &str, + preservation_by_count_delimiter: char, + ) -> Result<(), Error> { + let mut file = std::fs::File::create(preservation_by_count_path)?; + + file.write_all( + format!( + "syn_count_bucket{}mean_combo_count{}mean_combo_length{}count_preservation\n", + preservation_by_count_delimiter, + preservation_by_count_delimiter, + preservation_by_count_delimiter, + ) + .as_bytes(), + )?; + for max_val in self.buckets_map.keys().sorted().rev() { + let b = &self.buckets_map[max_val]; + + file.write_all( + format!( + "{}{}{}{}{}{}{}\n", + max_val, + preservation_by_count_delimiter, + b.get_mean_combination_count(), + preservation_by_count_delimiter, + b.get_mean_combination_length(), + preservation_by_count_delimiter, + b.get_mean_preservation(), + ) + .as_bytes(), + )? + } + Ok(()) + } +} + +impl Deref for PreservationByCountBuckets { + type Target = PreservationBucketsMap; + + fn deref(&self) -> &Self::Target { + &self.buckets_map + } +} + +impl DerefMut for PreservationByCountBuckets { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.buckets_map + } +} + +/// Stores the bins for each bucket. For example: +/// [10, 20, 40, 80, 160...] +#[derive(Debug)] +pub struct PreservationByCountBucketBins { + /// Bins where the index of the vector is the bin index, + /// and the value is the max value count allowed on the bin + bins: Vec, +} + +impl PreservationByCountBucketBins { + /// Generates a new PreservationByCountBucketBins with bins up to a `max_val` + /// # Arguments + /// * `max_val` - Max value allowed on the last bucket (inclusive) + pub fn new(max_val: usize) -> PreservationByCountBucketBins { + let mut bins: Vec = vec![INITIAL_BIN]; + loop { + let last = *bins.last().unwrap(); + if last >= max_val { + break; + } + bins.push(last * BIN_RATIO); + } + PreservationByCountBucketBins { bins } + } + + /// Find the first `bucket_max_val` where `val >= bucket_max_val` + /// # Arguments + /// * `val` - Count to look in which bucket it will fit + pub fn find_bucket_max_val(&self, val: usize) -> usize { + // find first element x where x >= val + self.bins[self.bins.partition_point(|x| *x < val)] + } +} + +impl Deref for PreservationByCountBucketBins { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.bins + } +} + +impl DerefMut for PreservationByCountBucketBins { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.bins + } +} diff --git a/packages/core/src/processing/evaluator/preservation_by_length.rs b/packages/core/src/processing/evaluator/preservation_by_length.rs new file mode 100644 index 0000000..99f0af6 --- /dev/null +++ b/packages/core/src/processing/evaluator/preservation_by_length.rs @@ -0,0 +1,118 @@ +use super::{preservation_bucket::PreservationBucket, typedefs::PreservationBucketsMap}; +use itertools::Itertools; +use std::{ + io::{Error, Write}, + ops::{Deref, DerefMut}, +}; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +use crate::processing::aggregator::{ + typedefs::AggregatesCountMap, value_combination::ValueCombination, +}; + +#[cfg_attr(feature = "pyo3", pyclass)] +/// Wrapping to store the preservation buckets grouped +/// by length +pub struct PreservationByLengthBuckets { + buckets_map: PreservationBucketsMap, +} + +impl PreservationByLengthBuckets { + /// Returns a new default PreservationByLengthBuckets + #[inline] + pub fn default() -> PreservationByLengthBuckets { + PreservationByLengthBuckets { + buckets_map: PreservationBucketsMap::default(), + } + } + + #[inline] + pub(super) fn populate( + &mut self, + comb: &ValueCombination, + sensitive_aggregates: &AggregatesCountMap, + synthetic_aggregates: &AggregatesCountMap, + ) { + let sen_count = match sensitive_aggregates.get(comb) { + Some(count) => count.count, + _ => 0, + }; + let syn_count = match synthetic_aggregates.get(comb) { + Some(count) => count.count, + _ => 0, + }; + let preservation = if sen_count > 0 { + // max value is 100%, so use min as syn count might be > sen_count + f64::min((syn_count as f64) / (sen_count as f64), 1.0) + } else { + 0.0 + }; + + self.buckets_map + .entry(comb.len()) + .or_insert_with(PreservationBucket::default) + .add(preservation, comb.len(), syn_count); + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl PreservationByLengthBuckets { + /// Returns the actual buckets grouped by length. + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + pub fn get_buckets(&self) -> PreservationBucketsMap { + self.buckets_map.clone() + } + + /// Writes the preservation grouped by length to the file system in a csv/tsv like format + /// # Arguments: + /// * `preservation_by_length_path` - File path to be written + /// * `preservation_by_length_delimiter` - Delimiter to use when writing to `preservation_by_length_path` + pub fn write_preservation_by_length( + &mut self, + preservation_by_length_path: &str, + preservation_by_length_delimiter: char, + ) -> Result<(), Error> { + let mut file = std::fs::File::create(preservation_by_length_path)?; + + file.write_all( + format!( + "syn_combo_length{}mean_combo_count{}count_preservation\n", + preservation_by_length_delimiter, preservation_by_length_delimiter, + ) + .as_bytes(), + )?; + for length in self.buckets_map.keys().sorted() { + let b = &self.buckets_map[length]; + + file.write_all( + format!( + "{}{}{}{}{}\n", + length, + preservation_by_length_delimiter, + b.get_mean_combination_count(), + preservation_by_length_delimiter, + b.get_mean_preservation(), + ) + .as_bytes(), + )? + } + Ok(()) + } +} + +impl Deref for PreservationByLengthBuckets { + type Target = PreservationBucketsMap; + + fn deref(&self) -> &Self::Target { + &self.buckets_map + } +} + +impl DerefMut for PreservationByLengthBuckets { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.buckets_map + } +} diff --git a/packages/core/src/processing/evaluator/rare_combinations_comparison_data.rs b/packages/core/src/processing/evaluator/rare_combinations_comparison_data.rs new file mode 100644 index 0000000..71b72e1 --- /dev/null +++ b/packages/core/src/processing/evaluator/rare_combinations_comparison_data.rs @@ -0,0 +1,197 @@ +use std::io::{Error, Write}; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +use crate::{ + data_block::typedefs::CombinationsComparisons, + processing::aggregator::aggregated_data::AggregatedData, utils::math::uround_down, +}; + +#[cfg_attr(feature = "pyo3", pyclass)] +#[derive(Clone)] +/// Represents a single combination comparison +pub struct CombinationComparison { + /// Length of the combination + pub combination_length: usize, + /// Combination formatted as string + pub combination: String, + /// Index of the record where this combination occur + pub record_index: usize, + /// Number of times this combination occurs on the synthetic data + pub synthetic_count: usize, + /// Number of times this combination occurs on the sensitive data + pub sensitive_count: usize, +} + +impl CombinationComparison { + #[inline] + /// Creates a new CombinationComparison + pub fn new( + combination_length: usize, + combination: String, + record_index: usize, + synthetic_count: usize, + sensitive_count: usize, + ) -> CombinationComparison { + CombinationComparison { + combination_length, + combination, + record_index, + synthetic_count, + sensitive_count, + } + } +} + +#[cfg(feature = "pyo3")] +#[cfg_attr(feature = "pyo3", pymethods)] +impl CombinationComparison { + #[getter] + /// Length of the combination + fn combination_length(&self) -> usize { + self.combination_length + } + + #[getter] + /// Combination formatted as string + fn combination(&self) -> String { + self.combination.clone() + } + + #[getter] + /// Index of the record where this combination occur + fn record_index(&self) -> usize { + self.record_index + } + + #[getter] + /// Number of times this combination occurs on the synthetic data + fn synthetic_count(&self) -> usize { + self.synthetic_count + } + + #[getter] + /// Number of times this combination occurs on the sensitive data + fn sensitive_count(&self) -> usize { + self.sensitive_count + } +} + +#[cfg_attr(feature = "pyo3", pyclass)] +/// Computed rare combination comparisons between the synthetic +/// and sensitive datasets for all records +pub struct RareCombinationsComparisonData { + /// Rare combination comparison for all records + /// (compares rare combinations on the synthetic dataset with + /// the sensitive counts) + pub rare_combinations: CombinationsComparisons, +} + +impl RareCombinationsComparisonData { + //// Build a new comparison between the rare combinations on the synthetic + /// data and the sensitive data counts + /// # Arguments + /// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data + /// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data + /// * `resolution` - Reporting resolution used for data synthesis + /// * `combination_delimiter` - Delimiter used to join combinations and format then + /// as strings + /// * `protect` - Whether or not the sensitive counts should be rounded to the + /// nearest smallest multiple of resolution + #[inline] + pub fn from_synthetic_and_sensitive_aggregated_data( + synthetic_aggregated_data: &AggregatedData, + sensitive_aggregated_data: &AggregatedData, + resolution: usize, + combination_delimiter: &str, + protect: bool, + ) -> RareCombinationsComparisonData { + let mut rare_combinations: CombinationsComparisons = CombinationsComparisons::default(); + let resolution_f64 = resolution as f64; + + for (agg, count) in synthetic_aggregated_data.aggregates_count.iter() { + if count.count < resolution { + let combination_str = agg.format_str_using_headers( + &synthetic_aggregated_data.data_block.headers, + combination_delimiter, + ); + let mut sensitive_count = sensitive_aggregated_data + .aggregates_count + .get(agg) + .map_or(0, |c| c.count); + + if protect { + sensitive_count = uround_down(sensitive_count as f64, resolution_f64) + } + + for record_index in count.contained_in_records.iter() { + rare_combinations.push(CombinationComparison::new( + agg.len(), + combination_str.clone(), + *record_index, + count.count, + sensitive_count, + )); + } + } + } + + // sort result by combination length + rare_combinations.sort_by_key(|c| c.combination_length); + + RareCombinationsComparisonData { rare_combinations } + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl RareCombinationsComparisonData { + /// Returns the rare combination comparisons between the synthetic + /// and sensitive datasets for all records. + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + pub fn get_rare_combinations(&self) -> CombinationsComparisons { + self.rare_combinations.clone() + } + + /// Writes the rare combinations to the file system in a csv/tsv like format + /// # Arguments: + /// * `rare_combinations_path` - File path to be written + /// * `rare_combinations_delimiter` - Delimiter to use when writing to `rare_combinations_path` + pub fn write_rare_combinations( + &mut self, + rare_combinations_path: &str, + rare_combinations_delimiter: char, + ) -> Result<(), Error> { + let mut file = std::fs::File::create(rare_combinations_path)?; + + file.write_all( + format!( + "combo_length{}combo{}record_id{}syn_count{}sen_count\n", + rare_combinations_delimiter, + rare_combinations_delimiter, + rare_combinations_delimiter, + rare_combinations_delimiter, + ) + .as_bytes(), + )?; + for ra in self.rare_combinations.iter() { + file.write_all( + format!( + "{}{}{}{}{}{}{}{}{}\n", + ra.combination_length, + rare_combinations_delimiter, + ra.combination, + rare_combinations_delimiter, + ra.record_index, + rare_combinations_delimiter, + ra.synthetic_count, + rare_combinations_delimiter, + ra.sensitive_count, + ) + .as_bytes(), + )? + } + Ok(()) + } +} diff --git a/packages/core/src/processing/evaluator/typedefs.rs b/packages/core/src/processing/evaluator/typedefs.rs index 306c76b..4701e08 100644 --- a/packages/core/src/processing/evaluator/typedefs.rs +++ b/packages/core/src/processing/evaluator/typedefs.rs @@ -1,6 +1,5 @@ -use super::bucket::PreservationByCountBucket; +use super::preservation_bucket::PreservationBucket; use fnv::FnvHashMap; -/// Maps the max value allowed in the bucket to its -/// correspondent PreservationByCountBucket -pub type PreservationByCountBuckets = FnvHashMap; +/// Maps a value to its correspondent PreservationBucket +pub type PreservationBucketsMap = FnvHashMap; diff --git a/packages/core/src/processing/generator/data_generator.rs b/packages/core/src/processing/generator/data_generator.rs new file mode 100644 index 0000000..e6a4d67 --- /dev/null +++ b/packages/core/src/processing/generator/data_generator.rs @@ -0,0 +1,143 @@ +use super::generated_data::GeneratedData; +use super::synthesizer::typedefs::SynthesizedRecords; +use super::synthesizer::unseeded::UnseededSynthesizer; +use super::SynthesisMode; +use log::info; +use std::sync::Arc; + +use crate::data_block::block::DataBlock; +use crate::data_block::typedefs::RawSyntheticData; +use crate::processing::generator::synthesizer::cache::SynthesizerCacheKey; +use crate::processing::generator::synthesizer::seeded::SeededSynthesizer; +use crate::utils::reporting::ReportProgress; +use crate::utils::time::ElapsedDurationLogger; + +/// Process a data block and generates new synthetic data +pub struct Generator { + data_block: Arc, +} + +impl Generator { + /// Returns a new Generator + /// # Arguments + /// * `data_block` - Sensitive data to be synthesized + #[inline] + pub fn new(data_block: Arc) -> Generator { + Generator { data_block } + } + + /// Generates new synthetic data based on sensitive data + /// # Arguments + /// * `resolution` - Reporting resolution used for data synthesis + /// * `cache_max_size` - Maximum cache size allowed + /// * `empty_value` - Empty values on the synthetic data will be represented by this + /// * `mode` - Which mode to perform the data synthesis + /// * `progress_reporter` - Will be used to report the processing + /// progress (`ReportProgress` trait). If `None`, nothing will be reported + pub fn generate( + &mut self, + resolution: usize, + cache_max_size: usize, + empty_value: String, + mode: SynthesisMode, + progress_reporter: &mut Option, + ) -> GeneratedData + where + T: ReportProgress, + { + let _duration_logger = ElapsedDurationLogger::new("data generation"); + let empty_value_arc = Arc::new(empty_value); + + info!("starting {} generation...", mode); + + let synthesized_records = match mode { + SynthesisMode::Seeded => { + self.seeded_synthesis(resolution, cache_max_size, progress_reporter) + } + SynthesisMode::Unseeded => self.unseeded_synthesis( + resolution, + cache_max_size, + &empty_value_arc, + progress_reporter, + ), + }; + + self.build_generated_data(synthesized_records, empty_value_arc) + } + + #[inline] + fn build_generated_data( + &self, + mut synthesized_records: SynthesizedRecords, + empty_value: Arc, + ) -> GeneratedData { + let mut result: RawSyntheticData = RawSyntheticData::default(); + let mut records: RawSyntheticData = synthesized_records + .drain(..) + .map(|r| { + SynthesizerCacheKey::new(self.data_block.headers.len(), &r) + .format_record(&empty_value) + }) + .collect(); + + // sort by number of defined attributes + records.sort(); + records.sort_by_key(|r| { + -r.iter() + .map(|s| if s.is_empty() { 0 } else { 1 }) + .sum::() + }); + + result.push(self.data_block.headers.to_vec()); + result.extend(records); + + let expansion_ratio = (result.len() - 1) as f64 / self.data_block.records.len() as f64; + + info!("expansion ratio: {:.4?}", expansion_ratio); + + GeneratedData::new(result, expansion_ratio) + } + + #[inline] + fn seeded_synthesis( + &self, + resolution: usize, + cache_max_size: usize, + progress_reporter: &mut Option, + ) -> SynthesizedRecords + where + T: ReportProgress, + { + let attr_rows_map = Arc::new(self.data_block.calc_attr_rows()); + let mut synth = SeededSynthesizer::new( + self.data_block.clone(), + attr_rows_map, + resolution, + cache_max_size, + ); + synth.run(progress_reporter) + } + + #[inline] + fn unseeded_synthesis( + &self, + resolution: usize, + cache_max_size: usize, + empty_value: &Arc, + progress_reporter: &mut Option, + ) -> SynthesizedRecords + where + T: ReportProgress, + { + let attr_rows_map_by_column = + Arc::new(self.data_block.calc_attr_rows_by_column(empty_value)); + let mut synth = UnseededSynthesizer::new( + self.data_block.clone(), + attr_rows_map_by_column, + resolution, + cache_max_size, + empty_value.clone(), + ); + synth.run(progress_reporter) + } +} diff --git a/packages/core/src/processing/generator/generated_data.rs b/packages/core/src/processing/generator/generated_data.rs new file mode 100644 index 0000000..d5fc19f --- /dev/null +++ b/packages/core/src/processing/generator/generated_data.rs @@ -0,0 +1,93 @@ +use csv::WriterBuilder; +use log::info; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +#[cfg(feature = "pyo3")] +use crate::data_block::typedefs::CsvRecord; + +use crate::{ + data_block::{csv_io_error::CsvIOError, typedefs::RawSyntheticData}, + utils::time::ElapsedDurationLogger, +}; + +#[cfg_attr(feature = "pyo3", pyclass)] +/// Synthetic data generated by the Generator +pub struct GeneratedData { + /// Synthesized data - headers (index 0) and records indexes 1... + pub synthetic_data: RawSyntheticData, + /// `Synthetic data length / Sensitive data length` (header not included) + pub expansion_ratio: f64, +} + +impl GeneratedData { + /// Returns a new GeneratedData struct with default values + #[inline] + pub fn default() -> GeneratedData { + GeneratedData { + synthetic_data: RawSyntheticData::default(), + expansion_ratio: 0.0, + } + } + + /// Returns a new GeneratedData struct + /// # Arguments + /// * `synthetic_data` - Synthesized data - headers (index 0) and records indexes 1... + /// * `expansion_ratio` - `Synthetic data length / Sensitive data length` (header not included) + #[inline] + pub fn new(synthetic_data: RawSyntheticData, expansion_ratio: f64) -> GeneratedData { + GeneratedData { + synthetic_data, + expansion_ratio, + } + } +} + +#[cfg_attr(feature = "pyo3", pymethods)] +impl GeneratedData { + #[cfg(feature = "pyo3")] + /// Synthesized data - headers (index 0) and records indexes 1... + /// This method will clone the data, so its recommended to have its result stored + /// in a local variable to avoid it being called multiple times + fn get_synthetic_data(&self) -> Vec { + self.synthetic_data + .iter() + .map(|row| row.iter().map(|value| (**value).clone()).collect()) + .collect() + } + + #[cfg(feature = "pyo3")] + #[getter] + /// `Synthetic data length / Sensitive data length` (header not included) + fn expansion_ratio(&self) -> f64 { + self.expansion_ratio + } + + /// Writes the synthesized data to the file system + /// # Arguments + /// * `path` - File path to be written + /// * `delimiter` - Delimiter to use when writing to `path` + pub fn write_synthetic_data(&self, path: &str, delimiter: char) -> Result<(), CsvIOError> { + let _duration_logger = ElapsedDurationLogger::new("write synthetic data"); + + info!("writing file {}", path); + + let mut wtr = match WriterBuilder::new() + .delimiter(delimiter as u8) + .from_path(&path) + { + Ok(writer) => writer, + Err(err) => return Err(CsvIOError::new(err)), + }; + + // write header and records + for r in self.synthetic_data.iter() { + match wtr.write_record(r.iter().map(|v| v.as_str())) { + Ok(_) => {} + Err(err) => return Err(CsvIOError::new(err)), + }; + } + Ok(()) + } +} diff --git a/packages/core/src/processing/generator/mod.rs b/packages/core/src/processing/generator/mod.rs index 444472a..4fc79cd 100644 --- a/packages/core/src/processing/generator/mod.rs +++ b/packages/core/src/processing/generator/mod.rs @@ -1,156 +1,13 @@ /// Module defining the data synthesis process pub mod synthesizer; -use csv::{Error, WriterBuilder}; -use log::Level::Trace; -use log::{info, log_enabled, trace}; -use std::time::Duration; +mod synthesis_mode; -use crate::data_block::block::DataBlock; -use crate::data_block::typedefs::CsvRecordRef; -use crate::measure_time; -use crate::processing::generator::synthesizer::cache::SynthesizerCacheKey; -use crate::processing::generator::synthesizer::seeded::SeededSynthesizer; -use crate::utils::reporting::ReportProgress; -use crate::utils::time::ElapsedDuration; +pub use synthesis_mode::SynthesisMode; -use self::synthesizer::typedefs::SynthesizedRecords; +mod data_generator; -#[derive(Debug)] -struct GeneratorDurations { - generate: Duration, - write_records: Duration, -} +pub use data_generator::Generator; -/// Process a data block and generates new synthetic data -pub struct Generator<'data_block> { - data_block: &'data_block DataBlock, - durations: GeneratorDurations, -} - -/// Synthetic data generated by the Generator -pub struct GeneratedData<'data_block> { - /// Synthesized data - headers (index 0) and records indexes 1... - pub synthetic_data: Vec>, - /// `Synthetic data length / Sensitive data length` (header not included) - pub expansion_ratio: f64, -} - -impl<'data_block> Generator<'data_block> { - /// Returns a new Generator - /// # Arguments - /// * `data_block` - Sensitive data to be synthesized - #[inline] - pub fn new(data_block: &'data_block DataBlock) -> Generator<'data_block> { - Generator { - data_block, - durations: GeneratorDurations { - generate: Duration::default(), - write_records: Duration::default(), - }, - } - } - - /// Generates new synthetic data based on sensitive data - /// # Arguments - /// * `resolution` - Reporting resolution used for data synthesis - /// * `cache_max_size` - Maximum cache size allowed - /// * `empty_value` - Empty values on the synthetic data will be represented by this - /// * `progress_reporter` - Will be used to report the processing - /// progress (`ReportProgress` trait). If `None`, nothing will be reported - pub fn generate( - &mut self, - resolution: usize, - cache_max_size: usize, - empty_value: &'data_block str, - progress_reporter: &mut Option, - ) -> GeneratedData<'data_block> - where - T: ReportProgress, - { - info!("starting generation..."); - - measure_time!( - || { - let mut result: Vec> = Vec::default(); - let mut records: Vec> = self - .seeded_synthesis(resolution, cache_max_size, progress_reporter) - .iter() - .map(|r| { - SynthesizerCacheKey::new(self.data_block.headers.len(), r) - .format_record(empty_value) - }) - .collect(); - - // sort by number of defined attributes - records.sort(); - records.sort_by_key(|r| { - -r.iter() - .map(|s| if s.is_empty() { 0 } else { 1 }) - .sum::() - }); - - let expansion_ratio = records.len() as f64 / self.data_block.records.len() as f64; - - info!("expansion ratio: {:.4?}", expansion_ratio); - - result.push(self.data_block.headers.iter().map(|h| h.as_str()).collect()); - result.extend(records); - - GeneratedData { - synthetic_data: result, - expansion_ratio, - } - }, - (self.durations.generate) - ) - } - - /// Writes the synthesized data to the file system - /// # Arguments - /// * `result` - Synthetic data to write (headers included) - /// * `path` - File path to be written - /// * `delimiter` - Delimiter to use when writing to `path` - pub fn write_records( - &mut self, - result: &[CsvRecordRef<'data_block>], - path: &str, - delimiter: u8, - ) -> Result<(), Error> { - info!("writing file {}", path); - - measure_time!( - || { - let mut wtr = WriterBuilder::new().delimiter(delimiter).from_path(&path)?; - - // write header and records - for r in result.iter() { - wtr.write_record(r)?; - } - Ok(()) - }, - (self.durations.write_records) - ) - } - - fn seeded_synthesis( - &mut self, - resolution: usize, - cache_max_size: usize, - progress_reporter: &mut Option, - ) -> SynthesizedRecords<'data_block> - where - T: ReportProgress, - { - let attr_rows_map = DataBlock::calc_attr_rows(&self.data_block.records); - let mut synth = - SeededSynthesizer::new(self.data_block, &attr_rows_map, resolution, cache_max_size); - synth.run(progress_reporter) - } -} - -impl<'data_block> Drop for Generator<'data_block> { - fn drop(&mut self) { - trace!("generator durations: {:#?}", self.durations); - } -} +/// Module to describe the synthetic data generated by the Generator +pub mod generated_data; diff --git a/packages/core/src/processing/generator/synthesis_mode.rs b/packages/core/src/processing/generator/synthesis_mode.rs new file mode 100644 index 0000000..da60303 --- /dev/null +++ b/packages/core/src/processing/generator/synthesis_mode.rs @@ -0,0 +1,29 @@ +use std::{fmt::Display, str::FromStr}; + +#[derive(Debug)] +/// Modes to execute the data generation/synthesis +pub enum SynthesisMode { + Seeded, + Unseeded, +} + +impl FromStr for SynthesisMode { + type Err = &'static str; + + fn from_str(mode: &str) -> Result { + match mode.to_lowercase().as_str() { + "seeded" => Ok(SynthesisMode::Seeded), + "unseeded" => Ok(SynthesisMode::Unseeded), + _ => Err("invalid mode, should be seeded or unseeded"), + } + } +} + +impl Display for SynthesisMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SynthesisMode::Seeded => Ok(write!(f, "seeded")?), + SynthesisMode::Unseeded => Ok(write!(f, "unseeded")?), + } + } +} diff --git a/packages/core/src/processing/generator/synthesizer/cache.rs b/packages/core/src/processing/generator/synthesizer/cache.rs index 18991b9..7c7159f 100644 --- a/packages/core/src/processing/generator/synthesizer/cache.rs +++ b/packages/core/src/processing/generator/synthesizer/cache.rs @@ -1,6 +1,7 @@ use super::typedefs::SynthesizedRecord; use fnv::FnvBuildHasher; use lru::LruCache; +use std::sync::Arc; use crate::data_block::{typedefs::CsvRecordRef, value::DataBlockValue}; @@ -8,29 +9,26 @@ use crate::data_block::{typedefs::CsvRecordRef, value::DataBlockValue}; /// (`columns[{column_index}] = `value` if value exists, /// or `columns[{column_index}] = None` otherwise) #[derive(Debug, PartialEq, Eq, Hash, Clone)] -pub struct SynthesizerCacheKey<'value> { +pub struct SynthesizerCacheKey { /// Values for a given synthesized record are indexed by `column_index` /// on this vector - columns: Vec>, + columns: Vec>>, } -impl<'value> SynthesizerCacheKey<'value> { +impl SynthesizerCacheKey { /// Returns a new SynthesizerCacheKey /// # Arguments /// * `num_columns` - Number of columns in the data block /// * `records` - Synthesized record to build the key for #[inline] - pub fn new( - num_columns: usize, - values: &SynthesizedRecord<'value>, - ) -> SynthesizerCacheKey<'value> { + pub fn new(num_columns: usize, values: &SynthesizedRecord) -> SynthesizerCacheKey { let mut key = SynthesizerCacheKey { columns: Vec::with_capacity(num_columns), }; key.columns.resize_with(num_columns, || None); for v in values.iter() { - key.columns[v.column_index] = Some(&v.value); + key.columns[v.column_index] = Some(v.value.clone()); } key } @@ -40,9 +38,9 @@ impl<'value> SynthesizerCacheKey<'value> { /// # Arguments /// * `value` - New data block value to be added in the new key #[inline] - pub fn new_with_value(&self, value: &'value DataBlockValue) -> SynthesizerCacheKey<'value> { + pub fn new_with_value(&self, value: &Arc) -> SynthesizerCacheKey { let mut new_key = self.clone(); - new_key.columns[value.column_index] = Some(&value.value); + new_key.columns[value.column_index] = Some(value.value.clone()); new_key } @@ -59,29 +57,29 @@ impl<'value> SynthesizerCacheKey<'value> { /// # Arguments /// * `empty_value` - String to be used in case the `columns[column_index] == None` #[inline] - pub fn format_record(&self, empty_value: &'value str) -> CsvRecordRef<'value> { + pub fn format_record(&self, empty_value: &Arc) -> CsvRecordRef { self.columns .iter() .map(|c_opt| match c_opt { - Some(c) => c, - None => empty_value, + Some(c) => (*c).clone(), + None => empty_value.clone(), }) .collect() } } /// Cache to store keys-values used during the synthesis process -pub struct SynthesizerCache<'value, T> { +pub struct SynthesizerCache { /// LruCache to store the keys mapping to a generic type T - cache: LruCache, T, FnvBuildHasher>, + cache: LruCache, } -impl<'value, T> SynthesizerCache<'value, T> { +impl SynthesizerCache { /// Returns a new SynthesizerCache /// # Arguments /// * `cache_max_size` - Maximum cache size allowed for the LRU #[inline] - pub fn new(cache_max_size: usize) -> SynthesizerCache<'value, T> { + pub fn new(cache_max_size: usize) -> SynthesizerCache { SynthesizerCache { cache: LruCache::with_hasher(cache_max_size, FnvBuildHasher::default()), } @@ -93,7 +91,7 @@ impl<'value, T> SynthesizerCache<'value, T> { /// # Arguments /// * `key` - Key to look for the value #[inline] - pub fn get(&mut self, key: &SynthesizerCacheKey<'value>) -> Option<&T> { + pub fn get(&mut self, key: &SynthesizerCacheKey) -> Option<&T> { self.cache.get(key) } @@ -104,7 +102,7 @@ impl<'value, T> SynthesizerCache<'value, T> { /// * `key` - Key to be inserted /// * `value` - Value to associated with the key #[inline] - pub fn insert(&mut self, key: SynthesizerCacheKey<'value>, value: T) -> Option { + pub fn insert(&mut self, key: SynthesizerCacheKey, value: T) -> Option { self.cache.put(key, value) } } diff --git a/packages/core/src/processing/generator/synthesizer/context.rs b/packages/core/src/processing/generator/synthesizer/context.rs index b6e990d..ea093e2 100644 --- a/packages/core/src/processing/generator/synthesizer/context.rs +++ b/packages/core/src/processing/generator/synthesizer/context.rs @@ -2,93 +2,76 @@ use super::cache::{SynthesizerCache, SynthesizerCacheKey}; use super::typedefs::{ AttributeCountMap, NotAllowedAttrSet, SynthesizedRecord, SynthesizerSeedSlice, }; +use itertools::Itertools; use rand::Rng; +use std::sync::Arc; -use crate::data_block::typedefs::{AttributeRows, AttributeRowsMap, AttributeRowsSlice}; +use crate::data_block::typedefs::{ + AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, AttributeRowsRefMap, + AttributeRowsSlice, +}; use crate::data_block::value::DataBlockValue; +use crate::processing::aggregator::typedefs::RecordsSet; use crate::utils::collections::ordered_vec_intersection; /// Represents a synthesis context, containing the information /// required to synthesize new records /// (common to seeded and unseeded synthesis) -pub struct SynthesizerContext<'data_block, 'attr_rows_map> { +pub struct SynthesizerContext { /// Number of headers in the data block - headers_len: usize, + pub headers_len: usize, /// Number of records in the data block - records_len: usize, - /// Maps a data block value to all the rows where it occurs - attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>, + pub records_len: usize, /// Cache used to store the rows where value combinations occurs - cache: SynthesizerCache<'data_block, AttributeRows>, + cache: SynthesizerCache>, /// Reporting resolution used for data synthesis resolution: usize, } -impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map> { +impl SynthesizerContext { /// Returns a new SynthesizerContext /// # Arguments /// * `headers_len` - Number of headers in the data block /// * `records_len` - Number of records in the data block - /// * `attr_rows_map` - Maps a data block value to all the rows where it occurs /// * `resolution` - Reporting resolution used for data synthesis /// * `cache_max_size` - Maximum cache size allowed #[inline] pub fn new( headers_len: usize, records_len: usize, - attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>, resolution: usize, cache_max_size: usize, - ) -> SynthesizerContext<'data_block, 'attr_rows_map> { + ) -> SynthesizerContext { SynthesizerContext { headers_len, records_len, - attr_rows_map, cache: SynthesizerCache::new(cache_max_size), resolution, } } - /// Returns the configured mapping of a data block value to all - /// the rows where it occurs - #[inline] - pub fn get_attr_rows_map(&self) -> &AttributeRowsMap<'data_block> { - self.attr_rows_map - } - - /// Returns the configured reporting resolution - #[inline] - pub fn get_resolution(&self) -> usize { - self.resolution - } - /// Samples an attribute based on its count /// (the higher the count the greater the chance for the /// attribute to be selected). /// Returns `None` if all the counts are 0 or the map is empty /// # Arguments /// * `counts` - Maps an attribute to its count for sampling - pub fn sample_from_attr_counts( - counts: &AttributeCountMap<'data_block>, - ) -> Option<&'data_block DataBlockValue> { - let mut res: Option<&'data_block DataBlockValue> = None; + pub fn sample_from_attr_counts(counts: &AttributeCountMap) -> Option> { + let mut res: Option> = None; let total: usize = counts.values().sum(); if total != 0 { - let random = if total == 1 { - 1 - } else { - rand::thread_rng().gen_range(1..total) - }; - let mut current_sim: usize = 0; + let random = rand::thread_rng().gen_range(1..=total); + let mut current_sum: usize = 0; - for (value, count) in counts.iter() { - current_sim += count; - if current_sim < random { - res = Some(value); - break; + for (value, count) in counts.iter().sorted_by_key(|(_, c)| **c) { + if *count > 0 { + current_sum += count; + res = Some(value.clone()); + if current_sum >= random { + break; + } } - res = Some(value) } } res @@ -100,52 +83,130 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map /// * `synthesized_record` - Record synthesized so far /// * `current_seed` - Current seed/record used for sampling /// * `not_allowed_attr_set` - Attributes not allowed to be sampled - pub fn sample_next_attr( + /// * `attr_rows_map` - Maps a data block value to all the rows where it occurs + pub fn sample_next_attr_from_seed( &mut self, - synthesized_record: &SynthesizedRecord<'data_block>, - current_seed: &SynthesizerSeedSlice<'data_block>, - not_allowed_attr_set: &NotAllowedAttrSet<'data_block>, - ) -> Option<&'data_block DataBlockValue> { - let counts = - self.calc_next_attr_count(synthesized_record, current_seed, not_allowed_attr_set); + synthesized_record: &SynthesizedRecord, + current_seed: &SynthesizerSeedSlice, + not_allowed_attr_set: &NotAllowedAttrSet, + attr_rows_map: &AttributeRowsMap, + ) -> Option> { + let counts = self.calc_next_attr_count( + synthesized_record, + current_seed, + not_allowed_attr_set, + attr_rows_map, + ); SynthesizerContext::sample_from_attr_counts(&counts) } + /// Samples the next attribute from the column at `column_index`. + /// Returns a tuple with + /// (intersection of `current_attrs_rows` and rows that sampled value appear, sampled value) + /// # Arguments + /// * `synthesized_record` - Record synthesized so far + /// * `column_index` - Index of the column to sample from + /// * `attr_rows_map_by_column` - Maps the column index -> data block value -> rows where the value appear + /// * `current_attrs_rows` - Rows where the so far sampled combination appear + /// * `empty_value` - Empty values on the synthetic data will be represented by this + pub fn sample_next_attr_from_column( + &mut self, + synthesized_record: &SynthesizedRecord, + column_index: usize, + attr_rows_map_by_column: &AttributeRowsByColumnMap, + current_attrs_rows: &AttributeRowsSlice, + empty_value: &Arc, + ) -> Option<(Arc, Arc)> { + let cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record); + let empty_block_value = Arc::new(DataBlockValue::new(column_index, empty_value.clone())); + let mut values_to_sample: AttributeRowsRefMap = AttributeRowsRefMap::default(); + let mut counts = AttributeCountMap::default(); + + // calculate the row intersections for each value in the column + for (value, rows) in attr_rows_map_by_column[&column_index].iter() { + let new_cache_key = cache_key.new_with_value(value); + let rows_intersection = match self.cache.get(&new_cache_key) { + Some(cached_value) => cached_value.clone(), + None => { + let intersection = Arc::new(ordered_vec_intersection(current_attrs_rows, rows)); + self.cache.insert(new_cache_key, intersection.clone()); + intersection + } + }; + values_to_sample.insert(value.clone(), rows_intersection); + } + + // store the rows with empty values + let mut rows_with_empty_values: RecordsSet = values_to_sample[&empty_block_value] + .iter() + .cloned() + .collect(); + + for (value, rows) in values_to_sample.iter() { + if rows.len() < self.resolution { + // if the combination containing the attribute appears in less + // than resolution rows, we can't use it so we tag it as an empty value + rows_with_empty_values.extend(rows.iter()); + } else if **value != *empty_block_value { + // if we can use the combination containing the attribute + // gather its count for sampling + counts.insert(value.clone(), rows.len()); + } + } + + // if there are empty values that can be sampled, add them for sampling + if !rows_with_empty_values.is_empty() { + counts.insert(empty_block_value, rows_with_empty_values.len()); + } + + Self::sample_from_attr_counts(&counts).map(|sampled_value| { + ( + values_to_sample.remove(&sampled_value).unwrap(), + sampled_value, + ) + }) + } + fn calc_current_attrs_rows( &mut self, - cache_key: &SynthesizerCacheKey<'data_block>, - synthesized_record: &SynthesizedRecord<'data_block>, - ) -> AttributeRows { + cache_key: &SynthesizerCacheKey, + synthesized_record: &SynthesizedRecord, + attr_rows_map: &AttributeRowsMap, + ) -> Arc { match self.cache.get(cache_key) { Some(cache_value) => cache_value.clone(), None => { let mut current_attrs_rows: AttributeRows = AttributeRows::new(); + let current_attrs_rows_arc: Arc; if !synthesized_record.is_empty() { for sr in synthesized_record.iter() { current_attrs_rows = if current_attrs_rows.is_empty() { - self.attr_rows_map[sr].clone() + attr_rows_map[sr].clone() } else { - ordered_vec_intersection(¤t_attrs_rows, &self.attr_rows_map[sr]) + ordered_vec_intersection(¤t_attrs_rows, &attr_rows_map[sr]) } } } else { current_attrs_rows = (0..self.records_len).collect(); } + current_attrs_rows_arc = Arc::new(current_attrs_rows); + self.cache - .insert(cache_key.clone(), current_attrs_rows.clone()); - current_attrs_rows + .insert(cache_key.clone(), current_attrs_rows_arc.clone()); + current_attrs_rows_arc } } } fn gen_attr_count_map( &mut self, - cache_key: &mut SynthesizerCacheKey<'data_block>, - current_seed: &SynthesizerSeedSlice<'data_block>, + cache_key: &SynthesizerCacheKey, + current_seed: &SynthesizerSeedSlice, current_attrs_rows: &AttributeRowsSlice, - not_allowed_attr_set: &NotAllowedAttrSet<'data_block>, - ) -> AttributeCountMap<'data_block> { + not_allowed_attr_set: &NotAllowedAttrSet, + attr_rows_map: &AttributeRowsMap, + ) -> AttributeCountMap { let mut attr_count_map: AttributeCountMap = AttributeCountMap::default(); for value in current_seed.iter() { @@ -156,10 +217,10 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map let count = match self.cache.get(&new_cache_key) { Some(cached_value) => cached_value.len(), None => { - let intersection = ordered_vec_intersection( + let intersection = Arc::new(ordered_vec_intersection( current_attrs_rows, - self.attr_rows_map.get(value).unwrap(), - ); + attr_rows_map.get(value).unwrap(), + )); let count = intersection.len(); self.cache.insert(new_cache_key, intersection); @@ -168,7 +229,7 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map }; if count >= self.resolution { - attr_count_map.insert(value, count); + attr_count_map.insert(value.clone(), count); } } } @@ -177,18 +238,21 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map fn calc_next_attr_count( &mut self, - synthesized_record: &SynthesizedRecord<'data_block>, - current_seed: &SynthesizerSeedSlice<'data_block>, - not_allowed_attr_set: &NotAllowedAttrSet<'data_block>, - ) -> AttributeCountMap<'data_block> { - let mut cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record); - let current_attrs_rows = self.calc_current_attrs_rows(&cache_key, synthesized_record); + synthesized_record: &SynthesizedRecord, + current_seed: &SynthesizerSeedSlice, + not_allowed_attr_set: &NotAllowedAttrSet, + attr_rows_map: &AttributeRowsMap, + ) -> AttributeCountMap { + let cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record); + let current_attrs_rows = + self.calc_current_attrs_rows(&cache_key, synthesized_record, attr_rows_map); self.gen_attr_count_map( - &mut cache_key, + &cache_key, current_seed, ¤t_attrs_rows, not_allowed_attr_set, + attr_rows_map, ) } } diff --git a/packages/core/src/processing/generator/synthesizer/mod.rs b/packages/core/src/processing/generator/synthesizer/mod.rs index 907dc1f..4acab1f 100644 --- a/packages/core/src/processing/generator/synthesizer/mod.rs +++ b/packages/core/src/processing/generator/synthesizer/mod.rs @@ -1,8 +1,18 @@ /// Module defining the cache used during the synthesis process pub mod cache; + /// Module defining the context used for data synthesis pub mod context; + /// Module defining the seeded synthesis process pub mod seeded; + +/// Module defining the unseeded synthesis process +pub mod unseeded; + /// Type definitions related to the synthesis process pub mod typedefs; + +mod seeded_rows_synthesizer; + +mod unseeded_rows_synthesizer; diff --git a/packages/core/src/processing/generator/synthesizer/seeded.rs b/packages/core/src/processing/generator/synthesizer/seeded.rs index b7f6f24..88dd9c8 100644 --- a/packages/core/src/processing/generator/synthesizer/seeded.rs +++ b/packages/core/src/processing/generator/synthesizer/seeded.rs @@ -1,56 +1,46 @@ use super::{ context::SynthesizerContext, + seeded_rows_synthesizer::SeededRowsSynthesizer, typedefs::{ AvailableAttrsMap, NotAllowedAttrSet, SynthesizedRecord, SynthesizedRecords, SynthesizedRecordsSlice, SynthesizerSeed, SynthesizerSeedSlice, }, }; use fnv::FnvHashMap; -use itertools::izip; -use log::Level::Trace; -use log::{info, log_enabled, trace}; +use itertools::{izip, Itertools}; +use log::info; use rand::{prelude::SliceRandom, thread_rng}; -use std::time::Duration; +use std::sync::Arc; use crate::{ - data_block::{ - block::DataBlock, - record::DataBlockRecord, - typedefs::{AttributeRowsMap, DataBlockRecords}, - value::DataBlockValue, - }, - measure_time, + data_block::{block::DataBlock, typedefs::AttributeRowsMap, value::DataBlockValue}, utils::{ math::{calc_percentage, iround_down}, reporting::ReportProgress, - time::ElapsedDuration, + threading::get_number_of_threads, + time::ElapsedDurationLogger, }, }; -#[derive(Debug)] -struct SeededSynthesizerDurations { - consolidate: Duration, - suppress: Duration, - synthesize_rows: Duration, -} - /// Represents all the information required to perform the seeded data synthesis -pub struct SeededSynthesizer<'data_block, 'attr_rows_map> { - /// Records to be synthesized - pub records: &'data_block DataBlockRecords, - /// Context used for the synthesis process - pub context: SynthesizerContext<'data_block, 'attr_rows_map>, +pub struct SeededSynthesizer { + /// Reference to the original data block + data_block: Arc, + /// Maps a data block value to all the rows where it occurs + attr_rows_map: Arc, + /// Reporting resolution used for data synthesis + resolution: usize, + /// Maximum cache size allowed + cache_max_size: usize, /// Percentage already completed on the row synthesis step synthesize_percentage: f64, /// Percentage already completed on the consolidation step consolidate_percentage: f64, /// Percentage already completed on the suppression step suppress_percentage: f64, - /// Elapsed durations on each step for benchmarking - durations: SeededSynthesizerDurations, } -impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> { +impl SeededSynthesizer { /// Returns a new SeededSynthesizer /// # Arguments /// * `data_block` - Sensitive data to be synthesized @@ -59,28 +49,19 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> /// * `cache_max_size` - Maximum cache size allowed #[inline] pub fn new( - data_block: &'data_block DataBlock, - attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>, + data_block: Arc, + attr_rows_map: Arc, resolution: usize, cache_max_size: usize, - ) -> SeededSynthesizer<'data_block, 'attr_rows_map> { + ) -> SeededSynthesizer { SeededSynthesizer { - records: &data_block.records, - context: SynthesizerContext::new( - data_block.headers.len(), - data_block.records.len(), - attr_rows_map, - resolution, - cache_max_size, - ), + data_block, + attr_rows_map, + resolution, + cache_max_size, synthesize_percentage: 0.0, consolidate_percentage: 0.0, suppress_percentage: 0.0, - durations: SeededSynthesizerDurations { - consolidate: Duration::default(), - suppress: Duration::default(), - synthesize_rows: Duration::default(), - }, } } @@ -89,64 +70,71 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> /// # Arguments /// * `progress_reporter` - Will be used to report the processing /// progress (`ReportProgress` trait). If `None`, nothing will be reported - pub fn run(&mut self, progress_reporter: &mut Option) -> SynthesizedRecords<'data_block> + pub fn run(&mut self, progress_reporter: &mut Option) -> SynthesizedRecords where T: ReportProgress, { - let mut synthesized_records: SynthesizedRecords<'data_block> = SynthesizedRecords::new(); + let mut synthesized_records: SynthesizedRecords = SynthesizedRecords::new(); - self.synthesize_percentage = 0.0; - self.consolidate_percentage = 0.0; - self.suppress_percentage = 0.0; + if !self.data_block.records.is_empty() { + let mut rows_synthesizers: Vec = self.build_rows_synthesizers(); - self.synthesize_rows(&mut synthesized_records, progress_reporter); - self.consolidate(&mut synthesized_records, progress_reporter); - self.suppress(&mut synthesized_records, progress_reporter); + self.synthesize_percentage = 0.0; + self.consolidate_percentage = 0.0; + self.suppress_percentage = 0.0; + self.synthesize_rows( + &mut synthesized_records, + &mut rows_synthesizers, + progress_reporter, + ); + self.consolidate( + &mut synthesized_records, + progress_reporter, + // use the first context to leverage already cached intersections + &mut rows_synthesizers[0].context, + ); + self.suppress(&mut synthesized_records, progress_reporter); + } synthesized_records } #[inline] - fn synthesize_row( - &mut self, - seed: &'data_block DataBlockRecord, - ) -> SynthesizedRecord<'data_block> { - let current_seed: SynthesizerSeed = seed.values.iter().collect(); - let mut synthesized_record = SynthesizedRecord::default(); - let not_allowed_attr_set = NotAllowedAttrSet::default(); + fn build_rows_synthesizers(&self) -> Vec { + let chunk_size = ((self.data_block.records.len() as f64) / (get_number_of_threads() as f64)) + .ceil() as usize; + let mut rows_synthesizers: Vec = Vec::default(); - loop { - let next = self.context.sample_next_attr( - &synthesized_record, - ¤t_seed, - ¬_allowed_attr_set, - ); - - match next { - Some(value) => { - synthesized_record.insert(value); - } - None => break, - } + for c in &self.data_block.records.iter().chunks(chunk_size) { + rows_synthesizers.push(SeededRowsSynthesizer::new( + SynthesizerContext::new( + self.data_block.headers.len(), + self.data_block.records.len(), + self.resolution, + self.cache_max_size, + ), + c.cloned().collect(), + self.attr_rows_map.clone(), + )); } - synthesized_record + rows_synthesizers } #[inline] fn count_not_used_attrs( &mut self, - synthesized_records: &SynthesizedRecordsSlice<'data_block>, - ) -> AvailableAttrsMap<'data_block> { + synthesized_records: &SynthesizedRecordsSlice, + ) -> AvailableAttrsMap { let mut available_attrs: AvailableAttrsMap = AvailableAttrsMap::default(); // go through the pairs (original_record, synthesized_record) and count how many // attributes were not used for (original_record, synthesized_record) in - izip!(self.records.iter(), synthesized_records.iter()) + izip!(self.data_block.records.iter(), synthesized_records.iter()) { for d in original_record.values.iter() { if !synthesized_record.contains(d) { - let attr = available_attrs.entry(d).or_insert(0); + let attr = available_attrs.entry(d.clone()).or_insert(0); *attr += 1; } } @@ -156,16 +144,16 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> fn calc_available_attrs( &mut self, - synthesized_records: &SynthesizedRecordsSlice<'data_block>, - ) -> AvailableAttrsMap<'data_block> { + synthesized_records: &SynthesizedRecordsSlice, + ) -> AvailableAttrsMap { let mut available_attrs = self.count_not_used_attrs(synthesized_records); - let resolution_f64 = self.context.get_resolution() as f64; + let resolution_f64 = self.resolution as f64; // add attributes for consolidation - for (attr, value) in self.context.get_attr_rows_map().iter() { + for (attr, value) in self.attr_rows_map.iter() { let n_rows = value.len(); - if n_rows >= self.context.get_resolution() { + if n_rows >= self.resolution { let target_attr_count = iround_down(n_rows as f64, resolution_f64) - (n_rows as isize) + match available_attrs.get(attr) { @@ -175,7 +163,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> if target_attr_count > 0 { // insert/update the final target count - available_attrs.insert(attr, target_attr_count); + available_attrs.insert(attr.clone(), target_attr_count); } else { // remove negative and zero values available_attrs.remove(attr); @@ -192,18 +180,17 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> #[inline] fn calc_not_allowed_attrs( &mut self, - available_attrs: &mut AvailableAttrsMap<'data_block>, - ) -> NotAllowedAttrSet<'data_block> { - self.context - .get_attr_rows_map() + available_attrs: &mut AvailableAttrsMap, + ) -> NotAllowedAttrSet { + self.attr_rows_map .keys() .filter_map(|attr| match available_attrs.get(attr) { // not on available attributes - None => Some(*attr), + None => Some(attr.clone()), Some(at) => { if *at <= 0 { // on available attributes, but count <= 0 - Some(*attr) + Some(attr.clone()) } else { None } @@ -215,18 +202,20 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> #[inline] fn consolidate_record( &mut self, - available_attrs: &mut AvailableAttrsMap<'data_block>, - current_seed: &SynthesizerSeedSlice<'data_block>, - ) -> SynthesizedRecord<'data_block> { + available_attrs: &mut AvailableAttrsMap, + current_seed: &SynthesizerSeedSlice, + context: &mut SynthesizerContext, + ) -> SynthesizedRecord { let mut not_allowed_attr_set: NotAllowedAttrSet = self.calc_not_allowed_attrs(available_attrs); let mut synthesized_record = SynthesizedRecord::default(); loop { - let next = self.context.sample_next_attr( + let next = context.sample_next_attr_from_seed( &synthesized_record, current_seed, ¬_allowed_attr_set, + &self.attr_rows_map, ); match next { @@ -236,9 +225,9 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> if next_count <= 1 { available_attrs.remove(&value); - not_allowed_attr_set.insert(value); + not_allowed_attr_set.insert(value.clone()); } else { - available_attrs.insert(value, next_count - 1); + available_attrs.insert(value.clone(), next_count - 1); } synthesized_record.insert(value); } @@ -249,44 +238,43 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> fn consolidate( &mut self, - synthesized_records: &mut SynthesizedRecords<'data_block>, + synthesized_records: &mut SynthesizedRecords, progress_reporter: &mut Option, + context: &mut SynthesizerContext, ) where T: ReportProgress, { - measure_time!( - || { - info!("consolidating..."); + let _duration_logger = ElapsedDurationLogger::new("consolidation"); + info!("consolidating..."); - let mut available_attrs = self.calc_available_attrs(synthesized_records); - let current_seed: SynthesizerSeed = available_attrs.keys().cloned().collect(); - let total = available_attrs.len(); - let total_f64 = total as f64; - let mut n_processed = 0; + let mut available_attrs = self.calc_available_attrs(synthesized_records); + let current_seed: SynthesizerSeed = available_attrs.keys().cloned().collect(); + let total = available_attrs.len(); + let total_f64 = total as f64; + let mut n_processed = 0; - while !available_attrs.is_empty() { - self.update_consolidate_progress(n_processed, total_f64, progress_reporter); - synthesized_records - .push(self.consolidate_record(&mut available_attrs, ¤t_seed)); - n_processed = total - available_attrs.len(); - } - self.update_consolidate_progress(n_processed, total_f64, progress_reporter); - }, - (self.durations.consolidate) - ); + while !available_attrs.is_empty() { + self.update_consolidate_progress(n_processed, total_f64, progress_reporter); + synthesized_records.push(self.consolidate_record( + &mut available_attrs, + ¤t_seed, + context, + )); + n_processed = total - available_attrs.len(); + } + self.update_consolidate_progress(n_processed, total_f64, progress_reporter); } #[inline] fn count_synthesized_records_attrs( &mut self, - synthesized_records: &mut SynthesizedRecords<'data_block>, - ) -> FnvHashMap<&'data_block DataBlockValue, isize> { - let mut current_counts: FnvHashMap<&'data_block DataBlockValue, isize> = - FnvHashMap::default(); + synthesized_records: &mut SynthesizedRecords, + ) -> FnvHashMap, isize> { + let mut current_counts: FnvHashMap, isize> = FnvHashMap::default(); for r in synthesized_records.iter() { for v in r.iter() { - let count = current_counts.entry(v).or_insert(0); + let count = current_counts.entry(v.clone()).or_insert(0); *count += 1; } } @@ -296,16 +284,16 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> #[inline] fn calc_exceeded_count_attrs( &mut self, - current_counts: &FnvHashMap<&'data_block DataBlockValue, isize>, - ) -> FnvHashMap<&'data_block DataBlockValue, isize> { - let mut targets: FnvHashMap<&'data_block DataBlockValue, isize> = FnvHashMap::default(); + current_counts: &FnvHashMap, isize>, + ) -> FnvHashMap, isize> { + let mut targets: FnvHashMap, isize> = FnvHashMap::default(); - for (attr, rows) in self.context.get_attr_rows_map().iter() { - if rows.len() >= self.context.get_resolution() { - let t = current_counts[attr] - - iround_down(rows.len() as f64, self.context.get_resolution() as f64); + for (attr, rows) in self.attr_rows_map.iter() { + if rows.len() >= self.resolution { + let t = + current_counts[attr] - iround_down(rows.len() as f64, self.resolution as f64); if t > 0 { - targets.insert(attr, t); + targets.insert(attr.clone(), t); } } } @@ -314,76 +302,73 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> fn suppress( &mut self, - synthesized_records: &mut SynthesizedRecords<'data_block>, + synthesized_records: &mut SynthesizedRecords, progress_reporter: &mut Option, ) where T: ReportProgress, { - measure_time!( - || { - info!("suppressing..."); + let _duration_logger = ElapsedDurationLogger::new("suppression"); + info!("suppressing..."); - let current_counts: FnvHashMap<&'data_block DataBlockValue, isize> = - self.count_synthesized_records_attrs(synthesized_records); - let mut targets: FnvHashMap<&'data_block DataBlockValue, isize> = - self.calc_exceeded_count_attrs(¤t_counts); - let total = synthesized_records.len() as f64; - let mut n_processed = 0; + let current_counts: FnvHashMap, isize> = + self.count_synthesized_records_attrs(synthesized_records); + let mut targets: FnvHashMap, isize> = + self.calc_exceeded_count_attrs(¤t_counts); + let total = synthesized_records.len() as f64; + let mut n_processed = 0; - synthesized_records.shuffle(&mut thread_rng()); + synthesized_records.shuffle(&mut thread_rng()); - for r in synthesized_records.iter_mut() { - let mut new_record = SynthesizedRecord::default(); + for r in synthesized_records.iter_mut() { + let mut new_record = SynthesizedRecord::default(); - self.update_suppress_progress(n_processed, total, progress_reporter); - n_processed += 1; - for attr in r.iter() { - match targets.get(attr).cloned() { - None => { - new_record.insert(attr); - } - Some(attr_count) => { - if attr_count == 1 { - targets.remove(attr); - } else { - targets.insert(attr, attr_count - 1); - } - } + self.update_suppress_progress(n_processed, total, progress_reporter); + n_processed += 1; + for attr in r.iter() { + match targets.get(attr).cloned() { + None => { + new_record.insert(attr.clone()); + } + Some(attr_count) => { + if attr_count == 1 { + targets.remove(attr); + } else { + targets.insert(attr.clone(), attr_count - 1); } } - *r = new_record; } - synthesized_records.retain(|r| !r.is_empty()); - self.update_suppress_progress(n_processed, total, progress_reporter); - }, - (self.durations.suppress) - ); + } + *r = new_record; + } + synthesized_records.retain(|r| !r.is_empty()); + self.update_suppress_progress(n_processed, total, progress_reporter); } fn synthesize_rows( &mut self, - synthesized_records: &mut SynthesizedRecords<'data_block>, + synthesized_records: &mut SynthesizedRecords, + rows_synthesizers: &mut Vec, progress_reporter: &mut Option, ) where T: ReportProgress, { - measure_time!( - || { - info!("synthesizing rows..."); + let _duration_logger = ElapsedDurationLogger::new("rows synthesis"); - let mut n_processed = 0; - let records = self.records; - let total = records.len() as f64; - - for seed in records.iter() { - self.update_synthesize_progress(n_processed, total, progress_reporter); - n_processed += 1; - synthesized_records.push(self.synthesize_row(seed)); - } - self.update_synthesize_progress(n_processed, total, progress_reporter); - }, - (self.durations.synthesize_rows) + info!( + "synthesizing rows with {} thread(s)...", + get_number_of_threads() ); + + let total = self.data_block.records.len() as f64; + + SeededRowsSynthesizer::synthesize_all( + total, + synthesized_records, + rows_synthesizers, + progress_reporter, + ); + + self.update_synthesize_progress(self.data_block.records.len(), total, progress_reporter); } #[inline] @@ -396,7 +381,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> T: ReportProgress, { if let Some(r) = progress_reporter { - self.synthesize_percentage = calc_percentage(n_processed, total); + self.synthesize_percentage = calc_percentage(n_processed as f64, total); r.report(self.calc_overall_progress()); } } @@ -411,7 +396,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> T: ReportProgress, { if let Some(r) = progress_reporter { - self.consolidate_percentage = calc_percentage(n_processed, total); + self.consolidate_percentage = calc_percentage(n_processed as f64, total); r.report(self.calc_overall_progress()); } } @@ -426,7 +411,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> T: ReportProgress, { if let Some(r) = progress_reporter { - self.suppress_percentage = calc_percentage(n_processed, total); + self.suppress_percentage = calc_percentage(n_processed as f64, total); r.report(self.calc_overall_progress()); } } @@ -438,9 +423,3 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> + self.suppress_percentage * 0.2 } } - -impl<'data_block, 'attr_rows_map> Drop for SeededSynthesizer<'data_block, 'attr_rows_map> { - fn drop(&mut self) { - trace!("seed synthesizer durations: {:#?}", self.durations); - } -} diff --git a/packages/core/src/processing/generator/synthesizer/seeded_rows_synthesizer.rs b/packages/core/src/processing/generator/synthesizer/seeded_rows_synthesizer.rs new file mode 100644 index 0000000..e9f85fc --- /dev/null +++ b/packages/core/src/processing/generator/synthesizer/seeded_rows_synthesizer.rs @@ -0,0 +1,127 @@ +use super::{ + context::SynthesizerContext, + typedefs::{NotAllowedAttrSet, SynthesizedRecord, SynthesizedRecords, SynthesizerSeed}, +}; +use crate::{ + data_block::typedefs::AttributeRowsMap, + utils::reporting::{SendableProgressReporter, SendableProgressReporterRef}, +}; +use std::sync::Arc; + +#[cfg(feature = "rayon")] +use rayon::prelude::*; + +#[cfg(feature = "rayon")] +use std::sync::Mutex; + +use crate::{ + data_block::{record::DataBlockRecord, typedefs::DataBlockRecords}, + utils::reporting::ReportProgress, +}; + +pub struct SeededRowsSynthesizer { + pub context: SynthesizerContext, + pub records: DataBlockRecords, + pub attr_rows_map: Arc, +} + +impl SeededRowsSynthesizer { + #[inline] + pub fn new( + context: SynthesizerContext, + records: DataBlockRecords, + attr_rows_map: Arc, + ) -> SeededRowsSynthesizer { + SeededRowsSynthesizer { + context, + records, + attr_rows_map, + } + } + + #[cfg(feature = "rayon")] + #[inline] + pub fn synthesize_all( + total: f64, + synthesized_records: &mut SynthesizedRecords, + rows_synthesizers: &mut Vec, + progress_reporter: &mut Option, + ) where + T: ReportProgress, + { + let sendable_pr = Arc::new(Mutex::new( + progress_reporter + .as_mut() + .map(|r| SendableProgressReporter::new(total, 0.5, r)), + )); + + synthesized_records.par_extend( + rows_synthesizers + .par_iter_mut() + .flat_map(|rs| rs.synthesize_rows(&mut sendable_pr.clone())), + ); + } + + #[cfg(not(feature = "rayon"))] + #[inline] + pub fn synthesize_all( + total: f64, + synthesized_records: &mut SynthesizedRecords, + rows_synthesizers: &mut Vec, + progress_reporter: &mut Option, + ) where + T: ReportProgress, + { + let mut sendable_pr = progress_reporter + .as_mut() + .map(|r| SendableProgressReporter::new(total, 0.5, r)); + + synthesized_records.extend( + rows_synthesizers + .iter_mut() + .flat_map(|rs| rs.synthesize_rows(&mut sendable_pr)), + ); + } + + #[inline] + fn synthesize_rows( + &mut self, + progress_reporter: &mut SendableProgressReporterRef, + ) -> SynthesizedRecords + where + T: ReportProgress, + { + let mut synthesized_records = SynthesizedRecords::default(); + let records = self.records.clone(); + + for seed in records.iter() { + synthesized_records.push(self.synthesize_row(seed)); + SendableProgressReporter::update_progress(progress_reporter, 1.0); + } + synthesized_records + } + + #[inline] + fn synthesize_row(&mut self, seed: &DataBlockRecord) -> SynthesizedRecord { + let current_seed: &SynthesizerSeed = &seed.values; + let mut synthesized_record = SynthesizedRecord::default(); + let not_allowed_attr_set = NotAllowedAttrSet::default(); + + loop { + let next = self.context.sample_next_attr_from_seed( + &synthesized_record, + current_seed, + ¬_allowed_attr_set, + &self.attr_rows_map, + ); + + match next { + Some(value) => { + synthesized_record.insert(value); + } + None => break, + } + } + synthesized_record + } +} diff --git a/packages/core/src/processing/generator/synthesizer/typedefs.rs b/packages/core/src/processing/generator/synthesizer/typedefs.rs index 3b2a556..c67dd40 100644 --- a/packages/core/src/processing/generator/synthesizer/typedefs.rs +++ b/packages/core/src/processing/generator/synthesizer/typedefs.rs @@ -1,33 +1,31 @@ use fnv::{FnvHashMap, FnvHashSet}; -use std::collections::BTreeSet; +use std::{collections::BTreeSet, sync::Arc}; use crate::data_block::value::DataBlockValue; /// If the data block value were added to the synthesized record, this maps to the /// number of rows the resulting attribute combination will be part of -pub type AttributeCountMap<'data_block_value> = - FnvHashMap<&'data_block_value DataBlockValue, usize>; +pub type AttributeCountMap = FnvHashMap, usize>; /// The seeds used for the current synthesis step (aka current record being processed) -pub type SynthesizerSeed<'data_block_value> = Vec<&'data_block_value DataBlockValue>; +pub type SynthesizerSeed = Vec>; /// Slice of SynthesizerSeed -pub type SynthesizerSeedSlice<'data_block_value> = [&'data_block_value DataBlockValue]; +pub type SynthesizerSeedSlice = [Arc]; /// Attributes not allowed to be sampled in a particular data synthesis stage -pub type NotAllowedAttrSet<'data_block_value> = FnvHashSet<&'data_block_value DataBlockValue>; +pub type NotAllowedAttrSet = FnvHashSet>; /// Record synthesized at a particular stage -pub type SynthesizedRecord<'data_block_value> = BTreeSet<&'data_block_value DataBlockValue>; +pub type SynthesizedRecord = BTreeSet>; /// Vector of synthesized records -pub type SynthesizedRecords<'data_block_value> = Vec>; +pub type SynthesizedRecords = Vec; /// Slice of SynthesizedRecords -pub type SynthesizedRecordsSlice<'data_block_value> = [SynthesizedRecord<'data_block_value>]; +pub type SynthesizedRecordsSlice = [SynthesizedRecord]; /// How many attributes (of the ones available) should be added during the /// consolidation step for the count to match the a multiple of reporting resolution /// (rounded down) -pub type AvailableAttrsMap<'data_block_value> = - FnvHashMap<&'data_block_value DataBlockValue, isize>; +pub type AvailableAttrsMap = FnvHashMap, isize>; diff --git a/packages/core/src/processing/generator/synthesizer/unseeded.rs b/packages/core/src/processing/generator/synthesizer/unseeded.rs new file mode 100644 index 0000000..ed77597 --- /dev/null +++ b/packages/core/src/processing/generator/synthesizer/unseeded.rs @@ -0,0 +1,163 @@ +use super::{ + context::SynthesizerContext, typedefs::SynthesizedRecords, + unseeded_rows_synthesizer::UnseededRowsSynthesizer, +}; +use log::info; +use std::sync::Arc; + +use crate::{ + data_block::{block::DataBlock, typedefs::AttributeRowsByColumnMap}, + utils::{ + math::calc_percentage, reporting::ReportProgress, threading::get_number_of_threads, + time::ElapsedDurationLogger, + }, +}; + +/// Represents all the information required to perform the unseeded data synthesis +pub struct UnseededSynthesizer { + /// Reference to the original data block + data_block: Arc, + /// Maps a data block value to all the rows where it occurs grouped by column + attr_rows_map_by_column: Arc, + /// Reporting resolution used for data synthesis + resolution: usize, + /// Maximum cache size allowed + cache_max_size: usize, + /// Empty values on the synthetic data will be represented by this + empty_value: Arc, + /// Percentage already completed on the row synthesis step + synthesize_percentage: f64, +} + +impl UnseededSynthesizer { + /// Returns a new UnseededSynthesizer + /// # Arguments + /// * `data_block` - Sensitive data to be synthesized + /// * `attr_rows_map_by_column` - Maps a data block value to all the rows where it occurs grouped by column + /// * `resolution` - Reporting resolution used for data synthesis + /// * `cache_max_size` - Maximum cache size allowed + /// * `empty_value` - Empty values on the synthetic data will be represented by this + #[inline] + pub fn new( + data_block: Arc, + attr_rows_map_by_column: Arc, + resolution: usize, + cache_max_size: usize, + empty_value: Arc, + ) -> UnseededSynthesizer { + UnseededSynthesizer { + data_block, + attr_rows_map_by_column, + resolution, + cache_max_size, + empty_value, + synthesize_percentage: 0.0, + } + } + + /// Performs the row synthesis + /// Returns the synthesized records + /// # Arguments + /// * `progress_reporter` - Will be used to report the processing + /// progress (`ReportProgress` trait). If `None`, nothing will be reported + pub fn run(&mut self, progress_reporter: &mut Option) -> SynthesizedRecords + where + T: ReportProgress, + { + let mut synthesized_records: SynthesizedRecords = SynthesizedRecords::new(); + + if !self.data_block.records.is_empty() { + let mut rows_synthesizers: Vec = + self.build_rows_synthesizers(); + + self.synthesize_percentage = 0.0; + + self.synthesize_rows( + &mut synthesized_records, + &mut rows_synthesizers, + progress_reporter, + ); + } + synthesized_records + } + + #[inline] + fn build_rows_synthesizers(&self) -> Vec { + let mut total_size = self.data_block.records.len(); + let chunk_size = ((total_size as f64) / (get_number_of_threads() as f64)).ceil() as usize; + let mut rows_synthesizers: Vec = Vec::default(); + + loop { + if total_size > chunk_size { + rows_synthesizers.push(UnseededRowsSynthesizer::new( + SynthesizerContext::new( + self.data_block.headers.len(), + self.data_block.records.len(), + self.resolution, + self.cache_max_size, + ), + chunk_size, + self.attr_rows_map_by_column.clone(), + self.empty_value.clone(), + )); + total_size -= chunk_size; + } else { + rows_synthesizers.push(UnseededRowsSynthesizer::new( + SynthesizerContext::new( + self.data_block.headers.len(), + self.data_block.records.len(), + self.resolution, + self.cache_max_size, + ), + total_size, + self.attr_rows_map_by_column.clone(), + self.empty_value.clone(), + )); + break; + } + } + rows_synthesizers + } + + fn synthesize_rows( + &mut self, + synthesized_records: &mut SynthesizedRecords, + rows_synthesizers: &mut Vec, + progress_reporter: &mut Option, + ) where + T: ReportProgress, + { + let _duration_logger = ElapsedDurationLogger::new("rows synthesis"); + + info!( + "synthesizing rows with {} thread(s)...", + get_number_of_threads() + ); + + let total = self.data_block.records.len() as f64; + + UnseededRowsSynthesizer::synthesize_all( + total, + synthesized_records, + rows_synthesizers, + progress_reporter, + ); + + self.update_synthesize_progress(self.data_block.records.len(), total, progress_reporter); + } + + #[inline] + fn update_synthesize_progress( + &mut self, + n_processed: usize, + total: f64, + progress_reporter: &mut Option, + ) where + T: ReportProgress, + { + if let Some(r) = progress_reporter { + self.synthesize_percentage = calc_percentage(n_processed as f64, total); + r.report(self.synthesize_percentage); + } + } +} diff --git a/packages/core/src/processing/generator/synthesizer/unseeded_rows_synthesizer.rs b/packages/core/src/processing/generator/synthesizer/unseeded_rows_synthesizer.rs new file mode 100644 index 0000000..75fe1e9 --- /dev/null +++ b/packages/core/src/processing/generator/synthesizer/unseeded_rows_synthesizer.rs @@ -0,0 +1,128 @@ +use super::{ + context::SynthesizerContext, + typedefs::{SynthesizedRecord, SynthesizedRecords}, +}; +use crate::{ + data_block::typedefs::{AttributeRows, AttributeRowsByColumnMap}, + utils::reporting::{SendableProgressReporter, SendableProgressReporterRef}, +}; +use rand::{prelude::SliceRandom, thread_rng}; +use std::sync::Arc; + +#[cfg(feature = "rayon")] +use rayon::prelude::*; + +#[cfg(feature = "rayon")] +use std::sync::Mutex; + +use crate::utils::reporting::ReportProgress; + +pub struct UnseededRowsSynthesizer { + context: SynthesizerContext, + chunk_size: usize, + column_indexes: Vec, + attr_rows_map_by_column: Arc, + empty_value: Arc, +} + +impl UnseededRowsSynthesizer { + #[inline] + pub fn new( + context: SynthesizerContext, + chunk_size: usize, + attr_rows_map_by_column: Arc, + empty_value: Arc, + ) -> UnseededRowsSynthesizer { + UnseededRowsSynthesizer { + context, + chunk_size, + column_indexes: attr_rows_map_by_column.keys().cloned().collect(), + attr_rows_map_by_column, + empty_value, + } + } + + #[cfg(feature = "rayon")] + #[inline] + pub fn synthesize_all( + total: f64, + synthesized_records: &mut SynthesizedRecords, + rows_synthesizers: &mut Vec, + progress_reporter: &mut Option, + ) where + T: ReportProgress, + { + let sendable_pr = Arc::new(Mutex::new( + progress_reporter + .as_mut() + .map(|r| SendableProgressReporter::new(total, 1.0, r)), + )); + + synthesized_records.par_extend( + rows_synthesizers + .par_iter_mut() + .flat_map(|rs| rs.synthesize_rows(&mut sendable_pr.clone())), + ); + } + + #[cfg(not(feature = "rayon"))] + #[inline] + pub fn synthesize_all( + total: f64, + synthesized_records: &mut SynthesizedRecords, + rows_synthesizers: &mut Vec, + progress_reporter: &mut Option, + ) where + T: ReportProgress, + { + let mut sendable_pr = progress_reporter + .as_mut() + .map(|r| SendableProgressReporter::new(total, 1.0, r)); + + synthesized_records.extend( + rows_synthesizers + .iter_mut() + .flat_map(|rs| rs.synthesize_rows(&mut sendable_pr)), + ); + } + + #[inline] + fn synthesize_rows( + &mut self, + progress_reporter: &mut SendableProgressReporterRef, + ) -> SynthesizedRecords + where + T: ReportProgress, + { + let mut synthesized_records = SynthesizedRecords::default(); + + for _ in 0..self.chunk_size { + synthesized_records.push(self.synthesize_row()); + SendableProgressReporter::update_progress(progress_reporter, 1.0); + } + synthesized_records + } + + #[inline] + fn synthesize_row(&mut self) -> SynthesizedRecord { + let mut synthesized_record = SynthesizedRecord::default(); + let mut current_attrs_rows: Arc = + Arc::new((0..self.context.records_len).collect()); + + self.column_indexes.shuffle(&mut thread_rng()); + + for column_index in self.column_indexes.iter() { + if let Some((next_attrs_rows, sample)) = self.context.sample_next_attr_from_column( + &synthesized_record, + *column_index, + &self.attr_rows_map_by_column, + ¤t_attrs_rows, + &self.empty_value, + ) { + current_attrs_rows = next_attrs_rows; + synthesized_record.insert(sample); + } + } + synthesized_record + } +} diff --git a/packages/core/src/utils/collections.rs b/packages/core/src/utils/collections.rs index f708d51..3411b55 100644 --- a/packages/core/src/utils/collections.rs +++ b/packages/core/src/utils/collections.rs @@ -4,8 +4,8 @@ use std::hash::Hash; /// Given two sorted vectors, calculates the intersection between them. /// Returns the result intersection as a new vector /// # Arguments -/// - `a`: first sorted vector -/// - `b`: second sorted vector +/// * `a` - first sorted vector +/// * `b` - second sorted vector #[inline] pub fn ordered_vec_intersection(a: &[T], b: &[T]) -> Vec where diff --git a/packages/core/src/utils/math.rs b/packages/core/src/utils/math.rs index 29dd51c..0cc7946 100644 --- a/packages/core/src/utils/math.rs +++ b/packages/core/src/utils/math.rs @@ -54,6 +54,6 @@ pub fn calc_n_combinations_range(n: usize, range: &[usize]) -> u64 { /// Calculates the percentage of processed elements up to a total #[inline] -pub fn calc_percentage(n_processed: usize, total: f64) -> f64 { - (n_processed as f64) * 100.0 / total +pub fn calc_percentage(n_processed: f64, total: f64) -> f64 { + n_processed * 100.0 / total } diff --git a/packages/core/src/utils/mod.rs b/packages/core/src/utils/mod.rs index 0213097..fe65517 100644 --- a/packages/core/src/utils/mod.rs +++ b/packages/core/src/utils/mod.rs @@ -1,8 +1,14 @@ /// Module for collection utilities pub mod collections; + /// Module for math utilities pub mod math; + /// Module for reporting utilities pub mod reporting; + +/// Module for threading utilities +pub mod threading; + /// Module for time utilities pub mod time; diff --git a/packages/core/src/utils/reporting.rs b/packages/core/src/utils/reporting/logger_progress_reporter.rs similarity index 78% rename from packages/core/src/utils/reporting.rs rename to packages/core/src/utils/reporting/logger_progress_reporter.rs index 340eb72..6045661 100644 --- a/packages/core/src/utils/reporting.rs +++ b/packages/core/src/utils/reporting/logger_progress_reporter.rs @@ -1,11 +1,6 @@ +use super::ReportProgress; use log::{log, Level}; -/// Implement this trait to inform progress -pub trait ReportProgress { - /// Receives the updated progress - fn report(&mut self, new_progress: f64); -} - /// Simple progress reporter using the default logger. /// * It will log progress using the configured `log_level` /// * It will only log at every 1% completed @@ -16,8 +11,8 @@ pub struct LoggerProgressReporter { impl LoggerProgressReporter { /// Returns a new LoggerProgressReporter - /// # Arguments: - /// * `log_level`: which log level use to log progress + /// # Arguments + /// * `log_level - which log level use to log progress pub fn new(log_level: Level) -> LoggerProgressReporter { LoggerProgressReporter { progress: 0.0, diff --git a/packages/core/src/utils/reporting/mod.rs b/packages/core/src/utils/reporting/mod.rs new file mode 100644 index 0000000..a330534 --- /dev/null +++ b/packages/core/src/utils/reporting/mod.rs @@ -0,0 +1,11 @@ +mod report_progress; + +mod logger_progress_reporter; + +mod sendable_progress_reporter; + +pub use report_progress::ReportProgress; + +pub use logger_progress_reporter::LoggerProgressReporter; + +pub use sendable_progress_reporter::{SendableProgressReporter, SendableProgressReporterRef}; diff --git a/packages/core/src/utils/reporting/report_progress.rs b/packages/core/src/utils/reporting/report_progress.rs new file mode 100644 index 0000000..de9158b --- /dev/null +++ b/packages/core/src/utils/reporting/report_progress.rs @@ -0,0 +1,5 @@ +/// Implement this trait to inform progress +pub trait ReportProgress { + /// Receives the updated progress + fn report(&mut self, new_progress: f64); +} diff --git a/packages/core/src/utils/reporting/sendable_progress_reporter.rs b/packages/core/src/utils/reporting/sendable_progress_reporter.rs new file mode 100644 index 0000000..ddf0f6d --- /dev/null +++ b/packages/core/src/utils/reporting/sendable_progress_reporter.rs @@ -0,0 +1,92 @@ +#[cfg(feature = "rayon")] +use std::sync::{Arc, Mutex}; + +use crate::utils::{math::calc_percentage, reporting::ReportProgress}; + +/// Progress reporter to be used in parallel environments +/// (should be wrapped with `Arc` and `Mutex` if multi-threading is enabled). +/// It will calculate `proportion * (n_processed * 100.0 / total)` and report +/// it to the main reporter +pub struct SendableProgressReporter<'main_reporter, T> +where + T: ReportProgress, +{ + total: f64, + n_processed: f64, + proportion: f64, + main_reporter: &'main_reporter mut T, +} + +impl<'main_reporter, T> SendableProgressReporter<'main_reporter, T> +where + T: ReportProgress, +{ + /// Creates a new SendableProgressReporter + /// # Arguments + /// * `total` - Total number of steps to be reported + /// * `proportion` - Value to multiply the percentage before reporting to + /// main_reporter + /// * `main_reporter` - Main reporter to which this should report + pub fn new( + total: f64, + proportion: f64, + main_reporter: &'main_reporter mut T, + ) -> SendableProgressReporter { + SendableProgressReporter { + total, + proportion, + n_processed: 0.0, + main_reporter, + } + } + + /// Updates `reporter` by adding `value_to_add` and reporting + /// to the main reporter in a thread safe way + #[inline] + pub fn update_progress(reporter: &mut SendableProgressReporterRef, value_to_add: f64) + where + T: ReportProgress, + { + #[cfg(feature = "rayon")] + if let Ok(guard) = &mut reporter.lock() { + if let Some(r) = guard.as_mut() { + r.report(value_to_add); + } + } + #[cfg(not(feature = "rayon"))] + if let Some(r) = reporter { + r.report(value_to_add); + } + } +} + +impl<'main_reporter, T> ReportProgress for SendableProgressReporter<'main_reporter, T> +where + T: ReportProgress, +{ + /// Will add `value_to_add` to `n_processed` and call the main reporter with + /// `proportion * (n_processed * 100.0 / total)` + fn report(&mut self, value_to_add: f64) { + self.n_processed += value_to_add; + self.main_reporter + .report(self.proportion * calc_percentage(self.n_processed, self.total)) + } +} + +#[cfg(feature = "rayon")] +/// People using this should correctly handle race conditions with a `Mutex` +/// (see `SendableProgressReporterRef`), so we inform the compiler this struct is thread safe +unsafe impl<'main_reporter, T> Send for SendableProgressReporter<'main_reporter, T> where + T: ReportProgress +{ +} + +#[cfg(feature = "rayon")] +/// Use this to refer to SendableProgressReporter if multi-threading is enabled +pub type SendableProgressReporterRef<'main_reporter, T> = + Arc>>>; + +#[cfg(not(feature = "rayon"))] +/// Use this to refer to SendableProgressReporter if multi-threading is disabled +pub type SendableProgressReporterRef<'main_reporter, T> = + Option>; diff --git a/packages/core/src/utils/threading.rs b/packages/core/src/utils/threading.rs new file mode 100644 index 0000000..de3baa5 --- /dev/null +++ b/packages/core/src/utils/threading.rs @@ -0,0 +1,34 @@ +#[cfg(feature = "rayon")] +use rayon; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +pub fn get_number_of_threads() -> usize { + #[cfg(feature = "rayon")] + { + rayon::current_num_threads() + } + #[cfg(not(feature = "rayon"))] + { + 1 + } +} + +#[cfg(feature = "rayon")] +#[cfg_attr(feature = "pyo3", pyfunction)] +#[allow(unused_must_use)] +/// Sets the number of threads used for parallel processing +/// # Arguments +/// * `n` - number of threads +pub fn set_number_of_threads(n: usize) { + rayon::ThreadPoolBuilder::new() + .num_threads(n) + .build_global(); +} + +#[cfg(feature = "pyo3")] +pub fn register(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(set_number_of_threads, m)?)?; + Ok(()) +} diff --git a/packages/core/src/utils/time.rs b/packages/core/src/utils/time.rs index 41fd7ab..fd0b0d9 100644 --- a/packages/core/src/utils/time.rs +++ b/packages/core/src/utils/time.rs @@ -16,7 +16,7 @@ pub struct ElapsedDuration<'lifetime> { impl<'lifetime> ElapsedDuration<'lifetime> { /// Returns a new ElapsedDuration /// # Arguments - /// - `result`: Duration where to sum the elapsed duration when the `ElapsedDuration` + /// * `result` - Duration where to sum the elapsed duration when the `ElapsedDuration` /// instance goes out of scope pub fn new(result: &mut Duration) -> ElapsedDuration { ElapsedDuration { @@ -41,10 +41,10 @@ pub struct ElapsedDurationLogger { } impl ElapsedDurationLogger { - pub fn new(message: String) -> ElapsedDurationLogger { + pub fn new>(message: S) -> ElapsedDurationLogger { ElapsedDurationLogger { _start: Instant::now(), - _message: message, + _message: message.into(), } } } diff --git a/packages/lib-python/Cargo.toml b/packages/lib-python/Cargo.toml index 61f630c..0aa5061 100644 --- a/packages/lib-python/Cargo.toml +++ b/packages/lib-python/Cargo.toml @@ -14,4 +14,5 @@ crate-type = ["cdylib"] log = { version = "0.4", features = ["std"] } csv = { version = "1.1" } pyo3 = { version = "0.15", features = ["extension-module"] } -sds-core = { path = "../core", features = ["pyo3"] } \ No newline at end of file +sds-core = { path = "../core", features = ["pyo3", "rayon"] } +env_logger = { version = "0.9" } \ No newline at end of file diff --git a/packages/lib-python/src/data_processor.rs b/packages/lib-python/src/data_processor.rs new file mode 100644 index 0000000..8084437 --- /dev/null +++ b/packages/lib-python/src/data_processor.rs @@ -0,0 +1,133 @@ +use csv::ReaderBuilder; +use pyo3::prelude::*; +use sds_core::{ + data_block::{ + block::DataBlock, csv_block_creator::CsvDataBlockCreator, csv_io_error::CsvIOError, + data_block_creator::DataBlockCreator, + }, + processing::{ + aggregator::{aggregated_data::AggregatedData, Aggregator}, + generator::{generated_data::GeneratedData, Generator, SynthesisMode}, + }, + utils::reporting::LoggerProgressReporter, +}; +use std::sync::Arc; + +#[pyclass] +/// Processor exposing the main features +pub struct SDSProcessor { + data_block: Arc, +} + +#[pymethods] +impl SDSProcessor { + /// Creates a new processor by reading the content of a given file + /// # Arguments + /// * `path` - File to be read to build the data block + /// * `delimiter` - CSV/TSV separator for the content on `path` + /// * `use_columns` - Column names to be used (if `[]` use all columns) + /// * `sensitive_zeros` - Column names containing sensitive zeros + /// (if `[]` no columns are considered to have sensitive zeros) + /// * `record_limit` - Use only the first `record_limit` records (if `0` use all records) + #[new] + pub fn new( + path: &str, + delimiter: char, + use_columns: Vec, + sensitive_zeros: Vec, + record_limit: usize, + ) -> Result { + match CsvDataBlockCreator::create( + ReaderBuilder::new() + .delimiter(delimiter as u8) + .from_path(path), + &use_columns, + &sensitive_zeros, + record_limit, + ) { + Ok(data_block) => Ok(SDSProcessor { data_block }), + Err(err) => Err(CsvIOError::new(err)), + } + } + + #[staticmethod] + /// Load the SDS Processor for the data block linked to `aggregated_data` + pub fn from_aggregated_data(aggregated_data: &AggregatedData) -> SDSProcessor { + SDSProcessor { + data_block: aggregated_data.data_block.clone(), + } + } + + #[inline] + /// Returns the number of records on the data block + pub fn number_of_records(&self) -> usize { + self.data_block.number_of_records() + } + + #[inline] + /// Returns the number of records on the data block protected by `resolution` + pub fn protected_number_of_records(&self, resolution: usize) -> usize { + self.data_block.protected_number_of_records(resolution) + } + + #[inline] + /// Normalizes the reporting length based on the number of selected headers. + /// Returns the normalized value + /// # Arguments + /// * `reporting_length` - Reporting length to be normalized (0 means use all columns) + pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize { + self.data_block.normalize_reporting_length(reporting_length) + } + + /// Builds the aggregated data for the content + /// using the specified `reporting_length` and `sensitivity_threshold`. + /// The result is written to `sensitive_aggregates_path` and `reportable_aggregates_path`. + /// # Arguments + /// * `reporting_length` - Maximum length to compute attribute combinations + /// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes + /// (0 means no suppression) + pub fn aggregate( + &self, + reporting_length: usize, + sensitivity_threshold: usize, + ) -> AggregatedData { + let mut progress_reporter: Option = None; + let mut aggregator = Aggregator::new(self.data_block.clone()); + aggregator.aggregate( + reporting_length, + sensitivity_threshold, + &mut progress_reporter, + ) + } + + /// Synthesizes the content using the specified `resolution` and + /// returns the generated data + /// # Arguments + /// * `cache_max_size` - Maximum cache size used during the synthesis process + /// * `resolution` - Reporting resolution to be used + /// * `empty_value` - Empty values on the synthetic data will be represented by this + /// * `seeded` - True for seeded synthesis, False for unseeded + pub fn generate( + &self, + cache_max_size: usize, + resolution: usize, + empty_value: String, + seeded: bool, + ) -> GeneratedData { + let mut progress_reporter: Option = None; + let mut generator = Generator::new(self.data_block.clone()); + let mode = if seeded { + SynthesisMode::Seeded + } else { + SynthesisMode::Unseeded + }; + + generator.generate( + resolution, + cache_max_size, + empty_value, + mode, + &mut progress_reporter, + ) + } +} diff --git a/packages/lib-python/src/lib.rs b/packages/lib-python/src/lib.rs index 92b2673..9b7b574 100644 --- a/packages/lib-python/src/lib.rs +++ b/packages/lib-python/src/lib.rs @@ -1,160 +1,24 @@ //! This crate will generate python bindings for the main features //! of the `sds_core` library. - -use csv::ReaderBuilder; -use pyo3::{exceptions::PyIOError, prelude::*}; +use data_processor::SDSProcessor; +use pyo3::prelude::*; use sds_core::{ - data_block::{ - block::{DataBlock, DataBlockCreator}, - csv_block_creator::CsvDataBlockCreator, - }, - processing::{ - aggregator::{typedefs::AggregatedCountByLenMap, Aggregator}, - generator::Generator, - }, - utils::reporting::LoggerProgressReporter, + processing::{aggregator::aggregated_data, evaluator}, + utils::threading, }; -/// Creates a data block by reading the content of a given file -/// # Arguments -/// * `path` - File to be read to build the data block -/// * `delimiter` - CSV/TSV separator for the content on `path` -/// * `use_columns` - Column names to be used (if `[]` use all columns) -/// * `sensitive_zeros` - Column names containing sensitive zeros -/// (if `[]` no columns are considered to have sensitive zeros) -/// * `record_limit` - Use only the first `record_limit` records (if `0` use all records) -#[pyfunction] -pub fn create_data_block_from_file( - path: &str, - delimiter: char, - use_columns: Vec, - sensitive_zeros: Vec, - record_limit: usize, -) -> PyResult { - match CsvDataBlockCreator::create( - ReaderBuilder::new() - .delimiter(delimiter as u8) - .from_path(path), - &use_columns, - &sensitive_zeros, - record_limit, - ) { - Ok(data_block) => Ok(data_block), - Err(err) => Err(PyIOError::new_err(format!( - "error reading data block: {}", - err - ))), - } -} - -/// Builds the protected and aggregated data for the content in `data_block` -/// using the specified `reporting_length` and `resolution`. -/// The result is written to `sensitive_aggregates_path` and `reportable_aggregates_path`. -/// -/// The function returns a tuple (combo_counts, rare_counts). -/// -/// * `combo_counts` - computed number of distinct combinations grouped by combination length (not protected) -/// * `rare_counts` - computed number of rare combinations (`count < resolution`) grouped by combination length (not protected) -/// -/// # Arguments -/// * `data_block` - Data block with the content to be aggregated -/// * `sensitive_aggregates_path` - Path to write the sensitive aggregates -/// * `reportable_aggregates_path` - Path to write the aggregates with protected count -/// (rounded down to the nearest multiple of `resolution`) -/// * `delimiter` - CSV/TSV separator for the content be written in -/// `sensitive_aggregates_path` and `reportable_aggregates_path` -/// * `reporting_length` - Maximum length to compute attribute combinations -/// * `resolution` - Reporting resolution to be used -#[pyfunction] -pub fn aggregate_and_write( - data_block: &DataBlock, - sensitive_aggregates_path: &str, - reportable_aggregates_path: &str, - delimiter: char, - reporting_length: usize, - resolution: usize, -) -> PyResult<(AggregatedCountByLenMap, AggregatedCountByLenMap)> { - let mut progress_reporter: Option = None; - let mut aggregator = Aggregator::new(data_block); - let mut aggregated_data = aggregator.aggregate(reporting_length, 0, &mut progress_reporter); - let combo_count = aggregator.calc_combinations_count_by_len(&aggregated_data.aggregates_count); - let rare_count = aggregator - .calc_rare_combinations_count_by_len(&aggregated_data.aggregates_count, resolution); - - if let Err(err) = aggregator.write_aggregates_count( - &aggregated_data.aggregates_count, - sensitive_aggregates_path, - delimiter, - resolution, - false, - ) { - return Err(PyIOError::new_err(format!( - "error writing sensitive aggregates: {}", - err - ))); - } - - Aggregator::protect_aggregates_count(&mut aggregated_data.aggregates_count, resolution); - - match aggregator.write_aggregates_count( - &aggregated_data.aggregates_count, - reportable_aggregates_path, - delimiter, - resolution, - true, - ) { - Ok(_) => Ok((combo_count, rare_count)), - Err(err) => Err(PyIOError::new_err(format!( - "error writing reportable aggregates: {}", - err - ))), - } -} - -/// Synthesizes the `data_block` using the specified `resolution`. -/// The result is written to `synthetic_path`. -/// -/// The function returns the expansion ratio: -/// `Synthetic data length / Sensitive data length` (header not included) -/// -/// # Arguments -/// * `data_block` - Data block with the content to be synthesized -/// * `synthetic_path` - Path to write the synthetic data -/// * `synthetic_delimiter` - CSV/TSV separator for the content be written in `synthetic_path` -/// * `cache_max_size` - Maximum cache size used during the synthesis process -/// * `resolution` - Reporting resolution to be used -#[pyfunction] -pub fn generate_seeded_and_write( - data_block: &DataBlock, - synthetic_path: &str, - synthetic_delimiter: char, - cache_max_size: usize, - resolution: usize, -) -> PyResult { - let mut progress_reporter: Option = None; - let mut generator = Generator::new(data_block); - let generated_data = generator.generate(resolution, cache_max_size, "", &mut progress_reporter); - - match generator.write_records( - &generated_data.synthetic_data, - synthetic_path, - synthetic_delimiter as u8, - ) { - Ok(_) => Ok(generated_data.expansion_ratio), - Err(err) => Err(PyIOError::new_err(format!( - "error writing synthetic data: {}", - err - ))), - } -} +/// Module that exposes the main processor +pub mod data_processor; /// A Python module implemented in Rust. The name of this function must match /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to /// import the module. #[pymodule] -fn sds(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(create_data_block_from_file, m)?)?; - m.add_function(wrap_pyfunction!(aggregate_and_write, m)?)?; - m.add_function(wrap_pyfunction!(generate_seeded_and_write, m)?)?; +fn sds(py: Python, m: &PyModule) -> PyResult<()> { + env_logger::init(); + m.add_class::()?; + threading::register(py, m)?; + aggregated_data::register(py, m)?; + evaluator::register(py, m)?; Ok(()) } diff --git a/packages/lib-wasm/Cargo.toml b/packages/lib-wasm/Cargo.toml index 821024a..170b3f1 100644 --- a/packages/lib-wasm/Cargo.toml +++ b/packages/lib-wasm/Cargo.toml @@ -19,7 +19,7 @@ csv = { version = "1.1" } web-sys = { version = "0.3", features = [ "console" ]} sds-core = { path = "../core" } js-sys = { version = "0.3" } -serde = { version = "1.0", features = [ "derive" ] } +serde = { version = "1.0", features = [ "derive", "rc" ] } # The `console_error_panic_hook` crate provides better debugging of panics by # logging them with `console.error`. This is great for development, but requires diff --git a/packages/lib-wasm/src/aggregator/mod.rs b/packages/lib-wasm/src/aggregator/mod.rs deleted file mode 100644 index b7cc3e2..0000000 --- a/packages/lib-wasm/src/aggregator/mod.rs +++ /dev/null @@ -1,116 +0,0 @@ -use js_sys::Object; -use js_sys::Reflect::set; -use log::error; -use sds_core::data_block::block::DataBlock; -use sds_core::data_block::typedefs::DataBlockHeaders; -use sds_core::processing::aggregator::typedefs::{AggregatedCountByLenMap, AggregatesCountMap}; -use sds_core::processing::aggregator::{Aggregator, PrivacyRiskSummary}; -use sds_core::utils::reporting::ReportProgress; -use sds_core::utils::time::ElapsedDurationLogger; -use serde::Serialize; -use wasm_bindgen::JsValue; - -use crate::utils::js::serializers::{serialize_aggregates_count, serialize_privacy_risk}; -use crate::{match_or_return_undefined, set_or_return_undefined}; - -#[derive(Serialize)] -pub struct AggregatedCombination { - combination_key: String, - count: usize, - length: usize, -} - -impl AggregatedCombination { - #[inline] - pub fn new(combination_key: String, count: usize, length: usize) -> AggregatedCombination { - AggregatedCombination { - combination_key, - count, - length, - } - } -} - -impl From for JsValue { - #[inline] - fn from(aggregated_combination: AggregatedCombination) -> Self { - match_or_return_undefined!(JsValue::from_serde(&aggregated_combination)) - } -} - -pub struct AggregatedResult<'data_block> { - pub headers: &'data_block DataBlockHeaders, - pub aggregates_count: AggregatesCountMap<'data_block>, - pub rare_combinations_count_by_len: AggregatedCountByLenMap, - pub combinations_count_by_len: AggregatedCountByLenMap, - pub combinations_sum_by_len: AggregatedCountByLenMap, - pub privacy_risk: PrivacyRiskSummary, -} - -impl<'data_block> From> for JsValue { - #[inline] - fn from(aggregated_result: AggregatedResult) -> Self { - let _duration_logger = - ElapsedDurationLogger::new(String::from("AggregatedResult conversion to JsValue")); - let result = Object::new(); - - set_or_return_undefined!( - &result, - &"aggregatedCombinations".into(), - &serialize_aggregates_count( - &aggregated_result.aggregates_count, - aggregated_result.headers - ), - ); - set_or_return_undefined!( - &result, - &"rareCombinationsCountByLen".into(), - &match_or_return_undefined!(JsValue::from_serde( - &aggregated_result.rare_combinations_count_by_len - )) - ); - set_or_return_undefined!( - &result, - &"combinationsCountByLen".into(), - &match_or_return_undefined!(JsValue::from_serde( - &aggregated_result.combinations_count_by_len - )) - ); - set_or_return_undefined!( - &result, - &"combinationsSumByLen".into(), - &match_or_return_undefined!(JsValue::from_serde( - &aggregated_result.combinations_sum_by_len - )) - ); - set_or_return_undefined!( - &result, - &"privacyRisk".into(), - &serialize_privacy_risk(&aggregated_result.privacy_risk) - ); - - result.into() - } -} - -pub fn aggregate<'data_block, T: ReportProgress>( - data_block: &'data_block DataBlock, - reporting_length: usize, - resolution: usize, - progress_reporter: &mut Option, -) -> Result, String> { - let mut aggregator = Aggregator::new(data_block); - let aggregates_data = aggregator.aggregate(reporting_length, 0, progress_reporter); - - Ok(AggregatedResult { - headers: &data_block.headers, - rare_combinations_count_by_len: aggregator - .calc_rare_combinations_count_by_len(&aggregates_data.aggregates_count, resolution), - combinations_count_by_len: aggregator - .calc_combinations_count_by_len(&aggregates_data.aggregates_count), - combinations_sum_by_len: aggregator - .calc_combinations_sum_by_len(&aggregates_data.aggregates_count), - privacy_risk: aggregator.calc_privacy_risk(&aggregates_data.aggregates_count, resolution), - aggregates_count: aggregates_data.aggregates_count, - }) -} diff --git a/packages/lib-wasm/src/evaluator/mod.rs b/packages/lib-wasm/src/evaluator/mod.rs deleted file mode 100644 index 41cc4cc..0000000 --- a/packages/lib-wasm/src/evaluator/mod.rs +++ /dev/null @@ -1,170 +0,0 @@ -use js_sys::Reflect::set; -use js_sys::{Array, Function, Object}; -use log::{error, info}; -use sds_core::processing::aggregator::typedefs::AggregatedCountByLenMap; -use sds_core::processing::aggregator::Aggregator; -use sds_core::processing::evaluator::typedefs::PreservationByCountBuckets; -use sds_core::processing::evaluator::Evaluator; -use sds_core::utils::reporting::ReportProgress; -use sds_core::utils::time::ElapsedDurationLogger; -use wasm_bindgen::prelude::*; -use wasm_bindgen::JsValue; - -use crate::aggregator::{aggregate, AggregatedResult}; -use crate::utils::js::deserializers::{ - deserialize_csv_data, deserialize_sensitive_zeros, deserialize_use_columns, -}; -use crate::utils::js::js_progress_reporter::JsProgressReporter; -use crate::utils::js::serializers::serialize_buckets; -use crate::{match_or_return_undefined, set_or_return_undefined}; - -struct EvaluatedResult<'data_block> { - sensitive_aggregated_result: AggregatedResult<'data_block>, - synthetic_aggregated_result: AggregatedResult<'data_block>, - leakage_count_by_len: AggregatedCountByLenMap, - fabricated_count_by_len: AggregatedCountByLenMap, - preservation_by_count_buckets: PreservationByCountBuckets, - combination_loss: f64, - record_expansion: f64, -} - -impl<'data_block> From> for JsValue { - #[inline] - fn from(evaluated_result: EvaluatedResult) -> Self { - let _duration_logger = - ElapsedDurationLogger::new(String::from("EvaluatedResult conversion to JsValue")); - let result = Object::new(); - - set_or_return_undefined!( - &result, - &"sensitiveAggregatedResult".into(), - &evaluated_result.sensitive_aggregated_result.into(), - ); - set_or_return_undefined!( - &result, - &"syntheticAggregatedResult".into(), - &evaluated_result.synthetic_aggregated_result.into(), - ); - set_or_return_undefined!( - &result, - &"leakageCountByLen".into(), - &match_or_return_undefined!(JsValue::from_serde( - &evaluated_result.leakage_count_by_len - )) - ); - set_or_return_undefined!( - &result, - &"fabricatedCountByLen".into(), - &match_or_return_undefined!(JsValue::from_serde( - &evaluated_result.fabricated_count_by_len - )) - ); - set_or_return_undefined!( - &result, - &"preservationByCountBuckets".into(), - &serialize_buckets(&evaluated_result.preservation_by_count_buckets) - ); - set_or_return_undefined!( - &result, - &"combinationLoss".into(), - &match_or_return_undefined!(JsValue::from_serde(&evaluated_result.combination_loss)) - ); - set_or_return_undefined!( - &result, - &"recordExpansion".into(), - &match_or_return_undefined!(JsValue::from_serde(&evaluated_result.record_expansion)) - ); - - result.into() - } -} - -#[wasm_bindgen] -#[allow(clippy::too_many_arguments)] -pub fn evaluate( - sensitive_csv_data: Array, - synthetic_csv_data: Array, - use_columns: Array, - sensitive_zeros: Array, - record_limit: usize, - reporting_length: usize, - resolution: usize, - progress_callback: Function, -) -> JsValue { - let _duration_logger = ElapsedDurationLogger::new(String::from("evaluation process")); - let use_columns_vec = match_or_return_undefined!(deserialize_use_columns(use_columns)); - let sensitive_zeros_vec = - match_or_return_undefined!(deserialize_sensitive_zeros(sensitive_zeros)); - - info!("aggregating sensitive data..."); - let sensitive_data_block = match_or_return_undefined!(deserialize_csv_data( - sensitive_csv_data, - &use_columns_vec, - &sensitive_zeros_vec, - record_limit, - )); - let mut sensitive_aggregated_result = match_or_return_undefined!(aggregate( - &sensitive_data_block, - reporting_length, - resolution, - &mut Some(JsProgressReporter::new(&progress_callback, &|p| p * 0.40)) - )); - - info!("aggregating synthetic data..."); - let synthetic_data_block = match_or_return_undefined!(deserialize_csv_data( - synthetic_csv_data, - &use_columns_vec, - &sensitive_zeros_vec, - record_limit, - )); - let synthetic_aggregated_result = match_or_return_undefined!(aggregate( - &synthetic_data_block, - reporting_length, - resolution, - &mut Some(JsProgressReporter::new(&progress_callback, &|p| { - 40.0 + (p * 0.40) - })) - )); - let mut evaluator_instance = Evaluator::default(); - - info!("evaluating synthetic data based on sensitive data..."); - - let buckets = evaluator_instance.calc_preservation_by_count( - &sensitive_aggregated_result.aggregates_count, - &synthetic_aggregated_result.aggregates_count, - resolution, - ); - let leakage_count_by_len = evaluator_instance.calc_leakage_count( - &sensitive_aggregated_result.aggregates_count, - &synthetic_aggregated_result.aggregates_count, - resolution, - ); - let fabricated_count_by_len = evaluator_instance.calc_fabricated_count( - &sensitive_aggregated_result.aggregates_count, - &synthetic_aggregated_result.aggregates_count, - ); - let combination_loss = evaluator_instance.calc_combination_loss(&buckets); - - Aggregator::protect_aggregates_count( - &mut sensitive_aggregated_result.aggregates_count, - resolution, - ); - - JsProgressReporter::new(&progress_callback, &|p| p).report(100.0); - - EvaluatedResult { - leakage_count_by_len, - fabricated_count_by_len, - combination_loss, - preservation_by_count_buckets: buckets, - record_expansion: (synthetic_aggregated_result - .privacy_risk - .total_number_of_records as f64) - / (sensitive_aggregated_result - .privacy_risk - .total_number_of_records as f64), - sensitive_aggregated_result, - synthetic_aggregated_result, - } - .into() -} diff --git a/packages/lib-wasm/src/generator/mod.rs b/packages/lib-wasm/src/generator/mod.rs deleted file mode 100644 index c5995aa..0000000 --- a/packages/lib-wasm/src/generator/mod.rs +++ /dev/null @@ -1,45 +0,0 @@ -use js_sys::{Array, Function}; -use log::error; -use sds_core::{processing::generator::Generator, utils::time::ElapsedDurationLogger}; -use wasm_bindgen::prelude::*; - -use crate::{ - match_or_return_undefined, - utils::js::{ - deserializers::{ - deserialize_csv_data, deserialize_sensitive_zeros, deserialize_use_columns, - }, - js_progress_reporter::JsProgressReporter, - }, -}; - -#[wasm_bindgen] -pub fn generate( - csv_data: Array, - use_columns: Array, - sensitive_zeros: Array, - record_limit: usize, - resolution: usize, - cache_size: usize, - progress_callback: Function, -) -> JsValue { - let _duration_logger = ElapsedDurationLogger::new(String::from("generation process")); - let data_block = match_or_return_undefined!(deserialize_csv_data( - csv_data, - &match_or_return_undefined!(deserialize_use_columns(use_columns)), - &match_or_return_undefined!(deserialize_sensitive_zeros(sensitive_zeros)), - record_limit, - )); - let mut generator = Generator::new(&data_block); - - match_or_return_undefined!(JsValue::from_serde( - &generator - .generate( - resolution, - cache_size, - "", - &mut Some(JsProgressReporter::new(&progress_callback, &|p| p)) - ) - .synthetic_data - )) -} diff --git a/packages/lib-wasm/src/lib.rs b/packages/lib-wasm/src/lib.rs index e178be9..b7a5a61 100644 --- a/packages/lib-wasm/src/lib.rs +++ b/packages/lib-wasm/src/lib.rs @@ -1,141 +1,3 @@ -//! This crate will generate wasm bindings for the main features -//! of the `sds_core` library: -//! -//! # Init Logger -//! ```typescript -//! init_logger( -//! level_str: string -//! ): boolean -//! ``` -//! -//! Initializes logging using a particular log level. -//! Returns `true` if successfully initialized, `false` otherwise -//! -//! ## Arguments: -//! * `level_str` - String containing a valid log level -//! (`off`, `error`, `warn`, `info`, `debug` or `trace`) -//! -//! # Generate -//! ```typescript -//! generate( -//! csv_data: string[][], -//! use_columns: string[], -//! sensitive_zeros: string[], -//! record_limit: number, -//! resolution: number, -//! cache_size: number, -//! progress_callback: (progress: value) => void -//! ): string[][] -//! ``` -//! -//! Synthesizes the `csv_data` using the configured parameters and returns it -//! -//! ## Arguments -//! * `csv_data` - Data to be synthesized -//! * `csv_data[0]` - Should be the headers -//! * `csv_data[1...]` - Should be the records -//! * `use_columns` - Column names to be used (if `[]` use all columns) -//! * `sensitive_zeros` - Column names containing sensitive zeros -//! (if `[]` no columns are considered to have sensitive zeros) -//! * `record_limit` - Use only the first `record_limit` records (if `0` use all records) -//! * `resolution` - Reporting resolution to be used -//! * `cache_size` - Maximum cache size used during the synthesis process -//! * `progress_callback` - Callback that informs the processing percentage (0.0 % - 100.0 %) -//! -//! # Evaluate -//! ```typescript -//! interface IAggregatedCombination { -//! combination_key: number -//! count: number -//! length: number -//! } -//! -//! interface IAggregatedCombinations { -//! [name: string]: IAggregatedCombination -//! } -//! -//! interface IAggregatedCountByLen { -//! [length: number]: number -//! } -//! -//! interface IPrivacyRiskSummary { -//! totalNumberOfRecords: number -//! totalNumberOfCombinations: number -//! recordsWithUniqueCombinationsCount: number -//! recordsWithRareCombinationsCount: number -//! uniqueCombinationsCount: number -//! rareCombinationsCount: number -//! recordsWithUniqueCombinationsProportion: number -//! recordsWithRareCombinationsProportion: number -//! uniqueCombinationsProportion: number -//! rareCombinationsProportion: number -//! } -//! -//! interface IAggregatedResult { -//! aggregatedCombinations?: IAggregatedCombinations -//! rareCombinationsCountByLen?: IAggregatedCountByLen -//! combinationsCountByLen?: IAggregatedCountByLen -//! combinationsSumByLen?: IAggregatedCountByLen -//! privacyRisk?: IPrivacyRiskSummary -//! } -//! -//! interface IPreservationByCountBucket { -//! size: number -//! preservationSum: number -//! lengthSum: number -//! } -//! -//! interface IPreservationByCountBuckets { -//! [bucket_index: number]: IPreservationByCountBucket -//! } -//! -//! interface IEvaluatedResult { -//! sensitiveAggregatedResult?: IAggregatedResult -//! syntheticAggregatedResult?: IAggregatedResult -//! leakageCountByLen?: IAggregatedCountByLen -//! fabricatedCountByLen?: IAggregatedCountByLen -//! preservationByCountBuckets?: IPreservationByCountBuckets -//! combinationLoss?: number -//! recordExpansion?: number -//! } -//! -//! evaluate( -//! sensitive_csv_data: string[][], -//! synthetic_csv_data: string[][], -//! use_columns: string[], -//! sensitive_zeros: string[], -//! record_limit: number, -//! reporting_length: number, -//! resolution: number, -//! progress_callback: (progress: value) => void -//! ): IEvaluatedResult -//! ``` -//! Evaluates the synthetic data based on the sensitive data and produces a `IEvaluatedResult` -//! -//! ## Arguments -//! * `sensitive_csv_data` - Sensitive data to be evaluated -//! * `sensitive_csv_data[0]` - Should be the headers -//! * `sensitive_csv_data[1...]` - Should be the records -//! * `synthetic_csv_data` - Synthetic data produced from the synthetic data -//! * `synthetic_csv_data[0]` - Should be the headers -//! * `synthetic_csv_data[1...]` - Should be the records -//! * `use_columns` - Column names to be used (if `[]` use all columns) -//! * `sensitive_zeros` - Column names containing sensitive zeros -//! (if `[]` no columns are considered to have sensitive zeros) -//! * `record_limit` - Use only the first `record_limit` records (if `0` use all records) -//! * `reporting_length` - Maximum length to compute attribute combinations -//! for analysis -//! * `resolution` - Reporting resolution to be used -//! * `progress_callback` - Callback that informs the processing percentage (0.0 % - 100.0 %) +pub mod processing; -#[doc(hidden)] -pub mod aggregator; - -#[doc(hidden)] -pub mod evaluator; - -#[doc(hidden)] -pub mod generator; - -#[doc(hidden)] pub mod utils; diff --git a/packages/lib-wasm/src/processing/aggregator/aggregate_count_and_length.rs b/packages/lib-wasm/src/processing/aggregator/aggregate_count_and_length.rs new file mode 100644 index 0000000..cf0b070 --- /dev/null +++ b/packages/lib-wasm/src/processing/aggregator/aggregate_count_and_length.rs @@ -0,0 +1,31 @@ +use serde::Serialize; +use std::convert::TryFrom; +use wasm_bindgen::{JsCast, JsValue}; + +use crate::utils::js::ts_definitions::JsAggregateCountAndLength; + +#[derive(Serialize)] +pub struct WasmAggregateCountAndLength { + count: usize, + length: usize, +} + +impl WasmAggregateCountAndLength { + #[inline] + pub fn new(count: usize, length: usize) -> WasmAggregateCountAndLength { + WasmAggregateCountAndLength { count, length } + } +} + +impl TryFrom for JsAggregateCountAndLength { + type Error = JsValue; + + #[inline] + fn try_from( + aggregate_count_and_length: WasmAggregateCountAndLength, + ) -> Result { + JsValue::from_serde(&aggregate_count_and_length) + .map_err(|err| JsValue::from(err.to_string())) + .map(|c| c.unchecked_into()) + } +} diff --git a/packages/lib-wasm/src/processing/aggregator/aggregate_result.rs b/packages/lib-wasm/src/processing/aggregator/aggregate_result.rs new file mode 100644 index 0000000..d189855 --- /dev/null +++ b/packages/lib-wasm/src/processing/aggregator/aggregate_result.rs @@ -0,0 +1,228 @@ +use super::aggregate_count_and_length::WasmAggregateCountAndLength; +use js_sys::{Object, Reflect::set}; +use sds_core::{ + processing::aggregator::aggregated_data::AggregatedData, utils::time::ElapsedDurationLogger, +}; +use std::{ + convert::TryFrom, + ops::{Deref, DerefMut}, +}; +use wasm_bindgen::{prelude::*, JsCast}; + +use crate::utils::js::ts_definitions::{ + JsAggregateCountAndLength, JsAggregateCountByLen, JsAggregateResult, JsAggregatesCount, + JsPrivacyRiskSummary, JsResult, +}; + +#[wasm_bindgen] +pub struct WasmAggregateResult { + aggregated_data: AggregatedData, +} + +impl WasmAggregateResult { + #[inline] + pub fn default() -> WasmAggregateResult { + WasmAggregateResult { + aggregated_data: AggregatedData::default(), + } + } + + #[inline] + pub fn new(aggregated_data: AggregatedData) -> WasmAggregateResult { + WasmAggregateResult { aggregated_data } + } +} + +#[wasm_bindgen] +impl WasmAggregateResult { + #[wasm_bindgen(getter)] + #[wasm_bindgen(js_name = "reportingLength")] + pub fn reporting_length(&self) -> usize { + self.aggregated_data.reporting_length + } + + #[wasm_bindgen(js_name = "aggregatesCountToJs")] + pub fn aggregates_count_to_js( + &self, + combination_delimiter: &str, + ) -> JsResult { + let result = Object::new(); + + for (agg, count) in self.aggregated_data.aggregates_count.iter() { + set( + &result, + &agg.format_str_using_headers( + &self.aggregated_data.data_block.headers, + combination_delimiter, + ) + .into(), + &JsAggregateCountAndLength::try_from(WasmAggregateCountAndLength::new( + count.count, + agg.len(), + ))? + .into(), + )?; + } + + Ok(result.unchecked_into::()) + } + + #[wasm_bindgen(js_name = "rareCombinationsCountByLenToJs")] + pub fn rare_combinations_count_by_len_to_js( + &self, + resolution: usize, + ) -> JsResult { + let count = self + .aggregated_data + .calc_rare_combinations_count_by_len(resolution); + + Ok(JsValue::from_serde(&count) + .map_err(|err| JsValue::from(err.to_string()))? + .unchecked_into::()) + } + + #[wasm_bindgen(js_name = "combinationsCountByLenToJs")] + pub fn combinations_count_by_len_to_js(&self) -> JsResult { + let count = self.aggregated_data.calc_combinations_count_by_len(); + + Ok(JsValue::from_serde(&count) + .map_err(|err| JsValue::from(err.to_string()))? + .unchecked_into::()) + } + + #[wasm_bindgen(js_name = "combinationsSumByLenToJs")] + pub fn combinations_sum_by_len_to_js(&self) -> JsResult { + let count = self.aggregated_data.calc_combinations_sum_by_len(); + + Ok(JsValue::from_serde(&count) + .map_err(|err| JsValue::from(err.to_string()))? + .unchecked_into::()) + } + + #[wasm_bindgen(js_name = "privacyRiskToJs")] + pub fn privacy_risk_to_js(&self, resolution: usize) -> JsResult { + let pr = self.aggregated_data.calc_privacy_risk(resolution); + let result = Object::new(); + + set( + &result, + &"totalNumberOfRecords".into(), + &pr.total_number_of_records.into(), + )?; + set( + &result, + &"totalNumberOfCombinations".into(), + &pr.total_number_of_combinations.into(), + )?; + set( + &result, + &"recordsWithUniqueCombinationsCount".into(), + &pr.records_with_unique_combinations_count.into(), + )?; + set( + &result, + &"recordsWithRareCombinationsCount".into(), + &pr.records_with_rare_combinations_count.into(), + )?; + set( + &result, + &"uniqueCombinationsCount".into(), + &pr.unique_combinations_count.into(), + )?; + set( + &result, + &"rareCombinationsCount".into(), + &pr.rare_combinations_count.into(), + )?; + set( + &result, + &"recordsWithUniqueCombinationsProportion".into(), + &pr.records_with_unique_combinations_proportion.into(), + )?; + set( + &result, + &"recordsWithRareCombinationsProportion".into(), + &pr.records_with_rare_combinations_proportion.into(), + )?; + set( + &result, + &"uniqueCombinationsProportion".into(), + &pr.unique_combinations_proportion.into(), + )?; + set( + &result, + &"rareCombinationsProportion".into(), + &pr.rare_combinations_proportion.into(), + )?; + + Ok(result.unchecked_into::()) + } + + #[wasm_bindgen(js_name = "toJs")] + pub fn to_js( + &self, + combination_delimiter: &str, + resolution: usize, + include_aggregates_count: bool, + ) -> JsResult { + let _duration_logger = + ElapsedDurationLogger::new(String::from("aggregate result serialization")); + let result = Object::new(); + + set( + &result, + &"reportingLength".into(), + &self.reporting_length().into(), + )?; + if include_aggregates_count { + set( + &result, + &"aggregatesCount".into(), + &self.aggregates_count_to_js(combination_delimiter)?.into(), + )?; + } + set( + &result, + &"rareCombinationsCountByLen".into(), + &self + .rare_combinations_count_by_len_to_js(resolution)? + .into(), + )?; + set( + &result, + &"combinationsCountByLen".into(), + &self.combinations_count_by_len_to_js()?.into(), + )?; + set( + &result, + &"combinationsSumByLen".into(), + &self.combinations_sum_by_len_to_js()?.into(), + )?; + set( + &result, + &"privacyRisk".into(), + &self.privacy_risk_to_js(resolution)?.into(), + )?; + + Ok(JsValue::from(result).unchecked_into::()) + } + + #[wasm_bindgen(js_name = "protectAggregatesCount")] + pub fn protect_aggregates_count(&mut self, resolution: usize) { + self.aggregated_data.protect_aggregates_count(resolution) + } +} + +impl Deref for WasmAggregateResult { + type Target = AggregatedData; + + fn deref(&self) -> &Self::Target { + &self.aggregated_data + } +} + +impl DerefMut for WasmAggregateResult { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.aggregated_data + } +} diff --git a/packages/lib-wasm/src/processing/aggregator/mod.rs b/packages/lib-wasm/src/processing/aggregator/mod.rs new file mode 100644 index 0000000..f3b5699 --- /dev/null +++ b/packages/lib-wasm/src/processing/aggregator/mod.rs @@ -0,0 +1,3 @@ +pub mod aggregate_count_and_length; + +pub mod aggregate_result; diff --git a/packages/lib-wasm/src/processing/evaluator/evaluate_result.rs b/packages/lib-wasm/src/processing/evaluator/evaluate_result.rs new file mode 100644 index 0000000..1d3a3e9 --- /dev/null +++ b/packages/lib-wasm/src/processing/evaluator/evaluate_result.rs @@ -0,0 +1,184 @@ +use super::preservation_by_count::WasmPreservationByCount; +use js_sys::{Object, Reflect::set}; +use sds_core::{processing::evaluator::Evaluator, utils::time::ElapsedDurationLogger}; +use wasm_bindgen::{prelude::*, JsCast}; + +use crate::{ + processing::aggregator::aggregate_result::WasmAggregateResult, + utils::js::ts_definitions::{ + JsAggregateCountByLen, JsAggregateResult, JsEvaluateResult, JsResult, + }, +}; + +#[wasm_bindgen] +pub struct WasmEvaluateResult { + pub(crate) sensitive_aggregate_result: WasmAggregateResult, + pub(crate) synthetic_aggregate_result: WasmAggregateResult, + evaluator: Evaluator, +} + +impl WasmEvaluateResult { + #[inline] + pub fn default() -> WasmEvaluateResult { + WasmEvaluateResult { + sensitive_aggregate_result: WasmAggregateResult::default(), + synthetic_aggregate_result: WasmAggregateResult::default(), + evaluator: Evaluator::default(), + } + } + + #[inline] + pub fn new( + sensitive_aggregate_result: WasmAggregateResult, + synthetic_aggregate_result: WasmAggregateResult, + ) -> WasmEvaluateResult { + WasmEvaluateResult { + sensitive_aggregate_result, + synthetic_aggregate_result, + evaluator: Evaluator::default(), + } + } +} + +#[wasm_bindgen] +impl WasmEvaluateResult { + #[wasm_bindgen(constructor)] + pub fn from_aggregate_results( + sensitive_aggregate_result: WasmAggregateResult, + synthetic_aggregate_result: WasmAggregateResult, + ) -> WasmEvaluateResult { + WasmEvaluateResult::new(sensitive_aggregate_result, synthetic_aggregate_result) + } + + #[wasm_bindgen(js_name = "sensitiveAggregateResultToJs")] + pub fn sensitive_aggregate_result_to_js( + &self, + combination_delimiter: &str, + resolution: usize, + include_aggregates_count: bool, + ) -> JsResult { + self.sensitive_aggregate_result.to_js( + combination_delimiter, + resolution, + include_aggregates_count, + ) + } + + #[wasm_bindgen(js_name = "syntheticAggregateResultToJs")] + pub fn synthetic_aggregate_result_to_js( + &self, + combination_delimiter: &str, + resolution: usize, + include_aggregates_count: bool, + ) -> JsResult { + self.synthetic_aggregate_result.to_js( + combination_delimiter, + resolution, + include_aggregates_count, + ) + } + + #[wasm_bindgen(js_name = "leakageCountByLenToJs")] + pub fn leakage_count_by_len_to_js(&self, resolution: usize) -> JsResult { + let count = self.evaluator.calc_leakage_count( + &self.sensitive_aggregate_result, + &self.synthetic_aggregate_result, + resolution, + ); + + Ok(JsValue::from_serde(&count) + .map_err(|err| JsValue::from(err.to_string()))? + .unchecked_into::()) + } + + #[wasm_bindgen(js_name = "fabricatedCountByLenToJs")] + pub fn fabricated_count_by_len_to_js(&self) -> JsResult { + let count = self.evaluator.calc_fabricated_count( + &self.sensitive_aggregate_result, + &self.synthetic_aggregate_result, + ); + + Ok(JsValue::from_serde(&count) + .map_err(|err| JsValue::from(err.to_string()))? + .unchecked_into::()) + } + + #[wasm_bindgen(js_name = "preservationByCount")] + pub fn preservation_by_count(&self, resolution: usize) -> WasmPreservationByCount { + WasmPreservationByCount::new(self.evaluator.calc_preservation_by_count( + &self.sensitive_aggregate_result, + &self.synthetic_aggregate_result, + resolution, + )) + } + + #[wasm_bindgen(js_name = "recordExpansion")] + pub fn record_expansion(&self) -> f64 { + (self + .synthetic_aggregate_result + .data_block + .number_of_records() as f64) + / (self + .sensitive_aggregate_result + .data_block + .number_of_records() as f64) + } + + #[wasm_bindgen(js_name = "toJs")] + pub fn to_js( + &self, + combination_delimiter: &str, + resolution: usize, + include_aggregates_count: bool, + ) -> JsResult { + let _duration_logger = + ElapsedDurationLogger::new(String::from("evaluate result serialization")); + let result = Object::new(); + + set( + &result, + &"sensitiveAggregateResult".into(), + &self + .sensitive_aggregate_result_to_js( + combination_delimiter, + resolution, + include_aggregates_count, + )? + .into(), + )?; + set( + &result, + &"syntheticAggregateResult".into(), + &self + .synthetic_aggregate_result_to_js( + combination_delimiter, + resolution, + include_aggregates_count, + )? + .into(), + )?; + + set( + &result, + &"leakageCountByLen".into(), + &self.leakage_count_by_len_to_js(resolution)?.into(), + )?; + set( + &result, + &"fabricatedCountByLen".into(), + &self.fabricated_count_by_len_to_js()?.into(), + )?; + set( + &result, + &"preservationByCount".into(), + &self.preservation_by_count(resolution).to_js()?.into(), + )?; + set( + &result, + &"recordExpansion".into(), + &self.record_expansion().into(), + )?; + + Ok(JsValue::from(result).unchecked_into::()) + } +} diff --git a/packages/lib-wasm/src/processing/evaluator/mod.rs b/packages/lib-wasm/src/processing/evaluator/mod.rs new file mode 100644 index 0000000..7ca97c1 --- /dev/null +++ b/packages/lib-wasm/src/processing/evaluator/mod.rs @@ -0,0 +1,3 @@ +pub mod preservation_by_count; + +pub mod evaluate_result; diff --git a/packages/lib-wasm/src/processing/evaluator/preservation_by_count.rs b/packages/lib-wasm/src/processing/evaluator/preservation_by_count.rs new file mode 100644 index 0000000..d95a947 --- /dev/null +++ b/packages/lib-wasm/src/processing/evaluator/preservation_by_count.rs @@ -0,0 +1,93 @@ +use js_sys::{Object, Reflect::set}; +use sds_core::{ + processing::evaluator::preservation_by_count::PreservationByCountBuckets, + utils::time::ElapsedDurationLogger, +}; +use std::ops::{Deref, DerefMut}; +use wasm_bindgen::{prelude::*, JsCast}; + +use crate::utils::js::ts_definitions::{ + JsPreservationByCount, JsPreservationByCountBuckets, JsResult, +}; + +#[wasm_bindgen] +pub struct WasmPreservationByCount { + preservation_by_count_buckets: PreservationByCountBuckets, +} + +impl WasmPreservationByCount { + #[inline] + pub fn new( + preservation_by_count_buckets: PreservationByCountBuckets, + ) -> WasmPreservationByCount { + WasmPreservationByCount { + preservation_by_count_buckets, + } + } +} + +#[wasm_bindgen] +impl WasmPreservationByCount { + #[wasm_bindgen(js_name = "bucketsToJs")] + pub fn buckets_to_js(&self) -> JsResult { + let result = Object::new(); + + for (bucket_index, b) in self.preservation_by_count_buckets.iter() { + let serialized_bucket = Object::new(); + + set(&serialized_bucket, &"size".into(), &b.size.into())?; + set( + &serialized_bucket, + &"preservationSum".into(), + &b.preservation_sum.into(), + )?; + set( + &serialized_bucket, + &"lengthSum".into(), + &b.length_sum.into(), + )?; + set( + &serialized_bucket, + &"combinationCountSum".into(), + &b.combination_count_sum.into(), + )?; + + set(&result, &(*bucket_index).into(), &serialized_bucket.into())?; + } + Ok(result.unchecked_into::()) + } + + #[wasm_bindgen(js_name = "combinationLoss")] + pub fn combination_loss(&self) -> f64 { + self.preservation_by_count_buckets.calc_combination_loss() + } + + #[wasm_bindgen(js_name = "toJs")] + pub fn to_js(&self) -> JsResult { + let _duration_logger = + ElapsedDurationLogger::new(String::from("preservation by count serialization")); + let result = Object::new(); + + set(&result, &"buckets".into(), &self.buckets_to_js()?.into())?; + set( + &result, + &"combinationLoss".into(), + &self.combination_loss().into(), + )?; + Ok(JsValue::from(result).unchecked_into::()) + } +} + +impl Deref for WasmPreservationByCount { + type Target = PreservationByCountBuckets; + + fn deref(&self) -> &Self::Target { + &self.preservation_by_count_buckets + } +} + +impl DerefMut for WasmPreservationByCount { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.preservation_by_count_buckets + } +} diff --git a/packages/lib-wasm/src/processing/generator/generate_result.rs b/packages/lib-wasm/src/processing/generator/generate_result.rs new file mode 100644 index 0000000..db1866f --- /dev/null +++ b/packages/lib-wasm/src/processing/generator/generate_result.rs @@ -0,0 +1,77 @@ +use js_sys::{Object, Reflect::set}; +use sds_core::{ + processing::generator::generated_data::GeneratedData, utils::time::ElapsedDurationLogger, +}; +use std::ops::{Deref, DerefMut}; +use wasm_bindgen::{prelude::*, JsCast}; + +use crate::utils::js::ts_definitions::{JsCsvData, JsGenerateResult, JsResult}; + +#[wasm_bindgen] +pub struct WasmGenerateResult { + generated_data: GeneratedData, +} + +impl WasmGenerateResult { + #[inline] + pub fn default() -> WasmGenerateResult { + WasmGenerateResult { + generated_data: GeneratedData::default(), + } + } + + #[inline] + pub fn new(generated_data: GeneratedData) -> WasmGenerateResult { + WasmGenerateResult { generated_data } + } +} + +#[wasm_bindgen] +impl WasmGenerateResult { + #[wasm_bindgen(getter)] + #[wasm_bindgen(js_name = "expansionRatio")] + pub fn expansion_ratio(&self) -> f64 { + self.generated_data.expansion_ratio + } + + #[wasm_bindgen(js_name = "syntheticDataToJs")] + pub fn synthetic_data_to_js(&self) -> JsResult { + Ok(JsValue::from_serde(&self.generated_data.synthetic_data) + .map_err(|err| JsValue::from(err.to_string()))? + .unchecked_into::()) + } + + #[wasm_bindgen(js_name = "toJs")] + pub fn to_js(&self) -> JsResult { + let _duration_logger = + ElapsedDurationLogger::new(String::from("generate result serialization")); + let result = Object::new(); + + set( + &result, + &"expansionRatio".into(), + &self.expansion_ratio().into(), + )?; + set( + &result, + &"syntheticData".into(), + &self.synthetic_data_to_js()?.into(), + )?; + + Ok(JsValue::from(result).unchecked_into::()) + } +} + +impl Deref for WasmGenerateResult { + type Target = GeneratedData; + + fn deref(&self) -> &Self::Target { + &self.generated_data + } +} + +impl DerefMut for WasmGenerateResult { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.generated_data + } +} diff --git a/packages/lib-wasm/src/processing/generator/mod.rs b/packages/lib-wasm/src/processing/generator/mod.rs new file mode 100644 index 0000000..4e53e8f --- /dev/null +++ b/packages/lib-wasm/src/processing/generator/mod.rs @@ -0,0 +1 @@ +pub mod generate_result; diff --git a/packages/lib-wasm/src/processing/mod.rs b/packages/lib-wasm/src/processing/mod.rs new file mode 100644 index 0000000..e89237e --- /dev/null +++ b/packages/lib-wasm/src/processing/mod.rs @@ -0,0 +1,11 @@ +pub mod aggregator; + +pub mod evaluator; + +pub mod generator; + +pub mod navigator; + +pub mod sds_processor; + +pub mod sds_context; diff --git a/packages/lib-wasm/src/processing/navigator/attributes_intersection.rs b/packages/lib-wasm/src/processing/navigator/attributes_intersection.rs new file mode 100644 index 0000000..79b39bc --- /dev/null +++ b/packages/lib-wasm/src/processing/navigator/attributes_intersection.rs @@ -0,0 +1,75 @@ +use js_sys::{Array, Object, Reflect::set}; +use std::{collections::HashMap, convert::TryFrom, sync::Arc}; +use wasm_bindgen::{JsCast, JsValue}; + +use crate::utils::js::ts_definitions::{ + JsAttributesIntersection, JsAttributesIntersectionByColumn, +}; + +pub struct WasmAttributesIntersection { + pub(crate) value: Arc, + pub(crate) estimated_count: usize, + pub(crate) actual_count: Option, +} + +impl WasmAttributesIntersection { + #[inline] + pub fn new( + value: Arc, + estimated_count: usize, + actual_count: Option, + ) -> WasmAttributesIntersection { + WasmAttributesIntersection { + value, + estimated_count, + actual_count, + } + } +} + +impl TryFrom for JsAttributesIntersection { + type Error = JsValue; + + fn try_from(attributes_intersection: WasmAttributesIntersection) -> Result { + let result = Object::new(); + + set( + &result, + &"value".into(), + &attributes_intersection.value.as_str().into(), + )?; + set( + &result, + &"estimatedCount".into(), + &attributes_intersection.estimated_count.into(), + )?; + if let Some(actual_count) = attributes_intersection.actual_count { + set(&result, &"actualCount".into(), &actual_count.into())?; + } + + Ok(result.unchecked_into::()) + } +} + +pub type WasmAttributesIntersectionByColumn = HashMap>; + +impl TryFrom for JsAttributesIntersectionByColumn { + type Error = JsValue; + + fn try_from( + mut intersections_by_column: WasmAttributesIntersectionByColumn, + ) -> Result { + let result = Object::new(); + + for (column_index, mut attrs_intersections) in intersections_by_column.drain() { + let intersections: Array = Array::default(); + + for intersection in attrs_intersections.drain(..) { + intersections.push(&JsAttributesIntersection::try_from(intersection)?.into()); + } + set(&result, &column_index.into(), &intersections)?; + } + + Ok(result.unchecked_into::()) + } +} diff --git a/packages/lib-wasm/src/processing/navigator/mod.rs b/packages/lib-wasm/src/processing/navigator/mod.rs new file mode 100644 index 0000000..eb50a84 --- /dev/null +++ b/packages/lib-wasm/src/processing/navigator/mod.rs @@ -0,0 +1,5 @@ +pub mod attributes_intersection; + +pub mod navigate_result; + +pub mod selected_attributes; diff --git a/packages/lib-wasm/src/processing/navigator/navigate_result.rs b/packages/lib-wasm/src/processing/navigator/navigate_result.rs new file mode 100644 index 0000000..fc0b5c9 --- /dev/null +++ b/packages/lib-wasm/src/processing/navigator/navigate_result.rs @@ -0,0 +1,229 @@ +use super::{ + attributes_intersection::{WasmAttributesIntersection, WasmAttributesIntersectionByColumn}, + selected_attributes::{WasmSelectedAttributes, WasmSelectedAttributesByColumn}, +}; +use sds_core::{ + data_block::{ + block::DataBlock, + typedefs::{ + AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, ColumnIndexByName, CsvRecord, + }, + value::DataBlockValue, + }, + processing::aggregator::value_combination::ValueCombination, + utils::collections::ordered_vec_intersection, +}; +use std::{cmp::Reverse, convert::TryFrom, sync::Arc}; +use wasm_bindgen::prelude::*; + +use crate::{ + processing::{aggregator::aggregate_result::WasmAggregateResult, sds_processor::SDSProcessor}, + utils::js::ts_definitions::{ + JsAttributesIntersectionByColumn, JsHeaderNames, JsResult, JsSelectedAttributesByColumn, + }, +}; + +#[wasm_bindgen] +pub struct WasmNavigateResult { + synthetic_data_block: Arc, + attr_rows_by_column: AttributeRowsByColumnMap, + selected_attributes: WasmSelectedAttributesByColumn, + selected_attr_rows: AttributeRows, + all_attr_rows: AttributeRows, + column_index_by_name: ColumnIndexByName, +} + +impl WasmNavigateResult { + #[inline] + pub fn default() -> WasmNavigateResult { + WasmNavigateResult::new( + Arc::new(DataBlock::default()), + AttributeRowsByColumnMap::default(), + WasmSelectedAttributesByColumn::default(), + AttributeRows::default(), + AttributeRows::default(), + ColumnIndexByName::default(), + ) + } + + #[inline] + pub fn new( + synthetic_data_block: Arc, + attr_rows_by_column: AttributeRowsByColumnMap, + selected_attributes: WasmSelectedAttributesByColumn, + selected_attr_rows: AttributeRows, + all_attr_rows: AttributeRows, + column_index_by_name: ColumnIndexByName, + ) -> WasmNavigateResult { + WasmNavigateResult { + synthetic_data_block, + attr_rows_by_column, + selected_attributes, + selected_attr_rows, + all_attr_rows, + column_index_by_name, + } + } + + #[inline] + fn intersect_attributes(&self, attributes: &WasmSelectedAttributesByColumn) -> AttributeRows { + let default_empty_attr_rows = AttributeRowsMap::default(); + let default_empty_rows = AttributeRows::default(); + let mut result = self.all_attr_rows.clone(); + + for (column_index, values) in attributes.iter() { + let attr_rows = self + .attr_rows_by_column + .get(column_index) + .unwrap_or(&default_empty_attr_rows); + for value in values.iter() { + result = ordered_vec_intersection( + &result, + attr_rows.get(value).unwrap_or(&default_empty_rows), + ); + } + } + result + } + + #[inline] + fn get_sensitive_aggregates_count( + &self, + attributes: &WasmSelectedAttributesByColumn, + sensitive_aggregate_result: &WasmAggregateResult, + ) -> Option { + let mut combination: Vec> = attributes + .values() + .flat_map(|values| values.iter().cloned()) + .collect(); + + combination.sort_by_key(|k| k.format_str_using_headers(&self.synthetic_data_block.headers)); + + sensitive_aggregate_result + .aggregates_count + .get(&ValueCombination::new(combination)) + .map(|c| c.count) + } + + #[inline] + fn get_selected_but_current_column( + &self, + column_index: usize, + ) -> (WasmSelectedAttributesByColumn, AttributeRows) { + let mut selected_attributes_but_current_column = self.selected_attributes.clone(); + + selected_attributes_but_current_column.remove(&column_index); + + let selected_attr_rows_but_current_column = + self.intersect_attributes(&selected_attributes_but_current_column); + + ( + selected_attributes_but_current_column, + selected_attr_rows_but_current_column, + ) + } + + #[inline] + fn selected_attributes_contains_value( + &self, + column_index: usize, + value: &Arc, + ) -> bool { + self.selected_attributes + .get(&column_index) + .map(|values| values.contains(value)) + .unwrap_or(false) + } +} + +#[wasm_bindgen] +impl WasmNavigateResult { + #[wasm_bindgen(constructor)] + pub fn from_synthetic_processor(synthetic_processor: &SDSProcessor) -> WasmNavigateResult { + WasmNavigateResult::new( + synthetic_processor.data_block.clone(), + synthetic_processor + .data_block + .calc_attr_rows_with_no_empty_values(), + WasmSelectedAttributesByColumn::default(), + AttributeRows::default(), + (0..synthetic_processor.data_block.number_of_records()).collect(), + synthetic_processor.data_block.calc_column_index_by_name(), + ) + } + + #[wasm_bindgen(js_name = "selectAttributes")] + pub fn select_attributes(&mut self, attributes: JsSelectedAttributesByColumn) -> JsResult<()> { + self.selected_attributes = WasmSelectedAttributesByColumn::try_from(attributes)?; + self.selected_attr_rows = self.intersect_attributes(&self.selected_attributes); + Ok(()) + } + + #[wasm_bindgen(js_name = "attributesIntersectionsByColumn")] + pub fn attributes_intersections_by_column( + &mut self, + columns: JsHeaderNames, + sensitive_aggregate_result: &WasmAggregateResult, + ) -> JsResult { + let column_indexes: AttributeRows = CsvRecord::try_from(columns)? + .iter() + .filter_map(|header_name| self.column_index_by_name.get(header_name)) + .cloned() + .collect(); + let mut result = WasmAttributesIntersectionByColumn::default(); + + for column_index in column_indexes { + let result_column_entry = result.entry(column_index).or_insert_with(Vec::default); + if let Some(attr_rows) = self.attr_rows_by_column.get(&column_index) { + let (selected_attributes_but_current_column, selected_attr_rows_but_current_column) = + self.get_selected_but_current_column(column_index); + + for (value, rows) in attr_rows { + let current_selected_attr_rows; + let current_selected_attributes; + + if self.selected_attributes_contains_value(column_index, value) { + // if the selected attributes contain the current value, we + // have already the selected attr rows cached + current_selected_attributes = &self.selected_attributes; + current_selected_attr_rows = &self.selected_attr_rows; + } else { + // if the selected attributes do not contain the current value + // use the intersection between all selected values, but the + // ones in the current column + current_selected_attributes = &selected_attributes_but_current_column; + current_selected_attr_rows = &selected_attr_rows_but_current_column; + }; + + let estimated_count = + ordered_vec_intersection(current_selected_attr_rows, rows).len(); + + if estimated_count > 0 { + let mut attributes = current_selected_attributes.clone(); + let attributes_entry = attributes + .entry(column_index) + .or_insert_with(WasmSelectedAttributes::default); + + attributes_entry.insert(value.clone()); + + result_column_entry.push(WasmAttributesIntersection::new( + value.value.clone(), + estimated_count, + self.get_sensitive_aggregates_count( + &attributes, + sensitive_aggregate_result, + ), + )); + } + } + } + } + + // sort by estimated count in descending order + for intersections in result.values_mut() { + intersections.sort_by_key(|intersection| Reverse(intersection.estimated_count)) + } + + JsAttributesIntersectionByColumn::try_from(result) + } +} diff --git a/packages/lib-wasm/src/processing/navigator/selected_attributes.rs b/packages/lib-wasm/src/processing/navigator/selected_attributes.rs new file mode 100644 index 0000000..aa14ff1 --- /dev/null +++ b/packages/lib-wasm/src/processing/navigator/selected_attributes.rs @@ -0,0 +1,48 @@ +use js_sys::{Array, Object, Set}; +use sds_core::data_block::value::DataBlockValue; +use std::{ + collections::{HashMap, HashSet}, + convert::TryFrom, + sync::Arc, +}; +use wasm_bindgen::{JsCast, JsValue}; + +use crate::utils::js::ts_definitions::JsSelectedAttributesByColumn; + +pub type WasmSelectedAttributes = HashSet>; + +pub type WasmSelectedAttributesByColumn = HashMap; + +impl TryFrom for WasmSelectedAttributesByColumn { + type Error = JsValue; + + fn try_from(js_selected_attributes: JsSelectedAttributesByColumn) -> Result { + let mut result = WasmSelectedAttributesByColumn::default(); + + for entry_res in Object::entries(&js_selected_attributes.dyn_into::()?).values() { + let entry = entry_res?.dyn_into::()?; + let column_index = entry + .get(0) + .as_string() + .ok_or_else(|| { + JsValue::from("invalid column index on selected attributes by column") + })? + .parse::() + .map_err(|err| JsValue::from(err.to_string()))?; + let result_entry = result.entry(column_index).or_insert_with(HashSet::default); + + for value_res in entry.get(1).dyn_into::()?.keys() { + let value = value_res?; + let value_str = value.as_string().ok_or_else(|| { + JsValue::from("invalid value on selected attributes by column") + })?; + + result_entry.insert(Arc::new(DataBlockValue::new( + column_index, + Arc::new(value_str), + ))); + } + } + Ok(result) + } +} diff --git a/packages/lib-wasm/src/processing/sds_context.rs b/packages/lib-wasm/src/processing/sds_context.rs new file mode 100644 index 0000000..2ac2bf7 --- /dev/null +++ b/packages/lib-wasm/src/processing/sds_context.rs @@ -0,0 +1,209 @@ +use super::{ + evaluator::evaluate_result::WasmEvaluateResult, generator::generate_result::WasmGenerateResult, + navigator::navigate_result::WasmNavigateResult, sds_processor::SDSProcessor, +}; +use log::debug; +use wasm_bindgen::{prelude::*, JsCast}; + +use crate::utils::js::ts_definitions::{ + JsAttributesIntersectionByColumn, JsCsvData, JsEvaluateResult, JsGenerateResult, JsHeaderNames, + JsReportProgressCallback, JsResult, JsSelectedAttributesByColumn, +}; + +#[wasm_bindgen] +pub struct SDSContext { + use_columns: JsHeaderNames, + sensitive_zeros: JsHeaderNames, + record_limit: usize, + sensitive_processor: SDSProcessor, + generate_result: WasmGenerateResult, + resolution: usize, + synthetic_processor: SDSProcessor, + evaluate_result: WasmEvaluateResult, + navigate_result: WasmNavigateResult, +} + +#[wasm_bindgen] +impl SDSContext { + #[wasm_bindgen(constructor)] + pub fn default() -> SDSContext { + SDSContext { + use_columns: JsHeaderNames::default(), + sensitive_zeros: JsHeaderNames::default(), + record_limit: 0, + sensitive_processor: SDSProcessor::default(), + generate_result: WasmGenerateResult::default(), + resolution: 0, + synthetic_processor: SDSProcessor::default(), + evaluate_result: WasmEvaluateResult::default(), + navigate_result: WasmNavigateResult::default(), + } + } + + #[wasm_bindgen(js_name = "clearSensitiveData")] + pub fn clear_sensitive_data(&mut self) { + self.use_columns = JsHeaderNames::default(); + self.sensitive_zeros = JsHeaderNames::default(); + self.record_limit = 0; + self.sensitive_processor = SDSProcessor::default(); + self.clear_generate(); + } + + #[wasm_bindgen(js_name = "clearGenerate")] + pub fn clear_generate(&mut self) { + self.generate_result = WasmGenerateResult::default(); + self.resolution = 0; + self.synthetic_processor = SDSProcessor::default(); + self.clear_evaluate() + } + + #[wasm_bindgen(js_name = "clearEvaluate")] + pub fn clear_evaluate(&mut self) { + self.evaluate_result = WasmEvaluateResult::default(); + self.clear_navigate() + } + + #[wasm_bindgen(js_name = "clearNavigate")] + pub fn clear_navigate(&mut self) { + self.navigate_result = WasmNavigateResult::default(); + } + + #[wasm_bindgen(js_name = "setSensitiveData")] + pub fn set_sensitive_data( + &mut self, + csv_data: JsCsvData, + use_columns: JsHeaderNames, + sensitive_zeros: JsHeaderNames, + record_limit: usize, + ) -> JsResult<()> { + debug!("setting sensitive data..."); + + self.use_columns = use_columns; + self.sensitive_zeros = sensitive_zeros; + self.record_limit = record_limit; + self.sensitive_processor = SDSProcessor::new( + csv_data, + self.use_columns.clone().unchecked_into(), + self.sensitive_zeros.clone().unchecked_into(), + self.record_limit, + )?; + self.clear_generate(); + Ok(()) + } + + pub fn generate( + &mut self, + cache_max_size: usize, + resolution: usize, + empty_value: String, + seeded: bool, + progress_callback: JsReportProgressCallback, + ) -> JsResult<()> { + debug!("generating synthetic data..."); + + self.generate_result = self.sensitive_processor.generate( + cache_max_size, + resolution, + empty_value, + seeded, + progress_callback, + )?; + self.resolution = resolution; + + debug!("creating synthetic data processor..."); + + self.synthetic_processor = SDSProcessor::new( + self.generate_result.synthetic_data_to_js()?, + self.use_columns.clone().unchecked_into(), + self.sensitive_zeros.clone().unchecked_into(), + self.record_limit, + )?; + self.clear_evaluate(); + Ok(()) + } + + pub fn evaluate( + &mut self, + reporting_length: usize, + sensitivity_threshold: usize, + sensitive_progress_callback: JsReportProgressCallback, + synthetic_progress_callback: JsReportProgressCallback, + ) -> JsResult<()> { + debug!("aggregating sensitive data..."); + + let sensitive_aggregate_result = self.sensitive_processor.aggregate( + reporting_length, + sensitivity_threshold, + sensitive_progress_callback, + )?; + + debug!("aggregating synthetic data..."); + + let synthetic_aggregate_result = self.synthetic_processor.aggregate( + reporting_length, + sensitivity_threshold, + synthetic_progress_callback, + )?; + + debug!("evaluating synthetic data based on sensitive data..."); + + self.evaluate_result = WasmEvaluateResult::from_aggregate_results( + sensitive_aggregate_result, + synthetic_aggregate_result, + ); + + self.clear_navigate(); + + Ok(()) + } + + #[wasm_bindgen(js_name = "protectSensitiveAggregatesCount")] + pub fn protect_sensitive_aggregates_count(&mut self) { + debug!("protecting sensitive aggregates count..."); + + self.evaluate_result + .sensitive_aggregate_result + .protect_aggregates_count(self.resolution); + } + + pub fn navigate(&mut self) { + debug!("creating navigate result..."); + + self.navigate_result = + WasmNavigateResult::from_synthetic_processor(&self.synthetic_processor); + } + + #[wasm_bindgen(js_name = "selectAttributes")] + pub fn select_attributes(&mut self, attributes: JsSelectedAttributesByColumn) -> JsResult<()> { + self.navigate_result.select_attributes(attributes) + } + + #[wasm_bindgen(js_name = "attributesIntersectionsByColumn")] + pub fn attributes_intersections_by_column( + &mut self, + columns: JsHeaderNames, + ) -> JsResult { + self.navigate_result.attributes_intersections_by_column( + columns, + &self.evaluate_result.sensitive_aggregate_result, + ) + } + + #[wasm_bindgen(js_name = "generateResultToJs")] + pub fn generate_result_to_js(&self) -> JsResult { + self.generate_result.to_js() + } + + #[wasm_bindgen(js_name = "evaluateResultToJs")] + pub fn evaluate_result_to_js( + &self, + combination_delimiter: &str, + include_aggregates_count: bool, + ) -> JsResult { + self.evaluate_result.to_js( + combination_delimiter, + self.resolution, + include_aggregates_count, + ) + } +} diff --git a/packages/lib-wasm/src/processing/sds_processor/header_names.rs b/packages/lib-wasm/src/processing/sds_processor/header_names.rs new file mode 100644 index 0000000..e1cede8 --- /dev/null +++ b/packages/lib-wasm/src/processing/sds_processor/header_names.rs @@ -0,0 +1,23 @@ +use js_sys::Array; +use sds_core::data_block::typedefs::CsvRecord; +use std::convert::TryFrom; +use wasm_bindgen::{JsCast, JsValue}; + +use crate::utils::js::ts_definitions::JsHeaderNames; + +impl JsHeaderNames { + #[inline] + pub fn default() -> JsHeaderNames { + Array::default().unchecked_into() + } +} + +impl TryFrom for CsvRecord { + type Error = JsValue; + + fn try_from(js_csv_record: JsHeaderNames) -> Result { + js_csv_record + .into_serde::() + .map_err(|err| JsValue::from(err.to_string())) + } +} diff --git a/packages/lib-wasm/src/processing/sds_processor/mod.rs b/packages/lib-wasm/src/processing/sds_processor/mod.rs new file mode 100644 index 0000000..8940b62 --- /dev/null +++ b/packages/lib-wasm/src/processing/sds_processor/mod.rs @@ -0,0 +1,5 @@ +pub mod header_names; + +mod processor; + +pub use processor::SDSProcessor; diff --git a/packages/lib-wasm/src/processing/sds_processor/processor.rs b/packages/lib-wasm/src/processing/sds_processor/processor.rs new file mode 100644 index 0000000..fb18684 --- /dev/null +++ b/packages/lib-wasm/src/processing/sds_processor/processor.rs @@ -0,0 +1,120 @@ +use js_sys::Function; +use sds_core::{ + data_block::{block::DataBlock, data_block_creator::DataBlockCreator, typedefs::CsvRecord}, + processing::{ + aggregator::Aggregator, + generator::{Generator, SynthesisMode}, + }, + utils::time::ElapsedDurationLogger, +}; +use std::convert::TryFrom; +use std::sync::Arc; +use wasm_bindgen::{prelude::*, JsCast}; + +use crate::{ + utils::js::js_progress_reporter::JsProgressReporter, + { + processing::{ + aggregator::aggregate_result::WasmAggregateResult, + generator::generate_result::WasmGenerateResult, + }, + utils::js::{ + js_block_creator::JsDataBlockCreator, + ts_definitions::{JsCsvData, JsHeaderNames, JsReportProgressCallback}, + }, + }, +}; + +#[wasm_bindgen] +pub struct SDSProcessor { + pub(crate) data_block: Arc, +} + +#[wasm_bindgen] +impl SDSProcessor { + pub fn default() -> SDSProcessor { + SDSProcessor { + data_block: Arc::new(DataBlock::default()), + } + } +} + +#[wasm_bindgen] +impl SDSProcessor { + #[wasm_bindgen(constructor)] + pub fn new( + csv_data: JsCsvData, + use_columns: JsHeaderNames, + sensitive_zeros: JsHeaderNames, + record_limit: usize, + ) -> Result { + let _duration_logger = ElapsedDurationLogger::new(String::from("sds processor creation")); + let data_block = JsDataBlockCreator::create( + Ok(csv_data.dyn_into()?), + &CsvRecord::try_from(use_columns)?, + &CsvRecord::try_from(sensitive_zeros)?, + record_limit, + )?; + + Ok(SDSProcessor { data_block }) + } + + #[wasm_bindgen(js_name = "numberOfRecords")] + pub fn number_of_records(&self) -> usize { + self.data_block.number_of_records() + } + + #[wasm_bindgen(js_name = "protectedNumberOfRecords")] + pub fn protected_number_of_records(&self, resolution: usize) -> usize { + self.data_block.protected_number_of_records(resolution) + } + + #[wasm_bindgen(js_name = "normalizeReportingLength")] + pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize { + self.data_block.normalize_reporting_length(reporting_length) + } + + pub fn aggregate( + &self, + reporting_length: usize, + sensitivity_threshold: usize, + progress_callback: JsReportProgressCallback, + ) -> Result { + let _duration_logger = + ElapsedDurationLogger::new(String::from("sds processor aggregation")); + let js_callback: Function = progress_callback.dyn_into()?; + let mut aggregator = Aggregator::new(self.data_block.clone()); + + Ok(WasmAggregateResult::new(aggregator.aggregate( + reporting_length, + sensitivity_threshold, + &mut Some(JsProgressReporter::new(&js_callback, &|p| p)), + ))) + } + + pub fn generate( + &self, + cache_max_size: usize, + resolution: usize, + empty_value: String, + seeded: bool, + progress_callback: JsReportProgressCallback, + ) -> Result { + let _duration_logger = ElapsedDurationLogger::new(String::from("sds processor generation")); + let js_callback: Function = progress_callback.dyn_into()?; + let mut generator = Generator::new(self.data_block.clone()); + let mode = if seeded { + SynthesisMode::Seeded + } else { + SynthesisMode::Unseeded + }; + + Ok(WasmGenerateResult::new(generator.generate( + resolution, + cache_max_size, + empty_value, + mode, + &mut Some(JsProgressReporter::new(&js_callback, &|p| p)), + ))) + } +} diff --git a/packages/lib-wasm/src/utils/js/deserializers.rs b/packages/lib-wasm/src/utils/js/deserializers.rs deleted file mode 100644 index f655582..0000000 --- a/packages/lib-wasm/src/utils/js/deserializers.rs +++ /dev/null @@ -1,41 +0,0 @@ -use super::js_block_creator::JsDataBlockCreator; -use js_sys::Array; -use sds_core::{ - data_block::{ - block::{DataBlock, DataBlockCreator}, - typedefs::CsvRecord, - }, - utils::time::ElapsedDurationLogger, -}; - -#[inline] -pub fn deserialize_use_columns(use_columns: Array) -> Result { - let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_use_columns")); - - match use_columns.into_serde::() { - Ok(v) => Ok(v), - _ => Err(String::from("use_columns should be an Array")), - } -} - -#[inline] -pub fn deserialize_sensitive_zeros(sensitive_zeros: Array) -> Result { - let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_sensitive_zeros")); - - match sensitive_zeros.into_serde::() { - Ok(v) => Ok(v), - _ => Err(String::from("sensitive_zeros should be an Array")), - } -} - -#[inline] -pub fn deserialize_csv_data( - csv_data: Array, - use_columns: &[String], - sensitive_zeros: &[String], - record_limit: usize, -) -> Result { - let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_csv_data")); - - JsDataBlockCreator::create(Ok(csv_data), use_columns, sensitive_zeros, record_limit) -} diff --git a/packages/lib-wasm/src/utils/js/js_block_creator.rs b/packages/lib-wasm/src/utils/js/js_block_creator.rs index 909540c..e1096e0 100644 --- a/packages/lib-wasm/src/utils/js/js_block_creator.rs +++ b/packages/lib-wasm/src/utils/js/js_block_creator.rs @@ -1,33 +1,34 @@ use js_sys::Array; -use sds_core::data_block::{block::DataBlockCreator, typedefs::CsvRecord}; +use sds_core::data_block::{data_block_creator::DataBlockCreator, typedefs::CsvRecord}; +use wasm_bindgen::JsValue; pub struct JsDataBlockCreator; impl DataBlockCreator for JsDataBlockCreator { type InputType = Array; - type ErrorType = String; + type ErrorType = JsValue; - fn get_headers(csv_data: &mut Array) -> Result { + fn get_headers(csv_data: &mut Self::InputType) -> Result { let headers = csv_data.get(0); if headers.is_undefined() { - Err(String::from("missing headers")) + Err(JsValue::from("csv data missing headers")) } else { - match headers.into_serde::>() { - Ok(h) => Ok(h), - Err(_) => Err(String::from("headers should an Array")), - } + headers + .into_serde::() + .map_err(|err| JsValue::from(err.to_string())) } } - fn get_records(csv_data: &mut Self::InputType) -> Result, String> { + fn get_records(csv_data: &mut Self::InputType) -> Result, Self::ErrorType> { csv_data .slice(1, csv_data.length()) .iter() - .map(|record| match record.into_serde::>() { - Ok(h) => Ok(h), - Err(_) => Err(String::from("records should an Array")), + .map(|record| { + record + .into_serde::() + .map_err(|err| JsValue::from(err.to_string())) }) - .collect::, String>>() + .collect::, Self::ErrorType>>() } } diff --git a/packages/lib-wasm/src/utils/js/macros.rs b/packages/lib-wasm/src/utils/js/macros.rs deleted file mode 100644 index bad4ef1..0000000 --- a/packages/lib-wasm/src/utils/js/macros.rs +++ /dev/null @@ -1,32 +0,0 @@ -#[doc(hidden)] -#[macro_export] -macro_rules! match_or_return_undefined { - ($result_to_match: expr) => { - match $result_to_match { - Ok(result) => result, - Err(err) => { - error!("{}", err); - return JsValue::undefined(); - } - } - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! set_or_return_undefined { - ($($arg:tt)*) => { - match set($($arg)*) { - Ok(has_been_set) => { - if !has_been_set { - error!("unable to set object value from wasm"); - return JsValue::undefined(); - } - } - _ => { - error!("unable to set object value from wasm"); - return JsValue::undefined(); - } - }; - }; -} diff --git a/packages/lib-wasm/src/utils/js/mod.rs b/packages/lib-wasm/src/utils/js/mod.rs index 4c542d8..91614d9 100644 --- a/packages/lib-wasm/src/utils/js/mod.rs +++ b/packages/lib-wasm/src/utils/js/mod.rs @@ -1,16 +1,9 @@ -pub fn set_panic_hook() { - // When the `console_error_panic_hook` feature is enabled, we can call the - // `set_panic_hook` function at least once during initialization, and then - // we will get better error messages if our code ever panics. - // - // For more details see - // https://github.com/rustwasm/console_error_panic_hook#readme - #[cfg(feature = "console_error_panic_hook")] - console_error_panic_hook::set_once(); -} - -pub mod deserializers; pub mod js_block_creator; + pub mod js_progress_reporter; -pub mod macros; -pub mod serializers; + +mod set_panic_hook; + +pub use set_panic_hook::set_panic_hook; + +pub mod ts_definitions; diff --git a/packages/lib-wasm/src/utils/js/serializers.rs b/packages/lib-wasm/src/utils/js/serializers.rs deleted file mode 100644 index e6b0ea3..0000000 --- a/packages/lib-wasm/src/utils/js/serializers.rs +++ /dev/null @@ -1,124 +0,0 @@ -use js_sys::{Object, Reflect::set}; -use log::error; -use sds_core::{ - data_block::typedefs::DataBlockHeadersSlice, - processing::{ - aggregator::{typedefs::AggregatesCountMap, Aggregator, PrivacyRiskSummary}, - evaluator::typedefs::PreservationByCountBuckets, - }, - utils::time::ElapsedDurationLogger, -}; -use wasm_bindgen::JsValue; - -use crate::{aggregator::AggregatedCombination, set_or_return_undefined}; - -#[inline] -pub fn serialize_aggregates_count( - aggregates_count: &AggregatesCountMap, - headers: &DataBlockHeadersSlice, -) -> JsValue { - let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_aggregates_count")); - let aggregated_values = Object::new(); - - for (agg, count) in aggregates_count.iter() { - let combination_key = Aggregator::format_aggregate_str(headers, agg); - - set_or_return_undefined!( - &aggregated_values, - &combination_key.clone().into(), - &AggregatedCombination::new(combination_key, count.count, agg.len()).into() - ); - } - aggregated_values.into() -} - -#[inline] -pub fn serialize_buckets(buckets: &PreservationByCountBuckets) -> JsValue { - let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_buckets")); - let serialized_buckets = Object::new(); - - for (bucket_index, b) in buckets.iter() { - let serialized_bucket = Object::new(); - - set_or_return_undefined!(&serialized_bucket, &"size".into(), &b.size.into()); - set_or_return_undefined!( - &serialized_bucket, - &"preservationSum".into(), - &b.preservation_sum.into() - ); - set_or_return_undefined!( - &serialized_bucket, - &"lengthSum".into(), - &b.length_sum.into() - ); - - set_or_return_undefined!( - &serialized_buckets, - &(*bucket_index).into(), - &serialized_bucket.into() - ); - } - serialized_buckets.into() -} - -#[inline] -pub fn serialize_privacy_risk(privacy_risk: &PrivacyRiskSummary) -> JsValue { - let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_privacy_risk")); - let serialized_privacy_risk = Object::new(); - - set_or_return_undefined!( - &serialized_privacy_risk, - &"totalNumberOfRecords".into(), - &privacy_risk.total_number_of_records.into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"totalNumberOfCombinations".into(), - &privacy_risk.total_number_of_combinations.into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"recordsWithUniqueCombinationsCount".into(), - &privacy_risk.records_with_unique_combinations_count.into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"recordsWithRareCombinationsCount".into(), - &privacy_risk.records_with_rare_combinations_count.into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"uniqueCombinationsCount".into(), - &privacy_risk.unique_combinations_count.into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"rareCombinationsCount".into(), - &privacy_risk.rare_combinations_count.into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"recordsWithUniqueCombinationsProportion".into(), - &privacy_risk - .records_with_unique_combinations_proportion - .into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"recordsWithRareCombinationsProportion".into(), - &privacy_risk - .records_with_rare_combinations_proportion - .into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"uniqueCombinationsProportion".into(), - &privacy_risk.unique_combinations_proportion.into() - ); - set_or_return_undefined!( - &serialized_privacy_risk, - &"rareCombinationsProportion".into(), - &privacy_risk.rare_combinations_proportion.into() - ); - serialized_privacy_risk.into() -} diff --git a/packages/lib-wasm/src/utils/js/set_panic_hook.rs b/packages/lib-wasm/src/utils/js/set_panic_hook.rs new file mode 100644 index 0000000..b1d7929 --- /dev/null +++ b/packages/lib-wasm/src/utils/js/set_panic_hook.rs @@ -0,0 +1,10 @@ +pub fn set_panic_hook() { + // When the `console_error_panic_hook` feature is enabled, we can call the + // `set_panic_hook` function at least once during initialization, and then + // we will get better error messages if our code ever panics. + // + // For more details see + // https://github.com/rustwasm/console_error_panic_hook#readme + #[cfg(feature = "console_error_panic_hook")] + console_error_panic_hook::set_once(); +} diff --git a/packages/lib-wasm/src/utils/js/ts_definitions.rs b/packages/lib-wasm/src/utils/js/ts_definitions.rs new file mode 100644 index 0000000..342c9fa --- /dev/null +++ b/packages/lib-wasm/src/utils/js/ts_definitions.rs @@ -0,0 +1,146 @@ +use wasm_bindgen::prelude::*; + +#[wasm_bindgen(typescript_custom_section)] +const TS_APPEND_CONTENT: &'static str = r#" +export type ReportProgressCallback = (progress: number) => void; + +export type HeaderNames = string[]; + +export type CsvRecord = string[]; + +export type CsvData = CsvRecord[]; + +export interface IGenerateResult { + expansionRatio: number; + syntheticData: CsvData; +} + +export interface IPrivacyRiskSummary { + totalNumberOfRecords: number + totalNumberOfCombinations: number + recordsWithUniqueCombinationsCount: number + recordsWithRareCombinationsCount: number + uniqueCombinationsCount: number + rareCombinationsCount: number + recordsWithUniqueCombinationsProportion: number + recordsWithRareCombinationsProportion: number + uniqueCombinationsProportion: number + rareCombinationsProportion: number +} + +export interface IAggregateCountByLen { + [length: number]: number +} + +export interface IAggregateCountAndLength { + count: number; + length: number; +} + +export interface IAggregatesCount { + [name: string]: IAggregateCountAndLength; +} + +export interface IAggregateResult { + reportingLength: number; + aggregatesCount?: IAggregatesCount; + rareCombinationsCountByLen: IAggregateCountByLen; + combinationsCountByLen: IAggregateCountByLen; + combinationsSumByLen: IAggregateCountByLen; + privacyRisk: IPrivacyRiskSummary; +} + +export interface IPreservationByCountBucket { + size: number; + preservationSum: number; + lengthSum: number; + combinationCountSum: number; +} + +export interface IPreservationByCountBuckets { + [bucket_index: number]: IPreservationByCountBucket; +} + +export interface IPreservationByCount { + buckets: IPreservationByCountBuckets; + combinationLoss: number; +} + +export interface IEvaluateResult { + sensitiveAggregateResult: IAggregateResult; + syntheticAggregateResult: IAggregateResult; + leakageCountByLen: IAggregateCountByLen; + fabricatedCountByLen: IAggregateCountByLen; + preservationByCount: IPreservationByCount; + recordExpansion: number; +} + +export interface ISelectedAttributesByColumn { + [columnIndex: number]: Set; +} + +export interface IAttributesIntersection { + value: string; + estimatedCount: number; + actualCount?: number; +} + +export interface IAttributesIntersectionByColumn { + [columnIndex: number]: IAttributesIntersection[]; +}"#; + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(typescript_type = "ReportProgressCallback")] + pub type JsReportProgressCallback; + + #[wasm_bindgen(typescript_type = "HeaderNames")] + pub type JsHeaderNames; + + #[wasm_bindgen(typescript_type = "CsvRecord")] + pub type JsCsvRecord; + + #[wasm_bindgen(typescript_type = "CsvData")] + pub type JsCsvData; + + #[wasm_bindgen(typescript_type = "IGenerateResult")] + pub type JsGenerateResult; + + #[wasm_bindgen(typescript_type = "IPrivacyRiskSummary")] + pub type JsPrivacyRiskSummary; + + #[wasm_bindgen(typescript_type = "IAggregateCountByLen")] + pub type JsAggregateCountByLen; + + #[wasm_bindgen(typescript_type = "IAggregateCountAndLength")] + pub type JsAggregateCountAndLength; + + #[wasm_bindgen(typescript_type = "IAggregatesCount")] + pub type JsAggregatesCount; + + #[wasm_bindgen(typescript_type = "IAggregateResult")] + pub type JsAggregateResult; + + #[wasm_bindgen(typescript_type = "IPreservationByCountBucket")] + pub type JsPreservationByCountBucket; + + #[wasm_bindgen(typescript_type = "IPreservationByCountBuckets")] + pub type JsPreservationByCountBuckets; + + #[wasm_bindgen(typescript_type = "IPreservationByCount")] + pub type JsPreservationByCount; + + #[wasm_bindgen(typescript_type = "IEvaluateResult")] + pub type JsEvaluateResult; + + #[wasm_bindgen(typescript_type = "ISelectedAttributesByColumn")] + pub type JsSelectedAttributesByColumn; + + #[wasm_bindgen(typescript_type = "IAttributesIntersection")] + pub type JsAttributesIntersection; + + #[wasm_bindgen(typescript_type = "IAttributesIntersectionByColumn")] + pub type JsAttributesIntersectionByColumn; +} + +pub type JsResult = Result; diff --git a/packages/lib-wasm/src/utils/mod.rs b/packages/lib-wasm/src/utils/mod.rs index a54c46d..aef1430 100644 --- a/packages/lib-wasm/src/utils/mod.rs +++ b/packages/lib-wasm/src/utils/mod.rs @@ -1,2 +1,3 @@ pub mod js; + pub mod logger; diff --git a/packages/python-pipeline/src/aggregator.py b/packages/python-pipeline/src/aggregator.py index 32571b8..5f73de9 100644 --- a/packages/python-pipeline/src/aggregator.py +++ b/packages/python-pipeline/src/aggregator.py @@ -30,11 +30,13 @@ def aggregate(config): sensitive_zeros = config['sensitive_zeros'] output_dir = config['output_dir'] prefix = config['prefix'] + sensitive_aggregated_data_json = path.join( + output_dir, f'{prefix}_sensitive_aggregated_data.json') logging.info(f'Aggregate {sensitive_microdata_path}') start_time = time.time() - data_block = sds.create_data_block_from_file( + sds_processor = sds.SDSProcessor( sensitive_microdata_path, sensitive_microdata_delimiter, use_columns, @@ -42,13 +44,32 @@ def aggregate(config): max(record_limit, 0) ) - len_to_combo_count, len_to_rare_count = sds.aggregate_and_write( - data_block, + aggregated_data = sds_processor.aggregate( + reporting_length, + 0 + ) + len_to_combo_count = aggregated_data.calc_combinations_count_by_len() + len_to_rare_count = aggregated_data.calc_rare_combinations_count_by_len( + reporting_resolution) + + aggregated_data.write_aggregates_count( sensitive_aggregates_path, + '\t', + ';', + reporting_resolution, + False + ) + + aggregated_data.write_to_json(sensitive_aggregated_data_json) + + aggregated_data.protect_aggregates_count(reporting_resolution) + + aggregated_data.write_aggregates_count( reportable_aggregates_path, '\t', - reporting_length, - reporting_resolution + ';', + reporting_resolution, + True ) leakage_tsv = path.join( diff --git a/packages/python-pipeline/src/evaluator.py b/packages/python-pipeline/src/evaluator.py index 7c1e422..d840386 100644 --- a/packages/python-pipeline/src/evaluator.py +++ b/packages/python-pipeline/src/evaluator.py @@ -2,8 +2,247 @@ import time import datetime import logging from os import path -from collections import defaultdict import util as util +import sds + + +class Evaluator: + def __init__(self, config): + self.sds_evaluator = sds.Evaluator() + self.use_columns = config['use_columns'] + self.record_limit = max(config['record_limit'], 0) + self.reporting_length = max(config['reporting_length'], 0) + self.reporting_resolution = config['reporting_resolution'] + self.sensitive_microdata_path = config['sensitive_microdata_path'] + self.sensitive_microdata_delimiter = config['sensitive_microdata_delimiter'] + self.synthetic_microdata_path = config['synthetic_microdata_path'] + self.sensitive_zeros = config['sensitive_zeros'] + self.output_dir = config['output_dir'] + self.prefix = config['prefix'] + self.sensitive_aggregated_data_json = path.join( + self.output_dir, f'{self.prefix}_sensitive_aggregated_data.json' + ) + self.record_analysis_tsv = path.join( + self.output_dir, f'{self.prefix}_sensitive_analysis_by_length.tsv' + ) + self.synthetic_rare_combos_tsv = path.join( + self.output_dir, f'{self.prefix}_synthetic_rare_combos_by_length.tsv' + ) + self.parameters_tsv = path.join( + self.output_dir, f'{self.prefix}_parameters.tsv' + ) + self.leakage_tsv = path.join( + self.output_dir, f'{self.prefix}_synthetic_leakage_by_length.tsv' + ) + self.leakage_svg = path.join( + self.output_dir, f'{self.prefix}_synthetic_leakage_by_length.svg' + ) + self.preservation_by_count_tsv = path.join( + self.output_dir, f'{self.prefix}_preservation_by_count.tsv' + ) + self.preservation_by_count_svg = path.join( + self.output_dir, f'{self.prefix}_preservation_by_count.svg' + ) + self.preservation_by_length_tsv = path.join( + self.output_dir, f'{self.prefix}_preservation_by_length.tsv' + ) + self.preservation_by_length_svg = path.join( + self.output_dir, f'{self.prefix}_preservation_by_length.svg' + ) + + def _load_sensitive_aggregates(self): + if not path.exists(self.sensitive_aggregated_data_json): + logging.info('Computing sensitive aggregates...') + self.sen_sds_processor = sds.SDSProcessor( + self.sensitive_microdata_path, + self.sensitive_microdata_delimiter, + self.use_columns, + self.sensitive_zeros, + self.record_limit + ) + self.sen_aggregated_data = self.sen_sds_processor.aggregate( + self.reporting_length, + 0 + ) + else: + logging.info('Loading sensitive aggregates...') + self.sen_aggregated_data = sds.AggregatedData.read_from_json( + self.sensitive_aggregated_data_json + ) + self.sen_sds_processor = sds.SDSProcessor.from_aggregated_data( + self.sen_aggregated_data + ) + + self.reporting_length = self.sen_sds_processor.normalize_reporting_length( + self.reporting_length + ) + + def _load_synthetic_aggregates(self): + logging.info('Computing synthetic aggregates...') + self.syn_sds_processor = sds.SDSProcessor( + self.synthetic_microdata_path, + "\t", + self.use_columns, + self.sensitive_zeros, + self.record_limit + ) + self.syn_aggregated_data = self.syn_sds_processor.aggregate( + self.reporting_length, + 0 + ) + + def _do_records_analysis(self): + logging.info('Performing records analysis on sensitive aggregates...') + self.records_analysis_data = self.sen_aggregated_data.calc_records_analysis_by_len( + self.reporting_resolution, True + ) + self.records_analysis_data.write_records_analysis( + self.record_analysis_tsv, '\t' + ) + + def _compare_synthetic_and_sensitive_rare(self): + logging.info( + 'Comparing synthetic and sensitive aggregates rare combinations...') + rare_combinations_data = self.sds_evaluator.compare_synthetic_and_sensitive_rare( + self.syn_aggregated_data, self.sen_aggregated_data, self.reporting_resolution, ' AND ', True + ) + rare_combinations_data.write_rare_combinations( + self.synthetic_rare_combos_tsv, '\t' + ) + + def _write_parameters(self): + logging.info('Writing evaluation parameters...') + total_sen = self.sen_sds_processor.protected_number_of_records( + self.reporting_resolution + ) + with open(self.parameters_tsv, 'w') as f: + f.write('\t'.join(['parameter', 'value'])+'\n') + f.write( + '\t'.join(['resolution', str(self.reporting_resolution)])+'\n' + ) + f.write('\t'.join(['limit', str(self.reporting_length)])+'\n') + f.write( + '\t'.join(['total_sensitive_records', str(total_sen)])+'\n' + ) + f.write('\t'.join(['unique_identifiable', str( + self.records_analysis_data.get_total_unique())])+'\n' + ) + f.write('\t'.join(['rare_identifiable', str( + self.records_analysis_data.get_total_rare())])+'\n' + ) + f.write('\t'.join(['risky_identifiable', str( + self.records_analysis_data.get_total_risky())])+'\n' + ) + f.write( + '\t'.join(['risky_identifiable_pct', str( + 100*self.records_analysis_data.get_total_risky()/total_sen)])+'\n' + ) + + def _find_leakages(self): + logging.info( + 'Looking for leakages from the sensitive dataset on the synthetic dataset...') + comb_counts = self.syn_aggregated_data.calc_combinations_count_by_len() + leakage_counts = self.sds_evaluator.calc_leakage_count( + self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution + ) + + with open(self.leakage_tsv, 'w') as f: + f.write('\t'.join( + ['syn_combo_length', 'combo_count', 'leak_count', 'leak_proportion'])+'\n' + ) + for length in range(1, self.reporting_length + 1): + combo_count = comb_counts.get(length, 0) + leak_count = leakage_counts.get(length, 0) + # by design there should be no leakage + assert leak_count == 0 + leak_prop = leak_count/combo_count if combo_count > 0 else 0 + f.write('\t'.join( + [str(length), str(combo_count), str(leak_count), str(leak_prop)])+'\n' + ) + + util.plotStats( + x_axis='syn_combo_length', + x_axis_title='Length of Synthetic Combination', + y_bar='combo_count', + y_bar_title='Count of Combinations', + y_line='leak_proportion', + y_line_title=f'Proportion of Leaked (<{self.reporting_resolution}) Combinations', + color='violet', + darker_color='darkviolet', + stats_tsv=self.leakage_tsv, + stats_svg=self.leakage_svg, + delimiter='\t', + style='whitegrid', + palette='magma' + ) + + def _calculate_preservation_by_count(self): + logging.info('Calculating preservation by count...') + preservation_by_count = self.sds_evaluator.calc_preservation_by_count( + self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution + ) + preservation_by_count.write_preservation_by_count( + self.preservation_by_count_tsv, '\t' + ) + + util.plotStats( + x_axis='syn_count_bucket', + x_axis_title='Count of Filtered Synthetic Records', + y_bar='mean_combo_length', + y_bar_title='Mean Length of Combinations', + y_line='count_preservation', + y_line_title='Count Preservation (Synthetic/Sensitive)', + color='lightgreen', + darker_color='green', + stats_tsv=self.preservation_by_count_tsv, + stats_svg=self.preservation_by_count_svg, + delimiter='\t', + style='whitegrid', + palette='magma' + ) + + def _calculate_preservation_by_length(self): + logging.info('Calculating preservation by length...') + preservation_by_length = self.sds_evaluator.calc_preservation_by_length( + self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution + ) + preservation_by_length.write_preservation_by_length( + self.preservation_by_length_tsv, '\t' + ) + + util.plotStats( + x_axis='syn_combo_length', + x_axis_title='Length of Combination', + y_bar='mean_combo_count', + y_bar_title='Mean Synthetic Count', + y_line='count_preservation', + y_line_title='Count Preservation (Synthetic/Sensitive)', + color='cornflowerblue', + darker_color='mediumblue', + stats_tsv=self.preservation_by_length_tsv, + stats_svg=self.preservation_by_length_svg, + delimiter='\t', + style='whitegrid', + palette='magma' + ) + + def run(self): + logging.info( + f'Evaluate {self.synthetic_microdata_path} vs {self.sensitive_microdata_path}' + ) + start_time = time.time() + + self._load_sensitive_aggregates() + self._load_synthetic_aggregates() + self._do_records_analysis() + self._compare_synthetic_and_sensitive_rare() + self._write_parameters() + self._find_leakages() + self._calculate_preservation_by_count() + self._calculate_preservation_by_length() + + logging.info( + f'Evaluated {self.synthetic_microdata_path} vs {self.sensitive_microdata_path}, took {datetime.timedelta(seconds = time.time() - start_time)}s') def evaluate(config): @@ -15,309 +254,4 @@ def evaluate(config): Args: config: options from the json config file, else default values. """ - - use_columns = config['use_columns'] - record_limit = config['record_limit'] - reporting_length = config['reporting_length'] - reporting_resolution = config['reporting_resolution'] - sensitive_microdata_path = config['sensitive_microdata_path'] - sensitive_microdata_delimiter = config['sensitive_microdata_delimiter'] - synthetic_microdata_path = config['synthetic_microdata_path'] - sensitive_zeros = config['sensitive_zeros'] - parallel_jobs = config['parallel_jobs'] - output_dir = config['output_dir'] - prefix = config['prefix'] - - logging.info( - f'Evaluate {synthetic_microdata_path} vs {sensitive_microdata_path}') - start_time = time.time() - - sen_counts = None - sen_records = None - sen_df = util.loadMicrodata( - path=sensitive_microdata_path, delimiter=sensitive_microdata_delimiter, record_limit=record_limit, - use_columns=use_columns) - sen_records = util.genRowList(sen_df, sensitive_zeros) - if not path.exists(config['sensitive_aggregates_path']): - logging.info('Computing sensitive aggregates...') - if reporting_length == -1: - reporting_length = max([len(row) for row in sen_records]) - sen_counts = util.countAllCombos( - sen_records, reporting_length, parallel_jobs) - else: - logging.info('Loading sensitive aggregates...') - sen_counts = util.loadSavedAggregates( - config['sensitive_aggregates_path']) - if reporting_length == -1: - reporting_length = max(sen_counts.keys()) - - if use_columns != []: - reporting_length = min(reporting_length, len(use_columns)) - - filtered_sen_counts = {length: {combo: count for combo, count in combo_to_counts.items( - ) if count >= reporting_resolution} for length, combo_to_counts in sen_counts.items()} - syn_df = util.loadMicrodata(path=synthetic_microdata_path, - delimiter='\t', record_limit=-1, use_columns=use_columns) - syn_records = util.genRowList(syn_df, sensitive_zeros) - syn_counts = util.countAllCombos( - syn_records, reporting_length, parallel_jobs) - - len_to_syn_count = {length: len(combo_to_count) - for length, combo_to_count in syn_counts.items()} - len_to_sen_rare = {length: {combo: count for combo, count in combo_to_count.items() if count < reporting_resolution} - for length, combo_to_count in sen_counts.items()} - len_to_syn_rare = {length: {combo: count for combo, count in combo_to_count.items() if count < reporting_resolution} - for length, combo_to_count in syn_counts.items()} - len_to_syn_leak = {length: len([1 for rare in rares if rare in syn_counts.get(length, {}).keys()]) - for length, rares in len_to_sen_rare.items()} - - sen_unique_to_records, sen_rare_to_records, _ = util.mapShortestUniqueRareComboLengthToRecords( - sen_records, len_to_sen_rare) - sen_rare_to_sen_count = {length: util.protect(len(records), reporting_resolution) - for length, records in sen_rare_to_records.items()} - sen_unique_to_sen_count = {length: util.protect(len(records), reporting_resolution) - for length, records in sen_unique_to_records.items()} - - total_sen = util.protect(len(sen_records), reporting_resolution) - unique_total = sum( - [v for k, v in sen_unique_to_sen_count.items() if k > 0]) - rare_total = sum([v for k, v in sen_rare_to_sen_count.items() if k > 0]) - risky_total = unique_total + rare_total - risky_total_pct = 100*risky_total/total_sen - - record_analysis_tsv = path.join( - output_dir, f'{prefix}_sensitive_analysis_by_length.tsv') - with open(record_analysis_tsv, 'w') as f: - f.write('\t'.join(['combo_length', 'sen_rare', 'sen_rare_pct', - 'sen_unique', 'sen_unique_pct', 'sen_risky', 'sen_risky_pct'])+'\n') - for length in sen_counts.keys(): - sen_rare = sen_rare_to_sen_count.get(length, 0) - sen_rare_pct = 100*sen_rare / total_sen if total_sen > 0 else 0 - sen_unique = sen_unique_to_sen_count.get(length, 0) - sen_unique_pct = 100*sen_unique / total_sen if total_sen > 0 else 0 - sen_risky = sen_rare + sen_unique - sen_risky_pct = 100*sen_risky / total_sen if total_sen > 0 else 0 - f.write( - '\t'.join( - [str(length), - str(sen_rare), - str(sen_rare_pct), - str(sen_unique), - str(sen_unique_pct), - str(sen_risky), - str(sen_risky_pct)]) + '\n') - - _, _, syn_length_to_combo_to_rare = util.mapShortestUniqueRareComboLengthToRecords( - syn_records, len_to_syn_rare) - combos_tsv = path.join( - output_dir, f'{prefix}_synthetic_rare_combos_by_length.tsv') - with open(combos_tsv, 'w') as f: - f.write('\t'.join(['combo_length', 'combo', - 'record_id', 'syn_count', 'sen_count'])+'\n') - for length, combo_to_rare in syn_length_to_combo_to_rare.items(): - for combo, rare_ids in combo_to_rare.items(): - syn_count = len(rare_ids) - for rare_id in rare_ids: - sen_count = util.protect(sen_counts.get(length, {})[ - combo], reporting_resolution) - f.write( - '\t'.join( - [str(length), - util.comboToString(combo).replace(';', ' AND '), - str(rare_id), - str(syn_count), - str(sen_count)]) + '\n') - - parameters_tsv = path.join(output_dir, f'{prefix}_parameters.tsv') - - with open(parameters_tsv, 'w') as f: - f.write('\t'.join(['parameter', 'value'])+'\n') - f.write('\t'.join(['resolution', str(reporting_resolution)])+'\n') - f.write('\t'.join(['limit', str(reporting_length)])+'\n') - f.write('\t'.join(['total_sensitive_records', str(total_sen)])+'\n') - f.write('\t'.join(['unique_identifiable', str(unique_total)])+'\n') - f.write('\t'.join(['rare_identifiable', str(rare_total)])+'\n') - f.write('\t'.join(['risky_identifiable', str(risky_total)])+'\n') - f.write( - '\t'.join(['risky_identifiable_pct', str(risky_total_pct)])+'\n') - - leakage_tsv = path.join( - output_dir, f'{prefix}_synthetic_leakage_by_length.tsv') - leakage_svg = path.join( - output_dir, f'{prefix}_synthetic_leakage_by_length.svg') - with open(leakage_tsv, 'w') as f: - f.write('\t'.join(['syn_combo_length', 'combo_count', - 'leak_count', 'leak_proportion'])+'\n') - for length, leak_count in len_to_syn_leak.items(): - combo_count = len_to_syn_count.get(length, 0) - leak_prop = leak_count/combo_count if combo_count > 0 else 0 - f.write('\t'.join([str(length), str(combo_count), - str(leak_count), str(leak_prop)])+'\n') - - util.plotStats( - x_axis='syn_combo_length', - x_axis_title='Length of Synthetic Combination', - y_bar='combo_count', - y_bar_title='Count of Combinations', - y_line='leak_proportion', - y_line_title=f'Proportion of Leaked (<{reporting_resolution}) Combinations', - color='violet', - darker_color='darkviolet', - stats_tsv=leakage_tsv, - stats_svg=leakage_svg, - delimiter='\t', - style='whitegrid', - palette='magma') - - compareDatasets(filtered_sen_counts, syn_counts, output_dir, prefix) - - logging.info( - f'Evaluated {synthetic_microdata_path} vs {sensitive_microdata_path}, took {datetime.timedelta(seconds = time.time() - start_time)}s') - - -def compareDatasets(sensitive_length_to_combo_to_count, synthetic_length_to_combo_to_count, output_dir, prefix): - """Evaluates the error in the synthetic microdata with respect to the sensitive microdata. - - Produces output statistics and graphics binned by attribute count. - - Args: - sensitive_length_to_combo_to_count: counts from sensitive microdata. - synthetic_length_to_combo_to_count: counts from synthetic microdata. - output_dir: where to save output statistics and graphics. - prefix: prefix to add to output files. - """ - - all_count_length_preservation = [] - max_syn_count = 0 - - all_combos = set() - for length, combo_to_count in sensitive_length_to_combo_to_count.items(): - for combo in combo_to_count.keys(): - all_combos.add((length, combo)) - for length, combo_to_count in synthetic_length_to_combo_to_count.items(): - for combo in combo_to_count.keys(): - all_combos.add((length, combo)) - tot = len(all_combos) - for i, (length, combo) in enumerate(all_combos): - if i % 10000 == 0: - logging.info(f'{100*i/tot:.1f}% through comparisons') - sen_count = sensitive_length_to_combo_to_count.get( - length, {}).get(combo, 0) - syn_count = synthetic_length_to_combo_to_count.get( - length, {}).get(combo, 0) - max_syn_count = max(syn_count, max_syn_count) - preservation = 0 - preservation = syn_count / sen_count if sen_count > 0 else 0 - if sen_count == 0: - logging.error( - f'Error: For {combo}, synthetic count is {syn_count} but no sensitive count') - all_count_length_preservation.append((syn_count, length, preservation)) - max_syn_count = max(syn_count, max_syn_count) - - generateStatsAndGraphics(output_dir, max_syn_count, - all_count_length_preservation, prefix) - - -def generateStatsAndGraphics(output_dir, max_syn_count, count_length_preservation, prefix): - """Generates count error statistics for buckets of user-observed counts (post-filtering). - - Outputs the files preservation_by_length and preservation_by_count tsv and svg files. - - Args: - output_dir: the folder in which to output summary files. - max_syn_count: the maximum observed count of synthetic records matching a single attribute value. - count_length_preservation: list of (count, length, preservation) tuples for observed preservation of sensitive counts after filtering by a combo of length. - """ - sorted_counts = sorted(count_length_preservation, - key=lambda x: x[0], reverse=False) - buckets = [] - next_bucket = 10 - while next_bucket < max_syn_count: - buckets.append(next_bucket) - next_bucket *= 2 - buckets.append(next_bucket) - bucket_to_preservations = defaultdict(list) - bucket_to_counts = defaultdict(list) - bucket_to_lengths = defaultdict(list) - length_to_preservations = defaultdict(list) - length_to_counts = defaultdict(list) - for (count, length, preservation) in sorted_counts: - bucket = buckets[0] - for bi in range(len(buckets)-1, -1, -1): - if count > buckets[bi]: - bucket = buckets[bi+1] - break - bucket_to_preservations[bucket].append(preservation) - bucket_to_lengths[bucket].append(length) - bucket_to_counts[bucket].append(count) - length_to_counts[length].append(count) - length_to_preservations[length].append(preservation) - - bucket_to_mean_count = {bucket: sum(counts)/len(counts) if len(counts) > - 0 else 0 for bucket, counts in bucket_to_counts.items()} - bucket_to_mean_preservation = {bucket: sum(preservations) / len(preservations) - if len(preservations) > 0 else 0 for bucket, - preservations in bucket_to_preservations.items()} - bucket_to_mean_length = {bucket: sum(lengths)/len(lengths) if len(lengths) > - 0 else 0 for bucket, lengths in bucket_to_lengths.items()} - - counts_tsv = path.join(output_dir, f'{prefix}_preservation_by_count.tsv') - counts_svg = path.join(output_dir, f'{prefix}_preservation_by_count.svg') - with open(counts_tsv, 'w') as f: - f.write('\t'.join(['syn_count_bucket', 'mean_combo_count', - 'mean_combo_length', 'count_preservation'])+'\n') - for bucket in reversed(buckets): - f.write( - '\t'.join( - [str(bucket), - str(bucket_to_mean_count.get(bucket, 0)), - str(bucket_to_mean_length.get(bucket, 0)), - str(bucket_to_mean_preservation.get(bucket, 0))]) + '\n') - - util.plotStats( - x_axis='syn_count_bucket', - x_axis_title='Count of Filtered Synthetic Records', - y_bar='mean_combo_length', - y_bar_title='Mean Length of Combinations', - y_line='count_preservation', - y_line_title='Count Preservation (Synthetic/Sensitive)', - color='lightgreen', - darker_color='green', - stats_tsv=counts_tsv, - stats_svg=counts_svg, - delimiter='\t', - style='whitegrid', - palette='magma') - - length_to_mean_preservation = {length: sum(preservations) / len(preservations) - if len(preservations) > 0 else 0 for length, - preservations in length_to_preservations.items()} - length_to_mean_count = {length: sum(counts)/len(counts) if len(counts) > - 0 else 0 for length, counts in length_to_counts.items()} - - lengths_tsv = path.join(output_dir, f'{prefix}_preservation_by_length.tsv') - lengths_svg = path.join(output_dir, f'{prefix}_preservation_by_length.svg') - with open(lengths_tsv, 'w') as f: - f.write( - '\t'.join(['syn_combo_length', 'mean_combo_count', 'count_preservation'])+'\n') - for length in sorted(length_to_preservations.keys()): - f.write( - '\t'.join( - [str(length), - str(length_to_mean_count.get(length, 0)), - str(length_to_mean_preservation.get(length, 0))]) + '\n') - - util.plotStats( - x_axis='syn_combo_length', - x_axis_title='Length of Combination', - y_bar='mean_combo_count', - y_bar_title='Mean Synthetic Count', - y_line='count_preservation', - y_line_title='Count Preservation (Synthetic/Sensitive)', - color='cornflowerblue', - darker_color='mediumblue', - stats_tsv=lengths_tsv, - stats_svg=lengths_svg, - delimiter='\t', - style='whitegrid', - palette='magma') + Evaluator(config).run() diff --git a/packages/python-pipeline/src/generator.py b/packages/python-pipeline/src/generator.py index c18e4ab..144c130 100644 --- a/packages/python-pipeline/src/generator.py +++ b/packages/python-pipeline/src/generator.py @@ -1,12 +1,6 @@ import time import datetime import logging -import joblib -import psutil -import pandas as pd -from math import ceil -from random import random, shuffle -import util as util import sds @@ -20,14 +14,12 @@ def generate(config): """ use_columns = config['use_columns'] - parallel_jobs = config['parallel_jobs'] record_limit = config['record_limit'] sensitive_microdata_path = config['sensitive_microdata_path'] sensitive_microdata_delimiter = config['sensitive_microdata_delimiter'] synthetic_microdata_path = config['synthetic_microdata_path'] sensitive_zeros = config['sensitive_zeros'] resolution = config['reporting_resolution'] - memory_limit = config['memory_limit_pct'] cache_max_size = config['cache_max_size'] seeded = config['seeded'] @@ -36,215 +28,26 @@ def generate(config): if seeded: logging.info(f'Generating from seeds') - data_block = sds.create_data_block_from_file( - sensitive_microdata_path, - sensitive_microdata_delimiter, - use_columns, - sensitive_zeros, - max(record_limit, 0) - ) - syn_ratio = sds.generate_seeded_and_write( - data_block, - synthetic_microdata_path, - '\t', - cache_max_size, - resolution - ) - else: - df = util.loadMicrodata(path=sensitive_microdata_path, delimiter=sensitive_microdata_delimiter, - record_limit=record_limit, use_columns=use_columns) - columns = df.columns.values - num = len(df) - logging.info(f'Prepared data') - records = [] - - chunk = ceil(num/(parallel_jobs)) - logging.info(f'Generating unseeded') - chunks = [chunk for i in range(parallel_jobs)] - col_val_ids = util.genColValIdsDict(df, sensitive_zeros) - res = joblib.Parallel( - n_jobs=parallel_jobs, backend='loky', verbose=1)( - joblib.delayed(synthesizeRowsUnseeded)( - chunk, num, columns, col_val_ids, resolution, memory_limit) - for chunk in chunks) - for rows in res: - records.extend(rows) - # trim any overgenerated records because of uniform chunk size - records = records[:num] - records.sort() - records.sort(key=lambda x: len( - [y for y in x if y != '']), reverse=True) - - sdf = pd.DataFrame(records, columns=df.columns) - syn_ratio = len(sdf) / len(df) - sdf.to_csv(synthetic_microdata_path, sep='\t', index=False) + sds_processor = sds.SDSProcessor( + sensitive_microdata_path, + sensitive_microdata_delimiter, + use_columns, + sensitive_zeros, + max(record_limit, 0) + ) + generated_data = sds_processor.generate( + cache_max_size, + resolution, + "", + seeded + ) + generated_data.write_synthetic_data(synthetic_microdata_path, '\t') + syn_ratio = generated_data.expansion_ratio config['expansion_ratio'] = syn_ratio logging.info( f'Generated {synthetic_microdata_path} from {sensitive_microdata_path} with synthesis ratio {syn_ratio}, took {datetime.timedelta(seconds = time.time() - start_time)}s') - - -def synthesizeRowsUnseeded(chunk, num_rows, columns, col_val_ids, resolution, memory_limit): - """Create synthetic records through unconstrained sampling of attribute distributions. - - Args: - chunk: how many records to synthesize. - num_rows: how many rows/records in the sensitive dataset. - columns: the columns of the sensitive dataset. - atts_to_ids: a dict mapping attributes to row ids. - resolution: the minimum count of sensitive attribute combinations for inclusion in synthetic records. - memory_limit: the percentage memory use not to exceed. - - Returns: - rows: synthetic records created through unconstrained sampling of attribute distributions. - overall_cache_hits: the number of times the cache was hit. - overall_cache_misses: the number of times the cache was missed. - """ - rows = [] - filter_cache = {} - overall_cache_hits = 0 - overall_cache_misses = 0 - - for i in range(chunk): - if i % 100 == 99: - logging.info( - f'{(100*i/chunk):.1f}% through row synthesis, cache utilization {100*overall_cache_hits/(overall_cache_hits + overall_cache_misses):.1f}%') - filters, cache_hits, cache_misses = synthesizeRowUnseeded( - filter_cache, num_rows, columns, col_val_ids, resolution, memory_limit) - overall_cache_hits += cache_hits - overall_cache_misses += cache_misses - row = normalize(columns, filters) - rows.append(row) - - return rows - - -def synthesizeRowUnseeded(filter_cache, num_rows, columns, col_val_ids, resolution, memory_limit): - """Creates a synthetic record through unconstrained sampling of attribute distributions. - - Args: - num_rows: how many rows/records in the sensitive dataset. - columns: the columns of the sensitive dataset. - atts_to_ids: a dict mapping attributes to row ids. - resolution: the minimum count of sensitive attribute combinations for inclusion in synthetic records. - memory_limit: the percentage memory use not to exceed. - - Returns: - row: a synthetic record created through unconstrained sampling of attribute distributions. - row_cache_hits: the number of times the cache was hit. - row_cache_misses: the number of times the cache was missed. - """ - shuffled_columns = list(columns) - shuffle(shuffled_columns) - output_atts = [] - residual_ids = set(range(num_rows)) - row_cache_hits = 0 - row_cache_misses = 0 - # do not add to cache once memory limit is reached - use_cache = psutil.virtual_memory()[2] <= memory_limit - for col in shuffled_columns: - vals_to_sample = {} - for val, ids in col_val_ids[col].items(): - next_filters = tuple( - sorted((*output_atts, (col, val)), key=lambda x: f'{x[0]}:{x[1]}'.lower())) - if next_filters in filter_cache.keys(): - vals_to_sample[val] = set(filter_cache[next_filters]) - row_cache_hits += 1 - else: - row_cache_misses += 1 - res = set(ids).intersection(residual_ids) - vals_to_sample[val] = res - if use_cache: - filter_cache[next_filters] = res - if '' not in vals_to_sample.keys(): - vals_to_sample[''] = set() - vals_to_remove = {k: v for k, - v in vals_to_sample.items() if len(v) < resolution} - for val, ids in vals_to_remove.items(): - vals_to_sample[''].update(ids) - if val != '': - del vals_to_sample[val] - val_to_counts = {k: len(v) for k, v in vals_to_sample.items()} - sampled_val = sampleFromCounts(val_to_counts, True) - if sampled_val != None: - output_atts.append((col, sampled_val)) - residual_ids = set(vals_to_sample[sampled_val]) - - filters = tuple( - sorted(output_atts, key=lambda x: f'{x[0]}:{x[1]}'.lower())) - return filters, row_cache_hits, row_cache_misses - - -def sampleFromCounts(counts, preferNotNone): - """Samples from a dict of counts based on their relative frequency - - Args: - counts: the dict of counts. - preferNotNone: whether to avoid sampling None if possible - - Returns: - sampled_value: the sampled value - """ - dist = convertCountsToCumulativeDistribution(counts) - sampled_value = None - r = random() - keys = sorted(dist.keys()) - for p in keys: - v = dist[p] - if r < p: - if preferNotNone and v == None: - continue - else: - sampled_value = v - break - sampled_value = v - return sampled_value - - -def convertCountsToCumulativeDistribution(att_to_count): - """Converts a dict of counts to a cumulative probability distribution for sampling. - - Args: - att_to_count: a dict of attribute counts. - - Returns: - dist a cumulative probability distribution mapping cumulative probabilities [0,1] to attributes. - """ - dist = {} - total = sum(att_to_count.values()) - if total == 0: - return dist - cumulative = 0 - for att, count in att_to_count.items(): - if count > 0: - p = count/total - cumulative += p - dist[cumulative] = att - return dist - - -def normalize(columns, atts): - """Creates an output record according to the columns schema from a set of attributes. - - Args: - columns: the columns of the output record. - atts: the attribute values of the output record. - - Returns: - row: a normalized row ready for dataframe integration. - """ - row = [] - for c in columns: - added = False - for a in atts: - if a[0] == c: - row.append(a[1]) - added = True - break - if not added: - row.append('') - return row diff --git a/packages/python-pipeline/src/showcase.py b/packages/python-pipeline/src/showcase.py index 333ca16..6f72a66 100644 --- a/packages/python-pipeline/src/showcase.py +++ b/packages/python-pipeline/src/showcase.py @@ -1,6 +1,7 @@ import json import logging import argparse +import sds from os import path, mkdir import aggregator as aggregator import generator as generator @@ -49,7 +50,7 @@ def main(): config['evaluate'] = True config['navigate'] = True - # set based on the number of cores/memory available + # set based on the number of cores/memory available config['parallel_jobs'] = config.get('parallel_jobs', 1) config['memory_limit_pct'] = config.get('memory_limit_pct', 80) config['cache_max_size'] = config.get('cache_max_size', 100000) @@ -102,6 +103,9 @@ def runPipeline(config): config: options from the json config file, else default values. """ + # set number of threads to be used for processing + sds.set_number_of_threads(config['parallel_jobs']) + if config['aggregate']: aggregator.aggregate(config) diff --git a/packages/python-pipeline/src/util.py b/packages/python-pipeline/src/util.py index 772648b..3973acf 100644 --- a/packages/python-pipeline/src/util.py +++ b/packages/python-pipeline/src/util.py @@ -1,14 +1,10 @@ import matplotlib.ticker as ticker import matplotlib.pyplot as plt -import logging import numpy as np import pandas as pd -import joblib -from itertools import combinations -from collections import defaultdict import seaborn as sns -from math import ceil, floor +from math import ceil import matplotlib # fixes matplotlib + joblib bug "RuntimeError: main thread is not in main loop Tcl_AsyncDelete: async handler deleted by the wrong thread" matplotlib.use('Agg') @@ -36,219 +32,6 @@ def loadMicrodata(path, delimiter, record_limit, use_columns): return df -def genRowList(df, sensitive_zeros): - """Converts a dataframe to a list of rows. - - Args: - df: the dataframe. - sensitive_zeros: columns for which negative values should be controlled. - - Returns: - row_list: the list of rows. - """ - row_list = [] - for _, row in df.iterrows(): - res = [] - for c in df.columns: - val = str(row[c]) - if val != '' and (c in sensitive_zeros or val != '0'): - res.append((c, val)) - row2 = sorted(res, key=lambda x: f'{x[0]}:{x[1]}'.lower()) - row_list.append(row2) - return row_list - - -def genColValIdsDict(df, sensitive_zeros): - """Converts a dataframe to a dict of col->val->ids. - - Args: - df: the dataframe - - Returns: - colValIds: the dict of col->val->ids. - """ - colValIds = {} - for rid, row in df.iterrows(): - for c in df.columns: - val = str(row[c]) - if c not in colValIds.keys(): - colValIds[c] = defaultdict(list) - if val == '0': - if c in sensitive_zeros: - colValIds[c][val].append(rid) - else: - colValIds[c][''].append(rid) - else: - colValIds[c][val].append(rid) - - return colValIds - - -def countAllCombos(row_list, length_limit, parallel_jobs): - """Counts all combinations in the given rows up to a limit. - - Args: - row_list: a list of rows. - length_limit: the maximum length to compute counts for (all lengths if -1). - parallel_jobs: the number of processor cores to use for parallelized counting. - - Returns: - length_to_combo_to_count: a dict mapping combination lengths to dicts mapping combinations to counts. - """ - length_to_combo_to_count = {} - if length_limit == -1: - length_limit = max([len(x) for x in row_list]) - for length in range(1, length_limit+1): - logging.info(f'counting combos of length {length}') - res = joblib.Parallel(n_jobs=parallel_jobs, backend='loky', verbose=1)( - joblib.delayed(genAllCombos)(row, length) for row in row_list) - length_to_combo_to_count[length] = defaultdict(int) - for combos in res: - for combo in combos: - length_to_combo_to_count.get(length, {})[combo] += 1 - - return length_to_combo_to_count - - -def genAllCombos(row, length): - """Generates all combos from row up to and including size length. - - Args: - row: the row to extract combinations from. - length: the maximum combination length to extract. - - Returns: - combos: list of combinations extracted from row. - """ - res = [] - if len(row) == length: - canonical_combo = tuple( - sorted(list(row), key=lambda x: f'{x[0]}:{x[1]}'.lower())) - res.append(canonical_combo) - else: - for combo in combinations(row, length): - canonical_combo = tuple( - sorted(list(combo), key=lambda x: f'{x[0]}:{x[1]}'.lower())) - res.append(canonical_combo) - return res - - -def mapShortestUniqueRareComboLengthToRecords(records, length_to_rare): - """ - Maps each record to the shortest combination length that isolates it within a rare group (i.e., below resolution). - - Args: - records: the input records. - length_to_rare: a dict of length to rare combo to count. - - Returns: - rare_to_records: dict of rare combination lengths mapped to record lists - """ - rare_to_records = defaultdict(set) - unique_to_records = defaultdict(set) - length_to_combo_to_rare = {length: defaultdict( - set) for length in length_to_rare.keys()} - for i, record in enumerate(records): - matchedRare = False - matchedUnique = False - for length in sorted(length_to_rare.keys()): - if matchedUnique: - break - for combo in combinations(record, length): - canonical_combo = tuple( - sorted(list(combo), key=lambda x: f'{x[0]}:{x[1]}'.lower())) - if canonical_combo in length_to_rare.get(length, {}).keys(): - # unique - if length_to_rare.get(length, {})[canonical_combo] == 1: - unique_to_records[length].add(i) - matchedUnique = True - length_to_combo_to_rare.get( - length, {})[canonical_combo].add(i) - if i in rare_to_records[length]: - rare_to_records[length].remove(i) - - else: - if i not in unique_to_records[length]: - rare_to_records[length].add(i) - matchedRare = True - length_to_combo_to_rare.get( - length, {})[canonical_combo].add(i) - - if matchedUnique: - break - if not matchedRare and not matchedUnique: - rare_to_records[0].add(i) - if not matchedUnique: - unique_to_records[0].add(i) - return unique_to_records, rare_to_records, length_to_combo_to_rare - - -def protect(value, resolution): - """Protects a value from a privacy perspective by rounding down to the closest multiple of the supplied resolution. - - Args: - value: the value to protect. - resolution: round values down to the closest multiple of this. - """ - rounded = floor(value / resolution) * resolution - return rounded - - -def comboToString(combo_tuple): - """Creates a string from a tuple of (attribute, value) tuples. - - Args: - combo_tuple: tuple of (attribute, value) tuples. - - Returns: - combo_string: a ;-delimited string of ':'-concatenated attribute-value pairs. - """ - combo_string = '' - for (col, val) in combo_tuple: - combo_string += f'{col}:{val};' - return combo_string[:-1] - - -def loadSavedAggregates(path): - """Loads sensitive aggregates from file to speed up evaluation. - - Args: - path: location of the saved sensitive aggregates. - - Returns: - length_to_combo_to_count: a dict mapping combination lengths to dicts mapping combinations to counts - """ - length_to_combo_to_count = defaultdict(dict) - with open(path, 'r') as f: - for i, line in enumerate(f): - if i == 0: - continue # skip header - parts = [x.strip() for x in line.split('\t')] - if len(parts[0]) > 0: - length, combo = stringToLengthAndCombo(parts[0]) - length_to_combo_to_count[length][combo] = int(parts[1]) - return length_to_combo_to_count - - -def stringToLengthAndCombo(combo_string): - """Creates a tuple of (attribute, value) tuples from a given string. - - Args: - combo_string: string representation of (attribute, value) tuples. - - Returns: - combo_tuple: tuple of (attribute, value) tuples. - """ - length = len(combo_string.split(';')) - combo_list = [] - for col_vals in combo_string.split(';'): - parts = col_vals.split(':') - if len(parts) == 2: - combo_list.append((parts[0], parts[1])) - combo_tuple = tuple(combo_list) - return length, combo_tuple - - def plotStats( x_axis, x_axis_title, y_bar, y_bar_title, y_line, y_line_title, color, darker_color, stats_tsv, stats_svg, delimiter, style='whitegrid', palette='magma'): diff --git a/packages/webapp/src/components/AttributeSelector/ColumnAttributeSelector.tsx b/packages/webapp/src/components/AttributeSelector/ColumnAttributeSelector.tsx index 909ff20..115418f 100644 --- a/packages/webapp/src/components/AttributeSelector/ColumnAttributeSelector.tsx +++ b/packages/webapp/src/components/AttributeSelector/ColumnAttributeSelector.tsx @@ -2,31 +2,24 @@ * Copyright (c) Microsoft. All rights reserved. * Licensed under the MIT license. See LICENSE file in the project. */ -import { Stack, Spinner, Label } from '@fluentui/react' -import _ from 'lodash' +import { Stack, Label, Spinner } from '@fluentui/react' import { memo, useCallback, useEffect, useRef, useState } from 'react' -import { useOnAttributeSelection } from './hooks' +import { IAttributesIntersection } from 'sds-wasm' +import { useMaxCount } from './hooks' import { AttributeIntersectionValueChart } from '~components/Charts/AttributeIntersectionValueChart' import { useStopPropagation } from '~components/Charts/hooks' -import { CsvRecord, IAttributesIntersectionValue } from '~models' -import { - useEvaluatedResultValue, - useNavigateResultValue, - useWasmWorkerValue, -} from '~states' -import { - useSelectedAttributeRows, - useSelectedAttributes, -} from '~states/dataShowcaseContext' +import { SetSelectedAttributesCallback } from '~components/Pages/DataShowcasePage/DataNavigation' +import { useWasmWorkerValue } from '~states' export interface ColumnAttributeSelectorProps { - headers: CsvRecord headerName: string columnIndex: number height: string | number - width: number | number - chartHeight: number + width: string | number + chartBarHeight: number minHeight?: string | number + selectedAttributes: Set + onSetSelectedAttributes: SetSelectedAttributesCallback } // fixed value for rough axis height so charts don't squish if selected down to 1 @@ -34,87 +27,49 @@ const AXIS_HEIGHT = 16 export const ColumnAttributeSelector: React.FC = memo(function ColumnAttributeSelector({ - headers, headerName, columnIndex, height, width, - chartHeight, + chartBarHeight, minHeight, + selectedAttributes, + onSetSelectedAttributes, }: ColumnAttributeSelectorProps) { - const [selectedAttributeRows, setSelectedAttributeRows] = - useSelectedAttributeRows() - const [selectedAttributes, setSelectedAttributes] = useSelectedAttributes() - const navigateResult = useNavigateResultValue() - const [items, setItems] = useState([]) + const [items, setItems] = useState([]) const [isLoading, setIsLoading] = useState(false) - const isMounted = useRef(true) - const evaluatedResult = useEvaluatedResultValue() const worker = useWasmWorkerValue() - const selectedAttribute = selectedAttributes[columnIndex] - const maxCount = - _.max([ - _.maxBy(items, item => item.estimatedCount)?.estimatedCount, - _.maxBy(items, item => item.actualCount ?? 0)?.actualCount, - ]) ?? 1 - - const onAttributeSelection = useOnAttributeSelection( - setIsLoading, - selectedAttributes, - worker, - navigateResult, - isMounted, - setSelectedAttributes, - setSelectedAttributeRows, - ) + const maxCount = useMaxCount(items) + const isMounted = useRef(true) + const stopPropagation = useStopPropagation() const handleSelection = useCallback( - (item: IAttributesIntersectionValue | undefined) => { + async (item: IAttributesIntersection | undefined) => { const newValue = item?.value // toggle off with re-click - if (newValue === selectedAttribute) { - onAttributeSelection(columnIndex, undefined) + if (newValue === undefined || selectedAttributes.has(newValue)) { + await onSetSelectedAttributes(columnIndex, undefined) } else { - onAttributeSelection(columnIndex, newValue) + await onSetSelectedAttributes(columnIndex, item) } }, - [selectedAttribute, columnIndex, onAttributeSelection], + [selectedAttributes, onSetSelectedAttributes, columnIndex], ) - const stopPropagation = useStopPropagation() - useEffect(() => { if (worker) { setIsLoading(true) worker - .intersectAttributesInColumnsWith( - headers, - navigateResult.allRows, - navigateResult.attrsInColumnsMap[columnIndex] ?? new Set(), - selectedAttributeRows, - selectedAttributes, - navigateResult.attrRowsMap, - columnIndex, - evaluatedResult.sensitiveAggregatedResult?.aggregatedCombinations, - ) - .then(newItems => { - if (isMounted.current) { - setItems(newItems ?? []) - setIsLoading(false) + .attributesIntersectionsByColumn([headerName]) + .then(intersections => { + if (!isMounted.current || !intersections) { + return } + setItems(intersections[columnIndex] ?? []) + setIsLoading(false) }) } - }, [ - worker, - navigateResult, - headers, - selectedAttributeRows, - selectedAttributes, - columnIndex, - evaluatedResult, - setItems, - setIsLoading, - ]) + }, [worker, setIsLoading, setItems, headerName, columnIndex]) useEffect(() => { return () => { @@ -151,8 +106,8 @@ export const ColumnAttributeSelector: React.FC = items={items} onClick={handleSelection} maxCount={maxCount} - height={chartHeight * Math.max(items.length, 1) + AXIS_HEIGHT} - selectedValue={selectedAttribute} + height={chartBarHeight * Math.max(items.length, 1) + AXIS_HEIGHT} + selectedAttributes={selectedAttributes} /> )} diff --git a/packages/webapp/src/components/AttributeSelector/ColumnAttributeSelectorGrid.tsx b/packages/webapp/src/components/AttributeSelector/ColumnAttributeSelectorGrid.tsx new file mode 100644 index 0000000..f002829 --- /dev/null +++ b/packages/webapp/src/components/AttributeSelector/ColumnAttributeSelectorGrid.tsx @@ -0,0 +1,81 @@ +/*! + * Copyright (c) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE file in the project. + */ +import { Stack, useTheme } from '@fluentui/react' +import { memo } from 'react' +import { HeaderNames, ISelectedAttributesByColumn } from 'sds-wasm' +import { ColumnAttributeSelector } from './ColumnAttributeSelector' +import { AttributeIntersectionValueChartLegend } from '~components/AttributeIntersectionValueChartLegend' +import { useHorizontalScrolling } from '~components/Charts/hooks' +import { SetSelectedAttributesCallback } from '~components/Pages/DataShowcasePage/DataNavigation' + +export interface ColumnAttributeSelectorGridProps { + viewHeight: string | number + headers: HeaderNames + selectedHeaders: boolean[] + chartHeight: string | number + chartWidth: string | number + chartBarHeight: number + chartMinHeight?: string | number + selectedAttributesByColumn: ISelectedAttributesByColumn + onSetSelectedAttributes: SetSelectedAttributesCallback +} + +export const ColumnAttributeSelectorGrid: React.FC = + memo(function ColumnAttributeSelectorGrid({ + viewHeight, + headers, + selectedHeaders, + chartHeight, + chartWidth, + chartBarHeight, + chartMinHeight, + selectedAttributesByColumn, + onSetSelectedAttributes, + }: ColumnAttributeSelectorGridProps) { + const theme = useTheme() + const doHorizontalScroll = useHorizontalScrolling() + + return ( + + + + {headers.map((h, i) => { + return ( + selectedHeaders[i] && ( + + + + ) + ) + })} + + + ) + }) diff --git a/packages/webapp/src/components/AttributeSelector/HeaderSelector.tsx b/packages/webapp/src/components/AttributeSelector/HeaderSelector.tsx index 29d70e9..efe0df0 100644 --- a/packages/webapp/src/components/AttributeSelector/HeaderSelector.tsx +++ b/packages/webapp/src/components/AttributeSelector/HeaderSelector.tsx @@ -4,10 +4,10 @@ */ import { Stack, useTheme, IStackTokens, Checkbox } from '@fluentui/react' import { memo } from 'react' -import { CsvRecord } from '~models' +import { HeaderNames } from 'sds-wasm' export interface HeaderSelectorProps { - headers: CsvRecord + headers: HeaderNames selectedHeaders: boolean[] onToggle: (columnIndex: number) => void } diff --git a/packages/webapp/src/components/AttributeSelector/SelectedAttributes.tsx b/packages/webapp/src/components/AttributeSelector/SelectedAttributes.tsx index 631aa1f..fda641b 100644 --- a/packages/webapp/src/components/AttributeSelector/SelectedAttributes.tsx +++ b/packages/webapp/src/components/AttributeSelector/SelectedAttributes.tsx @@ -10,80 +10,62 @@ import { Stack, useTheme, } from '@fluentui/react' -import { memo, useEffect, useRef, useState } from 'react' -import { useOnAttributeSelection, useOnClearAttributeSelection } from './hooks' -import { CsvRecord } from '~models' -import { useNavigateResultValue, useWasmWorkerValue } from '~states' +import { memo } from 'react' +import { HeaderNames, ISelectedAttributesByColumn } from 'sds-wasm' +import { useSelectedAttributesByColumnEntries } from './hooks' import { - useSelectedAttributeRowsSetter, - useSelectedAttributes, -} from '~states/dataShowcaseContext' + ClearSelectedAttributesCallback, + SetSelectedAttributesCallback, +} from '~components/Pages/DataShowcasePage/DataNavigation' const deleteIcon: IIconProps = { iconName: 'Delete' } export interface SelectedAttributesProps { - headers: CsvRecord + headers: HeaderNames + selectedAttributesByColumn: ISelectedAttributesByColumn + onSetSelectedAttributes: SetSelectedAttributesCallback + onClearSelectedAttributes: ClearSelectedAttributesCallback } export const SelectedAttributes: React.FC = memo( - function SelectedAttributes({ headers }: SelectedAttributesProps) { - const [selectedAttributes, setSelectedAttributes] = useSelectedAttributes() - const [isLoading, setIsLoading] = useState(false) + function SelectedAttributes({ + headers, + selectedAttributesByColumn, + onSetSelectedAttributes, + onClearSelectedAttributes, + }: SelectedAttributesProps) { const theme = useTheme() - const worker = useWasmWorkerValue() - const navigateResult = useNavigateResultValue() - const setSelectedAttributeRows = useSelectedAttributeRowsSetter() - const isMounted = useRef(true) + const selectedEntries = useSelectedAttributesByColumnEntries( + selectedAttributesByColumn, + ) const stackTokens: IStackTokens = { childrenGap: theme.spacing.s2, } - const onAttributeSelection = useOnAttributeSelection( - setIsLoading, - selectedAttributes, - worker, - navigateResult, - isMounted, - setSelectedAttributes, - setSelectedAttributeRows, - ) - - const onClearAttributeSection = useOnClearAttributeSelection( - setIsLoading, - worker, - navigateResult, - isMounted, - setSelectedAttributes, - setSelectedAttributeRows, - ) - - const selectedEntries = Object.entries(selectedAttributes) - - useEffect(() => { - return () => { - isMounted.current = false - } - }, []) - return ( Clear - {selectedEntries.map(entry => { - return ( - onAttributeSelection(+entry[0], undefined)} - /> - ) + {selectedEntries.flatMap(entry => { + return Array.from(entry[1].keys()) + .sort() + .map(value => { + return ( + + await onSetSelectedAttributes(+entry[0], undefined) + } + /> + ) + }) })} ) diff --git a/packages/webapp/src/components/AttributeSelector/hooks.ts b/packages/webapp/src/components/AttributeSelector/hooks.ts index adba895..a8da91c 100644 --- a/packages/webapp/src/components/AttributeSelector/hooks.ts +++ b/packages/webapp/src/components/AttributeSelector/hooks.ts @@ -2,101 +2,28 @@ * Copyright (c) Microsoft. All rights reserved. * Licensed under the MIT license. See LICENSE file in the project. */ -import { Dispatch, MutableRefObject, SetStateAction, useCallback } from 'react' -import { SetterOrUpdater } from 'recoil' -import { SdsWasmWorker } from 'src/workers/sds-wasm' -import { AttributeRows, INavigateResult, ISelectedAttributes } from '~models' +import _ from 'lodash' +import { useMemo } from 'react' +import { IAttributesIntersection, ISelectedAttributesByColumn } from 'sds-wasm' -async function loadNewSelectedAttributeRows( - setIsLoading: Dispatch>, - newSelectedAttributes: ISelectedAttributes, - worker: SdsWasmWorker | null, - navigateResult: INavigateResult, - isMounted: MutableRefObject, - setSelectedAttributes: SetterOrUpdater, - setSelectedAttributeRows: SetterOrUpdater, -) { - const newSelectedAttributeRows = - (await worker?.intersectSelectedAttributesWith( - newSelectedAttributes, - navigateResult.allRows, - navigateResult.attrRowsMap, - )) ?? navigateResult.allRows - - if (isMounted.current) { - setSelectedAttributes(newSelectedAttributes) - setSelectedAttributeRows(newSelectedAttributeRows) - setIsLoading(false) - } -} - -export function useOnAttributeSelection( - setIsLoading: Dispatch>, - selectedAttributes: ISelectedAttributes, - worker: SdsWasmWorker | null, - navigateResult: INavigateResult, - isMounted: MutableRefObject, - setSelectedAttributes: SetterOrUpdater, - setSelectedAttributeRows: SetterOrUpdater, -): (columnIndex: number, attr?: string) => Promise { - return useCallback( - async (columnIndex, attr) => { - setIsLoading(true) - const newSelectedAttributes = { - ...selectedAttributes, - [columnIndex]: attr, - } - - if (!attr) { - delete newSelectedAttributes[columnIndex] - } - await loadNewSelectedAttributeRows( - setIsLoading, - newSelectedAttributes, - worker, - navigateResult, - isMounted, - setSelectedAttributes, - setSelectedAttributeRows, - ) - }, - [ - selectedAttributes, - navigateResult, - worker, - setSelectedAttributes, - setSelectedAttributeRows, - setIsLoading, - isMounted, - ], +export function useMaxCount(items: IAttributesIntersection[]): number { + return useMemo( + () => + Number( + _.max([ + _.maxBy(items, item => item.estimatedCount)?.estimatedCount, + _.maxBy(items, item => item.actualCount ?? 0)?.actualCount, + ]), + ) ?? 1, + [items], ) } -export function useOnClearAttributeSelection( - setIsLoading: Dispatch>, - worker: SdsWasmWorker | null, - navigateResult: INavigateResult, - isMounted: MutableRefObject, - setSelectedAttributes: SetterOrUpdater, - setSelectedAttributeRows: SetterOrUpdater, -): () => Promise { - return useCallback(async () => { - setIsLoading(true) - await loadNewSelectedAttributeRows( - setIsLoading, - {}, - worker, - navigateResult, - isMounted, - setSelectedAttributes, - setSelectedAttributeRows, - ) - }, [ - navigateResult, - worker, - setSelectedAttributes, - setSelectedAttributeRows, - setIsLoading, - isMounted, - ]) +export function useSelectedAttributesByColumnEntries( + selectedAttributesByColumn: ISelectedAttributesByColumn, +): [string, Set][] { + return useMemo( + () => Object.entries(selectedAttributesByColumn), + [selectedAttributesByColumn], + ) } diff --git a/packages/webapp/src/components/AttributeSelector/index.ts b/packages/webapp/src/components/AttributeSelector/index.ts index 05b729f..8f05e91 100644 --- a/packages/webapp/src/components/AttributeSelector/index.ts +++ b/packages/webapp/src/components/AttributeSelector/index.ts @@ -3,5 +3,6 @@ * Licensed under the MIT license. See LICENSE file in the project. */ export * from './ColumnAttributeSelector' +export * from './ColumnAttributeSelectorGrid' export * from './HeaderSelector' export * from './SelectedAttributes' diff --git a/packages/webapp/src/components/Charts/AttributeIntersectionValueChart.tsx b/packages/webapp/src/components/Charts/AttributeIntersectionValueChart.tsx index bbc18e9..0a449fd 100644 --- a/packages/webapp/src/components/Charts/AttributeIntersectionValueChart.tsx +++ b/packages/webapp/src/components/Charts/AttributeIntersectionValueChart.tsx @@ -6,19 +6,19 @@ import type { Plugin } from 'chart.js' import ChartDataLabels from 'chartjs-plugin-datalabels' import { memo, useCallback } from 'react' import { Bar } from 'react-chartjs-2' +import { IAttributesIntersection } from 'sds-wasm' import { useActualBarConfig, useDataLabelsConfig, useEstimatedBarConfig, -} from '~components/Charts/hooks' -import { IAttributesIntersectionValue } from '~models' +} from './hooks' export interface AttributeIntersectionValueChartProps { - items: IAttributesIntersectionValue[] + items: IAttributesIntersection[] maxCount: number height: number - onClick?: (item: IAttributesIntersectionValue | undefined) => void - selectedValue?: string + onClick?: (item: IAttributesIntersection | undefined) => void + selectedAttributes: Set } export const AttributeIntersectionValueChart: React.FC = @@ -27,16 +27,20 @@ export const AttributeIntersectionValueChart: React.FC i.value) : [] - const estimated = items ? items.map(i => i.estimatedCount) : [] + const estimated = items ? items.map(i => Number(i.estimatedCount)) : [] const actual = items - ? items.map(i => i.actualCount).filter(i => i !== undefined) + ? items + .map(i => + i.actualCount !== undefined ? Number(i.actualCount) : undefined, + ) + .filter(i => i !== undefined) : [] - const estimatedBarConfig = useEstimatedBarConfig(labels, selectedValue) - const actualBarConfig = useActualBarConfig(labels, selectedValue) - const dataLabelsConfig = useDataLabelsConfig(labels, selectedValue) + const estimatedBarConfig = useEstimatedBarConfig(labels, selectedAttributes) + const actualBarConfig = useActualBarConfig(labels, selectedAttributes) + const dataLabelsConfig = useDataLabelsConfig(labels, selectedAttributes) const handleClick = useCallback( (evt, elements, chart) => { @@ -69,7 +73,21 @@ export const AttributeIntersectionValueChart: React.FC]} + plugins={[ + ChartDataLabels as Plugin<'bar'>, + { + id: 'event-catcher', + beforeEvent(chart, args, _pluginOptions) { + // on hover at options will not handle well the case + // where the mouse leaves the bar + if (args.event.type === 'mousemove') { + const elements = chart.getActiveElements() + chart.canvas.style.cursor = + elements && elements[0] ? 'pointer' : 'default' + } + }, + }, + ]} options={{ plugins: { ...dataLabelsConfig, diff --git a/packages/webapp/src/components/Charts/FabricatedCountChart.tsx b/packages/webapp/src/components/Charts/FabricatedCountChart.tsx index 7b5554c..3e9821e 100644 --- a/packages/webapp/src/components/Charts/FabricatedCountChart.tsx +++ b/packages/webapp/src/components/Charts/FabricatedCountChart.tsx @@ -4,13 +4,13 @@ */ import { memo, useMemo } from 'react' import { Bar } from 'react-chartjs-2' -import { IAggregatedCountByLen } from '~models' +import { IAggregateCountByLen } from 'sds-wasm' export interface FabricatedCountChartProps { combinationsLabel: string fabricatedLabel: string - combinationsCountByLen: IAggregatedCountByLen - fabricatedCountByLen: IAggregatedCountByLen + combinationsCountByLen: IAggregateCountByLen + fabricatedCountByLen: IAggregateCountByLen height: number width: number } diff --git a/packages/webapp/src/components/Charts/LeakageCountChart.tsx b/packages/webapp/src/components/Charts/LeakageCountChart.tsx index 0211d73..602c92f 100644 --- a/packages/webapp/src/components/Charts/LeakageCountChart.tsx +++ b/packages/webapp/src/components/Charts/LeakageCountChart.tsx @@ -4,13 +4,13 @@ */ import { memo } from 'react' import { Bar } from 'react-chartjs-2' -import { IAggregatedCountByLen } from '~models' +import { IAggregateCountByLen } from 'sds-wasm' export interface LeakageCountChartProps { combinationsLabel: string leakageLabel: string - combinationsCountByLen: IAggregatedCountByLen - leakageCountByLen: IAggregatedCountByLen + combinationsCountByLen: IAggregateCountByLen + leakageCountByLen: IAggregateCountByLen height: number width: number } diff --git a/packages/webapp/src/components/Charts/MeanCombinationsByLengthChart.tsx b/packages/webapp/src/components/Charts/MeanCombinationsByLengthChart.tsx index b0dfb80..4cebd9f 100644 --- a/packages/webapp/src/components/Charts/MeanCombinationsByLengthChart.tsx +++ b/packages/webapp/src/components/Charts/MeanCombinationsByLengthChart.tsx @@ -4,12 +4,12 @@ */ import { memo } from 'react' import { Bar } from 'react-chartjs-2' -import { IAggregatedCountByLen } from '~models' +import { IAggregateCountByLen } from 'sds-wasm' export interface MeanCombinationsByLengthChartProps { meanLabel: string - combinationsCountByLen: IAggregatedCountByLen - combinationsSumByLen: IAggregatedCountByLen + combinationsCountByLen: IAggregateCountByLen + combinationsSumByLen: IAggregateCountByLen height: number width: number } diff --git a/packages/webapp/src/components/Charts/PreservationByCountChart.tsx b/packages/webapp/src/components/Charts/PreservationByCountChart.tsx index 5a2afa0..49f4aed 100644 --- a/packages/webapp/src/components/Charts/PreservationByCountChart.tsx +++ b/packages/webapp/src/components/Charts/PreservationByCountChart.tsx @@ -5,7 +5,7 @@ import { memo } from 'react' import { Bar } from 'react-chartjs-2' -import { IPreservationByCountBuckets } from '~models' +import { IPreservationByCountBuckets } from 'sds-wasm' export interface PreservationByCountChartProps { meanLengthLabel: string diff --git a/packages/webapp/src/components/Charts/PreservationPercentageByLength.tsx b/packages/webapp/src/components/Charts/PreservationPercentageByLength.tsx index dad4f2b..64dccea 100644 --- a/packages/webapp/src/components/Charts/PreservationPercentageByLength.tsx +++ b/packages/webapp/src/components/Charts/PreservationPercentageByLength.tsx @@ -4,14 +4,14 @@ */ import { memo } from 'react' import { Bar } from 'react-chartjs-2' -import { IAggregatedCountByLen } from '~models' +import { IAggregateCountByLen } from 'sds-wasm' export interface PreservationPercentageByLengthProps { combinationsLabel: string preservationLabel: string - combinationsCountByLen: IAggregatedCountByLen - sensitiveCombinationsCountByLen: IAggregatedCountByLen - syntheticCombinationsCountByLen: IAggregatedCountByLen + combinationsCountByLen: IAggregateCountByLen + sensitiveCombinationsCountByLen: IAggregateCountByLen + syntheticCombinationsCountByLen: IAggregateCountByLen height: number width: number } diff --git a/packages/webapp/src/components/Charts/RareCombinationsByLengthChart.tsx b/packages/webapp/src/components/Charts/RareCombinationsByLengthChart.tsx index bfa0dfd..0e65dd3 100644 --- a/packages/webapp/src/components/Charts/RareCombinationsByLengthChart.tsx +++ b/packages/webapp/src/components/Charts/RareCombinationsByLengthChart.tsx @@ -4,13 +4,13 @@ */ import { memo } from 'react' import { Bar } from 'react-chartjs-2' -import { IAggregatedCountByLen } from '~models' +import { IAggregateCountByLen } from 'sds-wasm' export interface RareCombinationsByLengthChartProps { combinationsLabel: string rareCombinationsLabel: string - combinationsCountByLen: IAggregatedCountByLen - rareCombinationsCountByLen: IAggregatedCountByLen + combinationsCountByLen: IAggregateCountByLen + rareCombinationsCountByLen: IAggregateCountByLen height: number width: number } diff --git a/packages/webapp/src/components/Charts/hooks.ts b/packages/webapp/src/components/Charts/hooks.ts index 4254eb6..38c7f8f 100644 --- a/packages/webapp/src/components/Charts/hooks.ts +++ b/packages/webapp/src/components/Charts/hooks.ts @@ -30,12 +30,12 @@ export interface DataLabelsConfig { function useBarConfig( colors: BarColors, items: string[], - selectedValue?: string, + selectedAttributes: Set, ): ChartJsDatasetConfig { return useMemo(() => { const backgroundColor = items.map(i => { - if (selectedValue) { - return i === selectedValue ? colors.selected : colors.suppressed + if (selectedAttributes.size !== 0) { + return selectedAttributes.has(i) ? colors.selected : colors.suppressed } return colors.normal }) @@ -43,28 +43,28 @@ function useBarConfig( type: 'bar', backgroundColor: backgroundColor.length > 0 ? backgroundColor : undefined, } - }, [colors, items, selectedValue]) + }, [colors, items, selectedAttributes]) } export function useActualBarConfig( items: string[], - selectedValue?: string, + selectedAttributes: Set, ): ChartJsDatasetConfig { const colors = useActualBarChartColors() - return useBarConfig(colors, items, selectedValue) + return useBarConfig(colors, items, selectedAttributes) } export function useEstimatedBarConfig( items: string[], - selectedValue?: string, + selectedAttributes: Set, ): ChartJsDatasetConfig { const colors = useEstimatedBarChartColors() - return useBarConfig(colors, items, selectedValue) + return useBarConfig(colors, items, selectedAttributes) } export function useDataLabelsConfig( items: string[], - selectedValue?: string, + selectedAttributes: Set, ): DataLabelsConfig { const thematic = useThematic() return useMemo(() => { @@ -75,11 +75,11 @@ export function useDataLabelsConfig( align: 'end', offset: 5, color: items.map(item => - item === selectedValue ? greys[0] : greys[80], + selectedAttributes.has(item) ? greys[0] : greys[80], ), }, } - }, [thematic, items, selectedValue]) + }, [thematic, items, selectedAttributes]) } export function useHorizontalScrolling(): (e: WheelEvent) => void { diff --git a/packages/webapp/src/components/DataBinning/ColumnValueReplacer.tsx b/packages/webapp/src/components/DataBinning/ColumnValueReplacer.tsx index 689eba3..81aec18 100644 --- a/packages/webapp/src/components/DataBinning/ColumnValueReplacer.tsx +++ b/packages/webapp/src/components/DataBinning/ColumnValueReplacer.tsx @@ -15,17 +15,7 @@ import { } from '@fluentui/react' import _ from 'lodash' import { memo, useCallback, useEffect, useState } from 'react' -import { - defaultCsvContent, - defaultEvaluatedResult, - defaultNavigateResult, -} from '~models' -import { - useEvaluatedResultSetter, - useNavigateResultSetter, - useSensitiveContent, - useSyntheticContentSetter, -} from '~states' +import { useClearGenerate, useSensitiveContent } from '~states' import { BinOperationJoinCondition, BinOperationType, @@ -48,9 +38,7 @@ export const ColumnValueReplacer: React.FC = memo( const [currentValue, setCurrentValue] = useState('') const [valueToReplace, setValueToReplace] = useState('') const [csvContent, setCsvContent] = useSensitiveContent() - const setSyntheticContent = useSyntheticContentSetter() - const setEvaluatedResult = useEvaluatedResultSetter() - const setNavigateResult = useNavigateResultSetter() + const clearGenerate = useClearGenerate() const theme = useTheme() @@ -83,7 +71,7 @@ export const ColumnValueReplacer: React.FC = memo( [selectedValues, setSelectedValues], ) - const onRun = useCallback(() => { + const onRun = useCallback(async () => { if (selectedValues.length > 0 && valueToReplace.length > 0) { const newItems = [...csvContent.items.map(item => [...item])] const bins: ICustomBin[] = [ @@ -98,13 +86,11 @@ export const ColumnValueReplacer: React.FC = memo( ] new InplaceBinning().customBins(bins).run(newItems, headerIndex) + await clearGenerate() setCsvContent({ ...csvContent, items: newItems, }) - setSyntheticContent(defaultCsvContent) - setEvaluatedResult(defaultEvaluatedResult) - setNavigateResult(defaultNavigateResult) onUpdateCurrentValue('') onUpdateValueToReplace('') } @@ -113,10 +99,8 @@ export const ColumnValueReplacer: React.FC = memo( valueToReplace, csvContent, headerIndex, + clearGenerate, setCsvContent, - setSyntheticContent, - setEvaluatedResult, - setNavigateResult, onUpdateCurrentValue, onUpdateValueToReplace, ]) diff --git a/packages/webapp/src/components/DataBinning/CustomDataBinning.tsx b/packages/webapp/src/components/DataBinning/CustomDataBinning.tsx index ec166d7..3273801 100644 --- a/packages/webapp/src/components/DataBinning/CustomDataBinning.tsx +++ b/packages/webapp/src/components/DataBinning/CustomDataBinning.tsx @@ -16,17 +16,7 @@ import { useTheme, } from '@fluentui/react' import { memo, useCallback, useState } from 'react' -import { - defaultCsvContent, - defaultEvaluatedResult, - defaultNavigateResult, -} from '~models' -import { - useEvaluatedResultSetter, - useNavigateResultSetter, - useSensitiveContent, - useSyntheticContentSetter, -} from '~states' +import { useClearGenerate, useSensitiveContent } from '~states' import { BinOperationJoinCondition, BinOperationType, @@ -54,9 +44,7 @@ export const CustomDataBinning: React.FC = memo( function CustomDataBinning({ headerIndex }: CustomDataBinningProps) { const [bins, setBins] = useState([]) const [csvContent, setCsvContent] = useSensitiveContent() - const setSyntheticContent = useSyntheticContentSetter() - const setEvaluatedResult = useEvaluatedResultSetter() - const setNavigateResult = useNavigateResultSetter() + const clearGenerate = useClearGenerate() const theme = useTheme() @@ -140,29 +128,19 @@ export const CustomDataBinning: React.FC = memo( [bins, setBins], ) - const onRun = useCallback(() => { + const onRun = useCallback(async () => { if (bins.length > 0) { const newItems = [...csvContent.items.map(item => [...item])] new InplaceBinning().customBins(bins).run(newItems, headerIndex) + await clearGenerate() setCsvContent({ ...csvContent, items: newItems, }) - setSyntheticContent(defaultCsvContent) - setEvaluatedResult(defaultEvaluatedResult) - setNavigateResult(defaultNavigateResult) } - }, [ - bins, - csvContent, - headerIndex, - setCsvContent, - setSyntheticContent, - setEvaluatedResult, - setNavigateResult, - ]) + }, [bins, csvContent, headerIndex, clearGenerate, setCsvContent]) const stackTokens: IStackTokens = { childrenGap: theme.spacing.m, diff --git a/packages/webapp/src/components/DataBinning/FixedCountDataBinning.tsx b/packages/webapp/src/components/DataBinning/FixedCountDataBinning.tsx index 3e8440b..3723a5d 100644 --- a/packages/webapp/src/components/DataBinning/FixedCountDataBinning.tsx +++ b/packages/webapp/src/components/DataBinning/FixedCountDataBinning.tsx @@ -12,17 +12,7 @@ import { import { Form, useFormik, FormikProvider } from 'formik' import { memo, useCallback, useEffect } from 'react' import * as yup from 'yup' -import { - defaultCsvContent, - defaultEvaluatedResult, - defaultNavigateResult, -} from '~models' -import { - useEvaluatedResultSetter, - useNavigateResultSetter, - useSensitiveContent, - useSyntheticContentSetter, -} from '~states' +import { useClearGenerate, useSensitiveContent } from '~states' import { findMinMax, InplaceBinning, stringToNumber } from '~utils' export interface FixedCountDataBinningProps { @@ -38,33 +28,22 @@ const validationSchema = yup.object().shape({ export const FixedCountDataBinning: React.FC = memo( function FixedCountDataBinning({ headerIndex }: FixedCountDataBinningProps) { const [csvContent, setCsvContent] = useSensitiveContent() - const setSyntheticContent = useSyntheticContentSetter() - const setEvaluatedResult = useEvaluatedResultSetter() - const setNavigateResult = useNavigateResultSetter() + const clearGenerate = useClearGenerate() const onRun = useCallback( - values => { + async values => { const newItems = [...csvContent.items.map(item => [...item])] new InplaceBinning() .fixedBinCount(values.binCount, values.minValue, values.maxValue) .run(newItems, headerIndex) + await clearGenerate() setCsvContent({ ...csvContent, items: newItems, }) - setSyntheticContent(defaultCsvContent) - setEvaluatedResult(defaultEvaluatedResult) - setNavigateResult(defaultNavigateResult) }, - [ - csvContent, - headerIndex, - setCsvContent, - setSyntheticContent, - setEvaluatedResult, - setNavigateResult, - ], + [csvContent, headerIndex, clearGenerate, setCsvContent], ) const formik = useFormik({ validationSchema, diff --git a/packages/webapp/src/components/DataBinning/FixedWidthDataBinning.tsx b/packages/webapp/src/components/DataBinning/FixedWidthDataBinning.tsx index d8a6492..cd02db3 100644 --- a/packages/webapp/src/components/DataBinning/FixedWidthDataBinning.tsx +++ b/packages/webapp/src/components/DataBinning/FixedWidthDataBinning.tsx @@ -12,17 +12,7 @@ import { import { Form, useFormik, FormikProvider } from 'formik' import { memo, useCallback, useEffect } from 'react' import * as yup from 'yup' -import { - defaultCsvContent, - defaultEvaluatedResult, - defaultNavigateResult, -} from '~models' -import { - useEvaluatedResultSetter, - useNavigateResultSetter, - useSensitiveContent, - useSyntheticContentSetter, -} from '~states' +import { useClearGenerate, useSensitiveContent } from '~states' import { findMinMax, InplaceBinning, stringToNumber } from '~utils' export interface FixedWidthDataBinningProps { @@ -38,33 +28,22 @@ const validationSchema = yup.object().shape({ export const FixedWidthDataBinning: React.FC = memo( function FixedWidthDataBinning({ headerIndex }: FixedWidthDataBinningProps) { const [csvContent, setCsvContent] = useSensitiveContent() - const setSyntheticContent = useSyntheticContentSetter() - const setEvaluatedResult = useEvaluatedResultSetter() - const setNavigateResult = useNavigateResultSetter() + const clearGenerate = useClearGenerate() const onRun = useCallback( - values => { + async values => { const newItems = [...csvContent.items.map(item => [...item])] new InplaceBinning() .fixedBinWidth(values.binWidth, values.minValue, values.maxValue) .run(newItems, headerIndex) + await clearGenerate() setCsvContent({ ...csvContent, items: newItems, }) - setSyntheticContent(defaultCsvContent) - setEvaluatedResult(defaultEvaluatedResult) - setNavigateResult(defaultNavigateResult) }, - [ - csvContent, - headerIndex, - setCsvContent, - setSyntheticContent, - setEvaluatedResult, - setNavigateResult, - ], + [csvContent, headerIndex, clearGenerate, setCsvContent], ) const formik = useFormik({ validationSchema, diff --git a/packages/webapp/src/components/EvaluationSummary/EvaluationSummary.tsx b/packages/webapp/src/components/EvaluationSummary/EvaluationSummary.tsx index df7e0a2..4fa6138 100644 --- a/packages/webapp/src/components/EvaluationSummary/EvaluationSummary.tsx +++ b/packages/webapp/src/components/EvaluationSummary/EvaluationSummary.tsx @@ -10,7 +10,7 @@ import { SelectionMode, } from '@fluentui/react' import { memo } from 'react' -import { IPrivacyRiskSummary } from '~models' +import { IPrivacyRiskSummary } from 'sds-wasm' export interface EvaluationSummaryProps { privacyRiskLabel: string diff --git a/packages/webapp/src/components/Pages/DataShowcasePage/DataEvaluation.tsx b/packages/webapp/src/components/Pages/DataShowcasePage/DataEvaluation.tsx index 9a97fd8..29eb743 100644 --- a/packages/webapp/src/components/Pages/DataShowcasePage/DataEvaluation.tsx +++ b/packages/webapp/src/components/Pages/DataShowcasePage/DataEvaluation.tsx @@ -11,7 +11,6 @@ import { TextField, } from '@fluentui/react' import { memo, useCallback } from 'react' -import { CsvRecord } from 'src/models/csv' import { FabricatedCountChart, LeakageCountChart, @@ -21,33 +20,22 @@ import { RareCombinationsByLengthChart, } from '~components/Charts' import { EvaluationSummary } from '~components/EvaluationSummary' -import { defaultEvaluatedResult, defaultNavigateResult } from '~models' import { + useClearEvaluate, + useEvaluateResult, useIsProcessing, - useRecordLimitValue, - useReportingLength, - useResolutionValue, - useSensitiveContentValue, -} from '~states' -import { - useEvaluatedResult, - useNavigateResultSetter, useProcessingProgressSetter, - useSyntheticContentValue, + useReportingLength, useWasmWorkerValue, -} from '~states/dataShowcaseContext' +} from '~states' export const DataEvaluation: React.FC = memo(function DataEvaluation() { - const worker = useWasmWorkerValue() - const recordLimit = useRecordLimitValue() const [reportingLength, setReportingLength] = useReportingLength() const [isProcessing, setIsProcessing] = useIsProcessing() - const sensitiveContent = useSensitiveContentValue() - const syntheticContent = useSyntheticContentValue() - const [evaluatedResult, setEvaluatedResult] = useEvaluatedResult() - const setNavigateResult = useNavigateResultSetter() - const resolution = useResolutionValue() + const [evaluateResult, setEvaluateResult] = useEvaluateResult() + const worker = useWasmWorkerValue() const setProcessingProgress = useProcessingProgressSetter() + const clearEvaluate = useClearEvaluate() const theme = getTheme() @@ -92,26 +80,14 @@ export const DataEvaluation: React.FC = memo(function DataEvaluation() { const onRunEvaluate = useCallback(async () => { setIsProcessing(true) - setEvaluatedResult(defaultEvaluatedResult) - setNavigateResult(defaultNavigateResult) + await clearEvaluate() setProcessingProgress(0.0) const response = await worker?.evaluate( - [ - sensitiveContent.headers.map(h => h.name), - ...(sensitiveContent.items as CsvRecord[]), - ], - [ - syntheticContent.headers.map(h => h.name), - ...(syntheticContent.items as CsvRecord[]), - ], - sensitiveContent.headers.filter(h => h.use).map(h => h.name), - sensitiveContent.headers - .filter(h => h.hasSensitiveZeros) - .map(h => h.name), - recordLimit, reportingLength, - resolution, + 0, + ';', + false, p => { setProcessingProgress(p) }, @@ -119,18 +95,14 @@ export const DataEvaluation: React.FC = memo(function DataEvaluation() { setIsProcessing(false) if (response) { - setEvaluatedResult(response) + setEvaluateResult(response) } }, [ worker, setIsProcessing, - sensitiveContent, - syntheticContent, - recordLimit, reportingLength, - resolution, - setEvaluatedResult, - setNavigateResult, + clearEvaluate, + setEvaluateResult, setProcessingProgress, ]) @@ -163,26 +135,22 @@ export const DataEvaluation: React.FC = memo(function DataEvaluation() { - {evaluatedResult.sensitiveAggregatedResult && ( + {evaluateResult && ( <>

Summary

- {evaluatedResult.sensitiveAggregatedResult.privacyRisk && - evaluatedResult.recordExpansion !== undefined && - evaluatedResult.combinationLoss !== undefined && ( - - - - )} + + +

Sensitive data charts

@@ -196,12 +164,10 @@ export const DataEvaluation: React.FC = memo(function DataEvaluation() { + + + +

Synthetic data charts

+
+ + + + + + + + + + + + + + + + + + - {evaluatedResult.syntheticAggregatedResult && ( - <> - -

Synthetic data charts

-
- - - - - - - - - - - - - - - - - - - - - - )} )} diff --git a/packages/webapp/src/components/Pages/DataShowcasePage/DataInput.tsx b/packages/webapp/src/components/Pages/DataShowcasePage/DataInput.tsx index 99bf14f..3df4f77 100644 --- a/packages/webapp/src/components/Pages/DataShowcasePage/DataInput.tsx +++ b/packages/webapp/src/components/Pages/DataShowcasePage/DataInput.tsx @@ -15,17 +15,14 @@ import { } from '@fluentui/react' import { parse } from 'papaparse' import { memo, useCallback, useRef } from 'react' -import { defaultCsvContent, ICsvTableHeader } from 'src/models/csv' import { CsvTable } from './CsvTable' import { DataBinning } from '~components/DataBinning' -import { defaultEvaluatedResult, defaultNavigateResult } from '~models' +import { ICsvTableHeader } from '~models' import { - useEvaluatedResultSetter, + useClearSensitiveData, useIsProcessing, - useNavigateResultSetter, useRecordLimit, useSensitiveContent, - useSyntheticContentSetter, useWasmWorkerValue, } from '~states' @@ -35,10 +32,8 @@ export const DataInput: React.FC = memo(function DataInput() { const [recordLimit, setRecordLimit] = useRecordLimit() const [isProcessing, setIsProcessing] = useIsProcessing() const [sensitiveContent, setSensitiveContent] = useSensitiveContent() - const setSyntheticContent = useSyntheticContentSetter() - const setEvaluatedResult = useEvaluatedResultSetter() - const setNavigateResult = useNavigateResultSetter() const worker = useWasmWorkerValue() + const clearSensitiveData = useClearSensitiveData() const inputFile = useRef(null) @@ -62,14 +57,13 @@ export const DataInput: React.FC = memo(function DataInput() { } const onFileChange = useCallback( - (e: React.ChangeEvent) => { + async (e: React.ChangeEvent) => { const f = e.target.files?.[0] if (f) { setIsProcessing(true) - setSyntheticContent(defaultCsvContent) - setEvaluatedResult(defaultEvaluatedResult) - setNavigateResult(defaultNavigateResult) + await clearSensitiveData() + parse>(f, { complete: async results => { const headers = @@ -91,18 +85,15 @@ export const DataInput: React.FC = memo(function DataInput() { columnsWithZeros: await worker?.findColumnsWithZeros(items), delimiter: results.meta.delimiter, }) + // allow the same file to be loaded again + if (inputFile.current) { + inputFile.current.value = '' + } }, }) } }, - [ - worker, - setIsProcessing, - setSyntheticContent, - setEvaluatedResult, - setNavigateResult, - setSensitiveContent, - ], + [worker, setIsProcessing, setSensitiveContent, clearSensitiveData], ) const sensitiveColumnsWithZeros = sensitiveContent.columnsWithZeros?.filter( diff --git a/packages/webapp/src/components/Pages/DataShowcasePage/DataNavigation.tsx b/packages/webapp/src/components/Pages/DataShowcasePage/DataNavigation.tsx index 370f132..f5a92d2 100644 --- a/packages/webapp/src/components/Pages/DataShowcasePage/DataNavigation.tsx +++ b/packages/webapp/src/components/Pages/DataShowcasePage/DataNavigation.tsx @@ -13,43 +13,45 @@ import { Separator, } from '@fluentui/react' import { memo, useCallback, useEffect, useRef, useState } from 'react' -import { AttributeIntersectionValueChartLegend } from '~components/AttributeIntersectionValueChartLegend' +import { IAttributesIntersection, ISelectedAttributesByColumn } from 'sds-wasm' import { - ColumnAttributeSelector, + ColumnAttributeSelectorGrid, HeaderSelector, SelectedAttributes, } from '~components/AttributeSelector' -import { useHorizontalScrolling } from '~components/Charts/hooks' -import { defaultNavigateResult, PipelineStep } from '~models' +import { PipelineStep } from '~models' import { - useNavigateResultSetter, - useSyntheticContentValue, useWasmWorkerValue, -} from '~states' -import { - useSelectedAttributeRowsSetter, - useSelectedAttributesSetter, useSelectedPipelineStepSetter, -} from '~states/dataShowcaseContext' + useSyntheticHeaders, +} from '~states' const backIcon: IIconProps = { iconName: 'Back' } const initiallySelectedHeaders = 6 +const viewHeight = 'calc(100vh - 225px)' + +const chartHeight = `calc((${viewHeight} / 2) - 20px)` + +export type SetSelectedAttributesCallback = ( + headerIndex: number, + item: IAttributesIntersection | undefined, +) => Promise + +export type ClearSelectedAttributesCallback = () => Promise + export const DataNavigation: React.FC = memo(function DataNavigation() { const [isLoading, setIsLoading] = useState(true) - const setNavigateResult = useNavigateResultSetter() - const setSelectedAttributes = useSelectedAttributesSetter() - const syntheticContent = useSyntheticContentValue() + const [selectedAttributesByColumn, setSelectedAttributesByColumn] = + useState({}) const worker = useWasmWorkerValue() - const setSelectedAttributeRows = useSelectedAttributeRowsSetter() const setSelectedPipelineStep = useSelectedPipelineStepSetter() const isMounted = useRef(true) - const headers = syntheticContent.headers.map(h => h.name) + const headers = useSyntheticHeaders() const [selectedHeaders, setSelectedHeaders] = useState( headers.map((_, i) => i < initiallySelectedHeaders), ) - const theme = useTheme() const mainStackStyles: IStackStyles = { @@ -69,44 +71,66 @@ export const DataNavigation: React.FC = memo(function DataNavigation() { childrenGap: theme.spacing.s1, } - const viewHeigh = 'calc(100vh - 225px)' + const setNewSelectedAttributesByColumn = useCallback( + async (newSelectedAttributesByColumn: ISelectedAttributesByColumn) => { + if (worker) { + setIsLoading(true) + const result = await worker.selectAttributes( + newSelectedAttributesByColumn, + ) - const chartHeight = `calc((${viewHeigh} / 2) - 20px)` + if (isMounted.current && result) { + setSelectedAttributesByColumn(newSelectedAttributesByColumn) + setIsLoading(false) + } + } + }, + [worker, setIsLoading, isMounted, setSelectedAttributesByColumn], + ) + + const onSetSelectedAttributes = useCallback( + async (headerIndex: number, item: IAttributesIntersection | undefined) => { + setNewSelectedAttributesByColumn({ + ...selectedAttributesByColumn, + [headerIndex]: + item !== undefined + ? new Set([item.value]) + : new Set(), + }) + }, + [setNewSelectedAttributesByColumn, selectedAttributesByColumn], + ) + + const onClearSelectedAttributes = useCallback(async () => { + setNewSelectedAttributesByColumn({}) + }, [setNewSelectedAttributesByColumn]) const onGoBack = useCallback(() => { setSelectedPipelineStep(PipelineStep.Evaluate) }, [setSelectedPipelineStep]) const onToggleSelectedHeader = useCallback( - index => { + async index => { const newSelectedHeaders = [...selectedHeaders] newSelectedHeaders[index] = !newSelectedHeaders[index] - setSelectedHeaders(newSelectedHeaders) + await setSelectedHeaders(newSelectedHeaders) }, [setSelectedHeaders, selectedHeaders], ) - const doHorizontalScroll = useHorizontalScrolling() - useEffect(() => { if (worker) { setIsLoading(true) - setSelectedAttributeRows([]) - setSelectedAttributes({}) - setNavigateResult(defaultNavigateResult) - setSelectedAttributeRows([]) - - worker.navigate(syntheticContent.items).then(result => { - if (isMounted.current) { - if (result) { - setNavigateResult(result) - setSelectedAttributeRows(result.allRows) - } + worker.navigate().then(result => { + if (isMounted.current && result) { + setSelectedHeaders( + headers.map((_, i) => i < initiallySelectedHeaders), + ) setIsLoading(false) } }) } - }, [syntheticContent.items, setIsLoading, setSelectedAttributeRows, setSelectedAttributes, worker, setNavigateResult]) + }, [setIsLoading, worker, isMounted, setSelectedHeaders, headers]) useEffect(() => { return () => { @@ -121,70 +145,55 @@ export const DataNavigation: React.FC = memo(function DataNavigation() {

Compare sensitive and synthetic results

- + - - - + {isLoading ? ( + + ) : ( + <> + + + - + - - {isLoading ? ( - - ) : ( - - - - {headers.map((h, i) => { - return ( - selectedHeaders[i] && ( - - - - ) - ) - })} - - - )} - + + + + + )} ) diff --git a/packages/webapp/src/components/Pages/DataShowcasePage/DataSynthesis.tsx b/packages/webapp/src/components/Pages/DataShowcasePage/DataSynthesis.tsx index 3e874c4..757193d 100644 --- a/packages/webapp/src/components/Pages/DataShowcasePage/DataSynthesis.tsx +++ b/packages/webapp/src/components/Pages/DataShowcasePage/DataSynthesis.tsx @@ -11,35 +11,30 @@ import { TextField, } from '@fluentui/react' import { memo, useCallback } from 'react' -import { defaultCsvContent, ICsvTableHeader } from 'src/models/csv' +import { ICsvTableHeader } from 'src/models/csv' import { CsvTable } from './CsvTable' -import { defaultEvaluatedResult, defaultNavigateResult } from '~models' import { useCacheSize, + useClearGenerate, useIsProcessing, + useProcessingProgressSetter, useRecordLimitValue, useResolution, useSensitiveContentValue, useSyntheticContent, -} from '~states' -import { - useEvaluatedResultSetter, - useNavigateResultSetter, - useProcessingProgressSetter, useWasmWorkerValue, -} from '~states/dataShowcaseContext' +} from '~states' export const DataSynthesis: React.FC = memo(function DataSynthesis() { - const worker = useWasmWorkerValue() - const recordLimit = useRecordLimitValue() const [resolution, setResolution] = useResolution() const [cacheSize, setCacheSize] = useCacheSize() const [isProcessing, setIsProcessing] = useIsProcessing() - const sensitiveContent = useSensitiveContentValue() const [syntheticContent, setSyntheticContent] = useSyntheticContent() - const setEvaluatedResult = useEvaluatedResultSetter() - const setNavigateResult = useNavigateResultSetter() + const worker = useWasmWorkerValue() + const recordLimit = useRecordLimitValue() + const sensitiveContent = useSensitiveContentValue() const setProcessingProgress = useProcessingProgressSetter() + const clearGenerate = useClearGenerate() const theme = getTheme() @@ -62,9 +57,7 @@ export const DataSynthesis: React.FC = memo(function DataSynthesis() { const onRunGenerate = useCallback(async () => { setIsProcessing(true) - setSyntheticContent(defaultCsvContent) - setEvaluatedResult(defaultEvaluatedResult) - setNavigateResult(defaultNavigateResult) + await clearGenerate() setProcessingProgress(0.0) const response = await worker?.generate( @@ -100,8 +93,7 @@ export const DataSynthesis: React.FC = memo(function DataSynthesis() { worker, setIsProcessing, setSyntheticContent, - setEvaluatedResult, - setNavigateResult, + clearGenerate, sensitiveContent, recordLimit, resolution, diff --git a/packages/webapp/src/models/aggregate/index.ts b/packages/webapp/src/models/aggregate/index.ts deleted file mode 100644 index 88af253..0000000 --- a/packages/webapp/src/models/aggregate/index.ts +++ /dev/null @@ -1,46 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -export interface IAggregatedCombination { - combination_key: number - count: number - length: number -} - -export interface IAggregatedCombinations { - [name: string]: IAggregatedCombination -} - -export interface IAggregatedCountByLen { - [length: number]: number -} - -export interface IPrivacyRiskSummary { - totalNumberOfRecords: number - totalNumberOfCombinations: number - recordsWithUniqueCombinationsCount: number - recordsWithRareCombinationsCount: number - uniqueCombinationsCount: number - rareCombinationsCount: number - recordsWithUniqueCombinationsProportion: number - recordsWithRareCombinationsProportion: number - uniqueCombinationsProportion: number - rareCombinationsProportion: number -} - -export interface IAggregatedResult { - aggregatedCombinations?: IAggregatedCombinations - rareCombinationsCountByLen?: IAggregatedCountByLen - combinationsCountByLen?: IAggregatedCountByLen - combinationsSumByLen?: IAggregatedCountByLen - privacyRisk?: IPrivacyRiskSummary -} - -export const defaultAggregatedResult: IAggregatedResult = { - aggregatedCombinations: undefined, - rareCombinationsCountByLen: undefined, - combinationsCountByLen: undefined, - combinationsSumByLen: undefined, - privacyRisk: undefined, -} diff --git a/packages/webapp/src/models/csv/index.ts b/packages/webapp/src/models/csv/index.ts index 876b80a..78c88d4 100644 --- a/packages/webapp/src/models/csv/index.ts +++ b/packages/webapp/src/models/csv/index.ts @@ -2,7 +2,7 @@ * Copyright (c) Microsoft. All rights reserved. * Licensed under the MIT license. See LICENSE file in the project. */ -export type CsvRecord = string[] +import { CsvRecord } from 'sds-wasm' export interface ICsvTableHeader { name: string diff --git a/packages/webapp/src/models/evaluate/index.ts b/packages/webapp/src/models/evaluate/index.ts deleted file mode 100644 index 8ff4f11..0000000 --- a/packages/webapp/src/models/evaluate/index.ts +++ /dev/null @@ -1,33 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -import { IAggregatedCountByLen, IAggregatedResult } from '~models/aggregate' - -export interface IPreservationByCountBucket { - size: number - preservationSum: number - lengthSum: number -} - -export interface IPreservationByCountBuckets { - [bucket_index: number]: IPreservationByCountBucket -} - -export interface IEvaluatedResult { - sensitiveAggregatedResult?: IAggregatedResult - syntheticAggregatedResult?: IAggregatedResult - leakageCountByLen?: IAggregatedCountByLen - fabricatedCountByLen?: IAggregatedCountByLen - preservationByCountBuckets?: IPreservationByCountBuckets - combinationLoss?: number - recordExpansion?: number -} - -export const defaultEvaluatedResult: IEvaluatedResult = { - sensitiveAggregatedResult: undefined, - syntheticAggregatedResult: undefined, - leakageCountByLen: undefined, - fabricatedCountByLen: undefined, - combinationLoss: undefined, -} diff --git a/packages/webapp/src/models/index.ts b/packages/webapp/src/models/index.ts index 01e42a1..f601ba5 100644 --- a/packages/webapp/src/models/index.ts +++ b/packages/webapp/src/models/index.ts @@ -2,8 +2,5 @@ * Copyright (c) Microsoft. All rights reserved. * Licensed under the MIT license. See LICENSE file in the project. */ -export * from './aggregate' export * from './csv' -export * from './evaluate' -export * from './navigate' export * from './pipeline' diff --git a/packages/webapp/src/models/navigate/index.ts b/packages/webapp/src/models/navigate/index.ts deleted file mode 100644 index 53b09db..0000000 --- a/packages/webapp/src/models/navigate/index.ts +++ /dev/null @@ -1,39 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -export type AttributeRows = number[] - -export interface IAttributeRowsMap { - [columnIndex: number]: { - [attr: string]: AttributeRows - } -} - -export type AttributesInColumn = Set - -export interface IAttributesInColumnsMap { - [columnIndex: number]: AttributesInColumn -} - -export interface INavigateResult { - attrRowsMap: IAttributeRowsMap - attrsInColumnsMap: IAttributesInColumnsMap - allRows: AttributeRows -} - -export interface ISelectedAttributes { - [columnIndex: number]: string | undefined -} - -export interface IAttributesIntersectionValue { - value: string - estimatedCount: number - actualCount?: number -} - -export const defaultNavigateResult = { - attrRowsMap: {}, - attrsInColumnsMap: {}, - allRows: [], -} diff --git a/packages/webapp/src/states/dataShowcaseContext/evaluateResult.ts b/packages/webapp/src/states/dataShowcaseContext/evaluateResult.ts new file mode 100644 index 0000000..c0ecc95 --- /dev/null +++ b/packages/webapp/src/states/dataShowcaseContext/evaluateResult.ts @@ -0,0 +1,32 @@ +/*! + * Copyright (c) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE file in the project. + */ +import { + atom, + SetterOrUpdater, + useRecoilState, + useRecoilValue, + useSetRecoilState, +} from 'recoil' +import { IEvaluateResult } from 'sds-wasm' + +const state = atom({ + key: 'evaluate-result', + default: null, +}) + +export function useEvaluateResult(): [ + IEvaluateResult | null, + SetterOrUpdater, +] { + return useRecoilState(state) +} + +export function useEvaluateResultValue(): IEvaluateResult | null { + return useRecoilValue(state) +} + +export function useEvaluateResultSetter(): SetterOrUpdater { + return useSetRecoilState(state) +} diff --git a/packages/webapp/src/states/dataShowcaseContext/evaluatedResult.ts b/packages/webapp/src/states/dataShowcaseContext/evaluatedResult.ts deleted file mode 100644 index d235cbf..0000000 --- a/packages/webapp/src/states/dataShowcaseContext/evaluatedResult.ts +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -import { - atom, - SetterOrUpdater, - useRecoilState, - useRecoilValue, - useSetRecoilState, -} from 'recoil' -import { defaultEvaluatedResult, IEvaluatedResult } from '~models' - -const state = atom({ - key: 'evaluted-result', - default: defaultEvaluatedResult, -}) - -export function useEvaluatedResult(): [ - IEvaluatedResult, - SetterOrUpdater, -] { - return useRecoilState(state) -} - -export function useEvaluatedResultValue(): IEvaluatedResult { - return useRecoilValue(state) -} - -export function useEvaluatedResultSetter(): SetterOrUpdater { - return useSetRecoilState(state) -} diff --git a/packages/webapp/src/states/dataShowcaseContext/hooks.ts b/packages/webapp/src/states/dataShowcaseContext/hooks.ts new file mode 100644 index 0000000..b486c75 --- /dev/null +++ b/packages/webapp/src/states/dataShowcaseContext/hooks.ts @@ -0,0 +1,67 @@ +/*! + * Copyright (c) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE file in the project. + */ +import { useCallback, useMemo } from 'react' +import { HeaderNames } from 'sds-wasm' +import { defaultCsvContent } from '~models' +import { + useEvaluateResultSetter, + useSensitiveContentSetter, + useSyntheticContentSetter, + useSyntheticContentValue, + useWasmWorkerValue, +} from '~states' + +export function useClearSensitiveData(): () => Promise { + const worker = useWasmWorkerValue() + const setSensitiveContent = useSensitiveContentSetter() + const clearGenerate = useClearGenerate() + + return useCallback(async () => { + await worker?.clearSensitiveData() + setSensitiveContent(defaultCsvContent) + await clearGenerate() + }, [worker, setSensitiveContent, clearGenerate]) +} + +export function useClearGenerate(): () => Promise { + const worker = useWasmWorkerValue() + const setSyntheticContent = useSyntheticContentSetter() + const clearEvaluate = useClearEvaluate() + + return useCallback(async () => { + await worker?.clearGenerate() + setSyntheticContent(defaultCsvContent) + await clearEvaluate() + }, [worker, setSyntheticContent, clearEvaluate]) +} + +export function useClearEvaluate(): () => Promise { + const worker = useWasmWorkerValue() + const setEvaluateResult = useEvaluateResultSetter() + const clearNavigate = useClearNavigate() + + return useCallback(async () => { + await worker?.clearEvaluate() + setEvaluateResult(null) + await clearNavigate() + }, [worker, setEvaluateResult, clearNavigate]) +} + +export function useClearNavigate(): () => Promise { + const worker = useWasmWorkerValue() + + return useCallback(async () => { + await worker?.clearNavigate() + }, [worker]) +} + +export function useSyntheticHeaders(): HeaderNames { + const syntheticContent = useSyntheticContentValue() + + return useMemo( + () => syntheticContent.headers.map(h => h.name), + [syntheticContent.headers], + ) +} diff --git a/packages/webapp/src/states/dataShowcaseContext/index.ts b/packages/webapp/src/states/dataShowcaseContext/index.ts index 642344e..2358198 100644 --- a/packages/webapp/src/states/dataShowcaseContext/index.ts +++ b/packages/webapp/src/states/dataShowcaseContext/index.ts @@ -3,16 +3,14 @@ * Licensed under the MIT license. See LICENSE file in the project. */ export * from './cacheSize' -export * from './evaluatedResult' +export * from './evaluateResult' export * from './isProcessing' -export * from './navigateResult' export * from './processingProgress' export * from './recordLimit' export * from './reportingLength' export * from './resolution' -export * from './selectedAttributeRows' -export * from './selectedAttributes' export * from './selectedPipelineStep' export * from './sensitiveContent' export * from './syntheticContent' export * from './wasmWorker' +export * from './hooks' diff --git a/packages/webapp/src/states/dataShowcaseContext/navigateResult.ts b/packages/webapp/src/states/dataShowcaseContext/navigateResult.ts deleted file mode 100644 index 45af1e8..0000000 --- a/packages/webapp/src/states/dataShowcaseContext/navigateResult.ts +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -import { - atom, - SetterOrUpdater, - useRecoilState, - useRecoilValue, - useSetRecoilState, -} from 'recoil' -import { defaultNavigateResult, INavigateResult } from '~models' - -const state = atom({ - key: 'navigate-result', - default: defaultNavigateResult, -}) - -export function useNavigateResult(): [ - INavigateResult, - SetterOrUpdater, -] { - return useRecoilState(state) -} - -export function useNavigateResultValue(): INavigateResult { - return useRecoilValue(state) -} - -export function useNavigateResultSetter(): SetterOrUpdater { - return useSetRecoilState(state) -} diff --git a/packages/webapp/src/states/dataShowcaseContext/selectedAttributeRows.ts b/packages/webapp/src/states/dataShowcaseContext/selectedAttributeRows.ts deleted file mode 100644 index 579a981..0000000 --- a/packages/webapp/src/states/dataShowcaseContext/selectedAttributeRows.ts +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -import { - atom, - SetterOrUpdater, - useRecoilState, - useRecoilValue, - useSetRecoilState, -} from 'recoil' -import { AttributeRows } from '~models' - -const state = atom({ - key: 'selected-attribute-rows', - default: [], -}) - -export function useSelectedAttributeRows(): [ - AttributeRows, - SetterOrUpdater, -] { - return useRecoilState(state) -} - -export function useSelectedAttributeRowsValue(): AttributeRows { - return useRecoilValue(state) -} - -export function useSelectedAttributeRowsSetter(): SetterOrUpdater { - return useSetRecoilState(state) -} diff --git a/packages/webapp/src/states/dataShowcaseContext/selectedAttributes.ts b/packages/webapp/src/states/dataShowcaseContext/selectedAttributes.ts deleted file mode 100644 index 2bc995c..0000000 --- a/packages/webapp/src/states/dataShowcaseContext/selectedAttributes.ts +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -import { - atom, - SetterOrUpdater, - useRecoilState, - useRecoilValue, - useSetRecoilState, -} from 'recoil' -import { ISelectedAttributes } from '~models' - -const state = atom({ - key: 'selected-attributes', - default: {}, -}) - -export function useSelectedAttributes(): [ - ISelectedAttributes, - SetterOrUpdater, -] { - return useRecoilState(state) -} - -export function useSelectedAttributesValue(): ISelectedAttributes { - return useRecoilValue(state) -} - -export function useSelectedAttributesSetter(): SetterOrUpdater { - return useSetRecoilState(state) -} diff --git a/packages/webapp/src/storage/index.ts b/packages/webapp/src/storage/index.ts deleted file mode 100644 index 7fd6d35..0000000 --- a/packages/webapp/src/storage/index.ts +++ /dev/null @@ -1,4 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ diff --git a/packages/webapp/src/storage/localStorageEffect.ts b/packages/webapp/src/storage/localStorageEffect.ts deleted file mode 100644 index 7b2558e..0000000 --- a/packages/webapp/src/storage/localStorageEffect.ts +++ /dev/null @@ -1,23 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -import { AtomEffect, DefaultValue } from 'recoil' - -export function localStorageEffect(key: string): AtomEffect { - return ({ setSelf, onSet }) => { - const savedValue = localStorage.getItem(key) - - if (savedValue != null) { - setSelf(JSON.parse(savedValue)) - } - - onSet(newValue => { - if (newValue instanceof DefaultValue) { - localStorage.removeItem(key) - } else { - localStorage.setItem(key, JSON.stringify(newValue)) - } - }) - } -} diff --git a/packages/webapp/src/utils/binning.ts b/packages/webapp/src/utils/binning.ts index 25071c2..dcacc0b 100644 --- a/packages/webapp/src/utils/binning.ts +++ b/packages/webapp/src/utils/binning.ts @@ -2,9 +2,9 @@ * Copyright (c) Microsoft. All rights reserved. * Licensed under the MIT license. See LICENSE file in the project. */ +import { CsvRecord } from 'sds-wasm' import { calcPrecision, countDecimals } from './math' import { stringToNumber } from './strings' -import { CsvRecord } from '~models' interface IBin { match: (value: string) => boolean diff --git a/packages/webapp/src/workers/sds-wasm/SdsWasmWorker.ts b/packages/webapp/src/workers/sds-wasm/SdsWasmWorker.ts index 3e693d6..8fb401f 100644 --- a/packages/webapp/src/workers/sds-wasm/SdsWasmWorker.ts +++ b/packages/webapp/src/workers/sds-wasm/SdsWasmWorker.ts @@ -3,42 +3,43 @@ * Licensed under the MIT license. See LICENSE file in the project. */ import { - AttributeRows, - AttributesInColumn, - CsvRecord, - IAggregatedCombinations, - IAttributeRowsMap, - IAttributesIntersectionValue, - IEvaluatedResult, - INavigateResult, - ISelectedAttributes, -} from 'src/models' + CsvData, + HeaderNames, + IAttributesIntersectionByColumn, + IEvaluateResult, + ISelectedAttributesByColumn, + ReportProgressCallback, +} from 'sds-wasm' import { v4 } from 'uuid' import { + SdsWasmAttributesIntersectionsByColumnMessage, + SdsWasmAttributesIntersectionsByColumnResponse, + SdsWasmClearEvaluateMessage, + SdsWasmClearGenerateMessage, + SdsWasmClearNavigateMessage, + SdsWasmClearSensitiveDataMessage, + SdsWasmErrorResponse, + SdsWasmEvaluateMessage, + SdsWasmEvaluateResponse, SdsWasmGenerateMessage, SdsWasmGenerateResponse, SdsWasmInitMessage, SdsWasmMessage, SdsWasmMessageType, - SdsWasmResponse, - SdsWasmEvaluateMessage, SdsWasmNavigateMessage, - SdsWasmIntersectSelectedAttributesMessage, - SdsWasmIntersectSelectedAttributesResponse, - ReportProgressCallback, -} from './types' -import { - SdsWasmEvaluateResponse, - SdsWasmIntersectAttributesInColumnsMessage, - SdsWasmIntersectAttributesInColumnsResponse, - SdsWasmNavigateResponse, SdsWasmReportProgressResponse, -} from '.' + SdsWasmResponse, + SdsWasmSelectAttributesMessage, +} from './types' +import Worker from './worker?worker' type SdsWasmResponseCallback = ((value: SdsWasmResponse) => void) | undefined +type SdsWasmErrorCallback = ((reason?: string) => void) | undefined + interface ICallbackMapValue { resolver: SdsWasmResponseCallback + rejector: SdsWasmErrorCallback reportProgress?: ReportProgressCallback } @@ -60,6 +61,9 @@ export class SdsWasmWorker { callback.reportProgress?.( (response as SdsWasmReportProgressResponse).progress, ) + } else if (response.type === SdsWasmMessageType.Error) { + callback.rejector?.((response as SdsWasmErrorResponse).errorMessage) + this._callback_map.delete(response.id) } else { callback.resolver?.(response) this._callback_map.delete(response.id) @@ -73,12 +77,16 @@ export class SdsWasmWorker { reportProgress?: ReportProgressCallback, ): Promise { let resolver: SdsWasmResponseCallback = undefined - const receivePromise = new Promise(resolve => { + let rejector: SdsWasmErrorCallback = undefined + + const receivePromise = new Promise((resolve, reject) => { resolver = resolve + rejector = reject }) this._callback_map.set(message.id, { resolver, + rejector, reportProgress, }) this._worker?.postMessage(message) @@ -87,7 +95,7 @@ export class SdsWasmWorker { } public async init(logLevel: string): Promise { - this._worker = new Worker((await import('./worker?url')).default) + this._worker = new Worker() this._worker.onmessage = this.responseReceived.bind(this) const response = await this.execute({ @@ -101,82 +109,110 @@ export class SdsWasmWorker { return response.type === SdsWasmMessageType.Init } + public async clearSensitiveData(): Promise { + const response = await this.execute({ + id: v4(), + type: SdsWasmMessageType.ClearSensitiveData, + } as SdsWasmClearSensitiveDataMessage) + + return response.type === SdsWasmMessageType.ClearSensitiveData + } + + public async clearGenerate(): Promise { + const response = await this.execute({ + id: v4(), + type: SdsWasmMessageType.ClearGenerate, + } as SdsWasmClearGenerateMessage) + + return response.type === SdsWasmMessageType.ClearGenerate + } + + public async clearEvaluate(): Promise { + const response = await this.execute({ + id: v4(), + type: SdsWasmMessageType.ClearEvaluate, + } as SdsWasmClearEvaluateMessage) + + return response.type === SdsWasmMessageType.ClearEvaluate + } + + public async clearNavigate(): Promise { + const response = await this.execute({ + id: v4(), + type: SdsWasmMessageType.ClearNavigate, + } as SdsWasmClearNavigateMessage) + + return response.type === SdsWasmMessageType.ClearNavigate + } + public async generate( - csvContent: CsvRecord[], + sensitiveCsvData: CsvData, useColumns: string[], sensitiveZeros: string[], recordLimit: number, resolution: number, cacheSize: number, reportProgress?: ReportProgressCallback, - ): Promise { + emptyValue = '', + seeded = true, + ): Promise { const response = await this.execute( { id: v4(), type: SdsWasmMessageType.Generate, - csvContent, + sensitiveCsvData, useColumns, sensitiveZeros, recordLimit, resolution, + emptyValue, cacheSize, + seeded, } as SdsWasmGenerateMessage, reportProgress, ) if (response.type === SdsWasmMessageType.Generate) { - return (response as SdsWasmGenerateResponse).records + return (response as SdsWasmGenerateResponse).syntheticCsvData } return undefined } public async evaluate( - sensitiveCsvContent: CsvRecord[], - syntheticCsvContent: CsvRecord[], - useColumns: string[], - sensitiveZeros: string[], - recordLimit: number, reportingLength: number, - resolution: number, + sensitivityThreshold = 0, + combinationDelimiter = ';', + includeAggregatesCount = false, reportProgress?: ReportProgressCallback, - ): Promise { + ): Promise { const response = await this.execute( { id: v4(), type: SdsWasmMessageType.Evaluate, - sensitiveCsvContent, - syntheticCsvContent, - useColumns, - sensitiveZeros, - recordLimit, reportingLength, - resolution, + sensitivityThreshold, + combinationDelimiter, + includeAggregatesCount, } as SdsWasmEvaluateMessage, reportProgress, ) if (response.type === SdsWasmMessageType.Evaluate) { - return (response as SdsWasmEvaluateResponse).evaluatedResult + return (response as SdsWasmEvaluateResponse).evaluateResult } return undefined } - public async navigate( - syntheticCsvContent: CsvRecord[], - ): Promise { + public async navigate(): Promise { const response = await this.execute({ id: v4(), type: SdsWasmMessageType.Navigate, - syntheticCsvContent, } as SdsWasmNavigateMessage) - if (response.type === SdsWasmMessageType.Navigate) { - return (response as SdsWasmNavigateResponse).navigateResult - } - return undefined + return response.type === SdsWasmMessageType.Navigate } - public async findColumnsWithZeros(items: CsvRecord[]): Promise { + public async findColumnsWithZeros(items: CsvData): Promise { const zeros = new Set() items.forEach(line => { @@ -191,52 +227,30 @@ export class SdsWasmWorker { return Array.from(zeros) } - public async intersectSelectedAttributesWith( - selectedAttributes: ISelectedAttributes, - initialRows: AttributeRows, - attrRowsMap: IAttributeRowsMap, - ): Promise { + public async selectAttributes( + attributes: ISelectedAttributesByColumn, + ): Promise { const response = await this.execute({ id: v4(), - type: SdsWasmMessageType.IntersectSelectedAttributes, - selectedAttributes, - initialRows, - attrRowsMap, - } as SdsWasmIntersectSelectedAttributesMessage) + type: SdsWasmMessageType.SelectAttributes, + attributes, + } as SdsWasmSelectAttributesMessage) - if (response.type === SdsWasmMessageType.IntersectSelectedAttributes) { - return (response as SdsWasmIntersectSelectedAttributesResponse) - .intersectionResult - } - return undefined + return response.type === SdsWasmMessageType.SelectAttributes } - public async intersectAttributesInColumnsWith( - headers: CsvRecord, - initialRows: AttributeRows, - attrsInColumn: AttributesInColumn, - selectedAttributeRows: AttributeRows, - selectedAttributes: ISelectedAttributes, - attrRowsMap: IAttributeRowsMap, - columnIndex: number, - sensitiveAggregatedCombinations?: IAggregatedCombinations, - ): Promise { + public async attributesIntersectionsByColumn( + columns: HeaderNames, + ): Promise { const response = await this.execute({ id: v4(), - type: SdsWasmMessageType.IntersectAttributesInColumns, - headers, - initialRows, - attrsInColumn, - selectedAttributeRows, - selectedAttributes, - attrRowsMap, - columnIndex, - sensitiveAggregatedCombinations, - } as SdsWasmIntersectAttributesInColumnsMessage) + type: SdsWasmMessageType.AttributesIntersectionsByColumn, + columns, + } as SdsWasmAttributesIntersectionsByColumnMessage) - if (response.type === SdsWasmMessageType.IntersectAttributesInColumns) { - return (response as SdsWasmIntersectAttributesInColumnsResponse) - .intersectionResult + if (response.type === SdsWasmMessageType.AttributesIntersectionsByColumn) { + return (response as SdsWasmAttributesIntersectionsByColumnResponse) + .attributesIntersectionByColumn } return undefined } diff --git a/packages/webapp/src/workers/sds-wasm/types.ts b/packages/webapp/src/workers/sds-wasm/types.ts index 84aef95..2392438 100644 --- a/packages/webapp/src/workers/sds-wasm/types.ts +++ b/packages/webapp/src/workers/sds-wasm/types.ts @@ -2,26 +2,27 @@ * Copyright (c) Microsoft. All rights reserved. * Licensed under the MIT license. See LICENSE file in the project. */ -import { CsvRecord } from 'src/models/csv' import { - AttributeRows, - AttributesInColumn, - IAggregatedCombinations, - IAttributeRowsMap, - IAttributesIntersectionValue, - IEvaluatedResult, - INavigateResult, - ISelectedAttributes, -} from '~models' + CsvData, + HeaderNames, + IAttributesIntersectionByColumn, + IEvaluateResult, + ISelectedAttributesByColumn, +} from 'sds-wasm' export enum SdsWasmMessageType { Init = 'Init', + Error = 'Error', + ClearSensitiveData = 'ClearSensitiveData', + ClearGenerate = 'ClearGenerate', + ClearEvaluate = 'ClearEvaluate', + ClearNavigate = 'ClearNavigate', ReportProgress = 'ReportProgress', Generate = 'Generate', Evaluate = 'Evaluate', Navigate = 'Navigate', - IntersectSelectedAttributes = 'IntersectSelectedAttributes', - IntersectAttributesInColumns = 'IntersectAttributesInColumns', + SelectAttributes = 'SelectAttributes', + AttributesIntersectionsByColumn = 'AttributesIntersectionsByColumn', } export interface SdsWasmMessage { @@ -29,6 +30,11 @@ export interface SdsWasmMessage { type: SdsWasmMessageType } +export interface SdsWasmResponse { + type: SdsWasmMessageType + id: string +} + export interface SdsWasmInitMessage extends SdsWasmMessage { type: SdsWasmMessageType.Init logLevel: string @@ -36,92 +42,107 @@ export interface SdsWasmInitMessage extends SdsWasmMessage { wasmPath: string } -export interface SdsWasmGenerateMessage extends SdsWasmMessage { - type: SdsWasmMessageType.Generate - csvContent: CsvRecord[] - useColumns: string[] - sensitiveZeros: string[] - recordLimit: number - resolution: number - cacheSize: number -} - -export interface SdsWasmEvaluateMessage extends SdsWasmMessage { - type: SdsWasmMessageType.Evaluate - sensitiveCsvContent: CsvRecord[] - syntheticCsvContent: CsvRecord[] - useColumns: string[] - sensitiveZeros: string[] - recordLimit: number - reportingLength: number - resolution: number -} - -export interface SdsWasmNavigateMessage extends SdsWasmMessage { - type: SdsWasmMessageType.Navigate - syntheticCsvContent: CsvRecord[] -} - -export interface SdsWasmIntersectSelectedAttributesMessage - extends SdsWasmMessage { - type: SdsWasmMessageType.IntersectSelectedAttributes - selectedAttributes: ISelectedAttributes - initialRows: AttributeRows - attrRowsMap: IAttributeRowsMap -} - -export interface SdsWasmIntersectAttributesInColumnsMessage - extends SdsWasmMessage { - type: SdsWasmMessageType.IntersectAttributesInColumns - headers: CsvRecord - initialRows: AttributeRows - attrsInColumn: AttributesInColumn - selectedAttributeRows: AttributeRows - selectedAttributes: ISelectedAttributes - attrRowsMap: IAttributeRowsMap - columnIndex: number - sensitiveAggregatedCombinations?: IAggregatedCombinations -} - -export interface SdsWasmResponse { - type: SdsWasmMessageType - id: string -} - export interface SdsWasmInitResponse extends SdsWasmResponse { type: SdsWasmMessageType.Init } +export interface SdsWasmErrorResponse extends SdsWasmResponse { + type: SdsWasmMessageType.Error + errorMessage: string +} + +export interface SdsWasmClearSensitiveDataMessage extends SdsWasmMessage { + type: SdsWasmMessageType.ClearSensitiveData +} + +export interface SdsWasmClearSensitiveDataResponse extends SdsWasmResponse { + type: SdsWasmMessageType.ClearSensitiveData +} + +export interface SdsWasmClearGenerateMessage extends SdsWasmMessage { + type: SdsWasmMessageType.ClearGenerate +} + +export interface SdsWasmClearGenerateResponse extends SdsWasmResponse { + type: SdsWasmMessageType.ClearGenerate +} + +export interface SdsWasmClearEvaluateMessage extends SdsWasmMessage { + type: SdsWasmMessageType.ClearEvaluate +} + +export interface SdsWasmClearEvaluateResponse extends SdsWasmResponse { + type: SdsWasmMessageType.ClearEvaluate +} + +export interface SdsWasmClearNavigateMessage extends SdsWasmMessage { + type: SdsWasmMessageType.ClearNavigate +} + +export interface SdsWasmClearNavigateResponse extends SdsWasmResponse { + type: SdsWasmMessageType.ClearNavigate +} + export interface SdsWasmReportProgressResponse extends SdsWasmResponse { type: SdsWasmMessageType.ReportProgress progress: number } +export interface SdsWasmGenerateMessage extends SdsWasmMessage { + type: SdsWasmMessageType.Generate + sensitiveCsvData: CsvData + useColumns: HeaderNames + sensitiveZeros: HeaderNames + recordLimit: number + resolution: number + emptyValue: string + cacheSize: number + seeded: boolean +} + export interface SdsWasmGenerateResponse extends SdsWasmResponse { type: SdsWasmMessageType.Generate - records?: CsvRecord[] + syntheticCsvData: CsvData +} + +export interface SdsWasmEvaluateMessage extends SdsWasmMessage { + type: SdsWasmMessageType.Evaluate + reportingLength: number + sensitivityThreshold: number + combinationDelimiter: string + includeAggregatesCount: boolean } export interface SdsWasmEvaluateResponse extends SdsWasmResponse { type: SdsWasmMessageType.Evaluate - evaluatedResult?: IEvaluatedResult + evaluateResult: IEvaluateResult +} + +export interface SdsWasmNavigateMessage extends SdsWasmMessage { + type: SdsWasmMessageType.Navigate } export interface SdsWasmNavigateResponse extends SdsWasmResponse { type: SdsWasmMessageType.Navigate - navigateResult?: INavigateResult } -export interface SdsWasmIntersectSelectedAttributesResponse +export interface SdsWasmSelectAttributesMessage extends SdsWasmMessage { + type: SdsWasmMessageType.SelectAttributes + attributes: ISelectedAttributesByColumn +} + +export interface SdsWasmSelectAttributesResponse extends SdsWasmResponse { + type: SdsWasmMessageType.SelectAttributes +} + +export interface SdsWasmAttributesIntersectionsByColumnMessage + extends SdsWasmMessage { + type: SdsWasmMessageType.AttributesIntersectionsByColumn + columns: HeaderNames +} + +export interface SdsWasmAttributesIntersectionsByColumnResponse extends SdsWasmResponse { - type: SdsWasmMessageType.IntersectSelectedAttributes - intersectionResult?: AttributeRows + type: SdsWasmMessageType.AttributesIntersectionsByColumn + attributesIntersectionByColumn: IAttributesIntersectionByColumn } - -export interface SdsWasmIntersectAttributesInColumnsResponse - extends SdsWasmResponse { - type: SdsWasmMessageType.IntersectAttributesInColumns - intersectionResult?: IAttributesIntersectionValue[] -} - -export type ReportProgressCallback = (progress: number) => void diff --git a/packages/webapp/src/workers/sds-wasm/worker.js b/packages/webapp/src/workers/sds-wasm/worker.js deleted file mode 100644 index 46b1fbc..0000000 --- a/packages/webapp/src/workers/sds-wasm/worker.js +++ /dev/null @@ -1,265 +0,0 @@ -/*! - * Copyright (c) Microsoft. All rights reserved. - * Licensed under the MIT license. See LICENSE file in the project. - */ -let initLogger -let generate -let evaluate - -function postProgress(id, progress) { - postMessage({ - id, - type: 'ReportProgress', - progress, - }) -} - -function calcAttributeRowsMap(items) { - const attrRowsMap = {} - - items.forEach((record, i) => { - record.forEach((value, columnIndex) => { - if (!attrRowsMap[columnIndex]) { - attrRowsMap[columnIndex] = {} - } - if (value) { - if (!attrRowsMap[columnIndex][value]) { - attrRowsMap[columnIndex][value] = [] - } - attrRowsMap[columnIndex][value].push(i) - } - }) - }) - return attrRowsMap -} - -function calcAttributesInColumnsMap(items) { - const attrsInColumns = {} - - items.forEach(record => { - record.forEach((value, columnIndex) => { - if (!attrsInColumns[columnIndex]) { - attrsInColumns[columnIndex] = new Set() - } - if (value) { - attrsInColumns[columnIndex].add(value) - } - }) - }) - return attrsInColumns -} - -function intersectAttributeRows(a, b) { - const result = [] - let aIndex = 0 - let bIndex = 0 - - while (aIndex < a.length && bIndex < b.length) { - const aVal = a[aIndex] - const bVal = b[bIndex] - - if (aVal > bVal) { - bIndex++ - } else if (aVal < bVal) { - aIndex++ - } else { - result.push(aVal) - aIndex++ - bIndex++ - } - } - return result -} - -function getAggregatedCount(headers, aggregatedCombinations, selected) { - if (!selected) { - return undefined - } - const key = Object.entries(selected) - .map(entry => `${headers[entry[0]]}:${entry[1]}`) - .sort() - .join(';') - - return aggregatedCombinations?.[key]?.count -} - -async function handleInit(message) { - // eslint-disable-next-line no-undef - importScripts(message.wasmJsPath) - - // eslint-disable-next-line no-undef - initLogger = wasm_bindgen.init_logger - // eslint-disable-next-line no-undef - generate = wasm_bindgen.generate - // eslint-disable-next-line no-undef - evaluate = wasm_bindgen.evaluate - - // eslint-disable-next-line no-undef - await wasm_bindgen(message.wasmPath) - - initLogger(message.logLevel) - - return { - id: message.id, - type: message.type, - } -} - -function handleGenerate(message) { - return { - id: message.id, - type: message.type, - records: generate( - message.csvContent, - message.useColumns, - message.sensitiveZeros, - message.recordLimit, - message.resolution, - message.cacheSize, - p => { - postProgress(message.id, p) - }, - ), - } -} - -function handleEvaluate(message) { - return { - id: message.id, - type: message.type, - evaluatedResult: evaluate( - message.sensitiveCsvContent, - message.syntheticCsvContent, - message.useColumns, - message.sensitiveZeros, - message.recordLimit, - message.reportingLength, - message.resolution, - p => { - postProgress(message.id, p) - }, - ), - } -} - -function handleNavigate(message) { - return { - id: message.id, - type: message.type, - navigateResult: { - attrRowsMap: calcAttributeRowsMap(message.syntheticCsvContent), - attrsInColumnsMap: calcAttributesInColumnsMap( - message.syntheticCsvContent, - ), - allRows: message.syntheticCsvContent.map((_, i) => i), - }, - } -} - -function handleIntersectSelectedAttributes(message) { - let newSelectedAttributeRows = message.initialRows - - Object.entries(message.selectedAttributes).forEach(entry => { - newSelectedAttributeRows = - intersectAttributeRows( - newSelectedAttributeRows, - message.attrRowsMap[entry[0]][entry[1]], - ) ?? [] - }) - return { - id: message.id, - type: message.type, - intersectionResult: newSelectedAttributeRows, - } -} - -function handleIntersectAttributeInColumns(message) { - const result = [] - - const selectedAttribute = message.selectedAttributes[message.columnIndex] - const selectedAttributesButCurrentColumn = { - ...message.selectedAttributes, - } - let selectedAttributeRowsButCurrentColumn - - if (selectedAttribute !== undefined) { - delete selectedAttributesButCurrentColumn[message.columnIndex] - selectedAttributeRowsButCurrentColumn = handleIntersectSelectedAttributes({ - initialRows: message.initialRows, - selectedAttributes: selectedAttributesButCurrentColumn, - attrRowsMap: message.attrRowsMap, - }).intersectionResult - } else { - selectedAttributeRowsButCurrentColumn = message.selectedAttributeRows - } - - message.attrsInColumn.forEach(attr => { - const estimatedCount = - intersectAttributeRows( - attr === selectedAttribute - ? message.selectedAttributeRows - : selectedAttributeRowsButCurrentColumn, - message.attrRowsMap[message.columnIndex][attr], - )?.length ?? 0 - - if (estimatedCount > 0) { - result.push({ - value: attr, - estimatedCount, - actualCount: getAggregatedCount( - message.headers, - message.sensitiveAggregatedCombinations, - { - ...(attr === selectedAttribute - ? message.selectedAttributes - : selectedAttributesButCurrentColumn), - [message.columnIndex]: attr, - }, - ), - }) - } - }) - // sort descending by estimated count - result.sort((a, b) => b.estimatedCount - a.estimatedCount) - return { - id: message.id, - type: message.type, - intersectionResult: result, - } -} - -onmessage = async event => { - const message = event?.data - let response = undefined - - switch (message?.type) { - case 'Init': { - response = await handleInit(message) - break - } - case 'Generate': { - response = handleGenerate(message) - break - } - case 'Evaluate': { - response = handleEvaluate(message) - break - } - case 'Navigate': { - response = handleNavigate(message) - break - } - case 'IntersectSelectedAttributes': { - response = handleIntersectSelectedAttributes(message) - break - } - case 'IntersectAttributesInColumns': { - response = handleIntersectAttributeInColumns(message) - break - } - default: { - break - } - } - postMessage(response) -} diff --git a/packages/webapp/src/workers/sds-wasm/worker.ts b/packages/webapp/src/workers/sds-wasm/worker.ts new file mode 100644 index 0000000..c6a3b56 --- /dev/null +++ b/packages/webapp/src/workers/sds-wasm/worker.ts @@ -0,0 +1,223 @@ +/*! + * Copyright (c) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE file in the project. + */ +import init, { init_logger, SDSContext } from 'sds-wasm' +import { + SdsWasmAttributesIntersectionsByColumnMessage, + SdsWasmAttributesIntersectionsByColumnResponse, + SdsWasmClearEvaluateMessage, + SdsWasmClearEvaluateResponse, + SdsWasmClearGenerateMessage, + SdsWasmClearGenerateResponse, + SdsWasmClearNavigateMessage, + SdsWasmClearNavigateResponse, + SdsWasmClearSensitiveDataMessage, + SdsWasmClearSensitiveDataResponse, + SdsWasmErrorResponse, + SdsWasmEvaluateMessage, + SdsWasmEvaluateResponse, + SdsWasmGenerateMessage, + SdsWasmGenerateResponse, + SdsWasmInitMessage, + SdsWasmInitResponse, + SdsWasmMessage, + SdsWasmMessageType, + SdsWasmNavigateMessage, + SdsWasmNavigateResponse, + SdsWasmReportProgressResponse, + SdsWasmSelectAttributesMessage, + SdsWasmSelectAttributesResponse, +} from './types' + +let CONTEXT: SDSContext + +const HANDLERS = { + [SdsWasmMessageType.Init]: handleInit, + [SdsWasmMessageType.ClearSensitiveData]: handleClearSensitiveData, + [SdsWasmMessageType.ClearGenerate]: handleClearGenerate, + [SdsWasmMessageType.ClearEvaluate]: handleClearEvaluate, + [SdsWasmMessageType.ClearNavigate]: handleClearNavigate, + [SdsWasmMessageType.Generate]: handleGenerate, + [SdsWasmMessageType.Evaluate]: handleEvaluate, + [SdsWasmMessageType.Navigate]: handleNavigate, + [SdsWasmMessageType.SelectAttributes]: handleSelectAttributes, + [SdsWasmMessageType.AttributesIntersectionsByColumn]: + handleAttributesIntersectionsByColumn, +} + +function postProgress(id: string, progress: number) { + postMessage({ + id, + type: SdsWasmMessageType.ReportProgress, + progress, + } as SdsWasmReportProgressResponse) +} + +function postError(id: string, errorMessage: string) { + postMessage({ + id, + type: SdsWasmMessageType.Error, + errorMessage, + } as SdsWasmErrorResponse) +} + +async function handleInit( + message: SdsWasmInitMessage, +): Promise { + await init(message.wasmPath) + + init_logger(message.logLevel) + + CONTEXT = new SDSContext() + + return { + id: message.id, + type: message.type, + } +} + +async function handleClearSensitiveData( + message: SdsWasmClearSensitiveDataMessage, +): Promise { + CONTEXT.clearSensitiveData() + return { + id: message.id, + type: message.type, + } +} + +async function handleClearGenerate( + message: SdsWasmClearGenerateMessage, +): Promise { + CONTEXT.clearGenerate() + return { + id: message.id, + type: message.type, + } +} + +async function handleClearEvaluate( + message: SdsWasmClearEvaluateMessage, +): Promise { + CONTEXT.clearEvaluate() + return { + id: message.id, + type: message.type, + } +} + +async function handleClearNavigate( + message: SdsWasmClearNavigateMessage, +): Promise { + CONTEXT.clearNavigate() + return { + id: message.id, + type: message.type, + } +} + +async function handleGenerate( + message: SdsWasmGenerateMessage, +): Promise { + CONTEXT.setSensitiveData( + message.sensitiveCsvData, + message.useColumns, + message.sensitiveZeros, + message.recordLimit, + ) + + CONTEXT.generate( + message.cacheSize, + message.resolution, + message.emptyValue, + message.seeded, + p => { + postProgress(message.id, p) + }, + ) + + return { + id: message.id, + type: message.type, + syntheticCsvData: CONTEXT.generateResultToJs().syntheticData, + } +} + +async function handleEvaluate( + message: SdsWasmEvaluateMessage, +): Promise { + CONTEXT.evaluate( + message.reportingLength, + message.sensitivityThreshold, + p => { + postProgress(message.id, 0.5 * p) + }, + p => { + postProgress(message.id, 50.0 + 0.5 * p) + }, + ) + + const evaluateResult = CONTEXT.evaluateResultToJs( + message.combinationDelimiter, + message.includeAggregatesCount, + ) + + CONTEXT.protectSensitiveAggregatesCount() + + return { + id: message.id, + type: message.type, + evaluateResult, + } +} + +async function handleNavigate( + message: SdsWasmNavigateMessage, +): Promise { + CONTEXT.navigate() + + return { + id: message.id, + type: message.type, + } +} + +async function handleSelectAttributes( + message: SdsWasmSelectAttributesMessage, +): Promise { + CONTEXT.selectAttributes(message.attributes) + + return { + id: message.id, + type: message.type, + } +} + +async function handleAttributesIntersectionsByColumn( + message: SdsWasmAttributesIntersectionsByColumnMessage, +): Promise { + return { + id: message.id, + type: message.type, + attributesIntersectionByColumn: CONTEXT.attributesIntersectionsByColumn( + message.columns, + ), + } +} + +onmessage = async event => { + const message = event?.data as SdsWasmMessage | undefined + + if (message) { + try { + const response = HANDLERS[message.type]?.(message) + + if (response) { + postMessage(await response) + } + } catch (err) { + postError(message.type, `wasm error: ${err}`) + } + } +} diff --git a/packages/webapp/tsconfig.json b/packages/webapp/tsconfig.json index 89b36d2..bfebeec 100644 --- a/packages/webapp/tsconfig.json +++ b/packages/webapp/tsconfig.json @@ -11,9 +11,7 @@ "~models": ["src/models/index.ts"], "~models/*": ["src/models/*"], "~states": ["src/states/index.ts"], - "~states/*": ["src/states/*"], - "~storage": ["src/storage/index.ts"], - "~storage/*": ["src/storage/*"] + "~states/*": ["src/states/*"] }, "allowJs": true, "lib": ["ESNext", "DOM"], diff --git a/webapp.dockerfile b/webapp.dockerfile index 982233e..a19d1d1 100644 --- a/webapp.dockerfile +++ b/webapp.dockerfile @@ -24,7 +24,7 @@ WORKDIR /usr/src/sds COPY . . # build the wasm bindings for webapp to use -RUN cd packages/lib-wasm && wasm-pack build --release --target no-modules --out-dir ../../target/wasm +RUN cd packages/lib-wasm && wasm-pack build --release --target web --out-dir ../../target/wasm # --- compile application from typescript --- FROM node:14 as app-builder diff --git a/yarn.lock b/yarn.lock index f479365..2357f74 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10468,9 +10468,9 @@ resolve@^2.0.0-next.3: linkType: hard "sds-wasm@file:../../target/wasm::locator=webapp%40workspace%3Apackages%2Fwebapp": - version: 0.1.0 - resolution: "sds-wasm@file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp" - checksum: 1550beb7ae9523087f19757d2e67287739bf91b847ab00d8d8e4bd4ed601009e441071d0a775bc818977d2b241b0977983d0f7682eb3f577b02025a2d10609f7 + version: 1.0.0 + resolution: "sds-wasm@file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp" + checksum: dd8b934db19b301312cd3919ad007765652b8524dedbcd4dc742736d4567e02425db5ba5a5a40123be67492ff4117539f05445109d059b9e7452e65cfc51688b languageName: node linkType: hard