- refactor core library
- refactor python and wasm bindings - update python pipeline evaluate step to use the core library - add multi-threading support to the core library, cli application and python library - first version of unseeded data synthesis in rust (supported on core library, cli and python pipeline) - general performance optimizations on the core library, python pipeline and web application
This commit is contained in:
Родитель
ca020fc9a9
Коммит
981c25da13
|
@ -13,5 +13,4 @@ node_modules/
|
|||
pkg/
|
||||
target/
|
||||
.eslintrc.js
|
||||
babel.config.js
|
||||
packages/webapp/src/workers/sds-wasm/worker.js
|
||||
babel.config.js
|
|
@ -13697,10 +13697,10 @@ function $$SETUP_STATE(hydrateRuntimeState, basePath) {
|
|||
}]
|
||||
]],
|
||||
["sds-wasm", [
|
||||
["file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp", {
|
||||
"packageLocation": "./.yarn/cache/sds-wasm-file-aa35777e98-1550beb7ae.zip/node_modules/sds-wasm/",
|
||||
["file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp", {
|
||||
"packageLocation": "./.yarn/cache/sds-wasm-file-6dcf65ffd1-dd8b934db1.zip/node_modules/sds-wasm/",
|
||||
"packageDependencies": [
|
||||
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp"]
|
||||
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp"]
|
||||
],
|
||||
"linkType": "HARD",
|
||||
}]
|
||||
|
@ -15482,7 +15482,7 @@ function $$SETUP_STATE(hydrateRuntimeState, basePath) {
|
|||
["react-is", "npm:17.0.2"],
|
||||
["react-router-dom", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:6.0.2"],
|
||||
["recoil", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:0.5.2"],
|
||||
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp"],
|
||||
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp"],
|
||||
["styled-components", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:5.3.3"],
|
||||
["ts-node", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:10.4.0"],
|
||||
["typescript", "patch:typescript@npm%3A4.4.3#~builtin<compat/typescript>::version=4.4.3&hash=32657b"],
|
||||
|
|
33
.vsts-ci.yml
33
.vsts-ci.yml
|
@ -1,33 +0,0 @@
|
|||
name: SDS CI
|
||||
pool:
|
||||
vmImage: ubuntu-latest
|
||||
|
||||
trigger:
|
||||
batch: true
|
||||
branches:
|
||||
include:
|
||||
- main
|
||||
|
||||
stages:
|
||||
- stage: Compliance
|
||||
dependsOn: []
|
||||
jobs:
|
||||
- job: ComplianceJob
|
||||
pool:
|
||||
vmImage: windows-latest
|
||||
steps:
|
||||
- task: CredScan@3
|
||||
inputs:
|
||||
outputFormat: sarif
|
||||
debugMode: false
|
||||
|
||||
- task: ComponentGovernanceComponentDetection@0
|
||||
inputs:
|
||||
scanType: 'Register'
|
||||
verbosity: 'Verbose'
|
||||
alertWarningLevel: 'High'
|
||||
|
||||
- task: PublishSecurityAnalysisLogs@3
|
||||
inputs:
|
||||
ArtifactName: 'CodeAnalysisLogs'
|
||||
ArtifactType: 'Container'
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -42,6 +42,12 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||
|
||||
[[package]]
|
||||
name = "base-x"
|
||||
version = "0.2.8"
|
||||
|
@ -103,6 +109,50 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"lazy_static",
|
||||
"memoffset",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv"
|
||||
version = "1.1.6"
|
||||
|
@ -307,6 +357,25 @@ version = "2.4.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.8.0"
|
||||
|
@ -499,6 +568,31 @@ dependencies = [
|
|||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"crossbeam-deque",
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.10"
|
||||
|
@ -582,6 +676,9 @@ dependencies = [
|
|||
"lru",
|
||||
"pyo3",
|
||||
"rand",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -589,6 +686,7 @@ name = "sds-pyo3"
|
|||
version = "1.0.0"
|
||||
dependencies = [
|
||||
"csv",
|
||||
"env_logger",
|
||||
"log",
|
||||
"pyo3",
|
||||
"sds-core",
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
version: "3"
|
||||
version: '3'
|
||||
services:
|
||||
webapp:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: webapp.dockerfile
|
||||
ports:
|
||||
- 3000:80
|
||||
webapp:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: webapp.dockerfile
|
||||
ports:
|
||||
- 3000:80
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
"build:": "yarn workspaces foreach -ivt run build",
|
||||
"start:": "yarn workspaces foreach -piv run start",
|
||||
"lint:": "essex lint --fix --strict",
|
||||
"build:lib-wasm": "cd packages/lib-wasm && wasm-pack build --release --target no-modules --out-dir ../../target/wasm",
|
||||
"build:lib-wasm": "cd packages/lib-wasm && wasm-pack build --release --target web --out-dir ../../target/wasm",
|
||||
"prettify": "essex prettify",
|
||||
"rebuild-all": "cargo clean && run-s clean: && cargo build --release && run-s build:lib-wasm && yarn install && run-s build:"
|
||||
},
|
||||
|
|
|
@ -7,7 +7,7 @@ repository = "https://github.com/microsoft/synthetic-data-showcase"
|
|||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
sds-core = { path = "../core" }
|
||||
sds-core = { path = "../core", features = ["rayon"] }
|
||||
log = { version = "0.4" }
|
||||
env_logger = { version = "0.9" }
|
||||
structopt = { version = "0.3" }
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use log::{error, info, log_enabled, trace, Level::Debug};
|
||||
use sds_core::{
|
||||
data_block::{block::DataBlockCreator, csv_block_creator::CsvDataBlockCreator},
|
||||
processing::{aggregator::Aggregator, generator::Generator},
|
||||
utils::reporting::LoggerProgressReporter,
|
||||
data_block::{csv_block_creator::CsvDataBlockCreator, data_block_creator::DataBlockCreator},
|
||||
processing::{
|
||||
aggregator::Aggregator,
|
||||
generator::{Generator, SynthesisMode},
|
||||
},
|
||||
utils::{reporting::LoggerProgressReporter, threading::set_number_of_threads},
|
||||
};
|
||||
use std::process;
|
||||
use structopt::StructOpt;
|
||||
|
@ -26,6 +29,14 @@ enum Command {
|
|||
default_value = "100000"
|
||||
)]
|
||||
cache_max_size: usize,
|
||||
#[structopt(
|
||||
long = "mode",
|
||||
help = "synthesis mode",
|
||||
possible_values = &["seeded", "unseeded"],
|
||||
case_insensitive = true,
|
||||
default_value = "seeded"
|
||||
)]
|
||||
mode: SynthesisMode,
|
||||
},
|
||||
Aggregate {
|
||||
#[structopt(long = "aggregates-path", help = "generated aggregates file path")]
|
||||
|
@ -104,6 +115,12 @@ struct Cli {
|
|||
help = "columns where zeros should not be ignored (can be set multiple times)"
|
||||
)]
|
||||
sensitive_zeros: Vec<String>,
|
||||
|
||||
#[structopt(
|
||||
long = "n-threads",
|
||||
help = "number of threads used to process the data in parallel (default is the number of cores)"
|
||||
)]
|
||||
n_threads: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
@ -118,6 +135,10 @@ fn main() {
|
|||
|
||||
trace!("execution parameters: {:#?}", cli);
|
||||
|
||||
if let Some(n_threads) = cli.n_threads {
|
||||
set_number_of_threads(n_threads);
|
||||
}
|
||||
|
||||
match CsvDataBlockCreator::create(
|
||||
csv::ReaderBuilder::new()
|
||||
.delimiter(cli.sensitive_delimiter.chars().next().unwrap() as u8)
|
||||
|
@ -131,15 +152,20 @@ fn main() {
|
|||
synthetic_path,
|
||||
synthetic_delimiter,
|
||||
cache_max_size,
|
||||
mode,
|
||||
} => {
|
||||
let mut generator = Generator::new(&data_block);
|
||||
let generated_data =
|
||||
generator.generate(cli.resolution, cache_max_size, "", &mut progress_reporter);
|
||||
let mut generator = Generator::new(data_block);
|
||||
let generated_data = generator.generate(
|
||||
cli.resolution,
|
||||
cache_max_size,
|
||||
String::from(""),
|
||||
mode,
|
||||
&mut progress_reporter,
|
||||
);
|
||||
|
||||
if let Err(err) = generator.write_records(
|
||||
&generated_data.synthetic_data,
|
||||
if let Err(err) = generated_data.write_synthetic_data(
|
||||
&synthetic_path,
|
||||
synthetic_delimiter.chars().next().unwrap() as u8,
|
||||
synthetic_delimiter.chars().next().unwrap(),
|
||||
) {
|
||||
error!("error writing output file: {}", err);
|
||||
process::exit(1);
|
||||
|
@ -153,28 +179,24 @@ fn main() {
|
|||
sensitivity_threshold,
|
||||
records_sensitivity_path,
|
||||
} => {
|
||||
let mut aggregator = Aggregator::new(&data_block);
|
||||
let mut aggregator = Aggregator::new(data_block);
|
||||
let mut aggregated_data = aggregator.aggregate(
|
||||
reporting_length,
|
||||
sensitivity_threshold,
|
||||
&mut progress_reporter,
|
||||
);
|
||||
let privacy_risk =
|
||||
aggregator.calc_privacy_risk(&aggregated_data.aggregates_count, cli.resolution);
|
||||
let privacy_risk = aggregated_data.calc_privacy_risk(cli.resolution);
|
||||
|
||||
if !not_protect {
|
||||
Aggregator::protect_aggregates_count(
|
||||
&mut aggregated_data.aggregates_count,
|
||||
cli.resolution,
|
||||
);
|
||||
aggregated_data.protect_aggregates_count(cli.resolution);
|
||||
}
|
||||
|
||||
info!("Calculated privacy risk is: {:#?}", privacy_risk);
|
||||
|
||||
if let Err(err) = aggregator.write_aggregates_count(
|
||||
&aggregated_data.aggregates_count,
|
||||
if let Err(err) = aggregated_data.write_aggregates_count(
|
||||
&aggregates_path,
|
||||
aggregates_delimiter.chars().next().unwrap(),
|
||||
";",
|
||||
cli.resolution,
|
||||
!not_protect,
|
||||
) {
|
||||
|
@ -183,9 +205,7 @@ fn main() {
|
|||
}
|
||||
|
||||
if let Some(path) = records_sensitivity_path {
|
||||
if let Err(err) = aggregator
|
||||
.write_records_sensitivity(&aggregated_data.records_sensitivity, &path)
|
||||
{
|
||||
if let Err(err) = aggregated_data.write_records_sensitivity(&path, '\t') {
|
||||
error!("error writing output file: {}", err);
|
||||
process::exit(1);
|
||||
}
|
||||
|
|
|
@ -18,4 +18,7 @@ getrandom = { version = "0.2", features = ["js"] }
|
|||
log = { version = "0.4", features = ["std"] }
|
||||
csv = { version = "1.1" }
|
||||
instant = { version = "0.1", features = [ "stdweb", "wasm-bindgen" ] }
|
||||
pyo3 = { version = "0.15", features = ["extension-module"], optional = true }
|
||||
pyo3 = { version = "0.15", features = ["extension-module"], optional = true }
|
||||
rayon = { version = "1.5", optional = true }
|
||||
serde = { version = "1.0", features = [ "derive", "rc" ] }
|
||||
serde_json = { version = "1.0" }
|
|
@ -1,12 +1,16 @@
|
|||
use super::{
|
||||
record::DataBlockRecord,
|
||||
typedefs::{
|
||||
AttributeRows, AttributeRowsMap, CsvRecord, CsvRecordSlice, DataBlockHeaders,
|
||||
DataBlockRecords, DataBlockRecordsSlice,
|
||||
AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, ColumnIndexByName,
|
||||
DataBlockHeaders, DataBlockRecords,
|
||||
},
|
||||
value::DataBlockValue,
|
||||
};
|
||||
use std::collections::HashSet;
|
||||
use fnv::FnvHashMap;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{processing::aggregator::typedefs::RecordsSet, utils::math::uround_down};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
@ -15,189 +19,162 @@ use pyo3::prelude::*;
|
|||
/// The goal of this is to allow data processing to handle with memory references
|
||||
/// to the data block instead of copying data around
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DataBlock {
|
||||
/// Vector of strings representhing the data headers
|
||||
/// Vector of strings representing the data headers
|
||||
pub headers: DataBlockHeaders,
|
||||
/// Vector of data records, where each record represents a row (headers not included)
|
||||
pub records: DataBlockRecords,
|
||||
}
|
||||
|
||||
impl DataBlock {
|
||||
/// Returns a new DataBlock with default values
|
||||
pub fn default() -> DataBlock {
|
||||
DataBlock {
|
||||
headers: DataBlockHeaders::default(),
|
||||
records: DataBlockRecords::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new DataBlock
|
||||
/// # Arguments
|
||||
/// * `headers` - Vector of string representhing the data headers
|
||||
/// * `headers` - Vector of string representing the data headers
|
||||
/// * `records` - Vector of data records, where each record represents a row (headers not included)
|
||||
#[inline]
|
||||
pub fn new(headers: DataBlockHeaders, records: DataBlockRecords) -> DataBlock {
|
||||
DataBlock { headers, records }
|
||||
}
|
||||
|
||||
/// Calcules the rows where each value on the data records is present
|
||||
/// # Arguments
|
||||
/// * `records`: List of records to analyze
|
||||
/// Returns a map of column name -> column index
|
||||
#[inline]
|
||||
pub fn calc_attr_rows(records: &DataBlockRecordsSlice) -> AttributeRowsMap {
|
||||
pub fn calc_column_index_by_name(&self) -> ColumnIndexByName {
|
||||
self.headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(column_index, column_name)| ((**column_name).clone(), column_index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Calculates the rows where each value on the data records is present
|
||||
#[inline]
|
||||
pub fn calc_attr_rows(&self) -> AttributeRowsMap {
|
||||
let mut attr_rows: AttributeRowsMap = AttributeRowsMap::default();
|
||||
|
||||
for (i, r) in records.iter().enumerate() {
|
||||
for (i, r) in self.records.iter().enumerate() {
|
||||
for v in r.values.iter() {
|
||||
attr_rows
|
||||
.entry(v)
|
||||
.entry(v.clone())
|
||||
.or_insert_with(AttributeRows::new)
|
||||
.push(i);
|
||||
}
|
||||
}
|
||||
attr_rows
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that needs to be implement to create a data block.
|
||||
/// It already contains the logic to create the data block, so we only
|
||||
/// need to worry about mapping the headers and records from InputType
|
||||
pub trait DataBlockCreator {
|
||||
/// Creator input type, it can be a File Reader, another data structure...
|
||||
type InputType;
|
||||
/// The error type that can be generated when parsing headers/records
|
||||
type ErrorType;
|
||||
|
||||
// Calculates the rows where each value on the data records is present
|
||||
/// grouped by column index.
|
||||
#[inline]
|
||||
fn gen_use_columns_set(headers: &CsvRecordSlice, use_columns: &[String]) -> HashSet<usize> {
|
||||
let use_columns_str_set: HashSet<String> = use_columns
|
||||
.iter()
|
||||
.map(|c| c.trim().to_lowercase())
|
||||
.collect();
|
||||
headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if use_columns_str_set.is_empty()
|
||||
|| use_columns_str_set.contains(&h.trim().to_lowercase())
|
||||
{
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
pub fn calc_attr_rows_with_no_empty_values(&self) -> AttributeRowsByColumnMap {
|
||||
let mut attr_rows_by_column: AttributeRowsByColumnMap = AttributeRowsByColumnMap::default();
|
||||
|
||||
for (i, r) in self.records.iter().enumerate() {
|
||||
for v in r.values.iter() {
|
||||
attr_rows_by_column
|
||||
.entry(v.column_index)
|
||||
.or_insert_with(AttributeRowsMap::default)
|
||||
.entry(v.clone())
|
||||
.or_insert_with(AttributeRows::new)
|
||||
.push(i);
|
||||
}
|
||||
}
|
||||
attr_rows_by_column
|
||||
}
|
||||
|
||||
/// Calculates the rows where each value on the data records is present
|
||||
/// grouped by column index. This will also include empty values mapped with
|
||||
/// the `empty_value` string.
|
||||
/// # Arguments
|
||||
/// * `empty_value` - Empty values on the final synthetic data will be represented by this
|
||||
#[inline]
|
||||
pub fn calc_attr_rows_by_column(&self, empty_value: &Arc<String>) -> AttributeRowsByColumnMap {
|
||||
let mut attr_rows_by_column: FnvHashMap<
|
||||
usize,
|
||||
FnvHashMap<Arc<DataBlockValue>, RecordsSet>,
|
||||
> = FnvHashMap::default();
|
||||
let empty_records: RecordsSet = (0..self.records.len()).collect();
|
||||
|
||||
// start with empty values for all columns
|
||||
for column_index in 0..self.headers.len() {
|
||||
attr_rows_by_column
|
||||
.entry(column_index)
|
||||
.or_insert_with(FnvHashMap::default)
|
||||
.entry(Arc::new(DataBlockValue::new(
|
||||
column_index,
|
||||
empty_value.clone(),
|
||||
)))
|
||||
.or_insert_with(|| empty_records.clone());
|
||||
}
|
||||
|
||||
// go through each record and map where the values occur on the columns
|
||||
for (i, r) in self.records.iter().enumerate() {
|
||||
for value in r.values.iter() {
|
||||
let current_attr_rows = attr_rows_by_column
|
||||
.entry(value.column_index)
|
||||
.or_insert_with(FnvHashMap::default);
|
||||
|
||||
// insert it on the correspondent entry for the data block value
|
||||
current_attr_rows
|
||||
.entry(value.clone())
|
||||
.or_insert_with(RecordsSet::default)
|
||||
.insert(i);
|
||||
// it's now used being used, so we make sure to remove this from the column empty records
|
||||
current_attr_rows
|
||||
.entry(Arc::new(DataBlockValue::new(
|
||||
value.column_index,
|
||||
empty_value.clone(),
|
||||
)))
|
||||
.or_insert_with(RecordsSet::default)
|
||||
.remove(&i);
|
||||
}
|
||||
}
|
||||
|
||||
// sort the records ids, so we can leverage the intersection alg later on
|
||||
attr_rows_by_column
|
||||
.drain()
|
||||
.map(|(column_index, mut attr_rows)| {
|
||||
(
|
||||
column_index,
|
||||
attr_rows
|
||||
.drain()
|
||||
.map(|(value, mut rows_set)| (value, rows_set.drain().sorted().collect()))
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn gen_sensitive_zeros_set(
|
||||
filtered_headers: &CsvRecordSlice,
|
||||
sensitive_zeros: &[String],
|
||||
) -> HashSet<usize> {
|
||||
let sensitive_zeros_str_set: HashSet<String> = sensitive_zeros
|
||||
.iter()
|
||||
.map(|c| c.trim().to_lowercase())
|
||||
.collect();
|
||||
filtered_headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if sensitive_zeros_str_set.contains(&h.trim().to_lowercase()) {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
/// Returns the number of records on the data block
|
||||
pub fn number_of_records(&self) -> usize {
|
||||
self.records.len()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn map_headers(
|
||||
headers: &mut CsvRecord,
|
||||
use_columns: &[String],
|
||||
sensitive_zeros: &[String],
|
||||
) -> (CsvRecord, HashSet<usize>, HashSet<usize>) {
|
||||
let use_columns_set = Self::gen_use_columns_set(headers, use_columns);
|
||||
let filtered_headers: CsvRecord = headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if use_columns_set.contains(&i) {
|
||||
Some(h.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let sensitive_zeros_set = Self::gen_sensitive_zeros_set(&filtered_headers, sensitive_zeros);
|
||||
|
||||
(filtered_headers, use_columns_set, sensitive_zeros_set)
|
||||
/// Returns the number of records on the data block protected by `resolution`
|
||||
pub fn protected_number_of_records(&self, resolution: usize) -> usize {
|
||||
uround_down(self.number_of_records() as f64, resolution as f64)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn map_records(
|
||||
records: &[CsvRecord],
|
||||
use_columns_set: &HashSet<usize>,
|
||||
sensitive_zeros_set: &HashSet<usize>,
|
||||
record_limit: usize,
|
||||
) -> DataBlockRecords {
|
||||
let map_result = |record: &CsvRecord| {
|
||||
let values: CsvRecord = record
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if use_columns_set.contains(&i) {
|
||||
Some(h.trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
DataBlockRecord::new(
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, r)| {
|
||||
let record_val = r.trim();
|
||||
if !record_val.is_empty()
|
||||
&& (sensitive_zeros_set.contains(&i) || record_val != "0")
|
||||
{
|
||||
Some(DataBlockValue::new(i, record_val.into()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
|
||||
if record_limit > 0 {
|
||||
records.iter().take(record_limit).map(map_result).collect()
|
||||
/// Normalizes the reporting length based on the number of selected headers.
|
||||
/// Returns the normalized value
|
||||
/// # Arguments
|
||||
/// * `reporting_length` - Reporting length to be normalized (0 means use all columns)
|
||||
pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize {
|
||||
if reporting_length == 0 {
|
||||
self.headers.len()
|
||||
} else {
|
||||
records.iter().map(map_result).collect()
|
||||
usize::min(reporting_length, self.headers.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn create(
|
||||
input_res: Result<Self::InputType, Self::ErrorType>,
|
||||
use_columns: &[String],
|
||||
sensitive_zeros: &[String],
|
||||
record_limit: usize,
|
||||
) -> Result<DataBlock, Self::ErrorType> {
|
||||
let mut input = input_res?;
|
||||
let (headers, use_columns_set, sensitive_zeros_set) = Self::map_headers(
|
||||
&mut Self::get_headers(&mut input)?,
|
||||
use_columns,
|
||||
sensitive_zeros,
|
||||
);
|
||||
let records = Self::map_records(
|
||||
&Self::get_records(&mut input)?,
|
||||
&use_columns_set,
|
||||
&sensitive_zeros_set,
|
||||
record_limit,
|
||||
);
|
||||
|
||||
Ok(DataBlock::new(headers, records))
|
||||
}
|
||||
|
||||
/// Should be implemented to return the CsvRecords reprensenting the headers
|
||||
fn get_headers(input: &mut Self::InputType) -> Result<CsvRecord, Self::ErrorType>;
|
||||
|
||||
/// Should be implemented to return the vector of CsvRecords reprensenting rows
|
||||
fn get_records(input: &mut Self::InputType) -> Result<Vec<CsvRecord>, Self::ErrorType>;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{block::DataBlockCreator, typedefs::CsvRecord};
|
||||
use super::{data_block_creator::DataBlockCreator, typedefs::CsvRecord};
|
||||
use csv::{Error, Reader, StringRecord};
|
||||
use std::fs::File;
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
use csv::Error;
|
||||
use std::fmt::{Display, Formatter, Result};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::exceptions::PyIOError;
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
/// Wrapper for a csv::Error, so the from
|
||||
/// trait can be implemented for PyErr
|
||||
pub struct CsvIOError {
|
||||
error: Error,
|
||||
}
|
||||
|
||||
impl CsvIOError {
|
||||
/// Creates a new CsvIOError from a csv::Error
|
||||
/// # Arguments
|
||||
/// * `error` - Related csv::Error
|
||||
pub fn new(error: Error) -> CsvIOError {
|
||||
CsvIOError { error }
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for CsvIOError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
write!(f, "{}", self.error)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
impl From<CsvIOError> for PyErr {
|
||||
fn from(err: CsvIOError) -> PyErr {
|
||||
PyIOError::new_err(err.error.to_string())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
use super::{
|
||||
block::DataBlock,
|
||||
record::DataBlockRecord,
|
||||
typedefs::{CsvRecord, CsvRecordRef, CsvRecordRefSlice, CsvRecordSlice, DataBlockRecords},
|
||||
value::DataBlockValue,
|
||||
};
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
/// Trait that needs to be implement to create a data block.
|
||||
/// It already contains the logic to create the data block, so we only
|
||||
/// need to worry about mapping the headers and records from InputType
|
||||
pub trait DataBlockCreator {
|
||||
/// Creator input type, it can be a File Reader, another data structure...
|
||||
type InputType;
|
||||
/// The error type that can be generated when parsing headers/records
|
||||
type ErrorType;
|
||||
|
||||
#[inline]
|
||||
fn gen_use_columns_set(headers: &CsvRecordSlice, use_columns: &[String]) -> HashSet<usize> {
|
||||
let use_columns_str_set: HashSet<String> = use_columns
|
||||
.iter()
|
||||
.map(|c| c.trim().to_lowercase())
|
||||
.collect();
|
||||
headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if use_columns_str_set.is_empty()
|
||||
|| use_columns_str_set.contains(&h.trim().to_lowercase())
|
||||
{
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn gen_sensitive_zeros_set(
|
||||
filtered_headers: &CsvRecordRefSlice,
|
||||
sensitive_zeros: &[String],
|
||||
) -> HashSet<usize> {
|
||||
let sensitive_zeros_str_set: HashSet<String> = sensitive_zeros
|
||||
.iter()
|
||||
.map(|c| c.trim().to_lowercase())
|
||||
.collect();
|
||||
filtered_headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if sensitive_zeros_str_set.contains(&h.trim().to_lowercase()) {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn normalize_value(value: &str) -> String {
|
||||
value
|
||||
.trim()
|
||||
// replace reserved delimiters
|
||||
.replace(";", "<semicolon>")
|
||||
.replace(":", "<colon>")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn map_headers(
|
||||
headers: &mut CsvRecord,
|
||||
use_columns: &[String],
|
||||
sensitive_zeros: &[String],
|
||||
) -> (CsvRecordRef, HashSet<usize>, HashSet<usize>) {
|
||||
let use_columns_set = Self::gen_use_columns_set(headers, use_columns);
|
||||
let filtered_headers: CsvRecordRef = headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if use_columns_set.contains(&i) {
|
||||
Some(Arc::new(Self::normalize_value(h)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let sensitive_zeros_set = Self::gen_sensitive_zeros_set(&filtered_headers, sensitive_zeros);
|
||||
|
||||
(filtered_headers, use_columns_set, sensitive_zeros_set)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn map_records(
|
||||
records: &[CsvRecord],
|
||||
use_columns_set: &HashSet<usize>,
|
||||
sensitive_zeros_set: &HashSet<usize>,
|
||||
record_limit: usize,
|
||||
) -> DataBlockRecords {
|
||||
let map_result = |record: &CsvRecord| {
|
||||
let values: CsvRecord = record
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, h)| {
|
||||
if use_columns_set.contains(&i) {
|
||||
Some(h.trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Arc::new(DataBlockRecord::new(
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, r)| {
|
||||
let record_val = Self::normalize_value(r);
|
||||
if !record_val.is_empty()
|
||||
&& (sensitive_zeros_set.contains(&i) || record_val != "0")
|
||||
{
|
||||
Some(Arc::new(DataBlockValue::new(i, Arc::new(record_val))))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
};
|
||||
|
||||
if record_limit > 0 {
|
||||
records.iter().take(record_limit).map(map_result).collect()
|
||||
} else {
|
||||
records.iter().map(map_result).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn create(
|
||||
input_res: Result<Self::InputType, Self::ErrorType>,
|
||||
use_columns: &[String],
|
||||
sensitive_zeros: &[String],
|
||||
record_limit: usize,
|
||||
) -> Result<Arc<DataBlock>, Self::ErrorType> {
|
||||
let mut input = input_res?;
|
||||
let (headers, use_columns_set, sensitive_zeros_set) = Self::map_headers(
|
||||
&mut Self::get_headers(&mut input)?,
|
||||
use_columns,
|
||||
sensitive_zeros,
|
||||
);
|
||||
let records = Self::map_records(
|
||||
&Self::get_records(&mut input)?,
|
||||
&use_columns_set,
|
||||
&sensitive_zeros_set,
|
||||
record_limit,
|
||||
);
|
||||
|
||||
Ok(Arc::new(DataBlock::new(headers, records)))
|
||||
}
|
||||
|
||||
/// Should be implemented to return the CsvRecords representing the headers
|
||||
fn get_headers(input: &mut Self::InputType) -> Result<CsvRecord, Self::ErrorType>;
|
||||
|
||||
/// Should be implemented to return the vector of CsvRecords representing rows
|
||||
fn get_records(input: &mut Self::InputType) -> Result<Vec<CsvRecord>, Self::ErrorType>;
|
||||
}
|
|
@ -1,10 +1,21 @@
|
|||
/// Module defining the structures that represent a data block
|
||||
pub mod block;
|
||||
|
||||
/// Module to create data blocks from CSV files
|
||||
pub mod csv_block_creator;
|
||||
|
||||
/// Defines io errors for handling csv files
|
||||
/// that are easier to bind to other languages
|
||||
pub mod csv_io_error;
|
||||
|
||||
/// Module to create data blocks from different input types (trait definitions)
|
||||
pub mod data_block_creator;
|
||||
|
||||
/// Module defining the structures that represent a data block record
|
||||
pub mod record;
|
||||
|
||||
/// Type definitions related to data blocks
|
||||
pub mod typedefs;
|
||||
|
||||
/// Module defining the structures that represent a data block value
|
||||
pub mod value;
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use super::value::DataBlockValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Represents all the values of a given row in a data block
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DataBlockRecord {
|
||||
/// Vector of data block values for a given row indexed by column
|
||||
pub values: Vec<DataBlockValue>,
|
||||
pub values: Vec<Arc<DataBlockValue>>,
|
||||
}
|
||||
|
||||
impl DataBlockRecord {
|
||||
|
@ -12,7 +14,7 @@ impl DataBlockRecord {
|
|||
/// # Arguments
|
||||
/// * `values` - Vector of data block values for a given row indexed by column
|
||||
#[inline]
|
||||
pub fn new(values: Vec<DataBlockValue>) -> DataBlockRecord {
|
||||
pub fn new(values: Vec<Arc<DataBlockValue>>) -> DataBlockRecord {
|
||||
DataBlockRecord { values }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use super::{record::DataBlockRecord, value::DataBlockValue};
|
||||
use fnv::FnvHashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::processing::evaluator::rare_combinations_comparison_data::CombinationComparison;
|
||||
|
||||
/// Ordered vector of rows where a particular value combination is present
|
||||
pub type AttributeRows = Vec<usize>;
|
||||
|
@ -11,23 +14,38 @@ pub type AttributeRowsSlice = [usize];
|
|||
pub type CsvRecord = Vec<String>;
|
||||
|
||||
/// The same as CsvRecord, but keeping a reference to the string value, so data does not have to be duplicated
|
||||
pub type CsvRecordRef<'data_block_value> = Vec<&'data_block_value str>;
|
||||
pub type CsvRecordRef = Vec<Arc<String>>;
|
||||
|
||||
/// Slice of CsvRecord
|
||||
pub type CsvRecordSlice = [String];
|
||||
|
||||
/// Vector of strings representhing the data block headers
|
||||
pub type DataBlockHeaders = CsvRecord;
|
||||
/// Slice of CsvRecord
|
||||
pub type CsvRecordRefSlice = [Arc<String>];
|
||||
|
||||
/// Vector of strings representing the data block headers
|
||||
pub type DataBlockHeaders = CsvRecordRef;
|
||||
|
||||
/// Slice of DataBlockHeaders
|
||||
pub type DataBlockHeadersSlice = CsvRecordSlice;
|
||||
pub type DataBlockHeadersSlice = CsvRecordRefSlice;
|
||||
|
||||
/// Vector of data block records, where each record represents a row
|
||||
pub type DataBlockRecords = Vec<DataBlockRecord>;
|
||||
|
||||
/// Slice of DataBlockRecords
|
||||
pub type DataBlockRecordsSlice = [DataBlockRecord];
|
||||
pub type DataBlockRecords = Vec<Arc<DataBlockRecord>>;
|
||||
|
||||
/// HashMap with a data block value as key and all the attribute row indexes where it occurs as value
|
||||
pub type AttributeRowsMap<'data_block_value> =
|
||||
FnvHashMap<&'data_block_value DataBlockValue, AttributeRows>;
|
||||
pub type AttributeRowsMap = FnvHashMap<Arc<DataBlockValue>, AttributeRows>;
|
||||
|
||||
/// HashMap with a data block value as key and attribute row indexes as value
|
||||
pub type AttributeRowsRefMap = FnvHashMap<Arc<DataBlockValue>, Arc<AttributeRows>>;
|
||||
|
||||
/// Maps the column index -> data block value -> rows where the value appear
|
||||
pub type AttributeRowsByColumnMap = FnvHashMap<usize, AttributeRowsMap>;
|
||||
|
||||
/// Raw synthesized data (vector of csv record references to the original data block)
|
||||
pub type RawSyntheticData = Vec<CsvRecordRef>;
|
||||
|
||||
/// A vector of combination comparisons
|
||||
/// (between sensitive and synthetic data)
|
||||
pub type CombinationsComparisons = Vec<CombinationComparison>;
|
||||
|
||||
/// Maps a column name to the corresponding column index
|
||||
pub type ColumnIndexByName = FnvHashMap<String, usize>;
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
use super::typedefs::DataBlockHeadersSlice;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Display, str::FromStr, sync::Arc};
|
||||
|
||||
const VALUE_DELIMITER: char = ':';
|
||||
|
||||
/// Represents a value of a given data block for a particular row and column
|
||||
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct DataBlockValue {
|
||||
/// Column index this value belongs to starting in '0'
|
||||
pub column_index: usize,
|
||||
/// Value stored on the CSV file for a given row at `column_index`
|
||||
pub value: String,
|
||||
pub value: Arc<String>,
|
||||
}
|
||||
|
||||
impl DataBlockValue {
|
||||
|
@ -13,10 +19,72 @@ impl DataBlockValue {
|
|||
/// * `column_index` - Column index this value belongs to starting in '0'
|
||||
/// * `value` - Value stored on the CSV file for a given row at `column_index`
|
||||
#[inline]
|
||||
pub fn new(column_index: usize, value: String) -> DataBlockValue {
|
||||
pub fn new(column_index: usize, value: Arc<String>) -> DataBlockValue {
|
||||
DataBlockValue {
|
||||
column_index,
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
/// Formats a data block value as String using the
|
||||
/// corresponding header name
|
||||
/// The result is formatted as: `{header_name}:{block_value}`
|
||||
/// # Arguments
|
||||
/// * `headers` - data block headers
|
||||
/// * `value` - value to be formatted
|
||||
#[inline]
|
||||
pub fn format_str_using_headers(&self, headers: &DataBlockHeadersSlice) -> String {
|
||||
format!(
|
||||
"{}{}{}",
|
||||
headers[self.column_index], VALUE_DELIMITER, self.value
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Error that can happen when parsing a data block from
|
||||
/// a string
|
||||
pub struct ParseDataBlockValueError {
|
||||
error_message: String,
|
||||
}
|
||||
|
||||
impl ParseDataBlockValueError {
|
||||
#[inline]
|
||||
/// Creates a new ParseDataBlockValueError with `error_message`
|
||||
pub fn new(error_message: String) -> ParseDataBlockValueError {
|
||||
ParseDataBlockValueError { error_message }
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ParseDataBlockValueError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.error_message)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for DataBlockValue {
|
||||
type Err = ParseDataBlockValueError;
|
||||
|
||||
/// Creates a new DataBlockValue by parsing `str_value`
|
||||
fn from_str(str_value: &str) -> Result<Self, Self::Err> {
|
||||
if let Some(pos) = str_value.find(VALUE_DELIMITER) {
|
||||
Ok(DataBlockValue::new(
|
||||
str_value[..pos]
|
||||
.parse::<usize>()
|
||||
.map_err(|err| ParseDataBlockValueError::new(err.to_string()))?,
|
||||
Arc::new(str_value[pos + 1..].into()),
|
||||
))
|
||||
} else {
|
||||
Err(ParseDataBlockValueError::new(format!(
|
||||
"data block value missing '{}'",
|
||||
VALUE_DELIMITER
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for DataBlockValue {
|
||||
/// Formats the DataBlockValue as a string
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
Ok(write!(f, "{}:{}", self.column_index, self.value)?)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,448 @@
|
|||
use super::{
|
||||
privacy_risk_summary::PrivacyRiskSummary,
|
||||
records_analysis_data::RecordsAnalysisData,
|
||||
typedefs::{
|
||||
AggregatedCountByLenMap, AggregatesCountMap, AggregatesCountStringMap, RecordsByLenMap,
|
||||
RecordsSensitivity,
|
||||
},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
io::{BufReader, BufWriter, Error, Write},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::{
|
||||
data_block::block::DataBlock,
|
||||
processing::aggregator::typedefs::RecordsSet,
|
||||
utils::{math::uround_down, time::ElapsedDurationLogger},
|
||||
};
|
||||
|
||||
/// Aggregated data produced by the Aggregator
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AggregatedData {
|
||||
/// Data block from where this aggregated data was generated
|
||||
pub data_block: Arc<DataBlock>,
|
||||
/// Maps a value combination to its aggregated count
|
||||
pub aggregates_count: AggregatesCountMap,
|
||||
/// A vector of sensitivities for each record (the vector index is the record index)
|
||||
pub records_sensitivity: RecordsSensitivity,
|
||||
/// Maximum length used to compute attribute combinations
|
||||
pub reporting_length: usize,
|
||||
}
|
||||
|
||||
impl AggregatedData {
|
||||
/// Creates a new AggregatedData struct with default values
|
||||
#[inline]
|
||||
pub fn default() -> AggregatedData {
|
||||
AggregatedData {
|
||||
data_block: Arc::new(DataBlock::default()),
|
||||
aggregates_count: AggregatesCountMap::default(),
|
||||
records_sensitivity: RecordsSensitivity::default(),
|
||||
reporting_length: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new AggregatedData struct
|
||||
/// # Arguments:
|
||||
/// * `data_block` - Data block with the original data
|
||||
/// * `aggregates_count` - Computed aggregates count map
|
||||
/// * `records_sensitivity` - Computed sensitivity for the records
|
||||
/// * `reporting_length` - Maximum length used to compute attribute combinations
|
||||
#[inline]
|
||||
pub fn new(
|
||||
data_block: Arc<DataBlock>,
|
||||
aggregates_count: AggregatesCountMap,
|
||||
records_sensitivity: RecordsSensitivity,
|
||||
reporting_length: usize,
|
||||
) -> AggregatedData {
|
||||
AggregatedData {
|
||||
data_block,
|
||||
aggregates_count,
|
||||
records_sensitivity,
|
||||
reporting_length,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if the records len map contains a value across all lengths
|
||||
fn records_by_len_contains(records_by_len: &RecordsByLenMap, value: &usize) -> bool {
|
||||
records_by_len
|
||||
.values()
|
||||
.any(|records| records.contains(value))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Whe first generated the RecordsByLenMap might contain
|
||||
/// the same records appearing in different combination lengths.
|
||||
/// This will keep only the record on the shortest length key
|
||||
/// and remove the other occurrences
|
||||
fn keep_records_only_on_shortest_len(records_by_len: &mut RecordsByLenMap) {
|
||||
let lengths: Vec<usize> = records_by_len.keys().cloned().sorted().collect();
|
||||
let max_len = lengths.last().copied().unwrap_or(0);
|
||||
|
||||
// make sure the record will be only present in the shortest len
|
||||
// start on the shortest length
|
||||
for l in lengths {
|
||||
for r in records_by_len.get(&l).unwrap().clone() {
|
||||
// search all lengths > l and remove the record from there
|
||||
for n in l + 1..=max_len {
|
||||
if let Some(records) = records_by_len.get_mut(&n) {
|
||||
records.remove(&r);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// retain only non-empty record lists
|
||||
records_by_len.retain(|_, records| !records.is_empty());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn _read_from_json(file_path: &str) -> Result<AggregatedData, Error> {
|
||||
let _duration_logger = ElapsedDurationLogger::new("aggregated count json read");
|
||||
|
||||
Ok(serde_json::from_reader(BufReader::new(
|
||||
std::fs::File::open(file_path)?,
|
||||
))?)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl AggregatedData {
|
||||
/// Builds a map from value combinations formatted as string to its aggregated count
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
/// # Arguments:
|
||||
/// * `combination_delimiter` - Delimiter used to join combinations
|
||||
pub fn get_formatted_aggregates_count(
|
||||
&self,
|
||||
combination_delimiter: &str,
|
||||
) -> AggregatesCountStringMap {
|
||||
self.aggregates_count
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
(
|
||||
key.format_str_using_headers(&self.data_block.headers, combination_delimiter),
|
||||
value.clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
/// A vector of sensitivities for each record (the vector index is the record index)
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
pub fn get_records_sensitivity(&self) -> RecordsSensitivity {
|
||||
self.records_sensitivity.clone()
|
||||
}
|
||||
|
||||
/// Round the aggregated counts down to the nearest multiple of resolution
|
||||
/// # Arguments:
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn protect_aggregates_count(&mut self, resolution: usize) {
|
||||
let _duration_logger = ElapsedDurationLogger::new("aggregates count protect");
|
||||
|
||||
info!(
|
||||
"protecting aggregates counts with resolution {}",
|
||||
resolution
|
||||
);
|
||||
|
||||
for count in self.aggregates_count.values_mut() {
|
||||
count.count = uround_down(count.count as f64, resolution as f64);
|
||||
}
|
||||
// remove 0 counts from response
|
||||
self.aggregates_count.retain(|_, count| count.count > 0);
|
||||
}
|
||||
|
||||
/// Calculates the records that contain rare combinations grouped by length.
|
||||
/// This might contain duplicated records on different lengths if the record
|
||||
/// contains more than one rare combination. Unique combinations are also contained
|
||||
/// in this.
|
||||
/// # Arguments:
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_all_rare_combinations_records_by_len(&self, resolution: usize) -> RecordsByLenMap {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new("all rare combinations records by len calculation");
|
||||
let mut rare_records_by_len: RecordsByLenMap = RecordsByLenMap::default();
|
||||
|
||||
for (agg, count) in self.aggregates_count.iter() {
|
||||
if count.count < resolution {
|
||||
rare_records_by_len
|
||||
.entry(agg.len())
|
||||
.or_insert_with(RecordsSet::default)
|
||||
.extend(&count.contained_in_records);
|
||||
}
|
||||
}
|
||||
rare_records_by_len
|
||||
}
|
||||
|
||||
/// Calculates the records that contain unique combinations grouped by length.
|
||||
/// This might contain duplicated records on different lengths if the record
|
||||
/// contains more than one unique combination.
|
||||
pub fn calc_all_unique_combinations_records_by_len(&self) -> RecordsByLenMap {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new("all unique combinations records by len calculation");
|
||||
let mut unique_records_by_len: RecordsByLenMap = RecordsByLenMap::default();
|
||||
|
||||
for (agg, count) in self.aggregates_count.iter() {
|
||||
if count.count == 1 {
|
||||
unique_records_by_len
|
||||
.entry(agg.len())
|
||||
.or_insert_with(RecordsSet::default)
|
||||
.extend(&count.contained_in_records);
|
||||
}
|
||||
}
|
||||
unique_records_by_len
|
||||
}
|
||||
|
||||
/// Calculate the records that contain unique and rare combinations grouped by length.
|
||||
/// A tuple with the `(unique, rare)` is returned.
|
||||
/// Both returned maps are ensured to only contain the records on the shortest length,
|
||||
/// so each record will appear only on the shortest combination length that isolates it
|
||||
/// within a rare group. Also, if the record has a unique combination, it will not
|
||||
/// be present on the rare map, only on the unique one.
|
||||
/// # Arguments:
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_unique_rare_combinations_records_by_len(
|
||||
&self,
|
||||
resolution: usize,
|
||||
) -> (RecordsByLenMap, RecordsByLenMap) {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new("unique/rare combinations records by len calculation");
|
||||
let mut unique_records_by_len = self.calc_all_unique_combinations_records_by_len();
|
||||
let mut rare_records_by_len = self.calc_all_rare_combinations_records_by_len(resolution);
|
||||
|
||||
AggregatedData::keep_records_only_on_shortest_len(&mut unique_records_by_len);
|
||||
AggregatedData::keep_records_only_on_shortest_len(&mut rare_records_by_len);
|
||||
|
||||
// remove records with unique combinations from the rare map
|
||||
rare_records_by_len.values_mut().for_each(|records| {
|
||||
records.retain(|r| !AggregatedData::records_by_len_contains(&unique_records_by_len, r));
|
||||
});
|
||||
|
||||
(unique_records_by_len, rare_records_by_len)
|
||||
}
|
||||
|
||||
/// Perform the records analysis and returns the data containing
|
||||
/// unique, rare and risky information grouped per length.
|
||||
/// # Arguments:
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `protect` - Whether or not the counts should be rounded to the
|
||||
/// nearest smallest multiple of resolution
|
||||
pub fn calc_records_analysis_by_len(
|
||||
&self,
|
||||
resolution: usize,
|
||||
protect: bool,
|
||||
) -> RecordsAnalysisData {
|
||||
let _duration_logger = ElapsedDurationLogger::new("records analysis by len");
|
||||
let (unique_records_by_len, rare_records_by_len) =
|
||||
self.calc_unique_rare_combinations_records_by_len(resolution);
|
||||
|
||||
RecordsAnalysisData::from_unique_rare_combinations_records_by_len(
|
||||
&unique_records_by_len,
|
||||
&rare_records_by_len,
|
||||
self.data_block.records.len(),
|
||||
self.reporting_length,
|
||||
resolution,
|
||||
protect,
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates the number of rare combinations grouped by combination length
|
||||
/// # Arguments:
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_rare_combinations_count_by_len(
|
||||
&self,
|
||||
resolution: usize,
|
||||
) -> AggregatedCountByLenMap {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new("rare combinations count by len calculation");
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
info!(
|
||||
"calculating rare combinations counts by length with resolution {}",
|
||||
resolution
|
||||
);
|
||||
|
||||
for (agg, count) in self.aggregates_count.iter() {
|
||||
if count.count < resolution {
|
||||
let curr_count = result.entry(agg.len()).or_insert(0);
|
||||
*curr_count += 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Calculates the number of combinations grouped by combination length
|
||||
pub fn calc_combinations_count_by_len(&self) -> AggregatedCountByLenMap {
|
||||
let _duration_logger = ElapsedDurationLogger::new("combination count by len calculation");
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
info!("calculating combination counts by length");
|
||||
|
||||
for agg in self.aggregates_count.keys() {
|
||||
let curr_count = result.entry(agg.len()).or_insert(0);
|
||||
*curr_count += 1;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Calculates the sum of all combination counts grouped by combination length
|
||||
pub fn calc_combinations_sum_by_len(&self) -> AggregatedCountByLenMap {
|
||||
let _duration_logger = ElapsedDurationLogger::new("combinations sum by len calculation");
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
info!("calculating combination counts sums by length");
|
||||
|
||||
for (agg, count) in self.aggregates_count.iter() {
|
||||
let curr_sum = result.entry(agg.len()).or_insert(0);
|
||||
*curr_sum += count.count;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Calculates the privacy risk related with data block and the generated
|
||||
/// aggregates counts
|
||||
/// # Arguments:
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_privacy_risk(&self, resolution: usize) -> PrivacyRiskSummary {
|
||||
let _duration_logger = ElapsedDurationLogger::new("privacy risk calculation");
|
||||
|
||||
info!("calculating privacy risk...");
|
||||
|
||||
PrivacyRiskSummary::from_aggregates_count(
|
||||
self.data_block.records.len(),
|
||||
&self.aggregates_count,
|
||||
resolution,
|
||||
)
|
||||
}
|
||||
|
||||
/// Writes the aggregates counts to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `aggregates_path` - File path to be written
|
||||
/// * `aggregates_delimiter` - Delimiter to use when writing to `aggregates_path`
|
||||
/// * `combination_delimiter` - Delimiter used to join combinations and format then
|
||||
/// as strings
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `protected` - Whether or not the counts were protected before calling this
|
||||
pub fn write_aggregates_count(
|
||||
&self,
|
||||
aggregates_path: &str,
|
||||
aggregates_delimiter: char,
|
||||
combination_delimiter: &str,
|
||||
resolution: usize,
|
||||
protected: bool,
|
||||
) -> Result<(), Error> {
|
||||
let _duration_logger = ElapsedDurationLogger::new("write aggregates count");
|
||||
|
||||
info!("writing file {}", aggregates_path);
|
||||
|
||||
let mut file = std::fs::File::create(aggregates_path)?;
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"selections{}{}\n",
|
||||
aggregates_delimiter,
|
||||
if protected {
|
||||
"protected_count"
|
||||
} else {
|
||||
"count"
|
||||
}
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
file.write_all(
|
||||
format!(
|
||||
"selections{}{}\n",
|
||||
aggregates_delimiter,
|
||||
uround_down(self.data_block.records.len() as f64, resolution as f64)
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
for aggregate in self.aggregates_count.keys() {
|
||||
file.write_all(
|
||||
format!(
|
||||
"{}{}{}\n",
|
||||
aggregate
|
||||
.format_str_using_headers(&self.data_block.headers, combination_delimiter),
|
||||
aggregates_delimiter,
|
||||
self.aggregates_count[aggregate].count
|
||||
)
|
||||
.as_bytes(),
|
||||
)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Writes the records sensitivity to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `records_sensitivity_path` - File path to be written
|
||||
/// * `records_sensitivity_delimiter` - Delimiter to use when writing to `records_sensitivity_path`
|
||||
pub fn write_records_sensitivity(
|
||||
&self,
|
||||
records_sensitivity_path: &str,
|
||||
records_sensitivity_delimiter: char,
|
||||
) -> Result<(), Error> {
|
||||
let _duration_logger = ElapsedDurationLogger::new("write records sensitivity");
|
||||
|
||||
info!("writing file {}", records_sensitivity_path);
|
||||
|
||||
let mut file = std::fs::File::create(records_sensitivity_path)?;
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"record_index{}record_sensitivity\n",
|
||||
records_sensitivity_delimiter
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
for (i, sensitivity) in self.records_sensitivity.iter().enumerate() {
|
||||
file.write_all(
|
||||
format!("{}{}{}\n", i, records_sensitivity_delimiter, sensitivity).as_bytes(),
|
||||
)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Serializes the aggregated data to a json file
|
||||
/// # Arguments:
|
||||
/// * `file_path` - File path to be written
|
||||
pub fn write_to_json(&self, file_path: &str) -> Result<(), Error> {
|
||||
let _duration_logger = ElapsedDurationLogger::new("aggregated count json write");
|
||||
|
||||
Ok(serde_json::to_writer(
|
||||
BufWriter::new(std::fs::File::create(file_path)?),
|
||||
&self,
|
||||
)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
#[staticmethod]
|
||||
/// Deserializes the aggregated data from a json file
|
||||
/// # Arguments:
|
||||
/// * `file_path` - File path to read from
|
||||
pub fn read_from_json(file_path: &str) -> Result<AggregatedData, Error> {
|
||||
AggregatedData::_read_from_json(file_path)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "pyo3"))]
|
||||
/// Deserializes the aggregated data from a json file
|
||||
/// # Arguments:
|
||||
/// * `file_path` - File path to read from
|
||||
pub fn read_from_json(file_path: &str) -> Result<AggregatedData, Error> {
|
||||
AggregatedData::_read_from_json(file_path)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
pub fn register(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<AggregatedData>()?;
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
use super::aggregated_data::AggregatedData;
|
||||
use super::rows_aggregator::RowsAggregator;
|
||||
use super::typedefs::RecordsSet;
|
||||
use itertools::Itertools;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::data_block::block::DataBlock;
|
||||
use crate::processing::aggregator::record_attrs_selector::RecordAttrsSelector;
|
||||
use crate::utils::math::calc_percentage;
|
||||
use crate::utils::reporting::ReportProgress;
|
||||
use crate::utils::threading::get_number_of_threads;
|
||||
use crate::utils::time::ElapsedDurationLogger;
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
/// Result of data aggregation for each combination
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AggregatedCount {
|
||||
/// How many times the combination appears on the records
|
||||
pub count: usize,
|
||||
/// Which records this combinations is part of
|
||||
pub contained_in_records: RecordsSet,
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl AggregatedCount {
|
||||
/// How many times the combination appears on the records
|
||||
#[getter]
|
||||
fn count(&self) -> usize {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Which records this combinations is part of
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
fn get_contained_in_records(&self) -> RecordsSet {
|
||||
self.contained_in_records.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AggregatedCount {
|
||||
fn default() -> AggregatedCount {
|
||||
AggregatedCount {
|
||||
count: 0,
|
||||
contained_in_records: RecordsSet::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a data block to produced aggregated data
|
||||
pub struct Aggregator {
|
||||
data_block: Arc<DataBlock>,
|
||||
}
|
||||
|
||||
impl Aggregator {
|
||||
/// Returns a data aggregator for the given data block
|
||||
/// # Arguments
|
||||
/// * `data_block` - The data block to be processed
|
||||
#[inline]
|
||||
pub fn new(data_block: Arc<DataBlock>) -> Aggregator {
|
||||
Aggregator { data_block }
|
||||
}
|
||||
|
||||
/// Aggregates the data block and returns the aggregated data back
|
||||
/// # Arguments
|
||||
/// * `reporting_length` - Calculate combinations from 1 up to `reporting_length`
|
||||
/// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes
|
||||
/// (0 means no suppression)
|
||||
/// * `progress_reporter` - Will be used to report the processing
|
||||
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
|
||||
pub fn aggregate<T>(
|
||||
&mut self,
|
||||
reporting_length: usize,
|
||||
sensitivity_threshold: usize,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> AggregatedData
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let _duration_logger = ElapsedDurationLogger::new("data aggregation");
|
||||
let normalized_reporting_length =
|
||||
self.data_block.normalize_reporting_length(reporting_length);
|
||||
let length_range = (1..=normalized_reporting_length).collect::<Vec<usize>>();
|
||||
let total_n_records = self.data_block.records.len();
|
||||
let total_n_records_f64 = total_n_records as f64;
|
||||
|
||||
info!(
|
||||
"aggregating data with reporting length = {}, sensitivity_threshold = {} and {} thread(s)",
|
||||
normalized_reporting_length, sensitivity_threshold, get_number_of_threads()
|
||||
);
|
||||
|
||||
let result = RowsAggregator::aggregate_all(
|
||||
total_n_records,
|
||||
&mut self.build_rows_aggregators(&length_range, sensitivity_threshold),
|
||||
progress_reporter,
|
||||
);
|
||||
|
||||
Aggregator::update_aggregate_progress(
|
||||
progress_reporter,
|
||||
total_n_records,
|
||||
total_n_records_f64,
|
||||
);
|
||||
|
||||
info!(
|
||||
"data aggregated resulting in {} distinct combinations...",
|
||||
result.aggregates_count.len()
|
||||
);
|
||||
info!(
|
||||
"suppression ratio of aggregates is {:.2}%",
|
||||
(1.0 - (result.selected_combs_count as f64 / result.all_combs_count as f64)) * 100.0
|
||||
);
|
||||
|
||||
AggregatedData::new(
|
||||
self.data_block.clone(),
|
||||
result.aggregates_count,
|
||||
result.records_sensitivity,
|
||||
normalized_reporting_length,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn build_rows_aggregators<'length_range>(
|
||||
&self,
|
||||
length_range: &'length_range [usize],
|
||||
sensitivity_threshold: usize,
|
||||
) -> Vec<RowsAggregator<'length_range>> {
|
||||
if self.data_block.records.is_empty() {
|
||||
return Vec::default();
|
||||
}
|
||||
|
||||
let chunk_size = ((self.data_block.records.len() as f64) / (get_number_of_threads() as f64))
|
||||
.ceil() as usize;
|
||||
let mut rows_aggregators: Vec<RowsAggregator> = Vec::default();
|
||||
let attr_rows_map = Arc::new(self.data_block.calc_attr_rows());
|
||||
|
||||
for c in &self
|
||||
.data_block
|
||||
.records
|
||||
.clone()
|
||||
.drain(..)
|
||||
.enumerate()
|
||||
.chunks(chunk_size)
|
||||
{
|
||||
rows_aggregators.push(RowsAggregator::new(
|
||||
self.data_block.clone(),
|
||||
c.collect(),
|
||||
RecordAttrsSelector::new(
|
||||
length_range,
|
||||
sensitivity_threshold,
|
||||
attr_rows_map.clone(),
|
||||
),
|
||||
))
|
||||
}
|
||||
rows_aggregators
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_aggregate_progress<T>(
|
||||
progress_reporter: &mut Option<T>,
|
||||
n_processed: usize,
|
||||
total: f64,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
if let Some(r) = progress_reporter {
|
||||
r.report(calc_percentage(n_processed as f64, total));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,501 +1,27 @@
|
|||
/// Module to represent aggregated data and provide
|
||||
/// some methods/utilities for information extracted from it
|
||||
pub mod aggregated_data;
|
||||
|
||||
/// Dataset privacy risk definitions
|
||||
pub mod privacy_risk_summary;
|
||||
|
||||
/// Module that can the used to suppress attributes from records to
|
||||
/// meet a certain sensitivity threshold
|
||||
pub mod record_attrs_selector;
|
||||
|
||||
/// Defines structures related to records analysis (unique, rare and risky
|
||||
/// information)
|
||||
pub mod records_analysis_data;
|
||||
|
||||
/// Type definitions related to the aggregation process
|
||||
pub mod typedefs;
|
||||
|
||||
use self::typedefs::{
|
||||
AggregatedCountByLenMap, AggregatesCountMap, RecordsSensitivity, RecordsSensitivitySlice,
|
||||
RecordsSet, ValueCombinationSlice,
|
||||
};
|
||||
use instant::Duration;
|
||||
use itertools::Itertools;
|
||||
use log::Level::Trace;
|
||||
use log::{info, log_enabled, trace};
|
||||
use std::cmp::min;
|
||||
use std::io::{Error, Write};
|
||||
/// Defines structures to store value combinations generated
|
||||
/// during the aggregate process
|
||||
pub mod value_combination;
|
||||
|
||||
use crate::data_block::block::DataBlock;
|
||||
use crate::data_block::typedefs::DataBlockHeadersSlice;
|
||||
use crate::data_block::value::DataBlockValue;
|
||||
use crate::measure_time;
|
||||
use crate::processing::aggregator::record_attrs_selector::RecordAttrsSelector;
|
||||
use crate::utils::math::{calc_n_combinations_range, calc_percentage, uround_down};
|
||||
use crate::utils::reporting::ReportProgress;
|
||||
use crate::utils::time::ElapsedDuration;
|
||||
mod data_aggregator;
|
||||
|
||||
/// Represents the privacy risk information related to a data block
|
||||
#[derive(Debug)]
|
||||
pub struct PrivacyRiskSummary {
|
||||
/// Total number of records on the data block
|
||||
pub total_number_of_records: usize,
|
||||
/// Total number of combinations aggregated (up to reporting length)
|
||||
pub total_number_of_combinations: usize,
|
||||
/// Number of records with unique combinations
|
||||
pub records_with_unique_combinations_count: usize,
|
||||
/// Number of records with rare combinations (combination count < resolution)
|
||||
pub records_with_rare_combinations_count: usize,
|
||||
/// Number of unique combinations
|
||||
pub unique_combinations_count: usize,
|
||||
/// Number of rare combinations
|
||||
pub rare_combinations_count: usize,
|
||||
/// Proportion of records containing unique combinations
|
||||
pub records_with_unique_combinations_proportion: f64,
|
||||
/// Proportion of records containing rare combinations
|
||||
pub records_with_rare_combinations_proportion: f64,
|
||||
/// Proportion of unique combinations
|
||||
pub unique_combinations_proportion: f64,
|
||||
/// Proportion of rare combinations
|
||||
pub rare_combinations_proportion: f64,
|
||||
}
|
||||
pub use data_aggregator::{AggregatedCount, Aggregator};
|
||||
|
||||
/// Result of data aggregation for each combination
|
||||
#[derive(Debug)]
|
||||
pub struct AggregatedCount {
|
||||
/// How many times the combination appears on the records
|
||||
pub count: usize,
|
||||
/// Which records this combinations is part of
|
||||
pub contained_in_records: RecordsSet,
|
||||
}
|
||||
|
||||
impl Default for AggregatedCount {
|
||||
fn default() -> AggregatedCount {
|
||||
AggregatedCount {
|
||||
count: 0,
|
||||
contained_in_records: RecordsSet::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AggregatorDurations {
|
||||
aggregate: Duration,
|
||||
calc_rare_combinations_count_by_len: Duration,
|
||||
calc_combinations_count_by_len: Duration,
|
||||
calc_combinations_sum_by_len: Duration,
|
||||
calc_privacy_risk: Duration,
|
||||
write_aggregates_count: Duration,
|
||||
}
|
||||
|
||||
/// Aggregated data produced by the Aggregator
|
||||
pub struct AggregatedData<'data_block> {
|
||||
/// Maps a value combination to its aggregated count
|
||||
pub aggregates_count: AggregatesCountMap<'data_block>,
|
||||
/// A vector of sensitivities for each record (the vector index is the record index)
|
||||
pub records_sensitivity: RecordsSensitivity,
|
||||
}
|
||||
|
||||
/// Process a data block to produced aggregated data
|
||||
pub struct Aggregator<'data_block> {
|
||||
data_block: &'data_block DataBlock,
|
||||
durations: AggregatorDurations,
|
||||
}
|
||||
|
||||
impl<'data_block> Aggregator<'data_block> {
|
||||
/// Returns a data aggregator for the given data block
|
||||
/// # Arguments
|
||||
/// * `data_block` - The data block to be processed
|
||||
#[inline]
|
||||
pub fn new(data_block: &'data_block DataBlock) -> Aggregator<'data_block> {
|
||||
Aggregator {
|
||||
data_block,
|
||||
durations: AggregatorDurations {
|
||||
aggregate: Duration::default(),
|
||||
calc_rare_combinations_count_by_len: Duration::default(),
|
||||
calc_combinations_count_by_len: Duration::default(),
|
||||
calc_combinations_sum_by_len: Duration::default(),
|
||||
calc_privacy_risk: Duration::default(),
|
||||
write_aggregates_count: Duration::default(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Formats a data block value as String.
|
||||
/// The result is formatted as: `{header_name}:{block_value}`
|
||||
/// # Arguments
|
||||
/// * `headers`: data block headers
|
||||
/// * `value`: value to be formatted
|
||||
#[inline]
|
||||
pub fn format_data_block_value_str(
|
||||
headers: &DataBlockHeadersSlice,
|
||||
value: &'data_block DataBlockValue,
|
||||
) -> String {
|
||||
format!("{}:{}", headers[value.column_index], value.value)
|
||||
}
|
||||
|
||||
/// Formats a value combination as String.
|
||||
/// The result is formatted as:
|
||||
/// `{header_name}:{block_value};{header_name}:{block_value}...`
|
||||
/// # Arguments
|
||||
/// * `headers`: data block headers
|
||||
/// * `aggregate`: combinations to be formatted
|
||||
#[inline]
|
||||
pub fn format_aggregate_str(
|
||||
headers: &DataBlockHeadersSlice,
|
||||
aggregate: &ValueCombinationSlice<'data_block>,
|
||||
) -> String {
|
||||
let mut str = String::default();
|
||||
for comb in aggregate {
|
||||
if !str.is_empty() {
|
||||
str += ";";
|
||||
}
|
||||
str += Aggregator::format_data_block_value_str(headers, comb).as_str();
|
||||
}
|
||||
str
|
||||
}
|
||||
|
||||
/// Aggregates the data block and returns the aggregated data back
|
||||
/// # Arguments
|
||||
/// * `reporting_length` - Calculate combinations from 1 up to `reporting_length`
|
||||
/// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes
|
||||
/// (0 means no suppression)
|
||||
/// * `progress_reporter` - Will be used to report the processing
|
||||
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
|
||||
pub fn aggregate<T>(
|
||||
&mut self,
|
||||
reporting_length: usize,
|
||||
sensitivity_threshold: usize,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> AggregatedData<'data_block>
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut aggregates_count: AggregatesCountMap = AggregatesCountMap::default();
|
||||
let max_len = if reporting_length == 0 {
|
||||
self.data_block.headers.len()
|
||||
} else {
|
||||
min(reporting_length, self.data_block.headers.len())
|
||||
};
|
||||
let length_range = (1..=max_len).collect::<Vec<usize>>();
|
||||
let mut record_attrs_selector = RecordAttrsSelector::new(
|
||||
&length_range,
|
||||
sensitivity_threshold,
|
||||
DataBlock::calc_attr_rows(&self.data_block.records),
|
||||
);
|
||||
let mut selected_combs_count = 0_u64;
|
||||
let mut all_combs_count = 0_u64;
|
||||
let mut records_sensitivity: RecordsSensitivity = RecordsSensitivity::default();
|
||||
let total_n_records = self.data_block.records.len();
|
||||
let total_n_records_f64 = total_n_records as f64;
|
||||
|
||||
records_sensitivity.resize(total_n_records, 0);
|
||||
|
||||
info!(
|
||||
"aggregating data with reporting length = {} and sensitivity_threshold = {}",
|
||||
max_len, sensitivity_threshold
|
||||
);
|
||||
|
||||
for (record_index, record) in self.data_block.records.iter().enumerate() {
|
||||
let selected_attrs = record_attrs_selector.select_from_record(&record.values);
|
||||
|
||||
all_combs_count +=
|
||||
calc_n_combinations_range(record.values.len(), &length_range);
|
||||
|
||||
Aggregator::update_aggregate_progress(
|
||||
progress_reporter,
|
||||
record_index,
|
||||
total_n_records_f64,
|
||||
);
|
||||
|
||||
for l in &length_range {
|
||||
for c in selected_attrs.iter().combinations(*l) {
|
||||
let curr_count = aggregates_count
|
||||
.entry(
|
||||
c.iter()
|
||||
.sorted_by_key(|k| {
|
||||
Aggregator::format_data_block_value_str(
|
||||
&self.data_block.headers,
|
||||
k,
|
||||
)
|
||||
})
|
||||
.cloned()
|
||||
.copied()
|
||||
.collect_vec(),
|
||||
)
|
||||
.or_insert_with(AggregatedCount::default);
|
||||
curr_count.count += 1;
|
||||
curr_count.contained_in_records.insert(record_index);
|
||||
records_sensitivity[record_index] += 1;
|
||||
selected_combs_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
Aggregator::update_aggregate_progress(
|
||||
progress_reporter,
|
||||
total_n_records,
|
||||
total_n_records_f64,
|
||||
);
|
||||
|
||||
info!(
|
||||
"data aggregated resulting in {} distinct combinations...",
|
||||
aggregates_count.len()
|
||||
);
|
||||
info!(
|
||||
"suppression ratio of aggregates is {:.2}%",
|
||||
(1.0 - (selected_combs_count as f64 / all_combs_count as f64)) * 100.0
|
||||
);
|
||||
AggregatedData {
|
||||
aggregates_count,
|
||||
records_sensitivity,
|
||||
}
|
||||
},
|
||||
(self.durations.aggregate)
|
||||
)
|
||||
}
|
||||
|
||||
/// Round the aggregated counts down to the nearest multiple of resolution
|
||||
/// # Arguments:
|
||||
/// * `aggregates_count` - Counts to be rounded in place
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn protect_aggregates_count(
|
||||
aggregates_count: &mut AggregatesCountMap<'data_block>,
|
||||
resolution: usize,
|
||||
) {
|
||||
info!(
|
||||
"protecting aggregates counts with resolution {}",
|
||||
resolution
|
||||
);
|
||||
for count in aggregates_count.values_mut() {
|
||||
count.count = uround_down(count.count as f64, resolution as f64);
|
||||
}
|
||||
// remove 0 counts from response
|
||||
aggregates_count.retain(|_, count| count.count > 0);
|
||||
}
|
||||
|
||||
/// Calculates the number of rare combinations grouped by combination length
|
||||
/// # Arguments:
|
||||
/// * `aggregates_count` - Calculated aggregates count map
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_rare_combinations_count_by_len(
|
||||
&mut self,
|
||||
aggregates_count: &AggregatesCountMap<'data_block>,
|
||||
resolution: usize,
|
||||
) -> AggregatedCountByLenMap {
|
||||
info!(
|
||||
"calculating rare combinations counts by length with resolution {}",
|
||||
resolution
|
||||
);
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
for (agg, count) in aggregates_count.iter() {
|
||||
if count.count < resolution {
|
||||
let curr_count = result.entry(agg.len()).or_insert(0);
|
||||
*curr_count += 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
},
|
||||
(self.durations.calc_rare_combinations_count_by_len)
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates the number of combinations grouped by combination length
|
||||
/// # Arguments:
|
||||
/// * `aggregates_count` - Calculated aggregates count map
|
||||
pub fn calc_combinations_count_by_len(
|
||||
&mut self,
|
||||
aggregates_count: &AggregatesCountMap<'data_block>,
|
||||
) -> AggregatedCountByLenMap {
|
||||
info!("calculating combination counts by length");
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
for agg in aggregates_count.keys() {
|
||||
let curr_count = result.entry(agg.len()).or_insert(0);
|
||||
*curr_count += 1;
|
||||
}
|
||||
result
|
||||
},
|
||||
(self.durations.calc_combinations_count_by_len)
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates the sum of all combination counts grouped by combination length
|
||||
/// # Arguments:
|
||||
/// * `aggregates_count` - Calculated aggregates count map
|
||||
pub fn calc_combinations_sum_by_len(
|
||||
&mut self,
|
||||
aggregates_count: &AggregatesCountMap<'data_block>,
|
||||
) -> AggregatedCountByLenMap {
|
||||
info!("calculating combination counts sums by length");
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
for (agg, count) in aggregates_count.iter() {
|
||||
let curr_sum = result.entry(agg.len()).or_insert(0);
|
||||
*curr_sum += count.count;
|
||||
}
|
||||
result
|
||||
},
|
||||
(self.durations.calc_combinations_sum_by_len)
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates the privacy risk related with data block and the generated
|
||||
/// aggregates counts
|
||||
/// # Arguments:
|
||||
/// * `aggregates_count` - Calculated aggregates count map
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_privacy_risk(
|
||||
&mut self,
|
||||
aggregates_count: &AggregatesCountMap<'data_block>,
|
||||
resolution: usize,
|
||||
) -> PrivacyRiskSummary {
|
||||
info!("calculating privacy risk...");
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut records_with_unique_combinations = RecordsSet::default();
|
||||
let mut records_with_rare_combinations = RecordsSet::default();
|
||||
let mut unique_combinations_count: usize = 0;
|
||||
let mut rare_combinations_count: usize = 0;
|
||||
let total_number_of_records = self.data_block.records.len();
|
||||
let total_number_of_combinations = aggregates_count.len();
|
||||
|
||||
for count in aggregates_count.values() {
|
||||
if count.count == 1 {
|
||||
unique_combinations_count += 1;
|
||||
count.contained_in_records.iter().for_each(|record_index| {
|
||||
records_with_unique_combinations.insert(*record_index);
|
||||
});
|
||||
}
|
||||
if count.count < resolution {
|
||||
rare_combinations_count += 1;
|
||||
count.contained_in_records.iter().for_each(|record_index| {
|
||||
records_with_rare_combinations.insert(*record_index);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
PrivacyRiskSummary {
|
||||
total_number_of_records,
|
||||
total_number_of_combinations,
|
||||
records_with_unique_combinations_count: records_with_unique_combinations.len(),
|
||||
records_with_rare_combinations_count: records_with_rare_combinations.len(),
|
||||
unique_combinations_count,
|
||||
rare_combinations_count,
|
||||
records_with_unique_combinations_proportion: (records_with_unique_combinations
|
||||
.len()
|
||||
as f64)
|
||||
/ (total_number_of_records as f64),
|
||||
records_with_rare_combinations_proportion: (records_with_rare_combinations.len()
|
||||
as f64)
|
||||
/ (total_number_of_records as f64),
|
||||
unique_combinations_proportion: (unique_combinations_count as f64)
|
||||
/ (total_number_of_combinations as f64),
|
||||
rare_combinations_proportion: (rare_combinations_count as f64)
|
||||
/ (total_number_of_combinations as f64),
|
||||
}
|
||||
},
|
||||
(self.durations.calc_privacy_risk)
|
||||
)
|
||||
}
|
||||
|
||||
/// Writes the aggregates counts to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `aggregates_count` - Calculated aggregates count map
|
||||
/// * `aggregates_path` - File path to be written
|
||||
/// * `aggregates_delimiter` - Delimiter to use when writing to `aggregates_path`
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `protected` - Whether or not the counts were protected before calling this
|
||||
pub fn write_aggregates_count(
|
||||
&mut self,
|
||||
aggregates_count: &AggregatesCountMap<'data_block>,
|
||||
aggregates_path: &str,
|
||||
aggregates_delimiter: char,
|
||||
resolution: usize,
|
||||
protected: bool,
|
||||
) -> Result<(), Error> {
|
||||
info!("writing file {}", aggregates_path);
|
||||
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut file = std::fs::File::create(aggregates_path)?;
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"selections{}{}\n",
|
||||
aggregates_delimiter,
|
||||
if protected {
|
||||
"protected_count"
|
||||
} else {
|
||||
"count"
|
||||
}
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
file.write_all(
|
||||
format!(
|
||||
"selections{}{}\n",
|
||||
aggregates_delimiter,
|
||||
uround_down(self.data_block.records.len() as f64, resolution as f64)
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
for aggregate in aggregates_count.keys() {
|
||||
file.write_all(
|
||||
format!(
|
||||
"{}{}{}\n",
|
||||
Aggregator::format_aggregate_str(&self.data_block.headers, aggregate),
|
||||
aggregates_delimiter,
|
||||
aggregates_count[aggregate].count
|
||||
)
|
||||
.as_bytes(),
|
||||
)?
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
(self.durations.write_aggregates_count)
|
||||
)
|
||||
}
|
||||
|
||||
/// Writes the records sensitivity to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `records_sensitivity` - Calculated records sensitivity
|
||||
/// * `records_sensitivity_path` - File path to be written
|
||||
pub fn write_records_sensitivity(
|
||||
&mut self,
|
||||
records_sensitivity: &RecordsSensitivitySlice,
|
||||
records_sensitivity_path: &str,
|
||||
) -> Result<(), Error> {
|
||||
info!("writing file {}", records_sensitivity_path);
|
||||
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut file = std::fs::File::create(records_sensitivity_path)?;
|
||||
|
||||
file.write_all("record_index\trecord_sensitivity\n".as_bytes())?;
|
||||
for (i, sensitivity) in records_sensitivity.iter().enumerate() {
|
||||
file.write_all(format!("{}\t{}\n", i, sensitivity).as_bytes())?
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
(self.durations.write_aggregates_count)
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_aggregate_progress<T>(
|
||||
progress_reporter: &mut Option<T>,
|
||||
n_processed: usize,
|
||||
total: f64,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
if let Some(r) = progress_reporter {
|
||||
r.report(calc_percentage(n_processed, total));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'data_block> Drop for Aggregator<'data_block> {
|
||||
fn drop(&mut self) {
|
||||
trace!("aggregator durations: {:#?}", self.durations);
|
||||
}
|
||||
}
|
||||
mod rows_aggregator;
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
use super::typedefs::{AggregatesCountMap, RecordsSet};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
/// Represents the privacy risk information related to a data block
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
#[derive(Debug)]
|
||||
pub struct PrivacyRiskSummary {
|
||||
/// Total number of records on the data block
|
||||
pub total_number_of_records: usize,
|
||||
/// Total number of combinations aggregated (up to reporting length)
|
||||
pub total_number_of_combinations: usize,
|
||||
/// Number of records with unique combinations
|
||||
pub records_with_unique_combinations_count: usize,
|
||||
/// Number of records with rare combinations (combination count < resolution)
|
||||
pub records_with_rare_combinations_count: usize,
|
||||
/// Number of unique combinations
|
||||
pub unique_combinations_count: usize,
|
||||
/// Number of rare combinations
|
||||
pub rare_combinations_count: usize,
|
||||
/// Proportion of records containing unique combinations
|
||||
pub records_with_unique_combinations_proportion: f64,
|
||||
/// Proportion of records containing rare combinations
|
||||
pub records_with_rare_combinations_proportion: f64,
|
||||
/// Proportion of unique combinations
|
||||
pub unique_combinations_proportion: f64,
|
||||
/// Proportion of rare combinations
|
||||
pub rare_combinations_proportion: f64,
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl PrivacyRiskSummary {
|
||||
/// Total number of records on the data block
|
||||
#[getter]
|
||||
fn total_number_of_records(&self) -> usize {
|
||||
self.total_number_of_records
|
||||
}
|
||||
|
||||
/// Total number of combinations aggregated (up to reporting length)
|
||||
#[getter]
|
||||
fn total_number_of_combinations(&self) -> usize {
|
||||
self.total_number_of_combinations
|
||||
}
|
||||
|
||||
/// Number of records with unique combinations
|
||||
#[getter]
|
||||
fn records_with_unique_combinations_count(&self) -> usize {
|
||||
self.records_with_unique_combinations_count
|
||||
}
|
||||
|
||||
/// Number of records with rare combinations (combination count < resolution)
|
||||
#[getter]
|
||||
fn records_with_rare_combinations_count(&self) -> usize {
|
||||
self.records_with_rare_combinations_count
|
||||
}
|
||||
|
||||
/// Number of unique combinations
|
||||
#[getter]
|
||||
fn unique_combinations_count(&self) -> usize {
|
||||
self.unique_combinations_count
|
||||
}
|
||||
|
||||
/// Number of rare combinations
|
||||
#[getter]
|
||||
fn rare_combinations_count(&self) -> usize {
|
||||
self.rare_combinations_count
|
||||
}
|
||||
|
||||
/// Proportion of records containing unique combinations
|
||||
#[getter]
|
||||
fn records_with_unique_combinations_proportion(&self) -> f64 {
|
||||
self.records_with_unique_combinations_proportion
|
||||
}
|
||||
|
||||
/// Proportion of records containing rare combinations
|
||||
#[getter]
|
||||
fn records_with_rare_combinations_proportion(&self) -> f64 {
|
||||
self.records_with_rare_combinations_proportion
|
||||
}
|
||||
|
||||
/// Proportion of unique combinations
|
||||
#[getter]
|
||||
fn unique_combinations_proportion(&self) -> f64 {
|
||||
self.unique_combinations_proportion
|
||||
}
|
||||
|
||||
/// Proportion of rare combinations
|
||||
#[getter]
|
||||
fn rare_combinations_proportion(&self) -> f64 {
|
||||
self.rare_combinations_proportion
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivacyRiskSummary {
|
||||
#[inline]
|
||||
/// Calculates ands returns the privacy risk related to the aggregates counts
|
||||
/// # Arguments:
|
||||
/// * `total_number_of_records` - Total number of records on the data block
|
||||
/// * `aggregates_count` - Aggregates counts to compute the privacy risk for
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn from_aggregates_count(
|
||||
total_number_of_records: usize,
|
||||
aggregates_count: &AggregatesCountMap,
|
||||
resolution: usize,
|
||||
) -> PrivacyRiskSummary {
|
||||
let mut records_with_unique_combinations = RecordsSet::default();
|
||||
let mut records_with_rare_combinations = RecordsSet::default();
|
||||
let mut unique_combinations_count: usize = 0;
|
||||
let mut rare_combinations_count: usize = 0;
|
||||
let total_number_of_combinations = aggregates_count.len();
|
||||
|
||||
for count in aggregates_count.values() {
|
||||
if count.count == 1 {
|
||||
unique_combinations_count += 1;
|
||||
records_with_unique_combinations.extend(&count.contained_in_records);
|
||||
}
|
||||
if count.count < resolution {
|
||||
rare_combinations_count += 1;
|
||||
records_with_rare_combinations.extend(&count.contained_in_records);
|
||||
}
|
||||
}
|
||||
|
||||
PrivacyRiskSummary {
|
||||
total_number_of_records,
|
||||
total_number_of_combinations,
|
||||
records_with_unique_combinations_count: records_with_unique_combinations.len(),
|
||||
records_with_rare_combinations_count: records_with_rare_combinations.len(),
|
||||
unique_combinations_count,
|
||||
rare_combinations_count,
|
||||
records_with_unique_combinations_proportion: (records_with_unique_combinations.len()
|
||||
as f64)
|
||||
/ (total_number_of_records as f64),
|
||||
records_with_rare_combinations_proportion: (records_with_rare_combinations.len()
|
||||
as f64)
|
||||
/ (total_number_of_records as f64),
|
||||
unique_combinations_proportion: (unique_combinations_count as f64)
|
||||
/ (total_number_of_combinations as f64),
|
||||
rare_combinations_proportion: (rare_combinations_count as f64)
|
||||
/ (total_number_of_combinations as f64),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
use fnv::FnvHashMap;
|
||||
use itertools::Itertools;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
data_block::{typedefs::AttributeRowsMap, value::DataBlockValue},
|
||||
|
@ -13,19 +13,19 @@ use crate::{
|
|||
/// Sensitivity in this context means the number of combinations that
|
||||
/// can be generated based on the number of non-empty values on the record.
|
||||
/// So attributes will suppressed from the record until record sensitivity <= `sensitivity_threshold`
|
||||
pub struct RecordAttrsSelector<'length_range, 'data_block> {
|
||||
pub struct RecordAttrsSelector<'length_range> {
|
||||
/// Range to compute the number of combinations [1...reporting_length]
|
||||
pub length_range: &'length_range [usize],
|
||||
/// Cache for how many values should be suppressed based of the number
|
||||
/// of non-empty attributes on the record
|
||||
cache: FnvHashMap<usize, usize>,
|
||||
/// Range to compute the number of combinations [1...reporting_length]
|
||||
length_range: &'length_range [usize],
|
||||
/// Maximum sensitivity allowed for each record
|
||||
sensitivity_threshold: usize,
|
||||
/// Map with a data block value as key and all the attribute row indexes where it occurs as value
|
||||
attr_rows_map: AttributeRowsMap<'data_block>,
|
||||
attr_rows_map: Arc<AttributeRowsMap>,
|
||||
}
|
||||
|
||||
impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block> {
|
||||
impl<'length_range> RecordAttrsSelector<'length_range> {
|
||||
/// Returns a new RecordAttrsSelector
|
||||
/// # Arguments
|
||||
/// * `length_range` - Range to compute the number of combinations [1...reporting_length]
|
||||
|
@ -36,8 +36,8 @@ impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block>
|
|||
pub fn new(
|
||||
length_range: &'length_range [usize],
|
||||
sensitivity_threshold: usize,
|
||||
attr_rows_map: AttributeRowsMap<'data_block>,
|
||||
) -> RecordAttrsSelector<'length_range, 'data_block> {
|
||||
attr_rows_map: Arc<AttributeRowsMap>,
|
||||
) -> RecordAttrsSelector<'length_range> {
|
||||
RecordAttrsSelector {
|
||||
cache: FnvHashMap::default(),
|
||||
length_range,
|
||||
|
@ -53,30 +53,30 @@ impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block>
|
|||
/// all records, higher the chance for it **not** to be suppressed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `record`: record to select attributes from
|
||||
/// * `record` - record to select attributes from
|
||||
pub fn select_from_record(
|
||||
&mut self,
|
||||
record: &'data_block [DataBlockValue],
|
||||
) -> Vec<&'data_block DataBlockValue> {
|
||||
record: &[Arc<DataBlockValue>],
|
||||
) -> Vec<Arc<DataBlockValue>> {
|
||||
let n_attributes = record.len();
|
||||
let selected_attrs_count = n_attributes - self.get_suppressed_count(n_attributes);
|
||||
|
||||
if selected_attrs_count < n_attributes {
|
||||
let mut attr_count_map: AttributeCountMap = record
|
||||
.iter()
|
||||
.map(|r| (r, self.attr_rows_map.get(r).unwrap().len()))
|
||||
.map(|r| (r.clone(), self.attr_rows_map.get(r).unwrap().len()))
|
||||
.collect();
|
||||
let mut res: Vec<&DataBlockValue> = Vec::default();
|
||||
let mut res: Vec<Arc<DataBlockValue>> = Vec::default();
|
||||
|
||||
for _ in 0..selected_attrs_count {
|
||||
if let Some(sample) = SynthesizerContext::sample_from_attr_counts(&attr_count_map) {
|
||||
attr_count_map.remove(&sample);
|
||||
res.push(sample);
|
||||
attr_count_map.remove(sample);
|
||||
}
|
||||
}
|
||||
res
|
||||
} else {
|
||||
record.iter().collect_vec()
|
||||
record.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,247 @@
|
|||
use super::typedefs::{RecordsAnalysisByLenMap, RecordsByLenMap};
|
||||
use itertools::Itertools;
|
||||
use std::io::{Error, Write};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::utils::math::{calc_percentage, uround_down};
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
#[derive(Clone)]
|
||||
/// Analysis information related to a single record
|
||||
pub struct RecordsAnalysis {
|
||||
/// Number of records containing unique combinations
|
||||
pub unique_combinations_records_count: usize,
|
||||
/// Percentage of records containing unique combinations
|
||||
pub unique_combinations_records_percentage: f64,
|
||||
/// Number of records containing rare combinations
|
||||
/// (unique os not taken into account)
|
||||
pub rare_combinations_records_count: usize,
|
||||
/// Percentage of records containing rare combinations
|
||||
/// (unique os not taken into account)
|
||||
pub rare_combinations_records_percentage: f64,
|
||||
/// Count of unique + rare
|
||||
pub risky_combinations_records_count: usize,
|
||||
/// Percentage of unique + rare
|
||||
pub risky_combinations_records_percentage: f64,
|
||||
}
|
||||
|
||||
impl RecordsAnalysis {
|
||||
#[inline]
|
||||
/// Created a new RecordsAnalysis with default values
|
||||
pub fn default() -> RecordsAnalysis {
|
||||
RecordsAnalysis {
|
||||
unique_combinations_records_count: 0,
|
||||
unique_combinations_records_percentage: 0.0,
|
||||
rare_combinations_records_count: 0,
|
||||
rare_combinations_records_percentage: 0.0,
|
||||
risky_combinations_records_count: 0,
|
||||
risky_combinations_records_percentage: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl RecordsAnalysis {
|
||||
#[getter]
|
||||
/// Number of records containing unique combinations
|
||||
fn unique_combinations_records_count(&self) -> usize {
|
||||
self.unique_combinations_records_count
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Percentage of records containing unique combinations
|
||||
fn unique_combinations_records_percentage(&self) -> f64 {
|
||||
self.unique_combinations_records_percentage
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Number of records containing rare combinations
|
||||
/// (unique os not taken into account)
|
||||
fn rare_combinations_records_count(&self) -> usize {
|
||||
self.rare_combinations_records_count
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Percentage of records containing rare combinations
|
||||
/// (unique os not taken into account)
|
||||
fn rare_combinations_records_percentage(&self) -> f64 {
|
||||
self.rare_combinations_records_percentage
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Count of unique + rare
|
||||
fn risky_combinations_records_count(&self) -> usize {
|
||||
self.risky_combinations_records_count
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Percentage of unique + rare
|
||||
fn risky_combinations_records_percentage(&self) -> f64 {
|
||||
self.risky_combinations_records_percentage
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
/// Stores the records analysis for all records grouped by
|
||||
/// combination length
|
||||
pub struct RecordsAnalysisData {
|
||||
/// Map of records analysis grouped by combination len
|
||||
pub records_analysis_by_len: RecordsAnalysisByLenMap,
|
||||
}
|
||||
|
||||
impl RecordsAnalysisData {
|
||||
/// Computes the record analysis from the arguments.
|
||||
/// # Arguments
|
||||
/// * `unique_records_by_len` - unique records grouped by length
|
||||
/// * `rare_records_by_len` - rare records grouped by length
|
||||
/// * `total_number_of_records` - Total number of records on the data block
|
||||
/// * `reporting_length` - Reporting length used for the data aggregation
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `protect` - Whether or not the counts should be rounded to the
|
||||
/// nearest smallest multiple of resolution
|
||||
#[inline]
|
||||
pub fn from_unique_rare_combinations_records_by_len(
|
||||
unique_records_by_len: &RecordsByLenMap,
|
||||
rare_records_by_len: &RecordsByLenMap,
|
||||
total_n_records: usize,
|
||||
reporting_length: usize,
|
||||
resolution: usize,
|
||||
protect: bool,
|
||||
) -> RecordsAnalysisData {
|
||||
let total_n_records_f64 = if protect {
|
||||
uround_down(total_n_records as f64, resolution as f64)
|
||||
} else {
|
||||
total_n_records
|
||||
} as f64;
|
||||
let records_analysis_by_len: RecordsAnalysisByLenMap = (1..=reporting_length)
|
||||
.map(|l| {
|
||||
let mut ra = RecordsAnalysis::default();
|
||||
|
||||
ra.unique_combinations_records_count = unique_records_by_len
|
||||
.get(&l)
|
||||
.map_or(0, |records| records.len());
|
||||
ra.rare_combinations_records_count = rare_records_by_len
|
||||
.get(&l)
|
||||
.map_or(0, |records| records.len());
|
||||
|
||||
if protect {
|
||||
ra.unique_combinations_records_count = uround_down(
|
||||
ra.unique_combinations_records_count as f64,
|
||||
resolution as f64,
|
||||
);
|
||||
ra.rare_combinations_records_count =
|
||||
uround_down(ra.rare_combinations_records_count as f64, resolution as f64)
|
||||
}
|
||||
|
||||
ra.unique_combinations_records_percentage = calc_percentage(
|
||||
ra.unique_combinations_records_count as f64,
|
||||
total_n_records_f64,
|
||||
);
|
||||
ra.rare_combinations_records_percentage = calc_percentage(
|
||||
ra.rare_combinations_records_count as f64,
|
||||
total_n_records_f64,
|
||||
);
|
||||
ra.risky_combinations_records_count =
|
||||
ra.unique_combinations_records_count + ra.rare_combinations_records_count;
|
||||
ra.risky_combinations_records_percentage = calc_percentage(
|
||||
ra.risky_combinations_records_count as f64,
|
||||
total_n_records_f64,
|
||||
);
|
||||
|
||||
(l, ra)
|
||||
})
|
||||
.collect();
|
||||
RecordsAnalysisData {
|
||||
records_analysis_by_len,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl RecordsAnalysisData {
|
||||
/// Returns the records analysis map grouped by length.
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
pub fn get_records_analysis_by_len(&self) -> RecordsAnalysisByLenMap {
|
||||
self.records_analysis_by_len.clone()
|
||||
}
|
||||
|
||||
/// Returns the total number of records containing unique combinations
|
||||
/// for all lengths
|
||||
pub fn get_total_unique(&self) -> usize {
|
||||
self.records_analysis_by_len
|
||||
.values()
|
||||
.map(|v| v.unique_combinations_records_count)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Returns the total number of records containing rare combinations
|
||||
/// for all lengths
|
||||
pub fn get_total_rare(&self) -> usize {
|
||||
self.records_analysis_by_len
|
||||
.values()
|
||||
.map(|v| v.rare_combinations_records_count)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Returns the total number of records containing risky (unique + rare)
|
||||
/// combinations for all lengths
|
||||
pub fn get_total_risky(&self) -> usize {
|
||||
self.records_analysis_by_len
|
||||
.values()
|
||||
.map(|v| v.risky_combinations_records_count)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Writes the records analysis to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `records_analysis_path` - File path to be written
|
||||
/// * `records_analysis_delimiter` - Delimiter to use when writing to `records_analysis_path`
|
||||
pub fn write_records_analysis(
|
||||
&mut self,
|
||||
records_analysis_path: &str,
|
||||
records_analysis_delimiter: char,
|
||||
) -> Result<(), Error> {
|
||||
let mut file = std::fs::File::create(records_analysis_path)?;
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"combo_length{}sen_rare{}sen_rare_pct{}sen_unique{}sen_unique_pct{}sen_risky{}sen_risky_pct\n",
|
||||
records_analysis_delimiter,
|
||||
records_analysis_delimiter,
|
||||
records_analysis_delimiter,
|
||||
records_analysis_delimiter,
|
||||
records_analysis_delimiter,
|
||||
records_analysis_delimiter
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
for l in self.records_analysis_by_len.keys().sorted() {
|
||||
let ra = &self.records_analysis_by_len[l];
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"{}{}{}{}{}{}{}{}{}{}{}{}{}\n",
|
||||
l,
|
||||
records_analysis_delimiter,
|
||||
ra.rare_combinations_records_count,
|
||||
records_analysis_delimiter,
|
||||
ra.rare_combinations_records_percentage,
|
||||
records_analysis_delimiter,
|
||||
ra.unique_combinations_records_count,
|
||||
records_analysis_delimiter,
|
||||
ra.unique_combinations_records_percentage,
|
||||
records_analysis_delimiter,
|
||||
ra.risky_combinations_records_count,
|
||||
records_analysis_delimiter,
|
||||
ra.risky_combinations_records_percentage,
|
||||
)
|
||||
.as_bytes(),
|
||||
)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,195 @@
|
|||
use super::{
|
||||
record_attrs_selector::RecordAttrsSelector,
|
||||
typedefs::{AggregatesCountMap, EnumeratedDataBlockRecords, RecordsSensitivity},
|
||||
value_combination::ValueCombination,
|
||||
AggregatedCount,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use log::info;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
data_block::block::DataBlock,
|
||||
utils::{
|
||||
math::calc_n_combinations_range,
|
||||
reporting::{ReportProgress, SendableProgressReporter, SendableProgressReporterRef},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct RowsAggregatorResult {
|
||||
pub all_combs_count: u64,
|
||||
pub selected_combs_count: u64,
|
||||
pub aggregates_count: AggregatesCountMap,
|
||||
pub records_sensitivity: RecordsSensitivity,
|
||||
}
|
||||
|
||||
impl RowsAggregatorResult {
|
||||
#[inline]
|
||||
pub fn new(total_n_records: usize) -> RowsAggregatorResult {
|
||||
let mut records_sensitivity = RecordsSensitivity::default();
|
||||
|
||||
records_sensitivity.resize(total_n_records, 0);
|
||||
RowsAggregatorResult {
|
||||
all_combs_count: 0,
|
||||
selected_combs_count: 0,
|
||||
aggregates_count: AggregatesCountMap::default(),
|
||||
records_sensitivity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RowsAggregator<'length_range> {
|
||||
data_block: Arc<DataBlock>,
|
||||
enumerated_records: EnumeratedDataBlockRecords,
|
||||
record_attrs_selector: RecordAttrsSelector<'length_range>,
|
||||
}
|
||||
|
||||
impl<'length_range> RowsAggregator<'length_range> {
|
||||
#[inline]
|
||||
pub fn new(
|
||||
data_block: Arc<DataBlock>,
|
||||
enumerated_records: EnumeratedDataBlockRecords,
|
||||
record_attrs_selector: RecordAttrsSelector<'length_range>,
|
||||
) -> RowsAggregator {
|
||||
RowsAggregator {
|
||||
data_block,
|
||||
enumerated_records,
|
||||
record_attrs_selector,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
#[inline]
|
||||
pub fn aggregate_all<T>(
|
||||
total_n_records: usize,
|
||||
rows_aggregators: &mut Vec<RowsAggregator>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> RowsAggregatorResult
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let sendable_pr =
|
||||
Arc::new(Mutex::new(progress_reporter.as_mut().map(|r| {
|
||||
SendableProgressReporter::new(total_n_records as f64, 1.0, r)
|
||||
})));
|
||||
|
||||
RowsAggregator::join_partial_results(
|
||||
total_n_records,
|
||||
rows_aggregators
|
||||
.par_iter_mut()
|
||||
.map(|ra| ra.aggregate_rows(&mut sendable_pr.clone()))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "rayon"))]
|
||||
#[inline]
|
||||
pub fn aggregate_all<T>(
|
||||
total_n_records: usize,
|
||||
rows_aggregators: &mut Vec<RowsAggregator>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> RowsAggregatorResult
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut sendable_pr = progress_reporter
|
||||
.as_mut()
|
||||
.map(|r| SendableProgressReporter::new(total_n_records as f64, 1.0, r));
|
||||
|
||||
RowsAggregator::join_partial_results(
|
||||
total_n_records,
|
||||
rows_aggregators
|
||||
.iter_mut()
|
||||
.map(|ra| ra.aggregate_rows(&mut sendable_pr))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn join_partial_results(
|
||||
total_n_records: usize,
|
||||
mut partial_results: Vec<RowsAggregatorResult>,
|
||||
) -> RowsAggregatorResult {
|
||||
info!("joining aggregated partial results...");
|
||||
|
||||
// take last element and the initial result
|
||||
// or use the default one if there are no results
|
||||
let mut final_result = partial_results
|
||||
.pop()
|
||||
.unwrap_or_else(|| RowsAggregatorResult::new(total_n_records));
|
||||
|
||||
// use drain instead of fold, so we do not duplicate memory
|
||||
for mut partial_result in partial_results.drain(..) {
|
||||
// join counts
|
||||
final_result.all_combs_count += partial_result.all_combs_count;
|
||||
final_result.selected_combs_count += partial_result.selected_combs_count;
|
||||
|
||||
// join aggregated counts
|
||||
for (comb, value) in partial_result.aggregates_count.drain() {
|
||||
let final_count = final_result
|
||||
.aggregates_count
|
||||
.entry(comb)
|
||||
.or_insert_with(AggregatedCount::default);
|
||||
final_count.count += value.count;
|
||||
final_count
|
||||
.contained_in_records
|
||||
.extend(value.contained_in_records);
|
||||
}
|
||||
|
||||
// join records sensitivity
|
||||
for (i, sensitivity) in partial_result.records_sensitivity.drain(..).enumerate() {
|
||||
final_result.records_sensitivity[i] += sensitivity;
|
||||
}
|
||||
}
|
||||
final_result
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn aggregate_rows<T>(
|
||||
&mut self,
|
||||
progress_reporter: &mut SendableProgressReporterRef<T>,
|
||||
) -> RowsAggregatorResult
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut result = RowsAggregatorResult::new(self.data_block.records.len());
|
||||
|
||||
for (record_index, record) in self.enumerated_records.iter() {
|
||||
let mut selected_attrs = self
|
||||
.record_attrs_selector
|
||||
.select_from_record(&record.values);
|
||||
|
||||
// sort the attributes here, so combinations will be already sorted
|
||||
// and we do not need to sort entry by entry on the loop below
|
||||
selected_attrs.sort_by_key(|k| k.format_str_using_headers(&self.data_block.headers));
|
||||
|
||||
result.all_combs_count += calc_n_combinations_range(
|
||||
record.values.len(),
|
||||
self.record_attrs_selector.length_range,
|
||||
);
|
||||
|
||||
for l in self.record_attrs_selector.length_range {
|
||||
for mut c in selected_attrs.iter().combinations(*l) {
|
||||
let current_count = result
|
||||
.aggregates_count
|
||||
.entry(ValueCombination::new(
|
||||
c.drain(..).map(|k| (*k).clone()).collect(),
|
||||
))
|
||||
.or_insert_with(AggregatedCount::default);
|
||||
current_count.count += 1;
|
||||
current_count.contained_in_records.insert(*record_index);
|
||||
result.records_sensitivity[*record_index] += 1;
|
||||
result.selected_combs_count += 1;
|
||||
}
|
||||
}
|
||||
SendableProgressReporter::update_progress(progress_reporter, 1.0);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
|
@ -1,26 +1,36 @@
|
|||
use super::AggregatedCount;
|
||||
use super::{
|
||||
data_aggregator::AggregatedCount, records_analysis_data::RecordsAnalysis,
|
||||
value_combination::ValueCombination,
|
||||
};
|
||||
use fnv::{FnvHashMap, FnvHashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::data_block::value::DataBlockValue;
|
||||
use crate::data_block::record::DataBlockRecord;
|
||||
|
||||
/// Set of records where the key is the record index starting in 0
|
||||
pub type RecordsSet = FnvHashSet<usize>;
|
||||
|
||||
/// Vector of data block values representing a value combination (sorted by `{header_name}:{block_value}`)
|
||||
pub type ValueCombination<'data_block_value> = Vec<&'data_block_value DataBlockValue>;
|
||||
|
||||
/// Slice of ValueCombination
|
||||
pub type ValueCombinationSlice<'data_block_value> = [&'data_block_value DataBlockValue];
|
||||
|
||||
/// Maps a value combination to its aggregated count
|
||||
pub type AggregatesCountMap<'data_block_value> =
|
||||
FnvHashMap<ValueCombination<'data_block_value>, AggregatedCount>;
|
||||
pub type AggregatesCountMap = FnvHashMap<ValueCombination, AggregatedCount>;
|
||||
|
||||
/// Maps a value combination represented as a string to its aggregated count
|
||||
pub type AggregatesCountStringMap = FnvHashMap<String, AggregatedCount>;
|
||||
|
||||
/// Maps a length (1,2,3... up to reporting length) to a determined count
|
||||
pub type AggregatedCountByLenMap = FnvHashMap<usize, usize>;
|
||||
|
||||
/// Maps a length (1,2,3... up to reporting length) to a record set
|
||||
pub type RecordsByLenMap = FnvHashMap<usize, RecordsSet>;
|
||||
|
||||
/// A vector of sensitivities for each record (the vector index is the record index)
|
||||
pub type RecordsSensitivity = Vec<usize>;
|
||||
|
||||
/// Slice of RecordsSensitivity
|
||||
pub type RecordsSensitivitySlice = [usize];
|
||||
|
||||
/// Vector of tuples:
|
||||
/// (index of the original record, reference to the original record)
|
||||
pub type EnumeratedDataBlockRecords = Vec<(usize, Arc<DataBlockRecord>)>;
|
||||
|
||||
/// Map of records analysis grouped by combination len
|
||||
pub type RecordsAnalysisByLenMap = FnvHashMap<usize, RecordsAnalysis>;
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
use serde::{
|
||||
de::{self, Visitor},
|
||||
Deserialize, Serialize,
|
||||
};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
marker::PhantomData,
|
||||
ops::{Deref, DerefMut},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::data_block::{
|
||||
typedefs::DataBlockHeadersSlice,
|
||||
value::{DataBlockValue, ParseDataBlockValueError},
|
||||
};
|
||||
|
||||
const COMBINATIONS_DELIMITER: char = ';';
|
||||
|
||||
#[derive(Eq, PartialEq, Hash)]
|
||||
/// Wraps a vector of data block values representing a value
|
||||
/// combination (sorted by `{header_name}:{block_value}`)
|
||||
pub struct ValueCombination {
|
||||
combination: Vec<Arc<DataBlockValue>>,
|
||||
}
|
||||
|
||||
impl ValueCombination {
|
||||
#[inline]
|
||||
/// Creates a new ValueCombination with default values
|
||||
pub fn default() -> ValueCombination {
|
||||
ValueCombination::new(Vec::default())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Creates a new ValueCombination
|
||||
/// # Arguments
|
||||
/// * `combination` - raw vector of value combinations
|
||||
/// sorted by `{header_name}:{block_value}`
|
||||
pub fn new(combination: Vec<Arc<DataBlockValue>>) -> ValueCombination {
|
||||
ValueCombination { combination }
|
||||
}
|
||||
|
||||
/// Formats a value combination as String using the headers.
|
||||
/// The result is formatted as:
|
||||
/// `{header_name}:{block_value};{header_name}:{block_value}...`
|
||||
/// # Arguments
|
||||
/// * `headers` - Data block headers
|
||||
/// * `combination_delimiter` - Delimiter used to join combinations
|
||||
#[inline]
|
||||
pub fn format_str_using_headers(
|
||||
&self,
|
||||
headers: &DataBlockHeadersSlice,
|
||||
combination_delimiter: &str,
|
||||
) -> String {
|
||||
let mut str = String::default();
|
||||
|
||||
if let Some(comb) = self.combination.get(0) {
|
||||
str.push_str(&comb.format_str_using_headers(headers));
|
||||
}
|
||||
for comb in self.combination.iter().skip(1) {
|
||||
str += combination_delimiter;
|
||||
str.push_str(&comb.format_str_using_headers(headers));
|
||||
}
|
||||
str
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ValueCombination {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if let Some(comb) = self.combination.get(0) {
|
||||
write!(f, "{}", comb)?;
|
||||
}
|
||||
for comb in self.combination.iter().skip(1) {
|
||||
write!(f, "{}{}", COMBINATIONS_DELIMITER, comb)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ValueCombination {
|
||||
type Err = ParseDataBlockValueError;
|
||||
|
||||
/// Creates a new ValueCombination by parsing `str_value`
|
||||
fn from_str(str_value: &str) -> Result<Self, Self::Err> {
|
||||
Ok(ValueCombination::new(
|
||||
str_value
|
||||
.split(COMBINATIONS_DELIMITER)
|
||||
.map(|v| Ok(Arc::new(DataBlockValue::from_str(v)?)))
|
||||
.collect::<Result<Vec<Arc<DataBlockValue>>, Self::Err>>()?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ValueCombination {
|
||||
type Target = Vec<Arc<DataBlockValue>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.combination
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for ValueCombination {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.combination
|
||||
}
|
||||
}
|
||||
|
||||
// Implementing a custom serializer, so this can
|
||||
// be properly written to a json file
|
||||
impl Serialize for ValueCombination {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&format!("{}", self))
|
||||
}
|
||||
}
|
||||
|
||||
struct ValueCombinationVisitor {
|
||||
marker: PhantomData<fn() -> ValueCombination>,
|
||||
}
|
||||
|
||||
impl ValueCombinationVisitor {
|
||||
fn new() -> Self {
|
||||
ValueCombinationVisitor {
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Visitor<'de> for ValueCombinationVisitor {
|
||||
type Value = ValueCombination;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a string representing the value combination")
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
|
||||
ValueCombination::from_str(v).map_err(E::custom)
|
||||
}
|
||||
}
|
||||
|
||||
// Implementing a custom deserializer, so this can
|
||||
// be properly read from a json file
|
||||
impl<'de> Deserialize<'de> for ValueCombination {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_string(ValueCombinationVisitor::new())
|
||||
}
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
const INITIAL_BIN: usize = 10;
|
||||
const BIN_RATIO: usize = 2;
|
||||
|
||||
/// Stores the preservation information related to a particular bucket.
|
||||
/// A bucket stores count in a certain range value.
|
||||
/// For example: (0, 10], (10, 20], (20, 40]...
|
||||
#[derive(Debug)]
|
||||
pub struct PreservationByCountBucket {
|
||||
/// How many elements are stored in the bucket
|
||||
pub size: usize,
|
||||
/// Preservation sum of all elements in the bucket
|
||||
pub preservation_sum: f64,
|
||||
/// Combination length sum of all elements in the bucket
|
||||
pub length_sum: usize,
|
||||
}
|
||||
|
||||
impl PreservationByCountBucket {
|
||||
/// Return a new PreservationByCountBucket with default values
|
||||
pub fn default() -> PreservationByCountBucket {
|
||||
PreservationByCountBucket {
|
||||
size: 0,
|
||||
preservation_sum: 0.0,
|
||||
length_sum: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a new value to the bucket
|
||||
/// # Arguments
|
||||
/// * `preservation` - Preservation related to the value
|
||||
/// * `length` - Combination length related to the value
|
||||
pub fn add(&mut self, preservation: f64, length: usize) {
|
||||
self.size += 1;
|
||||
self.preservation_sum += preservation;
|
||||
self.length_sum += length;
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the bins for each bucket. For example:
|
||||
/// [10, 20, 40, 80, 160...]
|
||||
#[derive(Debug)]
|
||||
pub struct PreservationByCountBucketBins {
|
||||
/// Bins where the index of the vector is the bin index,
|
||||
/// and the value is the max value count allowed on the bin
|
||||
bins: Vec<usize>,
|
||||
}
|
||||
|
||||
impl PreservationByCountBucketBins {
|
||||
/// Generates a new PreservationByCountBucketBins with bins up to a `max_val`
|
||||
/// # Arguments
|
||||
/// * `max_val` - Max value allowed on the last bucket (inclusive)
|
||||
pub fn new(max_val: usize) -> PreservationByCountBucketBins {
|
||||
let mut bins: Vec<usize> = vec![INITIAL_BIN];
|
||||
loop {
|
||||
let last = *bins.last().unwrap();
|
||||
if last >= max_val {
|
||||
break;
|
||||
}
|
||||
bins.push(last * BIN_RATIO);
|
||||
}
|
||||
PreservationByCountBucketBins { bins }
|
||||
}
|
||||
|
||||
/// Find the first `bucket_max_val` where `val >= bucket_max_val`
|
||||
/// # Arguments
|
||||
/// * `val` - Count to look in which bucket it will fit
|
||||
pub fn find_bucket_max_val(&self, val: usize) -> usize {
|
||||
// find first element x where x >= val
|
||||
self.bins[self.bins.partition_point(|x| *x < val)]
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PreservationByCountBucketBins {
|
||||
type Target = Vec<usize>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.bins
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for PreservationByCountBucketBins {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.bins
|
||||
}
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
use super::preservation_by_count::{PreservationByCountBucketBins, PreservationByCountBuckets};
|
||||
use super::rare_combinations_comparison_data::RareCombinationsComparisonData;
|
||||
use fnv::FnvHashSet;
|
||||
use log::info;
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::processing::aggregator::aggregated_data::AggregatedData;
|
||||
use crate::processing::aggregator::typedefs::AggregatedCountByLenMap;
|
||||
use crate::processing::aggregator::value_combination::ValueCombination;
|
||||
use crate::processing::evaluator::preservation_bucket::PreservationBucket;
|
||||
use crate::processing::evaluator::preservation_by_length::PreservationByLengthBuckets;
|
||||
use crate::utils::time::ElapsedDurationLogger;
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
/// Evaluates aggregated, sensitive and synthesized data
|
||||
pub struct Evaluator {}
|
||||
|
||||
impl Evaluator {
|
||||
/// Returns a new Evaluator
|
||||
pub fn default() -> Evaluator {
|
||||
Evaluator {}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl Evaluator {
|
||||
/// Returns a new Evaluator
|
||||
#[cfg(feature = "pyo3")]
|
||||
#[new]
|
||||
pub fn default_pyo3() -> Evaluator {
|
||||
Evaluator::default()
|
||||
}
|
||||
|
||||
/// Calculates the leakage counts grouped by combination length
|
||||
/// (how many attribute combinations exist on the sensitive data and
|
||||
/// appear on the synthetic data with `count < resolution`).
|
||||
/// By design this should be `0`
|
||||
/// # Arguments
|
||||
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
|
||||
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_leakage_count(
|
||||
&self,
|
||||
sensitive_aggregated_data: &AggregatedData,
|
||||
synthetic_aggregated_data: &AggregatedData,
|
||||
resolution: usize,
|
||||
) -> AggregatedCountByLenMap {
|
||||
let _duration_logger = ElapsedDurationLogger::new("leakage count calculation");
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
info!("calculating rare sensitive combination leakages by length");
|
||||
|
||||
for (sensitive_agg, sensitive_count) in sensitive_aggregated_data.aggregates_count.iter() {
|
||||
if sensitive_count.count < resolution {
|
||||
if let Some(synthetic_count) = synthetic_aggregated_data
|
||||
.aggregates_count
|
||||
.get(sensitive_agg)
|
||||
{
|
||||
if synthetic_count.count < resolution {
|
||||
let leaks = result.entry(sensitive_agg.len()).or_insert(0);
|
||||
*leaks += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Calculates the fabricated counts grouped by combination length
|
||||
/// (how many attribute combinations exist on the synthetic data that do not
|
||||
/// exist on the sensitive data).
|
||||
/// By design this should be `0`
|
||||
/// # Arguments
|
||||
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
|
||||
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
|
||||
pub fn calc_fabricated_count(
|
||||
&self,
|
||||
sensitive_aggregated_data: &AggregatedData,
|
||||
synthetic_aggregated_data: &AggregatedData,
|
||||
) -> AggregatedCountByLenMap {
|
||||
let _duration_logger = ElapsedDurationLogger::new("fabricated count calculation");
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
info!("calculating fabricated synthetic combinations by length");
|
||||
|
||||
for synthetic_agg in synthetic_aggregated_data.aggregates_count.keys() {
|
||||
if sensitive_aggregated_data
|
||||
.aggregates_count
|
||||
.get(synthetic_agg)
|
||||
.is_none()
|
||||
{
|
||||
let fabricated = result.entry(synthetic_agg.len()).or_insert(0);
|
||||
*fabricated += 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Calculates the preservation information grouped by combination count.
|
||||
/// An example output might be a map like:
|
||||
/// `{ 10 -> PreservationBucket, 20 -> PreservationBucket, ...}`
|
||||
/// # Arguments
|
||||
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
|
||||
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_preservation_by_count(
|
||||
&self,
|
||||
sensitive_aggregated_data: &AggregatedData,
|
||||
synthetic_aggregated_data: &AggregatedData,
|
||||
resolution: usize,
|
||||
) -> PreservationByCountBuckets {
|
||||
let _duration_logger = ElapsedDurationLogger::new("preservation by count calculation");
|
||||
|
||||
info!(
|
||||
"calculating preservation by count with resolution: {}",
|
||||
resolution
|
||||
);
|
||||
|
||||
let max_syn_count = *synthetic_aggregated_data
|
||||
.aggregates_count
|
||||
.values()
|
||||
.map(|a| &a.count)
|
||||
.max()
|
||||
.unwrap_or(&0);
|
||||
let bins: PreservationByCountBucketBins = PreservationByCountBucketBins::new(max_syn_count);
|
||||
let mut buckets: PreservationByCountBuckets = PreservationByCountBuckets::default();
|
||||
let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default();
|
||||
|
||||
for (comb, count) in sensitive_aggregated_data.aggregates_count.iter() {
|
||||
// exclude sensitive rare combinations
|
||||
if count.count >= resolution && !processed_combs.contains(comb) {
|
||||
buckets.populate(
|
||||
&bins,
|
||||
comb,
|
||||
&sensitive_aggregated_data.aggregates_count,
|
||||
&synthetic_aggregated_data.aggregates_count,
|
||||
);
|
||||
processed_combs.insert(comb);
|
||||
}
|
||||
}
|
||||
for comb in synthetic_aggregated_data.aggregates_count.keys() {
|
||||
if !processed_combs.contains(comb) {
|
||||
buckets.populate(
|
||||
&bins,
|
||||
comb,
|
||||
&sensitive_aggregated_data.aggregates_count,
|
||||
&synthetic_aggregated_data.aggregates_count,
|
||||
);
|
||||
processed_combs.insert(comb);
|
||||
}
|
||||
}
|
||||
|
||||
// fill empty buckets with default value
|
||||
for bin in bins.iter() {
|
||||
buckets
|
||||
.entry(*bin)
|
||||
.or_insert_with(PreservationBucket::default);
|
||||
}
|
||||
|
||||
buckets
|
||||
}
|
||||
|
||||
/// Calculates the preservation information grouped by combination length.
|
||||
/// An example output might be a map like:
|
||||
/// `{ 1 -> PreservationBucket, 2 -> PreservationBucket, ...}`
|
||||
/// # Arguments
|
||||
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
|
||||
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_preservation_by_length(
|
||||
&self,
|
||||
sensitive_aggregated_data: &AggregatedData,
|
||||
synthetic_aggregated_data: &AggregatedData,
|
||||
resolution: usize,
|
||||
) -> PreservationByLengthBuckets {
|
||||
let _duration_logger = ElapsedDurationLogger::new("preservation by length calculation");
|
||||
|
||||
info!(
|
||||
"calculating preservation by length with resolution: {}",
|
||||
resolution
|
||||
);
|
||||
|
||||
let mut buckets: PreservationByLengthBuckets = PreservationByLengthBuckets::default();
|
||||
let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default();
|
||||
|
||||
for (comb, count) in sensitive_aggregated_data.aggregates_count.iter() {
|
||||
// exclude sensitive rare combinations
|
||||
if count.count >= resolution && !processed_combs.contains(comb) {
|
||||
buckets.populate(
|
||||
comb,
|
||||
&sensitive_aggregated_data.aggregates_count,
|
||||
&synthetic_aggregated_data.aggregates_count,
|
||||
);
|
||||
processed_combs.insert(comb);
|
||||
}
|
||||
}
|
||||
for comb in synthetic_aggregated_data.aggregates_count.keys() {
|
||||
if !processed_combs.contains(comb) {
|
||||
buckets.populate(
|
||||
comb,
|
||||
&sensitive_aggregated_data.aggregates_count,
|
||||
&synthetic_aggregated_data.aggregates_count,
|
||||
);
|
||||
processed_combs.insert(comb);
|
||||
}
|
||||
}
|
||||
|
||||
// fill empty buckets with default value
|
||||
for l in 1..=synthetic_aggregated_data.reporting_length {
|
||||
buckets.entry(l).or_insert_with(PreservationBucket::default);
|
||||
}
|
||||
|
||||
buckets
|
||||
}
|
||||
|
||||
//// Compares the rare combinations on the synthetic data with
|
||||
/// the sensitive data counts
|
||||
/// # Arguments
|
||||
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
|
||||
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `combination_delimiter` - Delimiter used to join combinations and format then
|
||||
/// as strings
|
||||
/// * `protect` - Whether or not the sensitive counts should be rounded to the
|
||||
/// nearest smallest multiple of resolution
|
||||
pub fn compare_synthetic_and_sensitive_rare(
|
||||
&self,
|
||||
synthetic_aggregated_data: &AggregatedData,
|
||||
sensitive_aggregated_data: &AggregatedData,
|
||||
resolution: usize,
|
||||
combination_delimiter: &str,
|
||||
protect: bool,
|
||||
) -> RareCombinationsComparisonData {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new("synthetic and sensitive rare comparison");
|
||||
RareCombinationsComparisonData::from_synthetic_and_sensitive_aggregated_data(
|
||||
synthetic_aggregated_data,
|
||||
sensitive_aggregated_data,
|
||||
resolution,
|
||||
combination_delimiter,
|
||||
protect,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
pub fn register(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<Evaluator>()?;
|
||||
Ok(())
|
||||
}
|
|
@ -1,227 +1,23 @@
|
|||
pub mod bucket;
|
||||
/// Module defining routines to calculate preservation by count
|
||||
pub mod preservation_by_count;
|
||||
|
||||
/// Module defining routines to calculate preservation by length
|
||||
pub mod preservation_by_length;
|
||||
|
||||
/// Module defining a bucket used to calculate preservation by count
|
||||
/// and length
|
||||
pub mod preservation_bucket;
|
||||
|
||||
/// Defines structures related to rare combination comparisons
|
||||
/// between the synthetic and sensitive datasets
|
||||
pub mod rare_combinations_comparison_data;
|
||||
|
||||
/// Type definitions related to the evaluation process
|
||||
pub mod typedefs;
|
||||
|
||||
use self::bucket::{PreservationByCountBucket, PreservationByCountBucketBins};
|
||||
use self::typedefs::PreservationByCountBuckets;
|
||||
use fnv::FnvHashSet;
|
||||
use instant::Duration;
|
||||
use log::Level::Trace;
|
||||
use log::{info, log_enabled, trace};
|
||||
mod data_evaluator;
|
||||
|
||||
use crate::measure_time;
|
||||
use crate::processing::aggregator::typedefs::ValueCombination;
|
||||
use crate::utils::time::ElapsedDuration;
|
||||
pub use data_evaluator::Evaluator;
|
||||
|
||||
use super::aggregator::typedefs::{
|
||||
AggregatedCountByLenMap, AggregatesCountMap, ValueCombinationSlice,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EvaluatorDurations {
|
||||
calc_leakage_count: Duration,
|
||||
calc_fabricated_count: Duration,
|
||||
calc_preservation_by_count: Duration,
|
||||
calc_combination_loss: Duration,
|
||||
}
|
||||
|
||||
/// Evaluates aggregated, sensitive and synthesized data
|
||||
pub struct Evaluator {
|
||||
durations: EvaluatorDurations,
|
||||
}
|
||||
|
||||
impl Evaluator {
|
||||
/// Returns a new Evaluator
|
||||
#[inline]
|
||||
pub fn default() -> Evaluator {
|
||||
Evaluator {
|
||||
durations: EvaluatorDurations {
|
||||
calc_leakage_count: Duration::default(),
|
||||
calc_fabricated_count: Duration::default(),
|
||||
calc_preservation_by_count: Duration::default(),
|
||||
calc_combination_loss: Duration::default(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates the leakage counts grouped by combination length
|
||||
/// (how many attribute combinations exist on the sensitive data and
|
||||
/// appear on the synthetic data with `count < resolution`).
|
||||
/// By design this should be `0`
|
||||
/// # Arguments
|
||||
/// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data
|
||||
/// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data
|
||||
pub fn calc_leakage_count(
|
||||
&mut self,
|
||||
sensitive_aggregates: &AggregatesCountMap,
|
||||
synthetic_aggregates: &AggregatesCountMap,
|
||||
resolution: usize,
|
||||
) -> AggregatedCountByLenMap {
|
||||
info!("calculating rare sensitive combination leakages by length");
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
for (sensitive_agg, sensitive_count) in sensitive_aggregates.iter() {
|
||||
if sensitive_count.count < resolution {
|
||||
if let Some(synthetic_count) = synthetic_aggregates.get(sensitive_agg) {
|
||||
if synthetic_count.count < resolution {
|
||||
let leaks = result.entry(sensitive_agg.len()).or_insert(0);
|
||||
*leaks += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
},
|
||||
(self.durations.calc_leakage_count)
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates the fabricated counts grouped by combination length
|
||||
/// (how many attribute combinations exist on the synthetic data that do not
|
||||
/// exist on the sensitive data).
|
||||
/// By design this should be `0`
|
||||
/// # Arguments
|
||||
/// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data
|
||||
/// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data
|
||||
pub fn calc_fabricated_count(
|
||||
&mut self,
|
||||
sensitive_aggregates: &AggregatesCountMap,
|
||||
synthetic_aggregates: &AggregatesCountMap,
|
||||
) -> AggregatedCountByLenMap {
|
||||
info!("calculating fabricated synthetic combinations by length");
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
|
||||
|
||||
for synthetic_agg in synthetic_aggregates.keys() {
|
||||
if sensitive_aggregates.get(synthetic_agg).is_none() {
|
||||
let fabricated = result.entry(synthetic_agg.len()).or_insert(0);
|
||||
*fabricated += 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
},
|
||||
(self.durations.calc_fabricated_count)
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates the preservation information by bucket.
|
||||
/// An example output might be a map like:
|
||||
/// `{ 10 -> PreservationByCountBucket, 20 -> PreservationByCountBucket, ...}`
|
||||
/// # Arguments
|
||||
/// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data
|
||||
/// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
pub fn calc_preservation_by_count(
|
||||
&mut self,
|
||||
sensitive_aggregates: &AggregatesCountMap,
|
||||
synthetic_aggregates: &AggregatesCountMap,
|
||||
resolution: usize,
|
||||
) -> PreservationByCountBuckets {
|
||||
info!(
|
||||
"calculating preservation by count with resolution: {}",
|
||||
resolution
|
||||
);
|
||||
measure_time!(
|
||||
|| {
|
||||
let max_syn_count = *synthetic_aggregates
|
||||
.values()
|
||||
.map(|a| &a.count)
|
||||
.max()
|
||||
.unwrap_or(&0);
|
||||
let bins: PreservationByCountBucketBins =
|
||||
PreservationByCountBucketBins::new(max_syn_count);
|
||||
let mut buckets: PreservationByCountBuckets = PreservationByCountBuckets::default();
|
||||
let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default();
|
||||
|
||||
for (comb, count) in sensitive_aggregates.iter() {
|
||||
if count.count >= resolution && !processed_combs.contains(comb) {
|
||||
Evaluator::populate_buckets(
|
||||
&mut buckets,
|
||||
&bins,
|
||||
comb,
|
||||
sensitive_aggregates,
|
||||
synthetic_aggregates,
|
||||
);
|
||||
processed_combs.insert(comb);
|
||||
}
|
||||
}
|
||||
for comb in synthetic_aggregates.keys() {
|
||||
if !processed_combs.contains(comb) {
|
||||
Evaluator::populate_buckets(
|
||||
&mut buckets,
|
||||
&bins,
|
||||
comb,
|
||||
sensitive_aggregates,
|
||||
synthetic_aggregates,
|
||||
);
|
||||
processed_combs.insert(comb);
|
||||
}
|
||||
}
|
||||
|
||||
// fill empty buckets with default value
|
||||
for bin in bins.iter() {
|
||||
if !buckets.contains_key(bin) {
|
||||
buckets.insert(*bin, PreservationByCountBucket::default());
|
||||
}
|
||||
}
|
||||
|
||||
buckets
|
||||
},
|
||||
(self.durations.calc_preservation_by_count)
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates the combination loss for the calculated bucket preservation information
|
||||
/// (`combination_loss = avg(1 - bucket_preservation_sum / bucket_size) for all buckets`)
|
||||
/// # Arguments
|
||||
/// - `buckets` - Calculated preservation buckets
|
||||
pub fn calc_combination_loss(&mut self, buckets: &PreservationByCountBuckets) -> f64 {
|
||||
measure_time!(
|
||||
|| {
|
||||
buckets
|
||||
.values()
|
||||
.map(|b| 1.0 - (b.preservation_sum / (b.size as f64)))
|
||||
.sum::<f64>()
|
||||
/ (buckets.len() as f64)
|
||||
},
|
||||
(self.durations.calc_combination_loss)
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn populate_buckets(
|
||||
buckets: &mut PreservationByCountBuckets,
|
||||
bins: &PreservationByCountBucketBins,
|
||||
comb: &ValueCombinationSlice,
|
||||
sensitive_aggregates: &AggregatesCountMap,
|
||||
synthetic_aggregates: &AggregatesCountMap,
|
||||
) {
|
||||
let sen_count = match sensitive_aggregates.get(comb) {
|
||||
Some(count) => count.count,
|
||||
_ => 0,
|
||||
};
|
||||
let syn_count = match synthetic_aggregates.get(comb) {
|
||||
Some(count) => count.count,
|
||||
_ => 0,
|
||||
};
|
||||
let preservation = if sen_count > 0 {
|
||||
// max value is 100%, so use min
|
||||
f64::min((syn_count as f64) / (sen_count as f64), 1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
buckets
|
||||
.entry(bins.find_bucket_max_val(syn_count))
|
||||
.or_insert_with(PreservationByCountBucket::default)
|
||||
.add(preservation, comb.len());
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Evaluator {
|
||||
fn drop(&mut self) {
|
||||
trace!("evaluator durations: {:#?}", self.durations);
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "pyo3")]
|
||||
pub use data_evaluator::register;
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
/// Bucket to store preservation information
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PreservationBucket {
|
||||
/// How many elements are stored in the bucket
|
||||
pub size: usize,
|
||||
/// Preservation sum of all elements in the bucket
|
||||
pub preservation_sum: f64,
|
||||
/// Combination length sum of all elements in the bucket
|
||||
pub length_sum: usize,
|
||||
/// Combination count sum
|
||||
pub combination_count_sum: usize,
|
||||
}
|
||||
|
||||
impl PreservationBucket {
|
||||
/// Return a new PreservationBucket with default values
|
||||
pub fn default() -> PreservationBucket {
|
||||
PreservationBucket {
|
||||
size: 0,
|
||||
preservation_sum: 0.0,
|
||||
length_sum: 0,
|
||||
combination_count_sum: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a new value to the bucket
|
||||
/// # Arguments
|
||||
/// * `preservation` - Preservation related to the value
|
||||
/// * `length` - Combination length related to the value
|
||||
/// * `combination_count` - Combination count related to the value
|
||||
pub fn add(&mut self, preservation: f64, length: usize, combination_count: usize) {
|
||||
self.size += 1;
|
||||
self.preservation_sum += preservation;
|
||||
self.length_sum += length;
|
||||
self.combination_count_sum += combination_count;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl PreservationBucket {
|
||||
/// Gets the mean preservation for the values in this bucket
|
||||
pub fn get_mean_preservation(&self) -> f64 {
|
||||
self.preservation_sum / (self.size as f64)
|
||||
}
|
||||
|
||||
/// Gets the mean combination length for the values in this bucket
|
||||
pub fn get_mean_combination_length(&self) -> f64 {
|
||||
(self.length_sum as f64) / (self.size as f64)
|
||||
}
|
||||
|
||||
/// Gets the mean combination count for the values in this bucket
|
||||
pub fn get_mean_combination_count(&self) -> f64 {
|
||||
(self.combination_count_sum as f64) / (self.size as f64)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,191 @@
|
|||
use super::{preservation_bucket::PreservationBucket, typedefs::PreservationBucketsMap};
|
||||
use itertools::Itertools;
|
||||
use std::{
|
||||
io::{Error, Write},
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::{
|
||||
processing::aggregator::{typedefs::AggregatesCountMap, value_combination::ValueCombination},
|
||||
utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
|
||||
const INITIAL_BIN: usize = 10;
|
||||
const BIN_RATIO: usize = 2;
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
/// Wrapping struct mapping the max value allowed in the bucket
|
||||
/// to its correspondent PreservationBucket.
|
||||
/// In this context a PreservationBucket stores the preservation information
|
||||
/// related to a particular bucket.
|
||||
/// A bucket stores counts in a certain range value.
|
||||
/// For example: (0, 10], (10, 20], (20, 40]...
|
||||
pub struct PreservationByCountBuckets {
|
||||
buckets_map: PreservationBucketsMap,
|
||||
}
|
||||
|
||||
impl PreservationByCountBuckets {
|
||||
/// Returns a new default PreservationByCountBuckets
|
||||
#[inline]
|
||||
pub fn default() -> PreservationByCountBuckets {
|
||||
PreservationByCountBuckets {
|
||||
buckets_map: PreservationBucketsMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn populate(
|
||||
&mut self,
|
||||
bins: &PreservationByCountBucketBins,
|
||||
comb: &ValueCombination,
|
||||
sensitive_aggregates: &AggregatesCountMap,
|
||||
synthetic_aggregates: &AggregatesCountMap,
|
||||
) {
|
||||
let sen_count = match sensitive_aggregates.get(comb) {
|
||||
Some(count) => count.count,
|
||||
_ => 0,
|
||||
};
|
||||
let syn_count = match synthetic_aggregates.get(comb) {
|
||||
Some(count) => count.count,
|
||||
_ => 0,
|
||||
};
|
||||
let preservation = if sen_count > 0 {
|
||||
// max value is 100%, so use min
|
||||
f64::min((syn_count as f64) / (sen_count as f64), 1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
self.buckets_map
|
||||
.entry(bins.find_bucket_max_val(syn_count))
|
||||
.or_insert_with(PreservationBucket::default)
|
||||
.add(preservation, comb.len(), syn_count);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl PreservationByCountBuckets {
|
||||
/// Returns the actual buckets grouped by its max value.
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
pub fn get_buckets(&self) -> PreservationBucketsMap {
|
||||
self.buckets_map.clone()
|
||||
}
|
||||
|
||||
/// Calculates the combination loss for the calculated bucket preservation information
|
||||
/// (`combination_loss = avg(1 - bucket_preservation_sum / bucket_size) for all buckets`)
|
||||
pub fn calc_combination_loss(&self) -> f64 {
|
||||
let _duration_logger = ElapsedDurationLogger::new("combination loss calculation");
|
||||
|
||||
self.buckets_map
|
||||
.values()
|
||||
.map(|b| 1.0 - (b.preservation_sum / (b.size as f64)))
|
||||
.sum::<f64>()
|
||||
/ (self.buckets_map.len() as f64)
|
||||
}
|
||||
|
||||
/// Writes the preservation grouped by counts to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `preservation_by_count_path` - File path to be written
|
||||
/// * `preservation_by_count_delimiter` - Delimiter to use when writing to `preservation_by_count_path`
|
||||
pub fn write_preservation_by_count(
|
||||
&self,
|
||||
preservation_by_count_path: &str,
|
||||
preservation_by_count_delimiter: char,
|
||||
) -> Result<(), Error> {
|
||||
let mut file = std::fs::File::create(preservation_by_count_path)?;
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"syn_count_bucket{}mean_combo_count{}mean_combo_length{}count_preservation\n",
|
||||
preservation_by_count_delimiter,
|
||||
preservation_by_count_delimiter,
|
||||
preservation_by_count_delimiter,
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
for max_val in self.buckets_map.keys().sorted().rev() {
|
||||
let b = &self.buckets_map[max_val];
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"{}{}{}{}{}{}{}\n",
|
||||
max_val,
|
||||
preservation_by_count_delimiter,
|
||||
b.get_mean_combination_count(),
|
||||
preservation_by_count_delimiter,
|
||||
b.get_mean_combination_length(),
|
||||
preservation_by_count_delimiter,
|
||||
b.get_mean_preservation(),
|
||||
)
|
||||
.as_bytes(),
|
||||
)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PreservationByCountBuckets {
|
||||
type Target = PreservationBucketsMap;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.buckets_map
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for PreservationByCountBuckets {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.buckets_map
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the bins for each bucket. For example:
|
||||
/// [10, 20, 40, 80, 160...]
|
||||
#[derive(Debug)]
|
||||
pub struct PreservationByCountBucketBins {
|
||||
/// Bins where the index of the vector is the bin index,
|
||||
/// and the value is the max value count allowed on the bin
|
||||
bins: Vec<usize>,
|
||||
}
|
||||
|
||||
impl PreservationByCountBucketBins {
|
||||
/// Generates a new PreservationByCountBucketBins with bins up to a `max_val`
|
||||
/// # Arguments
|
||||
/// * `max_val` - Max value allowed on the last bucket (inclusive)
|
||||
pub fn new(max_val: usize) -> PreservationByCountBucketBins {
|
||||
let mut bins: Vec<usize> = vec![INITIAL_BIN];
|
||||
loop {
|
||||
let last = *bins.last().unwrap();
|
||||
if last >= max_val {
|
||||
break;
|
||||
}
|
||||
bins.push(last * BIN_RATIO);
|
||||
}
|
||||
PreservationByCountBucketBins { bins }
|
||||
}
|
||||
|
||||
/// Find the first `bucket_max_val` where `val >= bucket_max_val`
|
||||
/// # Arguments
|
||||
/// * `val` - Count to look in which bucket it will fit
|
||||
pub fn find_bucket_max_val(&self, val: usize) -> usize {
|
||||
// find first element x where x >= val
|
||||
self.bins[self.bins.partition_point(|x| *x < val)]
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PreservationByCountBucketBins {
|
||||
type Target = Vec<usize>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.bins
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for PreservationByCountBucketBins {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.bins
|
||||
}
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
use super::{preservation_bucket::PreservationBucket, typedefs::PreservationBucketsMap};
|
||||
use itertools::Itertools;
|
||||
use std::{
|
||||
io::{Error, Write},
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::processing::aggregator::{
|
||||
typedefs::AggregatesCountMap, value_combination::ValueCombination,
|
||||
};
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
/// Wrapping to store the preservation buckets grouped
|
||||
/// by length
|
||||
pub struct PreservationByLengthBuckets {
|
||||
buckets_map: PreservationBucketsMap,
|
||||
}
|
||||
|
||||
impl PreservationByLengthBuckets {
|
||||
/// Returns a new default PreservationByLengthBuckets
|
||||
#[inline]
|
||||
pub fn default() -> PreservationByLengthBuckets {
|
||||
PreservationByLengthBuckets {
|
||||
buckets_map: PreservationBucketsMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn populate(
|
||||
&mut self,
|
||||
comb: &ValueCombination,
|
||||
sensitive_aggregates: &AggregatesCountMap,
|
||||
synthetic_aggregates: &AggregatesCountMap,
|
||||
) {
|
||||
let sen_count = match sensitive_aggregates.get(comb) {
|
||||
Some(count) => count.count,
|
||||
_ => 0,
|
||||
};
|
||||
let syn_count = match synthetic_aggregates.get(comb) {
|
||||
Some(count) => count.count,
|
||||
_ => 0,
|
||||
};
|
||||
let preservation = if sen_count > 0 {
|
||||
// max value is 100%, so use min as syn count might be > sen_count
|
||||
f64::min((syn_count as f64) / (sen_count as f64), 1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
self.buckets_map
|
||||
.entry(comb.len())
|
||||
.or_insert_with(PreservationBucket::default)
|
||||
.add(preservation, comb.len(), syn_count);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl PreservationByLengthBuckets {
|
||||
/// Returns the actual buckets grouped by length.
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
pub fn get_buckets(&self) -> PreservationBucketsMap {
|
||||
self.buckets_map.clone()
|
||||
}
|
||||
|
||||
/// Writes the preservation grouped by length to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `preservation_by_length_path` - File path to be written
|
||||
/// * `preservation_by_length_delimiter` - Delimiter to use when writing to `preservation_by_length_path`
|
||||
pub fn write_preservation_by_length(
|
||||
&mut self,
|
||||
preservation_by_length_path: &str,
|
||||
preservation_by_length_delimiter: char,
|
||||
) -> Result<(), Error> {
|
||||
let mut file = std::fs::File::create(preservation_by_length_path)?;
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"syn_combo_length{}mean_combo_count{}count_preservation\n",
|
||||
preservation_by_length_delimiter, preservation_by_length_delimiter,
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
for length in self.buckets_map.keys().sorted() {
|
||||
let b = &self.buckets_map[length];
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"{}{}{}{}{}\n",
|
||||
length,
|
||||
preservation_by_length_delimiter,
|
||||
b.get_mean_combination_count(),
|
||||
preservation_by_length_delimiter,
|
||||
b.get_mean_preservation(),
|
||||
)
|
||||
.as_bytes(),
|
||||
)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PreservationByLengthBuckets {
|
||||
type Target = PreservationBucketsMap;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.buckets_map
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for PreservationByLengthBuckets {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.buckets_map
|
||||
}
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
use std::io::{Error, Write};
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::{
|
||||
data_block::typedefs::CombinationsComparisons,
|
||||
processing::aggregator::aggregated_data::AggregatedData, utils::math::uround_down,
|
||||
};
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
#[derive(Clone)]
|
||||
/// Represents a single combination comparison
|
||||
pub struct CombinationComparison {
|
||||
/// Length of the combination
|
||||
pub combination_length: usize,
|
||||
/// Combination formatted as string
|
||||
pub combination: String,
|
||||
/// Index of the record where this combination occur
|
||||
pub record_index: usize,
|
||||
/// Number of times this combination occurs on the synthetic data
|
||||
pub synthetic_count: usize,
|
||||
/// Number of times this combination occurs on the sensitive data
|
||||
pub sensitive_count: usize,
|
||||
}
|
||||
|
||||
impl CombinationComparison {
|
||||
#[inline]
|
||||
/// Creates a new CombinationComparison
|
||||
pub fn new(
|
||||
combination_length: usize,
|
||||
combination: String,
|
||||
record_index: usize,
|
||||
synthetic_count: usize,
|
||||
sensitive_count: usize,
|
||||
) -> CombinationComparison {
|
||||
CombinationComparison {
|
||||
combination_length,
|
||||
combination,
|
||||
record_index,
|
||||
synthetic_count,
|
||||
sensitive_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl CombinationComparison {
|
||||
#[getter]
|
||||
/// Length of the combination
|
||||
fn combination_length(&self) -> usize {
|
||||
self.combination_length
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Combination formatted as string
|
||||
fn combination(&self) -> String {
|
||||
self.combination.clone()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Index of the record where this combination occur
|
||||
fn record_index(&self) -> usize {
|
||||
self.record_index
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Number of times this combination occurs on the synthetic data
|
||||
fn synthetic_count(&self) -> usize {
|
||||
self.synthetic_count
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Number of times this combination occurs on the sensitive data
|
||||
fn sensitive_count(&self) -> usize {
|
||||
self.sensitive_count
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
/// Computed rare combination comparisons between the synthetic
|
||||
/// and sensitive datasets for all records
|
||||
pub struct RareCombinationsComparisonData {
|
||||
/// Rare combination comparison for all records
|
||||
/// (compares rare combinations on the synthetic dataset with
|
||||
/// the sensitive counts)
|
||||
pub rare_combinations: CombinationsComparisons,
|
||||
}
|
||||
|
||||
impl RareCombinationsComparisonData {
|
||||
//// Build a new comparison between the rare combinations on the synthetic
|
||||
/// data and the sensitive data counts
|
||||
/// # Arguments
|
||||
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
|
||||
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `combination_delimiter` - Delimiter used to join combinations and format then
|
||||
/// as strings
|
||||
/// * `protect` - Whether or not the sensitive counts should be rounded to the
|
||||
/// nearest smallest multiple of resolution
|
||||
#[inline]
|
||||
pub fn from_synthetic_and_sensitive_aggregated_data(
|
||||
synthetic_aggregated_data: &AggregatedData,
|
||||
sensitive_aggregated_data: &AggregatedData,
|
||||
resolution: usize,
|
||||
combination_delimiter: &str,
|
||||
protect: bool,
|
||||
) -> RareCombinationsComparisonData {
|
||||
let mut rare_combinations: CombinationsComparisons = CombinationsComparisons::default();
|
||||
let resolution_f64 = resolution as f64;
|
||||
|
||||
for (agg, count) in synthetic_aggregated_data.aggregates_count.iter() {
|
||||
if count.count < resolution {
|
||||
let combination_str = agg.format_str_using_headers(
|
||||
&synthetic_aggregated_data.data_block.headers,
|
||||
combination_delimiter,
|
||||
);
|
||||
let mut sensitive_count = sensitive_aggregated_data
|
||||
.aggregates_count
|
||||
.get(agg)
|
||||
.map_or(0, |c| c.count);
|
||||
|
||||
if protect {
|
||||
sensitive_count = uround_down(sensitive_count as f64, resolution_f64)
|
||||
}
|
||||
|
||||
for record_index in count.contained_in_records.iter() {
|
||||
rare_combinations.push(CombinationComparison::new(
|
||||
agg.len(),
|
||||
combination_str.clone(),
|
||||
*record_index,
|
||||
count.count,
|
||||
sensitive_count,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort result by combination length
|
||||
rare_combinations.sort_by_key(|c| c.combination_length);
|
||||
|
||||
RareCombinationsComparisonData { rare_combinations }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl RareCombinationsComparisonData {
|
||||
/// Returns the rare combination comparisons between the synthetic
|
||||
/// and sensitive datasets for all records.
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
pub fn get_rare_combinations(&self) -> CombinationsComparisons {
|
||||
self.rare_combinations.clone()
|
||||
}
|
||||
|
||||
/// Writes the rare combinations to the file system in a csv/tsv like format
|
||||
/// # Arguments:
|
||||
/// * `rare_combinations_path` - File path to be written
|
||||
/// * `rare_combinations_delimiter` - Delimiter to use when writing to `rare_combinations_path`
|
||||
pub fn write_rare_combinations(
|
||||
&mut self,
|
||||
rare_combinations_path: &str,
|
||||
rare_combinations_delimiter: char,
|
||||
) -> Result<(), Error> {
|
||||
let mut file = std::fs::File::create(rare_combinations_path)?;
|
||||
|
||||
file.write_all(
|
||||
format!(
|
||||
"combo_length{}combo{}record_id{}syn_count{}sen_count\n",
|
||||
rare_combinations_delimiter,
|
||||
rare_combinations_delimiter,
|
||||
rare_combinations_delimiter,
|
||||
rare_combinations_delimiter,
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
for ra in self.rare_combinations.iter() {
|
||||
file.write_all(
|
||||
format!(
|
||||
"{}{}{}{}{}{}{}{}{}\n",
|
||||
ra.combination_length,
|
||||
rare_combinations_delimiter,
|
||||
ra.combination,
|
||||
rare_combinations_delimiter,
|
||||
ra.record_index,
|
||||
rare_combinations_delimiter,
|
||||
ra.synthetic_count,
|
||||
rare_combinations_delimiter,
|
||||
ra.sensitive_count,
|
||||
)
|
||||
.as_bytes(),
|
||||
)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
use super::bucket::PreservationByCountBucket;
|
||||
use super::preservation_bucket::PreservationBucket;
|
||||
use fnv::FnvHashMap;
|
||||
|
||||
/// Maps the max value allowed in the bucket to its
|
||||
/// correspondent PreservationByCountBucket
|
||||
pub type PreservationByCountBuckets = FnvHashMap<usize, PreservationByCountBucket>;
|
||||
/// Maps a value to its correspondent PreservationBucket
|
||||
pub type PreservationBucketsMap = FnvHashMap<usize, PreservationBucket>;
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
use super::generated_data::GeneratedData;
|
||||
use super::synthesizer::typedefs::SynthesizedRecords;
|
||||
use super::synthesizer::unseeded::UnseededSynthesizer;
|
||||
use super::SynthesisMode;
|
||||
use log::info;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::data_block::block::DataBlock;
|
||||
use crate::data_block::typedefs::RawSyntheticData;
|
||||
use crate::processing::generator::synthesizer::cache::SynthesizerCacheKey;
|
||||
use crate::processing::generator::synthesizer::seeded::SeededSynthesizer;
|
||||
use crate::utils::reporting::ReportProgress;
|
||||
use crate::utils::time::ElapsedDurationLogger;
|
||||
|
||||
/// Process a data block and generates new synthetic data
|
||||
pub struct Generator {
|
||||
data_block: Arc<DataBlock>,
|
||||
}
|
||||
|
||||
impl Generator {
|
||||
/// Returns a new Generator
|
||||
/// # Arguments
|
||||
/// * `data_block` - Sensitive data to be synthesized
|
||||
#[inline]
|
||||
pub fn new(data_block: Arc<DataBlock>) -> Generator {
|
||||
Generator { data_block }
|
||||
}
|
||||
|
||||
/// Generates new synthetic data based on sensitive data
|
||||
/// # Arguments
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `cache_max_size` - Maximum cache size allowed
|
||||
/// * `empty_value` - Empty values on the synthetic data will be represented by this
|
||||
/// * `mode` - Which mode to perform the data synthesis
|
||||
/// * `progress_reporter` - Will be used to report the processing
|
||||
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
|
||||
pub fn generate<T>(
|
||||
&mut self,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
empty_value: String,
|
||||
mode: SynthesisMode,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> GeneratedData
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let _duration_logger = ElapsedDurationLogger::new("data generation");
|
||||
let empty_value_arc = Arc::new(empty_value);
|
||||
|
||||
info!("starting {} generation...", mode);
|
||||
|
||||
let synthesized_records = match mode {
|
||||
SynthesisMode::Seeded => {
|
||||
self.seeded_synthesis(resolution, cache_max_size, progress_reporter)
|
||||
}
|
||||
SynthesisMode::Unseeded => self.unseeded_synthesis(
|
||||
resolution,
|
||||
cache_max_size,
|
||||
&empty_value_arc,
|
||||
progress_reporter,
|
||||
),
|
||||
};
|
||||
|
||||
self.build_generated_data(synthesized_records, empty_value_arc)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn build_generated_data(
|
||||
&self,
|
||||
mut synthesized_records: SynthesizedRecords,
|
||||
empty_value: Arc<String>,
|
||||
) -> GeneratedData {
|
||||
let mut result: RawSyntheticData = RawSyntheticData::default();
|
||||
let mut records: RawSyntheticData = synthesized_records
|
||||
.drain(..)
|
||||
.map(|r| {
|
||||
SynthesizerCacheKey::new(self.data_block.headers.len(), &r)
|
||||
.format_record(&empty_value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// sort by number of defined attributes
|
||||
records.sort();
|
||||
records.sort_by_key(|r| {
|
||||
-r.iter()
|
||||
.map(|s| if s.is_empty() { 0 } else { 1 })
|
||||
.sum::<isize>()
|
||||
});
|
||||
|
||||
result.push(self.data_block.headers.to_vec());
|
||||
result.extend(records);
|
||||
|
||||
let expansion_ratio = (result.len() - 1) as f64 / self.data_block.records.len() as f64;
|
||||
|
||||
info!("expansion ratio: {:.4?}", expansion_ratio);
|
||||
|
||||
GeneratedData::new(result, expansion_ratio)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn seeded_synthesis<T>(
|
||||
&self,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> SynthesizedRecords
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let attr_rows_map = Arc::new(self.data_block.calc_attr_rows());
|
||||
let mut synth = SeededSynthesizer::new(
|
||||
self.data_block.clone(),
|
||||
attr_rows_map,
|
||||
resolution,
|
||||
cache_max_size,
|
||||
);
|
||||
synth.run(progress_reporter)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn unseeded_synthesis<T>(
|
||||
&self,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
empty_value: &Arc<String>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> SynthesizedRecords
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let attr_rows_map_by_column =
|
||||
Arc::new(self.data_block.calc_attr_rows_by_column(empty_value));
|
||||
let mut synth = UnseededSynthesizer::new(
|
||||
self.data_block.clone(),
|
||||
attr_rows_map_by_column,
|
||||
resolution,
|
||||
cache_max_size,
|
||||
empty_value.clone(),
|
||||
);
|
||||
synth.run(progress_reporter)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
use csv::WriterBuilder;
|
||||
use log::info;
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use crate::data_block::typedefs::CsvRecord;
|
||||
|
||||
use crate::{
|
||||
data_block::{csv_io_error::CsvIOError, typedefs::RawSyntheticData},
|
||||
utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pyclass)]
|
||||
/// Synthetic data generated by the Generator
|
||||
pub struct GeneratedData {
|
||||
/// Synthesized data - headers (index 0) and records indexes 1...
|
||||
pub synthetic_data: RawSyntheticData,
|
||||
/// `Synthetic data length / Sensitive data length` (header not included)
|
||||
pub expansion_ratio: f64,
|
||||
}
|
||||
|
||||
impl GeneratedData {
|
||||
/// Returns a new GeneratedData struct with default values
|
||||
#[inline]
|
||||
pub fn default() -> GeneratedData {
|
||||
GeneratedData {
|
||||
synthetic_data: RawSyntheticData::default(),
|
||||
expansion_ratio: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new GeneratedData struct
|
||||
/// # Arguments
|
||||
/// * `synthetic_data` - Synthesized data - headers (index 0) and records indexes 1...
|
||||
/// * `expansion_ratio` - `Synthetic data length / Sensitive data length` (header not included)
|
||||
#[inline]
|
||||
pub fn new(synthetic_data: RawSyntheticData, expansion_ratio: f64) -> GeneratedData {
|
||||
GeneratedData {
|
||||
synthetic_data,
|
||||
expansion_ratio,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3", pymethods)]
|
||||
impl GeneratedData {
|
||||
#[cfg(feature = "pyo3")]
|
||||
/// Synthesized data - headers (index 0) and records indexes 1...
|
||||
/// This method will clone the data, so its recommended to have its result stored
|
||||
/// in a local variable to avoid it being called multiple times
|
||||
fn get_synthetic_data(&self) -> Vec<CsvRecord> {
|
||||
self.synthetic_data
|
||||
.iter()
|
||||
.map(|row| row.iter().map(|value| (**value).clone()).collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
#[getter]
|
||||
/// `Synthetic data length / Sensitive data length` (header not included)
|
||||
fn expansion_ratio(&self) -> f64 {
|
||||
self.expansion_ratio
|
||||
}
|
||||
|
||||
/// Writes the synthesized data to the file system
|
||||
/// # Arguments
|
||||
/// * `path` - File path to be written
|
||||
/// * `delimiter` - Delimiter to use when writing to `path`
|
||||
pub fn write_synthetic_data(&self, path: &str, delimiter: char) -> Result<(), CsvIOError> {
|
||||
let _duration_logger = ElapsedDurationLogger::new("write synthetic data");
|
||||
|
||||
info!("writing file {}", path);
|
||||
|
||||
let mut wtr = match WriterBuilder::new()
|
||||
.delimiter(delimiter as u8)
|
||||
.from_path(&path)
|
||||
{
|
||||
Ok(writer) => writer,
|
||||
Err(err) => return Err(CsvIOError::new(err)),
|
||||
};
|
||||
|
||||
// write header and records
|
||||
for r in self.synthetic_data.iter() {
|
||||
match wtr.write_record(r.iter().map(|v| v.as_str())) {
|
||||
Ok(_) => {}
|
||||
Err(err) => return Err(CsvIOError::new(err)),
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,156 +1,13 @@
|
|||
/// Module defining the data synthesis process
|
||||
pub mod synthesizer;
|
||||
|
||||
use csv::{Error, WriterBuilder};
|
||||
use log::Level::Trace;
|
||||
use log::{info, log_enabled, trace};
|
||||
use std::time::Duration;
|
||||
mod synthesis_mode;
|
||||
|
||||
use crate::data_block::block::DataBlock;
|
||||
use crate::data_block::typedefs::CsvRecordRef;
|
||||
use crate::measure_time;
|
||||
use crate::processing::generator::synthesizer::cache::SynthesizerCacheKey;
|
||||
use crate::processing::generator::synthesizer::seeded::SeededSynthesizer;
|
||||
use crate::utils::reporting::ReportProgress;
|
||||
use crate::utils::time::ElapsedDuration;
|
||||
pub use synthesis_mode::SynthesisMode;
|
||||
|
||||
use self::synthesizer::typedefs::SynthesizedRecords;
|
||||
mod data_generator;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GeneratorDurations {
|
||||
generate: Duration,
|
||||
write_records: Duration,
|
||||
}
|
||||
pub use data_generator::Generator;
|
||||
|
||||
/// Process a data block and generates new synthetic data
|
||||
pub struct Generator<'data_block> {
|
||||
data_block: &'data_block DataBlock,
|
||||
durations: GeneratorDurations,
|
||||
}
|
||||
|
||||
/// Synthetic data generated by the Generator
|
||||
pub struct GeneratedData<'data_block> {
|
||||
/// Synthesized data - headers (index 0) and records indexes 1...
|
||||
pub synthetic_data: Vec<CsvRecordRef<'data_block>>,
|
||||
/// `Synthetic data length / Sensitive data length` (header not included)
|
||||
pub expansion_ratio: f64,
|
||||
}
|
||||
|
||||
impl<'data_block> Generator<'data_block> {
|
||||
/// Returns a new Generator
|
||||
/// # Arguments
|
||||
/// * `data_block` - Sensitive data to be synthesized
|
||||
#[inline]
|
||||
pub fn new(data_block: &'data_block DataBlock) -> Generator<'data_block> {
|
||||
Generator {
|
||||
data_block,
|
||||
durations: GeneratorDurations {
|
||||
generate: Duration::default(),
|
||||
write_records: Duration::default(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates new synthetic data based on sensitive data
|
||||
/// # Arguments
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `cache_max_size` - Maximum cache size allowed
|
||||
/// * `empty_value` - Empty values on the synthetic data will be represented by this
|
||||
/// * `progress_reporter` - Will be used to report the processing
|
||||
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
|
||||
pub fn generate<T>(
|
||||
&mut self,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
empty_value: &'data_block str,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> GeneratedData<'data_block>
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
info!("starting generation...");
|
||||
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut result: Vec<CsvRecordRef<'data_block>> = Vec::default();
|
||||
let mut records: Vec<CsvRecordRef<'data_block>> = self
|
||||
.seeded_synthesis(resolution, cache_max_size, progress_reporter)
|
||||
.iter()
|
||||
.map(|r| {
|
||||
SynthesizerCacheKey::new(self.data_block.headers.len(), r)
|
||||
.format_record(empty_value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// sort by number of defined attributes
|
||||
records.sort();
|
||||
records.sort_by_key(|r| {
|
||||
-r.iter()
|
||||
.map(|s| if s.is_empty() { 0 } else { 1 })
|
||||
.sum::<isize>()
|
||||
});
|
||||
|
||||
let expansion_ratio = records.len() as f64 / self.data_block.records.len() as f64;
|
||||
|
||||
info!("expansion ratio: {:.4?}", expansion_ratio);
|
||||
|
||||
result.push(self.data_block.headers.iter().map(|h| h.as_str()).collect());
|
||||
result.extend(records);
|
||||
|
||||
GeneratedData {
|
||||
synthetic_data: result,
|
||||
expansion_ratio,
|
||||
}
|
||||
},
|
||||
(self.durations.generate)
|
||||
)
|
||||
}
|
||||
|
||||
/// Writes the synthesized data to the file system
|
||||
/// # Arguments
|
||||
/// * `result` - Synthetic data to write (headers included)
|
||||
/// * `path` - File path to be written
|
||||
/// * `delimiter` - Delimiter to use when writing to `path`
|
||||
pub fn write_records(
|
||||
&mut self,
|
||||
result: &[CsvRecordRef<'data_block>],
|
||||
path: &str,
|
||||
delimiter: u8,
|
||||
) -> Result<(), Error> {
|
||||
info!("writing file {}", path);
|
||||
|
||||
measure_time!(
|
||||
|| {
|
||||
let mut wtr = WriterBuilder::new().delimiter(delimiter).from_path(&path)?;
|
||||
|
||||
// write header and records
|
||||
for r in result.iter() {
|
||||
wtr.write_record(r)?;
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
(self.durations.write_records)
|
||||
)
|
||||
}
|
||||
|
||||
fn seeded_synthesis<T>(
|
||||
&mut self,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> SynthesizedRecords<'data_block>
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let attr_rows_map = DataBlock::calc_attr_rows(&self.data_block.records);
|
||||
let mut synth =
|
||||
SeededSynthesizer::new(self.data_block, &attr_rows_map, resolution, cache_max_size);
|
||||
synth.run(progress_reporter)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'data_block> Drop for Generator<'data_block> {
|
||||
fn drop(&mut self) {
|
||||
trace!("generator durations: {:#?}", self.durations);
|
||||
}
|
||||
}
|
||||
/// Module to describe the synthetic data generated by the Generator
|
||||
pub mod generated_data;
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
use std::{fmt::Display, str::FromStr};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Modes to execute the data generation/synthesis
|
||||
pub enum SynthesisMode {
|
||||
Seeded,
|
||||
Unseeded,
|
||||
}
|
||||
|
||||
impl FromStr for SynthesisMode {
|
||||
type Err = &'static str;
|
||||
|
||||
fn from_str(mode: &str) -> Result<Self, Self::Err> {
|
||||
match mode.to_lowercase().as_str() {
|
||||
"seeded" => Ok(SynthesisMode::Seeded),
|
||||
"unseeded" => Ok(SynthesisMode::Unseeded),
|
||||
_ => Err("invalid mode, should be seeded or unseeded"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for SynthesisMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SynthesisMode::Seeded => Ok(write!(f, "seeded")?),
|
||||
SynthesisMode::Unseeded => Ok(write!(f, "unseeded")?),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
use super::typedefs::SynthesizedRecord;
|
||||
use fnv::FnvBuildHasher;
|
||||
use lru::LruCache;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::data_block::{typedefs::CsvRecordRef, value::DataBlockValue};
|
||||
|
||||
|
@ -8,29 +9,26 @@ use crate::data_block::{typedefs::CsvRecordRef, value::DataBlockValue};
|
|||
/// (`columns[{column_index}] = `value` if value exists,
|
||||
/// or `columns[{column_index}] = None` otherwise)
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
||||
pub struct SynthesizerCacheKey<'value> {
|
||||
pub struct SynthesizerCacheKey {
|
||||
/// Values for a given synthesized record are indexed by `column_index`
|
||||
/// on this vector
|
||||
columns: Vec<Option<&'value String>>,
|
||||
columns: Vec<Option<Arc<String>>>,
|
||||
}
|
||||
|
||||
impl<'value> SynthesizerCacheKey<'value> {
|
||||
impl SynthesizerCacheKey {
|
||||
/// Returns a new SynthesizerCacheKey
|
||||
/// # Arguments
|
||||
/// * `num_columns` - Number of columns in the data block
|
||||
/// * `records` - Synthesized record to build the key for
|
||||
#[inline]
|
||||
pub fn new(
|
||||
num_columns: usize,
|
||||
values: &SynthesizedRecord<'value>,
|
||||
) -> SynthesizerCacheKey<'value> {
|
||||
pub fn new(num_columns: usize, values: &SynthesizedRecord) -> SynthesizerCacheKey {
|
||||
let mut key = SynthesizerCacheKey {
|
||||
columns: Vec::with_capacity(num_columns),
|
||||
};
|
||||
key.columns.resize_with(num_columns, || None);
|
||||
|
||||
for v in values.iter() {
|
||||
key.columns[v.column_index] = Some(&v.value);
|
||||
key.columns[v.column_index] = Some(v.value.clone());
|
||||
}
|
||||
key
|
||||
}
|
||||
|
@ -40,9 +38,9 @@ impl<'value> SynthesizerCacheKey<'value> {
|
|||
/// # Arguments
|
||||
/// * `value` - New data block value to be added in the new key
|
||||
#[inline]
|
||||
pub fn new_with_value(&self, value: &'value DataBlockValue) -> SynthesizerCacheKey<'value> {
|
||||
pub fn new_with_value(&self, value: &Arc<DataBlockValue>) -> SynthesizerCacheKey {
|
||||
let mut new_key = self.clone();
|
||||
new_key.columns[value.column_index] = Some(&value.value);
|
||||
new_key.columns[value.column_index] = Some(value.value.clone());
|
||||
new_key
|
||||
}
|
||||
|
||||
|
@ -59,29 +57,29 @@ impl<'value> SynthesizerCacheKey<'value> {
|
|||
/// # Arguments
|
||||
/// * `empty_value` - String to be used in case the `columns[column_index] == None`
|
||||
#[inline]
|
||||
pub fn format_record(&self, empty_value: &'value str) -> CsvRecordRef<'value> {
|
||||
pub fn format_record(&self, empty_value: &Arc<String>) -> CsvRecordRef {
|
||||
self.columns
|
||||
.iter()
|
||||
.map(|c_opt| match c_opt {
|
||||
Some(c) => c,
|
||||
None => empty_value,
|
||||
Some(c) => (*c).clone(),
|
||||
None => empty_value.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache to store keys-values used during the synthesis process
|
||||
pub struct SynthesizerCache<'value, T> {
|
||||
pub struct SynthesizerCache<T> {
|
||||
/// LruCache to store the keys mapping to a generic type T
|
||||
cache: LruCache<SynthesizerCacheKey<'value>, T, FnvBuildHasher>,
|
||||
cache: LruCache<SynthesizerCacheKey, T, FnvBuildHasher>,
|
||||
}
|
||||
|
||||
impl<'value, T> SynthesizerCache<'value, T> {
|
||||
impl<T> SynthesizerCache<T> {
|
||||
/// Returns a new SynthesizerCache
|
||||
/// # Arguments
|
||||
/// * `cache_max_size` - Maximum cache size allowed for the LRU
|
||||
#[inline]
|
||||
pub fn new(cache_max_size: usize) -> SynthesizerCache<'value, T> {
|
||||
pub fn new(cache_max_size: usize) -> SynthesizerCache<T> {
|
||||
SynthesizerCache {
|
||||
cache: LruCache::with_hasher(cache_max_size, FnvBuildHasher::default()),
|
||||
}
|
||||
|
@ -93,7 +91,7 @@ impl<'value, T> SynthesizerCache<'value, T> {
|
|||
/// # Arguments
|
||||
/// * `key` - Key to look for the value
|
||||
#[inline]
|
||||
pub fn get(&mut self, key: &SynthesizerCacheKey<'value>) -> Option<&T> {
|
||||
pub fn get(&mut self, key: &SynthesizerCacheKey) -> Option<&T> {
|
||||
self.cache.get(key)
|
||||
}
|
||||
|
||||
|
@ -104,7 +102,7 @@ impl<'value, T> SynthesizerCache<'value, T> {
|
|||
/// * `key` - Key to be inserted
|
||||
/// * `value` - Value to associated with the key
|
||||
#[inline]
|
||||
pub fn insert(&mut self, key: SynthesizerCacheKey<'value>, value: T) -> Option<T> {
|
||||
pub fn insert(&mut self, key: SynthesizerCacheKey, value: T) -> Option<T> {
|
||||
self.cache.put(key, value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,93 +2,76 @@ use super::cache::{SynthesizerCache, SynthesizerCacheKey};
|
|||
use super::typedefs::{
|
||||
AttributeCountMap, NotAllowedAttrSet, SynthesizedRecord, SynthesizerSeedSlice,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::data_block::typedefs::{AttributeRows, AttributeRowsMap, AttributeRowsSlice};
|
||||
use crate::data_block::typedefs::{
|
||||
AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, AttributeRowsRefMap,
|
||||
AttributeRowsSlice,
|
||||
};
|
||||
use crate::data_block::value::DataBlockValue;
|
||||
use crate::processing::aggregator::typedefs::RecordsSet;
|
||||
use crate::utils::collections::ordered_vec_intersection;
|
||||
|
||||
/// Represents a synthesis context, containing the information
|
||||
/// required to synthesize new records
|
||||
/// (common to seeded and unseeded synthesis)
|
||||
pub struct SynthesizerContext<'data_block, 'attr_rows_map> {
|
||||
pub struct SynthesizerContext {
|
||||
/// Number of headers in the data block
|
||||
headers_len: usize,
|
||||
pub headers_len: usize,
|
||||
/// Number of records in the data block
|
||||
records_len: usize,
|
||||
/// Maps a data block value to all the rows where it occurs
|
||||
attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>,
|
||||
pub records_len: usize,
|
||||
/// Cache used to store the rows where value combinations occurs
|
||||
cache: SynthesizerCache<'data_block, AttributeRows>,
|
||||
cache: SynthesizerCache<Arc<AttributeRows>>,
|
||||
/// Reporting resolution used for data synthesis
|
||||
resolution: usize,
|
||||
}
|
||||
|
||||
impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map> {
|
||||
impl SynthesizerContext {
|
||||
/// Returns a new SynthesizerContext
|
||||
/// # Arguments
|
||||
/// * `headers_len` - Number of headers in the data block
|
||||
/// * `records_len` - Number of records in the data block
|
||||
/// * `attr_rows_map` - Maps a data block value to all the rows where it occurs
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `cache_max_size` - Maximum cache size allowed
|
||||
#[inline]
|
||||
pub fn new(
|
||||
headers_len: usize,
|
||||
records_len: usize,
|
||||
attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
) -> SynthesizerContext<'data_block, 'attr_rows_map> {
|
||||
) -> SynthesizerContext {
|
||||
SynthesizerContext {
|
||||
headers_len,
|
||||
records_len,
|
||||
attr_rows_map,
|
||||
cache: SynthesizerCache::new(cache_max_size),
|
||||
resolution,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the configured mapping of a data block value to all
|
||||
/// the rows where it occurs
|
||||
#[inline]
|
||||
pub fn get_attr_rows_map(&self) -> &AttributeRowsMap<'data_block> {
|
||||
self.attr_rows_map
|
||||
}
|
||||
|
||||
/// Returns the configured reporting resolution
|
||||
#[inline]
|
||||
pub fn get_resolution(&self) -> usize {
|
||||
self.resolution
|
||||
}
|
||||
|
||||
/// Samples an attribute based on its count
|
||||
/// (the higher the count the greater the chance for the
|
||||
/// attribute to be selected).
|
||||
/// Returns `None` if all the counts are 0 or the map is empty
|
||||
/// # Arguments
|
||||
/// * `counts` - Maps an attribute to its count for sampling
|
||||
pub fn sample_from_attr_counts(
|
||||
counts: &AttributeCountMap<'data_block>,
|
||||
) -> Option<&'data_block DataBlockValue> {
|
||||
let mut res: Option<&'data_block DataBlockValue> = None;
|
||||
pub fn sample_from_attr_counts(counts: &AttributeCountMap) -> Option<Arc<DataBlockValue>> {
|
||||
let mut res: Option<Arc<DataBlockValue>> = None;
|
||||
let total: usize = counts.values().sum();
|
||||
|
||||
if total != 0 {
|
||||
let random = if total == 1 {
|
||||
1
|
||||
} else {
|
||||
rand::thread_rng().gen_range(1..total)
|
||||
};
|
||||
let mut current_sim: usize = 0;
|
||||
let random = rand::thread_rng().gen_range(1..=total);
|
||||
let mut current_sum: usize = 0;
|
||||
|
||||
for (value, count) in counts.iter() {
|
||||
current_sim += count;
|
||||
if current_sim < random {
|
||||
res = Some(value);
|
||||
break;
|
||||
for (value, count) in counts.iter().sorted_by_key(|(_, c)| **c) {
|
||||
if *count > 0 {
|
||||
current_sum += count;
|
||||
res = Some(value.clone());
|
||||
if current_sum >= random {
|
||||
break;
|
||||
}
|
||||
}
|
||||
res = Some(value)
|
||||
}
|
||||
}
|
||||
res
|
||||
|
@ -100,52 +83,130 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
|
|||
/// * `synthesized_record` - Record synthesized so far
|
||||
/// * `current_seed` - Current seed/record used for sampling
|
||||
/// * `not_allowed_attr_set` - Attributes not allowed to be sampled
|
||||
pub fn sample_next_attr(
|
||||
/// * `attr_rows_map` - Maps a data block value to all the rows where it occurs
|
||||
pub fn sample_next_attr_from_seed(
|
||||
&mut self,
|
||||
synthesized_record: &SynthesizedRecord<'data_block>,
|
||||
current_seed: &SynthesizerSeedSlice<'data_block>,
|
||||
not_allowed_attr_set: &NotAllowedAttrSet<'data_block>,
|
||||
) -> Option<&'data_block DataBlockValue> {
|
||||
let counts =
|
||||
self.calc_next_attr_count(synthesized_record, current_seed, not_allowed_attr_set);
|
||||
synthesized_record: &SynthesizedRecord,
|
||||
current_seed: &SynthesizerSeedSlice,
|
||||
not_allowed_attr_set: &NotAllowedAttrSet,
|
||||
attr_rows_map: &AttributeRowsMap,
|
||||
) -> Option<Arc<DataBlockValue>> {
|
||||
let counts = self.calc_next_attr_count(
|
||||
synthesized_record,
|
||||
current_seed,
|
||||
not_allowed_attr_set,
|
||||
attr_rows_map,
|
||||
);
|
||||
SynthesizerContext::sample_from_attr_counts(&counts)
|
||||
}
|
||||
|
||||
/// Samples the next attribute from the column at `column_index`.
|
||||
/// Returns a tuple with
|
||||
/// (intersection of `current_attrs_rows` and rows that sampled value appear, sampled value)
|
||||
/// # Arguments
|
||||
/// * `synthesized_record` - Record synthesized so far
|
||||
/// * `column_index` - Index of the column to sample from
|
||||
/// * `attr_rows_map_by_column` - Maps the column index -> data block value -> rows where the value appear
|
||||
/// * `current_attrs_rows` - Rows where the so far sampled combination appear
|
||||
/// * `empty_value` - Empty values on the synthetic data will be represented by this
|
||||
pub fn sample_next_attr_from_column(
|
||||
&mut self,
|
||||
synthesized_record: &SynthesizedRecord,
|
||||
column_index: usize,
|
||||
attr_rows_map_by_column: &AttributeRowsByColumnMap,
|
||||
current_attrs_rows: &AttributeRowsSlice,
|
||||
empty_value: &Arc<String>,
|
||||
) -> Option<(Arc<AttributeRows>, Arc<DataBlockValue>)> {
|
||||
let cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record);
|
||||
let empty_block_value = Arc::new(DataBlockValue::new(column_index, empty_value.clone()));
|
||||
let mut values_to_sample: AttributeRowsRefMap = AttributeRowsRefMap::default();
|
||||
let mut counts = AttributeCountMap::default();
|
||||
|
||||
// calculate the row intersections for each value in the column
|
||||
for (value, rows) in attr_rows_map_by_column[&column_index].iter() {
|
||||
let new_cache_key = cache_key.new_with_value(value);
|
||||
let rows_intersection = match self.cache.get(&new_cache_key) {
|
||||
Some(cached_value) => cached_value.clone(),
|
||||
None => {
|
||||
let intersection = Arc::new(ordered_vec_intersection(current_attrs_rows, rows));
|
||||
self.cache.insert(new_cache_key, intersection.clone());
|
||||
intersection
|
||||
}
|
||||
};
|
||||
values_to_sample.insert(value.clone(), rows_intersection);
|
||||
}
|
||||
|
||||
// store the rows with empty values
|
||||
let mut rows_with_empty_values: RecordsSet = values_to_sample[&empty_block_value]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
for (value, rows) in values_to_sample.iter() {
|
||||
if rows.len() < self.resolution {
|
||||
// if the combination containing the attribute appears in less
|
||||
// than resolution rows, we can't use it so we tag it as an empty value
|
||||
rows_with_empty_values.extend(rows.iter());
|
||||
} else if **value != *empty_block_value {
|
||||
// if we can use the combination containing the attribute
|
||||
// gather its count for sampling
|
||||
counts.insert(value.clone(), rows.len());
|
||||
}
|
||||
}
|
||||
|
||||
// if there are empty values that can be sampled, add them for sampling
|
||||
if !rows_with_empty_values.is_empty() {
|
||||
counts.insert(empty_block_value, rows_with_empty_values.len());
|
||||
}
|
||||
|
||||
Self::sample_from_attr_counts(&counts).map(|sampled_value| {
|
||||
(
|
||||
values_to_sample.remove(&sampled_value).unwrap(),
|
||||
sampled_value,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn calc_current_attrs_rows(
|
||||
&mut self,
|
||||
cache_key: &SynthesizerCacheKey<'data_block>,
|
||||
synthesized_record: &SynthesizedRecord<'data_block>,
|
||||
) -> AttributeRows {
|
||||
cache_key: &SynthesizerCacheKey,
|
||||
synthesized_record: &SynthesizedRecord,
|
||||
attr_rows_map: &AttributeRowsMap,
|
||||
) -> Arc<AttributeRows> {
|
||||
match self.cache.get(cache_key) {
|
||||
Some(cache_value) => cache_value.clone(),
|
||||
None => {
|
||||
let mut current_attrs_rows: AttributeRows = AttributeRows::new();
|
||||
let current_attrs_rows_arc: Arc<AttributeRows>;
|
||||
|
||||
if !synthesized_record.is_empty() {
|
||||
for sr in synthesized_record.iter() {
|
||||
current_attrs_rows = if current_attrs_rows.is_empty() {
|
||||
self.attr_rows_map[sr].clone()
|
||||
attr_rows_map[sr].clone()
|
||||
} else {
|
||||
ordered_vec_intersection(¤t_attrs_rows, &self.attr_rows_map[sr])
|
||||
ordered_vec_intersection(¤t_attrs_rows, &attr_rows_map[sr])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
current_attrs_rows = (0..self.records_len).collect();
|
||||
}
|
||||
current_attrs_rows_arc = Arc::new(current_attrs_rows);
|
||||
|
||||
self.cache
|
||||
.insert(cache_key.clone(), current_attrs_rows.clone());
|
||||
current_attrs_rows
|
||||
.insert(cache_key.clone(), current_attrs_rows_arc.clone());
|
||||
current_attrs_rows_arc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_attr_count_map(
|
||||
&mut self,
|
||||
cache_key: &mut SynthesizerCacheKey<'data_block>,
|
||||
current_seed: &SynthesizerSeedSlice<'data_block>,
|
||||
cache_key: &SynthesizerCacheKey,
|
||||
current_seed: &SynthesizerSeedSlice,
|
||||
current_attrs_rows: &AttributeRowsSlice,
|
||||
not_allowed_attr_set: &NotAllowedAttrSet<'data_block>,
|
||||
) -> AttributeCountMap<'data_block> {
|
||||
not_allowed_attr_set: &NotAllowedAttrSet,
|
||||
attr_rows_map: &AttributeRowsMap,
|
||||
) -> AttributeCountMap {
|
||||
let mut attr_count_map: AttributeCountMap = AttributeCountMap::default();
|
||||
|
||||
for value in current_seed.iter() {
|
||||
|
@ -156,10 +217,10 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
|
|||
let count = match self.cache.get(&new_cache_key) {
|
||||
Some(cached_value) => cached_value.len(),
|
||||
None => {
|
||||
let intersection = ordered_vec_intersection(
|
||||
let intersection = Arc::new(ordered_vec_intersection(
|
||||
current_attrs_rows,
|
||||
self.attr_rows_map.get(value).unwrap(),
|
||||
);
|
||||
attr_rows_map.get(value).unwrap(),
|
||||
));
|
||||
let count = intersection.len();
|
||||
|
||||
self.cache.insert(new_cache_key, intersection);
|
||||
|
@ -168,7 +229,7 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
|
|||
};
|
||||
|
||||
if count >= self.resolution {
|
||||
attr_count_map.insert(value, count);
|
||||
attr_count_map.insert(value.clone(), count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -177,18 +238,21 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
|
|||
|
||||
fn calc_next_attr_count(
|
||||
&mut self,
|
||||
synthesized_record: &SynthesizedRecord<'data_block>,
|
||||
current_seed: &SynthesizerSeedSlice<'data_block>,
|
||||
not_allowed_attr_set: &NotAllowedAttrSet<'data_block>,
|
||||
) -> AttributeCountMap<'data_block> {
|
||||
let mut cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record);
|
||||
let current_attrs_rows = self.calc_current_attrs_rows(&cache_key, synthesized_record);
|
||||
synthesized_record: &SynthesizedRecord,
|
||||
current_seed: &SynthesizerSeedSlice,
|
||||
not_allowed_attr_set: &NotAllowedAttrSet,
|
||||
attr_rows_map: &AttributeRowsMap,
|
||||
) -> AttributeCountMap {
|
||||
let cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record);
|
||||
let current_attrs_rows =
|
||||
self.calc_current_attrs_rows(&cache_key, synthesized_record, attr_rows_map);
|
||||
|
||||
self.gen_attr_count_map(
|
||||
&mut cache_key,
|
||||
&cache_key,
|
||||
current_seed,
|
||||
¤t_attrs_rows,
|
||||
not_allowed_attr_set,
|
||||
attr_rows_map,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,18 @@
|
|||
/// Module defining the cache used during the synthesis process
|
||||
pub mod cache;
|
||||
|
||||
/// Module defining the context used for data synthesis
|
||||
pub mod context;
|
||||
|
||||
/// Module defining the seeded synthesis process
|
||||
pub mod seeded;
|
||||
|
||||
/// Module defining the unseeded synthesis process
|
||||
pub mod unseeded;
|
||||
|
||||
/// Type definitions related to the synthesis process
|
||||
pub mod typedefs;
|
||||
|
||||
mod seeded_rows_synthesizer;
|
||||
|
||||
mod unseeded_rows_synthesizer;
|
||||
|
|
|
@ -1,56 +1,46 @@
|
|||
use super::{
|
||||
context::SynthesizerContext,
|
||||
seeded_rows_synthesizer::SeededRowsSynthesizer,
|
||||
typedefs::{
|
||||
AvailableAttrsMap, NotAllowedAttrSet, SynthesizedRecord, SynthesizedRecords,
|
||||
SynthesizedRecordsSlice, SynthesizerSeed, SynthesizerSeedSlice,
|
||||
},
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use itertools::izip;
|
||||
use log::Level::Trace;
|
||||
use log::{info, log_enabled, trace};
|
||||
use itertools::{izip, Itertools};
|
||||
use log::info;
|
||||
use rand::{prelude::SliceRandom, thread_rng};
|
||||
use std::time::Duration;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
data_block::{
|
||||
block::DataBlock,
|
||||
record::DataBlockRecord,
|
||||
typedefs::{AttributeRowsMap, DataBlockRecords},
|
||||
value::DataBlockValue,
|
||||
},
|
||||
measure_time,
|
||||
data_block::{block::DataBlock, typedefs::AttributeRowsMap, value::DataBlockValue},
|
||||
utils::{
|
||||
math::{calc_percentage, iround_down},
|
||||
reporting::ReportProgress,
|
||||
time::ElapsedDuration,
|
||||
threading::get_number_of_threads,
|
||||
time::ElapsedDurationLogger,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SeededSynthesizerDurations {
|
||||
consolidate: Duration,
|
||||
suppress: Duration,
|
||||
synthesize_rows: Duration,
|
||||
}
|
||||
|
||||
/// Represents all the information required to perform the seeded data synthesis
|
||||
pub struct SeededSynthesizer<'data_block, 'attr_rows_map> {
|
||||
/// Records to be synthesized
|
||||
pub records: &'data_block DataBlockRecords,
|
||||
/// Context used for the synthesis process
|
||||
pub context: SynthesizerContext<'data_block, 'attr_rows_map>,
|
||||
pub struct SeededSynthesizer {
|
||||
/// Reference to the original data block
|
||||
data_block: Arc<DataBlock>,
|
||||
/// Maps a data block value to all the rows where it occurs
|
||||
attr_rows_map: Arc<AttributeRowsMap>,
|
||||
/// Reporting resolution used for data synthesis
|
||||
resolution: usize,
|
||||
/// Maximum cache size allowed
|
||||
cache_max_size: usize,
|
||||
/// Percentage already completed on the row synthesis step
|
||||
synthesize_percentage: f64,
|
||||
/// Percentage already completed on the consolidation step
|
||||
consolidate_percentage: f64,
|
||||
/// Percentage already completed on the suppression step
|
||||
suppress_percentage: f64,
|
||||
/// Elapsed durations on each step for benchmarking
|
||||
durations: SeededSynthesizerDurations,
|
||||
}
|
||||
|
||||
impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> {
|
||||
impl SeededSynthesizer {
|
||||
/// Returns a new SeededSynthesizer
|
||||
/// # Arguments
|
||||
/// * `data_block` - Sensitive data to be synthesized
|
||||
|
@ -59,28 +49,19 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
/// * `cache_max_size` - Maximum cache size allowed
|
||||
#[inline]
|
||||
pub fn new(
|
||||
data_block: &'data_block DataBlock,
|
||||
attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>,
|
||||
data_block: Arc<DataBlock>,
|
||||
attr_rows_map: Arc<AttributeRowsMap>,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
) -> SeededSynthesizer<'data_block, 'attr_rows_map> {
|
||||
) -> SeededSynthesizer {
|
||||
SeededSynthesizer {
|
||||
records: &data_block.records,
|
||||
context: SynthesizerContext::new(
|
||||
data_block.headers.len(),
|
||||
data_block.records.len(),
|
||||
attr_rows_map,
|
||||
resolution,
|
||||
cache_max_size,
|
||||
),
|
||||
data_block,
|
||||
attr_rows_map,
|
||||
resolution,
|
||||
cache_max_size,
|
||||
synthesize_percentage: 0.0,
|
||||
consolidate_percentage: 0.0,
|
||||
suppress_percentage: 0.0,
|
||||
durations: SeededSynthesizerDurations {
|
||||
consolidate: Duration::default(),
|
||||
suppress: Duration::default(),
|
||||
synthesize_rows: Duration::default(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,64 +70,71 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
/// # Arguments
|
||||
/// * `progress_reporter` - Will be used to report the processing
|
||||
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
|
||||
pub fn run<T>(&mut self, progress_reporter: &mut Option<T>) -> SynthesizedRecords<'data_block>
|
||||
pub fn run<T>(&mut self, progress_reporter: &mut Option<T>) -> SynthesizedRecords
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut synthesized_records: SynthesizedRecords<'data_block> = SynthesizedRecords::new();
|
||||
let mut synthesized_records: SynthesizedRecords = SynthesizedRecords::new();
|
||||
|
||||
self.synthesize_percentage = 0.0;
|
||||
self.consolidate_percentage = 0.0;
|
||||
self.suppress_percentage = 0.0;
|
||||
if !self.data_block.records.is_empty() {
|
||||
let mut rows_synthesizers: Vec<SeededRowsSynthesizer> = self.build_rows_synthesizers();
|
||||
|
||||
self.synthesize_rows(&mut synthesized_records, progress_reporter);
|
||||
self.consolidate(&mut synthesized_records, progress_reporter);
|
||||
self.suppress(&mut synthesized_records, progress_reporter);
|
||||
self.synthesize_percentage = 0.0;
|
||||
self.consolidate_percentage = 0.0;
|
||||
self.suppress_percentage = 0.0;
|
||||
|
||||
self.synthesize_rows(
|
||||
&mut synthesized_records,
|
||||
&mut rows_synthesizers,
|
||||
progress_reporter,
|
||||
);
|
||||
self.consolidate(
|
||||
&mut synthesized_records,
|
||||
progress_reporter,
|
||||
// use the first context to leverage already cached intersections
|
||||
&mut rows_synthesizers[0].context,
|
||||
);
|
||||
self.suppress(&mut synthesized_records, progress_reporter);
|
||||
}
|
||||
synthesized_records
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn synthesize_row(
|
||||
&mut self,
|
||||
seed: &'data_block DataBlockRecord,
|
||||
) -> SynthesizedRecord<'data_block> {
|
||||
let current_seed: SynthesizerSeed = seed.values.iter().collect();
|
||||
let mut synthesized_record = SynthesizedRecord::default();
|
||||
let not_allowed_attr_set = NotAllowedAttrSet::default();
|
||||
fn build_rows_synthesizers(&self) -> Vec<SeededRowsSynthesizer> {
|
||||
let chunk_size = ((self.data_block.records.len() as f64) / (get_number_of_threads() as f64))
|
||||
.ceil() as usize;
|
||||
let mut rows_synthesizers: Vec<SeededRowsSynthesizer> = Vec::default();
|
||||
|
||||
loop {
|
||||
let next = self.context.sample_next_attr(
|
||||
&synthesized_record,
|
||||
¤t_seed,
|
||||
¬_allowed_attr_set,
|
||||
);
|
||||
|
||||
match next {
|
||||
Some(value) => {
|
||||
synthesized_record.insert(value);
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
for c in &self.data_block.records.iter().chunks(chunk_size) {
|
||||
rows_synthesizers.push(SeededRowsSynthesizer::new(
|
||||
SynthesizerContext::new(
|
||||
self.data_block.headers.len(),
|
||||
self.data_block.records.len(),
|
||||
self.resolution,
|
||||
self.cache_max_size,
|
||||
),
|
||||
c.cloned().collect(),
|
||||
self.attr_rows_map.clone(),
|
||||
));
|
||||
}
|
||||
synthesized_record
|
||||
rows_synthesizers
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn count_not_used_attrs(
|
||||
&mut self,
|
||||
synthesized_records: &SynthesizedRecordsSlice<'data_block>,
|
||||
) -> AvailableAttrsMap<'data_block> {
|
||||
synthesized_records: &SynthesizedRecordsSlice,
|
||||
) -> AvailableAttrsMap {
|
||||
let mut available_attrs: AvailableAttrsMap = AvailableAttrsMap::default();
|
||||
|
||||
// go through the pairs (original_record, synthesized_record) and count how many
|
||||
// attributes were not used
|
||||
for (original_record, synthesized_record) in
|
||||
izip!(self.records.iter(), synthesized_records.iter())
|
||||
izip!(self.data_block.records.iter(), synthesized_records.iter())
|
||||
{
|
||||
for d in original_record.values.iter() {
|
||||
if !synthesized_record.contains(d) {
|
||||
let attr = available_attrs.entry(d).or_insert(0);
|
||||
let attr = available_attrs.entry(d.clone()).or_insert(0);
|
||||
*attr += 1;
|
||||
}
|
||||
}
|
||||
|
@ -156,16 +144,16 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
|
||||
fn calc_available_attrs(
|
||||
&mut self,
|
||||
synthesized_records: &SynthesizedRecordsSlice<'data_block>,
|
||||
) -> AvailableAttrsMap<'data_block> {
|
||||
synthesized_records: &SynthesizedRecordsSlice,
|
||||
) -> AvailableAttrsMap {
|
||||
let mut available_attrs = self.count_not_used_attrs(synthesized_records);
|
||||
let resolution_f64 = self.context.get_resolution() as f64;
|
||||
let resolution_f64 = self.resolution as f64;
|
||||
|
||||
// add attributes for consolidation
|
||||
for (attr, value) in self.context.get_attr_rows_map().iter() {
|
||||
for (attr, value) in self.attr_rows_map.iter() {
|
||||
let n_rows = value.len();
|
||||
|
||||
if n_rows >= self.context.get_resolution() {
|
||||
if n_rows >= self.resolution {
|
||||
let target_attr_count = iround_down(n_rows as f64, resolution_f64)
|
||||
- (n_rows as isize)
|
||||
+ match available_attrs.get(attr) {
|
||||
|
@ -175,7 +163,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
|
||||
if target_attr_count > 0 {
|
||||
// insert/update the final target count
|
||||
available_attrs.insert(attr, target_attr_count);
|
||||
available_attrs.insert(attr.clone(), target_attr_count);
|
||||
} else {
|
||||
// remove negative and zero values
|
||||
available_attrs.remove(attr);
|
||||
|
@ -192,18 +180,17 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
#[inline]
|
||||
fn calc_not_allowed_attrs(
|
||||
&mut self,
|
||||
available_attrs: &mut AvailableAttrsMap<'data_block>,
|
||||
) -> NotAllowedAttrSet<'data_block> {
|
||||
self.context
|
||||
.get_attr_rows_map()
|
||||
available_attrs: &mut AvailableAttrsMap,
|
||||
) -> NotAllowedAttrSet {
|
||||
self.attr_rows_map
|
||||
.keys()
|
||||
.filter_map(|attr| match available_attrs.get(attr) {
|
||||
// not on available attributes
|
||||
None => Some(*attr),
|
||||
None => Some(attr.clone()),
|
||||
Some(at) => {
|
||||
if *at <= 0 {
|
||||
// on available attributes, but count <= 0
|
||||
Some(*attr)
|
||||
Some(attr.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -215,18 +202,20 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
#[inline]
|
||||
fn consolidate_record(
|
||||
&mut self,
|
||||
available_attrs: &mut AvailableAttrsMap<'data_block>,
|
||||
current_seed: &SynthesizerSeedSlice<'data_block>,
|
||||
) -> SynthesizedRecord<'data_block> {
|
||||
available_attrs: &mut AvailableAttrsMap,
|
||||
current_seed: &SynthesizerSeedSlice,
|
||||
context: &mut SynthesizerContext,
|
||||
) -> SynthesizedRecord {
|
||||
let mut not_allowed_attr_set: NotAllowedAttrSet =
|
||||
self.calc_not_allowed_attrs(available_attrs);
|
||||
let mut synthesized_record = SynthesizedRecord::default();
|
||||
|
||||
loop {
|
||||
let next = self.context.sample_next_attr(
|
||||
let next = context.sample_next_attr_from_seed(
|
||||
&synthesized_record,
|
||||
current_seed,
|
||||
¬_allowed_attr_set,
|
||||
&self.attr_rows_map,
|
||||
);
|
||||
|
||||
match next {
|
||||
|
@ -236,9 +225,9 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
|
||||
if next_count <= 1 {
|
||||
available_attrs.remove(&value);
|
||||
not_allowed_attr_set.insert(value);
|
||||
not_allowed_attr_set.insert(value.clone());
|
||||
} else {
|
||||
available_attrs.insert(value, next_count - 1);
|
||||
available_attrs.insert(value.clone(), next_count - 1);
|
||||
}
|
||||
synthesized_record.insert(value);
|
||||
}
|
||||
|
@ -249,44 +238,43 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
|
||||
fn consolidate<T>(
|
||||
&mut self,
|
||||
synthesized_records: &mut SynthesizedRecords<'data_block>,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
progress_reporter: &mut Option<T>,
|
||||
context: &mut SynthesizerContext,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
measure_time!(
|
||||
|| {
|
||||
info!("consolidating...");
|
||||
let _duration_logger = ElapsedDurationLogger::new("consolidation");
|
||||
info!("consolidating...");
|
||||
|
||||
let mut available_attrs = self.calc_available_attrs(synthesized_records);
|
||||
let current_seed: SynthesizerSeed = available_attrs.keys().cloned().collect();
|
||||
let total = available_attrs.len();
|
||||
let total_f64 = total as f64;
|
||||
let mut n_processed = 0;
|
||||
let mut available_attrs = self.calc_available_attrs(synthesized_records);
|
||||
let current_seed: SynthesizerSeed = available_attrs.keys().cloned().collect();
|
||||
let total = available_attrs.len();
|
||||
let total_f64 = total as f64;
|
||||
let mut n_processed = 0;
|
||||
|
||||
while !available_attrs.is_empty() {
|
||||
self.update_consolidate_progress(n_processed, total_f64, progress_reporter);
|
||||
synthesized_records
|
||||
.push(self.consolidate_record(&mut available_attrs, ¤t_seed));
|
||||
n_processed = total - available_attrs.len();
|
||||
}
|
||||
self.update_consolidate_progress(n_processed, total_f64, progress_reporter);
|
||||
},
|
||||
(self.durations.consolidate)
|
||||
);
|
||||
while !available_attrs.is_empty() {
|
||||
self.update_consolidate_progress(n_processed, total_f64, progress_reporter);
|
||||
synthesized_records.push(self.consolidate_record(
|
||||
&mut available_attrs,
|
||||
¤t_seed,
|
||||
context,
|
||||
));
|
||||
n_processed = total - available_attrs.len();
|
||||
}
|
||||
self.update_consolidate_progress(n_processed, total_f64, progress_reporter);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn count_synthesized_records_attrs(
|
||||
&mut self,
|
||||
synthesized_records: &mut SynthesizedRecords<'data_block>,
|
||||
) -> FnvHashMap<&'data_block DataBlockValue, isize> {
|
||||
let mut current_counts: FnvHashMap<&'data_block DataBlockValue, isize> =
|
||||
FnvHashMap::default();
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
) -> FnvHashMap<Arc<DataBlockValue>, isize> {
|
||||
let mut current_counts: FnvHashMap<Arc<DataBlockValue>, isize> = FnvHashMap::default();
|
||||
|
||||
for r in synthesized_records.iter() {
|
||||
for v in r.iter() {
|
||||
let count = current_counts.entry(v).or_insert(0);
|
||||
let count = current_counts.entry(v.clone()).or_insert(0);
|
||||
*count += 1;
|
||||
}
|
||||
}
|
||||
|
@ -296,16 +284,16 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
#[inline]
|
||||
fn calc_exceeded_count_attrs(
|
||||
&mut self,
|
||||
current_counts: &FnvHashMap<&'data_block DataBlockValue, isize>,
|
||||
) -> FnvHashMap<&'data_block DataBlockValue, isize> {
|
||||
let mut targets: FnvHashMap<&'data_block DataBlockValue, isize> = FnvHashMap::default();
|
||||
current_counts: &FnvHashMap<Arc<DataBlockValue>, isize>,
|
||||
) -> FnvHashMap<Arc<DataBlockValue>, isize> {
|
||||
let mut targets: FnvHashMap<Arc<DataBlockValue>, isize> = FnvHashMap::default();
|
||||
|
||||
for (attr, rows) in self.context.get_attr_rows_map().iter() {
|
||||
if rows.len() >= self.context.get_resolution() {
|
||||
let t = current_counts[attr]
|
||||
- iround_down(rows.len() as f64, self.context.get_resolution() as f64);
|
||||
for (attr, rows) in self.attr_rows_map.iter() {
|
||||
if rows.len() >= self.resolution {
|
||||
let t =
|
||||
current_counts[attr] - iround_down(rows.len() as f64, self.resolution as f64);
|
||||
if t > 0 {
|
||||
targets.insert(attr, t);
|
||||
targets.insert(attr.clone(), t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -314,76 +302,73 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
|
||||
fn suppress<T>(
|
||||
&mut self,
|
||||
synthesized_records: &mut SynthesizedRecords<'data_block>,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
measure_time!(
|
||||
|| {
|
||||
info!("suppressing...");
|
||||
let _duration_logger = ElapsedDurationLogger::new("suppression");
|
||||
info!("suppressing...");
|
||||
|
||||
let current_counts: FnvHashMap<&'data_block DataBlockValue, isize> =
|
||||
self.count_synthesized_records_attrs(synthesized_records);
|
||||
let mut targets: FnvHashMap<&'data_block DataBlockValue, isize> =
|
||||
self.calc_exceeded_count_attrs(¤t_counts);
|
||||
let total = synthesized_records.len() as f64;
|
||||
let mut n_processed = 0;
|
||||
let current_counts: FnvHashMap<Arc<DataBlockValue>, isize> =
|
||||
self.count_synthesized_records_attrs(synthesized_records);
|
||||
let mut targets: FnvHashMap<Arc<DataBlockValue>, isize> =
|
||||
self.calc_exceeded_count_attrs(¤t_counts);
|
||||
let total = synthesized_records.len() as f64;
|
||||
let mut n_processed = 0;
|
||||
|
||||
synthesized_records.shuffle(&mut thread_rng());
|
||||
synthesized_records.shuffle(&mut thread_rng());
|
||||
|
||||
for r in synthesized_records.iter_mut() {
|
||||
let mut new_record = SynthesizedRecord::default();
|
||||
for r in synthesized_records.iter_mut() {
|
||||
let mut new_record = SynthesizedRecord::default();
|
||||
|
||||
self.update_suppress_progress(n_processed, total, progress_reporter);
|
||||
n_processed += 1;
|
||||
for attr in r.iter() {
|
||||
match targets.get(attr).cloned() {
|
||||
None => {
|
||||
new_record.insert(attr);
|
||||
}
|
||||
Some(attr_count) => {
|
||||
if attr_count == 1 {
|
||||
targets.remove(attr);
|
||||
} else {
|
||||
targets.insert(attr, attr_count - 1);
|
||||
}
|
||||
}
|
||||
self.update_suppress_progress(n_processed, total, progress_reporter);
|
||||
n_processed += 1;
|
||||
for attr in r.iter() {
|
||||
match targets.get(attr).cloned() {
|
||||
None => {
|
||||
new_record.insert(attr.clone());
|
||||
}
|
||||
Some(attr_count) => {
|
||||
if attr_count == 1 {
|
||||
targets.remove(attr);
|
||||
} else {
|
||||
targets.insert(attr.clone(), attr_count - 1);
|
||||
}
|
||||
}
|
||||
*r = new_record;
|
||||
}
|
||||
synthesized_records.retain(|r| !r.is_empty());
|
||||
self.update_suppress_progress(n_processed, total, progress_reporter);
|
||||
},
|
||||
(self.durations.suppress)
|
||||
);
|
||||
}
|
||||
*r = new_record;
|
||||
}
|
||||
synthesized_records.retain(|r| !r.is_empty());
|
||||
self.update_suppress_progress(n_processed, total, progress_reporter);
|
||||
}
|
||||
|
||||
fn synthesize_rows<T>(
|
||||
&mut self,
|
||||
synthesized_records: &mut SynthesizedRecords<'data_block>,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
rows_synthesizers: &mut Vec<SeededRowsSynthesizer>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
measure_time!(
|
||||
|| {
|
||||
info!("synthesizing rows...");
|
||||
let _duration_logger = ElapsedDurationLogger::new("rows synthesis");
|
||||
|
||||
let mut n_processed = 0;
|
||||
let records = self.records;
|
||||
let total = records.len() as f64;
|
||||
|
||||
for seed in records.iter() {
|
||||
self.update_synthesize_progress(n_processed, total, progress_reporter);
|
||||
n_processed += 1;
|
||||
synthesized_records.push(self.synthesize_row(seed));
|
||||
}
|
||||
self.update_synthesize_progress(n_processed, total, progress_reporter);
|
||||
},
|
||||
(self.durations.synthesize_rows)
|
||||
info!(
|
||||
"synthesizing rows with {} thread(s)...",
|
||||
get_number_of_threads()
|
||||
);
|
||||
|
||||
let total = self.data_block.records.len() as f64;
|
||||
|
||||
SeededRowsSynthesizer::synthesize_all(
|
||||
total,
|
||||
synthesized_records,
|
||||
rows_synthesizers,
|
||||
progress_reporter,
|
||||
);
|
||||
|
||||
self.update_synthesize_progress(self.data_block.records.len(), total, progress_reporter);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -396,7 +381,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
T: ReportProgress,
|
||||
{
|
||||
if let Some(r) = progress_reporter {
|
||||
self.synthesize_percentage = calc_percentage(n_processed, total);
|
||||
self.synthesize_percentage = calc_percentage(n_processed as f64, total);
|
||||
r.report(self.calc_overall_progress());
|
||||
}
|
||||
}
|
||||
|
@ -411,7 +396,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
T: ReportProgress,
|
||||
{
|
||||
if let Some(r) = progress_reporter {
|
||||
self.consolidate_percentage = calc_percentage(n_processed, total);
|
||||
self.consolidate_percentage = calc_percentage(n_processed as f64, total);
|
||||
r.report(self.calc_overall_progress());
|
||||
}
|
||||
}
|
||||
|
@ -426,7 +411,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
T: ReportProgress,
|
||||
{
|
||||
if let Some(r) = progress_reporter {
|
||||
self.suppress_percentage = calc_percentage(n_processed, total);
|
||||
self.suppress_percentage = calc_percentage(n_processed as f64, total);
|
||||
r.report(self.calc_overall_progress());
|
||||
}
|
||||
}
|
||||
|
@ -438,9 +423,3 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
|
|||
+ self.suppress_percentage * 0.2
|
||||
}
|
||||
}
|
||||
|
||||
impl<'data_block, 'attr_rows_map> Drop for SeededSynthesizer<'data_block, 'attr_rows_map> {
|
||||
fn drop(&mut self) {
|
||||
trace!("seed synthesizer durations: {:#?}", self.durations);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
use super::{
|
||||
context::SynthesizerContext,
|
||||
typedefs::{NotAllowedAttrSet, SynthesizedRecord, SynthesizedRecords, SynthesizerSeed},
|
||||
};
|
||||
use crate::{
|
||||
data_block::typedefs::AttributeRowsMap,
|
||||
utils::reporting::{SendableProgressReporter, SendableProgressReporterRef},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
data_block::{record::DataBlockRecord, typedefs::DataBlockRecords},
|
||||
utils::reporting::ReportProgress,
|
||||
};
|
||||
|
||||
pub struct SeededRowsSynthesizer {
|
||||
pub context: SynthesizerContext,
|
||||
pub records: DataBlockRecords,
|
||||
pub attr_rows_map: Arc<AttributeRowsMap>,
|
||||
}
|
||||
|
||||
impl SeededRowsSynthesizer {
|
||||
#[inline]
|
||||
pub fn new(
|
||||
context: SynthesizerContext,
|
||||
records: DataBlockRecords,
|
||||
attr_rows_map: Arc<AttributeRowsMap>,
|
||||
) -> SeededRowsSynthesizer {
|
||||
SeededRowsSynthesizer {
|
||||
context,
|
||||
records,
|
||||
attr_rows_map,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
#[inline]
|
||||
pub fn synthesize_all<T>(
|
||||
total: f64,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
rows_synthesizers: &mut Vec<SeededRowsSynthesizer>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let sendable_pr = Arc::new(Mutex::new(
|
||||
progress_reporter
|
||||
.as_mut()
|
||||
.map(|r| SendableProgressReporter::new(total, 0.5, r)),
|
||||
));
|
||||
|
||||
synthesized_records.par_extend(
|
||||
rows_synthesizers
|
||||
.par_iter_mut()
|
||||
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr.clone())),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "rayon"))]
|
||||
#[inline]
|
||||
pub fn synthesize_all<T>(
|
||||
total: f64,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
rows_synthesizers: &mut Vec<SeededRowsSynthesizer>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut sendable_pr = progress_reporter
|
||||
.as_mut()
|
||||
.map(|r| SendableProgressReporter::new(total, 0.5, r));
|
||||
|
||||
synthesized_records.extend(
|
||||
rows_synthesizers
|
||||
.iter_mut()
|
||||
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr)),
|
||||
);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn synthesize_rows<T>(
|
||||
&mut self,
|
||||
progress_reporter: &mut SendableProgressReporterRef<T>,
|
||||
) -> SynthesizedRecords
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut synthesized_records = SynthesizedRecords::default();
|
||||
let records = self.records.clone();
|
||||
|
||||
for seed in records.iter() {
|
||||
synthesized_records.push(self.synthesize_row(seed));
|
||||
SendableProgressReporter::update_progress(progress_reporter, 1.0);
|
||||
}
|
||||
synthesized_records
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn synthesize_row(&mut self, seed: &DataBlockRecord) -> SynthesizedRecord {
|
||||
let current_seed: &SynthesizerSeed = &seed.values;
|
||||
let mut synthesized_record = SynthesizedRecord::default();
|
||||
let not_allowed_attr_set = NotAllowedAttrSet::default();
|
||||
|
||||
loop {
|
||||
let next = self.context.sample_next_attr_from_seed(
|
||||
&synthesized_record,
|
||||
current_seed,
|
||||
¬_allowed_attr_set,
|
||||
&self.attr_rows_map,
|
||||
);
|
||||
|
||||
match next {
|
||||
Some(value) => {
|
||||
synthesized_record.insert(value);
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
synthesized_record
|
||||
}
|
||||
}
|
|
@ -1,33 +1,31 @@
|
|||
use fnv::{FnvHashMap, FnvHashSet};
|
||||
use std::collections::BTreeSet;
|
||||
use std::{collections::BTreeSet, sync::Arc};
|
||||
|
||||
use crate::data_block::value::DataBlockValue;
|
||||
|
||||
/// If the data block value were added to the synthesized record, this maps to the
|
||||
/// number of rows the resulting attribute combination will be part of
|
||||
pub type AttributeCountMap<'data_block_value> =
|
||||
FnvHashMap<&'data_block_value DataBlockValue, usize>;
|
||||
pub type AttributeCountMap = FnvHashMap<Arc<DataBlockValue>, usize>;
|
||||
|
||||
/// The seeds used for the current synthesis step (aka current record being processed)
|
||||
pub type SynthesizerSeed<'data_block_value> = Vec<&'data_block_value DataBlockValue>;
|
||||
pub type SynthesizerSeed = Vec<Arc<DataBlockValue>>;
|
||||
|
||||
/// Slice of SynthesizerSeed
|
||||
pub type SynthesizerSeedSlice<'data_block_value> = [&'data_block_value DataBlockValue];
|
||||
pub type SynthesizerSeedSlice = [Arc<DataBlockValue>];
|
||||
|
||||
/// Attributes not allowed to be sampled in a particular data synthesis stage
|
||||
pub type NotAllowedAttrSet<'data_block_value> = FnvHashSet<&'data_block_value DataBlockValue>;
|
||||
pub type NotAllowedAttrSet = FnvHashSet<Arc<DataBlockValue>>;
|
||||
|
||||
/// Record synthesized at a particular stage
|
||||
pub type SynthesizedRecord<'data_block_value> = BTreeSet<&'data_block_value DataBlockValue>;
|
||||
pub type SynthesizedRecord = BTreeSet<Arc<DataBlockValue>>;
|
||||
|
||||
/// Vector of synthesized records
|
||||
pub type SynthesizedRecords<'data_block_value> = Vec<SynthesizedRecord<'data_block_value>>;
|
||||
pub type SynthesizedRecords = Vec<SynthesizedRecord>;
|
||||
|
||||
/// Slice of SynthesizedRecords
|
||||
pub type SynthesizedRecordsSlice<'data_block_value> = [SynthesizedRecord<'data_block_value>];
|
||||
pub type SynthesizedRecordsSlice = [SynthesizedRecord];
|
||||
|
||||
/// How many attributes (of the ones available) should be added during the
|
||||
/// consolidation step for the count to match the a multiple of reporting resolution
|
||||
/// (rounded down)
|
||||
pub type AvailableAttrsMap<'data_block_value> =
|
||||
FnvHashMap<&'data_block_value DataBlockValue, isize>;
|
||||
pub type AvailableAttrsMap = FnvHashMap<Arc<DataBlockValue>, isize>;
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
use super::{
|
||||
context::SynthesizerContext, typedefs::SynthesizedRecords,
|
||||
unseeded_rows_synthesizer::UnseededRowsSynthesizer,
|
||||
};
|
||||
use log::info;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
data_block::{block::DataBlock, typedefs::AttributeRowsByColumnMap},
|
||||
utils::{
|
||||
math::calc_percentage, reporting::ReportProgress, threading::get_number_of_threads,
|
||||
time::ElapsedDurationLogger,
|
||||
},
|
||||
};
|
||||
|
||||
/// Represents all the information required to perform the unseeded data synthesis
|
||||
pub struct UnseededSynthesizer {
|
||||
/// Reference to the original data block
|
||||
data_block: Arc<DataBlock>,
|
||||
/// Maps a data block value to all the rows where it occurs grouped by column
|
||||
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
|
||||
/// Reporting resolution used for data synthesis
|
||||
resolution: usize,
|
||||
/// Maximum cache size allowed
|
||||
cache_max_size: usize,
|
||||
/// Empty values on the synthetic data will be represented by this
|
||||
empty_value: Arc<String>,
|
||||
/// Percentage already completed on the row synthesis step
|
||||
synthesize_percentage: f64,
|
||||
}
|
||||
|
||||
impl UnseededSynthesizer {
|
||||
/// Returns a new UnseededSynthesizer
|
||||
/// # Arguments
|
||||
/// * `data_block` - Sensitive data to be synthesized
|
||||
/// * `attr_rows_map_by_column` - Maps a data block value to all the rows where it occurs grouped by column
|
||||
/// * `resolution` - Reporting resolution used for data synthesis
|
||||
/// * `cache_max_size` - Maximum cache size allowed
|
||||
/// * `empty_value` - Empty values on the synthetic data will be represented by this
|
||||
#[inline]
|
||||
pub fn new(
|
||||
data_block: Arc<DataBlock>,
|
||||
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
|
||||
resolution: usize,
|
||||
cache_max_size: usize,
|
||||
empty_value: Arc<String>,
|
||||
) -> UnseededSynthesizer {
|
||||
UnseededSynthesizer {
|
||||
data_block,
|
||||
attr_rows_map_by_column,
|
||||
resolution,
|
||||
cache_max_size,
|
||||
empty_value,
|
||||
synthesize_percentage: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs the row synthesis
|
||||
/// Returns the synthesized records
|
||||
/// # Arguments
|
||||
/// * `progress_reporter` - Will be used to report the processing
|
||||
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
|
||||
pub fn run<T>(&mut self, progress_reporter: &mut Option<T>) -> SynthesizedRecords
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut synthesized_records: SynthesizedRecords = SynthesizedRecords::new();
|
||||
|
||||
if !self.data_block.records.is_empty() {
|
||||
let mut rows_synthesizers: Vec<UnseededRowsSynthesizer> =
|
||||
self.build_rows_synthesizers();
|
||||
|
||||
self.synthesize_percentage = 0.0;
|
||||
|
||||
self.synthesize_rows(
|
||||
&mut synthesized_records,
|
||||
&mut rows_synthesizers,
|
||||
progress_reporter,
|
||||
);
|
||||
}
|
||||
synthesized_records
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn build_rows_synthesizers(&self) -> Vec<UnseededRowsSynthesizer> {
|
||||
let mut total_size = self.data_block.records.len();
|
||||
let chunk_size = ((total_size as f64) / (get_number_of_threads() as f64)).ceil() as usize;
|
||||
let mut rows_synthesizers: Vec<UnseededRowsSynthesizer> = Vec::default();
|
||||
|
||||
loop {
|
||||
if total_size > chunk_size {
|
||||
rows_synthesizers.push(UnseededRowsSynthesizer::new(
|
||||
SynthesizerContext::new(
|
||||
self.data_block.headers.len(),
|
||||
self.data_block.records.len(),
|
||||
self.resolution,
|
||||
self.cache_max_size,
|
||||
),
|
||||
chunk_size,
|
||||
self.attr_rows_map_by_column.clone(),
|
||||
self.empty_value.clone(),
|
||||
));
|
||||
total_size -= chunk_size;
|
||||
} else {
|
||||
rows_synthesizers.push(UnseededRowsSynthesizer::new(
|
||||
SynthesizerContext::new(
|
||||
self.data_block.headers.len(),
|
||||
self.data_block.records.len(),
|
||||
self.resolution,
|
||||
self.cache_max_size,
|
||||
),
|
||||
total_size,
|
||||
self.attr_rows_map_by_column.clone(),
|
||||
self.empty_value.clone(),
|
||||
));
|
||||
break;
|
||||
}
|
||||
}
|
||||
rows_synthesizers
|
||||
}
|
||||
|
||||
fn synthesize_rows<T>(
|
||||
&mut self,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
rows_synthesizers: &mut Vec<UnseededRowsSynthesizer>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let _duration_logger = ElapsedDurationLogger::new("rows synthesis");
|
||||
|
||||
info!(
|
||||
"synthesizing rows with {} thread(s)...",
|
||||
get_number_of_threads()
|
||||
);
|
||||
|
||||
let total = self.data_block.records.len() as f64;
|
||||
|
||||
UnseededRowsSynthesizer::synthesize_all(
|
||||
total,
|
||||
synthesized_records,
|
||||
rows_synthesizers,
|
||||
progress_reporter,
|
||||
);
|
||||
|
||||
self.update_synthesize_progress(self.data_block.records.len(), total, progress_reporter);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_synthesize_progress<T>(
|
||||
&mut self,
|
||||
n_processed: usize,
|
||||
total: f64,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
if let Some(r) = progress_reporter {
|
||||
self.synthesize_percentage = calc_percentage(n_processed as f64, total);
|
||||
r.report(self.synthesize_percentage);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
use super::{
|
||||
context::SynthesizerContext,
|
||||
typedefs::{SynthesizedRecord, SynthesizedRecords},
|
||||
};
|
||||
use crate::{
|
||||
data_block::typedefs::{AttributeRows, AttributeRowsByColumnMap},
|
||||
utils::reporting::{SendableProgressReporter, SendableProgressReporterRef},
|
||||
};
|
||||
use rand::{prelude::SliceRandom, thread_rng};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::utils::reporting::ReportProgress;
|
||||
|
||||
pub struct UnseededRowsSynthesizer {
|
||||
context: SynthesizerContext,
|
||||
chunk_size: usize,
|
||||
column_indexes: Vec<usize>,
|
||||
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
|
||||
empty_value: Arc<String>,
|
||||
}
|
||||
|
||||
impl UnseededRowsSynthesizer {
|
||||
#[inline]
|
||||
pub fn new(
|
||||
context: SynthesizerContext,
|
||||
chunk_size: usize,
|
||||
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
|
||||
empty_value: Arc<String>,
|
||||
) -> UnseededRowsSynthesizer {
|
||||
UnseededRowsSynthesizer {
|
||||
context,
|
||||
chunk_size,
|
||||
column_indexes: attr_rows_map_by_column.keys().cloned().collect(),
|
||||
attr_rows_map_by_column,
|
||||
empty_value,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
#[inline]
|
||||
pub fn synthesize_all<T>(
|
||||
total: f64,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
rows_synthesizers: &mut Vec<UnseededRowsSynthesizer>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let sendable_pr = Arc::new(Mutex::new(
|
||||
progress_reporter
|
||||
.as_mut()
|
||||
.map(|r| SendableProgressReporter::new(total, 1.0, r)),
|
||||
));
|
||||
|
||||
synthesized_records.par_extend(
|
||||
rows_synthesizers
|
||||
.par_iter_mut()
|
||||
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr.clone())),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "rayon"))]
|
||||
#[inline]
|
||||
pub fn synthesize_all<T>(
|
||||
total: f64,
|
||||
synthesized_records: &mut SynthesizedRecords,
|
||||
rows_synthesizers: &mut Vec<UnseededRowsSynthesizer>,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut sendable_pr = progress_reporter
|
||||
.as_mut()
|
||||
.map(|r| SendableProgressReporter::new(total, 1.0, r));
|
||||
|
||||
synthesized_records.extend(
|
||||
rows_synthesizers
|
||||
.iter_mut()
|
||||
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr)),
|
||||
);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn synthesize_rows<T>(
|
||||
&mut self,
|
||||
progress_reporter: &mut SendableProgressReporterRef<T>,
|
||||
) -> SynthesizedRecords
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
let mut synthesized_records = SynthesizedRecords::default();
|
||||
|
||||
for _ in 0..self.chunk_size {
|
||||
synthesized_records.push(self.synthesize_row());
|
||||
SendableProgressReporter::update_progress(progress_reporter, 1.0);
|
||||
}
|
||||
synthesized_records
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn synthesize_row(&mut self) -> SynthesizedRecord {
|
||||
let mut synthesized_record = SynthesizedRecord::default();
|
||||
let mut current_attrs_rows: Arc<AttributeRows> =
|
||||
Arc::new((0..self.context.records_len).collect());
|
||||
|
||||
self.column_indexes.shuffle(&mut thread_rng());
|
||||
|
||||
for column_index in self.column_indexes.iter() {
|
||||
if let Some((next_attrs_rows, sample)) = self.context.sample_next_attr_from_column(
|
||||
&synthesized_record,
|
||||
*column_index,
|
||||
&self.attr_rows_map_by_column,
|
||||
¤t_attrs_rows,
|
||||
&self.empty_value,
|
||||
) {
|
||||
current_attrs_rows = next_attrs_rows;
|
||||
synthesized_record.insert(sample);
|
||||
}
|
||||
}
|
||||
synthesized_record
|
||||
}
|
||||
}
|
|
@ -4,8 +4,8 @@ use std::hash::Hash;
|
|||
/// Given two sorted vectors, calculates the intersection between them.
|
||||
/// Returns the result intersection as a new vector
|
||||
/// # Arguments
|
||||
/// - `a`: first sorted vector
|
||||
/// - `b`: second sorted vector
|
||||
/// * `a` - first sorted vector
|
||||
/// * `b` - second sorted vector
|
||||
#[inline]
|
||||
pub fn ordered_vec_intersection<T>(a: &[T], b: &[T]) -> Vec<T>
|
||||
where
|
||||
|
|
|
@ -54,6 +54,6 @@ pub fn calc_n_combinations_range(n: usize, range: &[usize]) -> u64 {
|
|||
|
||||
/// Calculates the percentage of processed elements up to a total
|
||||
#[inline]
|
||||
pub fn calc_percentage(n_processed: usize, total: f64) -> f64 {
|
||||
(n_processed as f64) * 100.0 / total
|
||||
pub fn calc_percentage(n_processed: f64, total: f64) -> f64 {
|
||||
n_processed * 100.0 / total
|
||||
}
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
/// Module for collection utilities
|
||||
pub mod collections;
|
||||
|
||||
/// Module for math utilities
|
||||
pub mod math;
|
||||
|
||||
/// Module for reporting utilities
|
||||
pub mod reporting;
|
||||
|
||||
/// Module for threading utilities
|
||||
pub mod threading;
|
||||
|
||||
/// Module for time utilities
|
||||
pub mod time;
|
||||
|
|
|
@ -1,11 +1,6 @@
|
|||
use super::ReportProgress;
|
||||
use log::{log, Level};
|
||||
|
||||
/// Implement this trait to inform progress
|
||||
pub trait ReportProgress {
|
||||
/// Receives the updated progress
|
||||
fn report(&mut self, new_progress: f64);
|
||||
}
|
||||
|
||||
/// Simple progress reporter using the default logger.
|
||||
/// * It will log progress using the configured `log_level`
|
||||
/// * It will only log at every 1% completed
|
||||
|
@ -16,8 +11,8 @@ pub struct LoggerProgressReporter {
|
|||
|
||||
impl LoggerProgressReporter {
|
||||
/// Returns a new LoggerProgressReporter
|
||||
/// # Arguments:
|
||||
/// * `log_level`: which log level use to log progress
|
||||
/// # Arguments
|
||||
/// * `log_level - which log level use to log progress
|
||||
pub fn new(log_level: Level) -> LoggerProgressReporter {
|
||||
LoggerProgressReporter {
|
||||
progress: 0.0,
|
|
@ -0,0 +1,11 @@
|
|||
mod report_progress;
|
||||
|
||||
mod logger_progress_reporter;
|
||||
|
||||
mod sendable_progress_reporter;
|
||||
|
||||
pub use report_progress::ReportProgress;
|
||||
|
||||
pub use logger_progress_reporter::LoggerProgressReporter;
|
||||
|
||||
pub use sendable_progress_reporter::{SendableProgressReporter, SendableProgressReporterRef};
|
|
@ -0,0 +1,5 @@
|
|||
/// Implement this trait to inform progress
|
||||
pub trait ReportProgress {
|
||||
/// Receives the updated progress
|
||||
fn report(&mut self, new_progress: f64);
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
#[cfg(feature = "rayon")]
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::utils::{math::calc_percentage, reporting::ReportProgress};
|
||||
|
||||
/// Progress reporter to be used in parallel environments
|
||||
/// (should be wrapped with `Arc` and `Mutex` if multi-threading is enabled).
|
||||
/// It will calculate `proportion * (n_processed * 100.0 / total)` and report
|
||||
/// it to the main reporter
|
||||
pub struct SendableProgressReporter<'main_reporter, T>
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
total: f64,
|
||||
n_processed: f64,
|
||||
proportion: f64,
|
||||
main_reporter: &'main_reporter mut T,
|
||||
}
|
||||
|
||||
impl<'main_reporter, T> SendableProgressReporter<'main_reporter, T>
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
/// Creates a new SendableProgressReporter
|
||||
/// # Arguments
|
||||
/// * `total` - Total number of steps to be reported
|
||||
/// * `proportion` - Value to multiply the percentage before reporting to
|
||||
/// main_reporter
|
||||
/// * `main_reporter` - Main reporter to which this should report
|
||||
pub fn new(
|
||||
total: f64,
|
||||
proportion: f64,
|
||||
main_reporter: &'main_reporter mut T,
|
||||
) -> SendableProgressReporter<T> {
|
||||
SendableProgressReporter {
|
||||
total,
|
||||
proportion,
|
||||
n_processed: 0.0,
|
||||
main_reporter,
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates `reporter` by adding `value_to_add` and reporting
|
||||
/// to the main reporter in a thread safe way
|
||||
#[inline]
|
||||
pub fn update_progress(reporter: &mut SendableProgressReporterRef<T>, value_to_add: f64)
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
#[cfg(feature = "rayon")]
|
||||
if let Ok(guard) = &mut reporter.lock() {
|
||||
if let Some(r) = guard.as_mut() {
|
||||
r.report(value_to_add);
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "rayon"))]
|
||||
if let Some(r) = reporter {
|
||||
r.report(value_to_add);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'main_reporter, T> ReportProgress for SendableProgressReporter<'main_reporter, T>
|
||||
where
|
||||
T: ReportProgress,
|
||||
{
|
||||
/// Will add `value_to_add` to `n_processed` and call the main reporter with
|
||||
/// `proportion * (n_processed * 100.0 / total)`
|
||||
fn report(&mut self, value_to_add: f64) {
|
||||
self.n_processed += value_to_add;
|
||||
self.main_reporter
|
||||
.report(self.proportion * calc_percentage(self.n_processed, self.total))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
/// People using this should correctly handle race conditions with a `Mutex`
|
||||
/// (see `SendableProgressReporterRef`), so we inform the compiler this struct is thread safe
|
||||
unsafe impl<'main_reporter, T> Send for SendableProgressReporter<'main_reporter, T> where
|
||||
T: ReportProgress
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
/// Use this to refer to SendableProgressReporter if multi-threading is enabled
|
||||
pub type SendableProgressReporterRef<'main_reporter, T> =
|
||||
Arc<Mutex<Option<SendableProgressReporter<'main_reporter, T>>>>;
|
||||
|
||||
#[cfg(not(feature = "rayon"))]
|
||||
/// Use this to refer to SendableProgressReporter if multi-threading is disabled
|
||||
pub type SendableProgressReporterRef<'main_reporter, T> =
|
||||
Option<SendableProgressReporter<'main_reporter, T>>;
|
|
@ -0,0 +1,34 @@
|
|||
#[cfg(feature = "rayon")]
|
||||
use rayon;
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
pub fn get_number_of_threads() -> usize {
|
||||
#[cfg(feature = "rayon")]
|
||||
{
|
||||
rayon::current_num_threads()
|
||||
}
|
||||
#[cfg(not(feature = "rayon"))]
|
||||
{
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
#[cfg_attr(feature = "pyo3", pyfunction)]
|
||||
#[allow(unused_must_use)]
|
||||
/// Sets the number of threads used for parallel processing
|
||||
/// # Arguments
|
||||
/// * `n` - number of threads
|
||||
pub fn set_number_of_threads(n: usize) {
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(n)
|
||||
.build_global();
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3")]
|
||||
pub fn register(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(set_number_of_threads, m)?)?;
|
||||
Ok(())
|
||||
}
|
|
@ -16,7 +16,7 @@ pub struct ElapsedDuration<'lifetime> {
|
|||
impl<'lifetime> ElapsedDuration<'lifetime> {
|
||||
/// Returns a new ElapsedDuration
|
||||
/// # Arguments
|
||||
/// - `result`: Duration where to sum the elapsed duration when the `ElapsedDuration`
|
||||
/// * `result` - Duration where to sum the elapsed duration when the `ElapsedDuration`
|
||||
/// instance goes out of scope
|
||||
pub fn new(result: &mut Duration) -> ElapsedDuration {
|
||||
ElapsedDuration {
|
||||
|
@ -41,10 +41,10 @@ pub struct ElapsedDurationLogger {
|
|||
}
|
||||
|
||||
impl ElapsedDurationLogger {
|
||||
pub fn new(message: String) -> ElapsedDurationLogger {
|
||||
pub fn new<S: Into<String>>(message: S) -> ElapsedDurationLogger {
|
||||
ElapsedDurationLogger {
|
||||
_start: Instant::now(),
|
||||
_message: message,
|
||||
_message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,4 +14,5 @@ crate-type = ["cdylib"]
|
|||
log = { version = "0.4", features = ["std"] }
|
||||
csv = { version = "1.1" }
|
||||
pyo3 = { version = "0.15", features = ["extension-module"] }
|
||||
sds-core = { path = "../core", features = ["pyo3"] }
|
||||
sds-core = { path = "../core", features = ["pyo3", "rayon"] }
|
||||
env_logger = { version = "0.9" }
|
|
@ -0,0 +1,133 @@
|
|||
use csv::ReaderBuilder;
|
||||
use pyo3::prelude::*;
|
||||
use sds_core::{
|
||||
data_block::{
|
||||
block::DataBlock, csv_block_creator::CsvDataBlockCreator, csv_io_error::CsvIOError,
|
||||
data_block_creator::DataBlockCreator,
|
||||
},
|
||||
processing::{
|
||||
aggregator::{aggregated_data::AggregatedData, Aggregator},
|
||||
generator::{generated_data::GeneratedData, Generator, SynthesisMode},
|
||||
},
|
||||
utils::reporting::LoggerProgressReporter,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[pyclass]
|
||||
/// Processor exposing the main features
|
||||
pub struct SDSProcessor {
|
||||
data_block: Arc<DataBlock>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl SDSProcessor {
|
||||
/// Creates a new processor by reading the content of a given file
|
||||
/// # Arguments
|
||||
/// * `path` - File to be read to build the data block
|
||||
/// * `delimiter` - CSV/TSV separator for the content on `path`
|
||||
/// * `use_columns` - Column names to be used (if `[]` use all columns)
|
||||
/// * `sensitive_zeros` - Column names containing sensitive zeros
|
||||
/// (if `[]` no columns are considered to have sensitive zeros)
|
||||
/// * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
|
||||
#[new]
|
||||
pub fn new(
|
||||
path: &str,
|
||||
delimiter: char,
|
||||
use_columns: Vec<String>,
|
||||
sensitive_zeros: Vec<String>,
|
||||
record_limit: usize,
|
||||
) -> Result<SDSProcessor, CsvIOError> {
|
||||
match CsvDataBlockCreator::create(
|
||||
ReaderBuilder::new()
|
||||
.delimiter(delimiter as u8)
|
||||
.from_path(path),
|
||||
&use_columns,
|
||||
&sensitive_zeros,
|
||||
record_limit,
|
||||
) {
|
||||
Ok(data_block) => Ok(SDSProcessor { data_block }),
|
||||
Err(err) => Err(CsvIOError::new(err)),
|
||||
}
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
/// Load the SDS Processor for the data block linked to `aggregated_data`
|
||||
pub fn from_aggregated_data(aggregated_data: &AggregatedData) -> SDSProcessor {
|
||||
SDSProcessor {
|
||||
data_block: aggregated_data.data_block.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Returns the number of records on the data block
|
||||
pub fn number_of_records(&self) -> usize {
|
||||
self.data_block.number_of_records()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Returns the number of records on the data block protected by `resolution`
|
||||
pub fn protected_number_of_records(&self, resolution: usize) -> usize {
|
||||
self.data_block.protected_number_of_records(resolution)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Normalizes the reporting length based on the number of selected headers.
|
||||
/// Returns the normalized value
|
||||
/// # Arguments
|
||||
/// * `reporting_length` - Reporting length to be normalized (0 means use all columns)
|
||||
pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize {
|
||||
self.data_block.normalize_reporting_length(reporting_length)
|
||||
}
|
||||
|
||||
/// Builds the aggregated data for the content
|
||||
/// using the specified `reporting_length` and `sensitivity_threshold`.
|
||||
/// The result is written to `sensitive_aggregates_path` and `reportable_aggregates_path`.
|
||||
/// # Arguments
|
||||
/// * `reporting_length` - Maximum length to compute attribute combinations
|
||||
/// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes
|
||||
/// (0 means no suppression)
|
||||
pub fn aggregate(
|
||||
&self,
|
||||
reporting_length: usize,
|
||||
sensitivity_threshold: usize,
|
||||
) -> AggregatedData {
|
||||
let mut progress_reporter: Option<LoggerProgressReporter> = None;
|
||||
let mut aggregator = Aggregator::new(self.data_block.clone());
|
||||
aggregator.aggregate(
|
||||
reporting_length,
|
||||
sensitivity_threshold,
|
||||
&mut progress_reporter,
|
||||
)
|
||||
}
|
||||
|
||||
/// Synthesizes the content using the specified `resolution` and
|
||||
/// returns the generated data
|
||||
/// # Arguments
|
||||
/// * `cache_max_size` - Maximum cache size used during the synthesis process
|
||||
/// * `resolution` - Reporting resolution to be used
|
||||
/// * `empty_value` - Empty values on the synthetic data will be represented by this
|
||||
/// * `seeded` - True for seeded synthesis, False for unseeded
|
||||
pub fn generate(
|
||||
&self,
|
||||
cache_max_size: usize,
|
||||
resolution: usize,
|
||||
empty_value: String,
|
||||
seeded: bool,
|
||||
) -> GeneratedData {
|
||||
let mut progress_reporter: Option<LoggerProgressReporter> = None;
|
||||
let mut generator = Generator::new(self.data_block.clone());
|
||||
let mode = if seeded {
|
||||
SynthesisMode::Seeded
|
||||
} else {
|
||||
SynthesisMode::Unseeded
|
||||
};
|
||||
|
||||
generator.generate(
|
||||
resolution,
|
||||
cache_max_size,
|
||||
empty_value,
|
||||
mode,
|
||||
&mut progress_reporter,
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,160 +1,24 @@
|
|||
//! This crate will generate python bindings for the main features
|
||||
//! of the `sds_core` library.
|
||||
|
||||
use csv::ReaderBuilder;
|
||||
use pyo3::{exceptions::PyIOError, prelude::*};
|
||||
use data_processor::SDSProcessor;
|
||||
use pyo3::prelude::*;
|
||||
use sds_core::{
|
||||
data_block::{
|
||||
block::{DataBlock, DataBlockCreator},
|
||||
csv_block_creator::CsvDataBlockCreator,
|
||||
},
|
||||
processing::{
|
||||
aggregator::{typedefs::AggregatedCountByLenMap, Aggregator},
|
||||
generator::Generator,
|
||||
},
|
||||
utils::reporting::LoggerProgressReporter,
|
||||
processing::{aggregator::aggregated_data, evaluator},
|
||||
utils::threading,
|
||||
};
|
||||
|
||||
/// Creates a data block by reading the content of a given file
|
||||
/// # Arguments
|
||||
/// * `path` - File to be read to build the data block
|
||||
/// * `delimiter` - CSV/TSV separator for the content on `path`
|
||||
/// * `use_columns` - Column names to be used (if `[]` use all columns)
|
||||
/// * `sensitive_zeros` - Column names containing sensitive zeros
|
||||
/// (if `[]` no columns are considered to have sensitive zeros)
|
||||
/// * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
|
||||
#[pyfunction]
|
||||
pub fn create_data_block_from_file(
|
||||
path: &str,
|
||||
delimiter: char,
|
||||
use_columns: Vec<String>,
|
||||
sensitive_zeros: Vec<String>,
|
||||
record_limit: usize,
|
||||
) -> PyResult<DataBlock> {
|
||||
match CsvDataBlockCreator::create(
|
||||
ReaderBuilder::new()
|
||||
.delimiter(delimiter as u8)
|
||||
.from_path(path),
|
||||
&use_columns,
|
||||
&sensitive_zeros,
|
||||
record_limit,
|
||||
) {
|
||||
Ok(data_block) => Ok(data_block),
|
||||
Err(err) => Err(PyIOError::new_err(format!(
|
||||
"error reading data block: {}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds the protected and aggregated data for the content in `data_block`
|
||||
/// using the specified `reporting_length` and `resolution`.
|
||||
/// The result is written to `sensitive_aggregates_path` and `reportable_aggregates_path`.
|
||||
///
|
||||
/// The function returns a tuple (combo_counts, rare_counts).
|
||||
///
|
||||
/// * `combo_counts` - computed number of distinct combinations grouped by combination length (not protected)
|
||||
/// * `rare_counts` - computed number of rare combinations (`count < resolution`) grouped by combination length (not protected)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data_block` - Data block with the content to be aggregated
|
||||
/// * `sensitive_aggregates_path` - Path to write the sensitive aggregates
|
||||
/// * `reportable_aggregates_path` - Path to write the aggregates with protected count
|
||||
/// (rounded down to the nearest multiple of `resolution`)
|
||||
/// * `delimiter` - CSV/TSV separator for the content be written in
|
||||
/// `sensitive_aggregates_path` and `reportable_aggregates_path`
|
||||
/// * `reporting_length` - Maximum length to compute attribute combinations
|
||||
/// * `resolution` - Reporting resolution to be used
|
||||
#[pyfunction]
|
||||
pub fn aggregate_and_write(
|
||||
data_block: &DataBlock,
|
||||
sensitive_aggregates_path: &str,
|
||||
reportable_aggregates_path: &str,
|
||||
delimiter: char,
|
||||
reporting_length: usize,
|
||||
resolution: usize,
|
||||
) -> PyResult<(AggregatedCountByLenMap, AggregatedCountByLenMap)> {
|
||||
let mut progress_reporter: Option<LoggerProgressReporter> = None;
|
||||
let mut aggregator = Aggregator::new(data_block);
|
||||
let mut aggregated_data = aggregator.aggregate(reporting_length, 0, &mut progress_reporter);
|
||||
let combo_count = aggregator.calc_combinations_count_by_len(&aggregated_data.aggregates_count);
|
||||
let rare_count = aggregator
|
||||
.calc_rare_combinations_count_by_len(&aggregated_data.aggregates_count, resolution);
|
||||
|
||||
if let Err(err) = aggregator.write_aggregates_count(
|
||||
&aggregated_data.aggregates_count,
|
||||
sensitive_aggregates_path,
|
||||
delimiter,
|
||||
resolution,
|
||||
false,
|
||||
) {
|
||||
return Err(PyIOError::new_err(format!(
|
||||
"error writing sensitive aggregates: {}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
|
||||
Aggregator::protect_aggregates_count(&mut aggregated_data.aggregates_count, resolution);
|
||||
|
||||
match aggregator.write_aggregates_count(
|
||||
&aggregated_data.aggregates_count,
|
||||
reportable_aggregates_path,
|
||||
delimiter,
|
||||
resolution,
|
||||
true,
|
||||
) {
|
||||
Ok(_) => Ok((combo_count, rare_count)),
|
||||
Err(err) => Err(PyIOError::new_err(format!(
|
||||
"error writing reportable aggregates: {}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthesizes the `data_block` using the specified `resolution`.
|
||||
/// The result is written to `synthetic_path`.
|
||||
///
|
||||
/// The function returns the expansion ratio:
|
||||
/// `Synthetic data length / Sensitive data length` (header not included)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data_block` - Data block with the content to be synthesized
|
||||
/// * `synthetic_path` - Path to write the synthetic data
|
||||
/// * `synthetic_delimiter` - CSV/TSV separator for the content be written in `synthetic_path`
|
||||
/// * `cache_max_size` - Maximum cache size used during the synthesis process
|
||||
/// * `resolution` - Reporting resolution to be used
|
||||
#[pyfunction]
|
||||
pub fn generate_seeded_and_write(
|
||||
data_block: &DataBlock,
|
||||
synthetic_path: &str,
|
||||
synthetic_delimiter: char,
|
||||
cache_max_size: usize,
|
||||
resolution: usize,
|
||||
) -> PyResult<f64> {
|
||||
let mut progress_reporter: Option<LoggerProgressReporter> = None;
|
||||
let mut generator = Generator::new(data_block);
|
||||
let generated_data = generator.generate(resolution, cache_max_size, "", &mut progress_reporter);
|
||||
|
||||
match generator.write_records(
|
||||
&generated_data.synthetic_data,
|
||||
synthetic_path,
|
||||
synthetic_delimiter as u8,
|
||||
) {
|
||||
Ok(_) => Ok(generated_data.expansion_ratio),
|
||||
Err(err) => Err(PyIOError::new_err(format!(
|
||||
"error writing synthetic data: {}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
/// Module that exposes the main processor
|
||||
pub mod data_processor;
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule]
|
||||
fn sds(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(create_data_block_from_file, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(aggregate_and_write, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(generate_seeded_and_write, m)?)?;
|
||||
fn sds(py: Python, m: &PyModule) -> PyResult<()> {
|
||||
env_logger::init();
|
||||
m.add_class::<SDSProcessor>()?;
|
||||
threading::register(py, m)?;
|
||||
aggregated_data::register(py, m)?;
|
||||
evaluator::register(py, m)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ csv = { version = "1.1" }
|
|||
web-sys = { version = "0.3", features = [ "console" ]}
|
||||
sds-core = { path = "../core" }
|
||||
js-sys = { version = "0.3" }
|
||||
serde = { version = "1.0", features = [ "derive" ] }
|
||||
serde = { version = "1.0", features = [ "derive", "rc" ] }
|
||||
|
||||
# The `console_error_panic_hook` crate provides better debugging of panics by
|
||||
# logging them with `console.error`. This is great for development, but requires
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
use js_sys::Object;
|
||||
use js_sys::Reflect::set;
|
||||
use log::error;
|
||||
use sds_core::data_block::block::DataBlock;
|
||||
use sds_core::data_block::typedefs::DataBlockHeaders;
|
||||
use sds_core::processing::aggregator::typedefs::{AggregatedCountByLenMap, AggregatesCountMap};
|
||||
use sds_core::processing::aggregator::{Aggregator, PrivacyRiskSummary};
|
||||
use sds_core::utils::reporting::ReportProgress;
|
||||
use sds_core::utils::time::ElapsedDurationLogger;
|
||||
use serde::Serialize;
|
||||
use wasm_bindgen::JsValue;
|
||||
|
||||
use crate::utils::js::serializers::{serialize_aggregates_count, serialize_privacy_risk};
|
||||
use crate::{match_or_return_undefined, set_or_return_undefined};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AggregatedCombination {
|
||||
combination_key: String,
|
||||
count: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
impl AggregatedCombination {
|
||||
#[inline]
|
||||
pub fn new(combination_key: String, count: usize, length: usize) -> AggregatedCombination {
|
||||
AggregatedCombination {
|
||||
combination_key,
|
||||
count,
|
||||
length,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AggregatedCombination> for JsValue {
|
||||
#[inline]
|
||||
fn from(aggregated_combination: AggregatedCombination) -> Self {
|
||||
match_or_return_undefined!(JsValue::from_serde(&aggregated_combination))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AggregatedResult<'data_block> {
|
||||
pub headers: &'data_block DataBlockHeaders,
|
||||
pub aggregates_count: AggregatesCountMap<'data_block>,
|
||||
pub rare_combinations_count_by_len: AggregatedCountByLenMap,
|
||||
pub combinations_count_by_len: AggregatedCountByLenMap,
|
||||
pub combinations_sum_by_len: AggregatedCountByLenMap,
|
||||
pub privacy_risk: PrivacyRiskSummary,
|
||||
}
|
||||
|
||||
impl<'data_block> From<AggregatedResult<'data_block>> for JsValue {
|
||||
#[inline]
|
||||
fn from(aggregated_result: AggregatedResult) -> Self {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new(String::from("AggregatedResult conversion to JsValue"));
|
||||
let result = Object::new();
|
||||
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"aggregatedCombinations".into(),
|
||||
&serialize_aggregates_count(
|
||||
&aggregated_result.aggregates_count,
|
||||
aggregated_result.headers
|
||||
),
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"rareCombinationsCountByLen".into(),
|
||||
&match_or_return_undefined!(JsValue::from_serde(
|
||||
&aggregated_result.rare_combinations_count_by_len
|
||||
))
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"combinationsCountByLen".into(),
|
||||
&match_or_return_undefined!(JsValue::from_serde(
|
||||
&aggregated_result.combinations_count_by_len
|
||||
))
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"combinationsSumByLen".into(),
|
||||
&match_or_return_undefined!(JsValue::from_serde(
|
||||
&aggregated_result.combinations_sum_by_len
|
||||
))
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"privacyRisk".into(),
|
||||
&serialize_privacy_risk(&aggregated_result.privacy_risk)
|
||||
);
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aggregate<'data_block, T: ReportProgress>(
|
||||
data_block: &'data_block DataBlock,
|
||||
reporting_length: usize,
|
||||
resolution: usize,
|
||||
progress_reporter: &mut Option<T>,
|
||||
) -> Result<AggregatedResult<'data_block>, String> {
|
||||
let mut aggregator = Aggregator::new(data_block);
|
||||
let aggregates_data = aggregator.aggregate(reporting_length, 0, progress_reporter);
|
||||
|
||||
Ok(AggregatedResult {
|
||||
headers: &data_block.headers,
|
||||
rare_combinations_count_by_len: aggregator
|
||||
.calc_rare_combinations_count_by_len(&aggregates_data.aggregates_count, resolution),
|
||||
combinations_count_by_len: aggregator
|
||||
.calc_combinations_count_by_len(&aggregates_data.aggregates_count),
|
||||
combinations_sum_by_len: aggregator
|
||||
.calc_combinations_sum_by_len(&aggregates_data.aggregates_count),
|
||||
privacy_risk: aggregator.calc_privacy_risk(&aggregates_data.aggregates_count, resolution),
|
||||
aggregates_count: aggregates_data.aggregates_count,
|
||||
})
|
||||
}
|
|
@ -1,170 +0,0 @@
|
|||
use js_sys::Reflect::set;
|
||||
use js_sys::{Array, Function, Object};
|
||||
use log::{error, info};
|
||||
use sds_core::processing::aggregator::typedefs::AggregatedCountByLenMap;
|
||||
use sds_core::processing::aggregator::Aggregator;
|
||||
use sds_core::processing::evaluator::typedefs::PreservationByCountBuckets;
|
||||
use sds_core::processing::evaluator::Evaluator;
|
||||
use sds_core::utils::reporting::ReportProgress;
|
||||
use sds_core::utils::time::ElapsedDurationLogger;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen::JsValue;
|
||||
|
||||
use crate::aggregator::{aggregate, AggregatedResult};
|
||||
use crate::utils::js::deserializers::{
|
||||
deserialize_csv_data, deserialize_sensitive_zeros, deserialize_use_columns,
|
||||
};
|
||||
use crate::utils::js::js_progress_reporter::JsProgressReporter;
|
||||
use crate::utils::js::serializers::serialize_buckets;
|
||||
use crate::{match_or_return_undefined, set_or_return_undefined};
|
||||
|
||||
struct EvaluatedResult<'data_block> {
|
||||
sensitive_aggregated_result: AggregatedResult<'data_block>,
|
||||
synthetic_aggregated_result: AggregatedResult<'data_block>,
|
||||
leakage_count_by_len: AggregatedCountByLenMap,
|
||||
fabricated_count_by_len: AggregatedCountByLenMap,
|
||||
preservation_by_count_buckets: PreservationByCountBuckets,
|
||||
combination_loss: f64,
|
||||
record_expansion: f64,
|
||||
}
|
||||
|
||||
impl<'data_block> From<EvaluatedResult<'data_block>> for JsValue {
|
||||
#[inline]
|
||||
fn from(evaluated_result: EvaluatedResult) -> Self {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new(String::from("EvaluatedResult conversion to JsValue"));
|
||||
let result = Object::new();
|
||||
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"sensitiveAggregatedResult".into(),
|
||||
&evaluated_result.sensitive_aggregated_result.into(),
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"syntheticAggregatedResult".into(),
|
||||
&evaluated_result.synthetic_aggregated_result.into(),
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"leakageCountByLen".into(),
|
||||
&match_or_return_undefined!(JsValue::from_serde(
|
||||
&evaluated_result.leakage_count_by_len
|
||||
))
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"fabricatedCountByLen".into(),
|
||||
&match_or_return_undefined!(JsValue::from_serde(
|
||||
&evaluated_result.fabricated_count_by_len
|
||||
))
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"preservationByCountBuckets".into(),
|
||||
&serialize_buckets(&evaluated_result.preservation_by_count_buckets)
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"combinationLoss".into(),
|
||||
&match_or_return_undefined!(JsValue::from_serde(&evaluated_result.combination_loss))
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&result,
|
||||
&"recordExpansion".into(),
|
||||
&match_or_return_undefined!(JsValue::from_serde(&evaluated_result.record_expansion))
|
||||
);
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn evaluate(
|
||||
sensitive_csv_data: Array,
|
||||
synthetic_csv_data: Array,
|
||||
use_columns: Array,
|
||||
sensitive_zeros: Array,
|
||||
record_limit: usize,
|
||||
reporting_length: usize,
|
||||
resolution: usize,
|
||||
progress_callback: Function,
|
||||
) -> JsValue {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("evaluation process"));
|
||||
let use_columns_vec = match_or_return_undefined!(deserialize_use_columns(use_columns));
|
||||
let sensitive_zeros_vec =
|
||||
match_or_return_undefined!(deserialize_sensitive_zeros(sensitive_zeros));
|
||||
|
||||
info!("aggregating sensitive data...");
|
||||
let sensitive_data_block = match_or_return_undefined!(deserialize_csv_data(
|
||||
sensitive_csv_data,
|
||||
&use_columns_vec,
|
||||
&sensitive_zeros_vec,
|
||||
record_limit,
|
||||
));
|
||||
let mut sensitive_aggregated_result = match_or_return_undefined!(aggregate(
|
||||
&sensitive_data_block,
|
||||
reporting_length,
|
||||
resolution,
|
||||
&mut Some(JsProgressReporter::new(&progress_callback, &|p| p * 0.40))
|
||||
));
|
||||
|
||||
info!("aggregating synthetic data...");
|
||||
let synthetic_data_block = match_or_return_undefined!(deserialize_csv_data(
|
||||
synthetic_csv_data,
|
||||
&use_columns_vec,
|
||||
&sensitive_zeros_vec,
|
||||
record_limit,
|
||||
));
|
||||
let synthetic_aggregated_result = match_or_return_undefined!(aggregate(
|
||||
&synthetic_data_block,
|
||||
reporting_length,
|
||||
resolution,
|
||||
&mut Some(JsProgressReporter::new(&progress_callback, &|p| {
|
||||
40.0 + (p * 0.40)
|
||||
}))
|
||||
));
|
||||
let mut evaluator_instance = Evaluator::default();
|
||||
|
||||
info!("evaluating synthetic data based on sensitive data...");
|
||||
|
||||
let buckets = evaluator_instance.calc_preservation_by_count(
|
||||
&sensitive_aggregated_result.aggregates_count,
|
||||
&synthetic_aggregated_result.aggregates_count,
|
||||
resolution,
|
||||
);
|
||||
let leakage_count_by_len = evaluator_instance.calc_leakage_count(
|
||||
&sensitive_aggregated_result.aggregates_count,
|
||||
&synthetic_aggregated_result.aggregates_count,
|
||||
resolution,
|
||||
);
|
||||
let fabricated_count_by_len = evaluator_instance.calc_fabricated_count(
|
||||
&sensitive_aggregated_result.aggregates_count,
|
||||
&synthetic_aggregated_result.aggregates_count,
|
||||
);
|
||||
let combination_loss = evaluator_instance.calc_combination_loss(&buckets);
|
||||
|
||||
Aggregator::protect_aggregates_count(
|
||||
&mut sensitive_aggregated_result.aggregates_count,
|
||||
resolution,
|
||||
);
|
||||
|
||||
JsProgressReporter::new(&progress_callback, &|p| p).report(100.0);
|
||||
|
||||
EvaluatedResult {
|
||||
leakage_count_by_len,
|
||||
fabricated_count_by_len,
|
||||
combination_loss,
|
||||
preservation_by_count_buckets: buckets,
|
||||
record_expansion: (synthetic_aggregated_result
|
||||
.privacy_risk
|
||||
.total_number_of_records as f64)
|
||||
/ (sensitive_aggregated_result
|
||||
.privacy_risk
|
||||
.total_number_of_records as f64),
|
||||
sensitive_aggregated_result,
|
||||
synthetic_aggregated_result,
|
||||
}
|
||||
.into()
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
use js_sys::{Array, Function};
|
||||
use log::error;
|
||||
use sds_core::{processing::generator::Generator, utils::time::ElapsedDurationLogger};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::{
|
||||
match_or_return_undefined,
|
||||
utils::js::{
|
||||
deserializers::{
|
||||
deserialize_csv_data, deserialize_sensitive_zeros, deserialize_use_columns,
|
||||
},
|
||||
js_progress_reporter::JsProgressReporter,
|
||||
},
|
||||
};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn generate(
|
||||
csv_data: Array,
|
||||
use_columns: Array,
|
||||
sensitive_zeros: Array,
|
||||
record_limit: usize,
|
||||
resolution: usize,
|
||||
cache_size: usize,
|
||||
progress_callback: Function,
|
||||
) -> JsValue {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("generation process"));
|
||||
let data_block = match_or_return_undefined!(deserialize_csv_data(
|
||||
csv_data,
|
||||
&match_or_return_undefined!(deserialize_use_columns(use_columns)),
|
||||
&match_or_return_undefined!(deserialize_sensitive_zeros(sensitive_zeros)),
|
||||
record_limit,
|
||||
));
|
||||
let mut generator = Generator::new(&data_block);
|
||||
|
||||
match_or_return_undefined!(JsValue::from_serde(
|
||||
&generator
|
||||
.generate(
|
||||
resolution,
|
||||
cache_size,
|
||||
"",
|
||||
&mut Some(JsProgressReporter::new(&progress_callback, &|p| p))
|
||||
)
|
||||
.synthetic_data
|
||||
))
|
||||
}
|
|
@ -1,141 +1,3 @@
|
|||
//! This crate will generate wasm bindings for the main features
|
||||
//! of the `sds_core` library:
|
||||
//!
|
||||
//! # Init Logger
|
||||
//! ```typescript
|
||||
//! init_logger(
|
||||
//! level_str: string
|
||||
//! ): boolean
|
||||
//! ```
|
||||
//!
|
||||
//! Initializes logging using a particular log level.
|
||||
//! Returns `true` if successfully initialized, `false` otherwise
|
||||
//!
|
||||
//! ## Arguments:
|
||||
//! * `level_str` - String containing a valid log level
|
||||
//! (`off`, `error`, `warn`, `info`, `debug` or `trace`)
|
||||
//!
|
||||
//! # Generate
|
||||
//! ```typescript
|
||||
//! generate(
|
||||
//! csv_data: string[][],
|
||||
//! use_columns: string[],
|
||||
//! sensitive_zeros: string[],
|
||||
//! record_limit: number,
|
||||
//! resolution: number,
|
||||
//! cache_size: number,
|
||||
//! progress_callback: (progress: value) => void
|
||||
//! ): string[][]
|
||||
//! ```
|
||||
//!
|
||||
//! Synthesizes the `csv_data` using the configured parameters and returns it
|
||||
//!
|
||||
//! ## Arguments
|
||||
//! * `csv_data` - Data to be synthesized
|
||||
//! * `csv_data[0]` - Should be the headers
|
||||
//! * `csv_data[1...]` - Should be the records
|
||||
//! * `use_columns` - Column names to be used (if `[]` use all columns)
|
||||
//! * `sensitive_zeros` - Column names containing sensitive zeros
|
||||
//! (if `[]` no columns are considered to have sensitive zeros)
|
||||
//! * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
|
||||
//! * `resolution` - Reporting resolution to be used
|
||||
//! * `cache_size` - Maximum cache size used during the synthesis process
|
||||
//! * `progress_callback` - Callback that informs the processing percentage (0.0 % - 100.0 %)
|
||||
//!
|
||||
//! # Evaluate
|
||||
//! ```typescript
|
||||
//! interface IAggregatedCombination {
|
||||
//! combination_key: number
|
||||
//! count: number
|
||||
//! length: number
|
||||
//! }
|
||||
//!
|
||||
//! interface IAggregatedCombinations {
|
||||
//! [name: string]: IAggregatedCombination
|
||||
//! }
|
||||
//!
|
||||
//! interface IAggregatedCountByLen {
|
||||
//! [length: number]: number
|
||||
//! }
|
||||
//!
|
||||
//! interface IPrivacyRiskSummary {
|
||||
//! totalNumberOfRecords: number
|
||||
//! totalNumberOfCombinations: number
|
||||
//! recordsWithUniqueCombinationsCount: number
|
||||
//! recordsWithRareCombinationsCount: number
|
||||
//! uniqueCombinationsCount: number
|
||||
//! rareCombinationsCount: number
|
||||
//! recordsWithUniqueCombinationsProportion: number
|
||||
//! recordsWithRareCombinationsProportion: number
|
||||
//! uniqueCombinationsProportion: number
|
||||
//! rareCombinationsProportion: number
|
||||
//! }
|
||||
//!
|
||||
//! interface IAggregatedResult {
|
||||
//! aggregatedCombinations?: IAggregatedCombinations
|
||||
//! rareCombinationsCountByLen?: IAggregatedCountByLen
|
||||
//! combinationsCountByLen?: IAggregatedCountByLen
|
||||
//! combinationsSumByLen?: IAggregatedCountByLen
|
||||
//! privacyRisk?: IPrivacyRiskSummary
|
||||
//! }
|
||||
//!
|
||||
//! interface IPreservationByCountBucket {
|
||||
//! size: number
|
||||
//! preservationSum: number
|
||||
//! lengthSum: number
|
||||
//! }
|
||||
//!
|
||||
//! interface IPreservationByCountBuckets {
|
||||
//! [bucket_index: number]: IPreservationByCountBucket
|
||||
//! }
|
||||
//!
|
||||
//! interface IEvaluatedResult {
|
||||
//! sensitiveAggregatedResult?: IAggregatedResult
|
||||
//! syntheticAggregatedResult?: IAggregatedResult
|
||||
//! leakageCountByLen?: IAggregatedCountByLen
|
||||
//! fabricatedCountByLen?: IAggregatedCountByLen
|
||||
//! preservationByCountBuckets?: IPreservationByCountBuckets
|
||||
//! combinationLoss?: number
|
||||
//! recordExpansion?: number
|
||||
//! }
|
||||
//!
|
||||
//! evaluate(
|
||||
//! sensitive_csv_data: string[][],
|
||||
//! synthetic_csv_data: string[][],
|
||||
//! use_columns: string[],
|
||||
//! sensitive_zeros: string[],
|
||||
//! record_limit: number,
|
||||
//! reporting_length: number,
|
||||
//! resolution: number,
|
||||
//! progress_callback: (progress: value) => void
|
||||
//! ): IEvaluatedResult
|
||||
//! ```
|
||||
//! Evaluates the synthetic data based on the sensitive data and produces a `IEvaluatedResult`
|
||||
//!
|
||||
//! ## Arguments
|
||||
//! * `sensitive_csv_data` - Sensitive data to be evaluated
|
||||
//! * `sensitive_csv_data[0]` - Should be the headers
|
||||
//! * `sensitive_csv_data[1...]` - Should be the records
|
||||
//! * `synthetic_csv_data` - Synthetic data produced from the synthetic data
|
||||
//! * `synthetic_csv_data[0]` - Should be the headers
|
||||
//! * `synthetic_csv_data[1...]` - Should be the records
|
||||
//! * `use_columns` - Column names to be used (if `[]` use all columns)
|
||||
//! * `sensitive_zeros` - Column names containing sensitive zeros
|
||||
//! (if `[]` no columns are considered to have sensitive zeros)
|
||||
//! * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
|
||||
//! * `reporting_length` - Maximum length to compute attribute combinations
|
||||
//! for analysis
|
||||
//! * `resolution` - Reporting resolution to be used
|
||||
//! * `progress_callback` - Callback that informs the processing percentage (0.0 % - 100.0 %)
|
||||
pub mod processing;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub mod aggregator;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub mod evaluator;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub mod generator;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub mod utils;
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
use serde::Serialize;
|
||||
use std::convert::TryFrom;
|
||||
use wasm_bindgen::{JsCast, JsValue};
|
||||
|
||||
use crate::utils::js::ts_definitions::JsAggregateCountAndLength;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WasmAggregateCountAndLength {
|
||||
count: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
impl WasmAggregateCountAndLength {
|
||||
#[inline]
|
||||
pub fn new(count: usize, length: usize) -> WasmAggregateCountAndLength {
|
||||
WasmAggregateCountAndLength { count, length }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<WasmAggregateCountAndLength> for JsAggregateCountAndLength {
|
||||
type Error = JsValue;
|
||||
|
||||
#[inline]
|
||||
fn try_from(
|
||||
aggregate_count_and_length: WasmAggregateCountAndLength,
|
||||
) -> Result<Self, Self::Error> {
|
||||
JsValue::from_serde(&aggregate_count_and_length)
|
||||
.map_err(|err| JsValue::from(err.to_string()))
|
||||
.map(|c| c.unchecked_into())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,228 @@
|
|||
use super::aggregate_count_and_length::WasmAggregateCountAndLength;
|
||||
use js_sys::{Object, Reflect::set};
|
||||
use sds_core::{
|
||||
processing::aggregator::aggregated_data::AggregatedData, utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
use wasm_bindgen::{prelude::*, JsCast};
|
||||
|
||||
use crate::utils::js::ts_definitions::{
|
||||
JsAggregateCountAndLength, JsAggregateCountByLen, JsAggregateResult, JsAggregatesCount,
|
||||
JsPrivacyRiskSummary, JsResult,
|
||||
};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmAggregateResult {
|
||||
aggregated_data: AggregatedData,
|
||||
}
|
||||
|
||||
impl WasmAggregateResult {
|
||||
#[inline]
|
||||
pub fn default() -> WasmAggregateResult {
|
||||
WasmAggregateResult {
|
||||
aggregated_data: AggregatedData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn new(aggregated_data: AggregatedData) -> WasmAggregateResult {
|
||||
WasmAggregateResult { aggregated_data }
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmAggregateResult {
|
||||
#[wasm_bindgen(getter)]
|
||||
#[wasm_bindgen(js_name = "reportingLength")]
|
||||
pub fn reporting_length(&self) -> usize {
|
||||
self.aggregated_data.reporting_length
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "aggregatesCountToJs")]
|
||||
pub fn aggregates_count_to_js(
|
||||
&self,
|
||||
combination_delimiter: &str,
|
||||
) -> JsResult<JsAggregatesCount> {
|
||||
let result = Object::new();
|
||||
|
||||
for (agg, count) in self.aggregated_data.aggregates_count.iter() {
|
||||
set(
|
||||
&result,
|
||||
&agg.format_str_using_headers(
|
||||
&self.aggregated_data.data_block.headers,
|
||||
combination_delimiter,
|
||||
)
|
||||
.into(),
|
||||
&JsAggregateCountAndLength::try_from(WasmAggregateCountAndLength::new(
|
||||
count.count,
|
||||
agg.len(),
|
||||
))?
|
||||
.into(),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(result.unchecked_into::<JsAggregatesCount>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "rareCombinationsCountByLenToJs")]
|
||||
pub fn rare_combinations_count_by_len_to_js(
|
||||
&self,
|
||||
resolution: usize,
|
||||
) -> JsResult<JsAggregateCountByLen> {
|
||||
let count = self
|
||||
.aggregated_data
|
||||
.calc_rare_combinations_count_by_len(resolution);
|
||||
|
||||
Ok(JsValue::from_serde(&count)
|
||||
.map_err(|err| JsValue::from(err.to_string()))?
|
||||
.unchecked_into::<JsAggregateCountByLen>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "combinationsCountByLenToJs")]
|
||||
pub fn combinations_count_by_len_to_js(&self) -> JsResult<JsAggregateCountByLen> {
|
||||
let count = self.aggregated_data.calc_combinations_count_by_len();
|
||||
|
||||
Ok(JsValue::from_serde(&count)
|
||||
.map_err(|err| JsValue::from(err.to_string()))?
|
||||
.unchecked_into::<JsAggregateCountByLen>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "combinationsSumByLenToJs")]
|
||||
pub fn combinations_sum_by_len_to_js(&self) -> JsResult<JsAggregateCountByLen> {
|
||||
let count = self.aggregated_data.calc_combinations_sum_by_len();
|
||||
|
||||
Ok(JsValue::from_serde(&count)
|
||||
.map_err(|err| JsValue::from(err.to_string()))?
|
||||
.unchecked_into::<JsAggregateCountByLen>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "privacyRiskToJs")]
|
||||
pub fn privacy_risk_to_js(&self, resolution: usize) -> JsResult<JsPrivacyRiskSummary> {
|
||||
let pr = self.aggregated_data.calc_privacy_risk(resolution);
|
||||
let result = Object::new();
|
||||
|
||||
set(
|
||||
&result,
|
||||
&"totalNumberOfRecords".into(),
|
||||
&pr.total_number_of_records.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"totalNumberOfCombinations".into(),
|
||||
&pr.total_number_of_combinations.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"recordsWithUniqueCombinationsCount".into(),
|
||||
&pr.records_with_unique_combinations_count.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"recordsWithRareCombinationsCount".into(),
|
||||
&pr.records_with_rare_combinations_count.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"uniqueCombinationsCount".into(),
|
||||
&pr.unique_combinations_count.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"rareCombinationsCount".into(),
|
||||
&pr.rare_combinations_count.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"recordsWithUniqueCombinationsProportion".into(),
|
||||
&pr.records_with_unique_combinations_proportion.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"recordsWithRareCombinationsProportion".into(),
|
||||
&pr.records_with_rare_combinations_proportion.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"uniqueCombinationsProportion".into(),
|
||||
&pr.unique_combinations_proportion.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"rareCombinationsProportion".into(),
|
||||
&pr.rare_combinations_proportion.into(),
|
||||
)?;
|
||||
|
||||
Ok(result.unchecked_into::<JsPrivacyRiskSummary>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "toJs")]
|
||||
pub fn to_js(
|
||||
&self,
|
||||
combination_delimiter: &str,
|
||||
resolution: usize,
|
||||
include_aggregates_count: bool,
|
||||
) -> JsResult<JsAggregateResult> {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new(String::from("aggregate result serialization"));
|
||||
let result = Object::new();
|
||||
|
||||
set(
|
||||
&result,
|
||||
&"reportingLength".into(),
|
||||
&self.reporting_length().into(),
|
||||
)?;
|
||||
if include_aggregates_count {
|
||||
set(
|
||||
&result,
|
||||
&"aggregatesCount".into(),
|
||||
&self.aggregates_count_to_js(combination_delimiter)?.into(),
|
||||
)?;
|
||||
}
|
||||
set(
|
||||
&result,
|
||||
&"rareCombinationsCountByLen".into(),
|
||||
&self
|
||||
.rare_combinations_count_by_len_to_js(resolution)?
|
||||
.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"combinationsCountByLen".into(),
|
||||
&self.combinations_count_by_len_to_js()?.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"combinationsSumByLen".into(),
|
||||
&self.combinations_sum_by_len_to_js()?.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"privacyRisk".into(),
|
||||
&self.privacy_risk_to_js(resolution)?.into(),
|
||||
)?;
|
||||
|
||||
Ok(JsValue::from(result).unchecked_into::<JsAggregateResult>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "protectAggregatesCount")]
|
||||
pub fn protect_aggregates_count(&mut self, resolution: usize) {
|
||||
self.aggregated_data.protect_aggregates_count(resolution)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WasmAggregateResult {
|
||||
type Target = AggregatedData;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.aggregated_data
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for WasmAggregateResult {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.aggregated_data
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
pub mod aggregate_count_and_length;
|
||||
|
||||
pub mod aggregate_result;
|
|
@ -0,0 +1,184 @@
|
|||
use super::preservation_by_count::WasmPreservationByCount;
|
||||
use js_sys::{Object, Reflect::set};
|
||||
use sds_core::{processing::evaluator::Evaluator, utils::time::ElapsedDurationLogger};
|
||||
use wasm_bindgen::{prelude::*, JsCast};
|
||||
|
||||
use crate::{
|
||||
processing::aggregator::aggregate_result::WasmAggregateResult,
|
||||
utils::js::ts_definitions::{
|
||||
JsAggregateCountByLen, JsAggregateResult, JsEvaluateResult, JsResult,
|
||||
},
|
||||
};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmEvaluateResult {
|
||||
pub(crate) sensitive_aggregate_result: WasmAggregateResult,
|
||||
pub(crate) synthetic_aggregate_result: WasmAggregateResult,
|
||||
evaluator: Evaluator,
|
||||
}
|
||||
|
||||
impl WasmEvaluateResult {
|
||||
#[inline]
|
||||
pub fn default() -> WasmEvaluateResult {
|
||||
WasmEvaluateResult {
|
||||
sensitive_aggregate_result: WasmAggregateResult::default(),
|
||||
synthetic_aggregate_result: WasmAggregateResult::default(),
|
||||
evaluator: Evaluator::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn new(
|
||||
sensitive_aggregate_result: WasmAggregateResult,
|
||||
synthetic_aggregate_result: WasmAggregateResult,
|
||||
) -> WasmEvaluateResult {
|
||||
WasmEvaluateResult {
|
||||
sensitive_aggregate_result,
|
||||
synthetic_aggregate_result,
|
||||
evaluator: Evaluator::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEvaluateResult {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn from_aggregate_results(
|
||||
sensitive_aggregate_result: WasmAggregateResult,
|
||||
synthetic_aggregate_result: WasmAggregateResult,
|
||||
) -> WasmEvaluateResult {
|
||||
WasmEvaluateResult::new(sensitive_aggregate_result, synthetic_aggregate_result)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "sensitiveAggregateResultToJs")]
|
||||
pub fn sensitive_aggregate_result_to_js(
|
||||
&self,
|
||||
combination_delimiter: &str,
|
||||
resolution: usize,
|
||||
include_aggregates_count: bool,
|
||||
) -> JsResult<JsAggregateResult> {
|
||||
self.sensitive_aggregate_result.to_js(
|
||||
combination_delimiter,
|
||||
resolution,
|
||||
include_aggregates_count,
|
||||
)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "syntheticAggregateResultToJs")]
|
||||
pub fn synthetic_aggregate_result_to_js(
|
||||
&self,
|
||||
combination_delimiter: &str,
|
||||
resolution: usize,
|
||||
include_aggregates_count: bool,
|
||||
) -> JsResult<JsAggregateResult> {
|
||||
self.synthetic_aggregate_result.to_js(
|
||||
combination_delimiter,
|
||||
resolution,
|
||||
include_aggregates_count,
|
||||
)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "leakageCountByLenToJs")]
|
||||
pub fn leakage_count_by_len_to_js(&self, resolution: usize) -> JsResult<JsAggregateCountByLen> {
|
||||
let count = self.evaluator.calc_leakage_count(
|
||||
&self.sensitive_aggregate_result,
|
||||
&self.synthetic_aggregate_result,
|
||||
resolution,
|
||||
);
|
||||
|
||||
Ok(JsValue::from_serde(&count)
|
||||
.map_err(|err| JsValue::from(err.to_string()))?
|
||||
.unchecked_into::<JsAggregateCountByLen>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "fabricatedCountByLenToJs")]
|
||||
pub fn fabricated_count_by_len_to_js(&self) -> JsResult<JsAggregateCountByLen> {
|
||||
let count = self.evaluator.calc_fabricated_count(
|
||||
&self.sensitive_aggregate_result,
|
||||
&self.synthetic_aggregate_result,
|
||||
);
|
||||
|
||||
Ok(JsValue::from_serde(&count)
|
||||
.map_err(|err| JsValue::from(err.to_string()))?
|
||||
.unchecked_into::<JsAggregateCountByLen>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "preservationByCount")]
|
||||
pub fn preservation_by_count(&self, resolution: usize) -> WasmPreservationByCount {
|
||||
WasmPreservationByCount::new(self.evaluator.calc_preservation_by_count(
|
||||
&self.sensitive_aggregate_result,
|
||||
&self.synthetic_aggregate_result,
|
||||
resolution,
|
||||
))
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "recordExpansion")]
|
||||
pub fn record_expansion(&self) -> f64 {
|
||||
(self
|
||||
.synthetic_aggregate_result
|
||||
.data_block
|
||||
.number_of_records() as f64)
|
||||
/ (self
|
||||
.sensitive_aggregate_result
|
||||
.data_block
|
||||
.number_of_records() as f64)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "toJs")]
|
||||
pub fn to_js(
|
||||
&self,
|
||||
combination_delimiter: &str,
|
||||
resolution: usize,
|
||||
include_aggregates_count: bool,
|
||||
) -> JsResult<JsEvaluateResult> {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new(String::from("evaluate result serialization"));
|
||||
let result = Object::new();
|
||||
|
||||
set(
|
||||
&result,
|
||||
&"sensitiveAggregateResult".into(),
|
||||
&self
|
||||
.sensitive_aggregate_result_to_js(
|
||||
combination_delimiter,
|
||||
resolution,
|
||||
include_aggregates_count,
|
||||
)?
|
||||
.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"syntheticAggregateResult".into(),
|
||||
&self
|
||||
.synthetic_aggregate_result_to_js(
|
||||
combination_delimiter,
|
||||
resolution,
|
||||
include_aggregates_count,
|
||||
)?
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
set(
|
||||
&result,
|
||||
&"leakageCountByLen".into(),
|
||||
&self.leakage_count_by_len_to_js(resolution)?.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"fabricatedCountByLen".into(),
|
||||
&self.fabricated_count_by_len_to_js()?.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"preservationByCount".into(),
|
||||
&self.preservation_by_count(resolution).to_js()?.into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"recordExpansion".into(),
|
||||
&self.record_expansion().into(),
|
||||
)?;
|
||||
|
||||
Ok(JsValue::from(result).unchecked_into::<JsEvaluateResult>())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
pub mod preservation_by_count;
|
||||
|
||||
pub mod evaluate_result;
|
|
@ -0,0 +1,93 @@
|
|||
use js_sys::{Object, Reflect::set};
|
||||
use sds_core::{
|
||||
processing::evaluator::preservation_by_count::PreservationByCountBuckets,
|
||||
utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use wasm_bindgen::{prelude::*, JsCast};
|
||||
|
||||
use crate::utils::js::ts_definitions::{
|
||||
JsPreservationByCount, JsPreservationByCountBuckets, JsResult,
|
||||
};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmPreservationByCount {
|
||||
preservation_by_count_buckets: PreservationByCountBuckets,
|
||||
}
|
||||
|
||||
impl WasmPreservationByCount {
|
||||
#[inline]
|
||||
pub fn new(
|
||||
preservation_by_count_buckets: PreservationByCountBuckets,
|
||||
) -> WasmPreservationByCount {
|
||||
WasmPreservationByCount {
|
||||
preservation_by_count_buckets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmPreservationByCount {
|
||||
#[wasm_bindgen(js_name = "bucketsToJs")]
|
||||
pub fn buckets_to_js(&self) -> JsResult<JsPreservationByCountBuckets> {
|
||||
let result = Object::new();
|
||||
|
||||
for (bucket_index, b) in self.preservation_by_count_buckets.iter() {
|
||||
let serialized_bucket = Object::new();
|
||||
|
||||
set(&serialized_bucket, &"size".into(), &b.size.into())?;
|
||||
set(
|
||||
&serialized_bucket,
|
||||
&"preservationSum".into(),
|
||||
&b.preservation_sum.into(),
|
||||
)?;
|
||||
set(
|
||||
&serialized_bucket,
|
||||
&"lengthSum".into(),
|
||||
&b.length_sum.into(),
|
||||
)?;
|
||||
set(
|
||||
&serialized_bucket,
|
||||
&"combinationCountSum".into(),
|
||||
&b.combination_count_sum.into(),
|
||||
)?;
|
||||
|
||||
set(&result, &(*bucket_index).into(), &serialized_bucket.into())?;
|
||||
}
|
||||
Ok(result.unchecked_into::<JsPreservationByCountBuckets>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "combinationLoss")]
|
||||
pub fn combination_loss(&self) -> f64 {
|
||||
self.preservation_by_count_buckets.calc_combination_loss()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "toJs")]
|
||||
pub fn to_js(&self) -> JsResult<JsPreservationByCount> {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new(String::from("preservation by count serialization"));
|
||||
let result = Object::new();
|
||||
|
||||
set(&result, &"buckets".into(), &self.buckets_to_js()?.into())?;
|
||||
set(
|
||||
&result,
|
||||
&"combinationLoss".into(),
|
||||
&self.combination_loss().into(),
|
||||
)?;
|
||||
Ok(JsValue::from(result).unchecked_into::<JsPreservationByCount>())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WasmPreservationByCount {
|
||||
type Target = PreservationByCountBuckets;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.preservation_by_count_buckets
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for WasmPreservationByCount {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.preservation_by_count_buckets
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
use js_sys::{Object, Reflect::set};
|
||||
use sds_core::{
|
||||
processing::generator::generated_data::GeneratedData, utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use wasm_bindgen::{prelude::*, JsCast};
|
||||
|
||||
use crate::utils::js::ts_definitions::{JsCsvData, JsGenerateResult, JsResult};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmGenerateResult {
|
||||
generated_data: GeneratedData,
|
||||
}
|
||||
|
||||
impl WasmGenerateResult {
|
||||
#[inline]
|
||||
pub fn default() -> WasmGenerateResult {
|
||||
WasmGenerateResult {
|
||||
generated_data: GeneratedData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn new(generated_data: GeneratedData) -> WasmGenerateResult {
|
||||
WasmGenerateResult { generated_data }
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmGenerateResult {
|
||||
#[wasm_bindgen(getter)]
|
||||
#[wasm_bindgen(js_name = "expansionRatio")]
|
||||
pub fn expansion_ratio(&self) -> f64 {
|
||||
self.generated_data.expansion_ratio
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "syntheticDataToJs")]
|
||||
pub fn synthetic_data_to_js(&self) -> JsResult<JsCsvData> {
|
||||
Ok(JsValue::from_serde(&self.generated_data.synthetic_data)
|
||||
.map_err(|err| JsValue::from(err.to_string()))?
|
||||
.unchecked_into::<JsCsvData>())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "toJs")]
|
||||
pub fn to_js(&self) -> JsResult<JsGenerateResult> {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new(String::from("generate result serialization"));
|
||||
let result = Object::new();
|
||||
|
||||
set(
|
||||
&result,
|
||||
&"expansionRatio".into(),
|
||||
&self.expansion_ratio().into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"syntheticData".into(),
|
||||
&self.synthetic_data_to_js()?.into(),
|
||||
)?;
|
||||
|
||||
Ok(JsValue::from(result).unchecked_into::<JsGenerateResult>())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WasmGenerateResult {
|
||||
type Target = GeneratedData;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.generated_data
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for WasmGenerateResult {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.generated_data
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
pub mod generate_result;
|
|
@ -0,0 +1,11 @@
|
|||
pub mod aggregator;
|
||||
|
||||
pub mod evaluator;
|
||||
|
||||
pub mod generator;
|
||||
|
||||
pub mod navigator;
|
||||
|
||||
pub mod sds_processor;
|
||||
|
||||
pub mod sds_context;
|
|
@ -0,0 +1,75 @@
|
|||
use js_sys::{Array, Object, Reflect::set};
|
||||
use std::{collections::HashMap, convert::TryFrom, sync::Arc};
|
||||
use wasm_bindgen::{JsCast, JsValue};
|
||||
|
||||
use crate::utils::js::ts_definitions::{
|
||||
JsAttributesIntersection, JsAttributesIntersectionByColumn,
|
||||
};
|
||||
|
||||
pub struct WasmAttributesIntersection {
|
||||
pub(crate) value: Arc<String>,
|
||||
pub(crate) estimated_count: usize,
|
||||
pub(crate) actual_count: Option<usize>,
|
||||
}
|
||||
|
||||
impl WasmAttributesIntersection {
|
||||
#[inline]
|
||||
pub fn new(
|
||||
value: Arc<String>,
|
||||
estimated_count: usize,
|
||||
actual_count: Option<usize>,
|
||||
) -> WasmAttributesIntersection {
|
||||
WasmAttributesIntersection {
|
||||
value,
|
||||
estimated_count,
|
||||
actual_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<WasmAttributesIntersection> for JsAttributesIntersection {
|
||||
type Error = JsValue;
|
||||
|
||||
fn try_from(attributes_intersection: WasmAttributesIntersection) -> Result<Self, Self::Error> {
|
||||
let result = Object::new();
|
||||
|
||||
set(
|
||||
&result,
|
||||
&"value".into(),
|
||||
&attributes_intersection.value.as_str().into(),
|
||||
)?;
|
||||
set(
|
||||
&result,
|
||||
&"estimatedCount".into(),
|
||||
&attributes_intersection.estimated_count.into(),
|
||||
)?;
|
||||
if let Some(actual_count) = attributes_intersection.actual_count {
|
||||
set(&result, &"actualCount".into(), &actual_count.into())?;
|
||||
}
|
||||
|
||||
Ok(result.unchecked_into::<JsAttributesIntersection>())
|
||||
}
|
||||
}
|
||||
|
||||
pub type WasmAttributesIntersectionByColumn = HashMap<usize, Vec<WasmAttributesIntersection>>;
|
||||
|
||||
impl TryFrom<WasmAttributesIntersectionByColumn> for JsAttributesIntersectionByColumn {
|
||||
type Error = JsValue;
|
||||
|
||||
fn try_from(
|
||||
mut intersections_by_column: WasmAttributesIntersectionByColumn,
|
||||
) -> Result<Self, Self::Error> {
|
||||
let result = Object::new();
|
||||
|
||||
for (column_index, mut attrs_intersections) in intersections_by_column.drain() {
|
||||
let intersections: Array = Array::default();
|
||||
|
||||
for intersection in attrs_intersections.drain(..) {
|
||||
intersections.push(&JsAttributesIntersection::try_from(intersection)?.into());
|
||||
}
|
||||
set(&result, &column_index.into(), &intersections)?;
|
||||
}
|
||||
|
||||
Ok(result.unchecked_into::<JsAttributesIntersectionByColumn>())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
pub mod attributes_intersection;
|
||||
|
||||
pub mod navigate_result;
|
||||
|
||||
pub mod selected_attributes;
|
|
@ -0,0 +1,229 @@
|
|||
use super::{
|
||||
attributes_intersection::{WasmAttributesIntersection, WasmAttributesIntersectionByColumn},
|
||||
selected_attributes::{WasmSelectedAttributes, WasmSelectedAttributesByColumn},
|
||||
};
|
||||
use sds_core::{
|
||||
data_block::{
|
||||
block::DataBlock,
|
||||
typedefs::{
|
||||
AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, ColumnIndexByName, CsvRecord,
|
||||
},
|
||||
value::DataBlockValue,
|
||||
},
|
||||
processing::aggregator::value_combination::ValueCombination,
|
||||
utils::collections::ordered_vec_intersection,
|
||||
};
|
||||
use std::{cmp::Reverse, convert::TryFrom, sync::Arc};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::{
|
||||
processing::{aggregator::aggregate_result::WasmAggregateResult, sds_processor::SDSProcessor},
|
||||
utils::js::ts_definitions::{
|
||||
JsAttributesIntersectionByColumn, JsHeaderNames, JsResult, JsSelectedAttributesByColumn,
|
||||
},
|
||||
};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmNavigateResult {
|
||||
synthetic_data_block: Arc<DataBlock>,
|
||||
attr_rows_by_column: AttributeRowsByColumnMap,
|
||||
selected_attributes: WasmSelectedAttributesByColumn,
|
||||
selected_attr_rows: AttributeRows,
|
||||
all_attr_rows: AttributeRows,
|
||||
column_index_by_name: ColumnIndexByName,
|
||||
}
|
||||
|
||||
impl WasmNavigateResult {
|
||||
#[inline]
|
||||
pub fn default() -> WasmNavigateResult {
|
||||
WasmNavigateResult::new(
|
||||
Arc::new(DataBlock::default()),
|
||||
AttributeRowsByColumnMap::default(),
|
||||
WasmSelectedAttributesByColumn::default(),
|
||||
AttributeRows::default(),
|
||||
AttributeRows::default(),
|
||||
ColumnIndexByName::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn new(
|
||||
synthetic_data_block: Arc<DataBlock>,
|
||||
attr_rows_by_column: AttributeRowsByColumnMap,
|
||||
selected_attributes: WasmSelectedAttributesByColumn,
|
||||
selected_attr_rows: AttributeRows,
|
||||
all_attr_rows: AttributeRows,
|
||||
column_index_by_name: ColumnIndexByName,
|
||||
) -> WasmNavigateResult {
|
||||
WasmNavigateResult {
|
||||
synthetic_data_block,
|
||||
attr_rows_by_column,
|
||||
selected_attributes,
|
||||
selected_attr_rows,
|
||||
all_attr_rows,
|
||||
column_index_by_name,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn intersect_attributes(&self, attributes: &WasmSelectedAttributesByColumn) -> AttributeRows {
|
||||
let default_empty_attr_rows = AttributeRowsMap::default();
|
||||
let default_empty_rows = AttributeRows::default();
|
||||
let mut result = self.all_attr_rows.clone();
|
||||
|
||||
for (column_index, values) in attributes.iter() {
|
||||
let attr_rows = self
|
||||
.attr_rows_by_column
|
||||
.get(column_index)
|
||||
.unwrap_or(&default_empty_attr_rows);
|
||||
for value in values.iter() {
|
||||
result = ordered_vec_intersection(
|
||||
&result,
|
||||
attr_rows.get(value).unwrap_or(&default_empty_rows),
|
||||
);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_sensitive_aggregates_count(
|
||||
&self,
|
||||
attributes: &WasmSelectedAttributesByColumn,
|
||||
sensitive_aggregate_result: &WasmAggregateResult,
|
||||
) -> Option<usize> {
|
||||
let mut combination: Vec<Arc<DataBlockValue>> = attributes
|
||||
.values()
|
||||
.flat_map(|values| values.iter().cloned())
|
||||
.collect();
|
||||
|
||||
combination.sort_by_key(|k| k.format_str_using_headers(&self.synthetic_data_block.headers));
|
||||
|
||||
sensitive_aggregate_result
|
||||
.aggregates_count
|
||||
.get(&ValueCombination::new(combination))
|
||||
.map(|c| c.count)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_selected_but_current_column(
|
||||
&self,
|
||||
column_index: usize,
|
||||
) -> (WasmSelectedAttributesByColumn, AttributeRows) {
|
||||
let mut selected_attributes_but_current_column = self.selected_attributes.clone();
|
||||
|
||||
selected_attributes_but_current_column.remove(&column_index);
|
||||
|
||||
let selected_attr_rows_but_current_column =
|
||||
self.intersect_attributes(&selected_attributes_but_current_column);
|
||||
|
||||
(
|
||||
selected_attributes_but_current_column,
|
||||
selected_attr_rows_but_current_column,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn selected_attributes_contains_value(
|
||||
&self,
|
||||
column_index: usize,
|
||||
value: &Arc<DataBlockValue>,
|
||||
) -> bool {
|
||||
self.selected_attributes
|
||||
.get(&column_index)
|
||||
.map(|values| values.contains(value))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmNavigateResult {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn from_synthetic_processor(synthetic_processor: &SDSProcessor) -> WasmNavigateResult {
|
||||
WasmNavigateResult::new(
|
||||
synthetic_processor.data_block.clone(),
|
||||
synthetic_processor
|
||||
.data_block
|
||||
.calc_attr_rows_with_no_empty_values(),
|
||||
WasmSelectedAttributesByColumn::default(),
|
||||
AttributeRows::default(),
|
||||
(0..synthetic_processor.data_block.number_of_records()).collect(),
|
||||
synthetic_processor.data_block.calc_column_index_by_name(),
|
||||
)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "selectAttributes")]
|
||||
pub fn select_attributes(&mut self, attributes: JsSelectedAttributesByColumn) -> JsResult<()> {
|
||||
self.selected_attributes = WasmSelectedAttributesByColumn::try_from(attributes)?;
|
||||
self.selected_attr_rows = self.intersect_attributes(&self.selected_attributes);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "attributesIntersectionsByColumn")]
|
||||
pub fn attributes_intersections_by_column(
|
||||
&mut self,
|
||||
columns: JsHeaderNames,
|
||||
sensitive_aggregate_result: &WasmAggregateResult,
|
||||
) -> JsResult<JsAttributesIntersectionByColumn> {
|
||||
let column_indexes: AttributeRows = CsvRecord::try_from(columns)?
|
||||
.iter()
|
||||
.filter_map(|header_name| self.column_index_by_name.get(header_name))
|
||||
.cloned()
|
||||
.collect();
|
||||
let mut result = WasmAttributesIntersectionByColumn::default();
|
||||
|
||||
for column_index in column_indexes {
|
||||
let result_column_entry = result.entry(column_index).or_insert_with(Vec::default);
|
||||
if let Some(attr_rows) = self.attr_rows_by_column.get(&column_index) {
|
||||
let (selected_attributes_but_current_column, selected_attr_rows_but_current_column) =
|
||||
self.get_selected_but_current_column(column_index);
|
||||
|
||||
for (value, rows) in attr_rows {
|
||||
let current_selected_attr_rows;
|
||||
let current_selected_attributes;
|
||||
|
||||
if self.selected_attributes_contains_value(column_index, value) {
|
||||
// if the selected attributes contain the current value, we
|
||||
// have already the selected attr rows cached
|
||||
current_selected_attributes = &self.selected_attributes;
|
||||
current_selected_attr_rows = &self.selected_attr_rows;
|
||||
} else {
|
||||
// if the selected attributes do not contain the current value
|
||||
// use the intersection between all selected values, but the
|
||||
// ones in the current column
|
||||
current_selected_attributes = &selected_attributes_but_current_column;
|
||||
current_selected_attr_rows = &selected_attr_rows_but_current_column;
|
||||
};
|
||||
|
||||
let estimated_count =
|
||||
ordered_vec_intersection(current_selected_attr_rows, rows).len();
|
||||
|
||||
if estimated_count > 0 {
|
||||
let mut attributes = current_selected_attributes.clone();
|
||||
let attributes_entry = attributes
|
||||
.entry(column_index)
|
||||
.or_insert_with(WasmSelectedAttributes::default);
|
||||
|
||||
attributes_entry.insert(value.clone());
|
||||
|
||||
result_column_entry.push(WasmAttributesIntersection::new(
|
||||
value.value.clone(),
|
||||
estimated_count,
|
||||
self.get_sensitive_aggregates_count(
|
||||
&attributes,
|
||||
sensitive_aggregate_result,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort by estimated count in descending order
|
||||
for intersections in result.values_mut() {
|
||||
intersections.sort_by_key(|intersection| Reverse(intersection.estimated_count))
|
||||
}
|
||||
|
||||
JsAttributesIntersectionByColumn::try_from(result)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
use js_sys::{Array, Object, Set};
|
||||
use sds_core::data_block::value::DataBlockValue;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
convert::TryFrom,
|
||||
sync::Arc,
|
||||
};
|
||||
use wasm_bindgen::{JsCast, JsValue};
|
||||
|
||||
use crate::utils::js::ts_definitions::JsSelectedAttributesByColumn;
|
||||
|
||||
pub type WasmSelectedAttributes = HashSet<Arc<DataBlockValue>>;
|
||||
|
||||
pub type WasmSelectedAttributesByColumn = HashMap<usize, WasmSelectedAttributes>;
|
||||
|
||||
impl TryFrom<JsSelectedAttributesByColumn> for WasmSelectedAttributesByColumn {
|
||||
type Error = JsValue;
|
||||
|
||||
fn try_from(js_selected_attributes: JsSelectedAttributesByColumn) -> Result<Self, Self::Error> {
|
||||
let mut result = WasmSelectedAttributesByColumn::default();
|
||||
|
||||
for entry_res in Object::entries(&js_selected_attributes.dyn_into::<Object>()?).values() {
|
||||
let entry = entry_res?.dyn_into::<Array>()?;
|
||||
let column_index = entry
|
||||
.get(0)
|
||||
.as_string()
|
||||
.ok_or_else(|| {
|
||||
JsValue::from("invalid column index on selected attributes by column")
|
||||
})?
|
||||
.parse::<usize>()
|
||||
.map_err(|err| JsValue::from(err.to_string()))?;
|
||||
let result_entry = result.entry(column_index).or_insert_with(HashSet::default);
|
||||
|
||||
for value_res in entry.get(1).dyn_into::<Set>()?.keys() {
|
||||
let value = value_res?;
|
||||
let value_str = value.as_string().ok_or_else(|| {
|
||||
JsValue::from("invalid value on selected attributes by column")
|
||||
})?;
|
||||
|
||||
result_entry.insert(Arc::new(DataBlockValue::new(
|
||||
column_index,
|
||||
Arc::new(value_str),
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
use super::{
|
||||
evaluator::evaluate_result::WasmEvaluateResult, generator::generate_result::WasmGenerateResult,
|
||||
navigator::navigate_result::WasmNavigateResult, sds_processor::SDSProcessor,
|
||||
};
|
||||
use log::debug;
|
||||
use wasm_bindgen::{prelude::*, JsCast};
|
||||
|
||||
use crate::utils::js::ts_definitions::{
|
||||
JsAttributesIntersectionByColumn, JsCsvData, JsEvaluateResult, JsGenerateResult, JsHeaderNames,
|
||||
JsReportProgressCallback, JsResult, JsSelectedAttributesByColumn,
|
||||
};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct SDSContext {
|
||||
use_columns: JsHeaderNames,
|
||||
sensitive_zeros: JsHeaderNames,
|
||||
record_limit: usize,
|
||||
sensitive_processor: SDSProcessor,
|
||||
generate_result: WasmGenerateResult,
|
||||
resolution: usize,
|
||||
synthetic_processor: SDSProcessor,
|
||||
evaluate_result: WasmEvaluateResult,
|
||||
navigate_result: WasmNavigateResult,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SDSContext {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn default() -> SDSContext {
|
||||
SDSContext {
|
||||
use_columns: JsHeaderNames::default(),
|
||||
sensitive_zeros: JsHeaderNames::default(),
|
||||
record_limit: 0,
|
||||
sensitive_processor: SDSProcessor::default(),
|
||||
generate_result: WasmGenerateResult::default(),
|
||||
resolution: 0,
|
||||
synthetic_processor: SDSProcessor::default(),
|
||||
evaluate_result: WasmEvaluateResult::default(),
|
||||
navigate_result: WasmNavigateResult::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "clearSensitiveData")]
|
||||
pub fn clear_sensitive_data(&mut self) {
|
||||
self.use_columns = JsHeaderNames::default();
|
||||
self.sensitive_zeros = JsHeaderNames::default();
|
||||
self.record_limit = 0;
|
||||
self.sensitive_processor = SDSProcessor::default();
|
||||
self.clear_generate();
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "clearGenerate")]
|
||||
pub fn clear_generate(&mut self) {
|
||||
self.generate_result = WasmGenerateResult::default();
|
||||
self.resolution = 0;
|
||||
self.synthetic_processor = SDSProcessor::default();
|
||||
self.clear_evaluate()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "clearEvaluate")]
|
||||
pub fn clear_evaluate(&mut self) {
|
||||
self.evaluate_result = WasmEvaluateResult::default();
|
||||
self.clear_navigate()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "clearNavigate")]
|
||||
pub fn clear_navigate(&mut self) {
|
||||
self.navigate_result = WasmNavigateResult::default();
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "setSensitiveData")]
|
||||
pub fn set_sensitive_data(
|
||||
&mut self,
|
||||
csv_data: JsCsvData,
|
||||
use_columns: JsHeaderNames,
|
||||
sensitive_zeros: JsHeaderNames,
|
||||
record_limit: usize,
|
||||
) -> JsResult<()> {
|
||||
debug!("setting sensitive data...");
|
||||
|
||||
self.use_columns = use_columns;
|
||||
self.sensitive_zeros = sensitive_zeros;
|
||||
self.record_limit = record_limit;
|
||||
self.sensitive_processor = SDSProcessor::new(
|
||||
csv_data,
|
||||
self.use_columns.clone().unchecked_into(),
|
||||
self.sensitive_zeros.clone().unchecked_into(),
|
||||
self.record_limit,
|
||||
)?;
|
||||
self.clear_generate();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn generate(
|
||||
&mut self,
|
||||
cache_max_size: usize,
|
||||
resolution: usize,
|
||||
empty_value: String,
|
||||
seeded: bool,
|
||||
progress_callback: JsReportProgressCallback,
|
||||
) -> JsResult<()> {
|
||||
debug!("generating synthetic data...");
|
||||
|
||||
self.generate_result = self.sensitive_processor.generate(
|
||||
cache_max_size,
|
||||
resolution,
|
||||
empty_value,
|
||||
seeded,
|
||||
progress_callback,
|
||||
)?;
|
||||
self.resolution = resolution;
|
||||
|
||||
debug!("creating synthetic data processor...");
|
||||
|
||||
self.synthetic_processor = SDSProcessor::new(
|
||||
self.generate_result.synthetic_data_to_js()?,
|
||||
self.use_columns.clone().unchecked_into(),
|
||||
self.sensitive_zeros.clone().unchecked_into(),
|
||||
self.record_limit,
|
||||
)?;
|
||||
self.clear_evaluate();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn evaluate(
|
||||
&mut self,
|
||||
reporting_length: usize,
|
||||
sensitivity_threshold: usize,
|
||||
sensitive_progress_callback: JsReportProgressCallback,
|
||||
synthetic_progress_callback: JsReportProgressCallback,
|
||||
) -> JsResult<()> {
|
||||
debug!("aggregating sensitive data...");
|
||||
|
||||
let sensitive_aggregate_result = self.sensitive_processor.aggregate(
|
||||
reporting_length,
|
||||
sensitivity_threshold,
|
||||
sensitive_progress_callback,
|
||||
)?;
|
||||
|
||||
debug!("aggregating synthetic data...");
|
||||
|
||||
let synthetic_aggregate_result = self.synthetic_processor.aggregate(
|
||||
reporting_length,
|
||||
sensitivity_threshold,
|
||||
synthetic_progress_callback,
|
||||
)?;
|
||||
|
||||
debug!("evaluating synthetic data based on sensitive data...");
|
||||
|
||||
self.evaluate_result = WasmEvaluateResult::from_aggregate_results(
|
||||
sensitive_aggregate_result,
|
||||
synthetic_aggregate_result,
|
||||
);
|
||||
|
||||
self.clear_navigate();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "protectSensitiveAggregatesCount")]
|
||||
pub fn protect_sensitive_aggregates_count(&mut self) {
|
||||
debug!("protecting sensitive aggregates count...");
|
||||
|
||||
self.evaluate_result
|
||||
.sensitive_aggregate_result
|
||||
.protect_aggregates_count(self.resolution);
|
||||
}
|
||||
|
||||
pub fn navigate(&mut self) {
|
||||
debug!("creating navigate result...");
|
||||
|
||||
self.navigate_result =
|
||||
WasmNavigateResult::from_synthetic_processor(&self.synthetic_processor);
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "selectAttributes")]
|
||||
pub fn select_attributes(&mut self, attributes: JsSelectedAttributesByColumn) -> JsResult<()> {
|
||||
self.navigate_result.select_attributes(attributes)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "attributesIntersectionsByColumn")]
|
||||
pub fn attributes_intersections_by_column(
|
||||
&mut self,
|
||||
columns: JsHeaderNames,
|
||||
) -> JsResult<JsAttributesIntersectionByColumn> {
|
||||
self.navigate_result.attributes_intersections_by_column(
|
||||
columns,
|
||||
&self.evaluate_result.sensitive_aggregate_result,
|
||||
)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "generateResultToJs")]
|
||||
pub fn generate_result_to_js(&self) -> JsResult<JsGenerateResult> {
|
||||
self.generate_result.to_js()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "evaluateResultToJs")]
|
||||
pub fn evaluate_result_to_js(
|
||||
&self,
|
||||
combination_delimiter: &str,
|
||||
include_aggregates_count: bool,
|
||||
) -> JsResult<JsEvaluateResult> {
|
||||
self.evaluate_result.to_js(
|
||||
combination_delimiter,
|
||||
self.resolution,
|
||||
include_aggregates_count,
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
use js_sys::Array;
|
||||
use sds_core::data_block::typedefs::CsvRecord;
|
||||
use std::convert::TryFrom;
|
||||
use wasm_bindgen::{JsCast, JsValue};
|
||||
|
||||
use crate::utils::js::ts_definitions::JsHeaderNames;
|
||||
|
||||
impl JsHeaderNames {
|
||||
#[inline]
|
||||
pub fn default() -> JsHeaderNames {
|
||||
Array::default().unchecked_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<JsHeaderNames> for CsvRecord {
|
||||
type Error = JsValue;
|
||||
|
||||
fn try_from(js_csv_record: JsHeaderNames) -> Result<Self, Self::Error> {
|
||||
js_csv_record
|
||||
.into_serde::<CsvRecord>()
|
||||
.map_err(|err| JsValue::from(err.to_string()))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
pub mod header_names;
|
||||
|
||||
mod processor;
|
||||
|
||||
pub use processor::SDSProcessor;
|
|
@ -0,0 +1,120 @@
|
|||
use js_sys::Function;
|
||||
use sds_core::{
|
||||
data_block::{block::DataBlock, data_block_creator::DataBlockCreator, typedefs::CsvRecord},
|
||||
processing::{
|
||||
aggregator::Aggregator,
|
||||
generator::{Generator, SynthesisMode},
|
||||
},
|
||||
utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
use wasm_bindgen::{prelude::*, JsCast};
|
||||
|
||||
use crate::{
|
||||
utils::js::js_progress_reporter::JsProgressReporter,
|
||||
{
|
||||
processing::{
|
||||
aggregator::aggregate_result::WasmAggregateResult,
|
||||
generator::generate_result::WasmGenerateResult,
|
||||
},
|
||||
utils::js::{
|
||||
js_block_creator::JsDataBlockCreator,
|
||||
ts_definitions::{JsCsvData, JsHeaderNames, JsReportProgressCallback},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct SDSProcessor {
|
||||
pub(crate) data_block: Arc<DataBlock>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SDSProcessor {
|
||||
pub fn default() -> SDSProcessor {
|
||||
SDSProcessor {
|
||||
data_block: Arc::new(DataBlock::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SDSProcessor {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(
|
||||
csv_data: JsCsvData,
|
||||
use_columns: JsHeaderNames,
|
||||
sensitive_zeros: JsHeaderNames,
|
||||
record_limit: usize,
|
||||
) -> Result<SDSProcessor, JsValue> {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("sds processor creation"));
|
||||
let data_block = JsDataBlockCreator::create(
|
||||
Ok(csv_data.dyn_into()?),
|
||||
&CsvRecord::try_from(use_columns)?,
|
||||
&CsvRecord::try_from(sensitive_zeros)?,
|
||||
record_limit,
|
||||
)?;
|
||||
|
||||
Ok(SDSProcessor { data_block })
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "numberOfRecords")]
|
||||
pub fn number_of_records(&self) -> usize {
|
||||
self.data_block.number_of_records()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "protectedNumberOfRecords")]
|
||||
pub fn protected_number_of_records(&self, resolution: usize) -> usize {
|
||||
self.data_block.protected_number_of_records(resolution)
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = "normalizeReportingLength")]
|
||||
pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize {
|
||||
self.data_block.normalize_reporting_length(reporting_length)
|
||||
}
|
||||
|
||||
pub fn aggregate(
|
||||
&self,
|
||||
reporting_length: usize,
|
||||
sensitivity_threshold: usize,
|
||||
progress_callback: JsReportProgressCallback,
|
||||
) -> Result<WasmAggregateResult, JsValue> {
|
||||
let _duration_logger =
|
||||
ElapsedDurationLogger::new(String::from("sds processor aggregation"));
|
||||
let js_callback: Function = progress_callback.dyn_into()?;
|
||||
let mut aggregator = Aggregator::new(self.data_block.clone());
|
||||
|
||||
Ok(WasmAggregateResult::new(aggregator.aggregate(
|
||||
reporting_length,
|
||||
sensitivity_threshold,
|
||||
&mut Some(JsProgressReporter::new(&js_callback, &|p| p)),
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn generate(
|
||||
&self,
|
||||
cache_max_size: usize,
|
||||
resolution: usize,
|
||||
empty_value: String,
|
||||
seeded: bool,
|
||||
progress_callback: JsReportProgressCallback,
|
||||
) -> Result<WasmGenerateResult, JsValue> {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("sds processor generation"));
|
||||
let js_callback: Function = progress_callback.dyn_into()?;
|
||||
let mut generator = Generator::new(self.data_block.clone());
|
||||
let mode = if seeded {
|
||||
SynthesisMode::Seeded
|
||||
} else {
|
||||
SynthesisMode::Unseeded
|
||||
};
|
||||
|
||||
Ok(WasmGenerateResult::new(generator.generate(
|
||||
resolution,
|
||||
cache_max_size,
|
||||
empty_value,
|
||||
mode,
|
||||
&mut Some(JsProgressReporter::new(&js_callback, &|p| p)),
|
||||
)))
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
use super::js_block_creator::JsDataBlockCreator;
|
||||
use js_sys::Array;
|
||||
use sds_core::{
|
||||
data_block::{
|
||||
block::{DataBlock, DataBlockCreator},
|
||||
typedefs::CsvRecord,
|
||||
},
|
||||
utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
|
||||
#[inline]
|
||||
pub fn deserialize_use_columns(use_columns: Array) -> Result<CsvRecord, String> {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_use_columns"));
|
||||
|
||||
match use_columns.into_serde::<CsvRecord>() {
|
||||
Ok(v) => Ok(v),
|
||||
_ => Err(String::from("use_columns should be an Array<string>")),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn deserialize_sensitive_zeros(sensitive_zeros: Array) -> Result<CsvRecord, String> {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_sensitive_zeros"));
|
||||
|
||||
match sensitive_zeros.into_serde::<CsvRecord>() {
|
||||
Ok(v) => Ok(v),
|
||||
_ => Err(String::from("sensitive_zeros should be an Array<string>")),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn deserialize_csv_data(
|
||||
csv_data: Array,
|
||||
use_columns: &[String],
|
||||
sensitive_zeros: &[String],
|
||||
record_limit: usize,
|
||||
) -> Result<DataBlock, String> {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_csv_data"));
|
||||
|
||||
JsDataBlockCreator::create(Ok(csv_data), use_columns, sensitive_zeros, record_limit)
|
||||
}
|
|
@ -1,33 +1,34 @@
|
|||
use js_sys::Array;
|
||||
use sds_core::data_block::{block::DataBlockCreator, typedefs::CsvRecord};
|
||||
use sds_core::data_block::{data_block_creator::DataBlockCreator, typedefs::CsvRecord};
|
||||
use wasm_bindgen::JsValue;
|
||||
|
||||
pub struct JsDataBlockCreator;
|
||||
|
||||
impl DataBlockCreator for JsDataBlockCreator {
|
||||
type InputType = Array;
|
||||
type ErrorType = String;
|
||||
type ErrorType = JsValue;
|
||||
|
||||
fn get_headers(csv_data: &mut Array) -> Result<CsvRecord, String> {
|
||||
fn get_headers(csv_data: &mut Self::InputType) -> Result<CsvRecord, Self::ErrorType> {
|
||||
let headers = csv_data.get(0);
|
||||
|
||||
if headers.is_undefined() {
|
||||
Err(String::from("missing headers"))
|
||||
Err(JsValue::from("csv data missing headers"))
|
||||
} else {
|
||||
match headers.into_serde::<Vec<String>>() {
|
||||
Ok(h) => Ok(h),
|
||||
Err(_) => Err(String::from("headers should an Array<string>")),
|
||||
}
|
||||
headers
|
||||
.into_serde::<CsvRecord>()
|
||||
.map_err(|err| JsValue::from(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_records(csv_data: &mut Self::InputType) -> Result<Vec<CsvRecord>, String> {
|
||||
fn get_records(csv_data: &mut Self::InputType) -> Result<Vec<CsvRecord>, Self::ErrorType> {
|
||||
csv_data
|
||||
.slice(1, csv_data.length())
|
||||
.iter()
|
||||
.map(|record| match record.into_serde::<Vec<String>>() {
|
||||
Ok(h) => Ok(h),
|
||||
Err(_) => Err(String::from("records should an Array<string>")),
|
||||
.map(|record| {
|
||||
record
|
||||
.into_serde::<CsvRecord>()
|
||||
.map_err(|err| JsValue::from(err.to_string()))
|
||||
})
|
||||
.collect::<Result<Vec<CsvRecord>, String>>()
|
||||
.collect::<Result<Vec<CsvRecord>, Self::ErrorType>>()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! match_or_return_undefined {
|
||||
($result_to_match: expr) => {
|
||||
match $result_to_match {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
error!("{}", err);
|
||||
return JsValue::undefined();
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! set_or_return_undefined {
|
||||
($($arg:tt)*) => {
|
||||
match set($($arg)*) {
|
||||
Ok(has_been_set) => {
|
||||
if !has_been_set {
|
||||
error!("unable to set object value from wasm");
|
||||
return JsValue::undefined();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
error!("unable to set object value from wasm");
|
||||
return JsValue::undefined();
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
|
@ -1,16 +1,9 @@
|
|||
pub fn set_panic_hook() {
|
||||
// When the `console_error_panic_hook` feature is enabled, we can call the
|
||||
// `set_panic_hook` function at least once during initialization, and then
|
||||
// we will get better error messages if our code ever panics.
|
||||
//
|
||||
// For more details see
|
||||
// https://github.com/rustwasm/console_error_panic_hook#readme
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
pub mod deserializers;
|
||||
pub mod js_block_creator;
|
||||
|
||||
pub mod js_progress_reporter;
|
||||
pub mod macros;
|
||||
pub mod serializers;
|
||||
|
||||
mod set_panic_hook;
|
||||
|
||||
pub use set_panic_hook::set_panic_hook;
|
||||
|
||||
pub mod ts_definitions;
|
||||
|
|
|
@ -1,124 +0,0 @@
|
|||
use js_sys::{Object, Reflect::set};
|
||||
use log::error;
|
||||
use sds_core::{
|
||||
data_block::typedefs::DataBlockHeadersSlice,
|
||||
processing::{
|
||||
aggregator::{typedefs::AggregatesCountMap, Aggregator, PrivacyRiskSummary},
|
||||
evaluator::typedefs::PreservationByCountBuckets,
|
||||
},
|
||||
utils::time::ElapsedDurationLogger,
|
||||
};
|
||||
use wasm_bindgen::JsValue;
|
||||
|
||||
use crate::{aggregator::AggregatedCombination, set_or_return_undefined};
|
||||
|
||||
#[inline]
|
||||
pub fn serialize_aggregates_count(
|
||||
aggregates_count: &AggregatesCountMap,
|
||||
headers: &DataBlockHeadersSlice,
|
||||
) -> JsValue {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_aggregates_count"));
|
||||
let aggregated_values = Object::new();
|
||||
|
||||
for (agg, count) in aggregates_count.iter() {
|
||||
let combination_key = Aggregator::format_aggregate_str(headers, agg);
|
||||
|
||||
set_or_return_undefined!(
|
||||
&aggregated_values,
|
||||
&combination_key.clone().into(),
|
||||
&AggregatedCombination::new(combination_key, count.count, agg.len()).into()
|
||||
);
|
||||
}
|
||||
aggregated_values.into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn serialize_buckets(buckets: &PreservationByCountBuckets) -> JsValue {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_buckets"));
|
||||
let serialized_buckets = Object::new();
|
||||
|
||||
for (bucket_index, b) in buckets.iter() {
|
||||
let serialized_bucket = Object::new();
|
||||
|
||||
set_or_return_undefined!(&serialized_bucket, &"size".into(), &b.size.into());
|
||||
set_or_return_undefined!(
|
||||
&serialized_bucket,
|
||||
&"preservationSum".into(),
|
||||
&b.preservation_sum.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_bucket,
|
||||
&"lengthSum".into(),
|
||||
&b.length_sum.into()
|
||||
);
|
||||
|
||||
set_or_return_undefined!(
|
||||
&serialized_buckets,
|
||||
&(*bucket_index).into(),
|
||||
&serialized_bucket.into()
|
||||
);
|
||||
}
|
||||
serialized_buckets.into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn serialize_privacy_risk(privacy_risk: &PrivacyRiskSummary) -> JsValue {
|
||||
let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_privacy_risk"));
|
||||
let serialized_privacy_risk = Object::new();
|
||||
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"totalNumberOfRecords".into(),
|
||||
&privacy_risk.total_number_of_records.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"totalNumberOfCombinations".into(),
|
||||
&privacy_risk.total_number_of_combinations.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"recordsWithUniqueCombinationsCount".into(),
|
||||
&privacy_risk.records_with_unique_combinations_count.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"recordsWithRareCombinationsCount".into(),
|
||||
&privacy_risk.records_with_rare_combinations_count.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"uniqueCombinationsCount".into(),
|
||||
&privacy_risk.unique_combinations_count.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"rareCombinationsCount".into(),
|
||||
&privacy_risk.rare_combinations_count.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"recordsWithUniqueCombinationsProportion".into(),
|
||||
&privacy_risk
|
||||
.records_with_unique_combinations_proportion
|
||||
.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"recordsWithRareCombinationsProportion".into(),
|
||||
&privacy_risk
|
||||
.records_with_rare_combinations_proportion
|
||||
.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"uniqueCombinationsProportion".into(),
|
||||
&privacy_risk.unique_combinations_proportion.into()
|
||||
);
|
||||
set_or_return_undefined!(
|
||||
&serialized_privacy_risk,
|
||||
&"rareCombinationsProportion".into(),
|
||||
&privacy_risk.rare_combinations_proportion.into()
|
||||
);
|
||||
serialized_privacy_risk.into()
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
pub fn set_panic_hook() {
|
||||
// When the `console_error_panic_hook` feature is enabled, we can call the
|
||||
// `set_panic_hook` function at least once during initialization, and then
|
||||
// we will get better error messages if our code ever panics.
|
||||
//
|
||||
// For more details see
|
||||
// https://github.com/rustwasm/console_error_panic_hook#readme
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen(typescript_custom_section)]
|
||||
const TS_APPEND_CONTENT: &'static str = r#"
|
||||
export type ReportProgressCallback = (progress: number) => void;
|
||||
|
||||
export type HeaderNames = string[];
|
||||
|
||||
export type CsvRecord = string[];
|
||||
|
||||
export type CsvData = CsvRecord[];
|
||||
|
||||
export interface IGenerateResult {
|
||||
expansionRatio: number;
|
||||
syntheticData: CsvData;
|
||||
}
|
||||
|
||||
export interface IPrivacyRiskSummary {
|
||||
totalNumberOfRecords: number
|
||||
totalNumberOfCombinations: number
|
||||
recordsWithUniqueCombinationsCount: number
|
||||
recordsWithRareCombinationsCount: number
|
||||
uniqueCombinationsCount: number
|
||||
rareCombinationsCount: number
|
||||
recordsWithUniqueCombinationsProportion: number
|
||||
recordsWithRareCombinationsProportion: number
|
||||
uniqueCombinationsProportion: number
|
||||
rareCombinationsProportion: number
|
||||
}
|
||||
|
||||
export interface IAggregateCountByLen {
|
||||
[length: number]: number
|
||||
}
|
||||
|
||||
export interface IAggregateCountAndLength {
|
||||
count: number;
|
||||
length: number;
|
||||
}
|
||||
|
||||
export interface IAggregatesCount {
|
||||
[name: string]: IAggregateCountAndLength;
|
||||
}
|
||||
|
||||
export interface IAggregateResult {
|
||||
reportingLength: number;
|
||||
aggregatesCount?: IAggregatesCount;
|
||||
rareCombinationsCountByLen: IAggregateCountByLen;
|
||||
combinationsCountByLen: IAggregateCountByLen;
|
||||
combinationsSumByLen: IAggregateCountByLen;
|
||||
privacyRisk: IPrivacyRiskSummary;
|
||||
}
|
||||
|
||||
export interface IPreservationByCountBucket {
|
||||
size: number;
|
||||
preservationSum: number;
|
||||
lengthSum: number;
|
||||
combinationCountSum: number;
|
||||
}
|
||||
|
||||
export interface IPreservationByCountBuckets {
|
||||
[bucket_index: number]: IPreservationByCountBucket;
|
||||
}
|
||||
|
||||
export interface IPreservationByCount {
|
||||
buckets: IPreservationByCountBuckets;
|
||||
combinationLoss: number;
|
||||
}
|
||||
|
||||
export interface IEvaluateResult {
|
||||
sensitiveAggregateResult: IAggregateResult;
|
||||
syntheticAggregateResult: IAggregateResult;
|
||||
leakageCountByLen: IAggregateCountByLen;
|
||||
fabricatedCountByLen: IAggregateCountByLen;
|
||||
preservationByCount: IPreservationByCount;
|
||||
recordExpansion: number;
|
||||
}
|
||||
|
||||
export interface ISelectedAttributesByColumn {
|
||||
[columnIndex: number]: Set<string>;
|
||||
}
|
||||
|
||||
export interface IAttributesIntersection {
|
||||
value: string;
|
||||
estimatedCount: number;
|
||||
actualCount?: number;
|
||||
}
|
||||
|
||||
export interface IAttributesIntersectionByColumn {
|
||||
[columnIndex: number]: IAttributesIntersection[];
|
||||
}"#;
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
#[wasm_bindgen(typescript_type = "ReportProgressCallback")]
|
||||
pub type JsReportProgressCallback;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "HeaderNames")]
|
||||
pub type JsHeaderNames;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "CsvRecord")]
|
||||
pub type JsCsvRecord;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "CsvData")]
|
||||
pub type JsCsvData;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IGenerateResult")]
|
||||
pub type JsGenerateResult;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IPrivacyRiskSummary")]
|
||||
pub type JsPrivacyRiskSummary;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IAggregateCountByLen")]
|
||||
pub type JsAggregateCountByLen;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IAggregateCountAndLength")]
|
||||
pub type JsAggregateCountAndLength;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IAggregatesCount")]
|
||||
pub type JsAggregatesCount;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IAggregateResult")]
|
||||
pub type JsAggregateResult;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IPreservationByCountBucket")]
|
||||
pub type JsPreservationByCountBucket;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IPreservationByCountBuckets")]
|
||||
pub type JsPreservationByCountBuckets;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IPreservationByCount")]
|
||||
pub type JsPreservationByCount;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IEvaluateResult")]
|
||||
pub type JsEvaluateResult;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "ISelectedAttributesByColumn")]
|
||||
pub type JsSelectedAttributesByColumn;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IAttributesIntersection")]
|
||||
pub type JsAttributesIntersection;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "IAttributesIntersectionByColumn")]
|
||||
pub type JsAttributesIntersectionByColumn;
|
||||
}
|
||||
|
||||
pub type JsResult<T> = Result<T, JsValue>;
|
|
@ -1,2 +1,3 @@
|
|||
pub mod js;
|
||||
|
||||
pub mod logger;
|
||||
|
|
|
@ -30,11 +30,13 @@ def aggregate(config):
|
|||
sensitive_zeros = config['sensitive_zeros']
|
||||
output_dir = config['output_dir']
|
||||
prefix = config['prefix']
|
||||
sensitive_aggregated_data_json = path.join(
|
||||
output_dir, f'{prefix}_sensitive_aggregated_data.json')
|
||||
|
||||
logging.info(f'Aggregate {sensitive_microdata_path}')
|
||||
start_time = time.time()
|
||||
|
||||
data_block = sds.create_data_block_from_file(
|
||||
sds_processor = sds.SDSProcessor(
|
||||
sensitive_microdata_path,
|
||||
sensitive_microdata_delimiter,
|
||||
use_columns,
|
||||
|
@ -42,13 +44,32 @@ def aggregate(config):
|
|||
max(record_limit, 0)
|
||||
)
|
||||
|
||||
len_to_combo_count, len_to_rare_count = sds.aggregate_and_write(
|
||||
data_block,
|
||||
aggregated_data = sds_processor.aggregate(
|
||||
reporting_length,
|
||||
0
|
||||
)
|
||||
len_to_combo_count = aggregated_data.calc_combinations_count_by_len()
|
||||
len_to_rare_count = aggregated_data.calc_rare_combinations_count_by_len(
|
||||
reporting_resolution)
|
||||
|
||||
aggregated_data.write_aggregates_count(
|
||||
sensitive_aggregates_path,
|
||||
'\t',
|
||||
';',
|
||||
reporting_resolution,
|
||||
False
|
||||
)
|
||||
|
||||
aggregated_data.write_to_json(sensitive_aggregated_data_json)
|
||||
|
||||
aggregated_data.protect_aggregates_count(reporting_resolution)
|
||||
|
||||
aggregated_data.write_aggregates_count(
|
||||
reportable_aggregates_path,
|
||||
'\t',
|
||||
reporting_length,
|
||||
reporting_resolution
|
||||
';',
|
||||
reporting_resolution,
|
||||
True
|
||||
)
|
||||
|
||||
leakage_tsv = path.join(
|
||||
|
|
|
@ -2,8 +2,247 @@ import time
|
|||
import datetime
|
||||
import logging
|
||||
from os import path
|
||||
from collections import defaultdict
|
||||
import util as util
|
||||
import sds
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, config):
|
||||
self.sds_evaluator = sds.Evaluator()
|
||||
self.use_columns = config['use_columns']
|
||||
self.record_limit = max(config['record_limit'], 0)
|
||||
self.reporting_length = max(config['reporting_length'], 0)
|
||||
self.reporting_resolution = config['reporting_resolution']
|
||||
self.sensitive_microdata_path = config['sensitive_microdata_path']
|
||||
self.sensitive_microdata_delimiter = config['sensitive_microdata_delimiter']
|
||||
self.synthetic_microdata_path = config['synthetic_microdata_path']
|
||||
self.sensitive_zeros = config['sensitive_zeros']
|
||||
self.output_dir = config['output_dir']
|
||||
self.prefix = config['prefix']
|
||||
self.sensitive_aggregated_data_json = path.join(
|
||||
self.output_dir, f'{self.prefix}_sensitive_aggregated_data.json'
|
||||
)
|
||||
self.record_analysis_tsv = path.join(
|
||||
self.output_dir, f'{self.prefix}_sensitive_analysis_by_length.tsv'
|
||||
)
|
||||
self.synthetic_rare_combos_tsv = path.join(
|
||||
self.output_dir, f'{self.prefix}_synthetic_rare_combos_by_length.tsv'
|
||||
)
|
||||
self.parameters_tsv = path.join(
|
||||
self.output_dir, f'{self.prefix}_parameters.tsv'
|
||||
)
|
||||
self.leakage_tsv = path.join(
|
||||
self.output_dir, f'{self.prefix}_synthetic_leakage_by_length.tsv'
|
||||
)
|
||||
self.leakage_svg = path.join(
|
||||
self.output_dir, f'{self.prefix}_synthetic_leakage_by_length.svg'
|
||||
)
|
||||
self.preservation_by_count_tsv = path.join(
|
||||
self.output_dir, f'{self.prefix}_preservation_by_count.tsv'
|
||||
)
|
||||
self.preservation_by_count_svg = path.join(
|
||||
self.output_dir, f'{self.prefix}_preservation_by_count.svg'
|
||||
)
|
||||
self.preservation_by_length_tsv = path.join(
|
||||
self.output_dir, f'{self.prefix}_preservation_by_length.tsv'
|
||||
)
|
||||
self.preservation_by_length_svg = path.join(
|
||||
self.output_dir, f'{self.prefix}_preservation_by_length.svg'
|
||||
)
|
||||
|
||||
def _load_sensitive_aggregates(self):
|
||||
if not path.exists(self.sensitive_aggregated_data_json):
|
||||
logging.info('Computing sensitive aggregates...')
|
||||
self.sen_sds_processor = sds.SDSProcessor(
|
||||
self.sensitive_microdata_path,
|
||||
self.sensitive_microdata_delimiter,
|
||||
self.use_columns,
|
||||
self.sensitive_zeros,
|
||||
self.record_limit
|
||||
)
|
||||
self.sen_aggregated_data = self.sen_sds_processor.aggregate(
|
||||
self.reporting_length,
|
||||
0
|
||||
)
|
||||
else:
|
||||
logging.info('Loading sensitive aggregates...')
|
||||
self.sen_aggregated_data = sds.AggregatedData.read_from_json(
|
||||
self.sensitive_aggregated_data_json
|
||||
)
|
||||
self.sen_sds_processor = sds.SDSProcessor.from_aggregated_data(
|
||||
self.sen_aggregated_data
|
||||
)
|
||||
|
||||
self.reporting_length = self.sen_sds_processor.normalize_reporting_length(
|
||||
self.reporting_length
|
||||
)
|
||||
|
||||
def _load_synthetic_aggregates(self):
|
||||
logging.info('Computing synthetic aggregates...')
|
||||
self.syn_sds_processor = sds.SDSProcessor(
|
||||
self.synthetic_microdata_path,
|
||||
"\t",
|
||||
self.use_columns,
|
||||
self.sensitive_zeros,
|
||||
self.record_limit
|
||||
)
|
||||
self.syn_aggregated_data = self.syn_sds_processor.aggregate(
|
||||
self.reporting_length,
|
||||
0
|
||||
)
|
||||
|
||||
def _do_records_analysis(self):
|
||||
logging.info('Performing records analysis on sensitive aggregates...')
|
||||
self.records_analysis_data = self.sen_aggregated_data.calc_records_analysis_by_len(
|
||||
self.reporting_resolution, True
|
||||
)
|
||||
self.records_analysis_data.write_records_analysis(
|
||||
self.record_analysis_tsv, '\t'
|
||||
)
|
||||
|
||||
def _compare_synthetic_and_sensitive_rare(self):
|
||||
logging.info(
|
||||
'Comparing synthetic and sensitive aggregates rare combinations...')
|
||||
rare_combinations_data = self.sds_evaluator.compare_synthetic_and_sensitive_rare(
|
||||
self.syn_aggregated_data, self.sen_aggregated_data, self.reporting_resolution, ' AND ', True
|
||||
)
|
||||
rare_combinations_data.write_rare_combinations(
|
||||
self.synthetic_rare_combos_tsv, '\t'
|
||||
)
|
||||
|
||||
def _write_parameters(self):
|
||||
logging.info('Writing evaluation parameters...')
|
||||
total_sen = self.sen_sds_processor.protected_number_of_records(
|
||||
self.reporting_resolution
|
||||
)
|
||||
with open(self.parameters_tsv, 'w') as f:
|
||||
f.write('\t'.join(['parameter', 'value'])+'\n')
|
||||
f.write(
|
||||
'\t'.join(['resolution', str(self.reporting_resolution)])+'\n'
|
||||
)
|
||||
f.write('\t'.join(['limit', str(self.reporting_length)])+'\n')
|
||||
f.write(
|
||||
'\t'.join(['total_sensitive_records', str(total_sen)])+'\n'
|
||||
)
|
||||
f.write('\t'.join(['unique_identifiable', str(
|
||||
self.records_analysis_data.get_total_unique())])+'\n'
|
||||
)
|
||||
f.write('\t'.join(['rare_identifiable', str(
|
||||
self.records_analysis_data.get_total_rare())])+'\n'
|
||||
)
|
||||
f.write('\t'.join(['risky_identifiable', str(
|
||||
self.records_analysis_data.get_total_risky())])+'\n'
|
||||
)
|
||||
f.write(
|
||||
'\t'.join(['risky_identifiable_pct', str(
|
||||
100*self.records_analysis_data.get_total_risky()/total_sen)])+'\n'
|
||||
)
|
||||
|
||||
def _find_leakages(self):
|
||||
logging.info(
|
||||
'Looking for leakages from the sensitive dataset on the synthetic dataset...')
|
||||
comb_counts = self.syn_aggregated_data.calc_combinations_count_by_len()
|
||||
leakage_counts = self.sds_evaluator.calc_leakage_count(
|
||||
self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution
|
||||
)
|
||||
|
||||
with open(self.leakage_tsv, 'w') as f:
|
||||
f.write('\t'.join(
|
||||
['syn_combo_length', 'combo_count', 'leak_count', 'leak_proportion'])+'\n'
|
||||
)
|
||||
for length in range(1, self.reporting_length + 1):
|
||||
combo_count = comb_counts.get(length, 0)
|
||||
leak_count = leakage_counts.get(length, 0)
|
||||
# by design there should be no leakage
|
||||
assert leak_count == 0
|
||||
leak_prop = leak_count/combo_count if combo_count > 0 else 0
|
||||
f.write('\t'.join(
|
||||
[str(length), str(combo_count), str(leak_count), str(leak_prop)])+'\n'
|
||||
)
|
||||
|
||||
util.plotStats(
|
||||
x_axis='syn_combo_length',
|
||||
x_axis_title='Length of Synthetic Combination',
|
||||
y_bar='combo_count',
|
||||
y_bar_title='Count of Combinations',
|
||||
y_line='leak_proportion',
|
||||
y_line_title=f'Proportion of Leaked (<{self.reporting_resolution}) Combinations',
|
||||
color='violet',
|
||||
darker_color='darkviolet',
|
||||
stats_tsv=self.leakage_tsv,
|
||||
stats_svg=self.leakage_svg,
|
||||
delimiter='\t',
|
||||
style='whitegrid',
|
||||
palette='magma'
|
||||
)
|
||||
|
||||
def _calculate_preservation_by_count(self):
|
||||
logging.info('Calculating preservation by count...')
|
||||
preservation_by_count = self.sds_evaluator.calc_preservation_by_count(
|
||||
self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution
|
||||
)
|
||||
preservation_by_count.write_preservation_by_count(
|
||||
self.preservation_by_count_tsv, '\t'
|
||||
)
|
||||
|
||||
util.plotStats(
|
||||
x_axis='syn_count_bucket',
|
||||
x_axis_title='Count of Filtered Synthetic Records',
|
||||
y_bar='mean_combo_length',
|
||||
y_bar_title='Mean Length of Combinations',
|
||||
y_line='count_preservation',
|
||||
y_line_title='Count Preservation (Synthetic/Sensitive)',
|
||||
color='lightgreen',
|
||||
darker_color='green',
|
||||
stats_tsv=self.preservation_by_count_tsv,
|
||||
stats_svg=self.preservation_by_count_svg,
|
||||
delimiter='\t',
|
||||
style='whitegrid',
|
||||
palette='magma'
|
||||
)
|
||||
|
||||
def _calculate_preservation_by_length(self):
|
||||
logging.info('Calculating preservation by length...')
|
||||
preservation_by_length = self.sds_evaluator.calc_preservation_by_length(
|
||||
self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution
|
||||
)
|
||||
preservation_by_length.write_preservation_by_length(
|
||||
self.preservation_by_length_tsv, '\t'
|
||||
)
|
||||
|
||||
util.plotStats(
|
||||
x_axis='syn_combo_length',
|
||||
x_axis_title='Length of Combination',
|
||||
y_bar='mean_combo_count',
|
||||
y_bar_title='Mean Synthetic Count',
|
||||
y_line='count_preservation',
|
||||
y_line_title='Count Preservation (Synthetic/Sensitive)',
|
||||
color='cornflowerblue',
|
||||
darker_color='mediumblue',
|
||||
stats_tsv=self.preservation_by_length_tsv,
|
||||
stats_svg=self.preservation_by_length_svg,
|
||||
delimiter='\t',
|
||||
style='whitegrid',
|
||||
palette='magma'
|
||||
)
|
||||
|
||||
def run(self):
|
||||
logging.info(
|
||||
f'Evaluate {self.synthetic_microdata_path} vs {self.sensitive_microdata_path}'
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
self._load_sensitive_aggregates()
|
||||
self._load_synthetic_aggregates()
|
||||
self._do_records_analysis()
|
||||
self._compare_synthetic_and_sensitive_rare()
|
||||
self._write_parameters()
|
||||
self._find_leakages()
|
||||
self._calculate_preservation_by_count()
|
||||
self._calculate_preservation_by_length()
|
||||
|
||||
logging.info(
|
||||
f'Evaluated {self.synthetic_microdata_path} vs {self.sensitive_microdata_path}, took {datetime.timedelta(seconds = time.time() - start_time)}s')
|
||||
|
||||
|
||||
def evaluate(config):
|
||||
|
@ -15,309 +254,4 @@ def evaluate(config):
|
|||
Args:
|
||||
config: options from the json config file, else default values.
|
||||
"""
|
||||
|
||||
use_columns = config['use_columns']
|
||||
record_limit = config['record_limit']
|
||||
reporting_length = config['reporting_length']
|
||||
reporting_resolution = config['reporting_resolution']
|
||||
sensitive_microdata_path = config['sensitive_microdata_path']
|
||||
sensitive_microdata_delimiter = config['sensitive_microdata_delimiter']
|
||||
synthetic_microdata_path = config['synthetic_microdata_path']
|
||||
sensitive_zeros = config['sensitive_zeros']
|
||||
parallel_jobs = config['parallel_jobs']
|
||||
output_dir = config['output_dir']
|
||||
prefix = config['prefix']
|
||||
|
||||
logging.info(
|
||||
f'Evaluate {synthetic_microdata_path} vs {sensitive_microdata_path}')
|
||||
start_time = time.time()
|
||||
|
||||
sen_counts = None
|
||||
sen_records = None
|
||||
sen_df = util.loadMicrodata(
|
||||
path=sensitive_microdata_path, delimiter=sensitive_microdata_delimiter, record_limit=record_limit,
|
||||
use_columns=use_columns)
|
||||
sen_records = util.genRowList(sen_df, sensitive_zeros)
|
||||
if not path.exists(config['sensitive_aggregates_path']):
|
||||
logging.info('Computing sensitive aggregates...')
|
||||
if reporting_length == -1:
|
||||
reporting_length = max([len(row) for row in sen_records])
|
||||
sen_counts = util.countAllCombos(
|
||||
sen_records, reporting_length, parallel_jobs)
|
||||
else:
|
||||
logging.info('Loading sensitive aggregates...')
|
||||
sen_counts = util.loadSavedAggregates(
|
||||
config['sensitive_aggregates_path'])
|
||||
if reporting_length == -1:
|
||||
reporting_length = max(sen_counts.keys())
|
||||
|
||||
if use_columns != []:
|
||||
reporting_length = min(reporting_length, len(use_columns))
|
||||
|
||||
filtered_sen_counts = {length: {combo: count for combo, count in combo_to_counts.items(
|
||||
) if count >= reporting_resolution} for length, combo_to_counts in sen_counts.items()}
|
||||
syn_df = util.loadMicrodata(path=synthetic_microdata_path,
|
||||
delimiter='\t', record_limit=-1, use_columns=use_columns)
|
||||
syn_records = util.genRowList(syn_df, sensitive_zeros)
|
||||
syn_counts = util.countAllCombos(
|
||||
syn_records, reporting_length, parallel_jobs)
|
||||
|
||||
len_to_syn_count = {length: len(combo_to_count)
|
||||
for length, combo_to_count in syn_counts.items()}
|
||||
len_to_sen_rare = {length: {combo: count for combo, count in combo_to_count.items() if count < reporting_resolution}
|
||||
for length, combo_to_count in sen_counts.items()}
|
||||
len_to_syn_rare = {length: {combo: count for combo, count in combo_to_count.items() if count < reporting_resolution}
|
||||
for length, combo_to_count in syn_counts.items()}
|
||||
len_to_syn_leak = {length: len([1 for rare in rares if rare in syn_counts.get(length, {}).keys()])
|
||||
for length, rares in len_to_sen_rare.items()}
|
||||
|
||||
sen_unique_to_records, sen_rare_to_records, _ = util.mapShortestUniqueRareComboLengthToRecords(
|
||||
sen_records, len_to_sen_rare)
|
||||
sen_rare_to_sen_count = {length: util.protect(len(records), reporting_resolution)
|
||||
for length, records in sen_rare_to_records.items()}
|
||||
sen_unique_to_sen_count = {length: util.protect(len(records), reporting_resolution)
|
||||
for length, records in sen_unique_to_records.items()}
|
||||
|
||||
total_sen = util.protect(len(sen_records), reporting_resolution)
|
||||
unique_total = sum(
|
||||
[v for k, v in sen_unique_to_sen_count.items() if k > 0])
|
||||
rare_total = sum([v for k, v in sen_rare_to_sen_count.items() if k > 0])
|
||||
risky_total = unique_total + rare_total
|
||||
risky_total_pct = 100*risky_total/total_sen
|
||||
|
||||
record_analysis_tsv = path.join(
|
||||
output_dir, f'{prefix}_sensitive_analysis_by_length.tsv')
|
||||
with open(record_analysis_tsv, 'w') as f:
|
||||
f.write('\t'.join(['combo_length', 'sen_rare', 'sen_rare_pct',
|
||||
'sen_unique', 'sen_unique_pct', 'sen_risky', 'sen_risky_pct'])+'\n')
|
||||
for length in sen_counts.keys():
|
||||
sen_rare = sen_rare_to_sen_count.get(length, 0)
|
||||
sen_rare_pct = 100*sen_rare / total_sen if total_sen > 0 else 0
|
||||
sen_unique = sen_unique_to_sen_count.get(length, 0)
|
||||
sen_unique_pct = 100*sen_unique / total_sen if total_sen > 0 else 0
|
||||
sen_risky = sen_rare + sen_unique
|
||||
sen_risky_pct = 100*sen_risky / total_sen if total_sen > 0 else 0
|
||||
f.write(
|
||||
'\t'.join(
|
||||
[str(length),
|
||||
str(sen_rare),
|
||||
str(sen_rare_pct),
|
||||
str(sen_unique),
|
||||
str(sen_unique_pct),
|
||||
str(sen_risky),
|
||||
str(sen_risky_pct)]) + '\n')
|
||||
|
||||
_, _, syn_length_to_combo_to_rare = util.mapShortestUniqueRareComboLengthToRecords(
|
||||
syn_records, len_to_syn_rare)
|
||||
combos_tsv = path.join(
|
||||
output_dir, f'{prefix}_synthetic_rare_combos_by_length.tsv')
|
||||
with open(combos_tsv, 'w') as f:
|
||||
f.write('\t'.join(['combo_length', 'combo',
|
||||
'record_id', 'syn_count', 'sen_count'])+'\n')
|
||||
for length, combo_to_rare in syn_length_to_combo_to_rare.items():
|
||||
for combo, rare_ids in combo_to_rare.items():
|
||||
syn_count = len(rare_ids)
|
||||
for rare_id in rare_ids:
|
||||
sen_count = util.protect(sen_counts.get(length, {})[
|
||||
combo], reporting_resolution)
|
||||
f.write(
|
||||
'\t'.join(
|
||||
[str(length),
|
||||
util.comboToString(combo).replace(';', ' AND '),
|
||||
str(rare_id),
|
||||
str(syn_count),
|
||||
str(sen_count)]) + '\n')
|
||||
|
||||
parameters_tsv = path.join(output_dir, f'{prefix}_parameters.tsv')
|
||||
|
||||
with open(parameters_tsv, 'w') as f:
|
||||
f.write('\t'.join(['parameter', 'value'])+'\n')
|
||||
f.write('\t'.join(['resolution', str(reporting_resolution)])+'\n')
|
||||
f.write('\t'.join(['limit', str(reporting_length)])+'\n')
|
||||
f.write('\t'.join(['total_sensitive_records', str(total_sen)])+'\n')
|
||||
f.write('\t'.join(['unique_identifiable', str(unique_total)])+'\n')
|
||||
f.write('\t'.join(['rare_identifiable', str(rare_total)])+'\n')
|
||||
f.write('\t'.join(['risky_identifiable', str(risky_total)])+'\n')
|
||||
f.write(
|
||||
'\t'.join(['risky_identifiable_pct', str(risky_total_pct)])+'\n')
|
||||
|
||||
leakage_tsv = path.join(
|
||||
output_dir, f'{prefix}_synthetic_leakage_by_length.tsv')
|
||||
leakage_svg = path.join(
|
||||
output_dir, f'{prefix}_synthetic_leakage_by_length.svg')
|
||||
with open(leakage_tsv, 'w') as f:
|
||||
f.write('\t'.join(['syn_combo_length', 'combo_count',
|
||||
'leak_count', 'leak_proportion'])+'\n')
|
||||
for length, leak_count in len_to_syn_leak.items():
|
||||
combo_count = len_to_syn_count.get(length, 0)
|
||||
leak_prop = leak_count/combo_count if combo_count > 0 else 0
|
||||
f.write('\t'.join([str(length), str(combo_count),
|
||||
str(leak_count), str(leak_prop)])+'\n')
|
||||
|
||||
util.plotStats(
|
||||
x_axis='syn_combo_length',
|
||||
x_axis_title='Length of Synthetic Combination',
|
||||
y_bar='combo_count',
|
||||
y_bar_title='Count of Combinations',
|
||||
y_line='leak_proportion',
|
||||
y_line_title=f'Proportion of Leaked (<{reporting_resolution}) Combinations',
|
||||
color='violet',
|
||||
darker_color='darkviolet',
|
||||
stats_tsv=leakage_tsv,
|
||||
stats_svg=leakage_svg,
|
||||
delimiter='\t',
|
||||
style='whitegrid',
|
||||
palette='magma')
|
||||
|
||||
compareDatasets(filtered_sen_counts, syn_counts, output_dir, prefix)
|
||||
|
||||
logging.info(
|
||||
f'Evaluated {synthetic_microdata_path} vs {sensitive_microdata_path}, took {datetime.timedelta(seconds = time.time() - start_time)}s')
|
||||
|
||||
|
||||
def compareDatasets(sensitive_length_to_combo_to_count, synthetic_length_to_combo_to_count, output_dir, prefix):
|
||||
"""Evaluates the error in the synthetic microdata with respect to the sensitive microdata.
|
||||
|
||||
Produces output statistics and graphics binned by attribute count.
|
||||
|
||||
Args:
|
||||
sensitive_length_to_combo_to_count: counts from sensitive microdata.
|
||||
synthetic_length_to_combo_to_count: counts from synthetic microdata.
|
||||
output_dir: where to save output statistics and graphics.
|
||||
prefix: prefix to add to output files.
|
||||
"""
|
||||
|
||||
all_count_length_preservation = []
|
||||
max_syn_count = 0
|
||||
|
||||
all_combos = set()
|
||||
for length, combo_to_count in sensitive_length_to_combo_to_count.items():
|
||||
for combo in combo_to_count.keys():
|
||||
all_combos.add((length, combo))
|
||||
for length, combo_to_count in synthetic_length_to_combo_to_count.items():
|
||||
for combo in combo_to_count.keys():
|
||||
all_combos.add((length, combo))
|
||||
tot = len(all_combos)
|
||||
for i, (length, combo) in enumerate(all_combos):
|
||||
if i % 10000 == 0:
|
||||
logging.info(f'{100*i/tot:.1f}% through comparisons')
|
||||
sen_count = sensitive_length_to_combo_to_count.get(
|
||||
length, {}).get(combo, 0)
|
||||
syn_count = synthetic_length_to_combo_to_count.get(
|
||||
length, {}).get(combo, 0)
|
||||
max_syn_count = max(syn_count, max_syn_count)
|
||||
preservation = 0
|
||||
preservation = syn_count / sen_count if sen_count > 0 else 0
|
||||
if sen_count == 0:
|
||||
logging.error(
|
||||
f'Error: For {combo}, synthetic count is {syn_count} but no sensitive count')
|
||||
all_count_length_preservation.append((syn_count, length, preservation))
|
||||
max_syn_count = max(syn_count, max_syn_count)
|
||||
|
||||
generateStatsAndGraphics(output_dir, max_syn_count,
|
||||
all_count_length_preservation, prefix)
|
||||
|
||||
|
||||
def generateStatsAndGraphics(output_dir, max_syn_count, count_length_preservation, prefix):
|
||||
"""Generates count error statistics for buckets of user-observed counts (post-filtering).
|
||||
|
||||
Outputs the files preservation_by_length and preservation_by_count tsv and svg files.
|
||||
|
||||
Args:
|
||||
output_dir: the folder in which to output summary files.
|
||||
max_syn_count: the maximum observed count of synthetic records matching a single attribute value.
|
||||
count_length_preservation: list of (count, length, preservation) tuples for observed preservation of sensitive counts after filtering by a combo of length.
|
||||
"""
|
||||
sorted_counts = sorted(count_length_preservation,
|
||||
key=lambda x: x[0], reverse=False)
|
||||
buckets = []
|
||||
next_bucket = 10
|
||||
while next_bucket < max_syn_count:
|
||||
buckets.append(next_bucket)
|
||||
next_bucket *= 2
|
||||
buckets.append(next_bucket)
|
||||
bucket_to_preservations = defaultdict(list)
|
||||
bucket_to_counts = defaultdict(list)
|
||||
bucket_to_lengths = defaultdict(list)
|
||||
length_to_preservations = defaultdict(list)
|
||||
length_to_counts = defaultdict(list)
|
||||
for (count, length, preservation) in sorted_counts:
|
||||
bucket = buckets[0]
|
||||
for bi in range(len(buckets)-1, -1, -1):
|
||||
if count > buckets[bi]:
|
||||
bucket = buckets[bi+1]
|
||||
break
|
||||
bucket_to_preservations[bucket].append(preservation)
|
||||
bucket_to_lengths[bucket].append(length)
|
||||
bucket_to_counts[bucket].append(count)
|
||||
length_to_counts[length].append(count)
|
||||
length_to_preservations[length].append(preservation)
|
||||
|
||||
bucket_to_mean_count = {bucket: sum(counts)/len(counts) if len(counts) >
|
||||
0 else 0 for bucket, counts in bucket_to_counts.items()}
|
||||
bucket_to_mean_preservation = {bucket: sum(preservations) / len(preservations)
|
||||
if len(preservations) > 0 else 0 for bucket,
|
||||
preservations in bucket_to_preservations.items()}
|
||||
bucket_to_mean_length = {bucket: sum(lengths)/len(lengths) if len(lengths) >
|
||||
0 else 0 for bucket, lengths in bucket_to_lengths.items()}
|
||||
|
||||
counts_tsv = path.join(output_dir, f'{prefix}_preservation_by_count.tsv')
|
||||
counts_svg = path.join(output_dir, f'{prefix}_preservation_by_count.svg')
|
||||
with open(counts_tsv, 'w') as f:
|
||||
f.write('\t'.join(['syn_count_bucket', 'mean_combo_count',
|
||||
'mean_combo_length', 'count_preservation'])+'\n')
|
||||
for bucket in reversed(buckets):
|
||||
f.write(
|
||||
'\t'.join(
|
||||
[str(bucket),
|
||||
str(bucket_to_mean_count.get(bucket, 0)),
|
||||
str(bucket_to_mean_length.get(bucket, 0)),
|
||||
str(bucket_to_mean_preservation.get(bucket, 0))]) + '\n')
|
||||
|
||||
util.plotStats(
|
||||
x_axis='syn_count_bucket',
|
||||
x_axis_title='Count of Filtered Synthetic Records',
|
||||
y_bar='mean_combo_length',
|
||||
y_bar_title='Mean Length of Combinations',
|
||||
y_line='count_preservation',
|
||||
y_line_title='Count Preservation (Synthetic/Sensitive)',
|
||||
color='lightgreen',
|
||||
darker_color='green',
|
||||
stats_tsv=counts_tsv,
|
||||
stats_svg=counts_svg,
|
||||
delimiter='\t',
|
||||
style='whitegrid',
|
||||
palette='magma')
|
||||
|
||||
length_to_mean_preservation = {length: sum(preservations) / len(preservations)
|
||||
if len(preservations) > 0 else 0 for length,
|
||||
preservations in length_to_preservations.items()}
|
||||
length_to_mean_count = {length: sum(counts)/len(counts) if len(counts) >
|
||||
0 else 0 for length, counts in length_to_counts.items()}
|
||||
|
||||
lengths_tsv = path.join(output_dir, f'{prefix}_preservation_by_length.tsv')
|
||||
lengths_svg = path.join(output_dir, f'{prefix}_preservation_by_length.svg')
|
||||
with open(lengths_tsv, 'w') as f:
|
||||
f.write(
|
||||
'\t'.join(['syn_combo_length', 'mean_combo_count', 'count_preservation'])+'\n')
|
||||
for length in sorted(length_to_preservations.keys()):
|
||||
f.write(
|
||||
'\t'.join(
|
||||
[str(length),
|
||||
str(length_to_mean_count.get(length, 0)),
|
||||
str(length_to_mean_preservation.get(length, 0))]) + '\n')
|
||||
|
||||
util.plotStats(
|
||||
x_axis='syn_combo_length',
|
||||
x_axis_title='Length of Combination',
|
||||
y_bar='mean_combo_count',
|
||||
y_bar_title='Mean Synthetic Count',
|
||||
y_line='count_preservation',
|
||||
y_line_title='Count Preservation (Synthetic/Sensitive)',
|
||||
color='cornflowerblue',
|
||||
darker_color='mediumblue',
|
||||
stats_tsv=lengths_tsv,
|
||||
stats_svg=lengths_svg,
|
||||
delimiter='\t',
|
||||
style='whitegrid',
|
||||
palette='magma')
|
||||
Evaluator(config).run()
|
||||
|
|
|
@ -1,12 +1,6 @@
|
|||
import time
|
||||
import datetime
|
||||
import logging
|
||||
import joblib
|
||||
import psutil
|
||||
import pandas as pd
|
||||
from math import ceil
|
||||
from random import random, shuffle
|
||||
import util as util
|
||||
import sds
|
||||
|
||||
|
||||
|
@ -20,14 +14,12 @@ def generate(config):
|
|||
"""
|
||||
|
||||
use_columns = config['use_columns']
|
||||
parallel_jobs = config['parallel_jobs']
|
||||
record_limit = config['record_limit']
|
||||
sensitive_microdata_path = config['sensitive_microdata_path']
|
||||
sensitive_microdata_delimiter = config['sensitive_microdata_delimiter']
|
||||
synthetic_microdata_path = config['synthetic_microdata_path']
|
||||
sensitive_zeros = config['sensitive_zeros']
|
||||
resolution = config['reporting_resolution']
|
||||
memory_limit = config['memory_limit_pct']
|
||||
cache_max_size = config['cache_max_size']
|
||||
seeded = config['seeded']
|
||||
|
||||
|
@ -36,215 +28,26 @@ def generate(config):
|
|||
|
||||
if seeded:
|
||||
logging.info(f'Generating from seeds')
|
||||
data_block = sds.create_data_block_from_file(
|
||||
sensitive_microdata_path,
|
||||
sensitive_microdata_delimiter,
|
||||
use_columns,
|
||||
sensitive_zeros,
|
||||
max(record_limit, 0)
|
||||
)
|
||||
syn_ratio = sds.generate_seeded_and_write(
|
||||
data_block,
|
||||
synthetic_microdata_path,
|
||||
'\t',
|
||||
cache_max_size,
|
||||
resolution
|
||||
)
|
||||
|
||||
else:
|
||||
df = util.loadMicrodata(path=sensitive_microdata_path, delimiter=sensitive_microdata_delimiter,
|
||||
record_limit=record_limit, use_columns=use_columns)
|
||||
columns = df.columns.values
|
||||
num = len(df)
|
||||
logging.info(f'Prepared data')
|
||||
records = []
|
||||
|
||||
chunk = ceil(num/(parallel_jobs))
|
||||
|
||||
logging.info(f'Generating unseeded')
|
||||
chunks = [chunk for i in range(parallel_jobs)]
|
||||
col_val_ids = util.genColValIdsDict(df, sensitive_zeros)
|
||||
res = joblib.Parallel(
|
||||
n_jobs=parallel_jobs, backend='loky', verbose=1)(
|
||||
joblib.delayed(synthesizeRowsUnseeded)(
|
||||
chunk, num, columns, col_val_ids, resolution, memory_limit)
|
||||
for chunk in chunks)
|
||||
for rows in res:
|
||||
records.extend(rows)
|
||||
# trim any overgenerated records because of uniform chunk size
|
||||
records = records[:num]
|
||||
|
||||
records.sort()
|
||||
records.sort(key=lambda x: len(
|
||||
[y for y in x if y != '']), reverse=True)
|
||||
|
||||
sdf = pd.DataFrame(records, columns=df.columns)
|
||||
syn_ratio = len(sdf) / len(df)
|
||||
sdf.to_csv(synthetic_microdata_path, sep='\t', index=False)
|
||||
sds_processor = sds.SDSProcessor(
|
||||
sensitive_microdata_path,
|
||||
sensitive_microdata_delimiter,
|
||||
use_columns,
|
||||
sensitive_zeros,
|
||||
max(record_limit, 0)
|
||||
)
|
||||
generated_data = sds_processor.generate(
|
||||
cache_max_size,
|
||||
resolution,
|
||||
"",
|
||||
seeded
|
||||
)
|
||||
generated_data.write_synthetic_data(synthetic_microdata_path, '\t')
|
||||
syn_ratio = generated_data.expansion_ratio
|
||||
|
||||
config['expansion_ratio'] = syn_ratio
|
||||
|
||||
logging.info(
|
||||
f'Generated {synthetic_microdata_path} from {sensitive_microdata_path} with synthesis ratio {syn_ratio}, took {datetime.timedelta(seconds = time.time() - start_time)}s')
|
||||
|
||||
|
||||
def synthesizeRowsUnseeded(chunk, num_rows, columns, col_val_ids, resolution, memory_limit):
|
||||
"""Create synthetic records through unconstrained sampling of attribute distributions.
|
||||
|
||||
Args:
|
||||
chunk: how many records to synthesize.
|
||||
num_rows: how many rows/records in the sensitive dataset.
|
||||
columns: the columns of the sensitive dataset.
|
||||
atts_to_ids: a dict mapping attributes to row ids.
|
||||
resolution: the minimum count of sensitive attribute combinations for inclusion in synthetic records.
|
||||
memory_limit: the percentage memory use not to exceed.
|
||||
|
||||
Returns:
|
||||
rows: synthetic records created through unconstrained sampling of attribute distributions.
|
||||
overall_cache_hits: the number of times the cache was hit.
|
||||
overall_cache_misses: the number of times the cache was missed.
|
||||
"""
|
||||
rows = []
|
||||
filter_cache = {}
|
||||
overall_cache_hits = 0
|
||||
overall_cache_misses = 0
|
||||
|
||||
for i in range(chunk):
|
||||
if i % 100 == 99:
|
||||
logging.info(
|
||||
f'{(100*i/chunk):.1f}% through row synthesis, cache utilization {100*overall_cache_hits/(overall_cache_hits + overall_cache_misses):.1f}%')
|
||||
filters, cache_hits, cache_misses = synthesizeRowUnseeded(
|
||||
filter_cache, num_rows, columns, col_val_ids, resolution, memory_limit)
|
||||
overall_cache_hits += cache_hits
|
||||
overall_cache_misses += cache_misses
|
||||
row = normalize(columns, filters)
|
||||
rows.append(row)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def synthesizeRowUnseeded(filter_cache, num_rows, columns, col_val_ids, resolution, memory_limit):
|
||||
"""Creates a synthetic record through unconstrained sampling of attribute distributions.
|
||||
|
||||
Args:
|
||||
num_rows: how many rows/records in the sensitive dataset.
|
||||
columns: the columns of the sensitive dataset.
|
||||
atts_to_ids: a dict mapping attributes to row ids.
|
||||
resolution: the minimum count of sensitive attribute combinations for inclusion in synthetic records.
|
||||
memory_limit: the percentage memory use not to exceed.
|
||||
|
||||
Returns:
|
||||
row: a synthetic record created through unconstrained sampling of attribute distributions.
|
||||
row_cache_hits: the number of times the cache was hit.
|
||||
row_cache_misses: the number of times the cache was missed.
|
||||
"""
|
||||
shuffled_columns = list(columns)
|
||||
shuffle(shuffled_columns)
|
||||
output_atts = []
|
||||
residual_ids = set(range(num_rows))
|
||||
row_cache_hits = 0
|
||||
row_cache_misses = 0
|
||||
# do not add to cache once memory limit is reached
|
||||
use_cache = psutil.virtual_memory()[2] <= memory_limit
|
||||
for col in shuffled_columns:
|
||||
vals_to_sample = {}
|
||||
for val, ids in col_val_ids[col].items():
|
||||
next_filters = tuple(
|
||||
sorted((*output_atts, (col, val)), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
|
||||
if next_filters in filter_cache.keys():
|
||||
vals_to_sample[val] = set(filter_cache[next_filters])
|
||||
row_cache_hits += 1
|
||||
else:
|
||||
row_cache_misses += 1
|
||||
res = set(ids).intersection(residual_ids)
|
||||
vals_to_sample[val] = res
|
||||
if use_cache:
|
||||
filter_cache[next_filters] = res
|
||||
if '' not in vals_to_sample.keys():
|
||||
vals_to_sample[''] = set()
|
||||
vals_to_remove = {k: v for k,
|
||||
v in vals_to_sample.items() if len(v) < resolution}
|
||||
for val, ids in vals_to_remove.items():
|
||||
vals_to_sample[''].update(ids)
|
||||
if val != '':
|
||||
del vals_to_sample[val]
|
||||
val_to_counts = {k: len(v) for k, v in vals_to_sample.items()}
|
||||
sampled_val = sampleFromCounts(val_to_counts, True)
|
||||
if sampled_val != None:
|
||||
output_atts.append((col, sampled_val))
|
||||
residual_ids = set(vals_to_sample[sampled_val])
|
||||
|
||||
filters = tuple(
|
||||
sorted(output_atts, key=lambda x: f'{x[0]}:{x[1]}'.lower()))
|
||||
return filters, row_cache_hits, row_cache_misses
|
||||
|
||||
|
||||
def sampleFromCounts(counts, preferNotNone):
|
||||
"""Samples from a dict of counts based on their relative frequency
|
||||
|
||||
Args:
|
||||
counts: the dict of counts.
|
||||
preferNotNone: whether to avoid sampling None if possible
|
||||
|
||||
Returns:
|
||||
sampled_value: the sampled value
|
||||
"""
|
||||
dist = convertCountsToCumulativeDistribution(counts)
|
||||
sampled_value = None
|
||||
r = random()
|
||||
keys = sorted(dist.keys())
|
||||
for p in keys:
|
||||
v = dist[p]
|
||||
if r < p:
|
||||
if preferNotNone and v == None:
|
||||
continue
|
||||
else:
|
||||
sampled_value = v
|
||||
break
|
||||
sampled_value = v
|
||||
return sampled_value
|
||||
|
||||
|
||||
def convertCountsToCumulativeDistribution(att_to_count):
|
||||
"""Converts a dict of counts to a cumulative probability distribution for sampling.
|
||||
|
||||
Args:
|
||||
att_to_count: a dict of attribute counts.
|
||||
|
||||
Returns:
|
||||
dist a cumulative probability distribution mapping cumulative probabilities [0,1] to attributes.
|
||||
"""
|
||||
dist = {}
|
||||
total = sum(att_to_count.values())
|
||||
if total == 0:
|
||||
return dist
|
||||
cumulative = 0
|
||||
for att, count in att_to_count.items():
|
||||
if count > 0:
|
||||
p = count/total
|
||||
cumulative += p
|
||||
dist[cumulative] = att
|
||||
return dist
|
||||
|
||||
|
||||
def normalize(columns, atts):
|
||||
"""Creates an output record according to the columns schema from a set of attributes.
|
||||
|
||||
Args:
|
||||
columns: the columns of the output record.
|
||||
atts: the attribute values of the output record.
|
||||
|
||||
Returns:
|
||||
row: a normalized row ready for dataframe integration.
|
||||
"""
|
||||
row = []
|
||||
for c in columns:
|
||||
added = False
|
||||
for a in atts:
|
||||
if a[0] == c:
|
||||
row.append(a[1])
|
||||
added = True
|
||||
break
|
||||
if not added:
|
||||
row.append('')
|
||||
return row
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import sds
|
||||
from os import path, mkdir
|
||||
import aggregator as aggregator
|
||||
import generator as generator
|
||||
|
@ -49,7 +50,7 @@ def main():
|
|||
config['evaluate'] = True
|
||||
config['navigate'] = True
|
||||
|
||||
# set based on the number of cores/memory available
|
||||
# set based on the number of cores/memory available
|
||||
config['parallel_jobs'] = config.get('parallel_jobs', 1)
|
||||
config['memory_limit_pct'] = config.get('memory_limit_pct', 80)
|
||||
config['cache_max_size'] = config.get('cache_max_size', 100000)
|
||||
|
@ -102,6 +103,9 @@ def runPipeline(config):
|
|||
config: options from the json config file, else default values.
|
||||
"""
|
||||
|
||||
# set number of threads to be used for processing
|
||||
sds.set_number_of_threads(config['parallel_jobs'])
|
||||
|
||||
if config['aggregate']:
|
||||
aggregator.aggregate(config)
|
||||
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
|
||||
import matplotlib.ticker as ticker
|
||||
import matplotlib.pyplot as plt
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import joblib
|
||||
from itertools import combinations
|
||||
from collections import defaultdict
|
||||
import seaborn as sns
|
||||
from math import ceil, floor
|
||||
from math import ceil
|
||||
import matplotlib
|
||||
# fixes matplotlib + joblib bug "RuntimeError: main thread is not in main loop Tcl_AsyncDelete: async handler deleted by the wrong thread"
|
||||
matplotlib.use('Agg')
|
||||
|
@ -36,219 +32,6 @@ def loadMicrodata(path, delimiter, record_limit, use_columns):
|
|||
return df
|
||||
|
||||
|
||||
def genRowList(df, sensitive_zeros):
|
||||
"""Converts a dataframe to a list of rows.
|
||||
|
||||
Args:
|
||||
df: the dataframe.
|
||||
sensitive_zeros: columns for which negative values should be controlled.
|
||||
|
||||
Returns:
|
||||
row_list: the list of rows.
|
||||
"""
|
||||
row_list = []
|
||||
for _, row in df.iterrows():
|
||||
res = []
|
||||
for c in df.columns:
|
||||
val = str(row[c])
|
||||
if val != '' and (c in sensitive_zeros or val != '0'):
|
||||
res.append((c, val))
|
||||
row2 = sorted(res, key=lambda x: f'{x[0]}:{x[1]}'.lower())
|
||||
row_list.append(row2)
|
||||
return row_list
|
||||
|
||||
|
||||
def genColValIdsDict(df, sensitive_zeros):
|
||||
"""Converts a dataframe to a dict of col->val->ids.
|
||||
|
||||
Args:
|
||||
df: the dataframe
|
||||
|
||||
Returns:
|
||||
colValIds: the dict of col->val->ids.
|
||||
"""
|
||||
colValIds = {}
|
||||
for rid, row in df.iterrows():
|
||||
for c in df.columns:
|
||||
val = str(row[c])
|
||||
if c not in colValIds.keys():
|
||||
colValIds[c] = defaultdict(list)
|
||||
if val == '0':
|
||||
if c in sensitive_zeros:
|
||||
colValIds[c][val].append(rid)
|
||||
else:
|
||||
colValIds[c][''].append(rid)
|
||||
else:
|
||||
colValIds[c][val].append(rid)
|
||||
|
||||
return colValIds
|
||||
|
||||
|
||||
def countAllCombos(row_list, length_limit, parallel_jobs):
|
||||
"""Counts all combinations in the given rows up to a limit.
|
||||
|
||||
Args:
|
||||
row_list: a list of rows.
|
||||
length_limit: the maximum length to compute counts for (all lengths if -1).
|
||||
parallel_jobs: the number of processor cores to use for parallelized counting.
|
||||
|
||||
Returns:
|
||||
length_to_combo_to_count: a dict mapping combination lengths to dicts mapping combinations to counts.
|
||||
"""
|
||||
length_to_combo_to_count = {}
|
||||
if length_limit == -1:
|
||||
length_limit = max([len(x) for x in row_list])
|
||||
for length in range(1, length_limit+1):
|
||||
logging.info(f'counting combos of length {length}')
|
||||
res = joblib.Parallel(n_jobs=parallel_jobs, backend='loky', verbose=1)(
|
||||
joblib.delayed(genAllCombos)(row, length) for row in row_list)
|
||||
length_to_combo_to_count[length] = defaultdict(int)
|
||||
for combos in res:
|
||||
for combo in combos:
|
||||
length_to_combo_to_count.get(length, {})[combo] += 1
|
||||
|
||||
return length_to_combo_to_count
|
||||
|
||||
|
||||
def genAllCombos(row, length):
|
||||
"""Generates all combos from row up to and including size length.
|
||||
|
||||
Args:
|
||||
row: the row to extract combinations from.
|
||||
length: the maximum combination length to extract.
|
||||
|
||||
Returns:
|
||||
combos: list of combinations extracted from row.
|
||||
"""
|
||||
res = []
|
||||
if len(row) == length:
|
||||
canonical_combo = tuple(
|
||||
sorted(list(row), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
|
||||
res.append(canonical_combo)
|
||||
else:
|
||||
for combo in combinations(row, length):
|
||||
canonical_combo = tuple(
|
||||
sorted(list(combo), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
|
||||
res.append(canonical_combo)
|
||||
return res
|
||||
|
||||
|
||||
def mapShortestUniqueRareComboLengthToRecords(records, length_to_rare):
|
||||
"""
|
||||
Maps each record to the shortest combination length that isolates it within a rare group (i.e., below resolution).
|
||||
|
||||
Args:
|
||||
records: the input records.
|
||||
length_to_rare: a dict of length to rare combo to count.
|
||||
|
||||
Returns:
|
||||
rare_to_records: dict of rare combination lengths mapped to record lists
|
||||
"""
|
||||
rare_to_records = defaultdict(set)
|
||||
unique_to_records = defaultdict(set)
|
||||
length_to_combo_to_rare = {length: defaultdict(
|
||||
set) for length in length_to_rare.keys()}
|
||||
for i, record in enumerate(records):
|
||||
matchedRare = False
|
||||
matchedUnique = False
|
||||
for length in sorted(length_to_rare.keys()):
|
||||
if matchedUnique:
|
||||
break
|
||||
for combo in combinations(record, length):
|
||||
canonical_combo = tuple(
|
||||
sorted(list(combo), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
|
||||
if canonical_combo in length_to_rare.get(length, {}).keys():
|
||||
# unique
|
||||
if length_to_rare.get(length, {})[canonical_combo] == 1:
|
||||
unique_to_records[length].add(i)
|
||||
matchedUnique = True
|
||||
length_to_combo_to_rare.get(
|
||||
length, {})[canonical_combo].add(i)
|
||||
if i in rare_to_records[length]:
|
||||
rare_to_records[length].remove(i)
|
||||
|
||||
else:
|
||||
if i not in unique_to_records[length]:
|
||||
rare_to_records[length].add(i)
|
||||
matchedRare = True
|
||||
length_to_combo_to_rare.get(
|
||||
length, {})[canonical_combo].add(i)
|
||||
|
||||
if matchedUnique:
|
||||
break
|
||||
if not matchedRare and not matchedUnique:
|
||||
rare_to_records[0].add(i)
|
||||
if not matchedUnique:
|
||||
unique_to_records[0].add(i)
|
||||
return unique_to_records, rare_to_records, length_to_combo_to_rare
|
||||
|
||||
|
||||
def protect(value, resolution):
|
||||
"""Protects a value from a privacy perspective by rounding down to the closest multiple of the supplied resolution.
|
||||
|
||||
Args:
|
||||
value: the value to protect.
|
||||
resolution: round values down to the closest multiple of this.
|
||||
"""
|
||||
rounded = floor(value / resolution) * resolution
|
||||
return rounded
|
||||
|
||||
|
||||
def comboToString(combo_tuple):
|
||||
"""Creates a string from a tuple of (attribute, value) tuples.
|
||||
|
||||
Args:
|
||||
combo_tuple: tuple of (attribute, value) tuples.
|
||||
|
||||
Returns:
|
||||
combo_string: a ;-delimited string of ':'-concatenated attribute-value pairs.
|
||||
"""
|
||||
combo_string = ''
|
||||
for (col, val) in combo_tuple:
|
||||
combo_string += f'{col}:{val};'
|
||||
return combo_string[:-1]
|
||||
|
||||
|
||||
def loadSavedAggregates(path):
|
||||
"""Loads sensitive aggregates from file to speed up evaluation.
|
||||
|
||||
Args:
|
||||
path: location of the saved sensitive aggregates.
|
||||
|
||||
Returns:
|
||||
length_to_combo_to_count: a dict mapping combination lengths to dicts mapping combinations to counts
|
||||
"""
|
||||
length_to_combo_to_count = defaultdict(dict)
|
||||
with open(path, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i == 0:
|
||||
continue # skip header
|
||||
parts = [x.strip() for x in line.split('\t')]
|
||||
if len(parts[0]) > 0:
|
||||
length, combo = stringToLengthAndCombo(parts[0])
|
||||
length_to_combo_to_count[length][combo] = int(parts[1])
|
||||
return length_to_combo_to_count
|
||||
|
||||
|
||||
def stringToLengthAndCombo(combo_string):
|
||||
"""Creates a tuple of (attribute, value) tuples from a given string.
|
||||
|
||||
Args:
|
||||
combo_string: string representation of (attribute, value) tuples.
|
||||
|
||||
Returns:
|
||||
combo_tuple: tuple of (attribute, value) tuples.
|
||||
"""
|
||||
length = len(combo_string.split(';'))
|
||||
combo_list = []
|
||||
for col_vals in combo_string.split(';'):
|
||||
parts = col_vals.split(':')
|
||||
if len(parts) == 2:
|
||||
combo_list.append((parts[0], parts[1]))
|
||||
combo_tuple = tuple(combo_list)
|
||||
return length, combo_tuple
|
||||
|
||||
|
||||
def plotStats(
|
||||
x_axis, x_axis_title, y_bar, y_bar_title, y_line, y_line_title, color, darker_color, stats_tsv, stats_svg,
|
||||
delimiter, style='whitegrid', palette='magma'):
|
||||
|
|
|
@ -2,31 +2,24 @@
|
|||
* Copyright (c) Microsoft. All rights reserved.
|
||||
* Licensed under the MIT license. See LICENSE file in the project.
|
||||
*/
|
||||
import { Stack, Spinner, Label } from '@fluentui/react'
|
||||
import _ from 'lodash'
|
||||
import { Stack, Label, Spinner } from '@fluentui/react'
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useOnAttributeSelection } from './hooks'
|
||||
import { IAttributesIntersection } from 'sds-wasm'
|
||||
import { useMaxCount } from './hooks'
|
||||
import { AttributeIntersectionValueChart } from '~components/Charts/AttributeIntersectionValueChart'
|
||||
import { useStopPropagation } from '~components/Charts/hooks'
|
||||
import { CsvRecord, IAttributesIntersectionValue } from '~models'
|
||||
import {
|
||||
useEvaluatedResultValue,
|
||||
useNavigateResultValue,
|
||||
useWasmWorkerValue,
|
||||
} from '~states'
|
||||
import {
|
||||
useSelectedAttributeRows,
|
||||
useSelectedAttributes,
|
||||
} from '~states/dataShowcaseContext'
|
||||
import { SetSelectedAttributesCallback } from '~components/Pages/DataShowcasePage/DataNavigation'
|
||||
import { useWasmWorkerValue } from '~states'
|
||||
|
||||
export interface ColumnAttributeSelectorProps {
|
||||
headers: CsvRecord
|
||||
headerName: string
|
||||
columnIndex: number
|
||||
height: string | number
|
||||
width: number | number
|
||||
chartHeight: number
|
||||
width: string | number
|
||||
chartBarHeight: number
|
||||
minHeight?: string | number
|
||||
selectedAttributes: Set<string>
|
||||
onSetSelectedAttributes: SetSelectedAttributesCallback
|
||||
}
|
||||
|
||||
// fixed value for rough axis height so charts don't squish if selected down to 1
|
||||
|
@ -34,87 +27,49 @@ const AXIS_HEIGHT = 16
|
|||
|
||||
export const ColumnAttributeSelector: React.FC<ColumnAttributeSelectorProps> =
|
||||
memo(function ColumnAttributeSelector({
|
||||
headers,
|
||||
headerName,
|
||||
columnIndex,
|
||||
height,
|
||||
width,
|
||||
chartHeight,
|
||||
chartBarHeight,
|
||||
minHeight,
|
||||
selectedAttributes,
|
||||
onSetSelectedAttributes,
|
||||
}: ColumnAttributeSelectorProps) {
|
||||
const [selectedAttributeRows, setSelectedAttributeRows] =
|
||||
useSelectedAttributeRows()
|
||||
const [selectedAttributes, setSelectedAttributes] = useSelectedAttributes()
|
||||
const navigateResult = useNavigateResultValue()
|
||||
const [items, setItems] = useState<IAttributesIntersectionValue[]>([])
|
||||
const [items, setItems] = useState<IAttributesIntersection[]>([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const isMounted = useRef(true)
|
||||
const evaluatedResult = useEvaluatedResultValue()
|
||||
const worker = useWasmWorkerValue()
|
||||
const selectedAttribute = selectedAttributes[columnIndex]
|
||||
const maxCount =
|
||||
_.max([
|
||||
_.maxBy(items, item => item.estimatedCount)?.estimatedCount,
|
||||
_.maxBy(items, item => item.actualCount ?? 0)?.actualCount,
|
||||
]) ?? 1
|
||||
|
||||
const onAttributeSelection = useOnAttributeSelection(
|
||||
setIsLoading,
|
||||
selectedAttributes,
|
||||
worker,
|
||||
navigateResult,
|
||||
isMounted,
|
||||
setSelectedAttributes,
|
||||
setSelectedAttributeRows,
|
||||
)
|
||||
const maxCount = useMaxCount(items)
|
||||
const isMounted = useRef(true)
|
||||
const stopPropagation = useStopPropagation()
|
||||
|
||||
const handleSelection = useCallback(
|
||||
(item: IAttributesIntersectionValue | undefined) => {
|
||||
async (item: IAttributesIntersection | undefined) => {
|
||||
const newValue = item?.value
|
||||
// toggle off with re-click
|
||||
if (newValue === selectedAttribute) {
|
||||
onAttributeSelection(columnIndex, undefined)
|
||||
if (newValue === undefined || selectedAttributes.has(newValue)) {
|
||||
await onSetSelectedAttributes(columnIndex, undefined)
|
||||
} else {
|
||||
onAttributeSelection(columnIndex, newValue)
|
||||
await onSetSelectedAttributes(columnIndex, item)
|
||||
}
|
||||
},
|
||||
[selectedAttribute, columnIndex, onAttributeSelection],
|
||||
[selectedAttributes, onSetSelectedAttributes, columnIndex],
|
||||
)
|
||||
|
||||
const stopPropagation = useStopPropagation()
|
||||
|
||||
useEffect(() => {
|
||||
if (worker) {
|
||||
setIsLoading(true)
|
||||
worker
|
||||
.intersectAttributesInColumnsWith(
|
||||
headers,
|
||||
navigateResult.allRows,
|
||||
navigateResult.attrsInColumnsMap[columnIndex] ?? new Set(),
|
||||
selectedAttributeRows,
|
||||
selectedAttributes,
|
||||
navigateResult.attrRowsMap,
|
||||
columnIndex,
|
||||
evaluatedResult.sensitiveAggregatedResult?.aggregatedCombinations,
|
||||
)
|
||||
.then(newItems => {
|
||||
if (isMounted.current) {
|
||||
setItems(newItems ?? [])
|
||||
setIsLoading(false)
|
||||
.attributesIntersectionsByColumn([headerName])
|
||||
.then(intersections => {
|
||||
if (!isMounted.current || !intersections) {
|
||||
return
|
||||
}
|
||||
setItems(intersections[columnIndex] ?? [])
|
||||
setIsLoading(false)
|
||||
})
|
||||
}
|
||||
}, [
|
||||
worker,
|
||||
navigateResult,
|
||||
headers,
|
||||
selectedAttributeRows,
|
||||
selectedAttributes,
|
||||
columnIndex,
|
||||
evaluatedResult,
|
||||
setItems,
|
||||
setIsLoading,
|
||||
])
|
||||
}, [worker, setIsLoading, setItems, headerName, columnIndex])
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
|
@ -151,8 +106,8 @@ export const ColumnAttributeSelector: React.FC<ColumnAttributeSelectorProps> =
|
|||
items={items}
|
||||
onClick={handleSelection}
|
||||
maxCount={maxCount}
|
||||
height={chartHeight * Math.max(items.length, 1) + AXIS_HEIGHT}
|
||||
selectedValue={selectedAttribute}
|
||||
height={chartBarHeight * Math.max(items.length, 1) + AXIS_HEIGHT}
|
||||
selectedAttributes={selectedAttributes}
|
||||
/>
|
||||
</Stack.Item>
|
||||
)}
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче