- refactor python and wasm bindings
- update python pipeline evaluate step to use the core library
- add multi-threading support to the core library, cli application and python library
- first version of unseeded data synthesis in rust (supported on core library, cli and python pipeline)
- general performance optimizations on the core library, python pipeline and web application
This commit is contained in:
Rodrigo Racanicci 2022-01-05 10:28:14 -03:00
Родитель ca020fc9a9
Коммит 981c25da13
144 изменённых файлов: 7017 добавлений и 4487 удалений

Просмотреть файл

@ -14,4 +14,3 @@ pkg/
target/
.eslintrc.js
babel.config.js
packages/webapp/src/workers/sds-wasm/worker.js

8
.pnp.cjs сгенерированный
Просмотреть файл

@ -13697,10 +13697,10 @@ function $$SETUP_STATE(hydrateRuntimeState, basePath) {
}]
]],
["sds-wasm", [
["file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp", {
"packageLocation": "./.yarn/cache/sds-wasm-file-aa35777e98-1550beb7ae.zip/node_modules/sds-wasm/",
["file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp", {
"packageLocation": "./.yarn/cache/sds-wasm-file-6dcf65ffd1-dd8b934db1.zip/node_modules/sds-wasm/",
"packageDependencies": [
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp"]
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp"]
],
"linkType": "HARD",
}]
@ -15482,7 +15482,7 @@ function $$SETUP_STATE(hydrateRuntimeState, basePath) {
["react-is", "npm:17.0.2"],
["react-router-dom", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:6.0.2"],
["recoil", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:0.5.2"],
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=932617&locator=webapp%40workspace%3Apackages%2Fwebapp"],
["sds-wasm", "file:../../target/wasm#../../target/wasm::hash=3f33a3&locator=webapp%40workspace%3Apackages%2Fwebapp"],
["styled-components", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:5.3.3"],
["ts-node", "virtual:d293af44cc1e0d0fc09cc0c8c4a3d9e5fccdf4ddebae06b8fad52a312360d8122c830d53ecc46b13c13aaad8c6ae7dbd798566bd5cba581433425b2ff3f7540b#npm:10.4.0"],
["typescript", "patch:typescript@npm%3A4.4.3#~builtin<compat/typescript>::version=4.4.3&hash=32657b"],

Просмотреть файл

@ -1,33 +0,0 @@
name: SDS CI
pool:
vmImage: ubuntu-latest
trigger:
batch: true
branches:
include:
- main
stages:
- stage: Compliance
dependsOn: []
jobs:
- job: ComplianceJob
pool:
vmImage: windows-latest
steps:
- task: CredScan@3
inputs:
outputFormat: sarif
debugMode: false
- task: ComponentGovernanceComponentDetection@0
inputs:
scanType: 'Register'
verbosity: 'Verbose'
alertWarningLevel: 'High'
- task: PublishSecurityAnalysisLogs@3
inputs:
ArtifactName: 'CodeAnalysisLogs'
ArtifactType: 'Container'

Двоичные данные
.yarn/cache/sds-wasm-file-6dcf65ffd1-dd8b934db1.zip поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
.yarn/cache/sds-wasm-file-aa35777e98-1550beb7ae.zip поставляемый

Двоичный файл не отображается.

0
.yarn/sdks/eslint/bin/eslint.js поставляемый Executable file → Normal file
Просмотреть файл

0
.yarn/sdks/prettier/index.js поставляемый Executable file → Normal file
Просмотреть файл

0
.yarn/sdks/typescript/bin/tsc поставляемый Executable file → Normal file
Просмотреть файл

0
.yarn/sdks/typescript/bin/tsserver поставляемый Executable file → Normal file
Просмотреть файл

98
Cargo.lock сгенерированный
Просмотреть файл

@ -42,6 +42,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "autocfg"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "base-x"
version = "0.2.8"
@ -103,6 +109,50 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e"
dependencies = [
"cfg-if",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd"
dependencies = [
"cfg-if",
"crossbeam-utils",
"lazy_static",
"memoffset",
"scopeguard",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db"
dependencies = [
"cfg-if",
"lazy_static",
]
[[package]]
name = "csv"
version = "1.1.6"
@ -307,6 +357,25 @@ version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a"
[[package]]
name = "memoffset"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "once_cell"
version = "1.8.0"
@ -499,6 +568,31 @@ dependencies = [
"rand_core",
]
[[package]]
name = "rayon"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90"
dependencies = [
"autocfg",
"crossbeam-deque",
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-utils",
"lazy_static",
"num_cpus",
]
[[package]]
name = "redox_syscall"
version = "0.2.10"
@ -582,6 +676,9 @@ dependencies = [
"lru",
"pyo3",
"rand",
"rayon",
"serde",
"serde_json",
]
[[package]]
@ -589,6 +686,7 @@ name = "sds-pyo3"
version = "1.0.0"
dependencies = [
"csv",
"env_logger",
"log",
"pyo3",
"sds-core",

Просмотреть файл

@ -1,4 +1,4 @@
version: "3"
version: '3'
services:
webapp:
build:

Просмотреть файл

@ -7,7 +7,7 @@
"build:": "yarn workspaces foreach -ivt run build",
"start:": "yarn workspaces foreach -piv run start",
"lint:": "essex lint --fix --strict",
"build:lib-wasm": "cd packages/lib-wasm && wasm-pack build --release --target no-modules --out-dir ../../target/wasm",
"build:lib-wasm": "cd packages/lib-wasm && wasm-pack build --release --target web --out-dir ../../target/wasm",
"prettify": "essex prettify",
"rebuild-all": "cargo clean && run-s clean: && cargo build --release && run-s build:lib-wasm && yarn install && run-s build:"
},

Просмотреть файл

@ -7,7 +7,7 @@ repository = "https://github.com/microsoft/synthetic-data-showcase"
edition = "2018"
[dependencies]
sds-core = { path = "../core" }
sds-core = { path = "../core", features = ["rayon"] }
log = { version = "0.4" }
env_logger = { version = "0.9" }
structopt = { version = "0.3" }

Просмотреть файл

@ -1,8 +1,11 @@
use log::{error, info, log_enabled, trace, Level::Debug};
use sds_core::{
data_block::{block::DataBlockCreator, csv_block_creator::CsvDataBlockCreator},
processing::{aggregator::Aggregator, generator::Generator},
utils::reporting::LoggerProgressReporter,
data_block::{csv_block_creator::CsvDataBlockCreator, data_block_creator::DataBlockCreator},
processing::{
aggregator::Aggregator,
generator::{Generator, SynthesisMode},
},
utils::{reporting::LoggerProgressReporter, threading::set_number_of_threads},
};
use std::process;
use structopt::StructOpt;
@ -26,6 +29,14 @@ enum Command {
default_value = "100000"
)]
cache_max_size: usize,
#[structopt(
long = "mode",
help = "synthesis mode",
possible_values = &["seeded", "unseeded"],
case_insensitive = true,
default_value = "seeded"
)]
mode: SynthesisMode,
},
Aggregate {
#[structopt(long = "aggregates-path", help = "generated aggregates file path")]
@ -104,6 +115,12 @@ struct Cli {
help = "columns where zeros should not be ignored (can be set multiple times)"
)]
sensitive_zeros: Vec<String>,
#[structopt(
long = "n-threads",
help = "number of threads used to process the data in parallel (default is the number of cores)"
)]
n_threads: Option<usize>,
}
fn main() {
@ -118,6 +135,10 @@ fn main() {
trace!("execution parameters: {:#?}", cli);
if let Some(n_threads) = cli.n_threads {
set_number_of_threads(n_threads);
}
match CsvDataBlockCreator::create(
csv::ReaderBuilder::new()
.delimiter(cli.sensitive_delimiter.chars().next().unwrap() as u8)
@ -131,15 +152,20 @@ fn main() {
synthetic_path,
synthetic_delimiter,
cache_max_size,
mode,
} => {
let mut generator = Generator::new(&data_block);
let generated_data =
generator.generate(cli.resolution, cache_max_size, "", &mut progress_reporter);
let mut generator = Generator::new(data_block);
let generated_data = generator.generate(
cli.resolution,
cache_max_size,
String::from(""),
mode,
&mut progress_reporter,
);
if let Err(err) = generator.write_records(
&generated_data.synthetic_data,
if let Err(err) = generated_data.write_synthetic_data(
&synthetic_path,
synthetic_delimiter.chars().next().unwrap() as u8,
synthetic_delimiter.chars().next().unwrap(),
) {
error!("error writing output file: {}", err);
process::exit(1);
@ -153,28 +179,24 @@ fn main() {
sensitivity_threshold,
records_sensitivity_path,
} => {
let mut aggregator = Aggregator::new(&data_block);
let mut aggregator = Aggregator::new(data_block);
let mut aggregated_data = aggregator.aggregate(
reporting_length,
sensitivity_threshold,
&mut progress_reporter,
);
let privacy_risk =
aggregator.calc_privacy_risk(&aggregated_data.aggregates_count, cli.resolution);
let privacy_risk = aggregated_data.calc_privacy_risk(cli.resolution);
if !not_protect {
Aggregator::protect_aggregates_count(
&mut aggregated_data.aggregates_count,
cli.resolution,
);
aggregated_data.protect_aggregates_count(cli.resolution);
}
info!("Calculated privacy risk is: {:#?}", privacy_risk);
if let Err(err) = aggregator.write_aggregates_count(
&aggregated_data.aggregates_count,
if let Err(err) = aggregated_data.write_aggregates_count(
&aggregates_path,
aggregates_delimiter.chars().next().unwrap(),
";",
cli.resolution,
!not_protect,
) {
@ -183,9 +205,7 @@ fn main() {
}
if let Some(path) = records_sensitivity_path {
if let Err(err) = aggregator
.write_records_sensitivity(&aggregated_data.records_sensitivity, &path)
{
if let Err(err) = aggregated_data.write_records_sensitivity(&path, '\t') {
error!("error writing output file: {}", err);
process::exit(1);
}

Просмотреть файл

@ -19,3 +19,6 @@ log = { version = "0.4", features = ["std"] }
csv = { version = "1.1" }
instant = { version = "0.1", features = [ "stdweb", "wasm-bindgen" ] }
pyo3 = { version = "0.15", features = ["extension-module"], optional = true }
rayon = { version = "1.5", optional = true }
serde = { version = "1.0", features = [ "derive", "rc" ] }
serde_json = { version = "1.0" }

Просмотреть файл

@ -1,12 +1,16 @@
use super::{
record::DataBlockRecord,
typedefs::{
AttributeRows, AttributeRowsMap, CsvRecord, CsvRecordSlice, DataBlockHeaders,
DataBlockRecords, DataBlockRecordsSlice,
AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, ColumnIndexByName,
DataBlockHeaders, DataBlockRecords,
},
value::DataBlockValue,
};
use std::collections::HashSet;
use fnv::FnvHashMap;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::{processing::aggregator::typedefs::RecordsSet, utils::math::uround_down};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
@ -15,189 +19,162 @@ use pyo3::prelude::*;
/// The goal of this is to allow data processing to handle with memory references
/// to the data block instead of copying data around
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Debug)]
#[derive(Debug, Serialize, Deserialize)]
pub struct DataBlock {
/// Vector of strings representhing the data headers
/// Vector of strings representing the data headers
pub headers: DataBlockHeaders,
/// Vector of data records, where each record represents a row (headers not included)
pub records: DataBlockRecords,
}
impl DataBlock {
/// Returns a new DataBlock with default values
pub fn default() -> DataBlock {
DataBlock {
headers: DataBlockHeaders::default(),
records: DataBlockRecords::default(),
}
}
/// Returns a new DataBlock
/// # Arguments
/// * `headers` - Vector of string representhing the data headers
/// * `headers` - Vector of string representing the data headers
/// * `records` - Vector of data records, where each record represents a row (headers not included)
#[inline]
pub fn new(headers: DataBlockHeaders, records: DataBlockRecords) -> DataBlock {
DataBlock { headers, records }
}
/// Calcules the rows where each value on the data records is present
/// # Arguments
/// * `records`: List of records to analyze
/// Returns a map of column name -> column index
#[inline]
pub fn calc_attr_rows(records: &DataBlockRecordsSlice) -> AttributeRowsMap {
pub fn calc_column_index_by_name(&self) -> ColumnIndexByName {
self.headers
.iter()
.enumerate()
.map(|(column_index, column_name)| ((**column_name).clone(), column_index))
.collect()
}
/// Calculates the rows where each value on the data records is present
#[inline]
pub fn calc_attr_rows(&self) -> AttributeRowsMap {
let mut attr_rows: AttributeRowsMap = AttributeRowsMap::default();
for (i, r) in records.iter().enumerate() {
for (i, r) in self.records.iter().enumerate() {
for v in r.values.iter() {
attr_rows
.entry(v)
.entry(v.clone())
.or_insert_with(AttributeRows::new)
.push(i);
}
}
attr_rows
}
}
/// Trait that needs to be implement to create a data block.
/// It already contains the logic to create the data block, so we only
/// need to worry about mapping the headers and records from InputType
pub trait DataBlockCreator {
/// Creator input type, it can be a File Reader, another data structure...
type InputType;
/// The error type that can be generated when parsing headers/records
type ErrorType;
// Calculates the rows where each value on the data records is present
/// grouped by column index.
#[inline]
fn gen_use_columns_set(headers: &CsvRecordSlice, use_columns: &[String]) -> HashSet<usize> {
let use_columns_str_set: HashSet<String> = use_columns
.iter()
.map(|c| c.trim().to_lowercase())
.collect();
headers
.iter()
.enumerate()
.filter_map(|(i, h)| {
if use_columns_str_set.is_empty()
|| use_columns_str_set.contains(&h.trim().to_lowercase())
{
Some(i)
} else {
None
pub fn calc_attr_rows_with_no_empty_values(&self) -> AttributeRowsByColumnMap {
let mut attr_rows_by_column: AttributeRowsByColumnMap = AttributeRowsByColumnMap::default();
for (i, r) in self.records.iter().enumerate() {
for v in r.values.iter() {
attr_rows_by_column
.entry(v.column_index)
.or_insert_with(AttributeRowsMap::default)
.entry(v.clone())
.or_insert_with(AttributeRows::new)
.push(i);
}
})
.collect()
}
attr_rows_by_column
}
/// Calculates the rows where each value on the data records is present
/// grouped by column index. This will also include empty values mapped with
/// the `empty_value` string.
/// # Arguments
/// * `empty_value` - Empty values on the final synthetic data will be represented by this
#[inline]
fn gen_sensitive_zeros_set(
filtered_headers: &CsvRecordSlice,
sensitive_zeros: &[String],
) -> HashSet<usize> {
let sensitive_zeros_str_set: HashSet<String> = sensitive_zeros
.iter()
.map(|c| c.trim().to_lowercase())
.collect();
filtered_headers
.iter()
.enumerate()
.filter_map(|(i, h)| {
if sensitive_zeros_str_set.contains(&h.trim().to_lowercase()) {
Some(i)
} else {
None
}
})
.collect()
pub fn calc_attr_rows_by_column(&self, empty_value: &Arc<String>) -> AttributeRowsByColumnMap {
let mut attr_rows_by_column: FnvHashMap<
usize,
FnvHashMap<Arc<DataBlockValue>, RecordsSet>,
> = FnvHashMap::default();
let empty_records: RecordsSet = (0..self.records.len()).collect();
// start with empty values for all columns
for column_index in 0..self.headers.len() {
attr_rows_by_column
.entry(column_index)
.or_insert_with(FnvHashMap::default)
.entry(Arc::new(DataBlockValue::new(
column_index,
empty_value.clone(),
)))
.or_insert_with(|| empty_records.clone());
}
#[inline]
fn map_headers(
headers: &mut CsvRecord,
use_columns: &[String],
sensitive_zeros: &[String],
) -> (CsvRecord, HashSet<usize>, HashSet<usize>) {
let use_columns_set = Self::gen_use_columns_set(headers, use_columns);
let filtered_headers: CsvRecord = headers
.iter()
.enumerate()
.filter_map(|(i, h)| {
if use_columns_set.contains(&i) {
Some(h.clone())
} else {
None
}
})
.collect();
let sensitive_zeros_set = Self::gen_sensitive_zeros_set(&filtered_headers, sensitive_zeros);
// go through each record and map where the values occur on the columns
for (i, r) in self.records.iter().enumerate() {
for value in r.values.iter() {
let current_attr_rows = attr_rows_by_column
.entry(value.column_index)
.or_insert_with(FnvHashMap::default);
(filtered_headers, use_columns_set, sensitive_zeros_set)
// insert it on the correspondent entry for the data block value
current_attr_rows
.entry(value.clone())
.or_insert_with(RecordsSet::default)
.insert(i);
// it's now used being used, so we make sure to remove this from the column empty records
current_attr_rows
.entry(Arc::new(DataBlockValue::new(
value.column_index,
empty_value.clone(),
)))
.or_insert_with(RecordsSet::default)
.remove(&i);
}
}
#[inline]
fn map_records(
records: &[CsvRecord],
use_columns_set: &HashSet<usize>,
sensitive_zeros_set: &HashSet<usize>,
record_limit: usize,
) -> DataBlockRecords {
let map_result = |record: &CsvRecord| {
let values: CsvRecord = record
.iter()
.enumerate()
.filter_map(|(i, h)| {
if use_columns_set.contains(&i) {
Some(h.trim().to_string())
} else {
None
}
})
.collect();
DataBlockRecord::new(
values
.iter()
.enumerate()
.filter_map(|(i, r)| {
let record_val = r.trim();
if !record_val.is_empty()
&& (sensitive_zeros_set.contains(&i) || record_val != "0")
{
Some(DataBlockValue::new(i, record_val.into()))
} else {
None
}
})
// sort the records ids, so we can leverage the intersection alg later on
attr_rows_by_column
.drain()
.map(|(column_index, mut attr_rows)| {
(
column_index,
attr_rows
.drain()
.map(|(value, mut rows_set)| (value, rows_set.drain().sorted().collect()))
.collect(),
)
};
if record_limit > 0 {
records.iter().take(record_limit).map(map_result).collect()
} else {
records.iter().map(map_result).collect()
}
})
.collect()
}
#[inline]
fn create(
input_res: Result<Self::InputType, Self::ErrorType>,
use_columns: &[String],
sensitive_zeros: &[String],
record_limit: usize,
) -> Result<DataBlock, Self::ErrorType> {
let mut input = input_res?;
let (headers, use_columns_set, sensitive_zeros_set) = Self::map_headers(
&mut Self::get_headers(&mut input)?,
use_columns,
sensitive_zeros,
);
let records = Self::map_records(
&Self::get_records(&mut input)?,
&use_columns_set,
&sensitive_zeros_set,
record_limit,
);
Ok(DataBlock::new(headers, records))
/// Returns the number of records on the data block
pub fn number_of_records(&self) -> usize {
self.records.len()
}
/// Should be implemented to return the CsvRecords reprensenting the headers
fn get_headers(input: &mut Self::InputType) -> Result<CsvRecord, Self::ErrorType>;
/// Should be implemented to return the vector of CsvRecords reprensenting rows
fn get_records(input: &mut Self::InputType) -> Result<Vec<CsvRecord>, Self::ErrorType>;
#[inline]
/// Returns the number of records on the data block protected by `resolution`
pub fn protected_number_of_records(&self, resolution: usize) -> usize {
uround_down(self.number_of_records() as f64, resolution as f64)
}
#[inline]
/// Normalizes the reporting length based on the number of selected headers.
/// Returns the normalized value
/// # Arguments
/// * `reporting_length` - Reporting length to be normalized (0 means use all columns)
pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize {
if reporting_length == 0 {
self.headers.len()
} else {
usize::min(reporting_length, self.headers.len())
}
}
}

Просмотреть файл

@ -1,4 +1,4 @@
use super::{block::DataBlockCreator, typedefs::CsvRecord};
use super::{data_block_creator::DataBlockCreator, typedefs::CsvRecord};
use csv::{Error, Reader, StringRecord};
use std::fs::File;

Просмотреть файл

@ -0,0 +1,36 @@
use csv::Error;
use std::fmt::{Display, Formatter, Result};
#[cfg(feature = "pyo3")]
use pyo3::exceptions::PyIOError;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
/// Wrapper for a csv::Error, so the from
/// trait can be implemented for PyErr
pub struct CsvIOError {
error: Error,
}
impl CsvIOError {
/// Creates a new CsvIOError from a csv::Error
/// # Arguments
/// * `error` - Related csv::Error
pub fn new(error: Error) -> CsvIOError {
CsvIOError { error }
}
}
impl Display for CsvIOError {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "{}", self.error)
}
}
#[cfg(feature = "pyo3")]
impl From<CsvIOError> for PyErr {
fn from(err: CsvIOError) -> PyErr {
PyIOError::new_err(err.error.to_string())
}
}

Просмотреть файл

@ -0,0 +1,166 @@
use super::{
block::DataBlock,
record::DataBlockRecord,
typedefs::{CsvRecord, CsvRecordRef, CsvRecordRefSlice, CsvRecordSlice, DataBlockRecords},
value::DataBlockValue,
};
use std::{collections::HashSet, sync::Arc};
/// Trait that needs to be implement to create a data block.
/// It already contains the logic to create the data block, so we only
/// need to worry about mapping the headers and records from InputType
pub trait DataBlockCreator {
/// Creator input type, it can be a File Reader, another data structure...
type InputType;
/// The error type that can be generated when parsing headers/records
type ErrorType;
#[inline]
fn gen_use_columns_set(headers: &CsvRecordSlice, use_columns: &[String]) -> HashSet<usize> {
let use_columns_str_set: HashSet<String> = use_columns
.iter()
.map(|c| c.trim().to_lowercase())
.collect();
headers
.iter()
.enumerate()
.filter_map(|(i, h)| {
if use_columns_str_set.is_empty()
|| use_columns_str_set.contains(&h.trim().to_lowercase())
{
Some(i)
} else {
None
}
})
.collect()
}
#[inline]
fn gen_sensitive_zeros_set(
filtered_headers: &CsvRecordRefSlice,
sensitive_zeros: &[String],
) -> HashSet<usize> {
let sensitive_zeros_str_set: HashSet<String> = sensitive_zeros
.iter()
.map(|c| c.trim().to_lowercase())
.collect();
filtered_headers
.iter()
.enumerate()
.filter_map(|(i, h)| {
if sensitive_zeros_str_set.contains(&h.trim().to_lowercase()) {
Some(i)
} else {
None
}
})
.collect()
}
#[inline]
fn normalize_value(value: &str) -> String {
value
.trim()
// replace reserved delimiters
.replace(";", "<semicolon>")
.replace(":", "<colon>")
}
#[inline]
fn map_headers(
headers: &mut CsvRecord,
use_columns: &[String],
sensitive_zeros: &[String],
) -> (CsvRecordRef, HashSet<usize>, HashSet<usize>) {
let use_columns_set = Self::gen_use_columns_set(headers, use_columns);
let filtered_headers: CsvRecordRef = headers
.iter()
.enumerate()
.filter_map(|(i, h)| {
if use_columns_set.contains(&i) {
Some(Arc::new(Self::normalize_value(h)))
} else {
None
}
})
.collect();
let sensitive_zeros_set = Self::gen_sensitive_zeros_set(&filtered_headers, sensitive_zeros);
(filtered_headers, use_columns_set, sensitive_zeros_set)
}
#[inline]
fn map_records(
records: &[CsvRecord],
use_columns_set: &HashSet<usize>,
sensitive_zeros_set: &HashSet<usize>,
record_limit: usize,
) -> DataBlockRecords {
let map_result = |record: &CsvRecord| {
let values: CsvRecord = record
.iter()
.enumerate()
.filter_map(|(i, h)| {
if use_columns_set.contains(&i) {
Some(h.trim().to_string())
} else {
None
}
})
.collect();
Arc::new(DataBlockRecord::new(
values
.iter()
.enumerate()
.filter_map(|(i, r)| {
let record_val = Self::normalize_value(r);
if !record_val.is_empty()
&& (sensitive_zeros_set.contains(&i) || record_val != "0")
{
Some(Arc::new(DataBlockValue::new(i, Arc::new(record_val))))
} else {
None
}
})
.collect(),
))
};
if record_limit > 0 {
records.iter().take(record_limit).map(map_result).collect()
} else {
records.iter().map(map_result).collect()
}
}
#[inline]
fn create(
input_res: Result<Self::InputType, Self::ErrorType>,
use_columns: &[String],
sensitive_zeros: &[String],
record_limit: usize,
) -> Result<Arc<DataBlock>, Self::ErrorType> {
let mut input = input_res?;
let (headers, use_columns_set, sensitive_zeros_set) = Self::map_headers(
&mut Self::get_headers(&mut input)?,
use_columns,
sensitive_zeros,
);
let records = Self::map_records(
&Self::get_records(&mut input)?,
&use_columns_set,
&sensitive_zeros_set,
record_limit,
);
Ok(Arc::new(DataBlock::new(headers, records)))
}
/// Should be implemented to return the CsvRecords representing the headers
fn get_headers(input: &mut Self::InputType) -> Result<CsvRecord, Self::ErrorType>;
/// Should be implemented to return the vector of CsvRecords representing rows
fn get_records(input: &mut Self::InputType) -> Result<Vec<CsvRecord>, Self::ErrorType>;
}

Просмотреть файл

@ -1,10 +1,21 @@
/// Module defining the structures that represent a data block
pub mod block;
/// Module to create data blocks from CSV files
pub mod csv_block_creator;
/// Defines io errors for handling csv files
/// that are easier to bind to other languages
pub mod csv_io_error;
/// Module to create data blocks from different input types (trait definitions)
pub mod data_block_creator;
/// Module defining the structures that represent a data block record
pub mod record;
/// Type definitions related to data blocks
pub mod typedefs;
/// Module defining the structures that represent a data block value
pub mod value;

Просмотреть файл

@ -1,10 +1,12 @@
use super::value::DataBlockValue;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Represents all the values of a given row in a data block
#[derive(Debug)]
#[derive(Debug, Serialize, Deserialize)]
pub struct DataBlockRecord {
/// Vector of data block values for a given row indexed by column
pub values: Vec<DataBlockValue>,
pub values: Vec<Arc<DataBlockValue>>,
}
impl DataBlockRecord {
@ -12,7 +14,7 @@ impl DataBlockRecord {
/// # Arguments
/// * `values` - Vector of data block values for a given row indexed by column
#[inline]
pub fn new(values: Vec<DataBlockValue>) -> DataBlockRecord {
pub fn new(values: Vec<Arc<DataBlockValue>>) -> DataBlockRecord {
DataBlockRecord { values }
}
}

Просмотреть файл

@ -1,5 +1,8 @@
use super::{record::DataBlockRecord, value::DataBlockValue};
use fnv::FnvHashMap;
use std::sync::Arc;
use crate::processing::evaluator::rare_combinations_comparison_data::CombinationComparison;
/// Ordered vector of rows where a particular value combination is present
pub type AttributeRows = Vec<usize>;
@ -11,23 +14,38 @@ pub type AttributeRowsSlice = [usize];
pub type CsvRecord = Vec<String>;
/// The same as CsvRecord, but keeping a reference to the string value, so data does not have to be duplicated
pub type CsvRecordRef<'data_block_value> = Vec<&'data_block_value str>;
pub type CsvRecordRef = Vec<Arc<String>>;
/// Slice of CsvRecord
pub type CsvRecordSlice = [String];
/// Vector of strings representhing the data block headers
pub type DataBlockHeaders = CsvRecord;
/// Slice of CsvRecord
pub type CsvRecordRefSlice = [Arc<String>];
/// Vector of strings representing the data block headers
pub type DataBlockHeaders = CsvRecordRef;
/// Slice of DataBlockHeaders
pub type DataBlockHeadersSlice = CsvRecordSlice;
pub type DataBlockHeadersSlice = CsvRecordRefSlice;
/// Vector of data block records, where each record represents a row
pub type DataBlockRecords = Vec<DataBlockRecord>;
/// Slice of DataBlockRecords
pub type DataBlockRecordsSlice = [DataBlockRecord];
pub type DataBlockRecords = Vec<Arc<DataBlockRecord>>;
/// HashMap with a data block value as key and all the attribute row indexes where it occurs as value
pub type AttributeRowsMap<'data_block_value> =
FnvHashMap<&'data_block_value DataBlockValue, AttributeRows>;
pub type AttributeRowsMap = FnvHashMap<Arc<DataBlockValue>, AttributeRows>;
/// HashMap with a data block value as key and attribute row indexes as value
pub type AttributeRowsRefMap = FnvHashMap<Arc<DataBlockValue>, Arc<AttributeRows>>;
/// Maps the column index -> data block value -> rows where the value appear
pub type AttributeRowsByColumnMap = FnvHashMap<usize, AttributeRowsMap>;
/// Raw synthesized data (vector of csv record references to the original data block)
pub type RawSyntheticData = Vec<CsvRecordRef>;
/// A vector of combination comparisons
/// (between sensitive and synthetic data)
pub type CombinationsComparisons = Vec<CombinationComparison>;
/// Maps a column name to the corresponding column index
pub type ColumnIndexByName = FnvHashMap<String, usize>;

Просмотреть файл

@ -1,10 +1,16 @@
use super::typedefs::DataBlockHeadersSlice;
use serde::{Deserialize, Serialize};
use std::{fmt::Display, str::FromStr, sync::Arc};
const VALUE_DELIMITER: char = ':';
/// Represents a value of a given data block for a particular row and column
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct DataBlockValue {
/// Column index this value belongs to starting in '0'
pub column_index: usize,
/// Value stored on the CSV file for a given row at `column_index`
pub value: String,
pub value: Arc<String>,
}
impl DataBlockValue {
@ -13,10 +19,72 @@ impl DataBlockValue {
/// * `column_index` - Column index this value belongs to starting in '0'
/// * `value` - Value stored on the CSV file for a given row at `column_index`
#[inline]
pub fn new(column_index: usize, value: String) -> DataBlockValue {
pub fn new(column_index: usize, value: Arc<String>) -> DataBlockValue {
DataBlockValue {
column_index,
value,
}
}
/// Formats a data block value as String using the
/// corresponding header name
/// The result is formatted as: `{header_name}:{block_value}`
/// # Arguments
/// * `headers` - data block headers
/// * `value` - value to be formatted
#[inline]
pub fn format_str_using_headers(&self, headers: &DataBlockHeadersSlice) -> String {
format!(
"{}{}{}",
headers[self.column_index], VALUE_DELIMITER, self.value
)
}
}
/// Error that can happen when parsing a data block from
/// a string
pub struct ParseDataBlockValueError {
error_message: String,
}
impl ParseDataBlockValueError {
#[inline]
/// Creates a new ParseDataBlockValueError with `error_message`
pub fn new(error_message: String) -> ParseDataBlockValueError {
ParseDataBlockValueError { error_message }
}
}
impl Display for ParseDataBlockValueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error_message)
}
}
impl FromStr for DataBlockValue {
type Err = ParseDataBlockValueError;
/// Creates a new DataBlockValue by parsing `str_value`
fn from_str(str_value: &str) -> Result<Self, Self::Err> {
if let Some(pos) = str_value.find(VALUE_DELIMITER) {
Ok(DataBlockValue::new(
str_value[..pos]
.parse::<usize>()
.map_err(|err| ParseDataBlockValueError::new(err.to_string()))?,
Arc::new(str_value[pos + 1..].into()),
))
} else {
Err(ParseDataBlockValueError::new(format!(
"data block value missing '{}'",
VALUE_DELIMITER
)))
}
}
}
impl Display for DataBlockValue {
/// Formats the DataBlockValue as a string
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(write!(f, "{}:{}", self.column_index, self.value)?)
}
}

Просмотреть файл

@ -0,0 +1,448 @@
use super::{
privacy_risk_summary::PrivacyRiskSummary,
records_analysis_data::RecordsAnalysisData,
typedefs::{
AggregatedCountByLenMap, AggregatesCountMap, AggregatesCountStringMap, RecordsByLenMap,
RecordsSensitivity,
},
};
use itertools::Itertools;
use log::info;
use serde::{Deserialize, Serialize};
use std::{
io::{BufReader, BufWriter, Error, Write},
sync::Arc,
};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use crate::{
data_block::block::DataBlock,
processing::aggregator::typedefs::RecordsSet,
utils::{math::uround_down, time::ElapsedDurationLogger},
};
/// Aggregated data produced by the Aggregator
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Serialize, Deserialize)]
pub struct AggregatedData {
/// Data block from where this aggregated data was generated
pub data_block: Arc<DataBlock>,
/// Maps a value combination to its aggregated count
pub aggregates_count: AggregatesCountMap,
/// A vector of sensitivities for each record (the vector index is the record index)
pub records_sensitivity: RecordsSensitivity,
/// Maximum length used to compute attribute combinations
pub reporting_length: usize,
}
impl AggregatedData {
/// Creates a new AggregatedData struct with default values
#[inline]
pub fn default() -> AggregatedData {
AggregatedData {
data_block: Arc::new(DataBlock::default()),
aggregates_count: AggregatesCountMap::default(),
records_sensitivity: RecordsSensitivity::default(),
reporting_length: 0,
}
}
/// Creates a new AggregatedData struct
/// # Arguments:
/// * `data_block` - Data block with the original data
/// * `aggregates_count` - Computed aggregates count map
/// * `records_sensitivity` - Computed sensitivity for the records
/// * `reporting_length` - Maximum length used to compute attribute combinations
#[inline]
pub fn new(
data_block: Arc<DataBlock>,
aggregates_count: AggregatesCountMap,
records_sensitivity: RecordsSensitivity,
reporting_length: usize,
) -> AggregatedData {
AggregatedData {
data_block,
aggregates_count,
records_sensitivity,
reporting_length,
}
}
#[inline]
/// Check if the records len map contains a value across all lengths
fn records_by_len_contains(records_by_len: &RecordsByLenMap, value: &usize) -> bool {
records_by_len
.values()
.any(|records| records.contains(value))
}
#[inline]
/// Whe first generated the RecordsByLenMap might contain
/// the same records appearing in different combination lengths.
/// This will keep only the record on the shortest length key
/// and remove the other occurrences
fn keep_records_only_on_shortest_len(records_by_len: &mut RecordsByLenMap) {
let lengths: Vec<usize> = records_by_len.keys().cloned().sorted().collect();
let max_len = lengths.last().copied().unwrap_or(0);
// make sure the record will be only present in the shortest len
// start on the shortest length
for l in lengths {
for r in records_by_len.get(&l).unwrap().clone() {
// search all lengths > l and remove the record from there
for n in l + 1..=max_len {
if let Some(records) = records_by_len.get_mut(&n) {
records.remove(&r);
}
}
}
}
// retain only non-empty record lists
records_by_len.retain(|_, records| !records.is_empty());
}
#[inline]
fn _read_from_json(file_path: &str) -> Result<AggregatedData, Error> {
let _duration_logger = ElapsedDurationLogger::new("aggregated count json read");
Ok(serde_json::from_reader(BufReader::new(
std::fs::File::open(file_path)?,
))?)
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl AggregatedData {
/// Builds a map from value combinations formatted as string to its aggregated count
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
/// # Arguments:
/// * `combination_delimiter` - Delimiter used to join combinations
pub fn get_formatted_aggregates_count(
&self,
combination_delimiter: &str,
) -> AggregatesCountStringMap {
self.aggregates_count
.iter()
.map(|(key, value)| {
(
key.format_str_using_headers(&self.data_block.headers, combination_delimiter),
value.clone(),
)
})
.collect()
}
#[cfg(feature = "pyo3")]
/// A vector of sensitivities for each record (the vector index is the record index)
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
pub fn get_records_sensitivity(&self) -> RecordsSensitivity {
self.records_sensitivity.clone()
}
/// Round the aggregated counts down to the nearest multiple of resolution
/// # Arguments:
/// * `resolution` - Reporting resolution used for data synthesis
pub fn protect_aggregates_count(&mut self, resolution: usize) {
let _duration_logger = ElapsedDurationLogger::new("aggregates count protect");
info!(
"protecting aggregates counts with resolution {}",
resolution
);
for count in self.aggregates_count.values_mut() {
count.count = uround_down(count.count as f64, resolution as f64);
}
// remove 0 counts from response
self.aggregates_count.retain(|_, count| count.count > 0);
}
/// Calculates the records that contain rare combinations grouped by length.
/// This might contain duplicated records on different lengths if the record
/// contains more than one rare combination. Unique combinations are also contained
/// in this.
/// # Arguments:
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_all_rare_combinations_records_by_len(&self, resolution: usize) -> RecordsByLenMap {
let _duration_logger =
ElapsedDurationLogger::new("all rare combinations records by len calculation");
let mut rare_records_by_len: RecordsByLenMap = RecordsByLenMap::default();
for (agg, count) in self.aggregates_count.iter() {
if count.count < resolution {
rare_records_by_len
.entry(agg.len())
.or_insert_with(RecordsSet::default)
.extend(&count.contained_in_records);
}
}
rare_records_by_len
}
/// Calculates the records that contain unique combinations grouped by length.
/// This might contain duplicated records on different lengths if the record
/// contains more than one unique combination.
pub fn calc_all_unique_combinations_records_by_len(&self) -> RecordsByLenMap {
let _duration_logger =
ElapsedDurationLogger::new("all unique combinations records by len calculation");
let mut unique_records_by_len: RecordsByLenMap = RecordsByLenMap::default();
for (agg, count) in self.aggregates_count.iter() {
if count.count == 1 {
unique_records_by_len
.entry(agg.len())
.or_insert_with(RecordsSet::default)
.extend(&count.contained_in_records);
}
}
unique_records_by_len
}
/// Calculate the records that contain unique and rare combinations grouped by length.
/// A tuple with the `(unique, rare)` is returned.
/// Both returned maps are ensured to only contain the records on the shortest length,
/// so each record will appear only on the shortest combination length that isolates it
/// within a rare group. Also, if the record has a unique combination, it will not
/// be present on the rare map, only on the unique one.
/// # Arguments:
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_unique_rare_combinations_records_by_len(
&self,
resolution: usize,
) -> (RecordsByLenMap, RecordsByLenMap) {
let _duration_logger =
ElapsedDurationLogger::new("unique/rare combinations records by len calculation");
let mut unique_records_by_len = self.calc_all_unique_combinations_records_by_len();
let mut rare_records_by_len = self.calc_all_rare_combinations_records_by_len(resolution);
AggregatedData::keep_records_only_on_shortest_len(&mut unique_records_by_len);
AggregatedData::keep_records_only_on_shortest_len(&mut rare_records_by_len);
// remove records with unique combinations from the rare map
rare_records_by_len.values_mut().for_each(|records| {
records.retain(|r| !AggregatedData::records_by_len_contains(&unique_records_by_len, r));
});
(unique_records_by_len, rare_records_by_len)
}
/// Perform the records analysis and returns the data containing
/// unique, rare and risky information grouped per length.
/// # Arguments:
/// * `resolution` - Reporting resolution used for data synthesis
/// * `protect` - Whether or not the counts should be rounded to the
/// nearest smallest multiple of resolution
pub fn calc_records_analysis_by_len(
&self,
resolution: usize,
protect: bool,
) -> RecordsAnalysisData {
let _duration_logger = ElapsedDurationLogger::new("records analysis by len");
let (unique_records_by_len, rare_records_by_len) =
self.calc_unique_rare_combinations_records_by_len(resolution);
RecordsAnalysisData::from_unique_rare_combinations_records_by_len(
&unique_records_by_len,
&rare_records_by_len,
self.data_block.records.len(),
self.reporting_length,
resolution,
protect,
)
}
/// Calculates the number of rare combinations grouped by combination length
/// # Arguments:
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_rare_combinations_count_by_len(
&self,
resolution: usize,
) -> AggregatedCountByLenMap {
let _duration_logger =
ElapsedDurationLogger::new("rare combinations count by len calculation");
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
info!(
"calculating rare combinations counts by length with resolution {}",
resolution
);
for (agg, count) in self.aggregates_count.iter() {
if count.count < resolution {
let curr_count = result.entry(agg.len()).or_insert(0);
*curr_count += 1;
}
}
result
}
/// Calculates the number of combinations grouped by combination length
pub fn calc_combinations_count_by_len(&self) -> AggregatedCountByLenMap {
let _duration_logger = ElapsedDurationLogger::new("combination count by len calculation");
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
info!("calculating combination counts by length");
for agg in self.aggregates_count.keys() {
let curr_count = result.entry(agg.len()).or_insert(0);
*curr_count += 1;
}
result
}
/// Calculates the sum of all combination counts grouped by combination length
pub fn calc_combinations_sum_by_len(&self) -> AggregatedCountByLenMap {
let _duration_logger = ElapsedDurationLogger::new("combinations sum by len calculation");
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
info!("calculating combination counts sums by length");
for (agg, count) in self.aggregates_count.iter() {
let curr_sum = result.entry(agg.len()).or_insert(0);
*curr_sum += count.count;
}
result
}
/// Calculates the privacy risk related with data block and the generated
/// aggregates counts
/// # Arguments:
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_privacy_risk(&self, resolution: usize) -> PrivacyRiskSummary {
let _duration_logger = ElapsedDurationLogger::new("privacy risk calculation");
info!("calculating privacy risk...");
PrivacyRiskSummary::from_aggregates_count(
self.data_block.records.len(),
&self.aggregates_count,
resolution,
)
}
/// Writes the aggregates counts to the file system in a csv/tsv like format
/// # Arguments:
/// * `aggregates_path` - File path to be written
/// * `aggregates_delimiter` - Delimiter to use when writing to `aggregates_path`
/// * `combination_delimiter` - Delimiter used to join combinations and format then
/// as strings
/// * `resolution` - Reporting resolution used for data synthesis
/// * `protected` - Whether or not the counts were protected before calling this
pub fn write_aggregates_count(
&self,
aggregates_path: &str,
aggregates_delimiter: char,
combination_delimiter: &str,
resolution: usize,
protected: bool,
) -> Result<(), Error> {
let _duration_logger = ElapsedDurationLogger::new("write aggregates count");
info!("writing file {}", aggregates_path);
let mut file = std::fs::File::create(aggregates_path)?;
file.write_all(
format!(
"selections{}{}\n",
aggregates_delimiter,
if protected {
"protected_count"
} else {
"count"
}
)
.as_bytes(),
)?;
file.write_all(
format!(
"selections{}{}\n",
aggregates_delimiter,
uround_down(self.data_block.records.len() as f64, resolution as f64)
)
.as_bytes(),
)?;
for aggregate in self.aggregates_count.keys() {
file.write_all(
format!(
"{}{}{}\n",
aggregate
.format_str_using_headers(&self.data_block.headers, combination_delimiter),
aggregates_delimiter,
self.aggregates_count[aggregate].count
)
.as_bytes(),
)?
}
Ok(())
}
/// Writes the records sensitivity to the file system in a csv/tsv like format
/// # Arguments:
/// * `records_sensitivity_path` - File path to be written
/// * `records_sensitivity_delimiter` - Delimiter to use when writing to `records_sensitivity_path`
pub fn write_records_sensitivity(
&self,
records_sensitivity_path: &str,
records_sensitivity_delimiter: char,
) -> Result<(), Error> {
let _duration_logger = ElapsedDurationLogger::new("write records sensitivity");
info!("writing file {}", records_sensitivity_path);
let mut file = std::fs::File::create(records_sensitivity_path)?;
file.write_all(
format!(
"record_index{}record_sensitivity\n",
records_sensitivity_delimiter
)
.as_bytes(),
)?;
for (i, sensitivity) in self.records_sensitivity.iter().enumerate() {
file.write_all(
format!("{}{}{}\n", i, records_sensitivity_delimiter, sensitivity).as_bytes(),
)?
}
Ok(())
}
/// Serializes the aggregated data to a json file
/// # Arguments:
/// * `file_path` - File path to be written
pub fn write_to_json(&self, file_path: &str) -> Result<(), Error> {
let _duration_logger = ElapsedDurationLogger::new("aggregated count json write");
Ok(serde_json::to_writer(
BufWriter::new(std::fs::File::create(file_path)?),
&self,
)?)
}
#[cfg(feature = "pyo3")]
#[staticmethod]
/// Deserializes the aggregated data from a json file
/// # Arguments:
/// * `file_path` - File path to read from
pub fn read_from_json(file_path: &str) -> Result<AggregatedData, Error> {
AggregatedData::_read_from_json(file_path)
}
#[cfg(not(feature = "pyo3"))]
/// Deserializes the aggregated data from a json file
/// # Arguments:
/// * `file_path` - File path to read from
pub fn read_from_json(file_path: &str) -> Result<AggregatedData, Error> {
AggregatedData::_read_from_json(file_path)
}
}
#[cfg(feature = "pyo3")]
pub fn register(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<AggregatedData>()?;
Ok(())
}

Просмотреть файл

@ -0,0 +1,174 @@
use super::aggregated_data::AggregatedData;
use super::rows_aggregator::RowsAggregator;
use super::typedefs::RecordsSet;
use itertools::Itertools;
use log::info;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::data_block::block::DataBlock;
use crate::processing::aggregator::record_attrs_selector::RecordAttrsSelector;
use crate::utils::math::calc_percentage;
use crate::utils::reporting::ReportProgress;
use crate::utils::threading::get_number_of_threads;
use crate::utils::time::ElapsedDurationLogger;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
/// Result of data aggregation for each combination
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregatedCount {
/// How many times the combination appears on the records
pub count: usize,
/// Which records this combinations is part of
pub contained_in_records: RecordsSet,
}
#[cfg(feature = "pyo3")]
#[cfg_attr(feature = "pyo3", pymethods)]
impl AggregatedCount {
/// How many times the combination appears on the records
#[getter]
fn count(&self) -> usize {
self.count
}
/// Which records this combinations is part of
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
fn get_contained_in_records(&self) -> RecordsSet {
self.contained_in_records.clone()
}
}
impl Default for AggregatedCount {
fn default() -> AggregatedCount {
AggregatedCount {
count: 0,
contained_in_records: RecordsSet::default(),
}
}
}
/// Process a data block to produced aggregated data
pub struct Aggregator {
data_block: Arc<DataBlock>,
}
impl Aggregator {
/// Returns a data aggregator for the given data block
/// # Arguments
/// * `data_block` - The data block to be processed
#[inline]
pub fn new(data_block: Arc<DataBlock>) -> Aggregator {
Aggregator { data_block }
}
/// Aggregates the data block and returns the aggregated data back
/// # Arguments
/// * `reporting_length` - Calculate combinations from 1 up to `reporting_length`
/// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes
/// (0 means no suppression)
/// * `progress_reporter` - Will be used to report the processing
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
pub fn aggregate<T>(
&mut self,
reporting_length: usize,
sensitivity_threshold: usize,
progress_reporter: &mut Option<T>,
) -> AggregatedData
where
T: ReportProgress,
{
let _duration_logger = ElapsedDurationLogger::new("data aggregation");
let normalized_reporting_length =
self.data_block.normalize_reporting_length(reporting_length);
let length_range = (1..=normalized_reporting_length).collect::<Vec<usize>>();
let total_n_records = self.data_block.records.len();
let total_n_records_f64 = total_n_records as f64;
info!(
"aggregating data with reporting length = {}, sensitivity_threshold = {} and {} thread(s)",
normalized_reporting_length, sensitivity_threshold, get_number_of_threads()
);
let result = RowsAggregator::aggregate_all(
total_n_records,
&mut self.build_rows_aggregators(&length_range, sensitivity_threshold),
progress_reporter,
);
Aggregator::update_aggregate_progress(
progress_reporter,
total_n_records,
total_n_records_f64,
);
info!(
"data aggregated resulting in {} distinct combinations...",
result.aggregates_count.len()
);
info!(
"suppression ratio of aggregates is {:.2}%",
(1.0 - (result.selected_combs_count as f64 / result.all_combs_count as f64)) * 100.0
);
AggregatedData::new(
self.data_block.clone(),
result.aggregates_count,
result.records_sensitivity,
normalized_reporting_length,
)
}
#[inline]
fn build_rows_aggregators<'length_range>(
&self,
length_range: &'length_range [usize],
sensitivity_threshold: usize,
) -> Vec<RowsAggregator<'length_range>> {
if self.data_block.records.is_empty() {
return Vec::default();
}
let chunk_size = ((self.data_block.records.len() as f64) / (get_number_of_threads() as f64))
.ceil() as usize;
let mut rows_aggregators: Vec<RowsAggregator> = Vec::default();
let attr_rows_map = Arc::new(self.data_block.calc_attr_rows());
for c in &self
.data_block
.records
.clone()
.drain(..)
.enumerate()
.chunks(chunk_size)
{
rows_aggregators.push(RowsAggregator::new(
self.data_block.clone(),
c.collect(),
RecordAttrsSelector::new(
length_range,
sensitivity_threshold,
attr_rows_map.clone(),
),
))
}
rows_aggregators
}
#[inline]
fn update_aggregate_progress<T>(
progress_reporter: &mut Option<T>,
n_processed: usize,
total: f64,
) where
T: ReportProgress,
{
if let Some(r) = progress_reporter {
r.report(calc_percentage(n_processed as f64, total));
}
}
}

Просмотреть файл

@ -1,501 +1,27 @@
/// Module to represent aggregated data and provide
/// some methods/utilities for information extracted from it
pub mod aggregated_data;
/// Dataset privacy risk definitions
pub mod privacy_risk_summary;
/// Module that can the used to suppress attributes from records to
/// meet a certain sensitivity threshold
pub mod record_attrs_selector;
/// Defines structures related to records analysis (unique, rare and risky
/// information)
pub mod records_analysis_data;
/// Type definitions related to the aggregation process
pub mod typedefs;
use self::typedefs::{
AggregatedCountByLenMap, AggregatesCountMap, RecordsSensitivity, RecordsSensitivitySlice,
RecordsSet, ValueCombinationSlice,
};
use instant::Duration;
use itertools::Itertools;
use log::Level::Trace;
use log::{info, log_enabled, trace};
use std::cmp::min;
use std::io::{Error, Write};
/// Defines structures to store value combinations generated
/// during the aggregate process
pub mod value_combination;
use crate::data_block::block::DataBlock;
use crate::data_block::typedefs::DataBlockHeadersSlice;
use crate::data_block::value::DataBlockValue;
use crate::measure_time;
use crate::processing::aggregator::record_attrs_selector::RecordAttrsSelector;
use crate::utils::math::{calc_n_combinations_range, calc_percentage, uround_down};
use crate::utils::reporting::ReportProgress;
use crate::utils::time::ElapsedDuration;
mod data_aggregator;
/// Represents the privacy risk information related to a data block
#[derive(Debug)]
pub struct PrivacyRiskSummary {
/// Total number of records on the data block
pub total_number_of_records: usize,
/// Total number of combinations aggregated (up to reporting length)
pub total_number_of_combinations: usize,
/// Number of records with unique combinations
pub records_with_unique_combinations_count: usize,
/// Number of records with rare combinations (combination count < resolution)
pub records_with_rare_combinations_count: usize,
/// Number of unique combinations
pub unique_combinations_count: usize,
/// Number of rare combinations
pub rare_combinations_count: usize,
/// Proportion of records containing unique combinations
pub records_with_unique_combinations_proportion: f64,
/// Proportion of records containing rare combinations
pub records_with_rare_combinations_proportion: f64,
/// Proportion of unique combinations
pub unique_combinations_proportion: f64,
/// Proportion of rare combinations
pub rare_combinations_proportion: f64,
}
pub use data_aggregator::{AggregatedCount, Aggregator};
/// Result of data aggregation for each combination
#[derive(Debug)]
pub struct AggregatedCount {
/// How many times the combination appears on the records
pub count: usize,
/// Which records this combinations is part of
pub contained_in_records: RecordsSet,
}
impl Default for AggregatedCount {
fn default() -> AggregatedCount {
AggregatedCount {
count: 0,
contained_in_records: RecordsSet::default(),
}
}
}
#[derive(Debug)]
struct AggregatorDurations {
aggregate: Duration,
calc_rare_combinations_count_by_len: Duration,
calc_combinations_count_by_len: Duration,
calc_combinations_sum_by_len: Duration,
calc_privacy_risk: Duration,
write_aggregates_count: Duration,
}
/// Aggregated data produced by the Aggregator
pub struct AggregatedData<'data_block> {
/// Maps a value combination to its aggregated count
pub aggregates_count: AggregatesCountMap<'data_block>,
/// A vector of sensitivities for each record (the vector index is the record index)
pub records_sensitivity: RecordsSensitivity,
}
/// Process a data block to produced aggregated data
pub struct Aggregator<'data_block> {
data_block: &'data_block DataBlock,
durations: AggregatorDurations,
}
impl<'data_block> Aggregator<'data_block> {
/// Returns a data aggregator for the given data block
/// # Arguments
/// * `data_block` - The data block to be processed
#[inline]
pub fn new(data_block: &'data_block DataBlock) -> Aggregator<'data_block> {
Aggregator {
data_block,
durations: AggregatorDurations {
aggregate: Duration::default(),
calc_rare_combinations_count_by_len: Duration::default(),
calc_combinations_count_by_len: Duration::default(),
calc_combinations_sum_by_len: Duration::default(),
calc_privacy_risk: Duration::default(),
write_aggregates_count: Duration::default(),
},
}
}
/// Formats a data block value as String.
/// The result is formatted as: `{header_name}:{block_value}`
/// # Arguments
/// * `headers`: data block headers
/// * `value`: value to be formatted
#[inline]
pub fn format_data_block_value_str(
headers: &DataBlockHeadersSlice,
value: &'data_block DataBlockValue,
) -> String {
format!("{}:{}", headers[value.column_index], value.value)
}
/// Formats a value combination as String.
/// The result is formatted as:
/// `{header_name}:{block_value};{header_name}:{block_value}...`
/// # Arguments
/// * `headers`: data block headers
/// * `aggregate`: combinations to be formatted
#[inline]
pub fn format_aggregate_str(
headers: &DataBlockHeadersSlice,
aggregate: &ValueCombinationSlice<'data_block>,
) -> String {
let mut str = String::default();
for comb in aggregate {
if !str.is_empty() {
str += ";";
}
str += Aggregator::format_data_block_value_str(headers, comb).as_str();
}
str
}
/// Aggregates the data block and returns the aggregated data back
/// # Arguments
/// * `reporting_length` - Calculate combinations from 1 up to `reporting_length`
/// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes
/// (0 means no suppression)
/// * `progress_reporter` - Will be used to report the processing
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
pub fn aggregate<T>(
&mut self,
reporting_length: usize,
sensitivity_threshold: usize,
progress_reporter: &mut Option<T>,
) -> AggregatedData<'data_block>
where
T: ReportProgress,
{
measure_time!(
|| {
let mut aggregates_count: AggregatesCountMap = AggregatesCountMap::default();
let max_len = if reporting_length == 0 {
self.data_block.headers.len()
} else {
min(reporting_length, self.data_block.headers.len())
};
let length_range = (1..=max_len).collect::<Vec<usize>>();
let mut record_attrs_selector = RecordAttrsSelector::new(
&length_range,
sensitivity_threshold,
DataBlock::calc_attr_rows(&self.data_block.records),
);
let mut selected_combs_count = 0_u64;
let mut all_combs_count = 0_u64;
let mut records_sensitivity: RecordsSensitivity = RecordsSensitivity::default();
let total_n_records = self.data_block.records.len();
let total_n_records_f64 = total_n_records as f64;
records_sensitivity.resize(total_n_records, 0);
info!(
"aggregating data with reporting length = {} and sensitivity_threshold = {}",
max_len, sensitivity_threshold
);
for (record_index, record) in self.data_block.records.iter().enumerate() {
let selected_attrs = record_attrs_selector.select_from_record(&record.values);
all_combs_count +=
calc_n_combinations_range(record.values.len(), &length_range);
Aggregator::update_aggregate_progress(
progress_reporter,
record_index,
total_n_records_f64,
);
for l in &length_range {
for c in selected_attrs.iter().combinations(*l) {
let curr_count = aggregates_count
.entry(
c.iter()
.sorted_by_key(|k| {
Aggregator::format_data_block_value_str(
&self.data_block.headers,
k,
)
})
.cloned()
.copied()
.collect_vec(),
)
.or_insert_with(AggregatedCount::default);
curr_count.count += 1;
curr_count.contained_in_records.insert(record_index);
records_sensitivity[record_index] += 1;
selected_combs_count += 1;
}
}
}
Aggregator::update_aggregate_progress(
progress_reporter,
total_n_records,
total_n_records_f64,
);
info!(
"data aggregated resulting in {} distinct combinations...",
aggregates_count.len()
);
info!(
"suppression ratio of aggregates is {:.2}%",
(1.0 - (selected_combs_count as f64 / all_combs_count as f64)) * 100.0
);
AggregatedData {
aggregates_count,
records_sensitivity,
}
},
(self.durations.aggregate)
)
}
/// Round the aggregated counts down to the nearest multiple of resolution
/// # Arguments:
/// * `aggregates_count` - Counts to be rounded in place
/// * `resolution` - Reporting resolution used for data synthesis
pub fn protect_aggregates_count(
aggregates_count: &mut AggregatesCountMap<'data_block>,
resolution: usize,
) {
info!(
"protecting aggregates counts with resolution {}",
resolution
);
for count in aggregates_count.values_mut() {
count.count = uround_down(count.count as f64, resolution as f64);
}
// remove 0 counts from response
aggregates_count.retain(|_, count| count.count > 0);
}
/// Calculates the number of rare combinations grouped by combination length
/// # Arguments:
/// * `aggregates_count` - Calculated aggregates count map
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_rare_combinations_count_by_len(
&mut self,
aggregates_count: &AggregatesCountMap<'data_block>,
resolution: usize,
) -> AggregatedCountByLenMap {
info!(
"calculating rare combinations counts by length with resolution {}",
resolution
);
measure_time!(
|| {
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
for (agg, count) in aggregates_count.iter() {
if count.count < resolution {
let curr_count = result.entry(agg.len()).or_insert(0);
*curr_count += 1;
}
}
result
},
(self.durations.calc_rare_combinations_count_by_len)
)
}
/// Calculates the number of combinations grouped by combination length
/// # Arguments:
/// * `aggregates_count` - Calculated aggregates count map
pub fn calc_combinations_count_by_len(
&mut self,
aggregates_count: &AggregatesCountMap<'data_block>,
) -> AggregatedCountByLenMap {
info!("calculating combination counts by length");
measure_time!(
|| {
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
for agg in aggregates_count.keys() {
let curr_count = result.entry(agg.len()).or_insert(0);
*curr_count += 1;
}
result
},
(self.durations.calc_combinations_count_by_len)
)
}
/// Calculates the sum of all combination counts grouped by combination length
/// # Arguments:
/// * `aggregates_count` - Calculated aggregates count map
pub fn calc_combinations_sum_by_len(
&mut self,
aggregates_count: &AggregatesCountMap<'data_block>,
) -> AggregatedCountByLenMap {
info!("calculating combination counts sums by length");
measure_time!(
|| {
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
for (agg, count) in aggregates_count.iter() {
let curr_sum = result.entry(agg.len()).or_insert(0);
*curr_sum += count.count;
}
result
},
(self.durations.calc_combinations_sum_by_len)
)
}
/// Calculates the privacy risk related with data block and the generated
/// aggregates counts
/// # Arguments:
/// * `aggregates_count` - Calculated aggregates count map
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_privacy_risk(
&mut self,
aggregates_count: &AggregatesCountMap<'data_block>,
resolution: usize,
) -> PrivacyRiskSummary {
info!("calculating privacy risk...");
measure_time!(
|| {
let mut records_with_unique_combinations = RecordsSet::default();
let mut records_with_rare_combinations = RecordsSet::default();
let mut unique_combinations_count: usize = 0;
let mut rare_combinations_count: usize = 0;
let total_number_of_records = self.data_block.records.len();
let total_number_of_combinations = aggregates_count.len();
for count in aggregates_count.values() {
if count.count == 1 {
unique_combinations_count += 1;
count.contained_in_records.iter().for_each(|record_index| {
records_with_unique_combinations.insert(*record_index);
});
}
if count.count < resolution {
rare_combinations_count += 1;
count.contained_in_records.iter().for_each(|record_index| {
records_with_rare_combinations.insert(*record_index);
});
}
}
PrivacyRiskSummary {
total_number_of_records,
total_number_of_combinations,
records_with_unique_combinations_count: records_with_unique_combinations.len(),
records_with_rare_combinations_count: records_with_rare_combinations.len(),
unique_combinations_count,
rare_combinations_count,
records_with_unique_combinations_proportion: (records_with_unique_combinations
.len()
as f64)
/ (total_number_of_records as f64),
records_with_rare_combinations_proportion: (records_with_rare_combinations.len()
as f64)
/ (total_number_of_records as f64),
unique_combinations_proportion: (unique_combinations_count as f64)
/ (total_number_of_combinations as f64),
rare_combinations_proportion: (rare_combinations_count as f64)
/ (total_number_of_combinations as f64),
}
},
(self.durations.calc_privacy_risk)
)
}
/// Writes the aggregates counts to the file system in a csv/tsv like format
/// # Arguments:
/// * `aggregates_count` - Calculated aggregates count map
/// * `aggregates_path` - File path to be written
/// * `aggregates_delimiter` - Delimiter to use when writing to `aggregates_path`
/// * `resolution` - Reporting resolution used for data synthesis
/// * `protected` - Whether or not the counts were protected before calling this
pub fn write_aggregates_count(
&mut self,
aggregates_count: &AggregatesCountMap<'data_block>,
aggregates_path: &str,
aggregates_delimiter: char,
resolution: usize,
protected: bool,
) -> Result<(), Error> {
info!("writing file {}", aggregates_path);
measure_time!(
|| {
let mut file = std::fs::File::create(aggregates_path)?;
file.write_all(
format!(
"selections{}{}\n",
aggregates_delimiter,
if protected {
"protected_count"
} else {
"count"
}
)
.as_bytes(),
)?;
file.write_all(
format!(
"selections{}{}\n",
aggregates_delimiter,
uround_down(self.data_block.records.len() as f64, resolution as f64)
)
.as_bytes(),
)?;
for aggregate in aggregates_count.keys() {
file.write_all(
format!(
"{}{}{}\n",
Aggregator::format_aggregate_str(&self.data_block.headers, aggregate),
aggregates_delimiter,
aggregates_count[aggregate].count
)
.as_bytes(),
)?
}
Ok(())
},
(self.durations.write_aggregates_count)
)
}
/// Writes the records sensitivity to the file system in a csv/tsv like format
/// # Arguments:
/// * `records_sensitivity` - Calculated records sensitivity
/// * `records_sensitivity_path` - File path to be written
pub fn write_records_sensitivity(
&mut self,
records_sensitivity: &RecordsSensitivitySlice,
records_sensitivity_path: &str,
) -> Result<(), Error> {
info!("writing file {}", records_sensitivity_path);
measure_time!(
|| {
let mut file = std::fs::File::create(records_sensitivity_path)?;
file.write_all("record_index\trecord_sensitivity\n".as_bytes())?;
for (i, sensitivity) in records_sensitivity.iter().enumerate() {
file.write_all(format!("{}\t{}\n", i, sensitivity).as_bytes())?
}
Ok(())
},
(self.durations.write_aggregates_count)
)
}
#[inline]
fn update_aggregate_progress<T>(
progress_reporter: &mut Option<T>,
n_processed: usize,
total: f64,
) where
T: ReportProgress,
{
if let Some(r) = progress_reporter {
r.report(calc_percentage(n_processed, total));
}
}
}
impl<'data_block> Drop for Aggregator<'data_block> {
fn drop(&mut self) {
trace!("aggregator durations: {:#?}", self.durations);
}
}
mod rows_aggregator;

Просмотреть файл

@ -0,0 +1,144 @@
use super::typedefs::{AggregatesCountMap, RecordsSet};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
/// Represents the privacy risk information related to a data block
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Debug)]
pub struct PrivacyRiskSummary {
/// Total number of records on the data block
pub total_number_of_records: usize,
/// Total number of combinations aggregated (up to reporting length)
pub total_number_of_combinations: usize,
/// Number of records with unique combinations
pub records_with_unique_combinations_count: usize,
/// Number of records with rare combinations (combination count < resolution)
pub records_with_rare_combinations_count: usize,
/// Number of unique combinations
pub unique_combinations_count: usize,
/// Number of rare combinations
pub rare_combinations_count: usize,
/// Proportion of records containing unique combinations
pub records_with_unique_combinations_proportion: f64,
/// Proportion of records containing rare combinations
pub records_with_rare_combinations_proportion: f64,
/// Proportion of unique combinations
pub unique_combinations_proportion: f64,
/// Proportion of rare combinations
pub rare_combinations_proportion: f64,
}
#[cfg(feature = "pyo3")]
#[cfg_attr(feature = "pyo3", pymethods)]
impl PrivacyRiskSummary {
/// Total number of records on the data block
#[getter]
fn total_number_of_records(&self) -> usize {
self.total_number_of_records
}
/// Total number of combinations aggregated (up to reporting length)
#[getter]
fn total_number_of_combinations(&self) -> usize {
self.total_number_of_combinations
}
/// Number of records with unique combinations
#[getter]
fn records_with_unique_combinations_count(&self) -> usize {
self.records_with_unique_combinations_count
}
/// Number of records with rare combinations (combination count < resolution)
#[getter]
fn records_with_rare_combinations_count(&self) -> usize {
self.records_with_rare_combinations_count
}
/// Number of unique combinations
#[getter]
fn unique_combinations_count(&self) -> usize {
self.unique_combinations_count
}
/// Number of rare combinations
#[getter]
fn rare_combinations_count(&self) -> usize {
self.rare_combinations_count
}
/// Proportion of records containing unique combinations
#[getter]
fn records_with_unique_combinations_proportion(&self) -> f64 {
self.records_with_unique_combinations_proportion
}
/// Proportion of records containing rare combinations
#[getter]
fn records_with_rare_combinations_proportion(&self) -> f64 {
self.records_with_rare_combinations_proportion
}
/// Proportion of unique combinations
#[getter]
fn unique_combinations_proportion(&self) -> f64 {
self.unique_combinations_proportion
}
/// Proportion of rare combinations
#[getter]
fn rare_combinations_proportion(&self) -> f64 {
self.rare_combinations_proportion
}
}
impl PrivacyRiskSummary {
#[inline]
/// Calculates ands returns the privacy risk related to the aggregates counts
/// # Arguments:
/// * `total_number_of_records` - Total number of records on the data block
/// * `aggregates_count` - Aggregates counts to compute the privacy risk for
/// * `resolution` - Reporting resolution used for data synthesis
pub fn from_aggregates_count(
total_number_of_records: usize,
aggregates_count: &AggregatesCountMap,
resolution: usize,
) -> PrivacyRiskSummary {
let mut records_with_unique_combinations = RecordsSet::default();
let mut records_with_rare_combinations = RecordsSet::default();
let mut unique_combinations_count: usize = 0;
let mut rare_combinations_count: usize = 0;
let total_number_of_combinations = aggregates_count.len();
for count in aggregates_count.values() {
if count.count == 1 {
unique_combinations_count += 1;
records_with_unique_combinations.extend(&count.contained_in_records);
}
if count.count < resolution {
rare_combinations_count += 1;
records_with_rare_combinations.extend(&count.contained_in_records);
}
}
PrivacyRiskSummary {
total_number_of_records,
total_number_of_combinations,
records_with_unique_combinations_count: records_with_unique_combinations.len(),
records_with_rare_combinations_count: records_with_rare_combinations.len(),
unique_combinations_count,
rare_combinations_count,
records_with_unique_combinations_proportion: (records_with_unique_combinations.len()
as f64)
/ (total_number_of_records as f64),
records_with_rare_combinations_proportion: (records_with_rare_combinations.len()
as f64)
/ (total_number_of_records as f64),
unique_combinations_proportion: (unique_combinations_count as f64)
/ (total_number_of_combinations as f64),
rare_combinations_proportion: (rare_combinations_count as f64)
/ (total_number_of_combinations as f64),
}
}
}

Просмотреть файл

@ -1,5 +1,5 @@
use fnv::FnvHashMap;
use itertools::Itertools;
use std::sync::Arc;
use crate::{
data_block::{typedefs::AttributeRowsMap, value::DataBlockValue},
@ -13,19 +13,19 @@ use crate::{
/// Sensitivity in this context means the number of combinations that
/// can be generated based on the number of non-empty values on the record.
/// So attributes will suppressed from the record until record sensitivity <= `sensitivity_threshold`
pub struct RecordAttrsSelector<'length_range, 'data_block> {
pub struct RecordAttrsSelector<'length_range> {
/// Range to compute the number of combinations [1...reporting_length]
pub length_range: &'length_range [usize],
/// Cache for how many values should be suppressed based of the number
/// of non-empty attributes on the record
cache: FnvHashMap<usize, usize>,
/// Range to compute the number of combinations [1...reporting_length]
length_range: &'length_range [usize],
/// Maximum sensitivity allowed for each record
sensitivity_threshold: usize,
/// Map with a data block value as key and all the attribute row indexes where it occurs as value
attr_rows_map: AttributeRowsMap<'data_block>,
attr_rows_map: Arc<AttributeRowsMap>,
}
impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block> {
impl<'length_range> RecordAttrsSelector<'length_range> {
/// Returns a new RecordAttrsSelector
/// # Arguments
/// * `length_range` - Range to compute the number of combinations [1...reporting_length]
@ -36,8 +36,8 @@ impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block>
pub fn new(
length_range: &'length_range [usize],
sensitivity_threshold: usize,
attr_rows_map: AttributeRowsMap<'data_block>,
) -> RecordAttrsSelector<'length_range, 'data_block> {
attr_rows_map: Arc<AttributeRowsMap>,
) -> RecordAttrsSelector<'length_range> {
RecordAttrsSelector {
cache: FnvHashMap::default(),
length_range,
@ -53,30 +53,30 @@ impl<'length_range, 'data_block> RecordAttrsSelector<'length_range, 'data_block>
/// all records, higher the chance for it **not** to be suppressed
///
/// # Arguments
/// * `record`: record to select attributes from
/// * `record` - record to select attributes from
pub fn select_from_record(
&mut self,
record: &'data_block [DataBlockValue],
) -> Vec<&'data_block DataBlockValue> {
record: &[Arc<DataBlockValue>],
) -> Vec<Arc<DataBlockValue>> {
let n_attributes = record.len();
let selected_attrs_count = n_attributes - self.get_suppressed_count(n_attributes);
if selected_attrs_count < n_attributes {
let mut attr_count_map: AttributeCountMap = record
.iter()
.map(|r| (r, self.attr_rows_map.get(r).unwrap().len()))
.map(|r| (r.clone(), self.attr_rows_map.get(r).unwrap().len()))
.collect();
let mut res: Vec<&DataBlockValue> = Vec::default();
let mut res: Vec<Arc<DataBlockValue>> = Vec::default();
for _ in 0..selected_attrs_count {
if let Some(sample) = SynthesizerContext::sample_from_attr_counts(&attr_count_map) {
attr_count_map.remove(&sample);
res.push(sample);
attr_count_map.remove(sample);
}
}
res
} else {
record.iter().collect_vec()
record.to_vec()
}
}

Просмотреть файл

@ -0,0 +1,247 @@
use super::typedefs::{RecordsAnalysisByLenMap, RecordsByLenMap};
use itertools::Itertools;
use std::io::{Error, Write};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use crate::utils::math::{calc_percentage, uround_down};
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Clone)]
/// Analysis information related to a single record
pub struct RecordsAnalysis {
/// Number of records containing unique combinations
pub unique_combinations_records_count: usize,
/// Percentage of records containing unique combinations
pub unique_combinations_records_percentage: f64,
/// Number of records containing rare combinations
/// (unique os not taken into account)
pub rare_combinations_records_count: usize,
/// Percentage of records containing rare combinations
/// (unique os not taken into account)
pub rare_combinations_records_percentage: f64,
/// Count of unique + rare
pub risky_combinations_records_count: usize,
/// Percentage of unique + rare
pub risky_combinations_records_percentage: f64,
}
impl RecordsAnalysis {
#[inline]
/// Created a new RecordsAnalysis with default values
pub fn default() -> RecordsAnalysis {
RecordsAnalysis {
unique_combinations_records_count: 0,
unique_combinations_records_percentage: 0.0,
rare_combinations_records_count: 0,
rare_combinations_records_percentage: 0.0,
risky_combinations_records_count: 0,
risky_combinations_records_percentage: 0.0,
}
}
}
#[cfg(feature = "pyo3")]
#[cfg_attr(feature = "pyo3", pymethods)]
impl RecordsAnalysis {
#[getter]
/// Number of records containing unique combinations
fn unique_combinations_records_count(&self) -> usize {
self.unique_combinations_records_count
}
#[getter]
/// Percentage of records containing unique combinations
fn unique_combinations_records_percentage(&self) -> f64 {
self.unique_combinations_records_percentage
}
#[getter]
/// Number of records containing rare combinations
/// (unique os not taken into account)
fn rare_combinations_records_count(&self) -> usize {
self.rare_combinations_records_count
}
#[getter]
/// Percentage of records containing rare combinations
/// (unique os not taken into account)
fn rare_combinations_records_percentage(&self) -> f64 {
self.rare_combinations_records_percentage
}
#[getter]
/// Count of unique + rare
fn risky_combinations_records_count(&self) -> usize {
self.risky_combinations_records_count
}
#[getter]
/// Percentage of unique + rare
fn risky_combinations_records_percentage(&self) -> f64 {
self.risky_combinations_records_percentage
}
}
#[cfg_attr(feature = "pyo3", pyclass)]
/// Stores the records analysis for all records grouped by
/// combination length
pub struct RecordsAnalysisData {
/// Map of records analysis grouped by combination len
pub records_analysis_by_len: RecordsAnalysisByLenMap,
}
impl RecordsAnalysisData {
/// Computes the record analysis from the arguments.
/// # Arguments
/// * `unique_records_by_len` - unique records grouped by length
/// * `rare_records_by_len` - rare records grouped by length
/// * `total_number_of_records` - Total number of records on the data block
/// * `reporting_length` - Reporting length used for the data aggregation
/// * `resolution` - Reporting resolution used for data synthesis
/// * `protect` - Whether or not the counts should be rounded to the
/// nearest smallest multiple of resolution
#[inline]
pub fn from_unique_rare_combinations_records_by_len(
unique_records_by_len: &RecordsByLenMap,
rare_records_by_len: &RecordsByLenMap,
total_n_records: usize,
reporting_length: usize,
resolution: usize,
protect: bool,
) -> RecordsAnalysisData {
let total_n_records_f64 = if protect {
uround_down(total_n_records as f64, resolution as f64)
} else {
total_n_records
} as f64;
let records_analysis_by_len: RecordsAnalysisByLenMap = (1..=reporting_length)
.map(|l| {
let mut ra = RecordsAnalysis::default();
ra.unique_combinations_records_count = unique_records_by_len
.get(&l)
.map_or(0, |records| records.len());
ra.rare_combinations_records_count = rare_records_by_len
.get(&l)
.map_or(0, |records| records.len());
if protect {
ra.unique_combinations_records_count = uround_down(
ra.unique_combinations_records_count as f64,
resolution as f64,
);
ra.rare_combinations_records_count =
uround_down(ra.rare_combinations_records_count as f64, resolution as f64)
}
ra.unique_combinations_records_percentage = calc_percentage(
ra.unique_combinations_records_count as f64,
total_n_records_f64,
);
ra.rare_combinations_records_percentage = calc_percentage(
ra.rare_combinations_records_count as f64,
total_n_records_f64,
);
ra.risky_combinations_records_count =
ra.unique_combinations_records_count + ra.rare_combinations_records_count;
ra.risky_combinations_records_percentage = calc_percentage(
ra.risky_combinations_records_count as f64,
total_n_records_f64,
);
(l, ra)
})
.collect();
RecordsAnalysisData {
records_analysis_by_len,
}
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl RecordsAnalysisData {
/// Returns the records analysis map grouped by length.
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
pub fn get_records_analysis_by_len(&self) -> RecordsAnalysisByLenMap {
self.records_analysis_by_len.clone()
}
/// Returns the total number of records containing unique combinations
/// for all lengths
pub fn get_total_unique(&self) -> usize {
self.records_analysis_by_len
.values()
.map(|v| v.unique_combinations_records_count)
.sum()
}
/// Returns the total number of records containing rare combinations
/// for all lengths
pub fn get_total_rare(&self) -> usize {
self.records_analysis_by_len
.values()
.map(|v| v.rare_combinations_records_count)
.sum()
}
/// Returns the total number of records containing risky (unique + rare)
/// combinations for all lengths
pub fn get_total_risky(&self) -> usize {
self.records_analysis_by_len
.values()
.map(|v| v.risky_combinations_records_count)
.sum()
}
/// Writes the records analysis to the file system in a csv/tsv like format
/// # Arguments:
/// * `records_analysis_path` - File path to be written
/// * `records_analysis_delimiter` - Delimiter to use when writing to `records_analysis_path`
pub fn write_records_analysis(
&mut self,
records_analysis_path: &str,
records_analysis_delimiter: char,
) -> Result<(), Error> {
let mut file = std::fs::File::create(records_analysis_path)?;
file.write_all(
format!(
"combo_length{}sen_rare{}sen_rare_pct{}sen_unique{}sen_unique_pct{}sen_risky{}sen_risky_pct\n",
records_analysis_delimiter,
records_analysis_delimiter,
records_analysis_delimiter,
records_analysis_delimiter,
records_analysis_delimiter,
records_analysis_delimiter
)
.as_bytes(),
)?;
for l in self.records_analysis_by_len.keys().sorted() {
let ra = &self.records_analysis_by_len[l];
file.write_all(
format!(
"{}{}{}{}{}{}{}{}{}{}{}{}{}\n",
l,
records_analysis_delimiter,
ra.rare_combinations_records_count,
records_analysis_delimiter,
ra.rare_combinations_records_percentage,
records_analysis_delimiter,
ra.unique_combinations_records_count,
records_analysis_delimiter,
ra.unique_combinations_records_percentage,
records_analysis_delimiter,
ra.risky_combinations_records_count,
records_analysis_delimiter,
ra.risky_combinations_records_percentage,
)
.as_bytes(),
)?
}
Ok(())
}
}

Просмотреть файл

@ -0,0 +1,195 @@
use super::{
record_attrs_selector::RecordAttrsSelector,
typedefs::{AggregatesCountMap, EnumeratedDataBlockRecords, RecordsSensitivity},
value_combination::ValueCombination,
AggregatedCount,
};
use itertools::Itertools;
use log::info;
use std::sync::Arc;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[cfg(feature = "rayon")]
use std::sync::Mutex;
use crate::{
data_block::block::DataBlock,
utils::{
math::calc_n_combinations_range,
reporting::{ReportProgress, SendableProgressReporter, SendableProgressReporterRef},
},
};
pub struct RowsAggregatorResult {
pub all_combs_count: u64,
pub selected_combs_count: u64,
pub aggregates_count: AggregatesCountMap,
pub records_sensitivity: RecordsSensitivity,
}
impl RowsAggregatorResult {
#[inline]
pub fn new(total_n_records: usize) -> RowsAggregatorResult {
let mut records_sensitivity = RecordsSensitivity::default();
records_sensitivity.resize(total_n_records, 0);
RowsAggregatorResult {
all_combs_count: 0,
selected_combs_count: 0,
aggregates_count: AggregatesCountMap::default(),
records_sensitivity,
}
}
}
pub struct RowsAggregator<'length_range> {
data_block: Arc<DataBlock>,
enumerated_records: EnumeratedDataBlockRecords,
record_attrs_selector: RecordAttrsSelector<'length_range>,
}
impl<'length_range> RowsAggregator<'length_range> {
#[inline]
pub fn new(
data_block: Arc<DataBlock>,
enumerated_records: EnumeratedDataBlockRecords,
record_attrs_selector: RecordAttrsSelector<'length_range>,
) -> RowsAggregator {
RowsAggregator {
data_block,
enumerated_records,
record_attrs_selector,
}
}
#[cfg(feature = "rayon")]
#[inline]
pub fn aggregate_all<T>(
total_n_records: usize,
rows_aggregators: &mut Vec<RowsAggregator>,
progress_reporter: &mut Option<T>,
) -> RowsAggregatorResult
where
T: ReportProgress,
{
let sendable_pr =
Arc::new(Mutex::new(progress_reporter.as_mut().map(|r| {
SendableProgressReporter::new(total_n_records as f64, 1.0, r)
})));
RowsAggregator::join_partial_results(
total_n_records,
rows_aggregators
.par_iter_mut()
.map(|ra| ra.aggregate_rows(&mut sendable_pr.clone()))
.collect(),
)
}
#[cfg(not(feature = "rayon"))]
#[inline]
pub fn aggregate_all<T>(
total_n_records: usize,
rows_aggregators: &mut Vec<RowsAggregator>,
progress_reporter: &mut Option<T>,
) -> RowsAggregatorResult
where
T: ReportProgress,
{
let mut sendable_pr = progress_reporter
.as_mut()
.map(|r| SendableProgressReporter::new(total_n_records as f64, 1.0, r));
RowsAggregator::join_partial_results(
total_n_records,
rows_aggregators
.iter_mut()
.map(|ra| ra.aggregate_rows(&mut sendable_pr))
.collect(),
)
}
#[inline]
fn join_partial_results(
total_n_records: usize,
mut partial_results: Vec<RowsAggregatorResult>,
) -> RowsAggregatorResult {
info!("joining aggregated partial results...");
// take last element and the initial result
// or use the default one if there are no results
let mut final_result = partial_results
.pop()
.unwrap_or_else(|| RowsAggregatorResult::new(total_n_records));
// use drain instead of fold, so we do not duplicate memory
for mut partial_result in partial_results.drain(..) {
// join counts
final_result.all_combs_count += partial_result.all_combs_count;
final_result.selected_combs_count += partial_result.selected_combs_count;
// join aggregated counts
for (comb, value) in partial_result.aggregates_count.drain() {
let final_count = final_result
.aggregates_count
.entry(comb)
.or_insert_with(AggregatedCount::default);
final_count.count += value.count;
final_count
.contained_in_records
.extend(value.contained_in_records);
}
// join records sensitivity
for (i, sensitivity) in partial_result.records_sensitivity.drain(..).enumerate() {
final_result.records_sensitivity[i] += sensitivity;
}
}
final_result
}
#[inline]
fn aggregate_rows<T>(
&mut self,
progress_reporter: &mut SendableProgressReporterRef<T>,
) -> RowsAggregatorResult
where
T: ReportProgress,
{
let mut result = RowsAggregatorResult::new(self.data_block.records.len());
for (record_index, record) in self.enumerated_records.iter() {
let mut selected_attrs = self
.record_attrs_selector
.select_from_record(&record.values);
// sort the attributes here, so combinations will be already sorted
// and we do not need to sort entry by entry on the loop below
selected_attrs.sort_by_key(|k| k.format_str_using_headers(&self.data_block.headers));
result.all_combs_count += calc_n_combinations_range(
record.values.len(),
self.record_attrs_selector.length_range,
);
for l in self.record_attrs_selector.length_range {
for mut c in selected_attrs.iter().combinations(*l) {
let current_count = result
.aggregates_count
.entry(ValueCombination::new(
c.drain(..).map(|k| (*k).clone()).collect(),
))
.or_insert_with(AggregatedCount::default);
current_count.count += 1;
current_count.contained_in_records.insert(*record_index);
result.records_sensitivity[*record_index] += 1;
result.selected_combs_count += 1;
}
}
SendableProgressReporter::update_progress(progress_reporter, 1.0);
}
result
}
}

Просмотреть файл

@ -1,26 +1,36 @@
use super::AggregatedCount;
use super::{
data_aggregator::AggregatedCount, records_analysis_data::RecordsAnalysis,
value_combination::ValueCombination,
};
use fnv::{FnvHashMap, FnvHashSet};
use std::sync::Arc;
use crate::data_block::value::DataBlockValue;
use crate::data_block::record::DataBlockRecord;
/// Set of records where the key is the record index starting in 0
pub type RecordsSet = FnvHashSet<usize>;
/// Vector of data block values representing a value combination (sorted by `{header_name}:{block_value}`)
pub type ValueCombination<'data_block_value> = Vec<&'data_block_value DataBlockValue>;
/// Slice of ValueCombination
pub type ValueCombinationSlice<'data_block_value> = [&'data_block_value DataBlockValue];
/// Maps a value combination to its aggregated count
pub type AggregatesCountMap<'data_block_value> =
FnvHashMap<ValueCombination<'data_block_value>, AggregatedCount>;
pub type AggregatesCountMap = FnvHashMap<ValueCombination, AggregatedCount>;
/// Maps a value combination represented as a string to its aggregated count
pub type AggregatesCountStringMap = FnvHashMap<String, AggregatedCount>;
/// Maps a length (1,2,3... up to reporting length) to a determined count
pub type AggregatedCountByLenMap = FnvHashMap<usize, usize>;
/// Maps a length (1,2,3... up to reporting length) to a record set
pub type RecordsByLenMap = FnvHashMap<usize, RecordsSet>;
/// A vector of sensitivities for each record (the vector index is the record index)
pub type RecordsSensitivity = Vec<usize>;
/// Slice of RecordsSensitivity
pub type RecordsSensitivitySlice = [usize];
/// Vector of tuples:
/// (index of the original record, reference to the original record)
pub type EnumeratedDataBlockRecords = Vec<(usize, Arc<DataBlockRecord>)>;
/// Map of records analysis grouped by combination len
pub type RecordsAnalysisByLenMap = FnvHashMap<usize, RecordsAnalysis>;

Просмотреть файл

@ -0,0 +1,152 @@
use serde::{
de::{self, Visitor},
Deserialize, Serialize,
};
use std::{
fmt::Display,
marker::PhantomData,
ops::{Deref, DerefMut},
str::FromStr,
sync::Arc,
};
use crate::data_block::{
typedefs::DataBlockHeadersSlice,
value::{DataBlockValue, ParseDataBlockValueError},
};
const COMBINATIONS_DELIMITER: char = ';';
#[derive(Eq, PartialEq, Hash)]
/// Wraps a vector of data block values representing a value
/// combination (sorted by `{header_name}:{block_value}`)
pub struct ValueCombination {
combination: Vec<Arc<DataBlockValue>>,
}
impl ValueCombination {
#[inline]
/// Creates a new ValueCombination with default values
pub fn default() -> ValueCombination {
ValueCombination::new(Vec::default())
}
#[inline]
/// Creates a new ValueCombination
/// # Arguments
/// * `combination` - raw vector of value combinations
/// sorted by `{header_name}:{block_value}`
pub fn new(combination: Vec<Arc<DataBlockValue>>) -> ValueCombination {
ValueCombination { combination }
}
/// Formats a value combination as String using the headers.
/// The result is formatted as:
/// `{header_name}:{block_value};{header_name}:{block_value}...`
/// # Arguments
/// * `headers` - Data block headers
/// * `combination_delimiter` - Delimiter used to join combinations
#[inline]
pub fn format_str_using_headers(
&self,
headers: &DataBlockHeadersSlice,
combination_delimiter: &str,
) -> String {
let mut str = String::default();
if let Some(comb) = self.combination.get(0) {
str.push_str(&comb.format_str_using_headers(headers));
}
for comb in self.combination.iter().skip(1) {
str += combination_delimiter;
str.push_str(&comb.format_str_using_headers(headers));
}
str
}
}
impl Display for ValueCombination {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(comb) = self.combination.get(0) {
write!(f, "{}", comb)?;
}
for comb in self.combination.iter().skip(1) {
write!(f, "{}{}", COMBINATIONS_DELIMITER, comb)?;
}
Ok(())
}
}
impl FromStr for ValueCombination {
type Err = ParseDataBlockValueError;
/// Creates a new ValueCombination by parsing `str_value`
fn from_str(str_value: &str) -> Result<Self, Self::Err> {
Ok(ValueCombination::new(
str_value
.split(COMBINATIONS_DELIMITER)
.map(|v| Ok(Arc::new(DataBlockValue::from_str(v)?)))
.collect::<Result<Vec<Arc<DataBlockValue>>, Self::Err>>()?,
))
}
}
impl Deref for ValueCombination {
type Target = Vec<Arc<DataBlockValue>>;
fn deref(&self) -> &Self::Target {
&self.combination
}
}
impl DerefMut for ValueCombination {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.combination
}
}
// Implementing a custom serializer, so this can
// be properly written to a json file
impl Serialize for ValueCombination {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&format!("{}", self))
}
}
struct ValueCombinationVisitor {
marker: PhantomData<fn() -> ValueCombination>,
}
impl ValueCombinationVisitor {
fn new() -> Self {
ValueCombinationVisitor {
marker: PhantomData,
}
}
}
impl<'de> Visitor<'de> for ValueCombinationVisitor {
type Value = ValueCombination;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representing the value combination")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
ValueCombination::from_str(v).map_err(E::custom)
}
}
// Implementing a custom deserializer, so this can
// be properly read from a json file
impl<'de> Deserialize<'de> for ValueCombination {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_string(ValueCombinationVisitor::new())
}
}

Просмотреть файл

@ -1,86 +0,0 @@
use std::ops::{Deref, DerefMut};
const INITIAL_BIN: usize = 10;
const BIN_RATIO: usize = 2;
/// Stores the preservation information related to a particular bucket.
/// A bucket stores count in a certain range value.
/// For example: (0, 10], (10, 20], (20, 40]...
#[derive(Debug)]
pub struct PreservationByCountBucket {
/// How many elements are stored in the bucket
pub size: usize,
/// Preservation sum of all elements in the bucket
pub preservation_sum: f64,
/// Combination length sum of all elements in the bucket
pub length_sum: usize,
}
impl PreservationByCountBucket {
/// Return a new PreservationByCountBucket with default values
pub fn default() -> PreservationByCountBucket {
PreservationByCountBucket {
size: 0,
preservation_sum: 0.0,
length_sum: 0,
}
}
/// Adds a new value to the bucket
/// # Arguments
/// * `preservation` - Preservation related to the value
/// * `length` - Combination length related to the value
pub fn add(&mut self, preservation: f64, length: usize) {
self.size += 1;
self.preservation_sum += preservation;
self.length_sum += length;
}
}
/// Stores the bins for each bucket. For example:
/// [10, 20, 40, 80, 160...]
#[derive(Debug)]
pub struct PreservationByCountBucketBins {
/// Bins where the index of the vector is the bin index,
/// and the value is the max value count allowed on the bin
bins: Vec<usize>,
}
impl PreservationByCountBucketBins {
/// Generates a new PreservationByCountBucketBins with bins up to a `max_val`
/// # Arguments
/// * `max_val` - Max value allowed on the last bucket (inclusive)
pub fn new(max_val: usize) -> PreservationByCountBucketBins {
let mut bins: Vec<usize> = vec![INITIAL_BIN];
loop {
let last = *bins.last().unwrap();
if last >= max_val {
break;
}
bins.push(last * BIN_RATIO);
}
PreservationByCountBucketBins { bins }
}
/// Find the first `bucket_max_val` where `val >= bucket_max_val`
/// # Arguments
/// * `val` - Count to look in which bucket it will fit
pub fn find_bucket_max_val(&self, val: usize) -> usize {
// find first element x where x >= val
self.bins[self.bins.partition_point(|x| *x < val)]
}
}
impl Deref for PreservationByCountBucketBins {
type Target = Vec<usize>;
fn deref(&self) -> &Self::Target {
&self.bins
}
}
impl DerefMut for PreservationByCountBucketBins {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.bins
}
}

Просмотреть файл

@ -0,0 +1,252 @@
use super::preservation_by_count::{PreservationByCountBucketBins, PreservationByCountBuckets};
use super::rare_combinations_comparison_data::RareCombinationsComparisonData;
use fnv::FnvHashSet;
use log::info;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use crate::processing::aggregator::aggregated_data::AggregatedData;
use crate::processing::aggregator::typedefs::AggregatedCountByLenMap;
use crate::processing::aggregator::value_combination::ValueCombination;
use crate::processing::evaluator::preservation_bucket::PreservationBucket;
use crate::processing::evaluator::preservation_by_length::PreservationByLengthBuckets;
use crate::utils::time::ElapsedDurationLogger;
#[cfg_attr(feature = "pyo3", pyclass)]
/// Evaluates aggregated, sensitive and synthesized data
pub struct Evaluator {}
impl Evaluator {
/// Returns a new Evaluator
pub fn default() -> Evaluator {
Evaluator {}
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl Evaluator {
/// Returns a new Evaluator
#[cfg(feature = "pyo3")]
#[new]
pub fn default_pyo3() -> Evaluator {
Evaluator::default()
}
/// Calculates the leakage counts grouped by combination length
/// (how many attribute combinations exist on the sensitive data and
/// appear on the synthetic data with `count < resolution`).
/// By design this should be `0`
/// # Arguments
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_leakage_count(
&self,
sensitive_aggregated_data: &AggregatedData,
synthetic_aggregated_data: &AggregatedData,
resolution: usize,
) -> AggregatedCountByLenMap {
let _duration_logger = ElapsedDurationLogger::new("leakage count calculation");
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
info!("calculating rare sensitive combination leakages by length");
for (sensitive_agg, sensitive_count) in sensitive_aggregated_data.aggregates_count.iter() {
if sensitive_count.count < resolution {
if let Some(synthetic_count) = synthetic_aggregated_data
.aggregates_count
.get(sensitive_agg)
{
if synthetic_count.count < resolution {
let leaks = result.entry(sensitive_agg.len()).or_insert(0);
*leaks += 1;
}
}
}
}
result
}
/// Calculates the fabricated counts grouped by combination length
/// (how many attribute combinations exist on the synthetic data that do not
/// exist on the sensitive data).
/// By design this should be `0`
/// # Arguments
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
pub fn calc_fabricated_count(
&self,
sensitive_aggregated_data: &AggregatedData,
synthetic_aggregated_data: &AggregatedData,
) -> AggregatedCountByLenMap {
let _duration_logger = ElapsedDurationLogger::new("fabricated count calculation");
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
info!("calculating fabricated synthetic combinations by length");
for synthetic_agg in synthetic_aggregated_data.aggregates_count.keys() {
if sensitive_aggregated_data
.aggregates_count
.get(synthetic_agg)
.is_none()
{
let fabricated = result.entry(synthetic_agg.len()).or_insert(0);
*fabricated += 1;
}
}
result
}
/// Calculates the preservation information grouped by combination count.
/// An example output might be a map like:
/// `{ 10 -> PreservationBucket, 20 -> PreservationBucket, ...}`
/// # Arguments
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_preservation_by_count(
&self,
sensitive_aggregated_data: &AggregatedData,
synthetic_aggregated_data: &AggregatedData,
resolution: usize,
) -> PreservationByCountBuckets {
let _duration_logger = ElapsedDurationLogger::new("preservation by count calculation");
info!(
"calculating preservation by count with resolution: {}",
resolution
);
let max_syn_count = *synthetic_aggregated_data
.aggregates_count
.values()
.map(|a| &a.count)
.max()
.unwrap_or(&0);
let bins: PreservationByCountBucketBins = PreservationByCountBucketBins::new(max_syn_count);
let mut buckets: PreservationByCountBuckets = PreservationByCountBuckets::default();
let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default();
for (comb, count) in sensitive_aggregated_data.aggregates_count.iter() {
// exclude sensitive rare combinations
if count.count >= resolution && !processed_combs.contains(comb) {
buckets.populate(
&bins,
comb,
&sensitive_aggregated_data.aggregates_count,
&synthetic_aggregated_data.aggregates_count,
);
processed_combs.insert(comb);
}
}
for comb in synthetic_aggregated_data.aggregates_count.keys() {
if !processed_combs.contains(comb) {
buckets.populate(
&bins,
comb,
&sensitive_aggregated_data.aggregates_count,
&synthetic_aggregated_data.aggregates_count,
);
processed_combs.insert(comb);
}
}
// fill empty buckets with default value
for bin in bins.iter() {
buckets
.entry(*bin)
.or_insert_with(PreservationBucket::default);
}
buckets
}
/// Calculates the preservation information grouped by combination length.
/// An example output might be a map like:
/// `{ 1 -> PreservationBucket, 2 -> PreservationBucket, ...}`
/// # Arguments
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_preservation_by_length(
&self,
sensitive_aggregated_data: &AggregatedData,
synthetic_aggregated_data: &AggregatedData,
resolution: usize,
) -> PreservationByLengthBuckets {
let _duration_logger = ElapsedDurationLogger::new("preservation by length calculation");
info!(
"calculating preservation by length with resolution: {}",
resolution
);
let mut buckets: PreservationByLengthBuckets = PreservationByLengthBuckets::default();
let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default();
for (comb, count) in sensitive_aggregated_data.aggregates_count.iter() {
// exclude sensitive rare combinations
if count.count >= resolution && !processed_combs.contains(comb) {
buckets.populate(
comb,
&sensitive_aggregated_data.aggregates_count,
&synthetic_aggregated_data.aggregates_count,
);
processed_combs.insert(comb);
}
}
for comb in synthetic_aggregated_data.aggregates_count.keys() {
if !processed_combs.contains(comb) {
buckets.populate(
comb,
&sensitive_aggregated_data.aggregates_count,
&synthetic_aggregated_data.aggregates_count,
);
processed_combs.insert(comb);
}
}
// fill empty buckets with default value
for l in 1..=synthetic_aggregated_data.reporting_length {
buckets.entry(l).or_insert_with(PreservationBucket::default);
}
buckets
}
//// Compares the rare combinations on the synthetic data with
/// the sensitive data counts
/// # Arguments
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
/// * `resolution` - Reporting resolution used for data synthesis
/// * `combination_delimiter` - Delimiter used to join combinations and format then
/// as strings
/// * `protect` - Whether or not the sensitive counts should be rounded to the
/// nearest smallest multiple of resolution
pub fn compare_synthetic_and_sensitive_rare(
&self,
synthetic_aggregated_data: &AggregatedData,
sensitive_aggregated_data: &AggregatedData,
resolution: usize,
combination_delimiter: &str,
protect: bool,
) -> RareCombinationsComparisonData {
let _duration_logger =
ElapsedDurationLogger::new("synthetic and sensitive rare comparison");
RareCombinationsComparisonData::from_synthetic_and_sensitive_aggregated_data(
synthetic_aggregated_data,
sensitive_aggregated_data,
resolution,
combination_delimiter,
protect,
)
}
}
#[cfg(feature = "pyo3")]
pub fn register(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Evaluator>()?;
Ok(())
}

Просмотреть файл

@ -1,227 +1,23 @@
pub mod bucket;
/// Module defining routines to calculate preservation by count
pub mod preservation_by_count;
/// Module defining routines to calculate preservation by length
pub mod preservation_by_length;
/// Module defining a bucket used to calculate preservation by count
/// and length
pub mod preservation_bucket;
/// Defines structures related to rare combination comparisons
/// between the synthetic and sensitive datasets
pub mod rare_combinations_comparison_data;
/// Type definitions related to the evaluation process
pub mod typedefs;
use self::bucket::{PreservationByCountBucket, PreservationByCountBucketBins};
use self::typedefs::PreservationByCountBuckets;
use fnv::FnvHashSet;
use instant::Duration;
use log::Level::Trace;
use log::{info, log_enabled, trace};
mod data_evaluator;
use crate::measure_time;
use crate::processing::aggregator::typedefs::ValueCombination;
use crate::utils::time::ElapsedDuration;
pub use data_evaluator::Evaluator;
use super::aggregator::typedefs::{
AggregatedCountByLenMap, AggregatesCountMap, ValueCombinationSlice,
};
#[derive(Debug)]
struct EvaluatorDurations {
calc_leakage_count: Duration,
calc_fabricated_count: Duration,
calc_preservation_by_count: Duration,
calc_combination_loss: Duration,
}
/// Evaluates aggregated, sensitive and synthesized data
pub struct Evaluator {
durations: EvaluatorDurations,
}
impl Evaluator {
/// Returns a new Evaluator
#[inline]
pub fn default() -> Evaluator {
Evaluator {
durations: EvaluatorDurations {
calc_leakage_count: Duration::default(),
calc_fabricated_count: Duration::default(),
calc_preservation_by_count: Duration::default(),
calc_combination_loss: Duration::default(),
},
}
}
/// Calculates the leakage counts grouped by combination length
/// (how many attribute combinations exist on the sensitive data and
/// appear on the synthetic data with `count < resolution`).
/// By design this should be `0`
/// # Arguments
/// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data
/// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data
pub fn calc_leakage_count(
&mut self,
sensitive_aggregates: &AggregatesCountMap,
synthetic_aggregates: &AggregatesCountMap,
resolution: usize,
) -> AggregatedCountByLenMap {
info!("calculating rare sensitive combination leakages by length");
measure_time!(
|| {
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
for (sensitive_agg, sensitive_count) in sensitive_aggregates.iter() {
if sensitive_count.count < resolution {
if let Some(synthetic_count) = synthetic_aggregates.get(sensitive_agg) {
if synthetic_count.count < resolution {
let leaks = result.entry(sensitive_agg.len()).or_insert(0);
*leaks += 1;
}
}
}
}
result
},
(self.durations.calc_leakage_count)
)
}
/// Calculates the fabricated counts grouped by combination length
/// (how many attribute combinations exist on the synthetic data that do not
/// exist on the sensitive data).
/// By design this should be `0`
/// # Arguments
/// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data
/// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data
pub fn calc_fabricated_count(
&mut self,
sensitive_aggregates: &AggregatesCountMap,
synthetic_aggregates: &AggregatesCountMap,
) -> AggregatedCountByLenMap {
info!("calculating fabricated synthetic combinations by length");
measure_time!(
|| {
let mut result: AggregatedCountByLenMap = AggregatedCountByLenMap::default();
for synthetic_agg in synthetic_aggregates.keys() {
if sensitive_aggregates.get(synthetic_agg).is_none() {
let fabricated = result.entry(synthetic_agg.len()).or_insert(0);
*fabricated += 1;
}
}
result
},
(self.durations.calc_fabricated_count)
)
}
/// Calculates the preservation information by bucket.
/// An example output might be a map like:
/// `{ 10 -> PreservationByCountBucket, 20 -> PreservationByCountBucket, ...}`
/// # Arguments
/// - `sensitive_aggregates` - Calculated aggregates counts for the sensitive data
/// - `synthetic_aggregates` - Calculated aggregates counts for the synthetic data
/// * `resolution` - Reporting resolution used for data synthesis
pub fn calc_preservation_by_count(
&mut self,
sensitive_aggregates: &AggregatesCountMap,
synthetic_aggregates: &AggregatesCountMap,
resolution: usize,
) -> PreservationByCountBuckets {
info!(
"calculating preservation by count with resolution: {}",
resolution
);
measure_time!(
|| {
let max_syn_count = *synthetic_aggregates
.values()
.map(|a| &a.count)
.max()
.unwrap_or(&0);
let bins: PreservationByCountBucketBins =
PreservationByCountBucketBins::new(max_syn_count);
let mut buckets: PreservationByCountBuckets = PreservationByCountBuckets::default();
let mut processed_combs: FnvHashSet<&ValueCombination> = FnvHashSet::default();
for (comb, count) in sensitive_aggregates.iter() {
if count.count >= resolution && !processed_combs.contains(comb) {
Evaluator::populate_buckets(
&mut buckets,
&bins,
comb,
sensitive_aggregates,
synthetic_aggregates,
);
processed_combs.insert(comb);
}
}
for comb in synthetic_aggregates.keys() {
if !processed_combs.contains(comb) {
Evaluator::populate_buckets(
&mut buckets,
&bins,
comb,
sensitive_aggregates,
synthetic_aggregates,
);
processed_combs.insert(comb);
}
}
// fill empty buckets with default value
for bin in bins.iter() {
if !buckets.contains_key(bin) {
buckets.insert(*bin, PreservationByCountBucket::default());
}
}
buckets
},
(self.durations.calc_preservation_by_count)
)
}
/// Calculates the combination loss for the calculated bucket preservation information
/// (`combination_loss = avg(1 - bucket_preservation_sum / bucket_size) for all buckets`)
/// # Arguments
/// - `buckets` - Calculated preservation buckets
pub fn calc_combination_loss(&mut self, buckets: &PreservationByCountBuckets) -> f64 {
measure_time!(
|| {
buckets
.values()
.map(|b| 1.0 - (b.preservation_sum / (b.size as f64)))
.sum::<f64>()
/ (buckets.len() as f64)
},
(self.durations.calc_combination_loss)
)
}
#[inline]
fn populate_buckets(
buckets: &mut PreservationByCountBuckets,
bins: &PreservationByCountBucketBins,
comb: &ValueCombinationSlice,
sensitive_aggregates: &AggregatesCountMap,
synthetic_aggregates: &AggregatesCountMap,
) {
let sen_count = match sensitive_aggregates.get(comb) {
Some(count) => count.count,
_ => 0,
};
let syn_count = match synthetic_aggregates.get(comb) {
Some(count) => count.count,
_ => 0,
};
let preservation = if sen_count > 0 {
// max value is 100%, so use min
f64::min((syn_count as f64) / (sen_count as f64), 1.0)
} else {
0.0
};
buckets
.entry(bins.find_bucket_max_val(syn_count))
.or_insert_with(PreservationByCountBucket::default)
.add(preservation, comb.len());
}
}
impl Drop for Evaluator {
fn drop(&mut self) {
trace!("evaluator durations: {:#?}", self.durations);
}
}
#[cfg(feature = "pyo3")]
pub use data_evaluator::register;

Просмотреть файл

@ -0,0 +1,58 @@
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
/// Bucket to store preservation information
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Debug, Clone)]
pub struct PreservationBucket {
/// How many elements are stored in the bucket
pub size: usize,
/// Preservation sum of all elements in the bucket
pub preservation_sum: f64,
/// Combination length sum of all elements in the bucket
pub length_sum: usize,
/// Combination count sum
pub combination_count_sum: usize,
}
impl PreservationBucket {
/// Return a new PreservationBucket with default values
pub fn default() -> PreservationBucket {
PreservationBucket {
size: 0,
preservation_sum: 0.0,
length_sum: 0,
combination_count_sum: 0,
}
}
/// Adds a new value to the bucket
/// # Arguments
/// * `preservation` - Preservation related to the value
/// * `length` - Combination length related to the value
/// * `combination_count` - Combination count related to the value
pub fn add(&mut self, preservation: f64, length: usize, combination_count: usize) {
self.size += 1;
self.preservation_sum += preservation;
self.length_sum += length;
self.combination_count_sum += combination_count;
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl PreservationBucket {
/// Gets the mean preservation for the values in this bucket
pub fn get_mean_preservation(&self) -> f64 {
self.preservation_sum / (self.size as f64)
}
/// Gets the mean combination length for the values in this bucket
pub fn get_mean_combination_length(&self) -> f64 {
(self.length_sum as f64) / (self.size as f64)
}
/// Gets the mean combination count for the values in this bucket
pub fn get_mean_combination_count(&self) -> f64 {
(self.combination_count_sum as f64) / (self.size as f64)
}
}

Просмотреть файл

@ -0,0 +1,191 @@
use super::{preservation_bucket::PreservationBucket, typedefs::PreservationBucketsMap};
use itertools::Itertools;
use std::{
io::{Error, Write},
ops::{Deref, DerefMut},
};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use crate::{
processing::aggregator::{typedefs::AggregatesCountMap, value_combination::ValueCombination},
utils::time::ElapsedDurationLogger,
};
const INITIAL_BIN: usize = 10;
const BIN_RATIO: usize = 2;
#[cfg_attr(feature = "pyo3", pyclass)]
/// Wrapping struct mapping the max value allowed in the bucket
/// to its correspondent PreservationBucket.
/// In this context a PreservationBucket stores the preservation information
/// related to a particular bucket.
/// A bucket stores counts in a certain range value.
/// For example: (0, 10], (10, 20], (20, 40]...
pub struct PreservationByCountBuckets {
buckets_map: PreservationBucketsMap,
}
impl PreservationByCountBuckets {
/// Returns a new default PreservationByCountBuckets
#[inline]
pub fn default() -> PreservationByCountBuckets {
PreservationByCountBuckets {
buckets_map: PreservationBucketsMap::default(),
}
}
#[inline]
pub(super) fn populate(
&mut self,
bins: &PreservationByCountBucketBins,
comb: &ValueCombination,
sensitive_aggregates: &AggregatesCountMap,
synthetic_aggregates: &AggregatesCountMap,
) {
let sen_count = match sensitive_aggregates.get(comb) {
Some(count) => count.count,
_ => 0,
};
let syn_count = match synthetic_aggregates.get(comb) {
Some(count) => count.count,
_ => 0,
};
let preservation = if sen_count > 0 {
// max value is 100%, so use min
f64::min((syn_count as f64) / (sen_count as f64), 1.0)
} else {
0.0
};
self.buckets_map
.entry(bins.find_bucket_max_val(syn_count))
.or_insert_with(PreservationBucket::default)
.add(preservation, comb.len(), syn_count);
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl PreservationByCountBuckets {
/// Returns the actual buckets grouped by its max value.
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
pub fn get_buckets(&self) -> PreservationBucketsMap {
self.buckets_map.clone()
}
/// Calculates the combination loss for the calculated bucket preservation information
/// (`combination_loss = avg(1 - bucket_preservation_sum / bucket_size) for all buckets`)
pub fn calc_combination_loss(&self) -> f64 {
let _duration_logger = ElapsedDurationLogger::new("combination loss calculation");
self.buckets_map
.values()
.map(|b| 1.0 - (b.preservation_sum / (b.size as f64)))
.sum::<f64>()
/ (self.buckets_map.len() as f64)
}
/// Writes the preservation grouped by counts to the file system in a csv/tsv like format
/// # Arguments:
/// * `preservation_by_count_path` - File path to be written
/// * `preservation_by_count_delimiter` - Delimiter to use when writing to `preservation_by_count_path`
pub fn write_preservation_by_count(
&self,
preservation_by_count_path: &str,
preservation_by_count_delimiter: char,
) -> Result<(), Error> {
let mut file = std::fs::File::create(preservation_by_count_path)?;
file.write_all(
format!(
"syn_count_bucket{}mean_combo_count{}mean_combo_length{}count_preservation\n",
preservation_by_count_delimiter,
preservation_by_count_delimiter,
preservation_by_count_delimiter,
)
.as_bytes(),
)?;
for max_val in self.buckets_map.keys().sorted().rev() {
let b = &self.buckets_map[max_val];
file.write_all(
format!(
"{}{}{}{}{}{}{}\n",
max_val,
preservation_by_count_delimiter,
b.get_mean_combination_count(),
preservation_by_count_delimiter,
b.get_mean_combination_length(),
preservation_by_count_delimiter,
b.get_mean_preservation(),
)
.as_bytes(),
)?
}
Ok(())
}
}
impl Deref for PreservationByCountBuckets {
type Target = PreservationBucketsMap;
fn deref(&self) -> &Self::Target {
&self.buckets_map
}
}
impl DerefMut for PreservationByCountBuckets {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.buckets_map
}
}
/// Stores the bins for each bucket. For example:
/// [10, 20, 40, 80, 160...]
#[derive(Debug)]
pub struct PreservationByCountBucketBins {
/// Bins where the index of the vector is the bin index,
/// and the value is the max value count allowed on the bin
bins: Vec<usize>,
}
impl PreservationByCountBucketBins {
/// Generates a new PreservationByCountBucketBins with bins up to a `max_val`
/// # Arguments
/// * `max_val` - Max value allowed on the last bucket (inclusive)
pub fn new(max_val: usize) -> PreservationByCountBucketBins {
let mut bins: Vec<usize> = vec![INITIAL_BIN];
loop {
let last = *bins.last().unwrap();
if last >= max_val {
break;
}
bins.push(last * BIN_RATIO);
}
PreservationByCountBucketBins { bins }
}
/// Find the first `bucket_max_val` where `val >= bucket_max_val`
/// # Arguments
/// * `val` - Count to look in which bucket it will fit
pub fn find_bucket_max_val(&self, val: usize) -> usize {
// find first element x where x >= val
self.bins[self.bins.partition_point(|x| *x < val)]
}
}
impl Deref for PreservationByCountBucketBins {
type Target = Vec<usize>;
fn deref(&self) -> &Self::Target {
&self.bins
}
}
impl DerefMut for PreservationByCountBucketBins {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.bins
}
}

Просмотреть файл

@ -0,0 +1,118 @@
use super::{preservation_bucket::PreservationBucket, typedefs::PreservationBucketsMap};
use itertools::Itertools;
use std::{
io::{Error, Write},
ops::{Deref, DerefMut},
};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use crate::processing::aggregator::{
typedefs::AggregatesCountMap, value_combination::ValueCombination,
};
#[cfg_attr(feature = "pyo3", pyclass)]
/// Wrapping to store the preservation buckets grouped
/// by length
pub struct PreservationByLengthBuckets {
buckets_map: PreservationBucketsMap,
}
impl PreservationByLengthBuckets {
/// Returns a new default PreservationByLengthBuckets
#[inline]
pub fn default() -> PreservationByLengthBuckets {
PreservationByLengthBuckets {
buckets_map: PreservationBucketsMap::default(),
}
}
#[inline]
pub(super) fn populate(
&mut self,
comb: &ValueCombination,
sensitive_aggregates: &AggregatesCountMap,
synthetic_aggregates: &AggregatesCountMap,
) {
let sen_count = match sensitive_aggregates.get(comb) {
Some(count) => count.count,
_ => 0,
};
let syn_count = match synthetic_aggregates.get(comb) {
Some(count) => count.count,
_ => 0,
};
let preservation = if sen_count > 0 {
// max value is 100%, so use min as syn count might be > sen_count
f64::min((syn_count as f64) / (sen_count as f64), 1.0)
} else {
0.0
};
self.buckets_map
.entry(comb.len())
.or_insert_with(PreservationBucket::default)
.add(preservation, comb.len(), syn_count);
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl PreservationByLengthBuckets {
/// Returns the actual buckets grouped by length.
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
pub fn get_buckets(&self) -> PreservationBucketsMap {
self.buckets_map.clone()
}
/// Writes the preservation grouped by length to the file system in a csv/tsv like format
/// # Arguments:
/// * `preservation_by_length_path` - File path to be written
/// * `preservation_by_length_delimiter` - Delimiter to use when writing to `preservation_by_length_path`
pub fn write_preservation_by_length(
&mut self,
preservation_by_length_path: &str,
preservation_by_length_delimiter: char,
) -> Result<(), Error> {
let mut file = std::fs::File::create(preservation_by_length_path)?;
file.write_all(
format!(
"syn_combo_length{}mean_combo_count{}count_preservation\n",
preservation_by_length_delimiter, preservation_by_length_delimiter,
)
.as_bytes(),
)?;
for length in self.buckets_map.keys().sorted() {
let b = &self.buckets_map[length];
file.write_all(
format!(
"{}{}{}{}{}\n",
length,
preservation_by_length_delimiter,
b.get_mean_combination_count(),
preservation_by_length_delimiter,
b.get_mean_preservation(),
)
.as_bytes(),
)?
}
Ok(())
}
}
impl Deref for PreservationByLengthBuckets {
type Target = PreservationBucketsMap;
fn deref(&self) -> &Self::Target {
&self.buckets_map
}
}
impl DerefMut for PreservationByLengthBuckets {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.buckets_map
}
}

Просмотреть файл

@ -0,0 +1,197 @@
use std::io::{Error, Write};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use crate::{
data_block::typedefs::CombinationsComparisons,
processing::aggregator::aggregated_data::AggregatedData, utils::math::uround_down,
};
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Clone)]
/// Represents a single combination comparison
pub struct CombinationComparison {
/// Length of the combination
pub combination_length: usize,
/// Combination formatted as string
pub combination: String,
/// Index of the record where this combination occur
pub record_index: usize,
/// Number of times this combination occurs on the synthetic data
pub synthetic_count: usize,
/// Number of times this combination occurs on the sensitive data
pub sensitive_count: usize,
}
impl CombinationComparison {
#[inline]
/// Creates a new CombinationComparison
pub fn new(
combination_length: usize,
combination: String,
record_index: usize,
synthetic_count: usize,
sensitive_count: usize,
) -> CombinationComparison {
CombinationComparison {
combination_length,
combination,
record_index,
synthetic_count,
sensitive_count,
}
}
}
#[cfg(feature = "pyo3")]
#[cfg_attr(feature = "pyo3", pymethods)]
impl CombinationComparison {
#[getter]
/// Length of the combination
fn combination_length(&self) -> usize {
self.combination_length
}
#[getter]
/// Combination formatted as string
fn combination(&self) -> String {
self.combination.clone()
}
#[getter]
/// Index of the record where this combination occur
fn record_index(&self) -> usize {
self.record_index
}
#[getter]
/// Number of times this combination occurs on the synthetic data
fn synthetic_count(&self) -> usize {
self.synthetic_count
}
#[getter]
/// Number of times this combination occurs on the sensitive data
fn sensitive_count(&self) -> usize {
self.sensitive_count
}
}
#[cfg_attr(feature = "pyo3", pyclass)]
/// Computed rare combination comparisons between the synthetic
/// and sensitive datasets for all records
pub struct RareCombinationsComparisonData {
/// Rare combination comparison for all records
/// (compares rare combinations on the synthetic dataset with
/// the sensitive counts)
pub rare_combinations: CombinationsComparisons,
}
impl RareCombinationsComparisonData {
//// Build a new comparison between the rare combinations on the synthetic
/// data and the sensitive data counts
/// # Arguments
/// * `sensitive_aggregated_data` - Calculated aggregated data for the sensitive data
/// * `synthetic_aggregated_data` - Calculated aggregated data for the synthetic data
/// * `resolution` - Reporting resolution used for data synthesis
/// * `combination_delimiter` - Delimiter used to join combinations and format then
/// as strings
/// * `protect` - Whether or not the sensitive counts should be rounded to the
/// nearest smallest multiple of resolution
#[inline]
pub fn from_synthetic_and_sensitive_aggregated_data(
synthetic_aggregated_data: &AggregatedData,
sensitive_aggregated_data: &AggregatedData,
resolution: usize,
combination_delimiter: &str,
protect: bool,
) -> RareCombinationsComparisonData {
let mut rare_combinations: CombinationsComparisons = CombinationsComparisons::default();
let resolution_f64 = resolution as f64;
for (agg, count) in synthetic_aggregated_data.aggregates_count.iter() {
if count.count < resolution {
let combination_str = agg.format_str_using_headers(
&synthetic_aggregated_data.data_block.headers,
combination_delimiter,
);
let mut sensitive_count = sensitive_aggregated_data
.aggregates_count
.get(agg)
.map_or(0, |c| c.count);
if protect {
sensitive_count = uround_down(sensitive_count as f64, resolution_f64)
}
for record_index in count.contained_in_records.iter() {
rare_combinations.push(CombinationComparison::new(
agg.len(),
combination_str.clone(),
*record_index,
count.count,
sensitive_count,
));
}
}
}
// sort result by combination length
rare_combinations.sort_by_key(|c| c.combination_length);
RareCombinationsComparisonData { rare_combinations }
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl RareCombinationsComparisonData {
/// Returns the rare combination comparisons between the synthetic
/// and sensitive datasets for all records.
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
pub fn get_rare_combinations(&self) -> CombinationsComparisons {
self.rare_combinations.clone()
}
/// Writes the rare combinations to the file system in a csv/tsv like format
/// # Arguments:
/// * `rare_combinations_path` - File path to be written
/// * `rare_combinations_delimiter` - Delimiter to use when writing to `rare_combinations_path`
pub fn write_rare_combinations(
&mut self,
rare_combinations_path: &str,
rare_combinations_delimiter: char,
) -> Result<(), Error> {
let mut file = std::fs::File::create(rare_combinations_path)?;
file.write_all(
format!(
"combo_length{}combo{}record_id{}syn_count{}sen_count\n",
rare_combinations_delimiter,
rare_combinations_delimiter,
rare_combinations_delimiter,
rare_combinations_delimiter,
)
.as_bytes(),
)?;
for ra in self.rare_combinations.iter() {
file.write_all(
format!(
"{}{}{}{}{}{}{}{}{}\n",
ra.combination_length,
rare_combinations_delimiter,
ra.combination,
rare_combinations_delimiter,
ra.record_index,
rare_combinations_delimiter,
ra.synthetic_count,
rare_combinations_delimiter,
ra.sensitive_count,
)
.as_bytes(),
)?
}
Ok(())
}
}

Просмотреть файл

@ -1,6 +1,5 @@
use super::bucket::PreservationByCountBucket;
use super::preservation_bucket::PreservationBucket;
use fnv::FnvHashMap;
/// Maps the max value allowed in the bucket to its
/// correspondent PreservationByCountBucket
pub type PreservationByCountBuckets = FnvHashMap<usize, PreservationByCountBucket>;
/// Maps a value to its correspondent PreservationBucket
pub type PreservationBucketsMap = FnvHashMap<usize, PreservationBucket>;

Просмотреть файл

@ -0,0 +1,143 @@
use super::generated_data::GeneratedData;
use super::synthesizer::typedefs::SynthesizedRecords;
use super::synthesizer::unseeded::UnseededSynthesizer;
use super::SynthesisMode;
use log::info;
use std::sync::Arc;
use crate::data_block::block::DataBlock;
use crate::data_block::typedefs::RawSyntheticData;
use crate::processing::generator::synthesizer::cache::SynthesizerCacheKey;
use crate::processing::generator::synthesizer::seeded::SeededSynthesizer;
use crate::utils::reporting::ReportProgress;
use crate::utils::time::ElapsedDurationLogger;
/// Process a data block and generates new synthetic data
pub struct Generator {
data_block: Arc<DataBlock>,
}
impl Generator {
/// Returns a new Generator
/// # Arguments
/// * `data_block` - Sensitive data to be synthesized
#[inline]
pub fn new(data_block: Arc<DataBlock>) -> Generator {
Generator { data_block }
}
/// Generates new synthetic data based on sensitive data
/// # Arguments
/// * `resolution` - Reporting resolution used for data synthesis
/// * `cache_max_size` - Maximum cache size allowed
/// * `empty_value` - Empty values on the synthetic data will be represented by this
/// * `mode` - Which mode to perform the data synthesis
/// * `progress_reporter` - Will be used to report the processing
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
pub fn generate<T>(
&mut self,
resolution: usize,
cache_max_size: usize,
empty_value: String,
mode: SynthesisMode,
progress_reporter: &mut Option<T>,
) -> GeneratedData
where
T: ReportProgress,
{
let _duration_logger = ElapsedDurationLogger::new("data generation");
let empty_value_arc = Arc::new(empty_value);
info!("starting {} generation...", mode);
let synthesized_records = match mode {
SynthesisMode::Seeded => {
self.seeded_synthesis(resolution, cache_max_size, progress_reporter)
}
SynthesisMode::Unseeded => self.unseeded_synthesis(
resolution,
cache_max_size,
&empty_value_arc,
progress_reporter,
),
};
self.build_generated_data(synthesized_records, empty_value_arc)
}
#[inline]
fn build_generated_data(
&self,
mut synthesized_records: SynthesizedRecords,
empty_value: Arc<String>,
) -> GeneratedData {
let mut result: RawSyntheticData = RawSyntheticData::default();
let mut records: RawSyntheticData = synthesized_records
.drain(..)
.map(|r| {
SynthesizerCacheKey::new(self.data_block.headers.len(), &r)
.format_record(&empty_value)
})
.collect();
// sort by number of defined attributes
records.sort();
records.sort_by_key(|r| {
-r.iter()
.map(|s| if s.is_empty() { 0 } else { 1 })
.sum::<isize>()
});
result.push(self.data_block.headers.to_vec());
result.extend(records);
let expansion_ratio = (result.len() - 1) as f64 / self.data_block.records.len() as f64;
info!("expansion ratio: {:.4?}", expansion_ratio);
GeneratedData::new(result, expansion_ratio)
}
#[inline]
fn seeded_synthesis<T>(
&self,
resolution: usize,
cache_max_size: usize,
progress_reporter: &mut Option<T>,
) -> SynthesizedRecords
where
T: ReportProgress,
{
let attr_rows_map = Arc::new(self.data_block.calc_attr_rows());
let mut synth = SeededSynthesizer::new(
self.data_block.clone(),
attr_rows_map,
resolution,
cache_max_size,
);
synth.run(progress_reporter)
}
#[inline]
fn unseeded_synthesis<T>(
&self,
resolution: usize,
cache_max_size: usize,
empty_value: &Arc<String>,
progress_reporter: &mut Option<T>,
) -> SynthesizedRecords
where
T: ReportProgress,
{
let attr_rows_map_by_column =
Arc::new(self.data_block.calc_attr_rows_by_column(empty_value));
let mut synth = UnseededSynthesizer::new(
self.data_block.clone(),
attr_rows_map_by_column,
resolution,
cache_max_size,
empty_value.clone(),
);
synth.run(progress_reporter)
}
}

Просмотреть файл

@ -0,0 +1,93 @@
use csv::WriterBuilder;
use log::info;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
#[cfg(feature = "pyo3")]
use crate::data_block::typedefs::CsvRecord;
use crate::{
data_block::{csv_io_error::CsvIOError, typedefs::RawSyntheticData},
utils::time::ElapsedDurationLogger,
};
#[cfg_attr(feature = "pyo3", pyclass)]
/// Synthetic data generated by the Generator
pub struct GeneratedData {
/// Synthesized data - headers (index 0) and records indexes 1...
pub synthetic_data: RawSyntheticData,
/// `Synthetic data length / Sensitive data length` (header not included)
pub expansion_ratio: f64,
}
impl GeneratedData {
/// Returns a new GeneratedData struct with default values
#[inline]
pub fn default() -> GeneratedData {
GeneratedData {
synthetic_data: RawSyntheticData::default(),
expansion_ratio: 0.0,
}
}
/// Returns a new GeneratedData struct
/// # Arguments
/// * `synthetic_data` - Synthesized data - headers (index 0) and records indexes 1...
/// * `expansion_ratio` - `Synthetic data length / Sensitive data length` (header not included)
#[inline]
pub fn new(synthetic_data: RawSyntheticData, expansion_ratio: f64) -> GeneratedData {
GeneratedData {
synthetic_data,
expansion_ratio,
}
}
}
#[cfg_attr(feature = "pyo3", pymethods)]
impl GeneratedData {
#[cfg(feature = "pyo3")]
/// Synthesized data - headers (index 0) and records indexes 1...
/// This method will clone the data, so its recommended to have its result stored
/// in a local variable to avoid it being called multiple times
fn get_synthetic_data(&self) -> Vec<CsvRecord> {
self.synthetic_data
.iter()
.map(|row| row.iter().map(|value| (**value).clone()).collect())
.collect()
}
#[cfg(feature = "pyo3")]
#[getter]
/// `Synthetic data length / Sensitive data length` (header not included)
fn expansion_ratio(&self) -> f64 {
self.expansion_ratio
}
/// Writes the synthesized data to the file system
/// # Arguments
/// * `path` - File path to be written
/// * `delimiter` - Delimiter to use when writing to `path`
pub fn write_synthetic_data(&self, path: &str, delimiter: char) -> Result<(), CsvIOError> {
let _duration_logger = ElapsedDurationLogger::new("write synthetic data");
info!("writing file {}", path);
let mut wtr = match WriterBuilder::new()
.delimiter(delimiter as u8)
.from_path(&path)
{
Ok(writer) => writer,
Err(err) => return Err(CsvIOError::new(err)),
};
// write header and records
for r in self.synthetic_data.iter() {
match wtr.write_record(r.iter().map(|v| v.as_str())) {
Ok(_) => {}
Err(err) => return Err(CsvIOError::new(err)),
};
}
Ok(())
}
}

Просмотреть файл

@ -1,156 +1,13 @@
/// Module defining the data synthesis process
pub mod synthesizer;
use csv::{Error, WriterBuilder};
use log::Level::Trace;
use log::{info, log_enabled, trace};
use std::time::Duration;
mod synthesis_mode;
use crate::data_block::block::DataBlock;
use crate::data_block::typedefs::CsvRecordRef;
use crate::measure_time;
use crate::processing::generator::synthesizer::cache::SynthesizerCacheKey;
use crate::processing::generator::synthesizer::seeded::SeededSynthesizer;
use crate::utils::reporting::ReportProgress;
use crate::utils::time::ElapsedDuration;
pub use synthesis_mode::SynthesisMode;
use self::synthesizer::typedefs::SynthesizedRecords;
mod data_generator;
#[derive(Debug)]
struct GeneratorDurations {
generate: Duration,
write_records: Duration,
}
pub use data_generator::Generator;
/// Process a data block and generates new synthetic data
pub struct Generator<'data_block> {
data_block: &'data_block DataBlock,
durations: GeneratorDurations,
}
/// Synthetic data generated by the Generator
pub struct GeneratedData<'data_block> {
/// Synthesized data - headers (index 0) and records indexes 1...
pub synthetic_data: Vec<CsvRecordRef<'data_block>>,
/// `Synthetic data length / Sensitive data length` (header not included)
pub expansion_ratio: f64,
}
impl<'data_block> Generator<'data_block> {
/// Returns a new Generator
/// # Arguments
/// * `data_block` - Sensitive data to be synthesized
#[inline]
pub fn new(data_block: &'data_block DataBlock) -> Generator<'data_block> {
Generator {
data_block,
durations: GeneratorDurations {
generate: Duration::default(),
write_records: Duration::default(),
},
}
}
/// Generates new synthetic data based on sensitive data
/// # Arguments
/// * `resolution` - Reporting resolution used for data synthesis
/// * `cache_max_size` - Maximum cache size allowed
/// * `empty_value` - Empty values on the synthetic data will be represented by this
/// * `progress_reporter` - Will be used to report the processing
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
pub fn generate<T>(
&mut self,
resolution: usize,
cache_max_size: usize,
empty_value: &'data_block str,
progress_reporter: &mut Option<T>,
) -> GeneratedData<'data_block>
where
T: ReportProgress,
{
info!("starting generation...");
measure_time!(
|| {
let mut result: Vec<CsvRecordRef<'data_block>> = Vec::default();
let mut records: Vec<CsvRecordRef<'data_block>> = self
.seeded_synthesis(resolution, cache_max_size, progress_reporter)
.iter()
.map(|r| {
SynthesizerCacheKey::new(self.data_block.headers.len(), r)
.format_record(empty_value)
})
.collect();
// sort by number of defined attributes
records.sort();
records.sort_by_key(|r| {
-r.iter()
.map(|s| if s.is_empty() { 0 } else { 1 })
.sum::<isize>()
});
let expansion_ratio = records.len() as f64 / self.data_block.records.len() as f64;
info!("expansion ratio: {:.4?}", expansion_ratio);
result.push(self.data_block.headers.iter().map(|h| h.as_str()).collect());
result.extend(records);
GeneratedData {
synthetic_data: result,
expansion_ratio,
}
},
(self.durations.generate)
)
}
/// Writes the synthesized data to the file system
/// # Arguments
/// * `result` - Synthetic data to write (headers included)
/// * `path` - File path to be written
/// * `delimiter` - Delimiter to use when writing to `path`
pub fn write_records(
&mut self,
result: &[CsvRecordRef<'data_block>],
path: &str,
delimiter: u8,
) -> Result<(), Error> {
info!("writing file {}", path);
measure_time!(
|| {
let mut wtr = WriterBuilder::new().delimiter(delimiter).from_path(&path)?;
// write header and records
for r in result.iter() {
wtr.write_record(r)?;
}
Ok(())
},
(self.durations.write_records)
)
}
fn seeded_synthesis<T>(
&mut self,
resolution: usize,
cache_max_size: usize,
progress_reporter: &mut Option<T>,
) -> SynthesizedRecords<'data_block>
where
T: ReportProgress,
{
let attr_rows_map = DataBlock::calc_attr_rows(&self.data_block.records);
let mut synth =
SeededSynthesizer::new(self.data_block, &attr_rows_map, resolution, cache_max_size);
synth.run(progress_reporter)
}
}
impl<'data_block> Drop for Generator<'data_block> {
fn drop(&mut self) {
trace!("generator durations: {:#?}", self.durations);
}
}
/// Module to describe the synthetic data generated by the Generator
pub mod generated_data;

Просмотреть файл

@ -0,0 +1,29 @@
use std::{fmt::Display, str::FromStr};
#[derive(Debug)]
/// Modes to execute the data generation/synthesis
pub enum SynthesisMode {
Seeded,
Unseeded,
}
impl FromStr for SynthesisMode {
type Err = &'static str;
fn from_str(mode: &str) -> Result<Self, Self::Err> {
match mode.to_lowercase().as_str() {
"seeded" => Ok(SynthesisMode::Seeded),
"unseeded" => Ok(SynthesisMode::Unseeded),
_ => Err("invalid mode, should be seeded or unseeded"),
}
}
}
impl Display for SynthesisMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SynthesisMode::Seeded => Ok(write!(f, "seeded")?),
SynthesisMode::Unseeded => Ok(write!(f, "unseeded")?),
}
}
}

Просмотреть файл

@ -1,6 +1,7 @@
use super::typedefs::SynthesizedRecord;
use fnv::FnvBuildHasher;
use lru::LruCache;
use std::sync::Arc;
use crate::data_block::{typedefs::CsvRecordRef, value::DataBlockValue};
@ -8,29 +9,26 @@ use crate::data_block::{typedefs::CsvRecordRef, value::DataBlockValue};
/// (`columns[{column_index}] = `value` if value exists,
/// or `columns[{column_index}] = None` otherwise)
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct SynthesizerCacheKey<'value> {
pub struct SynthesizerCacheKey {
/// Values for a given synthesized record are indexed by `column_index`
/// on this vector
columns: Vec<Option<&'value String>>,
columns: Vec<Option<Arc<String>>>,
}
impl<'value> SynthesizerCacheKey<'value> {
impl SynthesizerCacheKey {
/// Returns a new SynthesizerCacheKey
/// # Arguments
/// * `num_columns` - Number of columns in the data block
/// * `records` - Synthesized record to build the key for
#[inline]
pub fn new(
num_columns: usize,
values: &SynthesizedRecord<'value>,
) -> SynthesizerCacheKey<'value> {
pub fn new(num_columns: usize, values: &SynthesizedRecord) -> SynthesizerCacheKey {
let mut key = SynthesizerCacheKey {
columns: Vec::with_capacity(num_columns),
};
key.columns.resize_with(num_columns, || None);
for v in values.iter() {
key.columns[v.column_index] = Some(&v.value);
key.columns[v.column_index] = Some(v.value.clone());
}
key
}
@ -40,9 +38,9 @@ impl<'value> SynthesizerCacheKey<'value> {
/// # Arguments
/// * `value` - New data block value to be added in the new key
#[inline]
pub fn new_with_value(&self, value: &'value DataBlockValue) -> SynthesizerCacheKey<'value> {
pub fn new_with_value(&self, value: &Arc<DataBlockValue>) -> SynthesizerCacheKey {
let mut new_key = self.clone();
new_key.columns[value.column_index] = Some(&value.value);
new_key.columns[value.column_index] = Some(value.value.clone());
new_key
}
@ -59,29 +57,29 @@ impl<'value> SynthesizerCacheKey<'value> {
/// # Arguments
/// * `empty_value` - String to be used in case the `columns[column_index] == None`
#[inline]
pub fn format_record(&self, empty_value: &'value str) -> CsvRecordRef<'value> {
pub fn format_record(&self, empty_value: &Arc<String>) -> CsvRecordRef {
self.columns
.iter()
.map(|c_opt| match c_opt {
Some(c) => c,
None => empty_value,
Some(c) => (*c).clone(),
None => empty_value.clone(),
})
.collect()
}
}
/// Cache to store keys-values used during the synthesis process
pub struct SynthesizerCache<'value, T> {
pub struct SynthesizerCache<T> {
/// LruCache to store the keys mapping to a generic type T
cache: LruCache<SynthesizerCacheKey<'value>, T, FnvBuildHasher>,
cache: LruCache<SynthesizerCacheKey, T, FnvBuildHasher>,
}
impl<'value, T> SynthesizerCache<'value, T> {
impl<T> SynthesizerCache<T> {
/// Returns a new SynthesizerCache
/// # Arguments
/// * `cache_max_size` - Maximum cache size allowed for the LRU
#[inline]
pub fn new(cache_max_size: usize) -> SynthesizerCache<'value, T> {
pub fn new(cache_max_size: usize) -> SynthesizerCache<T> {
SynthesizerCache {
cache: LruCache::with_hasher(cache_max_size, FnvBuildHasher::default()),
}
@ -93,7 +91,7 @@ impl<'value, T> SynthesizerCache<'value, T> {
/// # Arguments
/// * `key` - Key to look for the value
#[inline]
pub fn get(&mut self, key: &SynthesizerCacheKey<'value>) -> Option<&T> {
pub fn get(&mut self, key: &SynthesizerCacheKey) -> Option<&T> {
self.cache.get(key)
}
@ -104,7 +102,7 @@ impl<'value, T> SynthesizerCache<'value, T> {
/// * `key` - Key to be inserted
/// * `value` - Value to associated with the key
#[inline]
pub fn insert(&mut self, key: SynthesizerCacheKey<'value>, value: T) -> Option<T> {
pub fn insert(&mut self, key: SynthesizerCacheKey, value: T) -> Option<T> {
self.cache.put(key, value)
}
}

Просмотреть файл

@ -2,93 +2,76 @@ use super::cache::{SynthesizerCache, SynthesizerCacheKey};
use super::typedefs::{
AttributeCountMap, NotAllowedAttrSet, SynthesizedRecord, SynthesizerSeedSlice,
};
use itertools::Itertools;
use rand::Rng;
use std::sync::Arc;
use crate::data_block::typedefs::{AttributeRows, AttributeRowsMap, AttributeRowsSlice};
use crate::data_block::typedefs::{
AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, AttributeRowsRefMap,
AttributeRowsSlice,
};
use crate::data_block::value::DataBlockValue;
use crate::processing::aggregator::typedefs::RecordsSet;
use crate::utils::collections::ordered_vec_intersection;
/// Represents a synthesis context, containing the information
/// required to synthesize new records
/// (common to seeded and unseeded synthesis)
pub struct SynthesizerContext<'data_block, 'attr_rows_map> {
pub struct SynthesizerContext {
/// Number of headers in the data block
headers_len: usize,
pub headers_len: usize,
/// Number of records in the data block
records_len: usize,
/// Maps a data block value to all the rows where it occurs
attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>,
pub records_len: usize,
/// Cache used to store the rows where value combinations occurs
cache: SynthesizerCache<'data_block, AttributeRows>,
cache: SynthesizerCache<Arc<AttributeRows>>,
/// Reporting resolution used for data synthesis
resolution: usize,
}
impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map> {
impl SynthesizerContext {
/// Returns a new SynthesizerContext
/// # Arguments
/// * `headers_len` - Number of headers in the data block
/// * `records_len` - Number of records in the data block
/// * `attr_rows_map` - Maps a data block value to all the rows where it occurs
/// * `resolution` - Reporting resolution used for data synthesis
/// * `cache_max_size` - Maximum cache size allowed
#[inline]
pub fn new(
headers_len: usize,
records_len: usize,
attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>,
resolution: usize,
cache_max_size: usize,
) -> SynthesizerContext<'data_block, 'attr_rows_map> {
) -> SynthesizerContext {
SynthesizerContext {
headers_len,
records_len,
attr_rows_map,
cache: SynthesizerCache::new(cache_max_size),
resolution,
}
}
/// Returns the configured mapping of a data block value to all
/// the rows where it occurs
#[inline]
pub fn get_attr_rows_map(&self) -> &AttributeRowsMap<'data_block> {
self.attr_rows_map
}
/// Returns the configured reporting resolution
#[inline]
pub fn get_resolution(&self) -> usize {
self.resolution
}
/// Samples an attribute based on its count
/// (the higher the count the greater the chance for the
/// attribute to be selected).
/// Returns `None` if all the counts are 0 or the map is empty
/// # Arguments
/// * `counts` - Maps an attribute to its count for sampling
pub fn sample_from_attr_counts(
counts: &AttributeCountMap<'data_block>,
) -> Option<&'data_block DataBlockValue> {
let mut res: Option<&'data_block DataBlockValue> = None;
pub fn sample_from_attr_counts(counts: &AttributeCountMap) -> Option<Arc<DataBlockValue>> {
let mut res: Option<Arc<DataBlockValue>> = None;
let total: usize = counts.values().sum();
if total != 0 {
let random = if total == 1 {
1
} else {
rand::thread_rng().gen_range(1..total)
};
let mut current_sim: usize = 0;
let random = rand::thread_rng().gen_range(1..=total);
let mut current_sum: usize = 0;
for (value, count) in counts.iter() {
current_sim += count;
if current_sim < random {
res = Some(value);
for (value, count) in counts.iter().sorted_by_key(|(_, c)| **c) {
if *count > 0 {
current_sum += count;
res = Some(value.clone());
if current_sum >= random {
break;
}
res = Some(value)
}
}
}
res
@ -100,52 +83,130 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
/// * `synthesized_record` - Record synthesized so far
/// * `current_seed` - Current seed/record used for sampling
/// * `not_allowed_attr_set` - Attributes not allowed to be sampled
pub fn sample_next_attr(
/// * `attr_rows_map` - Maps a data block value to all the rows where it occurs
pub fn sample_next_attr_from_seed(
&mut self,
synthesized_record: &SynthesizedRecord<'data_block>,
current_seed: &SynthesizerSeedSlice<'data_block>,
not_allowed_attr_set: &NotAllowedAttrSet<'data_block>,
) -> Option<&'data_block DataBlockValue> {
let counts =
self.calc_next_attr_count(synthesized_record, current_seed, not_allowed_attr_set);
synthesized_record: &SynthesizedRecord,
current_seed: &SynthesizerSeedSlice,
not_allowed_attr_set: &NotAllowedAttrSet,
attr_rows_map: &AttributeRowsMap,
) -> Option<Arc<DataBlockValue>> {
let counts = self.calc_next_attr_count(
synthesized_record,
current_seed,
not_allowed_attr_set,
attr_rows_map,
);
SynthesizerContext::sample_from_attr_counts(&counts)
}
/// Samples the next attribute from the column at `column_index`.
/// Returns a tuple with
/// (intersection of `current_attrs_rows` and rows that sampled value appear, sampled value)
/// # Arguments
/// * `synthesized_record` - Record synthesized so far
/// * `column_index` - Index of the column to sample from
/// * `attr_rows_map_by_column` - Maps the column index -> data block value -> rows where the value appear
/// * `current_attrs_rows` - Rows where the so far sampled combination appear
/// * `empty_value` - Empty values on the synthetic data will be represented by this
pub fn sample_next_attr_from_column(
&mut self,
synthesized_record: &SynthesizedRecord,
column_index: usize,
attr_rows_map_by_column: &AttributeRowsByColumnMap,
current_attrs_rows: &AttributeRowsSlice,
empty_value: &Arc<String>,
) -> Option<(Arc<AttributeRows>, Arc<DataBlockValue>)> {
let cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record);
let empty_block_value = Arc::new(DataBlockValue::new(column_index, empty_value.clone()));
let mut values_to_sample: AttributeRowsRefMap = AttributeRowsRefMap::default();
let mut counts = AttributeCountMap::default();
// calculate the row intersections for each value in the column
for (value, rows) in attr_rows_map_by_column[&column_index].iter() {
let new_cache_key = cache_key.new_with_value(value);
let rows_intersection = match self.cache.get(&new_cache_key) {
Some(cached_value) => cached_value.clone(),
None => {
let intersection = Arc::new(ordered_vec_intersection(current_attrs_rows, rows));
self.cache.insert(new_cache_key, intersection.clone());
intersection
}
};
values_to_sample.insert(value.clone(), rows_intersection);
}
// store the rows with empty values
let mut rows_with_empty_values: RecordsSet = values_to_sample[&empty_block_value]
.iter()
.cloned()
.collect();
for (value, rows) in values_to_sample.iter() {
if rows.len() < self.resolution {
// if the combination containing the attribute appears in less
// than resolution rows, we can't use it so we tag it as an empty value
rows_with_empty_values.extend(rows.iter());
} else if **value != *empty_block_value {
// if we can use the combination containing the attribute
// gather its count for sampling
counts.insert(value.clone(), rows.len());
}
}
// if there are empty values that can be sampled, add them for sampling
if !rows_with_empty_values.is_empty() {
counts.insert(empty_block_value, rows_with_empty_values.len());
}
Self::sample_from_attr_counts(&counts).map(|sampled_value| {
(
values_to_sample.remove(&sampled_value).unwrap(),
sampled_value,
)
})
}
fn calc_current_attrs_rows(
&mut self,
cache_key: &SynthesizerCacheKey<'data_block>,
synthesized_record: &SynthesizedRecord<'data_block>,
) -> AttributeRows {
cache_key: &SynthesizerCacheKey,
synthesized_record: &SynthesizedRecord,
attr_rows_map: &AttributeRowsMap,
) -> Arc<AttributeRows> {
match self.cache.get(cache_key) {
Some(cache_value) => cache_value.clone(),
None => {
let mut current_attrs_rows: AttributeRows = AttributeRows::new();
let current_attrs_rows_arc: Arc<AttributeRows>;
if !synthesized_record.is_empty() {
for sr in synthesized_record.iter() {
current_attrs_rows = if current_attrs_rows.is_empty() {
self.attr_rows_map[sr].clone()
attr_rows_map[sr].clone()
} else {
ordered_vec_intersection(&current_attrs_rows, &self.attr_rows_map[sr])
ordered_vec_intersection(&current_attrs_rows, &attr_rows_map[sr])
}
}
} else {
current_attrs_rows = (0..self.records_len).collect();
}
current_attrs_rows_arc = Arc::new(current_attrs_rows);
self.cache
.insert(cache_key.clone(), current_attrs_rows.clone());
current_attrs_rows
.insert(cache_key.clone(), current_attrs_rows_arc.clone());
current_attrs_rows_arc
}
}
}
fn gen_attr_count_map(
&mut self,
cache_key: &mut SynthesizerCacheKey<'data_block>,
current_seed: &SynthesizerSeedSlice<'data_block>,
cache_key: &SynthesizerCacheKey,
current_seed: &SynthesizerSeedSlice,
current_attrs_rows: &AttributeRowsSlice,
not_allowed_attr_set: &NotAllowedAttrSet<'data_block>,
) -> AttributeCountMap<'data_block> {
not_allowed_attr_set: &NotAllowedAttrSet,
attr_rows_map: &AttributeRowsMap,
) -> AttributeCountMap {
let mut attr_count_map: AttributeCountMap = AttributeCountMap::default();
for value in current_seed.iter() {
@ -156,10 +217,10 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
let count = match self.cache.get(&new_cache_key) {
Some(cached_value) => cached_value.len(),
None => {
let intersection = ordered_vec_intersection(
let intersection = Arc::new(ordered_vec_intersection(
current_attrs_rows,
self.attr_rows_map.get(value).unwrap(),
);
attr_rows_map.get(value).unwrap(),
));
let count = intersection.len();
self.cache.insert(new_cache_key, intersection);
@ -168,7 +229,7 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
};
if count >= self.resolution {
attr_count_map.insert(value, count);
attr_count_map.insert(value.clone(), count);
}
}
}
@ -177,18 +238,21 @@ impl<'data_block, 'attr_rows_map> SynthesizerContext<'data_block, 'attr_rows_map
fn calc_next_attr_count(
&mut self,
synthesized_record: &SynthesizedRecord<'data_block>,
current_seed: &SynthesizerSeedSlice<'data_block>,
not_allowed_attr_set: &NotAllowedAttrSet<'data_block>,
) -> AttributeCountMap<'data_block> {
let mut cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record);
let current_attrs_rows = self.calc_current_attrs_rows(&cache_key, synthesized_record);
synthesized_record: &SynthesizedRecord,
current_seed: &SynthesizerSeedSlice,
not_allowed_attr_set: &NotAllowedAttrSet,
attr_rows_map: &AttributeRowsMap,
) -> AttributeCountMap {
let cache_key = SynthesizerCacheKey::new(self.headers_len, synthesized_record);
let current_attrs_rows =
self.calc_current_attrs_rows(&cache_key, synthesized_record, attr_rows_map);
self.gen_attr_count_map(
&mut cache_key,
&cache_key,
current_seed,
&current_attrs_rows,
not_allowed_attr_set,
attr_rows_map,
)
}
}

Просмотреть файл

@ -1,8 +1,18 @@
/// Module defining the cache used during the synthesis process
pub mod cache;
/// Module defining the context used for data synthesis
pub mod context;
/// Module defining the seeded synthesis process
pub mod seeded;
/// Module defining the unseeded synthesis process
pub mod unseeded;
/// Type definitions related to the synthesis process
pub mod typedefs;
mod seeded_rows_synthesizer;
mod unseeded_rows_synthesizer;

Просмотреть файл

@ -1,56 +1,46 @@
use super::{
context::SynthesizerContext,
seeded_rows_synthesizer::SeededRowsSynthesizer,
typedefs::{
AvailableAttrsMap, NotAllowedAttrSet, SynthesizedRecord, SynthesizedRecords,
SynthesizedRecordsSlice, SynthesizerSeed, SynthesizerSeedSlice,
},
};
use fnv::FnvHashMap;
use itertools::izip;
use log::Level::Trace;
use log::{info, log_enabled, trace};
use itertools::{izip, Itertools};
use log::info;
use rand::{prelude::SliceRandom, thread_rng};
use std::time::Duration;
use std::sync::Arc;
use crate::{
data_block::{
block::DataBlock,
record::DataBlockRecord,
typedefs::{AttributeRowsMap, DataBlockRecords},
value::DataBlockValue,
},
measure_time,
data_block::{block::DataBlock, typedefs::AttributeRowsMap, value::DataBlockValue},
utils::{
math::{calc_percentage, iround_down},
reporting::ReportProgress,
time::ElapsedDuration,
threading::get_number_of_threads,
time::ElapsedDurationLogger,
},
};
#[derive(Debug)]
struct SeededSynthesizerDurations {
consolidate: Duration,
suppress: Duration,
synthesize_rows: Duration,
}
/// Represents all the information required to perform the seeded data synthesis
pub struct SeededSynthesizer<'data_block, 'attr_rows_map> {
/// Records to be synthesized
pub records: &'data_block DataBlockRecords,
/// Context used for the synthesis process
pub context: SynthesizerContext<'data_block, 'attr_rows_map>,
pub struct SeededSynthesizer {
/// Reference to the original data block
data_block: Arc<DataBlock>,
/// Maps a data block value to all the rows where it occurs
attr_rows_map: Arc<AttributeRowsMap>,
/// Reporting resolution used for data synthesis
resolution: usize,
/// Maximum cache size allowed
cache_max_size: usize,
/// Percentage already completed on the row synthesis step
synthesize_percentage: f64,
/// Percentage already completed on the consolidation step
consolidate_percentage: f64,
/// Percentage already completed on the suppression step
suppress_percentage: f64,
/// Elapsed durations on each step for benchmarking
durations: SeededSynthesizerDurations,
}
impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map> {
impl SeededSynthesizer {
/// Returns a new SeededSynthesizer
/// # Arguments
/// * `data_block` - Sensitive data to be synthesized
@ -59,28 +49,19 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
/// * `cache_max_size` - Maximum cache size allowed
#[inline]
pub fn new(
data_block: &'data_block DataBlock,
attr_rows_map: &'attr_rows_map AttributeRowsMap<'data_block>,
data_block: Arc<DataBlock>,
attr_rows_map: Arc<AttributeRowsMap>,
resolution: usize,
cache_max_size: usize,
) -> SeededSynthesizer<'data_block, 'attr_rows_map> {
) -> SeededSynthesizer {
SeededSynthesizer {
records: &data_block.records,
context: SynthesizerContext::new(
data_block.headers.len(),
data_block.records.len(),
data_block,
attr_rows_map,
resolution,
cache_max_size,
),
synthesize_percentage: 0.0,
consolidate_percentage: 0.0,
suppress_percentage: 0.0,
durations: SeededSynthesizerDurations {
consolidate: Duration::default(),
suppress: Duration::default(),
synthesize_rows: Duration::default(),
},
}
}
@ -89,64 +70,71 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
/// # Arguments
/// * `progress_reporter` - Will be used to report the processing
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
pub fn run<T>(&mut self, progress_reporter: &mut Option<T>) -> SynthesizedRecords<'data_block>
pub fn run<T>(&mut self, progress_reporter: &mut Option<T>) -> SynthesizedRecords
where
T: ReportProgress,
{
let mut synthesized_records: SynthesizedRecords<'data_block> = SynthesizedRecords::new();
let mut synthesized_records: SynthesizedRecords = SynthesizedRecords::new();
if !self.data_block.records.is_empty() {
let mut rows_synthesizers: Vec<SeededRowsSynthesizer> = self.build_rows_synthesizers();
self.synthesize_percentage = 0.0;
self.consolidate_percentage = 0.0;
self.suppress_percentage = 0.0;
self.synthesize_rows(&mut synthesized_records, progress_reporter);
self.consolidate(&mut synthesized_records, progress_reporter);
self.synthesize_rows(
&mut synthesized_records,
&mut rows_synthesizers,
progress_reporter,
);
self.consolidate(
&mut synthesized_records,
progress_reporter,
// use the first context to leverage already cached intersections
&mut rows_synthesizers[0].context,
);
self.suppress(&mut synthesized_records, progress_reporter);
}
synthesized_records
}
#[inline]
fn synthesize_row(
&mut self,
seed: &'data_block DataBlockRecord,
) -> SynthesizedRecord<'data_block> {
let current_seed: SynthesizerSeed = seed.values.iter().collect();
let mut synthesized_record = SynthesizedRecord::default();
let not_allowed_attr_set = NotAllowedAttrSet::default();
fn build_rows_synthesizers(&self) -> Vec<SeededRowsSynthesizer> {
let chunk_size = ((self.data_block.records.len() as f64) / (get_number_of_threads() as f64))
.ceil() as usize;
let mut rows_synthesizers: Vec<SeededRowsSynthesizer> = Vec::default();
loop {
let next = self.context.sample_next_attr(
&synthesized_record,
&current_seed,
&not_allowed_attr_set,
);
match next {
Some(value) => {
synthesized_record.insert(value);
for c in &self.data_block.records.iter().chunks(chunk_size) {
rows_synthesizers.push(SeededRowsSynthesizer::new(
SynthesizerContext::new(
self.data_block.headers.len(),
self.data_block.records.len(),
self.resolution,
self.cache_max_size,
),
c.cloned().collect(),
self.attr_rows_map.clone(),
));
}
None => break,
}
}
synthesized_record
rows_synthesizers
}
#[inline]
fn count_not_used_attrs(
&mut self,
synthesized_records: &SynthesizedRecordsSlice<'data_block>,
) -> AvailableAttrsMap<'data_block> {
synthesized_records: &SynthesizedRecordsSlice,
) -> AvailableAttrsMap {
let mut available_attrs: AvailableAttrsMap = AvailableAttrsMap::default();
// go through the pairs (original_record, synthesized_record) and count how many
// attributes were not used
for (original_record, synthesized_record) in
izip!(self.records.iter(), synthesized_records.iter())
izip!(self.data_block.records.iter(), synthesized_records.iter())
{
for d in original_record.values.iter() {
if !synthesized_record.contains(d) {
let attr = available_attrs.entry(d).or_insert(0);
let attr = available_attrs.entry(d.clone()).or_insert(0);
*attr += 1;
}
}
@ -156,16 +144,16 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
fn calc_available_attrs(
&mut self,
synthesized_records: &SynthesizedRecordsSlice<'data_block>,
) -> AvailableAttrsMap<'data_block> {
synthesized_records: &SynthesizedRecordsSlice,
) -> AvailableAttrsMap {
let mut available_attrs = self.count_not_used_attrs(synthesized_records);
let resolution_f64 = self.context.get_resolution() as f64;
let resolution_f64 = self.resolution as f64;
// add attributes for consolidation
for (attr, value) in self.context.get_attr_rows_map().iter() {
for (attr, value) in self.attr_rows_map.iter() {
let n_rows = value.len();
if n_rows >= self.context.get_resolution() {
if n_rows >= self.resolution {
let target_attr_count = iround_down(n_rows as f64, resolution_f64)
- (n_rows as isize)
+ match available_attrs.get(attr) {
@ -175,7 +163,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
if target_attr_count > 0 {
// insert/update the final target count
available_attrs.insert(attr, target_attr_count);
available_attrs.insert(attr.clone(), target_attr_count);
} else {
// remove negative and zero values
available_attrs.remove(attr);
@ -192,18 +180,17 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
#[inline]
fn calc_not_allowed_attrs(
&mut self,
available_attrs: &mut AvailableAttrsMap<'data_block>,
) -> NotAllowedAttrSet<'data_block> {
self.context
.get_attr_rows_map()
available_attrs: &mut AvailableAttrsMap,
) -> NotAllowedAttrSet {
self.attr_rows_map
.keys()
.filter_map(|attr| match available_attrs.get(attr) {
// not on available attributes
None => Some(*attr),
None => Some(attr.clone()),
Some(at) => {
if *at <= 0 {
// on available attributes, but count <= 0
Some(*attr)
Some(attr.clone())
} else {
None
}
@ -215,18 +202,20 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
#[inline]
fn consolidate_record(
&mut self,
available_attrs: &mut AvailableAttrsMap<'data_block>,
current_seed: &SynthesizerSeedSlice<'data_block>,
) -> SynthesizedRecord<'data_block> {
available_attrs: &mut AvailableAttrsMap,
current_seed: &SynthesizerSeedSlice,
context: &mut SynthesizerContext,
) -> SynthesizedRecord {
let mut not_allowed_attr_set: NotAllowedAttrSet =
self.calc_not_allowed_attrs(available_attrs);
let mut synthesized_record = SynthesizedRecord::default();
loop {
let next = self.context.sample_next_attr(
let next = context.sample_next_attr_from_seed(
&synthesized_record,
current_seed,
&not_allowed_attr_set,
&self.attr_rows_map,
);
match next {
@ -236,9 +225,9 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
if next_count <= 1 {
available_attrs.remove(&value);
not_allowed_attr_set.insert(value);
not_allowed_attr_set.insert(value.clone());
} else {
available_attrs.insert(value, next_count - 1);
available_attrs.insert(value.clone(), next_count - 1);
}
synthesized_record.insert(value);
}
@ -249,13 +238,13 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
fn consolidate<T>(
&mut self,
synthesized_records: &mut SynthesizedRecords<'data_block>,
synthesized_records: &mut SynthesizedRecords,
progress_reporter: &mut Option<T>,
context: &mut SynthesizerContext,
) where
T: ReportProgress,
{
measure_time!(
|| {
let _duration_logger = ElapsedDurationLogger::new("consolidation");
info!("consolidating...");
let mut available_attrs = self.calc_available_attrs(synthesized_records);
@ -266,27 +255,26 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
while !available_attrs.is_empty() {
self.update_consolidate_progress(n_processed, total_f64, progress_reporter);
synthesized_records
.push(self.consolidate_record(&mut available_attrs, &current_seed));
synthesized_records.push(self.consolidate_record(
&mut available_attrs,
&current_seed,
context,
));
n_processed = total - available_attrs.len();
}
self.update_consolidate_progress(n_processed, total_f64, progress_reporter);
},
(self.durations.consolidate)
);
}
#[inline]
fn count_synthesized_records_attrs(
&mut self,
synthesized_records: &mut SynthesizedRecords<'data_block>,
) -> FnvHashMap<&'data_block DataBlockValue, isize> {
let mut current_counts: FnvHashMap<&'data_block DataBlockValue, isize> =
FnvHashMap::default();
synthesized_records: &mut SynthesizedRecords,
) -> FnvHashMap<Arc<DataBlockValue>, isize> {
let mut current_counts: FnvHashMap<Arc<DataBlockValue>, isize> = FnvHashMap::default();
for r in synthesized_records.iter() {
for v in r.iter() {
let count = current_counts.entry(v).or_insert(0);
let count = current_counts.entry(v.clone()).or_insert(0);
*count += 1;
}
}
@ -296,16 +284,16 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
#[inline]
fn calc_exceeded_count_attrs(
&mut self,
current_counts: &FnvHashMap<&'data_block DataBlockValue, isize>,
) -> FnvHashMap<&'data_block DataBlockValue, isize> {
let mut targets: FnvHashMap<&'data_block DataBlockValue, isize> = FnvHashMap::default();
current_counts: &FnvHashMap<Arc<DataBlockValue>, isize>,
) -> FnvHashMap<Arc<DataBlockValue>, isize> {
let mut targets: FnvHashMap<Arc<DataBlockValue>, isize> = FnvHashMap::default();
for (attr, rows) in self.context.get_attr_rows_map().iter() {
if rows.len() >= self.context.get_resolution() {
let t = current_counts[attr]
- iround_down(rows.len() as f64, self.context.get_resolution() as f64);
for (attr, rows) in self.attr_rows_map.iter() {
if rows.len() >= self.resolution {
let t =
current_counts[attr] - iround_down(rows.len() as f64, self.resolution as f64);
if t > 0 {
targets.insert(attr, t);
targets.insert(attr.clone(), t);
}
}
}
@ -314,18 +302,17 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
fn suppress<T>(
&mut self,
synthesized_records: &mut SynthesizedRecords<'data_block>,
synthesized_records: &mut SynthesizedRecords,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
measure_time!(
|| {
let _duration_logger = ElapsedDurationLogger::new("suppression");
info!("suppressing...");
let current_counts: FnvHashMap<&'data_block DataBlockValue, isize> =
let current_counts: FnvHashMap<Arc<DataBlockValue>, isize> =
self.count_synthesized_records_attrs(synthesized_records);
let mut targets: FnvHashMap<&'data_block DataBlockValue, isize> =
let mut targets: FnvHashMap<Arc<DataBlockValue>, isize> =
self.calc_exceeded_count_attrs(&current_counts);
let total = synthesized_records.len() as f64;
let mut n_processed = 0;
@ -340,13 +327,13 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
for attr in r.iter() {
match targets.get(attr).cloned() {
None => {
new_record.insert(attr);
new_record.insert(attr.clone());
}
Some(attr_count) => {
if attr_count == 1 {
targets.remove(attr);
} else {
targets.insert(attr, attr_count - 1);
targets.insert(attr.clone(), attr_count - 1);
}
}
}
@ -355,35 +342,33 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
}
synthesized_records.retain(|r| !r.is_empty());
self.update_suppress_progress(n_processed, total, progress_reporter);
},
(self.durations.suppress)
);
}
fn synthesize_rows<T>(
&mut self,
synthesized_records: &mut SynthesizedRecords<'data_block>,
synthesized_records: &mut SynthesizedRecords,
rows_synthesizers: &mut Vec<SeededRowsSynthesizer>,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
measure_time!(
|| {
info!("synthesizing rows...");
let _duration_logger = ElapsedDurationLogger::new("rows synthesis");
let mut n_processed = 0;
let records = self.records;
let total = records.len() as f64;
for seed in records.iter() {
self.update_synthesize_progress(n_processed, total, progress_reporter);
n_processed += 1;
synthesized_records.push(self.synthesize_row(seed));
}
self.update_synthesize_progress(n_processed, total, progress_reporter);
},
(self.durations.synthesize_rows)
info!(
"synthesizing rows with {} thread(s)...",
get_number_of_threads()
);
let total = self.data_block.records.len() as f64;
SeededRowsSynthesizer::synthesize_all(
total,
synthesized_records,
rows_synthesizers,
progress_reporter,
);
self.update_synthesize_progress(self.data_block.records.len(), total, progress_reporter);
}
#[inline]
@ -396,7 +381,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
T: ReportProgress,
{
if let Some(r) = progress_reporter {
self.synthesize_percentage = calc_percentage(n_processed, total);
self.synthesize_percentage = calc_percentage(n_processed as f64, total);
r.report(self.calc_overall_progress());
}
}
@ -411,7 +396,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
T: ReportProgress,
{
if let Some(r) = progress_reporter {
self.consolidate_percentage = calc_percentage(n_processed, total);
self.consolidate_percentage = calc_percentage(n_processed as f64, total);
r.report(self.calc_overall_progress());
}
}
@ -426,7 +411,7 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
T: ReportProgress,
{
if let Some(r) = progress_reporter {
self.suppress_percentage = calc_percentage(n_processed, total);
self.suppress_percentage = calc_percentage(n_processed as f64, total);
r.report(self.calc_overall_progress());
}
}
@ -438,9 +423,3 @@ impl<'data_block, 'attr_rows_map> SeededSynthesizer<'data_block, 'attr_rows_map>
+ self.suppress_percentage * 0.2
}
}
impl<'data_block, 'attr_rows_map> Drop for SeededSynthesizer<'data_block, 'attr_rows_map> {
fn drop(&mut self) {
trace!("seed synthesizer durations: {:#?}", self.durations);
}
}

Просмотреть файл

@ -0,0 +1,127 @@
use super::{
context::SynthesizerContext,
typedefs::{NotAllowedAttrSet, SynthesizedRecord, SynthesizedRecords, SynthesizerSeed},
};
use crate::{
data_block::typedefs::AttributeRowsMap,
utils::reporting::{SendableProgressReporter, SendableProgressReporterRef},
};
use std::sync::Arc;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[cfg(feature = "rayon")]
use std::sync::Mutex;
use crate::{
data_block::{record::DataBlockRecord, typedefs::DataBlockRecords},
utils::reporting::ReportProgress,
};
pub struct SeededRowsSynthesizer {
pub context: SynthesizerContext,
pub records: DataBlockRecords,
pub attr_rows_map: Arc<AttributeRowsMap>,
}
impl SeededRowsSynthesizer {
#[inline]
pub fn new(
context: SynthesizerContext,
records: DataBlockRecords,
attr_rows_map: Arc<AttributeRowsMap>,
) -> SeededRowsSynthesizer {
SeededRowsSynthesizer {
context,
records,
attr_rows_map,
}
}
#[cfg(feature = "rayon")]
#[inline]
pub fn synthesize_all<T>(
total: f64,
synthesized_records: &mut SynthesizedRecords,
rows_synthesizers: &mut Vec<SeededRowsSynthesizer>,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
let sendable_pr = Arc::new(Mutex::new(
progress_reporter
.as_mut()
.map(|r| SendableProgressReporter::new(total, 0.5, r)),
));
synthesized_records.par_extend(
rows_synthesizers
.par_iter_mut()
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr.clone())),
);
}
#[cfg(not(feature = "rayon"))]
#[inline]
pub fn synthesize_all<T>(
total: f64,
synthesized_records: &mut SynthesizedRecords,
rows_synthesizers: &mut Vec<SeededRowsSynthesizer>,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
let mut sendable_pr = progress_reporter
.as_mut()
.map(|r| SendableProgressReporter::new(total, 0.5, r));
synthesized_records.extend(
rows_synthesizers
.iter_mut()
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr)),
);
}
#[inline]
fn synthesize_rows<T>(
&mut self,
progress_reporter: &mut SendableProgressReporterRef<T>,
) -> SynthesizedRecords
where
T: ReportProgress,
{
let mut synthesized_records = SynthesizedRecords::default();
let records = self.records.clone();
for seed in records.iter() {
synthesized_records.push(self.synthesize_row(seed));
SendableProgressReporter::update_progress(progress_reporter, 1.0);
}
synthesized_records
}
#[inline]
fn synthesize_row(&mut self, seed: &DataBlockRecord) -> SynthesizedRecord {
let current_seed: &SynthesizerSeed = &seed.values;
let mut synthesized_record = SynthesizedRecord::default();
let not_allowed_attr_set = NotAllowedAttrSet::default();
loop {
let next = self.context.sample_next_attr_from_seed(
&synthesized_record,
current_seed,
&not_allowed_attr_set,
&self.attr_rows_map,
);
match next {
Some(value) => {
synthesized_record.insert(value);
}
None => break,
}
}
synthesized_record
}
}

Просмотреть файл

@ -1,33 +1,31 @@
use fnv::{FnvHashMap, FnvHashSet};
use std::collections::BTreeSet;
use std::{collections::BTreeSet, sync::Arc};
use crate::data_block::value::DataBlockValue;
/// If the data block value were added to the synthesized record, this maps to the
/// number of rows the resulting attribute combination will be part of
pub type AttributeCountMap<'data_block_value> =
FnvHashMap<&'data_block_value DataBlockValue, usize>;
pub type AttributeCountMap = FnvHashMap<Arc<DataBlockValue>, usize>;
/// The seeds used for the current synthesis step (aka current record being processed)
pub type SynthesizerSeed<'data_block_value> = Vec<&'data_block_value DataBlockValue>;
pub type SynthesizerSeed = Vec<Arc<DataBlockValue>>;
/// Slice of SynthesizerSeed
pub type SynthesizerSeedSlice<'data_block_value> = [&'data_block_value DataBlockValue];
pub type SynthesizerSeedSlice = [Arc<DataBlockValue>];
/// Attributes not allowed to be sampled in a particular data synthesis stage
pub type NotAllowedAttrSet<'data_block_value> = FnvHashSet<&'data_block_value DataBlockValue>;
pub type NotAllowedAttrSet = FnvHashSet<Arc<DataBlockValue>>;
/// Record synthesized at a particular stage
pub type SynthesizedRecord<'data_block_value> = BTreeSet<&'data_block_value DataBlockValue>;
pub type SynthesizedRecord = BTreeSet<Arc<DataBlockValue>>;
/// Vector of synthesized records
pub type SynthesizedRecords<'data_block_value> = Vec<SynthesizedRecord<'data_block_value>>;
pub type SynthesizedRecords = Vec<SynthesizedRecord>;
/// Slice of SynthesizedRecords
pub type SynthesizedRecordsSlice<'data_block_value> = [SynthesizedRecord<'data_block_value>];
pub type SynthesizedRecordsSlice = [SynthesizedRecord];
/// How many attributes (of the ones available) should be added during the
/// consolidation step for the count to match the a multiple of reporting resolution
/// (rounded down)
pub type AvailableAttrsMap<'data_block_value> =
FnvHashMap<&'data_block_value DataBlockValue, isize>;
pub type AvailableAttrsMap = FnvHashMap<Arc<DataBlockValue>, isize>;

Просмотреть файл

@ -0,0 +1,163 @@
use super::{
context::SynthesizerContext, typedefs::SynthesizedRecords,
unseeded_rows_synthesizer::UnseededRowsSynthesizer,
};
use log::info;
use std::sync::Arc;
use crate::{
data_block::{block::DataBlock, typedefs::AttributeRowsByColumnMap},
utils::{
math::calc_percentage, reporting::ReportProgress, threading::get_number_of_threads,
time::ElapsedDurationLogger,
},
};
/// Represents all the information required to perform the unseeded data synthesis
pub struct UnseededSynthesizer {
/// Reference to the original data block
data_block: Arc<DataBlock>,
/// Maps a data block value to all the rows where it occurs grouped by column
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
/// Reporting resolution used for data synthesis
resolution: usize,
/// Maximum cache size allowed
cache_max_size: usize,
/// Empty values on the synthetic data will be represented by this
empty_value: Arc<String>,
/// Percentage already completed on the row synthesis step
synthesize_percentage: f64,
}
impl UnseededSynthesizer {
/// Returns a new UnseededSynthesizer
/// # Arguments
/// * `data_block` - Sensitive data to be synthesized
/// * `attr_rows_map_by_column` - Maps a data block value to all the rows where it occurs grouped by column
/// * `resolution` - Reporting resolution used for data synthesis
/// * `cache_max_size` - Maximum cache size allowed
/// * `empty_value` - Empty values on the synthetic data will be represented by this
#[inline]
pub fn new(
data_block: Arc<DataBlock>,
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
resolution: usize,
cache_max_size: usize,
empty_value: Arc<String>,
) -> UnseededSynthesizer {
UnseededSynthesizer {
data_block,
attr_rows_map_by_column,
resolution,
cache_max_size,
empty_value,
synthesize_percentage: 0.0,
}
}
/// Performs the row synthesis
/// Returns the synthesized records
/// # Arguments
/// * `progress_reporter` - Will be used to report the processing
/// progress (`ReportProgress` trait). If `None`, nothing will be reported
pub fn run<T>(&mut self, progress_reporter: &mut Option<T>) -> SynthesizedRecords
where
T: ReportProgress,
{
let mut synthesized_records: SynthesizedRecords = SynthesizedRecords::new();
if !self.data_block.records.is_empty() {
let mut rows_synthesizers: Vec<UnseededRowsSynthesizer> =
self.build_rows_synthesizers();
self.synthesize_percentage = 0.0;
self.synthesize_rows(
&mut synthesized_records,
&mut rows_synthesizers,
progress_reporter,
);
}
synthesized_records
}
#[inline]
fn build_rows_synthesizers(&self) -> Vec<UnseededRowsSynthesizer> {
let mut total_size = self.data_block.records.len();
let chunk_size = ((total_size as f64) / (get_number_of_threads() as f64)).ceil() as usize;
let mut rows_synthesizers: Vec<UnseededRowsSynthesizer> = Vec::default();
loop {
if total_size > chunk_size {
rows_synthesizers.push(UnseededRowsSynthesizer::new(
SynthesizerContext::new(
self.data_block.headers.len(),
self.data_block.records.len(),
self.resolution,
self.cache_max_size,
),
chunk_size,
self.attr_rows_map_by_column.clone(),
self.empty_value.clone(),
));
total_size -= chunk_size;
} else {
rows_synthesizers.push(UnseededRowsSynthesizer::new(
SynthesizerContext::new(
self.data_block.headers.len(),
self.data_block.records.len(),
self.resolution,
self.cache_max_size,
),
total_size,
self.attr_rows_map_by_column.clone(),
self.empty_value.clone(),
));
break;
}
}
rows_synthesizers
}
fn synthesize_rows<T>(
&mut self,
synthesized_records: &mut SynthesizedRecords,
rows_synthesizers: &mut Vec<UnseededRowsSynthesizer>,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
let _duration_logger = ElapsedDurationLogger::new("rows synthesis");
info!(
"synthesizing rows with {} thread(s)...",
get_number_of_threads()
);
let total = self.data_block.records.len() as f64;
UnseededRowsSynthesizer::synthesize_all(
total,
synthesized_records,
rows_synthesizers,
progress_reporter,
);
self.update_synthesize_progress(self.data_block.records.len(), total, progress_reporter);
}
#[inline]
fn update_synthesize_progress<T>(
&mut self,
n_processed: usize,
total: f64,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
if let Some(r) = progress_reporter {
self.synthesize_percentage = calc_percentage(n_processed as f64, total);
r.report(self.synthesize_percentage);
}
}
}

Просмотреть файл

@ -0,0 +1,128 @@
use super::{
context::SynthesizerContext,
typedefs::{SynthesizedRecord, SynthesizedRecords},
};
use crate::{
data_block::typedefs::{AttributeRows, AttributeRowsByColumnMap},
utils::reporting::{SendableProgressReporter, SendableProgressReporterRef},
};
use rand::{prelude::SliceRandom, thread_rng};
use std::sync::Arc;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[cfg(feature = "rayon")]
use std::sync::Mutex;
use crate::utils::reporting::ReportProgress;
pub struct UnseededRowsSynthesizer {
context: SynthesizerContext,
chunk_size: usize,
column_indexes: Vec<usize>,
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
empty_value: Arc<String>,
}
impl UnseededRowsSynthesizer {
#[inline]
pub fn new(
context: SynthesizerContext,
chunk_size: usize,
attr_rows_map_by_column: Arc<AttributeRowsByColumnMap>,
empty_value: Arc<String>,
) -> UnseededRowsSynthesizer {
UnseededRowsSynthesizer {
context,
chunk_size,
column_indexes: attr_rows_map_by_column.keys().cloned().collect(),
attr_rows_map_by_column,
empty_value,
}
}
#[cfg(feature = "rayon")]
#[inline]
pub fn synthesize_all<T>(
total: f64,
synthesized_records: &mut SynthesizedRecords,
rows_synthesizers: &mut Vec<UnseededRowsSynthesizer>,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
let sendable_pr = Arc::new(Mutex::new(
progress_reporter
.as_mut()
.map(|r| SendableProgressReporter::new(total, 1.0, r)),
));
synthesized_records.par_extend(
rows_synthesizers
.par_iter_mut()
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr.clone())),
);
}
#[cfg(not(feature = "rayon"))]
#[inline]
pub fn synthesize_all<T>(
total: f64,
synthesized_records: &mut SynthesizedRecords,
rows_synthesizers: &mut Vec<UnseededRowsSynthesizer>,
progress_reporter: &mut Option<T>,
) where
T: ReportProgress,
{
let mut sendable_pr = progress_reporter
.as_mut()
.map(|r| SendableProgressReporter::new(total, 1.0, r));
synthesized_records.extend(
rows_synthesizers
.iter_mut()
.flat_map(|rs| rs.synthesize_rows(&mut sendable_pr)),
);
}
#[inline]
fn synthesize_rows<T>(
&mut self,
progress_reporter: &mut SendableProgressReporterRef<T>,
) -> SynthesizedRecords
where
T: ReportProgress,
{
let mut synthesized_records = SynthesizedRecords::default();
for _ in 0..self.chunk_size {
synthesized_records.push(self.synthesize_row());
SendableProgressReporter::update_progress(progress_reporter, 1.0);
}
synthesized_records
}
#[inline]
fn synthesize_row(&mut self) -> SynthesizedRecord {
let mut synthesized_record = SynthesizedRecord::default();
let mut current_attrs_rows: Arc<AttributeRows> =
Arc::new((0..self.context.records_len).collect());
self.column_indexes.shuffle(&mut thread_rng());
for column_index in self.column_indexes.iter() {
if let Some((next_attrs_rows, sample)) = self.context.sample_next_attr_from_column(
&synthesized_record,
*column_index,
&self.attr_rows_map_by_column,
&current_attrs_rows,
&self.empty_value,
) {
current_attrs_rows = next_attrs_rows;
synthesized_record.insert(sample);
}
}
synthesized_record
}
}

Просмотреть файл

@ -4,8 +4,8 @@ use std::hash::Hash;
/// Given two sorted vectors, calculates the intersection between them.
/// Returns the result intersection as a new vector
/// # Arguments
/// - `a`: first sorted vector
/// - `b`: second sorted vector
/// * `a` - first sorted vector
/// * `b` - second sorted vector
#[inline]
pub fn ordered_vec_intersection<T>(a: &[T], b: &[T]) -> Vec<T>
where

Просмотреть файл

@ -54,6 +54,6 @@ pub fn calc_n_combinations_range(n: usize, range: &[usize]) -> u64 {
/// Calculates the percentage of processed elements up to a total
#[inline]
pub fn calc_percentage(n_processed: usize, total: f64) -> f64 {
(n_processed as f64) * 100.0 / total
pub fn calc_percentage(n_processed: f64, total: f64) -> f64 {
n_processed * 100.0 / total
}

Просмотреть файл

@ -1,8 +1,14 @@
/// Module for collection utilities
pub mod collections;
/// Module for math utilities
pub mod math;
/// Module for reporting utilities
pub mod reporting;
/// Module for threading utilities
pub mod threading;
/// Module for time utilities
pub mod time;

Просмотреть файл

@ -1,11 +1,6 @@
use super::ReportProgress;
use log::{log, Level};
/// Implement this trait to inform progress
pub trait ReportProgress {
/// Receives the updated progress
fn report(&mut self, new_progress: f64);
}
/// Simple progress reporter using the default logger.
/// * It will log progress using the configured `log_level`
/// * It will only log at every 1% completed
@ -16,8 +11,8 @@ pub struct LoggerProgressReporter {
impl LoggerProgressReporter {
/// Returns a new LoggerProgressReporter
/// # Arguments:
/// * `log_level`: which log level use to log progress
/// # Arguments
/// * `log_level - which log level use to log progress
pub fn new(log_level: Level) -> LoggerProgressReporter {
LoggerProgressReporter {
progress: 0.0,

Просмотреть файл

@ -0,0 +1,11 @@
mod report_progress;
mod logger_progress_reporter;
mod sendable_progress_reporter;
pub use report_progress::ReportProgress;
pub use logger_progress_reporter::LoggerProgressReporter;
pub use sendable_progress_reporter::{SendableProgressReporter, SendableProgressReporterRef};

Просмотреть файл

@ -0,0 +1,5 @@
/// Implement this trait to inform progress
pub trait ReportProgress {
/// Receives the updated progress
fn report(&mut self, new_progress: f64);
}

Просмотреть файл

@ -0,0 +1,92 @@
#[cfg(feature = "rayon")]
use std::sync::{Arc, Mutex};
use crate::utils::{math::calc_percentage, reporting::ReportProgress};
/// Progress reporter to be used in parallel environments
/// (should be wrapped with `Arc` and `Mutex` if multi-threading is enabled).
/// It will calculate `proportion * (n_processed * 100.0 / total)` and report
/// it to the main reporter
pub struct SendableProgressReporter<'main_reporter, T>
where
T: ReportProgress,
{
total: f64,
n_processed: f64,
proportion: f64,
main_reporter: &'main_reporter mut T,
}
impl<'main_reporter, T> SendableProgressReporter<'main_reporter, T>
where
T: ReportProgress,
{
/// Creates a new SendableProgressReporter
/// # Arguments
/// * `total` - Total number of steps to be reported
/// * `proportion` - Value to multiply the percentage before reporting to
/// main_reporter
/// * `main_reporter` - Main reporter to which this should report
pub fn new(
total: f64,
proportion: f64,
main_reporter: &'main_reporter mut T,
) -> SendableProgressReporter<T> {
SendableProgressReporter {
total,
proportion,
n_processed: 0.0,
main_reporter,
}
}
/// Updates `reporter` by adding `value_to_add` and reporting
/// to the main reporter in a thread safe way
#[inline]
pub fn update_progress(reporter: &mut SendableProgressReporterRef<T>, value_to_add: f64)
where
T: ReportProgress,
{
#[cfg(feature = "rayon")]
if let Ok(guard) = &mut reporter.lock() {
if let Some(r) = guard.as_mut() {
r.report(value_to_add);
}
}
#[cfg(not(feature = "rayon"))]
if let Some(r) = reporter {
r.report(value_to_add);
}
}
}
impl<'main_reporter, T> ReportProgress for SendableProgressReporter<'main_reporter, T>
where
T: ReportProgress,
{
/// Will add `value_to_add` to `n_processed` and call the main reporter with
/// `proportion * (n_processed * 100.0 / total)`
fn report(&mut self, value_to_add: f64) {
self.n_processed += value_to_add;
self.main_reporter
.report(self.proportion * calc_percentage(self.n_processed, self.total))
}
}
#[cfg(feature = "rayon")]
/// People using this should correctly handle race conditions with a `Mutex`
/// (see `SendableProgressReporterRef`), so we inform the compiler this struct is thread safe
unsafe impl<'main_reporter, T> Send for SendableProgressReporter<'main_reporter, T> where
T: ReportProgress
{
}
#[cfg(feature = "rayon")]
/// Use this to refer to SendableProgressReporter if multi-threading is enabled
pub type SendableProgressReporterRef<'main_reporter, T> =
Arc<Mutex<Option<SendableProgressReporter<'main_reporter, T>>>>;
#[cfg(not(feature = "rayon"))]
/// Use this to refer to SendableProgressReporter if multi-threading is disabled
pub type SendableProgressReporterRef<'main_reporter, T> =
Option<SendableProgressReporter<'main_reporter, T>>;

Просмотреть файл

@ -0,0 +1,34 @@
#[cfg(feature = "rayon")]
use rayon;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
pub fn get_number_of_threads() -> usize {
#[cfg(feature = "rayon")]
{
rayon::current_num_threads()
}
#[cfg(not(feature = "rayon"))]
{
1
}
}
#[cfg(feature = "rayon")]
#[cfg_attr(feature = "pyo3", pyfunction)]
#[allow(unused_must_use)]
/// Sets the number of threads used for parallel processing
/// # Arguments
/// * `n` - number of threads
pub fn set_number_of_threads(n: usize) {
rayon::ThreadPoolBuilder::new()
.num_threads(n)
.build_global();
}
#[cfg(feature = "pyo3")]
pub fn register(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(set_number_of_threads, m)?)?;
Ok(())
}

Просмотреть файл

@ -16,7 +16,7 @@ pub struct ElapsedDuration<'lifetime> {
impl<'lifetime> ElapsedDuration<'lifetime> {
/// Returns a new ElapsedDuration
/// # Arguments
/// - `result`: Duration where to sum the elapsed duration when the `ElapsedDuration`
/// * `result` - Duration where to sum the elapsed duration when the `ElapsedDuration`
/// instance goes out of scope
pub fn new(result: &mut Duration) -> ElapsedDuration {
ElapsedDuration {
@ -41,10 +41,10 @@ pub struct ElapsedDurationLogger {
}
impl ElapsedDurationLogger {
pub fn new(message: String) -> ElapsedDurationLogger {
pub fn new<S: Into<String>>(message: S) -> ElapsedDurationLogger {
ElapsedDurationLogger {
_start: Instant::now(),
_message: message,
_message: message.into(),
}
}
}

Просмотреть файл

@ -14,4 +14,5 @@ crate-type = ["cdylib"]
log = { version = "0.4", features = ["std"] }
csv = { version = "1.1" }
pyo3 = { version = "0.15", features = ["extension-module"] }
sds-core = { path = "../core", features = ["pyo3"] }
sds-core = { path = "../core", features = ["pyo3", "rayon"] }
env_logger = { version = "0.9" }

Просмотреть файл

@ -0,0 +1,133 @@
use csv::ReaderBuilder;
use pyo3::prelude::*;
use sds_core::{
data_block::{
block::DataBlock, csv_block_creator::CsvDataBlockCreator, csv_io_error::CsvIOError,
data_block_creator::DataBlockCreator,
},
processing::{
aggregator::{aggregated_data::AggregatedData, Aggregator},
generator::{generated_data::GeneratedData, Generator, SynthesisMode},
},
utils::reporting::LoggerProgressReporter,
};
use std::sync::Arc;
#[pyclass]
/// Processor exposing the main features
pub struct SDSProcessor {
data_block: Arc<DataBlock>,
}
#[pymethods]
impl SDSProcessor {
/// Creates a new processor by reading the content of a given file
/// # Arguments
/// * `path` - File to be read to build the data block
/// * `delimiter` - CSV/TSV separator for the content on `path`
/// * `use_columns` - Column names to be used (if `[]` use all columns)
/// * `sensitive_zeros` - Column names containing sensitive zeros
/// (if `[]` no columns are considered to have sensitive zeros)
/// * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
#[new]
pub fn new(
path: &str,
delimiter: char,
use_columns: Vec<String>,
sensitive_zeros: Vec<String>,
record_limit: usize,
) -> Result<SDSProcessor, CsvIOError> {
match CsvDataBlockCreator::create(
ReaderBuilder::new()
.delimiter(delimiter as u8)
.from_path(path),
&use_columns,
&sensitive_zeros,
record_limit,
) {
Ok(data_block) => Ok(SDSProcessor { data_block }),
Err(err) => Err(CsvIOError::new(err)),
}
}
#[staticmethod]
/// Load the SDS Processor for the data block linked to `aggregated_data`
pub fn from_aggregated_data(aggregated_data: &AggregatedData) -> SDSProcessor {
SDSProcessor {
data_block: aggregated_data.data_block.clone(),
}
}
#[inline]
/// Returns the number of records on the data block
pub fn number_of_records(&self) -> usize {
self.data_block.number_of_records()
}
#[inline]
/// Returns the number of records on the data block protected by `resolution`
pub fn protected_number_of_records(&self, resolution: usize) -> usize {
self.data_block.protected_number_of_records(resolution)
}
#[inline]
/// Normalizes the reporting length based on the number of selected headers.
/// Returns the normalized value
/// # Arguments
/// * `reporting_length` - Reporting length to be normalized (0 means use all columns)
pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize {
self.data_block.normalize_reporting_length(reporting_length)
}
/// Builds the aggregated data for the content
/// using the specified `reporting_length` and `sensitivity_threshold`.
/// The result is written to `sensitive_aggregates_path` and `reportable_aggregates_path`.
/// # Arguments
/// * `reporting_length` - Maximum length to compute attribute combinations
/// * `sensitivity_threshold` - Sensitivity threshold to filter record attributes
/// (0 means no suppression)
pub fn aggregate(
&self,
reporting_length: usize,
sensitivity_threshold: usize,
) -> AggregatedData {
let mut progress_reporter: Option<LoggerProgressReporter> = None;
let mut aggregator = Aggregator::new(self.data_block.clone());
aggregator.aggregate(
reporting_length,
sensitivity_threshold,
&mut progress_reporter,
)
}
/// Synthesizes the content using the specified `resolution` and
/// returns the generated data
/// # Arguments
/// * `cache_max_size` - Maximum cache size used during the synthesis process
/// * `resolution` - Reporting resolution to be used
/// * `empty_value` - Empty values on the synthetic data will be represented by this
/// * `seeded` - True for seeded synthesis, False for unseeded
pub fn generate(
&self,
cache_max_size: usize,
resolution: usize,
empty_value: String,
seeded: bool,
) -> GeneratedData {
let mut progress_reporter: Option<LoggerProgressReporter> = None;
let mut generator = Generator::new(self.data_block.clone());
let mode = if seeded {
SynthesisMode::Seeded
} else {
SynthesisMode::Unseeded
};
generator.generate(
resolution,
cache_max_size,
empty_value,
mode,
&mut progress_reporter,
)
}
}

Просмотреть файл

@ -1,160 +1,24 @@
//! This crate will generate python bindings for the main features
//! of the `sds_core` library.
use csv::ReaderBuilder;
use pyo3::{exceptions::PyIOError, prelude::*};
use data_processor::SDSProcessor;
use pyo3::prelude::*;
use sds_core::{
data_block::{
block::{DataBlock, DataBlockCreator},
csv_block_creator::CsvDataBlockCreator,
},
processing::{
aggregator::{typedefs::AggregatedCountByLenMap, Aggregator},
generator::Generator,
},
utils::reporting::LoggerProgressReporter,
processing::{aggregator::aggregated_data, evaluator},
utils::threading,
};
/// Creates a data block by reading the content of a given file
/// # Arguments
/// * `path` - File to be read to build the data block
/// * `delimiter` - CSV/TSV separator for the content on `path`
/// * `use_columns` - Column names to be used (if `[]` use all columns)
/// * `sensitive_zeros` - Column names containing sensitive zeros
/// (if `[]` no columns are considered to have sensitive zeros)
/// * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
#[pyfunction]
pub fn create_data_block_from_file(
path: &str,
delimiter: char,
use_columns: Vec<String>,
sensitive_zeros: Vec<String>,
record_limit: usize,
) -> PyResult<DataBlock> {
match CsvDataBlockCreator::create(
ReaderBuilder::new()
.delimiter(delimiter as u8)
.from_path(path),
&use_columns,
&sensitive_zeros,
record_limit,
) {
Ok(data_block) => Ok(data_block),
Err(err) => Err(PyIOError::new_err(format!(
"error reading data block: {}",
err
))),
}
}
/// Builds the protected and aggregated data for the content in `data_block`
/// using the specified `reporting_length` and `resolution`.
/// The result is written to `sensitive_aggregates_path` and `reportable_aggregates_path`.
///
/// The function returns a tuple (combo_counts, rare_counts).
///
/// * `combo_counts` - computed number of distinct combinations grouped by combination length (not protected)
/// * `rare_counts` - computed number of rare combinations (`count < resolution`) grouped by combination length (not protected)
///
/// # Arguments
/// * `data_block` - Data block with the content to be aggregated
/// * `sensitive_aggregates_path` - Path to write the sensitive aggregates
/// * `reportable_aggregates_path` - Path to write the aggregates with protected count
/// (rounded down to the nearest multiple of `resolution`)
/// * `delimiter` - CSV/TSV separator for the content be written in
/// `sensitive_aggregates_path` and `reportable_aggregates_path`
/// * `reporting_length` - Maximum length to compute attribute combinations
/// * `resolution` - Reporting resolution to be used
#[pyfunction]
pub fn aggregate_and_write(
data_block: &DataBlock,
sensitive_aggregates_path: &str,
reportable_aggregates_path: &str,
delimiter: char,
reporting_length: usize,
resolution: usize,
) -> PyResult<(AggregatedCountByLenMap, AggregatedCountByLenMap)> {
let mut progress_reporter: Option<LoggerProgressReporter> = None;
let mut aggregator = Aggregator::new(data_block);
let mut aggregated_data = aggregator.aggregate(reporting_length, 0, &mut progress_reporter);
let combo_count = aggregator.calc_combinations_count_by_len(&aggregated_data.aggregates_count);
let rare_count = aggregator
.calc_rare_combinations_count_by_len(&aggregated_data.aggregates_count, resolution);
if let Err(err) = aggregator.write_aggregates_count(
&aggregated_data.aggregates_count,
sensitive_aggregates_path,
delimiter,
resolution,
false,
) {
return Err(PyIOError::new_err(format!(
"error writing sensitive aggregates: {}",
err
)));
}
Aggregator::protect_aggregates_count(&mut aggregated_data.aggregates_count, resolution);
match aggregator.write_aggregates_count(
&aggregated_data.aggregates_count,
reportable_aggregates_path,
delimiter,
resolution,
true,
) {
Ok(_) => Ok((combo_count, rare_count)),
Err(err) => Err(PyIOError::new_err(format!(
"error writing reportable aggregates: {}",
err
))),
}
}
/// Synthesizes the `data_block` using the specified `resolution`.
/// The result is written to `synthetic_path`.
///
/// The function returns the expansion ratio:
/// `Synthetic data length / Sensitive data length` (header not included)
///
/// # Arguments
/// * `data_block` - Data block with the content to be synthesized
/// * `synthetic_path` - Path to write the synthetic data
/// * `synthetic_delimiter` - CSV/TSV separator for the content be written in `synthetic_path`
/// * `cache_max_size` - Maximum cache size used during the synthesis process
/// * `resolution` - Reporting resolution to be used
#[pyfunction]
pub fn generate_seeded_and_write(
data_block: &DataBlock,
synthetic_path: &str,
synthetic_delimiter: char,
cache_max_size: usize,
resolution: usize,
) -> PyResult<f64> {
let mut progress_reporter: Option<LoggerProgressReporter> = None;
let mut generator = Generator::new(data_block);
let generated_data = generator.generate(resolution, cache_max_size, "", &mut progress_reporter);
match generator.write_records(
&generated_data.synthetic_data,
synthetic_path,
synthetic_delimiter as u8,
) {
Ok(_) => Ok(generated_data.expansion_ratio),
Err(err) => Err(PyIOError::new_err(format!(
"error writing synthetic data: {}",
err
))),
}
}
/// Module that exposes the main processor
pub mod data_processor;
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn sds(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(create_data_block_from_file, m)?)?;
m.add_function(wrap_pyfunction!(aggregate_and_write, m)?)?;
m.add_function(wrap_pyfunction!(generate_seeded_and_write, m)?)?;
fn sds(py: Python, m: &PyModule) -> PyResult<()> {
env_logger::init();
m.add_class::<SDSProcessor>()?;
threading::register(py, m)?;
aggregated_data::register(py, m)?;
evaluator::register(py, m)?;
Ok(())
}

Просмотреть файл

@ -19,7 +19,7 @@ csv = { version = "1.1" }
web-sys = { version = "0.3", features = [ "console" ]}
sds-core = { path = "../core" }
js-sys = { version = "0.3" }
serde = { version = "1.0", features = [ "derive" ] }
serde = { version = "1.0", features = [ "derive", "rc" ] }
# The `console_error_panic_hook` crate provides better debugging of panics by
# logging them with `console.error`. This is great for development, but requires

Просмотреть файл

@ -1,116 +0,0 @@
use js_sys::Object;
use js_sys::Reflect::set;
use log::error;
use sds_core::data_block::block::DataBlock;
use sds_core::data_block::typedefs::DataBlockHeaders;
use sds_core::processing::aggregator::typedefs::{AggregatedCountByLenMap, AggregatesCountMap};
use sds_core::processing::aggregator::{Aggregator, PrivacyRiskSummary};
use sds_core::utils::reporting::ReportProgress;
use sds_core::utils::time::ElapsedDurationLogger;
use serde::Serialize;
use wasm_bindgen::JsValue;
use crate::utils::js::serializers::{serialize_aggregates_count, serialize_privacy_risk};
use crate::{match_or_return_undefined, set_or_return_undefined};
#[derive(Serialize)]
pub struct AggregatedCombination {
combination_key: String,
count: usize,
length: usize,
}
impl AggregatedCombination {
#[inline]
pub fn new(combination_key: String, count: usize, length: usize) -> AggregatedCombination {
AggregatedCombination {
combination_key,
count,
length,
}
}
}
impl From<AggregatedCombination> for JsValue {
#[inline]
fn from(aggregated_combination: AggregatedCombination) -> Self {
match_or_return_undefined!(JsValue::from_serde(&aggregated_combination))
}
}
pub struct AggregatedResult<'data_block> {
pub headers: &'data_block DataBlockHeaders,
pub aggregates_count: AggregatesCountMap<'data_block>,
pub rare_combinations_count_by_len: AggregatedCountByLenMap,
pub combinations_count_by_len: AggregatedCountByLenMap,
pub combinations_sum_by_len: AggregatedCountByLenMap,
pub privacy_risk: PrivacyRiskSummary,
}
impl<'data_block> From<AggregatedResult<'data_block>> for JsValue {
#[inline]
fn from(aggregated_result: AggregatedResult) -> Self {
let _duration_logger =
ElapsedDurationLogger::new(String::from("AggregatedResult conversion to JsValue"));
let result = Object::new();
set_or_return_undefined!(
&result,
&"aggregatedCombinations".into(),
&serialize_aggregates_count(
&aggregated_result.aggregates_count,
aggregated_result.headers
),
);
set_or_return_undefined!(
&result,
&"rareCombinationsCountByLen".into(),
&match_or_return_undefined!(JsValue::from_serde(
&aggregated_result.rare_combinations_count_by_len
))
);
set_or_return_undefined!(
&result,
&"combinationsCountByLen".into(),
&match_or_return_undefined!(JsValue::from_serde(
&aggregated_result.combinations_count_by_len
))
);
set_or_return_undefined!(
&result,
&"combinationsSumByLen".into(),
&match_or_return_undefined!(JsValue::from_serde(
&aggregated_result.combinations_sum_by_len
))
);
set_or_return_undefined!(
&result,
&"privacyRisk".into(),
&serialize_privacy_risk(&aggregated_result.privacy_risk)
);
result.into()
}
}
pub fn aggregate<'data_block, T: ReportProgress>(
data_block: &'data_block DataBlock,
reporting_length: usize,
resolution: usize,
progress_reporter: &mut Option<T>,
) -> Result<AggregatedResult<'data_block>, String> {
let mut aggregator = Aggregator::new(data_block);
let aggregates_data = aggregator.aggregate(reporting_length, 0, progress_reporter);
Ok(AggregatedResult {
headers: &data_block.headers,
rare_combinations_count_by_len: aggregator
.calc_rare_combinations_count_by_len(&aggregates_data.aggregates_count, resolution),
combinations_count_by_len: aggregator
.calc_combinations_count_by_len(&aggregates_data.aggregates_count),
combinations_sum_by_len: aggregator
.calc_combinations_sum_by_len(&aggregates_data.aggregates_count),
privacy_risk: aggregator.calc_privacy_risk(&aggregates_data.aggregates_count, resolution),
aggregates_count: aggregates_data.aggregates_count,
})
}

Просмотреть файл

@ -1,170 +0,0 @@
use js_sys::Reflect::set;
use js_sys::{Array, Function, Object};
use log::{error, info};
use sds_core::processing::aggregator::typedefs::AggregatedCountByLenMap;
use sds_core::processing::aggregator::Aggregator;
use sds_core::processing::evaluator::typedefs::PreservationByCountBuckets;
use sds_core::processing::evaluator::Evaluator;
use sds_core::utils::reporting::ReportProgress;
use sds_core::utils::time::ElapsedDurationLogger;
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsValue;
use crate::aggregator::{aggregate, AggregatedResult};
use crate::utils::js::deserializers::{
deserialize_csv_data, deserialize_sensitive_zeros, deserialize_use_columns,
};
use crate::utils::js::js_progress_reporter::JsProgressReporter;
use crate::utils::js::serializers::serialize_buckets;
use crate::{match_or_return_undefined, set_or_return_undefined};
struct EvaluatedResult<'data_block> {
sensitive_aggregated_result: AggregatedResult<'data_block>,
synthetic_aggregated_result: AggregatedResult<'data_block>,
leakage_count_by_len: AggregatedCountByLenMap,
fabricated_count_by_len: AggregatedCountByLenMap,
preservation_by_count_buckets: PreservationByCountBuckets,
combination_loss: f64,
record_expansion: f64,
}
impl<'data_block> From<EvaluatedResult<'data_block>> for JsValue {
#[inline]
fn from(evaluated_result: EvaluatedResult) -> Self {
let _duration_logger =
ElapsedDurationLogger::new(String::from("EvaluatedResult conversion to JsValue"));
let result = Object::new();
set_or_return_undefined!(
&result,
&"sensitiveAggregatedResult".into(),
&evaluated_result.sensitive_aggregated_result.into(),
);
set_or_return_undefined!(
&result,
&"syntheticAggregatedResult".into(),
&evaluated_result.synthetic_aggregated_result.into(),
);
set_or_return_undefined!(
&result,
&"leakageCountByLen".into(),
&match_or_return_undefined!(JsValue::from_serde(
&evaluated_result.leakage_count_by_len
))
);
set_or_return_undefined!(
&result,
&"fabricatedCountByLen".into(),
&match_or_return_undefined!(JsValue::from_serde(
&evaluated_result.fabricated_count_by_len
))
);
set_or_return_undefined!(
&result,
&"preservationByCountBuckets".into(),
&serialize_buckets(&evaluated_result.preservation_by_count_buckets)
);
set_or_return_undefined!(
&result,
&"combinationLoss".into(),
&match_or_return_undefined!(JsValue::from_serde(&evaluated_result.combination_loss))
);
set_or_return_undefined!(
&result,
&"recordExpansion".into(),
&match_or_return_undefined!(JsValue::from_serde(&evaluated_result.record_expansion))
);
result.into()
}
}
#[wasm_bindgen]
#[allow(clippy::too_many_arguments)]
pub fn evaluate(
sensitive_csv_data: Array,
synthetic_csv_data: Array,
use_columns: Array,
sensitive_zeros: Array,
record_limit: usize,
reporting_length: usize,
resolution: usize,
progress_callback: Function,
) -> JsValue {
let _duration_logger = ElapsedDurationLogger::new(String::from("evaluation process"));
let use_columns_vec = match_or_return_undefined!(deserialize_use_columns(use_columns));
let sensitive_zeros_vec =
match_or_return_undefined!(deserialize_sensitive_zeros(sensitive_zeros));
info!("aggregating sensitive data...");
let sensitive_data_block = match_or_return_undefined!(deserialize_csv_data(
sensitive_csv_data,
&use_columns_vec,
&sensitive_zeros_vec,
record_limit,
));
let mut sensitive_aggregated_result = match_or_return_undefined!(aggregate(
&sensitive_data_block,
reporting_length,
resolution,
&mut Some(JsProgressReporter::new(&progress_callback, &|p| p * 0.40))
));
info!("aggregating synthetic data...");
let synthetic_data_block = match_or_return_undefined!(deserialize_csv_data(
synthetic_csv_data,
&use_columns_vec,
&sensitive_zeros_vec,
record_limit,
));
let synthetic_aggregated_result = match_or_return_undefined!(aggregate(
&synthetic_data_block,
reporting_length,
resolution,
&mut Some(JsProgressReporter::new(&progress_callback, &|p| {
40.0 + (p * 0.40)
}))
));
let mut evaluator_instance = Evaluator::default();
info!("evaluating synthetic data based on sensitive data...");
let buckets = evaluator_instance.calc_preservation_by_count(
&sensitive_aggregated_result.aggregates_count,
&synthetic_aggregated_result.aggregates_count,
resolution,
);
let leakage_count_by_len = evaluator_instance.calc_leakage_count(
&sensitive_aggregated_result.aggregates_count,
&synthetic_aggregated_result.aggregates_count,
resolution,
);
let fabricated_count_by_len = evaluator_instance.calc_fabricated_count(
&sensitive_aggregated_result.aggregates_count,
&synthetic_aggregated_result.aggregates_count,
);
let combination_loss = evaluator_instance.calc_combination_loss(&buckets);
Aggregator::protect_aggregates_count(
&mut sensitive_aggregated_result.aggregates_count,
resolution,
);
JsProgressReporter::new(&progress_callback, &|p| p).report(100.0);
EvaluatedResult {
leakage_count_by_len,
fabricated_count_by_len,
combination_loss,
preservation_by_count_buckets: buckets,
record_expansion: (synthetic_aggregated_result
.privacy_risk
.total_number_of_records as f64)
/ (sensitive_aggregated_result
.privacy_risk
.total_number_of_records as f64),
sensitive_aggregated_result,
synthetic_aggregated_result,
}
.into()
}

Просмотреть файл

@ -1,45 +0,0 @@
use js_sys::{Array, Function};
use log::error;
use sds_core::{processing::generator::Generator, utils::time::ElapsedDurationLogger};
use wasm_bindgen::prelude::*;
use crate::{
match_or_return_undefined,
utils::js::{
deserializers::{
deserialize_csv_data, deserialize_sensitive_zeros, deserialize_use_columns,
},
js_progress_reporter::JsProgressReporter,
},
};
#[wasm_bindgen]
pub fn generate(
csv_data: Array,
use_columns: Array,
sensitive_zeros: Array,
record_limit: usize,
resolution: usize,
cache_size: usize,
progress_callback: Function,
) -> JsValue {
let _duration_logger = ElapsedDurationLogger::new(String::from("generation process"));
let data_block = match_or_return_undefined!(deserialize_csv_data(
csv_data,
&match_or_return_undefined!(deserialize_use_columns(use_columns)),
&match_or_return_undefined!(deserialize_sensitive_zeros(sensitive_zeros)),
record_limit,
));
let mut generator = Generator::new(&data_block);
match_or_return_undefined!(JsValue::from_serde(
&generator
.generate(
resolution,
cache_size,
"",
&mut Some(JsProgressReporter::new(&progress_callback, &|p| p))
)
.synthetic_data
))
}

Просмотреть файл

@ -1,141 +1,3 @@
//! This crate will generate wasm bindings for the main features
//! of the `sds_core` library:
//!
//! # Init Logger
//! ```typescript
//! init_logger(
//! level_str: string
//! ): boolean
//! ```
//!
//! Initializes logging using a particular log level.
//! Returns `true` if successfully initialized, `false` otherwise
//!
//! ## Arguments:
//! * `level_str` - String containing a valid log level
//! (`off`, `error`, `warn`, `info`, `debug` or `trace`)
//!
//! # Generate
//! ```typescript
//! generate(
//! csv_data: string[][],
//! use_columns: string[],
//! sensitive_zeros: string[],
//! record_limit: number,
//! resolution: number,
//! cache_size: number,
//! progress_callback: (progress: value) => void
//! ): string[][]
//! ```
//!
//! Synthesizes the `csv_data` using the configured parameters and returns it
//!
//! ## Arguments
//! * `csv_data` - Data to be synthesized
//! * `csv_data[0]` - Should be the headers
//! * `csv_data[1...]` - Should be the records
//! * `use_columns` - Column names to be used (if `[]` use all columns)
//! * `sensitive_zeros` - Column names containing sensitive zeros
//! (if `[]` no columns are considered to have sensitive zeros)
//! * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
//! * `resolution` - Reporting resolution to be used
//! * `cache_size` - Maximum cache size used during the synthesis process
//! * `progress_callback` - Callback that informs the processing percentage (0.0 % - 100.0 %)
//!
//! # Evaluate
//! ```typescript
//! interface IAggregatedCombination {
//! combination_key: number
//! count: number
//! length: number
//! }
//!
//! interface IAggregatedCombinations {
//! [name: string]: IAggregatedCombination
//! }
//!
//! interface IAggregatedCountByLen {
//! [length: number]: number
//! }
//!
//! interface IPrivacyRiskSummary {
//! totalNumberOfRecords: number
//! totalNumberOfCombinations: number
//! recordsWithUniqueCombinationsCount: number
//! recordsWithRareCombinationsCount: number
//! uniqueCombinationsCount: number
//! rareCombinationsCount: number
//! recordsWithUniqueCombinationsProportion: number
//! recordsWithRareCombinationsProportion: number
//! uniqueCombinationsProportion: number
//! rareCombinationsProportion: number
//! }
//!
//! interface IAggregatedResult {
//! aggregatedCombinations?: IAggregatedCombinations
//! rareCombinationsCountByLen?: IAggregatedCountByLen
//! combinationsCountByLen?: IAggregatedCountByLen
//! combinationsSumByLen?: IAggregatedCountByLen
//! privacyRisk?: IPrivacyRiskSummary
//! }
//!
//! interface IPreservationByCountBucket {
//! size: number
//! preservationSum: number
//! lengthSum: number
//! }
//!
//! interface IPreservationByCountBuckets {
//! [bucket_index: number]: IPreservationByCountBucket
//! }
//!
//! interface IEvaluatedResult {
//! sensitiveAggregatedResult?: IAggregatedResult
//! syntheticAggregatedResult?: IAggregatedResult
//! leakageCountByLen?: IAggregatedCountByLen
//! fabricatedCountByLen?: IAggregatedCountByLen
//! preservationByCountBuckets?: IPreservationByCountBuckets
//! combinationLoss?: number
//! recordExpansion?: number
//! }
//!
//! evaluate(
//! sensitive_csv_data: string[][],
//! synthetic_csv_data: string[][],
//! use_columns: string[],
//! sensitive_zeros: string[],
//! record_limit: number,
//! reporting_length: number,
//! resolution: number,
//! progress_callback: (progress: value) => void
//! ): IEvaluatedResult
//! ```
//! Evaluates the synthetic data based on the sensitive data and produces a `IEvaluatedResult`
//!
//! ## Arguments
//! * `sensitive_csv_data` - Sensitive data to be evaluated
//! * `sensitive_csv_data[0]` - Should be the headers
//! * `sensitive_csv_data[1...]` - Should be the records
//! * `synthetic_csv_data` - Synthetic data produced from the synthetic data
//! * `synthetic_csv_data[0]` - Should be the headers
//! * `synthetic_csv_data[1...]` - Should be the records
//! * `use_columns` - Column names to be used (if `[]` use all columns)
//! * `sensitive_zeros` - Column names containing sensitive zeros
//! (if `[]` no columns are considered to have sensitive zeros)
//! * `record_limit` - Use only the first `record_limit` records (if `0` use all records)
//! * `reporting_length` - Maximum length to compute attribute combinations
//! for analysis
//! * `resolution` - Reporting resolution to be used
//! * `progress_callback` - Callback that informs the processing percentage (0.0 % - 100.0 %)
pub mod processing;
#[doc(hidden)]
pub mod aggregator;
#[doc(hidden)]
pub mod evaluator;
#[doc(hidden)]
pub mod generator;
#[doc(hidden)]
pub mod utils;

Просмотреть файл

@ -0,0 +1,31 @@
use serde::Serialize;
use std::convert::TryFrom;
use wasm_bindgen::{JsCast, JsValue};
use crate::utils::js::ts_definitions::JsAggregateCountAndLength;
#[derive(Serialize)]
pub struct WasmAggregateCountAndLength {
count: usize,
length: usize,
}
impl WasmAggregateCountAndLength {
#[inline]
pub fn new(count: usize, length: usize) -> WasmAggregateCountAndLength {
WasmAggregateCountAndLength { count, length }
}
}
impl TryFrom<WasmAggregateCountAndLength> for JsAggregateCountAndLength {
type Error = JsValue;
#[inline]
fn try_from(
aggregate_count_and_length: WasmAggregateCountAndLength,
) -> Result<Self, Self::Error> {
JsValue::from_serde(&aggregate_count_and_length)
.map_err(|err| JsValue::from(err.to_string()))
.map(|c| c.unchecked_into())
}
}

Просмотреть файл

@ -0,0 +1,228 @@
use super::aggregate_count_and_length::WasmAggregateCountAndLength;
use js_sys::{Object, Reflect::set};
use sds_core::{
processing::aggregator::aggregated_data::AggregatedData, utils::time::ElapsedDurationLogger,
};
use std::{
convert::TryFrom,
ops::{Deref, DerefMut},
};
use wasm_bindgen::{prelude::*, JsCast};
use crate::utils::js::ts_definitions::{
JsAggregateCountAndLength, JsAggregateCountByLen, JsAggregateResult, JsAggregatesCount,
JsPrivacyRiskSummary, JsResult,
};
#[wasm_bindgen]
pub struct WasmAggregateResult {
aggregated_data: AggregatedData,
}
impl WasmAggregateResult {
#[inline]
pub fn default() -> WasmAggregateResult {
WasmAggregateResult {
aggregated_data: AggregatedData::default(),
}
}
#[inline]
pub fn new(aggregated_data: AggregatedData) -> WasmAggregateResult {
WasmAggregateResult { aggregated_data }
}
}
#[wasm_bindgen]
impl WasmAggregateResult {
#[wasm_bindgen(getter)]
#[wasm_bindgen(js_name = "reportingLength")]
pub fn reporting_length(&self) -> usize {
self.aggregated_data.reporting_length
}
#[wasm_bindgen(js_name = "aggregatesCountToJs")]
pub fn aggregates_count_to_js(
&self,
combination_delimiter: &str,
) -> JsResult<JsAggregatesCount> {
let result = Object::new();
for (agg, count) in self.aggregated_data.aggregates_count.iter() {
set(
&result,
&agg.format_str_using_headers(
&self.aggregated_data.data_block.headers,
combination_delimiter,
)
.into(),
&JsAggregateCountAndLength::try_from(WasmAggregateCountAndLength::new(
count.count,
agg.len(),
))?
.into(),
)?;
}
Ok(result.unchecked_into::<JsAggregatesCount>())
}
#[wasm_bindgen(js_name = "rareCombinationsCountByLenToJs")]
pub fn rare_combinations_count_by_len_to_js(
&self,
resolution: usize,
) -> JsResult<JsAggregateCountByLen> {
let count = self
.aggregated_data
.calc_rare_combinations_count_by_len(resolution);
Ok(JsValue::from_serde(&count)
.map_err(|err| JsValue::from(err.to_string()))?
.unchecked_into::<JsAggregateCountByLen>())
}
#[wasm_bindgen(js_name = "combinationsCountByLenToJs")]
pub fn combinations_count_by_len_to_js(&self) -> JsResult<JsAggregateCountByLen> {
let count = self.aggregated_data.calc_combinations_count_by_len();
Ok(JsValue::from_serde(&count)
.map_err(|err| JsValue::from(err.to_string()))?
.unchecked_into::<JsAggregateCountByLen>())
}
#[wasm_bindgen(js_name = "combinationsSumByLenToJs")]
pub fn combinations_sum_by_len_to_js(&self) -> JsResult<JsAggregateCountByLen> {
let count = self.aggregated_data.calc_combinations_sum_by_len();
Ok(JsValue::from_serde(&count)
.map_err(|err| JsValue::from(err.to_string()))?
.unchecked_into::<JsAggregateCountByLen>())
}
#[wasm_bindgen(js_name = "privacyRiskToJs")]
pub fn privacy_risk_to_js(&self, resolution: usize) -> JsResult<JsPrivacyRiskSummary> {
let pr = self.aggregated_data.calc_privacy_risk(resolution);
let result = Object::new();
set(
&result,
&"totalNumberOfRecords".into(),
&pr.total_number_of_records.into(),
)?;
set(
&result,
&"totalNumberOfCombinations".into(),
&pr.total_number_of_combinations.into(),
)?;
set(
&result,
&"recordsWithUniqueCombinationsCount".into(),
&pr.records_with_unique_combinations_count.into(),
)?;
set(
&result,
&"recordsWithRareCombinationsCount".into(),
&pr.records_with_rare_combinations_count.into(),
)?;
set(
&result,
&"uniqueCombinationsCount".into(),
&pr.unique_combinations_count.into(),
)?;
set(
&result,
&"rareCombinationsCount".into(),
&pr.rare_combinations_count.into(),
)?;
set(
&result,
&"recordsWithUniqueCombinationsProportion".into(),
&pr.records_with_unique_combinations_proportion.into(),
)?;
set(
&result,
&"recordsWithRareCombinationsProportion".into(),
&pr.records_with_rare_combinations_proportion.into(),
)?;
set(
&result,
&"uniqueCombinationsProportion".into(),
&pr.unique_combinations_proportion.into(),
)?;
set(
&result,
&"rareCombinationsProportion".into(),
&pr.rare_combinations_proportion.into(),
)?;
Ok(result.unchecked_into::<JsPrivacyRiskSummary>())
}
#[wasm_bindgen(js_name = "toJs")]
pub fn to_js(
&self,
combination_delimiter: &str,
resolution: usize,
include_aggregates_count: bool,
) -> JsResult<JsAggregateResult> {
let _duration_logger =
ElapsedDurationLogger::new(String::from("aggregate result serialization"));
let result = Object::new();
set(
&result,
&"reportingLength".into(),
&self.reporting_length().into(),
)?;
if include_aggregates_count {
set(
&result,
&"aggregatesCount".into(),
&self.aggregates_count_to_js(combination_delimiter)?.into(),
)?;
}
set(
&result,
&"rareCombinationsCountByLen".into(),
&self
.rare_combinations_count_by_len_to_js(resolution)?
.into(),
)?;
set(
&result,
&"combinationsCountByLen".into(),
&self.combinations_count_by_len_to_js()?.into(),
)?;
set(
&result,
&"combinationsSumByLen".into(),
&self.combinations_sum_by_len_to_js()?.into(),
)?;
set(
&result,
&"privacyRisk".into(),
&self.privacy_risk_to_js(resolution)?.into(),
)?;
Ok(JsValue::from(result).unchecked_into::<JsAggregateResult>())
}
#[wasm_bindgen(js_name = "protectAggregatesCount")]
pub fn protect_aggregates_count(&mut self, resolution: usize) {
self.aggregated_data.protect_aggregates_count(resolution)
}
}
impl Deref for WasmAggregateResult {
type Target = AggregatedData;
fn deref(&self) -> &Self::Target {
&self.aggregated_data
}
}
impl DerefMut for WasmAggregateResult {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.aggregated_data
}
}

Просмотреть файл

@ -0,0 +1,3 @@
pub mod aggregate_count_and_length;
pub mod aggregate_result;

Просмотреть файл

@ -0,0 +1,184 @@
use super::preservation_by_count::WasmPreservationByCount;
use js_sys::{Object, Reflect::set};
use sds_core::{processing::evaluator::Evaluator, utils::time::ElapsedDurationLogger};
use wasm_bindgen::{prelude::*, JsCast};
use crate::{
processing::aggregator::aggregate_result::WasmAggregateResult,
utils::js::ts_definitions::{
JsAggregateCountByLen, JsAggregateResult, JsEvaluateResult, JsResult,
},
};
#[wasm_bindgen]
pub struct WasmEvaluateResult {
pub(crate) sensitive_aggregate_result: WasmAggregateResult,
pub(crate) synthetic_aggregate_result: WasmAggregateResult,
evaluator: Evaluator,
}
impl WasmEvaluateResult {
#[inline]
pub fn default() -> WasmEvaluateResult {
WasmEvaluateResult {
sensitive_aggregate_result: WasmAggregateResult::default(),
synthetic_aggregate_result: WasmAggregateResult::default(),
evaluator: Evaluator::default(),
}
}
#[inline]
pub fn new(
sensitive_aggregate_result: WasmAggregateResult,
synthetic_aggregate_result: WasmAggregateResult,
) -> WasmEvaluateResult {
WasmEvaluateResult {
sensitive_aggregate_result,
synthetic_aggregate_result,
evaluator: Evaluator::default(),
}
}
}
#[wasm_bindgen]
impl WasmEvaluateResult {
#[wasm_bindgen(constructor)]
pub fn from_aggregate_results(
sensitive_aggregate_result: WasmAggregateResult,
synthetic_aggregate_result: WasmAggregateResult,
) -> WasmEvaluateResult {
WasmEvaluateResult::new(sensitive_aggregate_result, synthetic_aggregate_result)
}
#[wasm_bindgen(js_name = "sensitiveAggregateResultToJs")]
pub fn sensitive_aggregate_result_to_js(
&self,
combination_delimiter: &str,
resolution: usize,
include_aggregates_count: bool,
) -> JsResult<JsAggregateResult> {
self.sensitive_aggregate_result.to_js(
combination_delimiter,
resolution,
include_aggregates_count,
)
}
#[wasm_bindgen(js_name = "syntheticAggregateResultToJs")]
pub fn synthetic_aggregate_result_to_js(
&self,
combination_delimiter: &str,
resolution: usize,
include_aggregates_count: bool,
) -> JsResult<JsAggregateResult> {
self.synthetic_aggregate_result.to_js(
combination_delimiter,
resolution,
include_aggregates_count,
)
}
#[wasm_bindgen(js_name = "leakageCountByLenToJs")]
pub fn leakage_count_by_len_to_js(&self, resolution: usize) -> JsResult<JsAggregateCountByLen> {
let count = self.evaluator.calc_leakage_count(
&self.sensitive_aggregate_result,
&self.synthetic_aggregate_result,
resolution,
);
Ok(JsValue::from_serde(&count)
.map_err(|err| JsValue::from(err.to_string()))?
.unchecked_into::<JsAggregateCountByLen>())
}
#[wasm_bindgen(js_name = "fabricatedCountByLenToJs")]
pub fn fabricated_count_by_len_to_js(&self) -> JsResult<JsAggregateCountByLen> {
let count = self.evaluator.calc_fabricated_count(
&self.sensitive_aggregate_result,
&self.synthetic_aggregate_result,
);
Ok(JsValue::from_serde(&count)
.map_err(|err| JsValue::from(err.to_string()))?
.unchecked_into::<JsAggregateCountByLen>())
}
#[wasm_bindgen(js_name = "preservationByCount")]
pub fn preservation_by_count(&self, resolution: usize) -> WasmPreservationByCount {
WasmPreservationByCount::new(self.evaluator.calc_preservation_by_count(
&self.sensitive_aggregate_result,
&self.synthetic_aggregate_result,
resolution,
))
}
#[wasm_bindgen(js_name = "recordExpansion")]
pub fn record_expansion(&self) -> f64 {
(self
.synthetic_aggregate_result
.data_block
.number_of_records() as f64)
/ (self
.sensitive_aggregate_result
.data_block
.number_of_records() as f64)
}
#[wasm_bindgen(js_name = "toJs")]
pub fn to_js(
&self,
combination_delimiter: &str,
resolution: usize,
include_aggregates_count: bool,
) -> JsResult<JsEvaluateResult> {
let _duration_logger =
ElapsedDurationLogger::new(String::from("evaluate result serialization"));
let result = Object::new();
set(
&result,
&"sensitiveAggregateResult".into(),
&self
.sensitive_aggregate_result_to_js(
combination_delimiter,
resolution,
include_aggregates_count,
)?
.into(),
)?;
set(
&result,
&"syntheticAggregateResult".into(),
&self
.synthetic_aggregate_result_to_js(
combination_delimiter,
resolution,
include_aggregates_count,
)?
.into(),
)?;
set(
&result,
&"leakageCountByLen".into(),
&self.leakage_count_by_len_to_js(resolution)?.into(),
)?;
set(
&result,
&"fabricatedCountByLen".into(),
&self.fabricated_count_by_len_to_js()?.into(),
)?;
set(
&result,
&"preservationByCount".into(),
&self.preservation_by_count(resolution).to_js()?.into(),
)?;
set(
&result,
&"recordExpansion".into(),
&self.record_expansion().into(),
)?;
Ok(JsValue::from(result).unchecked_into::<JsEvaluateResult>())
}
}

Просмотреть файл

@ -0,0 +1,3 @@
pub mod preservation_by_count;
pub mod evaluate_result;

Просмотреть файл

@ -0,0 +1,93 @@
use js_sys::{Object, Reflect::set};
use sds_core::{
processing::evaluator::preservation_by_count::PreservationByCountBuckets,
utils::time::ElapsedDurationLogger,
};
use std::ops::{Deref, DerefMut};
use wasm_bindgen::{prelude::*, JsCast};
use crate::utils::js::ts_definitions::{
JsPreservationByCount, JsPreservationByCountBuckets, JsResult,
};
#[wasm_bindgen]
pub struct WasmPreservationByCount {
preservation_by_count_buckets: PreservationByCountBuckets,
}
impl WasmPreservationByCount {
#[inline]
pub fn new(
preservation_by_count_buckets: PreservationByCountBuckets,
) -> WasmPreservationByCount {
WasmPreservationByCount {
preservation_by_count_buckets,
}
}
}
#[wasm_bindgen]
impl WasmPreservationByCount {
#[wasm_bindgen(js_name = "bucketsToJs")]
pub fn buckets_to_js(&self) -> JsResult<JsPreservationByCountBuckets> {
let result = Object::new();
for (bucket_index, b) in self.preservation_by_count_buckets.iter() {
let serialized_bucket = Object::new();
set(&serialized_bucket, &"size".into(), &b.size.into())?;
set(
&serialized_bucket,
&"preservationSum".into(),
&b.preservation_sum.into(),
)?;
set(
&serialized_bucket,
&"lengthSum".into(),
&b.length_sum.into(),
)?;
set(
&serialized_bucket,
&"combinationCountSum".into(),
&b.combination_count_sum.into(),
)?;
set(&result, &(*bucket_index).into(), &serialized_bucket.into())?;
}
Ok(result.unchecked_into::<JsPreservationByCountBuckets>())
}
#[wasm_bindgen(js_name = "combinationLoss")]
pub fn combination_loss(&self) -> f64 {
self.preservation_by_count_buckets.calc_combination_loss()
}
#[wasm_bindgen(js_name = "toJs")]
pub fn to_js(&self) -> JsResult<JsPreservationByCount> {
let _duration_logger =
ElapsedDurationLogger::new(String::from("preservation by count serialization"));
let result = Object::new();
set(&result, &"buckets".into(), &self.buckets_to_js()?.into())?;
set(
&result,
&"combinationLoss".into(),
&self.combination_loss().into(),
)?;
Ok(JsValue::from(result).unchecked_into::<JsPreservationByCount>())
}
}
impl Deref for WasmPreservationByCount {
type Target = PreservationByCountBuckets;
fn deref(&self) -> &Self::Target {
&self.preservation_by_count_buckets
}
}
impl DerefMut for WasmPreservationByCount {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.preservation_by_count_buckets
}
}

Просмотреть файл

@ -0,0 +1,77 @@
use js_sys::{Object, Reflect::set};
use sds_core::{
processing::generator::generated_data::GeneratedData, utils::time::ElapsedDurationLogger,
};
use std::ops::{Deref, DerefMut};
use wasm_bindgen::{prelude::*, JsCast};
use crate::utils::js::ts_definitions::{JsCsvData, JsGenerateResult, JsResult};
#[wasm_bindgen]
pub struct WasmGenerateResult {
generated_data: GeneratedData,
}
impl WasmGenerateResult {
#[inline]
pub fn default() -> WasmGenerateResult {
WasmGenerateResult {
generated_data: GeneratedData::default(),
}
}
#[inline]
pub fn new(generated_data: GeneratedData) -> WasmGenerateResult {
WasmGenerateResult { generated_data }
}
}
#[wasm_bindgen]
impl WasmGenerateResult {
#[wasm_bindgen(getter)]
#[wasm_bindgen(js_name = "expansionRatio")]
pub fn expansion_ratio(&self) -> f64 {
self.generated_data.expansion_ratio
}
#[wasm_bindgen(js_name = "syntheticDataToJs")]
pub fn synthetic_data_to_js(&self) -> JsResult<JsCsvData> {
Ok(JsValue::from_serde(&self.generated_data.synthetic_data)
.map_err(|err| JsValue::from(err.to_string()))?
.unchecked_into::<JsCsvData>())
}
#[wasm_bindgen(js_name = "toJs")]
pub fn to_js(&self) -> JsResult<JsGenerateResult> {
let _duration_logger =
ElapsedDurationLogger::new(String::from("generate result serialization"));
let result = Object::new();
set(
&result,
&"expansionRatio".into(),
&self.expansion_ratio().into(),
)?;
set(
&result,
&"syntheticData".into(),
&self.synthetic_data_to_js()?.into(),
)?;
Ok(JsValue::from(result).unchecked_into::<JsGenerateResult>())
}
}
impl Deref for WasmGenerateResult {
type Target = GeneratedData;
fn deref(&self) -> &Self::Target {
&self.generated_data
}
}
impl DerefMut for WasmGenerateResult {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.generated_data
}
}

Просмотреть файл

@ -0,0 +1 @@
pub mod generate_result;

Просмотреть файл

@ -0,0 +1,11 @@
pub mod aggregator;
pub mod evaluator;
pub mod generator;
pub mod navigator;
pub mod sds_processor;
pub mod sds_context;

Просмотреть файл

@ -0,0 +1,75 @@
use js_sys::{Array, Object, Reflect::set};
use std::{collections::HashMap, convert::TryFrom, sync::Arc};
use wasm_bindgen::{JsCast, JsValue};
use crate::utils::js::ts_definitions::{
JsAttributesIntersection, JsAttributesIntersectionByColumn,
};
pub struct WasmAttributesIntersection {
pub(crate) value: Arc<String>,
pub(crate) estimated_count: usize,
pub(crate) actual_count: Option<usize>,
}
impl WasmAttributesIntersection {
#[inline]
pub fn new(
value: Arc<String>,
estimated_count: usize,
actual_count: Option<usize>,
) -> WasmAttributesIntersection {
WasmAttributesIntersection {
value,
estimated_count,
actual_count,
}
}
}
impl TryFrom<WasmAttributesIntersection> for JsAttributesIntersection {
type Error = JsValue;
fn try_from(attributes_intersection: WasmAttributesIntersection) -> Result<Self, Self::Error> {
let result = Object::new();
set(
&result,
&"value".into(),
&attributes_intersection.value.as_str().into(),
)?;
set(
&result,
&"estimatedCount".into(),
&attributes_intersection.estimated_count.into(),
)?;
if let Some(actual_count) = attributes_intersection.actual_count {
set(&result, &"actualCount".into(), &actual_count.into())?;
}
Ok(result.unchecked_into::<JsAttributesIntersection>())
}
}
pub type WasmAttributesIntersectionByColumn = HashMap<usize, Vec<WasmAttributesIntersection>>;
impl TryFrom<WasmAttributesIntersectionByColumn> for JsAttributesIntersectionByColumn {
type Error = JsValue;
fn try_from(
mut intersections_by_column: WasmAttributesIntersectionByColumn,
) -> Result<Self, Self::Error> {
let result = Object::new();
for (column_index, mut attrs_intersections) in intersections_by_column.drain() {
let intersections: Array = Array::default();
for intersection in attrs_intersections.drain(..) {
intersections.push(&JsAttributesIntersection::try_from(intersection)?.into());
}
set(&result, &column_index.into(), &intersections)?;
}
Ok(result.unchecked_into::<JsAttributesIntersectionByColumn>())
}
}

Просмотреть файл

@ -0,0 +1,5 @@
pub mod attributes_intersection;
pub mod navigate_result;
pub mod selected_attributes;

Просмотреть файл

@ -0,0 +1,229 @@
use super::{
attributes_intersection::{WasmAttributesIntersection, WasmAttributesIntersectionByColumn},
selected_attributes::{WasmSelectedAttributes, WasmSelectedAttributesByColumn},
};
use sds_core::{
data_block::{
block::DataBlock,
typedefs::{
AttributeRows, AttributeRowsByColumnMap, AttributeRowsMap, ColumnIndexByName, CsvRecord,
},
value::DataBlockValue,
},
processing::aggregator::value_combination::ValueCombination,
utils::collections::ordered_vec_intersection,
};
use std::{cmp::Reverse, convert::TryFrom, sync::Arc};
use wasm_bindgen::prelude::*;
use crate::{
processing::{aggregator::aggregate_result::WasmAggregateResult, sds_processor::SDSProcessor},
utils::js::ts_definitions::{
JsAttributesIntersectionByColumn, JsHeaderNames, JsResult, JsSelectedAttributesByColumn,
},
};
#[wasm_bindgen]
pub struct WasmNavigateResult {
synthetic_data_block: Arc<DataBlock>,
attr_rows_by_column: AttributeRowsByColumnMap,
selected_attributes: WasmSelectedAttributesByColumn,
selected_attr_rows: AttributeRows,
all_attr_rows: AttributeRows,
column_index_by_name: ColumnIndexByName,
}
impl WasmNavigateResult {
#[inline]
pub fn default() -> WasmNavigateResult {
WasmNavigateResult::new(
Arc::new(DataBlock::default()),
AttributeRowsByColumnMap::default(),
WasmSelectedAttributesByColumn::default(),
AttributeRows::default(),
AttributeRows::default(),
ColumnIndexByName::default(),
)
}
#[inline]
pub fn new(
synthetic_data_block: Arc<DataBlock>,
attr_rows_by_column: AttributeRowsByColumnMap,
selected_attributes: WasmSelectedAttributesByColumn,
selected_attr_rows: AttributeRows,
all_attr_rows: AttributeRows,
column_index_by_name: ColumnIndexByName,
) -> WasmNavigateResult {
WasmNavigateResult {
synthetic_data_block,
attr_rows_by_column,
selected_attributes,
selected_attr_rows,
all_attr_rows,
column_index_by_name,
}
}
#[inline]
fn intersect_attributes(&self, attributes: &WasmSelectedAttributesByColumn) -> AttributeRows {
let default_empty_attr_rows = AttributeRowsMap::default();
let default_empty_rows = AttributeRows::default();
let mut result = self.all_attr_rows.clone();
for (column_index, values) in attributes.iter() {
let attr_rows = self
.attr_rows_by_column
.get(column_index)
.unwrap_or(&default_empty_attr_rows);
for value in values.iter() {
result = ordered_vec_intersection(
&result,
attr_rows.get(value).unwrap_or(&default_empty_rows),
);
}
}
result
}
#[inline]
fn get_sensitive_aggregates_count(
&self,
attributes: &WasmSelectedAttributesByColumn,
sensitive_aggregate_result: &WasmAggregateResult,
) -> Option<usize> {
let mut combination: Vec<Arc<DataBlockValue>> = attributes
.values()
.flat_map(|values| values.iter().cloned())
.collect();
combination.sort_by_key(|k| k.format_str_using_headers(&self.synthetic_data_block.headers));
sensitive_aggregate_result
.aggregates_count
.get(&ValueCombination::new(combination))
.map(|c| c.count)
}
#[inline]
fn get_selected_but_current_column(
&self,
column_index: usize,
) -> (WasmSelectedAttributesByColumn, AttributeRows) {
let mut selected_attributes_but_current_column = self.selected_attributes.clone();
selected_attributes_but_current_column.remove(&column_index);
let selected_attr_rows_but_current_column =
self.intersect_attributes(&selected_attributes_but_current_column);
(
selected_attributes_but_current_column,
selected_attr_rows_but_current_column,
)
}
#[inline]
fn selected_attributes_contains_value(
&self,
column_index: usize,
value: &Arc<DataBlockValue>,
) -> bool {
self.selected_attributes
.get(&column_index)
.map(|values| values.contains(value))
.unwrap_or(false)
}
}
#[wasm_bindgen]
impl WasmNavigateResult {
#[wasm_bindgen(constructor)]
pub fn from_synthetic_processor(synthetic_processor: &SDSProcessor) -> WasmNavigateResult {
WasmNavigateResult::new(
synthetic_processor.data_block.clone(),
synthetic_processor
.data_block
.calc_attr_rows_with_no_empty_values(),
WasmSelectedAttributesByColumn::default(),
AttributeRows::default(),
(0..synthetic_processor.data_block.number_of_records()).collect(),
synthetic_processor.data_block.calc_column_index_by_name(),
)
}
#[wasm_bindgen(js_name = "selectAttributes")]
pub fn select_attributes(&mut self, attributes: JsSelectedAttributesByColumn) -> JsResult<()> {
self.selected_attributes = WasmSelectedAttributesByColumn::try_from(attributes)?;
self.selected_attr_rows = self.intersect_attributes(&self.selected_attributes);
Ok(())
}
#[wasm_bindgen(js_name = "attributesIntersectionsByColumn")]
pub fn attributes_intersections_by_column(
&mut self,
columns: JsHeaderNames,
sensitive_aggregate_result: &WasmAggregateResult,
) -> JsResult<JsAttributesIntersectionByColumn> {
let column_indexes: AttributeRows = CsvRecord::try_from(columns)?
.iter()
.filter_map(|header_name| self.column_index_by_name.get(header_name))
.cloned()
.collect();
let mut result = WasmAttributesIntersectionByColumn::default();
for column_index in column_indexes {
let result_column_entry = result.entry(column_index).or_insert_with(Vec::default);
if let Some(attr_rows) = self.attr_rows_by_column.get(&column_index) {
let (selected_attributes_but_current_column, selected_attr_rows_but_current_column) =
self.get_selected_but_current_column(column_index);
for (value, rows) in attr_rows {
let current_selected_attr_rows;
let current_selected_attributes;
if self.selected_attributes_contains_value(column_index, value) {
// if the selected attributes contain the current value, we
// have already the selected attr rows cached
current_selected_attributes = &self.selected_attributes;
current_selected_attr_rows = &self.selected_attr_rows;
} else {
// if the selected attributes do not contain the current value
// use the intersection between all selected values, but the
// ones in the current column
current_selected_attributes = &selected_attributes_but_current_column;
current_selected_attr_rows = &selected_attr_rows_but_current_column;
};
let estimated_count =
ordered_vec_intersection(current_selected_attr_rows, rows).len();
if estimated_count > 0 {
let mut attributes = current_selected_attributes.clone();
let attributes_entry = attributes
.entry(column_index)
.or_insert_with(WasmSelectedAttributes::default);
attributes_entry.insert(value.clone());
result_column_entry.push(WasmAttributesIntersection::new(
value.value.clone(),
estimated_count,
self.get_sensitive_aggregates_count(
&attributes,
sensitive_aggregate_result,
),
));
}
}
}
}
// sort by estimated count in descending order
for intersections in result.values_mut() {
intersections.sort_by_key(|intersection| Reverse(intersection.estimated_count))
}
JsAttributesIntersectionByColumn::try_from(result)
}
}

Просмотреть файл

@ -0,0 +1,48 @@
use js_sys::{Array, Object, Set};
use sds_core::data_block::value::DataBlockValue;
use std::{
collections::{HashMap, HashSet},
convert::TryFrom,
sync::Arc,
};
use wasm_bindgen::{JsCast, JsValue};
use crate::utils::js::ts_definitions::JsSelectedAttributesByColumn;
pub type WasmSelectedAttributes = HashSet<Arc<DataBlockValue>>;
pub type WasmSelectedAttributesByColumn = HashMap<usize, WasmSelectedAttributes>;
impl TryFrom<JsSelectedAttributesByColumn> for WasmSelectedAttributesByColumn {
type Error = JsValue;
fn try_from(js_selected_attributes: JsSelectedAttributesByColumn) -> Result<Self, Self::Error> {
let mut result = WasmSelectedAttributesByColumn::default();
for entry_res in Object::entries(&js_selected_attributes.dyn_into::<Object>()?).values() {
let entry = entry_res?.dyn_into::<Array>()?;
let column_index = entry
.get(0)
.as_string()
.ok_or_else(|| {
JsValue::from("invalid column index on selected attributes by column")
})?
.parse::<usize>()
.map_err(|err| JsValue::from(err.to_string()))?;
let result_entry = result.entry(column_index).or_insert_with(HashSet::default);
for value_res in entry.get(1).dyn_into::<Set>()?.keys() {
let value = value_res?;
let value_str = value.as_string().ok_or_else(|| {
JsValue::from("invalid value on selected attributes by column")
})?;
result_entry.insert(Arc::new(DataBlockValue::new(
column_index,
Arc::new(value_str),
)));
}
}
Ok(result)
}
}

Просмотреть файл

@ -0,0 +1,209 @@
use super::{
evaluator::evaluate_result::WasmEvaluateResult, generator::generate_result::WasmGenerateResult,
navigator::navigate_result::WasmNavigateResult, sds_processor::SDSProcessor,
};
use log::debug;
use wasm_bindgen::{prelude::*, JsCast};
use crate::utils::js::ts_definitions::{
JsAttributesIntersectionByColumn, JsCsvData, JsEvaluateResult, JsGenerateResult, JsHeaderNames,
JsReportProgressCallback, JsResult, JsSelectedAttributesByColumn,
};
#[wasm_bindgen]
pub struct SDSContext {
use_columns: JsHeaderNames,
sensitive_zeros: JsHeaderNames,
record_limit: usize,
sensitive_processor: SDSProcessor,
generate_result: WasmGenerateResult,
resolution: usize,
synthetic_processor: SDSProcessor,
evaluate_result: WasmEvaluateResult,
navigate_result: WasmNavigateResult,
}
#[wasm_bindgen]
impl SDSContext {
#[wasm_bindgen(constructor)]
pub fn default() -> SDSContext {
SDSContext {
use_columns: JsHeaderNames::default(),
sensitive_zeros: JsHeaderNames::default(),
record_limit: 0,
sensitive_processor: SDSProcessor::default(),
generate_result: WasmGenerateResult::default(),
resolution: 0,
synthetic_processor: SDSProcessor::default(),
evaluate_result: WasmEvaluateResult::default(),
navigate_result: WasmNavigateResult::default(),
}
}
#[wasm_bindgen(js_name = "clearSensitiveData")]
pub fn clear_sensitive_data(&mut self) {
self.use_columns = JsHeaderNames::default();
self.sensitive_zeros = JsHeaderNames::default();
self.record_limit = 0;
self.sensitive_processor = SDSProcessor::default();
self.clear_generate();
}
#[wasm_bindgen(js_name = "clearGenerate")]
pub fn clear_generate(&mut self) {
self.generate_result = WasmGenerateResult::default();
self.resolution = 0;
self.synthetic_processor = SDSProcessor::default();
self.clear_evaluate()
}
#[wasm_bindgen(js_name = "clearEvaluate")]
pub fn clear_evaluate(&mut self) {
self.evaluate_result = WasmEvaluateResult::default();
self.clear_navigate()
}
#[wasm_bindgen(js_name = "clearNavigate")]
pub fn clear_navigate(&mut self) {
self.navigate_result = WasmNavigateResult::default();
}
#[wasm_bindgen(js_name = "setSensitiveData")]
pub fn set_sensitive_data(
&mut self,
csv_data: JsCsvData,
use_columns: JsHeaderNames,
sensitive_zeros: JsHeaderNames,
record_limit: usize,
) -> JsResult<()> {
debug!("setting sensitive data...");
self.use_columns = use_columns;
self.sensitive_zeros = sensitive_zeros;
self.record_limit = record_limit;
self.sensitive_processor = SDSProcessor::new(
csv_data,
self.use_columns.clone().unchecked_into(),
self.sensitive_zeros.clone().unchecked_into(),
self.record_limit,
)?;
self.clear_generate();
Ok(())
}
pub fn generate(
&mut self,
cache_max_size: usize,
resolution: usize,
empty_value: String,
seeded: bool,
progress_callback: JsReportProgressCallback,
) -> JsResult<()> {
debug!("generating synthetic data...");
self.generate_result = self.sensitive_processor.generate(
cache_max_size,
resolution,
empty_value,
seeded,
progress_callback,
)?;
self.resolution = resolution;
debug!("creating synthetic data processor...");
self.synthetic_processor = SDSProcessor::new(
self.generate_result.synthetic_data_to_js()?,
self.use_columns.clone().unchecked_into(),
self.sensitive_zeros.clone().unchecked_into(),
self.record_limit,
)?;
self.clear_evaluate();
Ok(())
}
pub fn evaluate(
&mut self,
reporting_length: usize,
sensitivity_threshold: usize,
sensitive_progress_callback: JsReportProgressCallback,
synthetic_progress_callback: JsReportProgressCallback,
) -> JsResult<()> {
debug!("aggregating sensitive data...");
let sensitive_aggregate_result = self.sensitive_processor.aggregate(
reporting_length,
sensitivity_threshold,
sensitive_progress_callback,
)?;
debug!("aggregating synthetic data...");
let synthetic_aggregate_result = self.synthetic_processor.aggregate(
reporting_length,
sensitivity_threshold,
synthetic_progress_callback,
)?;
debug!("evaluating synthetic data based on sensitive data...");
self.evaluate_result = WasmEvaluateResult::from_aggregate_results(
sensitive_aggregate_result,
synthetic_aggregate_result,
);
self.clear_navigate();
Ok(())
}
#[wasm_bindgen(js_name = "protectSensitiveAggregatesCount")]
pub fn protect_sensitive_aggregates_count(&mut self) {
debug!("protecting sensitive aggregates count...");
self.evaluate_result
.sensitive_aggregate_result
.protect_aggregates_count(self.resolution);
}
pub fn navigate(&mut self) {
debug!("creating navigate result...");
self.navigate_result =
WasmNavigateResult::from_synthetic_processor(&self.synthetic_processor);
}
#[wasm_bindgen(js_name = "selectAttributes")]
pub fn select_attributes(&mut self, attributes: JsSelectedAttributesByColumn) -> JsResult<()> {
self.navigate_result.select_attributes(attributes)
}
#[wasm_bindgen(js_name = "attributesIntersectionsByColumn")]
pub fn attributes_intersections_by_column(
&mut self,
columns: JsHeaderNames,
) -> JsResult<JsAttributesIntersectionByColumn> {
self.navigate_result.attributes_intersections_by_column(
columns,
&self.evaluate_result.sensitive_aggregate_result,
)
}
#[wasm_bindgen(js_name = "generateResultToJs")]
pub fn generate_result_to_js(&self) -> JsResult<JsGenerateResult> {
self.generate_result.to_js()
}
#[wasm_bindgen(js_name = "evaluateResultToJs")]
pub fn evaluate_result_to_js(
&self,
combination_delimiter: &str,
include_aggregates_count: bool,
) -> JsResult<JsEvaluateResult> {
self.evaluate_result.to_js(
combination_delimiter,
self.resolution,
include_aggregates_count,
)
}
}

Просмотреть файл

@ -0,0 +1,23 @@
use js_sys::Array;
use sds_core::data_block::typedefs::CsvRecord;
use std::convert::TryFrom;
use wasm_bindgen::{JsCast, JsValue};
use crate::utils::js::ts_definitions::JsHeaderNames;
impl JsHeaderNames {
#[inline]
pub fn default() -> JsHeaderNames {
Array::default().unchecked_into()
}
}
impl TryFrom<JsHeaderNames> for CsvRecord {
type Error = JsValue;
fn try_from(js_csv_record: JsHeaderNames) -> Result<Self, Self::Error> {
js_csv_record
.into_serde::<CsvRecord>()
.map_err(|err| JsValue::from(err.to_string()))
}
}

Просмотреть файл

@ -0,0 +1,5 @@
pub mod header_names;
mod processor;
pub use processor::SDSProcessor;

Просмотреть файл

@ -0,0 +1,120 @@
use js_sys::Function;
use sds_core::{
data_block::{block::DataBlock, data_block_creator::DataBlockCreator, typedefs::CsvRecord},
processing::{
aggregator::Aggregator,
generator::{Generator, SynthesisMode},
},
utils::time::ElapsedDurationLogger,
};
use std::convert::TryFrom;
use std::sync::Arc;
use wasm_bindgen::{prelude::*, JsCast};
use crate::{
utils::js::js_progress_reporter::JsProgressReporter,
{
processing::{
aggregator::aggregate_result::WasmAggregateResult,
generator::generate_result::WasmGenerateResult,
},
utils::js::{
js_block_creator::JsDataBlockCreator,
ts_definitions::{JsCsvData, JsHeaderNames, JsReportProgressCallback},
},
},
};
#[wasm_bindgen]
pub struct SDSProcessor {
pub(crate) data_block: Arc<DataBlock>,
}
#[wasm_bindgen]
impl SDSProcessor {
pub fn default() -> SDSProcessor {
SDSProcessor {
data_block: Arc::new(DataBlock::default()),
}
}
}
#[wasm_bindgen]
impl SDSProcessor {
#[wasm_bindgen(constructor)]
pub fn new(
csv_data: JsCsvData,
use_columns: JsHeaderNames,
sensitive_zeros: JsHeaderNames,
record_limit: usize,
) -> Result<SDSProcessor, JsValue> {
let _duration_logger = ElapsedDurationLogger::new(String::from("sds processor creation"));
let data_block = JsDataBlockCreator::create(
Ok(csv_data.dyn_into()?),
&CsvRecord::try_from(use_columns)?,
&CsvRecord::try_from(sensitive_zeros)?,
record_limit,
)?;
Ok(SDSProcessor { data_block })
}
#[wasm_bindgen(js_name = "numberOfRecords")]
pub fn number_of_records(&self) -> usize {
self.data_block.number_of_records()
}
#[wasm_bindgen(js_name = "protectedNumberOfRecords")]
pub fn protected_number_of_records(&self, resolution: usize) -> usize {
self.data_block.protected_number_of_records(resolution)
}
#[wasm_bindgen(js_name = "normalizeReportingLength")]
pub fn normalize_reporting_length(&self, reporting_length: usize) -> usize {
self.data_block.normalize_reporting_length(reporting_length)
}
pub fn aggregate(
&self,
reporting_length: usize,
sensitivity_threshold: usize,
progress_callback: JsReportProgressCallback,
) -> Result<WasmAggregateResult, JsValue> {
let _duration_logger =
ElapsedDurationLogger::new(String::from("sds processor aggregation"));
let js_callback: Function = progress_callback.dyn_into()?;
let mut aggregator = Aggregator::new(self.data_block.clone());
Ok(WasmAggregateResult::new(aggregator.aggregate(
reporting_length,
sensitivity_threshold,
&mut Some(JsProgressReporter::new(&js_callback, &|p| p)),
)))
}
pub fn generate(
&self,
cache_max_size: usize,
resolution: usize,
empty_value: String,
seeded: bool,
progress_callback: JsReportProgressCallback,
) -> Result<WasmGenerateResult, JsValue> {
let _duration_logger = ElapsedDurationLogger::new(String::from("sds processor generation"));
let js_callback: Function = progress_callback.dyn_into()?;
let mut generator = Generator::new(self.data_block.clone());
let mode = if seeded {
SynthesisMode::Seeded
} else {
SynthesisMode::Unseeded
};
Ok(WasmGenerateResult::new(generator.generate(
resolution,
cache_max_size,
empty_value,
mode,
&mut Some(JsProgressReporter::new(&js_callback, &|p| p)),
)))
}
}

Просмотреть файл

@ -1,41 +0,0 @@
use super::js_block_creator::JsDataBlockCreator;
use js_sys::Array;
use sds_core::{
data_block::{
block::{DataBlock, DataBlockCreator},
typedefs::CsvRecord,
},
utils::time::ElapsedDurationLogger,
};
#[inline]
pub fn deserialize_use_columns(use_columns: Array) -> Result<CsvRecord, String> {
let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_use_columns"));
match use_columns.into_serde::<CsvRecord>() {
Ok(v) => Ok(v),
_ => Err(String::from("use_columns should be an Array<string>")),
}
}
#[inline]
pub fn deserialize_sensitive_zeros(sensitive_zeros: Array) -> Result<CsvRecord, String> {
let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_sensitive_zeros"));
match sensitive_zeros.into_serde::<CsvRecord>() {
Ok(v) => Ok(v),
_ => Err(String::from("sensitive_zeros should be an Array<string>")),
}
}
#[inline]
pub fn deserialize_csv_data(
csv_data: Array,
use_columns: &[String],
sensitive_zeros: &[String],
record_limit: usize,
) -> Result<DataBlock, String> {
let _duration_logger = ElapsedDurationLogger::new(String::from("deserialize_csv_data"));
JsDataBlockCreator::create(Ok(csv_data), use_columns, sensitive_zeros, record_limit)
}

Просмотреть файл

@ -1,33 +1,34 @@
use js_sys::Array;
use sds_core::data_block::{block::DataBlockCreator, typedefs::CsvRecord};
use sds_core::data_block::{data_block_creator::DataBlockCreator, typedefs::CsvRecord};
use wasm_bindgen::JsValue;
pub struct JsDataBlockCreator;
impl DataBlockCreator for JsDataBlockCreator {
type InputType = Array;
type ErrorType = String;
type ErrorType = JsValue;
fn get_headers(csv_data: &mut Array) -> Result<CsvRecord, String> {
fn get_headers(csv_data: &mut Self::InputType) -> Result<CsvRecord, Self::ErrorType> {
let headers = csv_data.get(0);
if headers.is_undefined() {
Err(String::from("missing headers"))
Err(JsValue::from("csv data missing headers"))
} else {
match headers.into_serde::<Vec<String>>() {
Ok(h) => Ok(h),
Err(_) => Err(String::from("headers should an Array<string>")),
}
headers
.into_serde::<CsvRecord>()
.map_err(|err| JsValue::from(err.to_string()))
}
}
fn get_records(csv_data: &mut Self::InputType) -> Result<Vec<CsvRecord>, String> {
fn get_records(csv_data: &mut Self::InputType) -> Result<Vec<CsvRecord>, Self::ErrorType> {
csv_data
.slice(1, csv_data.length())
.iter()
.map(|record| match record.into_serde::<Vec<String>>() {
Ok(h) => Ok(h),
Err(_) => Err(String::from("records should an Array<string>")),
.map(|record| {
record
.into_serde::<CsvRecord>()
.map_err(|err| JsValue::from(err.to_string()))
})
.collect::<Result<Vec<CsvRecord>, String>>()
.collect::<Result<Vec<CsvRecord>, Self::ErrorType>>()
}
}

Просмотреть файл

@ -1,32 +0,0 @@
#[doc(hidden)]
#[macro_export]
macro_rules! match_or_return_undefined {
($result_to_match: expr) => {
match $result_to_match {
Ok(result) => result,
Err(err) => {
error!("{}", err);
return JsValue::undefined();
}
}
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! set_or_return_undefined {
($($arg:tt)*) => {
match set($($arg)*) {
Ok(has_been_set) => {
if !has_been_set {
error!("unable to set object value from wasm");
return JsValue::undefined();
}
}
_ => {
error!("unable to set object value from wasm");
return JsValue::undefined();
}
};
};
}

Просмотреть файл

@ -1,16 +1,9 @@
pub fn set_panic_hook() {
// When the `console_error_panic_hook` feature is enabled, we can call the
// `set_panic_hook` function at least once during initialization, and then
// we will get better error messages if our code ever panics.
//
// For more details see
// https://github.com/rustwasm/console_error_panic_hook#readme
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
pub mod deserializers;
pub mod js_block_creator;
pub mod js_progress_reporter;
pub mod macros;
pub mod serializers;
mod set_panic_hook;
pub use set_panic_hook::set_panic_hook;
pub mod ts_definitions;

Просмотреть файл

@ -1,124 +0,0 @@
use js_sys::{Object, Reflect::set};
use log::error;
use sds_core::{
data_block::typedefs::DataBlockHeadersSlice,
processing::{
aggregator::{typedefs::AggregatesCountMap, Aggregator, PrivacyRiskSummary},
evaluator::typedefs::PreservationByCountBuckets,
},
utils::time::ElapsedDurationLogger,
};
use wasm_bindgen::JsValue;
use crate::{aggregator::AggregatedCombination, set_or_return_undefined};
#[inline]
pub fn serialize_aggregates_count(
aggregates_count: &AggregatesCountMap,
headers: &DataBlockHeadersSlice,
) -> JsValue {
let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_aggregates_count"));
let aggregated_values = Object::new();
for (agg, count) in aggregates_count.iter() {
let combination_key = Aggregator::format_aggregate_str(headers, agg);
set_or_return_undefined!(
&aggregated_values,
&combination_key.clone().into(),
&AggregatedCombination::new(combination_key, count.count, agg.len()).into()
);
}
aggregated_values.into()
}
#[inline]
pub fn serialize_buckets(buckets: &PreservationByCountBuckets) -> JsValue {
let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_buckets"));
let serialized_buckets = Object::new();
for (bucket_index, b) in buckets.iter() {
let serialized_bucket = Object::new();
set_or_return_undefined!(&serialized_bucket, &"size".into(), &b.size.into());
set_or_return_undefined!(
&serialized_bucket,
&"preservationSum".into(),
&b.preservation_sum.into()
);
set_or_return_undefined!(
&serialized_bucket,
&"lengthSum".into(),
&b.length_sum.into()
);
set_or_return_undefined!(
&serialized_buckets,
&(*bucket_index).into(),
&serialized_bucket.into()
);
}
serialized_buckets.into()
}
#[inline]
pub fn serialize_privacy_risk(privacy_risk: &PrivacyRiskSummary) -> JsValue {
let _duration_logger = ElapsedDurationLogger::new(String::from("serialize_privacy_risk"));
let serialized_privacy_risk = Object::new();
set_or_return_undefined!(
&serialized_privacy_risk,
&"totalNumberOfRecords".into(),
&privacy_risk.total_number_of_records.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"totalNumberOfCombinations".into(),
&privacy_risk.total_number_of_combinations.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"recordsWithUniqueCombinationsCount".into(),
&privacy_risk.records_with_unique_combinations_count.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"recordsWithRareCombinationsCount".into(),
&privacy_risk.records_with_rare_combinations_count.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"uniqueCombinationsCount".into(),
&privacy_risk.unique_combinations_count.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"rareCombinationsCount".into(),
&privacy_risk.rare_combinations_count.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"recordsWithUniqueCombinationsProportion".into(),
&privacy_risk
.records_with_unique_combinations_proportion
.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"recordsWithRareCombinationsProportion".into(),
&privacy_risk
.records_with_rare_combinations_proportion
.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"uniqueCombinationsProportion".into(),
&privacy_risk.unique_combinations_proportion.into()
);
set_or_return_undefined!(
&serialized_privacy_risk,
&"rareCombinationsProportion".into(),
&privacy_risk.rare_combinations_proportion.into()
);
serialized_privacy_risk.into()
}

Просмотреть файл

@ -0,0 +1,10 @@
pub fn set_panic_hook() {
// When the `console_error_panic_hook` feature is enabled, we can call the
// `set_panic_hook` function at least once during initialization, and then
// we will get better error messages if our code ever panics.
//
// For more details see
// https://github.com/rustwasm/console_error_panic_hook#readme
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}

Просмотреть файл

@ -0,0 +1,146 @@
use wasm_bindgen::prelude::*;
#[wasm_bindgen(typescript_custom_section)]
const TS_APPEND_CONTENT: &'static str = r#"
export type ReportProgressCallback = (progress: number) => void;
export type HeaderNames = string[];
export type CsvRecord = string[];
export type CsvData = CsvRecord[];
export interface IGenerateResult {
expansionRatio: number;
syntheticData: CsvData;
}
export interface IPrivacyRiskSummary {
totalNumberOfRecords: number
totalNumberOfCombinations: number
recordsWithUniqueCombinationsCount: number
recordsWithRareCombinationsCount: number
uniqueCombinationsCount: number
rareCombinationsCount: number
recordsWithUniqueCombinationsProportion: number
recordsWithRareCombinationsProportion: number
uniqueCombinationsProportion: number
rareCombinationsProportion: number
}
export interface IAggregateCountByLen {
[length: number]: number
}
export interface IAggregateCountAndLength {
count: number;
length: number;
}
export interface IAggregatesCount {
[name: string]: IAggregateCountAndLength;
}
export interface IAggregateResult {
reportingLength: number;
aggregatesCount?: IAggregatesCount;
rareCombinationsCountByLen: IAggregateCountByLen;
combinationsCountByLen: IAggregateCountByLen;
combinationsSumByLen: IAggregateCountByLen;
privacyRisk: IPrivacyRiskSummary;
}
export interface IPreservationByCountBucket {
size: number;
preservationSum: number;
lengthSum: number;
combinationCountSum: number;
}
export interface IPreservationByCountBuckets {
[bucket_index: number]: IPreservationByCountBucket;
}
export interface IPreservationByCount {
buckets: IPreservationByCountBuckets;
combinationLoss: number;
}
export interface IEvaluateResult {
sensitiveAggregateResult: IAggregateResult;
syntheticAggregateResult: IAggregateResult;
leakageCountByLen: IAggregateCountByLen;
fabricatedCountByLen: IAggregateCountByLen;
preservationByCount: IPreservationByCount;
recordExpansion: number;
}
export interface ISelectedAttributesByColumn {
[columnIndex: number]: Set<string>;
}
export interface IAttributesIntersection {
value: string;
estimatedCount: number;
actualCount?: number;
}
export interface IAttributesIntersectionByColumn {
[columnIndex: number]: IAttributesIntersection[];
}"#;
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(typescript_type = "ReportProgressCallback")]
pub type JsReportProgressCallback;
#[wasm_bindgen(typescript_type = "HeaderNames")]
pub type JsHeaderNames;
#[wasm_bindgen(typescript_type = "CsvRecord")]
pub type JsCsvRecord;
#[wasm_bindgen(typescript_type = "CsvData")]
pub type JsCsvData;
#[wasm_bindgen(typescript_type = "IGenerateResult")]
pub type JsGenerateResult;
#[wasm_bindgen(typescript_type = "IPrivacyRiskSummary")]
pub type JsPrivacyRiskSummary;
#[wasm_bindgen(typescript_type = "IAggregateCountByLen")]
pub type JsAggregateCountByLen;
#[wasm_bindgen(typescript_type = "IAggregateCountAndLength")]
pub type JsAggregateCountAndLength;
#[wasm_bindgen(typescript_type = "IAggregatesCount")]
pub type JsAggregatesCount;
#[wasm_bindgen(typescript_type = "IAggregateResult")]
pub type JsAggregateResult;
#[wasm_bindgen(typescript_type = "IPreservationByCountBucket")]
pub type JsPreservationByCountBucket;
#[wasm_bindgen(typescript_type = "IPreservationByCountBuckets")]
pub type JsPreservationByCountBuckets;
#[wasm_bindgen(typescript_type = "IPreservationByCount")]
pub type JsPreservationByCount;
#[wasm_bindgen(typescript_type = "IEvaluateResult")]
pub type JsEvaluateResult;
#[wasm_bindgen(typescript_type = "ISelectedAttributesByColumn")]
pub type JsSelectedAttributesByColumn;
#[wasm_bindgen(typescript_type = "IAttributesIntersection")]
pub type JsAttributesIntersection;
#[wasm_bindgen(typescript_type = "IAttributesIntersectionByColumn")]
pub type JsAttributesIntersectionByColumn;
}
pub type JsResult<T> = Result<T, JsValue>;

Просмотреть файл

@ -1,2 +1,3 @@
pub mod js;
pub mod logger;

Просмотреть файл

@ -30,11 +30,13 @@ def aggregate(config):
sensitive_zeros = config['sensitive_zeros']
output_dir = config['output_dir']
prefix = config['prefix']
sensitive_aggregated_data_json = path.join(
output_dir, f'{prefix}_sensitive_aggregated_data.json')
logging.info(f'Aggregate {sensitive_microdata_path}')
start_time = time.time()
data_block = sds.create_data_block_from_file(
sds_processor = sds.SDSProcessor(
sensitive_microdata_path,
sensitive_microdata_delimiter,
use_columns,
@ -42,13 +44,32 @@ def aggregate(config):
max(record_limit, 0)
)
len_to_combo_count, len_to_rare_count = sds.aggregate_and_write(
data_block,
aggregated_data = sds_processor.aggregate(
reporting_length,
0
)
len_to_combo_count = aggregated_data.calc_combinations_count_by_len()
len_to_rare_count = aggregated_data.calc_rare_combinations_count_by_len(
reporting_resolution)
aggregated_data.write_aggregates_count(
sensitive_aggregates_path,
'\t',
';',
reporting_resolution,
False
)
aggregated_data.write_to_json(sensitive_aggregated_data_json)
aggregated_data.protect_aggregates_count(reporting_resolution)
aggregated_data.write_aggregates_count(
reportable_aggregates_path,
'\t',
reporting_length,
reporting_resolution
';',
reporting_resolution,
True
)
leakage_tsv = path.join(

Просмотреть файл

@ -2,8 +2,247 @@ import time
import datetime
import logging
from os import path
from collections import defaultdict
import util as util
import sds
class Evaluator:
def __init__(self, config):
self.sds_evaluator = sds.Evaluator()
self.use_columns = config['use_columns']
self.record_limit = max(config['record_limit'], 0)
self.reporting_length = max(config['reporting_length'], 0)
self.reporting_resolution = config['reporting_resolution']
self.sensitive_microdata_path = config['sensitive_microdata_path']
self.sensitive_microdata_delimiter = config['sensitive_microdata_delimiter']
self.synthetic_microdata_path = config['synthetic_microdata_path']
self.sensitive_zeros = config['sensitive_zeros']
self.output_dir = config['output_dir']
self.prefix = config['prefix']
self.sensitive_aggregated_data_json = path.join(
self.output_dir, f'{self.prefix}_sensitive_aggregated_data.json'
)
self.record_analysis_tsv = path.join(
self.output_dir, f'{self.prefix}_sensitive_analysis_by_length.tsv'
)
self.synthetic_rare_combos_tsv = path.join(
self.output_dir, f'{self.prefix}_synthetic_rare_combos_by_length.tsv'
)
self.parameters_tsv = path.join(
self.output_dir, f'{self.prefix}_parameters.tsv'
)
self.leakage_tsv = path.join(
self.output_dir, f'{self.prefix}_synthetic_leakage_by_length.tsv'
)
self.leakage_svg = path.join(
self.output_dir, f'{self.prefix}_synthetic_leakage_by_length.svg'
)
self.preservation_by_count_tsv = path.join(
self.output_dir, f'{self.prefix}_preservation_by_count.tsv'
)
self.preservation_by_count_svg = path.join(
self.output_dir, f'{self.prefix}_preservation_by_count.svg'
)
self.preservation_by_length_tsv = path.join(
self.output_dir, f'{self.prefix}_preservation_by_length.tsv'
)
self.preservation_by_length_svg = path.join(
self.output_dir, f'{self.prefix}_preservation_by_length.svg'
)
def _load_sensitive_aggregates(self):
if not path.exists(self.sensitive_aggregated_data_json):
logging.info('Computing sensitive aggregates...')
self.sen_sds_processor = sds.SDSProcessor(
self.sensitive_microdata_path,
self.sensitive_microdata_delimiter,
self.use_columns,
self.sensitive_zeros,
self.record_limit
)
self.sen_aggregated_data = self.sen_sds_processor.aggregate(
self.reporting_length,
0
)
else:
logging.info('Loading sensitive aggregates...')
self.sen_aggregated_data = sds.AggregatedData.read_from_json(
self.sensitive_aggregated_data_json
)
self.sen_sds_processor = sds.SDSProcessor.from_aggregated_data(
self.sen_aggregated_data
)
self.reporting_length = self.sen_sds_processor.normalize_reporting_length(
self.reporting_length
)
def _load_synthetic_aggregates(self):
logging.info('Computing synthetic aggregates...')
self.syn_sds_processor = sds.SDSProcessor(
self.synthetic_microdata_path,
"\t",
self.use_columns,
self.sensitive_zeros,
self.record_limit
)
self.syn_aggregated_data = self.syn_sds_processor.aggregate(
self.reporting_length,
0
)
def _do_records_analysis(self):
logging.info('Performing records analysis on sensitive aggregates...')
self.records_analysis_data = self.sen_aggregated_data.calc_records_analysis_by_len(
self.reporting_resolution, True
)
self.records_analysis_data.write_records_analysis(
self.record_analysis_tsv, '\t'
)
def _compare_synthetic_and_sensitive_rare(self):
logging.info(
'Comparing synthetic and sensitive aggregates rare combinations...')
rare_combinations_data = self.sds_evaluator.compare_synthetic_and_sensitive_rare(
self.syn_aggregated_data, self.sen_aggregated_data, self.reporting_resolution, ' AND ', True
)
rare_combinations_data.write_rare_combinations(
self.synthetic_rare_combos_tsv, '\t'
)
def _write_parameters(self):
logging.info('Writing evaluation parameters...')
total_sen = self.sen_sds_processor.protected_number_of_records(
self.reporting_resolution
)
with open(self.parameters_tsv, 'w') as f:
f.write('\t'.join(['parameter', 'value'])+'\n')
f.write(
'\t'.join(['resolution', str(self.reporting_resolution)])+'\n'
)
f.write('\t'.join(['limit', str(self.reporting_length)])+'\n')
f.write(
'\t'.join(['total_sensitive_records', str(total_sen)])+'\n'
)
f.write('\t'.join(['unique_identifiable', str(
self.records_analysis_data.get_total_unique())])+'\n'
)
f.write('\t'.join(['rare_identifiable', str(
self.records_analysis_data.get_total_rare())])+'\n'
)
f.write('\t'.join(['risky_identifiable', str(
self.records_analysis_data.get_total_risky())])+'\n'
)
f.write(
'\t'.join(['risky_identifiable_pct', str(
100*self.records_analysis_data.get_total_risky()/total_sen)])+'\n'
)
def _find_leakages(self):
logging.info(
'Looking for leakages from the sensitive dataset on the synthetic dataset...')
comb_counts = self.syn_aggregated_data.calc_combinations_count_by_len()
leakage_counts = self.sds_evaluator.calc_leakage_count(
self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution
)
with open(self.leakage_tsv, 'w') as f:
f.write('\t'.join(
['syn_combo_length', 'combo_count', 'leak_count', 'leak_proportion'])+'\n'
)
for length in range(1, self.reporting_length + 1):
combo_count = comb_counts.get(length, 0)
leak_count = leakage_counts.get(length, 0)
# by design there should be no leakage
assert leak_count == 0
leak_prop = leak_count/combo_count if combo_count > 0 else 0
f.write('\t'.join(
[str(length), str(combo_count), str(leak_count), str(leak_prop)])+'\n'
)
util.plotStats(
x_axis='syn_combo_length',
x_axis_title='Length of Synthetic Combination',
y_bar='combo_count',
y_bar_title='Count of Combinations',
y_line='leak_proportion',
y_line_title=f'Proportion of Leaked (<{self.reporting_resolution}) Combinations',
color='violet',
darker_color='darkviolet',
stats_tsv=self.leakage_tsv,
stats_svg=self.leakage_svg,
delimiter='\t',
style='whitegrid',
palette='magma'
)
def _calculate_preservation_by_count(self):
logging.info('Calculating preservation by count...')
preservation_by_count = self.sds_evaluator.calc_preservation_by_count(
self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution
)
preservation_by_count.write_preservation_by_count(
self.preservation_by_count_tsv, '\t'
)
util.plotStats(
x_axis='syn_count_bucket',
x_axis_title='Count of Filtered Synthetic Records',
y_bar='mean_combo_length',
y_bar_title='Mean Length of Combinations',
y_line='count_preservation',
y_line_title='Count Preservation (Synthetic/Sensitive)',
color='lightgreen',
darker_color='green',
stats_tsv=self.preservation_by_count_tsv,
stats_svg=self.preservation_by_count_svg,
delimiter='\t',
style='whitegrid',
palette='magma'
)
def _calculate_preservation_by_length(self):
logging.info('Calculating preservation by length...')
preservation_by_length = self.sds_evaluator.calc_preservation_by_length(
self.sen_aggregated_data, self.syn_aggregated_data, self.reporting_resolution
)
preservation_by_length.write_preservation_by_length(
self.preservation_by_length_tsv, '\t'
)
util.plotStats(
x_axis='syn_combo_length',
x_axis_title='Length of Combination',
y_bar='mean_combo_count',
y_bar_title='Mean Synthetic Count',
y_line='count_preservation',
y_line_title='Count Preservation (Synthetic/Sensitive)',
color='cornflowerblue',
darker_color='mediumblue',
stats_tsv=self.preservation_by_length_tsv,
stats_svg=self.preservation_by_length_svg,
delimiter='\t',
style='whitegrid',
palette='magma'
)
def run(self):
logging.info(
f'Evaluate {self.synthetic_microdata_path} vs {self.sensitive_microdata_path}'
)
start_time = time.time()
self._load_sensitive_aggregates()
self._load_synthetic_aggregates()
self._do_records_analysis()
self._compare_synthetic_and_sensitive_rare()
self._write_parameters()
self._find_leakages()
self._calculate_preservation_by_count()
self._calculate_preservation_by_length()
logging.info(
f'Evaluated {self.synthetic_microdata_path} vs {self.sensitive_microdata_path}, took {datetime.timedelta(seconds = time.time() - start_time)}s')
def evaluate(config):
@ -15,309 +254,4 @@ def evaluate(config):
Args:
config: options from the json config file, else default values.
"""
use_columns = config['use_columns']
record_limit = config['record_limit']
reporting_length = config['reporting_length']
reporting_resolution = config['reporting_resolution']
sensitive_microdata_path = config['sensitive_microdata_path']
sensitive_microdata_delimiter = config['sensitive_microdata_delimiter']
synthetic_microdata_path = config['synthetic_microdata_path']
sensitive_zeros = config['sensitive_zeros']
parallel_jobs = config['parallel_jobs']
output_dir = config['output_dir']
prefix = config['prefix']
logging.info(
f'Evaluate {synthetic_microdata_path} vs {sensitive_microdata_path}')
start_time = time.time()
sen_counts = None
sen_records = None
sen_df = util.loadMicrodata(
path=sensitive_microdata_path, delimiter=sensitive_microdata_delimiter, record_limit=record_limit,
use_columns=use_columns)
sen_records = util.genRowList(sen_df, sensitive_zeros)
if not path.exists(config['sensitive_aggregates_path']):
logging.info('Computing sensitive aggregates...')
if reporting_length == -1:
reporting_length = max([len(row) for row in sen_records])
sen_counts = util.countAllCombos(
sen_records, reporting_length, parallel_jobs)
else:
logging.info('Loading sensitive aggregates...')
sen_counts = util.loadSavedAggregates(
config['sensitive_aggregates_path'])
if reporting_length == -1:
reporting_length = max(sen_counts.keys())
if use_columns != []:
reporting_length = min(reporting_length, len(use_columns))
filtered_sen_counts = {length: {combo: count for combo, count in combo_to_counts.items(
) if count >= reporting_resolution} for length, combo_to_counts in sen_counts.items()}
syn_df = util.loadMicrodata(path=synthetic_microdata_path,
delimiter='\t', record_limit=-1, use_columns=use_columns)
syn_records = util.genRowList(syn_df, sensitive_zeros)
syn_counts = util.countAllCombos(
syn_records, reporting_length, parallel_jobs)
len_to_syn_count = {length: len(combo_to_count)
for length, combo_to_count in syn_counts.items()}
len_to_sen_rare = {length: {combo: count for combo, count in combo_to_count.items() if count < reporting_resolution}
for length, combo_to_count in sen_counts.items()}
len_to_syn_rare = {length: {combo: count for combo, count in combo_to_count.items() if count < reporting_resolution}
for length, combo_to_count in syn_counts.items()}
len_to_syn_leak = {length: len([1 for rare in rares if rare in syn_counts.get(length, {}).keys()])
for length, rares in len_to_sen_rare.items()}
sen_unique_to_records, sen_rare_to_records, _ = util.mapShortestUniqueRareComboLengthToRecords(
sen_records, len_to_sen_rare)
sen_rare_to_sen_count = {length: util.protect(len(records), reporting_resolution)
for length, records in sen_rare_to_records.items()}
sen_unique_to_sen_count = {length: util.protect(len(records), reporting_resolution)
for length, records in sen_unique_to_records.items()}
total_sen = util.protect(len(sen_records), reporting_resolution)
unique_total = sum(
[v for k, v in sen_unique_to_sen_count.items() if k > 0])
rare_total = sum([v for k, v in sen_rare_to_sen_count.items() if k > 0])
risky_total = unique_total + rare_total
risky_total_pct = 100*risky_total/total_sen
record_analysis_tsv = path.join(
output_dir, f'{prefix}_sensitive_analysis_by_length.tsv')
with open(record_analysis_tsv, 'w') as f:
f.write('\t'.join(['combo_length', 'sen_rare', 'sen_rare_pct',
'sen_unique', 'sen_unique_pct', 'sen_risky', 'sen_risky_pct'])+'\n')
for length in sen_counts.keys():
sen_rare = sen_rare_to_sen_count.get(length, 0)
sen_rare_pct = 100*sen_rare / total_sen if total_sen > 0 else 0
sen_unique = sen_unique_to_sen_count.get(length, 0)
sen_unique_pct = 100*sen_unique / total_sen if total_sen > 0 else 0
sen_risky = sen_rare + sen_unique
sen_risky_pct = 100*sen_risky / total_sen if total_sen > 0 else 0
f.write(
'\t'.join(
[str(length),
str(sen_rare),
str(sen_rare_pct),
str(sen_unique),
str(sen_unique_pct),
str(sen_risky),
str(sen_risky_pct)]) + '\n')
_, _, syn_length_to_combo_to_rare = util.mapShortestUniqueRareComboLengthToRecords(
syn_records, len_to_syn_rare)
combos_tsv = path.join(
output_dir, f'{prefix}_synthetic_rare_combos_by_length.tsv')
with open(combos_tsv, 'w') as f:
f.write('\t'.join(['combo_length', 'combo',
'record_id', 'syn_count', 'sen_count'])+'\n')
for length, combo_to_rare in syn_length_to_combo_to_rare.items():
for combo, rare_ids in combo_to_rare.items():
syn_count = len(rare_ids)
for rare_id in rare_ids:
sen_count = util.protect(sen_counts.get(length, {})[
combo], reporting_resolution)
f.write(
'\t'.join(
[str(length),
util.comboToString(combo).replace(';', ' AND '),
str(rare_id),
str(syn_count),
str(sen_count)]) + '\n')
parameters_tsv = path.join(output_dir, f'{prefix}_parameters.tsv')
with open(parameters_tsv, 'w') as f:
f.write('\t'.join(['parameter', 'value'])+'\n')
f.write('\t'.join(['resolution', str(reporting_resolution)])+'\n')
f.write('\t'.join(['limit', str(reporting_length)])+'\n')
f.write('\t'.join(['total_sensitive_records', str(total_sen)])+'\n')
f.write('\t'.join(['unique_identifiable', str(unique_total)])+'\n')
f.write('\t'.join(['rare_identifiable', str(rare_total)])+'\n')
f.write('\t'.join(['risky_identifiable', str(risky_total)])+'\n')
f.write(
'\t'.join(['risky_identifiable_pct', str(risky_total_pct)])+'\n')
leakage_tsv = path.join(
output_dir, f'{prefix}_synthetic_leakage_by_length.tsv')
leakage_svg = path.join(
output_dir, f'{prefix}_synthetic_leakage_by_length.svg')
with open(leakage_tsv, 'w') as f:
f.write('\t'.join(['syn_combo_length', 'combo_count',
'leak_count', 'leak_proportion'])+'\n')
for length, leak_count in len_to_syn_leak.items():
combo_count = len_to_syn_count.get(length, 0)
leak_prop = leak_count/combo_count if combo_count > 0 else 0
f.write('\t'.join([str(length), str(combo_count),
str(leak_count), str(leak_prop)])+'\n')
util.plotStats(
x_axis='syn_combo_length',
x_axis_title='Length of Synthetic Combination',
y_bar='combo_count',
y_bar_title='Count of Combinations',
y_line='leak_proportion',
y_line_title=f'Proportion of Leaked (<{reporting_resolution}) Combinations',
color='violet',
darker_color='darkviolet',
stats_tsv=leakage_tsv,
stats_svg=leakage_svg,
delimiter='\t',
style='whitegrid',
palette='magma')
compareDatasets(filtered_sen_counts, syn_counts, output_dir, prefix)
logging.info(
f'Evaluated {synthetic_microdata_path} vs {sensitive_microdata_path}, took {datetime.timedelta(seconds = time.time() - start_time)}s')
def compareDatasets(sensitive_length_to_combo_to_count, synthetic_length_to_combo_to_count, output_dir, prefix):
"""Evaluates the error in the synthetic microdata with respect to the sensitive microdata.
Produces output statistics and graphics binned by attribute count.
Args:
sensitive_length_to_combo_to_count: counts from sensitive microdata.
synthetic_length_to_combo_to_count: counts from synthetic microdata.
output_dir: where to save output statistics and graphics.
prefix: prefix to add to output files.
"""
all_count_length_preservation = []
max_syn_count = 0
all_combos = set()
for length, combo_to_count in sensitive_length_to_combo_to_count.items():
for combo in combo_to_count.keys():
all_combos.add((length, combo))
for length, combo_to_count in synthetic_length_to_combo_to_count.items():
for combo in combo_to_count.keys():
all_combos.add((length, combo))
tot = len(all_combos)
for i, (length, combo) in enumerate(all_combos):
if i % 10000 == 0:
logging.info(f'{100*i/tot:.1f}% through comparisons')
sen_count = sensitive_length_to_combo_to_count.get(
length, {}).get(combo, 0)
syn_count = synthetic_length_to_combo_to_count.get(
length, {}).get(combo, 0)
max_syn_count = max(syn_count, max_syn_count)
preservation = 0
preservation = syn_count / sen_count if sen_count > 0 else 0
if sen_count == 0:
logging.error(
f'Error: For {combo}, synthetic count is {syn_count} but no sensitive count')
all_count_length_preservation.append((syn_count, length, preservation))
max_syn_count = max(syn_count, max_syn_count)
generateStatsAndGraphics(output_dir, max_syn_count,
all_count_length_preservation, prefix)
def generateStatsAndGraphics(output_dir, max_syn_count, count_length_preservation, prefix):
"""Generates count error statistics for buckets of user-observed counts (post-filtering).
Outputs the files preservation_by_length and preservation_by_count tsv and svg files.
Args:
output_dir: the folder in which to output summary files.
max_syn_count: the maximum observed count of synthetic records matching a single attribute value.
count_length_preservation: list of (count, length, preservation) tuples for observed preservation of sensitive counts after filtering by a combo of length.
"""
sorted_counts = sorted(count_length_preservation,
key=lambda x: x[0], reverse=False)
buckets = []
next_bucket = 10
while next_bucket < max_syn_count:
buckets.append(next_bucket)
next_bucket *= 2
buckets.append(next_bucket)
bucket_to_preservations = defaultdict(list)
bucket_to_counts = defaultdict(list)
bucket_to_lengths = defaultdict(list)
length_to_preservations = defaultdict(list)
length_to_counts = defaultdict(list)
for (count, length, preservation) in sorted_counts:
bucket = buckets[0]
for bi in range(len(buckets)-1, -1, -1):
if count > buckets[bi]:
bucket = buckets[bi+1]
break
bucket_to_preservations[bucket].append(preservation)
bucket_to_lengths[bucket].append(length)
bucket_to_counts[bucket].append(count)
length_to_counts[length].append(count)
length_to_preservations[length].append(preservation)
bucket_to_mean_count = {bucket: sum(counts)/len(counts) if len(counts) >
0 else 0 for bucket, counts in bucket_to_counts.items()}
bucket_to_mean_preservation = {bucket: sum(preservations) / len(preservations)
if len(preservations) > 0 else 0 for bucket,
preservations in bucket_to_preservations.items()}
bucket_to_mean_length = {bucket: sum(lengths)/len(lengths) if len(lengths) >
0 else 0 for bucket, lengths in bucket_to_lengths.items()}
counts_tsv = path.join(output_dir, f'{prefix}_preservation_by_count.tsv')
counts_svg = path.join(output_dir, f'{prefix}_preservation_by_count.svg')
with open(counts_tsv, 'w') as f:
f.write('\t'.join(['syn_count_bucket', 'mean_combo_count',
'mean_combo_length', 'count_preservation'])+'\n')
for bucket in reversed(buckets):
f.write(
'\t'.join(
[str(bucket),
str(bucket_to_mean_count.get(bucket, 0)),
str(bucket_to_mean_length.get(bucket, 0)),
str(bucket_to_mean_preservation.get(bucket, 0))]) + '\n')
util.plotStats(
x_axis='syn_count_bucket',
x_axis_title='Count of Filtered Synthetic Records',
y_bar='mean_combo_length',
y_bar_title='Mean Length of Combinations',
y_line='count_preservation',
y_line_title='Count Preservation (Synthetic/Sensitive)',
color='lightgreen',
darker_color='green',
stats_tsv=counts_tsv,
stats_svg=counts_svg,
delimiter='\t',
style='whitegrid',
palette='magma')
length_to_mean_preservation = {length: sum(preservations) / len(preservations)
if len(preservations) > 0 else 0 for length,
preservations in length_to_preservations.items()}
length_to_mean_count = {length: sum(counts)/len(counts) if len(counts) >
0 else 0 for length, counts in length_to_counts.items()}
lengths_tsv = path.join(output_dir, f'{prefix}_preservation_by_length.tsv')
lengths_svg = path.join(output_dir, f'{prefix}_preservation_by_length.svg')
with open(lengths_tsv, 'w') as f:
f.write(
'\t'.join(['syn_combo_length', 'mean_combo_count', 'count_preservation'])+'\n')
for length in sorted(length_to_preservations.keys()):
f.write(
'\t'.join(
[str(length),
str(length_to_mean_count.get(length, 0)),
str(length_to_mean_preservation.get(length, 0))]) + '\n')
util.plotStats(
x_axis='syn_combo_length',
x_axis_title='Length of Combination',
y_bar='mean_combo_count',
y_bar_title='Mean Synthetic Count',
y_line='count_preservation',
y_line_title='Count Preservation (Synthetic/Sensitive)',
color='cornflowerblue',
darker_color='mediumblue',
stats_tsv=lengths_tsv,
stats_svg=lengths_svg,
delimiter='\t',
style='whitegrid',
palette='magma')
Evaluator(config).run()

Просмотреть файл

@ -1,12 +1,6 @@
import time
import datetime
import logging
import joblib
import psutil
import pandas as pd
from math import ceil
from random import random, shuffle
import util as util
import sds
@ -20,14 +14,12 @@ def generate(config):
"""
use_columns = config['use_columns']
parallel_jobs = config['parallel_jobs']
record_limit = config['record_limit']
sensitive_microdata_path = config['sensitive_microdata_path']
sensitive_microdata_delimiter = config['sensitive_microdata_delimiter']
synthetic_microdata_path = config['synthetic_microdata_path']
sensitive_zeros = config['sensitive_zeros']
resolution = config['reporting_resolution']
memory_limit = config['memory_limit_pct']
cache_max_size = config['cache_max_size']
seeded = config['seeded']
@ -36,215 +28,26 @@ def generate(config):
if seeded:
logging.info(f'Generating from seeds')
data_block = sds.create_data_block_from_file(
else:
logging.info(f'Generating unseeded')
sds_processor = sds.SDSProcessor(
sensitive_microdata_path,
sensitive_microdata_delimiter,
use_columns,
sensitive_zeros,
max(record_limit, 0)
)
syn_ratio = sds.generate_seeded_and_write(
data_block,
synthetic_microdata_path,
'\t',
generated_data = sds_processor.generate(
cache_max_size,
resolution
resolution,
"",
seeded
)
else:
df = util.loadMicrodata(path=sensitive_microdata_path, delimiter=sensitive_microdata_delimiter,
record_limit=record_limit, use_columns=use_columns)
columns = df.columns.values
num = len(df)
logging.info(f'Prepared data')
records = []
chunk = ceil(num/(parallel_jobs))
logging.info(f'Generating unseeded')
chunks = [chunk for i in range(parallel_jobs)]
col_val_ids = util.genColValIdsDict(df, sensitive_zeros)
res = joblib.Parallel(
n_jobs=parallel_jobs, backend='loky', verbose=1)(
joblib.delayed(synthesizeRowsUnseeded)(
chunk, num, columns, col_val_ids, resolution, memory_limit)
for chunk in chunks)
for rows in res:
records.extend(rows)
# trim any overgenerated records because of uniform chunk size
records = records[:num]
records.sort()
records.sort(key=lambda x: len(
[y for y in x if y != '']), reverse=True)
sdf = pd.DataFrame(records, columns=df.columns)
syn_ratio = len(sdf) / len(df)
sdf.to_csv(synthetic_microdata_path, sep='\t', index=False)
generated_data.write_synthetic_data(synthetic_microdata_path, '\t')
syn_ratio = generated_data.expansion_ratio
config['expansion_ratio'] = syn_ratio
logging.info(
f'Generated {synthetic_microdata_path} from {sensitive_microdata_path} with synthesis ratio {syn_ratio}, took {datetime.timedelta(seconds = time.time() - start_time)}s')
def synthesizeRowsUnseeded(chunk, num_rows, columns, col_val_ids, resolution, memory_limit):
"""Create synthetic records through unconstrained sampling of attribute distributions.
Args:
chunk: how many records to synthesize.
num_rows: how many rows/records in the sensitive dataset.
columns: the columns of the sensitive dataset.
atts_to_ids: a dict mapping attributes to row ids.
resolution: the minimum count of sensitive attribute combinations for inclusion in synthetic records.
memory_limit: the percentage memory use not to exceed.
Returns:
rows: synthetic records created through unconstrained sampling of attribute distributions.
overall_cache_hits: the number of times the cache was hit.
overall_cache_misses: the number of times the cache was missed.
"""
rows = []
filter_cache = {}
overall_cache_hits = 0
overall_cache_misses = 0
for i in range(chunk):
if i % 100 == 99:
logging.info(
f'{(100*i/chunk):.1f}% through row synthesis, cache utilization {100*overall_cache_hits/(overall_cache_hits + overall_cache_misses):.1f}%')
filters, cache_hits, cache_misses = synthesizeRowUnseeded(
filter_cache, num_rows, columns, col_val_ids, resolution, memory_limit)
overall_cache_hits += cache_hits
overall_cache_misses += cache_misses
row = normalize(columns, filters)
rows.append(row)
return rows
def synthesizeRowUnseeded(filter_cache, num_rows, columns, col_val_ids, resolution, memory_limit):
"""Creates a synthetic record through unconstrained sampling of attribute distributions.
Args:
num_rows: how many rows/records in the sensitive dataset.
columns: the columns of the sensitive dataset.
atts_to_ids: a dict mapping attributes to row ids.
resolution: the minimum count of sensitive attribute combinations for inclusion in synthetic records.
memory_limit: the percentage memory use not to exceed.
Returns:
row: a synthetic record created through unconstrained sampling of attribute distributions.
row_cache_hits: the number of times the cache was hit.
row_cache_misses: the number of times the cache was missed.
"""
shuffled_columns = list(columns)
shuffle(shuffled_columns)
output_atts = []
residual_ids = set(range(num_rows))
row_cache_hits = 0
row_cache_misses = 0
# do not add to cache once memory limit is reached
use_cache = psutil.virtual_memory()[2] <= memory_limit
for col in shuffled_columns:
vals_to_sample = {}
for val, ids in col_val_ids[col].items():
next_filters = tuple(
sorted((*output_atts, (col, val)), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
if next_filters in filter_cache.keys():
vals_to_sample[val] = set(filter_cache[next_filters])
row_cache_hits += 1
else:
row_cache_misses += 1
res = set(ids).intersection(residual_ids)
vals_to_sample[val] = res
if use_cache:
filter_cache[next_filters] = res
if '' not in vals_to_sample.keys():
vals_to_sample[''] = set()
vals_to_remove = {k: v for k,
v in vals_to_sample.items() if len(v) < resolution}
for val, ids in vals_to_remove.items():
vals_to_sample[''].update(ids)
if val != '':
del vals_to_sample[val]
val_to_counts = {k: len(v) for k, v in vals_to_sample.items()}
sampled_val = sampleFromCounts(val_to_counts, True)
if sampled_val != None:
output_atts.append((col, sampled_val))
residual_ids = set(vals_to_sample[sampled_val])
filters = tuple(
sorted(output_atts, key=lambda x: f'{x[0]}:{x[1]}'.lower()))
return filters, row_cache_hits, row_cache_misses
def sampleFromCounts(counts, preferNotNone):
"""Samples from a dict of counts based on their relative frequency
Args:
counts: the dict of counts.
preferNotNone: whether to avoid sampling None if possible
Returns:
sampled_value: the sampled value
"""
dist = convertCountsToCumulativeDistribution(counts)
sampled_value = None
r = random()
keys = sorted(dist.keys())
for p in keys:
v = dist[p]
if r < p:
if preferNotNone and v == None:
continue
else:
sampled_value = v
break
sampled_value = v
return sampled_value
def convertCountsToCumulativeDistribution(att_to_count):
"""Converts a dict of counts to a cumulative probability distribution for sampling.
Args:
att_to_count: a dict of attribute counts.
Returns:
dist a cumulative probability distribution mapping cumulative probabilities [0,1] to attributes.
"""
dist = {}
total = sum(att_to_count.values())
if total == 0:
return dist
cumulative = 0
for att, count in att_to_count.items():
if count > 0:
p = count/total
cumulative += p
dist[cumulative] = att
return dist
def normalize(columns, atts):
"""Creates an output record according to the columns schema from a set of attributes.
Args:
columns: the columns of the output record.
atts: the attribute values of the output record.
Returns:
row: a normalized row ready for dataframe integration.
"""
row = []
for c in columns:
added = False
for a in atts:
if a[0] == c:
row.append(a[1])
added = True
break
if not added:
row.append('')
return row

Просмотреть файл

@ -1,6 +1,7 @@
import json
import logging
import argparse
import sds
from os import path, mkdir
import aggregator as aggregator
import generator as generator
@ -102,6 +103,9 @@ def runPipeline(config):
config: options from the json config file, else default values.
"""
# set number of threads to be used for processing
sds.set_number_of_threads(config['parallel_jobs'])
if config['aggregate']:
aggregator.aggregate(config)

Просмотреть файл

@ -1,14 +1,10 @@
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import logging
import numpy as np
import pandas as pd
import joblib
from itertools import combinations
from collections import defaultdict
import seaborn as sns
from math import ceil, floor
from math import ceil
import matplotlib
# fixes matplotlib + joblib bug "RuntimeError: main thread is not in main loop Tcl_AsyncDelete: async handler deleted by the wrong thread"
matplotlib.use('Agg')
@ -36,219 +32,6 @@ def loadMicrodata(path, delimiter, record_limit, use_columns):
return df
def genRowList(df, sensitive_zeros):
"""Converts a dataframe to a list of rows.
Args:
df: the dataframe.
sensitive_zeros: columns for which negative values should be controlled.
Returns:
row_list: the list of rows.
"""
row_list = []
for _, row in df.iterrows():
res = []
for c in df.columns:
val = str(row[c])
if val != '' and (c in sensitive_zeros or val != '0'):
res.append((c, val))
row2 = sorted(res, key=lambda x: f'{x[0]}:{x[1]}'.lower())
row_list.append(row2)
return row_list
def genColValIdsDict(df, sensitive_zeros):
"""Converts a dataframe to a dict of col->val->ids.
Args:
df: the dataframe
Returns:
colValIds: the dict of col->val->ids.
"""
colValIds = {}
for rid, row in df.iterrows():
for c in df.columns:
val = str(row[c])
if c not in colValIds.keys():
colValIds[c] = defaultdict(list)
if val == '0':
if c in sensitive_zeros:
colValIds[c][val].append(rid)
else:
colValIds[c][''].append(rid)
else:
colValIds[c][val].append(rid)
return colValIds
def countAllCombos(row_list, length_limit, parallel_jobs):
"""Counts all combinations in the given rows up to a limit.
Args:
row_list: a list of rows.
length_limit: the maximum length to compute counts for (all lengths if -1).
parallel_jobs: the number of processor cores to use for parallelized counting.
Returns:
length_to_combo_to_count: a dict mapping combination lengths to dicts mapping combinations to counts.
"""
length_to_combo_to_count = {}
if length_limit == -1:
length_limit = max([len(x) for x in row_list])
for length in range(1, length_limit+1):
logging.info(f'counting combos of length {length}')
res = joblib.Parallel(n_jobs=parallel_jobs, backend='loky', verbose=1)(
joblib.delayed(genAllCombos)(row, length) for row in row_list)
length_to_combo_to_count[length] = defaultdict(int)
for combos in res:
for combo in combos:
length_to_combo_to_count.get(length, {})[combo] += 1
return length_to_combo_to_count
def genAllCombos(row, length):
"""Generates all combos from row up to and including size length.
Args:
row: the row to extract combinations from.
length: the maximum combination length to extract.
Returns:
combos: list of combinations extracted from row.
"""
res = []
if len(row) == length:
canonical_combo = tuple(
sorted(list(row), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
res.append(canonical_combo)
else:
for combo in combinations(row, length):
canonical_combo = tuple(
sorted(list(combo), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
res.append(canonical_combo)
return res
def mapShortestUniqueRareComboLengthToRecords(records, length_to_rare):
"""
Maps each record to the shortest combination length that isolates it within a rare group (i.e., below resolution).
Args:
records: the input records.
length_to_rare: a dict of length to rare combo to count.
Returns:
rare_to_records: dict of rare combination lengths mapped to record lists
"""
rare_to_records = defaultdict(set)
unique_to_records = defaultdict(set)
length_to_combo_to_rare = {length: defaultdict(
set) for length in length_to_rare.keys()}
for i, record in enumerate(records):
matchedRare = False
matchedUnique = False
for length in sorted(length_to_rare.keys()):
if matchedUnique:
break
for combo in combinations(record, length):
canonical_combo = tuple(
sorted(list(combo), key=lambda x: f'{x[0]}:{x[1]}'.lower()))
if canonical_combo in length_to_rare.get(length, {}).keys():
# unique
if length_to_rare.get(length, {})[canonical_combo] == 1:
unique_to_records[length].add(i)
matchedUnique = True
length_to_combo_to_rare.get(
length, {})[canonical_combo].add(i)
if i in rare_to_records[length]:
rare_to_records[length].remove(i)
else:
if i not in unique_to_records[length]:
rare_to_records[length].add(i)
matchedRare = True
length_to_combo_to_rare.get(
length, {})[canonical_combo].add(i)
if matchedUnique:
break
if not matchedRare and not matchedUnique:
rare_to_records[0].add(i)
if not matchedUnique:
unique_to_records[0].add(i)
return unique_to_records, rare_to_records, length_to_combo_to_rare
def protect(value, resolution):
"""Protects a value from a privacy perspective by rounding down to the closest multiple of the supplied resolution.
Args:
value: the value to protect.
resolution: round values down to the closest multiple of this.
"""
rounded = floor(value / resolution) * resolution
return rounded
def comboToString(combo_tuple):
"""Creates a string from a tuple of (attribute, value) tuples.
Args:
combo_tuple: tuple of (attribute, value) tuples.
Returns:
combo_string: a ;-delimited string of ':'-concatenated attribute-value pairs.
"""
combo_string = ''
for (col, val) in combo_tuple:
combo_string += f'{col}:{val};'
return combo_string[:-1]
def loadSavedAggregates(path):
"""Loads sensitive aggregates from file to speed up evaluation.
Args:
path: location of the saved sensitive aggregates.
Returns:
length_to_combo_to_count: a dict mapping combination lengths to dicts mapping combinations to counts
"""
length_to_combo_to_count = defaultdict(dict)
with open(path, 'r') as f:
for i, line in enumerate(f):
if i == 0:
continue # skip header
parts = [x.strip() for x in line.split('\t')]
if len(parts[0]) > 0:
length, combo = stringToLengthAndCombo(parts[0])
length_to_combo_to_count[length][combo] = int(parts[1])
return length_to_combo_to_count
def stringToLengthAndCombo(combo_string):
"""Creates a tuple of (attribute, value) tuples from a given string.
Args:
combo_string: string representation of (attribute, value) tuples.
Returns:
combo_tuple: tuple of (attribute, value) tuples.
"""
length = len(combo_string.split(';'))
combo_list = []
for col_vals in combo_string.split(';'):
parts = col_vals.split(':')
if len(parts) == 2:
combo_list.append((parts[0], parts[1]))
combo_tuple = tuple(combo_list)
return length, combo_tuple
def plotStats(
x_axis, x_axis_title, y_bar, y_bar_title, y_line, y_line_title, color, darker_color, stats_tsv, stats_svg,
delimiter, style='whitegrid', palette='magma'):

Просмотреть файл

@ -2,31 +2,24 @@
* Copyright (c) Microsoft. All rights reserved.
* Licensed under the MIT license. See LICENSE file in the project.
*/
import { Stack, Spinner, Label } from '@fluentui/react'
import _ from 'lodash'
import { Stack, Label, Spinner } from '@fluentui/react'
import { memo, useCallback, useEffect, useRef, useState } from 'react'
import { useOnAttributeSelection } from './hooks'
import { IAttributesIntersection } from 'sds-wasm'
import { useMaxCount } from './hooks'
import { AttributeIntersectionValueChart } from '~components/Charts/AttributeIntersectionValueChart'
import { useStopPropagation } from '~components/Charts/hooks'
import { CsvRecord, IAttributesIntersectionValue } from '~models'
import {
useEvaluatedResultValue,
useNavigateResultValue,
useWasmWorkerValue,
} from '~states'
import {
useSelectedAttributeRows,
useSelectedAttributes,
} from '~states/dataShowcaseContext'
import { SetSelectedAttributesCallback } from '~components/Pages/DataShowcasePage/DataNavigation'
import { useWasmWorkerValue } from '~states'
export interface ColumnAttributeSelectorProps {
headers: CsvRecord
headerName: string
columnIndex: number
height: string | number
width: number | number
chartHeight: number
width: string | number
chartBarHeight: number
minHeight?: string | number
selectedAttributes: Set<string>
onSetSelectedAttributes: SetSelectedAttributesCallback
}
// fixed value for rough axis height so charts don't squish if selected down to 1
@ -34,87 +27,49 @@ const AXIS_HEIGHT = 16
export const ColumnAttributeSelector: React.FC<ColumnAttributeSelectorProps> =
memo(function ColumnAttributeSelector({
headers,
headerName,
columnIndex,
height,
width,
chartHeight,
chartBarHeight,
minHeight,
}: ColumnAttributeSelectorProps) {
const [selectedAttributeRows, setSelectedAttributeRows] =
useSelectedAttributeRows()
const [selectedAttributes, setSelectedAttributes] = useSelectedAttributes()
const navigateResult = useNavigateResultValue()
const [items, setItems] = useState<IAttributesIntersectionValue[]>([])
const [isLoading, setIsLoading] = useState(false)
const isMounted = useRef(true)
const evaluatedResult = useEvaluatedResultValue()
const worker = useWasmWorkerValue()
const selectedAttribute = selectedAttributes[columnIndex]
const maxCount =
_.max([
_.maxBy(items, item => item.estimatedCount)?.estimatedCount,
_.maxBy(items, item => item.actualCount ?? 0)?.actualCount,
]) ?? 1
const onAttributeSelection = useOnAttributeSelection(
setIsLoading,
selectedAttributes,
worker,
navigateResult,
isMounted,
setSelectedAttributes,
setSelectedAttributeRows,
)
onSetSelectedAttributes,
}: ColumnAttributeSelectorProps) {
const [items, setItems] = useState<IAttributesIntersection[]>([])
const [isLoading, setIsLoading] = useState(false)
const worker = useWasmWorkerValue()
const maxCount = useMaxCount(items)
const isMounted = useRef(true)
const stopPropagation = useStopPropagation()
const handleSelection = useCallback(
(item: IAttributesIntersectionValue | undefined) => {
async (item: IAttributesIntersection | undefined) => {
const newValue = item?.value
// toggle off with re-click
if (newValue === selectedAttribute) {
onAttributeSelection(columnIndex, undefined)
if (newValue === undefined || selectedAttributes.has(newValue)) {
await onSetSelectedAttributes(columnIndex, undefined)
} else {
onAttributeSelection(columnIndex, newValue)
await onSetSelectedAttributes(columnIndex, item)
}
},
[selectedAttribute, columnIndex, onAttributeSelection],
[selectedAttributes, onSetSelectedAttributes, columnIndex],
)
const stopPropagation = useStopPropagation()
useEffect(() => {
if (worker) {
setIsLoading(true)
worker
.intersectAttributesInColumnsWith(
headers,
navigateResult.allRows,
navigateResult.attrsInColumnsMap[columnIndex] ?? new Set(),
selectedAttributeRows,
selectedAttributes,
navigateResult.attrRowsMap,
columnIndex,
evaluatedResult.sensitiveAggregatedResult?.aggregatedCombinations,
)
.then(newItems => {
if (isMounted.current) {
setItems(newItems ?? [])
setIsLoading(false)
.attributesIntersectionsByColumn([headerName])
.then(intersections => {
if (!isMounted.current || !intersections) {
return
}
setItems(intersections[columnIndex] ?? [])
setIsLoading(false)
})
}
}, [
worker,
navigateResult,
headers,
selectedAttributeRows,
selectedAttributes,
columnIndex,
evaluatedResult,
setItems,
setIsLoading,
])
}, [worker, setIsLoading, setItems, headerName, columnIndex])
useEffect(() => {
return () => {
@ -151,8 +106,8 @@ export const ColumnAttributeSelector: React.FC<ColumnAttributeSelectorProps> =
items={items}
onClick={handleSelection}
maxCount={maxCount}
height={chartHeight * Math.max(items.length, 1) + AXIS_HEIGHT}
selectedValue={selectedAttribute}
height={chartBarHeight * Math.max(items.length, 1) + AXIS_HEIGHT}
selectedAttributes={selectedAttributes}
/>
</Stack.Item>
)}

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше