diff --git a/docs/source/example_notebooks/prediction/dowhy_causal_prediction_demo.ipynb b/docs/source/example_notebooks/prediction/dowhy_causal_prediction_demo.ipynb index c63561cfd..9695dfb1e 100644 --- a/docs/source/example_notebooks/prediction/dowhy_causal_prediction_demo.ipynb +++ b/docs/source/example_notebooks/prediction/dowhy_causal_prediction_demo.ipynb @@ -7,9 +7,7 @@ "source": [ "# Demo for DoWhy Causal Prediction on MNIST\n", "\n", - "We're adding prediction functionality to DoWhy. The goal of this notebook is to demonstrate an example of causal prediction using *Causally Adaptive Constraint Minimization (CACM)* [1]. \n", - "\n", - "[1] Kaur, J.N., Kıcıman, E., & Sharma, A. (2022). Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization. ArXiv, abs/2206.07837." + "The goal of this notebook is to demonstrate an example of causal prediction using *Causally Adaptive Constraint Minimization (CACM)* (https://arxiv.org/abs/2206.07837) [1]. " ] }, { @@ -21,9 +19,6 @@ "\n", "Domain generalization literature has largely focused on datasets with a single kind of distribution shift over one attribute. Using MNIST as an example, domains are created either by adding new values of a spurious attribute like rotation (e.g., Rotated-MNIST dataset [2]) or domains exhibit different values of correlation between the class label and a spurious attribute like color (e.g., Colored-MNIST [3]). However, real-world data often has multiple distribution shifts over different attributes. For example, satellite imagery data demonstrates distribution shifts over time as well as the region captured.\n", "\n", - "[2] Ghifary, M., Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain Generalization for Object Recognition with Multi-task Autoencoders. 2015 IEEE International Conference on Computer Vision (ICCV), 2551-2559.
\n", - "[3] Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant Risk Minimization. ArXiv, abs/1907.02893.\n", - "\n", "\n", "### Multi-attribute MNIST\n", "\n", @@ -264,7 +259,9 @@ "id": "adf11498", "metadata": {}, "source": [ - "### Initialize algorithm class: ERM" + "### Initialize algorithm class: ERM\n", + "\n", + "We have implemented Empirical Risk Minimization (ERM) in `dowhy.causal_prediction.algorithms` as a baseline." ] }, { @@ -340,9 +337,9 @@ "id": "bb318eea", "metadata": {}, "source": [ - "## Prediction with CACM\n", + "## Prediction with *CACM*\n", "\n", - "We now train and evaluate the above dataset with CACM. We specify the type of shifts present using list `attr_types` provided as input to CACM. Further instructions regarding using CACM with multi-attribute shifts is provided in the next section." + "We now train and evaluate the above dataset with *CACM*. We specify the type of shifts present using list `attr_types` provided as input to *CACM*. Further instructions regarding using *CACM* with multi-attribute shifts is provided in the next section." ] }, { @@ -362,7 +359,7 @@ "metadata": {}, "outputs": [], "source": [ - "# `attr_types` list contains type of attributes present (supports 'causal' and 'ind' currently)\n", + "# `attr_types` list contains type of attributes present (supports 'causal', 'conf', ind', and 'sel' currently)\n", "algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal'], lambda_causal=100.)" ] }, @@ -404,7 +401,7 @@ "source": [ "### MNIST Independent and Causal+Independent datasets\n", "\n", - "We show how to perform the above evaluation for `MNISTIndAttribute` and`MNISTCausalIndAttribute` datasets. Additional `attr_types` should be provided to CACM algorithm for handling multiple shifts. We currently support `Causal` and `Independent` distribution shifts in the data." + "We show how to perform the above evaluation for `MNISTIndAttribute` and`MNISTCausalIndAttribute` datasets. Additional `attr_types` should be provided to *CACM* algorithm for handling multiple shifts. We currently support *Causal*, *Confounded*, *Independent*, and *Selected* distribution shifts in the data." ] }, { @@ -435,7 +432,7 @@ "metadata": {}, "outputs": [], "source": [ - "algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind'], lambda_ind=10.)" + "algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind'], lambda_ind=10., E_eq_A=[0])" ] }, { @@ -466,8 +463,8 @@ "metadata": {}, "outputs": [], "source": [ - "# `attr_types` can be provided in any order\n", - "algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind', 'causal'], lambda_causal=100., lambda_ind=10.)" + "# `attr_types` should be ordered consistent with the attribute order in dataset class\n", + "algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal', 'ind'], lambda_causal=100., lambda_ind=10., E_eq_A=[1])" ] }, { @@ -480,18 +477,24 @@ "We provide our demo on MNIST using ERM and *CACM* algorithms. It is possible to extend the evaluation to new datasets and algorithms for evaluation.\n", "\n", "\n", - "New datasets can be added to `dowhy.causal_prediction.datasets` and imported here, as we did for MNIST. We provide description of the MNIST dataset (and variants) in `dowhy.causal_prediction.datasets.mnist` that will be helpful in creating new dataset classes. We currently support `Causal` and `Independent` distribution shifts in the data.\n", + "New datasets can be added to `dowhy.causal_prediction.datasets` and imported here, as we did for MNIST. We provide description of the MNIST dataset (and variants) in `dowhy.causal_prediction.datasets.mnist` that will be helpful in creating new dataset classes. We currently support *Causal*, *Confounded*, *Independent*, and *Selected* distribution shifts in the data. \n", "\n", "We have implemented ERM in `dowhy.causal_prediction.algorithms` as a baseline. Additional algorithms can be added by overriding the `training_step` function in base class `PredictionAlgorithm`." ] }, { - "cell_type": "code", - "execution_count": null, - "id": "6206a050", + "cell_type": "markdown", + "id": "3b9e1b05", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "## References\n", + "\n", + "[1] Kaur, J.N., Kıcıman, E., & Sharma, A. (2022). Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization. ArXiv, abs/2206.07837.\n", + "\n", + "[2] Ghifary, M., Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain Generalization for Object Recognition with Multi-task Autoencoders. 2015 IEEE International Conference on Computer Vision (ICCV), 2551-2559.
\n", + "\n", + "[3] Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant Risk Minimization. ArXiv, abs/1907.02893.\n" + ] } ], "metadata": { @@ -510,7 +513,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.11" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/dowhy/causal_prediction/algorithms/base_algorithm.py b/dowhy/causal_prediction/algorithms/base_algorithm.py index e2a671e56..9a61ccaa5 100644 --- a/dowhy/causal_prediction/algorithms/base_algorithm.py +++ b/dowhy/causal_prediction/algorithms/base_algorithm.py @@ -45,8 +45,8 @@ class PredictionAlgorithm(pl.LightningModule): """ if isinstance(batch[0], list): - x = torch.cat([x for x, y, _, _ in batch]) - y = torch.cat([y for x, y, _, _ in batch]) + x = torch.cat([x for x, y, _ in batch]) + y = torch.cat([y for x, y, _ in batch]) else: x = batch[0] y = batch[1] @@ -67,8 +67,8 @@ class PredictionAlgorithm(pl.LightningModule): """ if isinstance(batch[0], list): - x = torch.cat([x for x, y, _, _ in batch]) - y = torch.cat([y for x, y, _, _ in batch]) + x = torch.cat([x for x, y, _ in batch]) + y = torch.cat([y for x, y, _ in batch]) else: x = batch[0] y = batch[1] diff --git a/dowhy/causal_prediction/algorithms/cacm.py b/dowhy/causal_prediction/algorithms/cacm.py index 527059386..d985bd315 100644 --- a/dowhy/causal_prediction/algorithms/cacm.py +++ b/dowhy/causal_prediction/algorithms/cacm.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn import functional as F from dowhy.causal_prediction.algorithms.base_algorithm import PredictionAlgorithm -from dowhy.causal_prediction.algorithms.utils import mmd_compute +from dowhy.causal_prediction.algorithms.regularization import Regularizer class CACM(PredictionAlgorithm): @@ -19,12 +19,13 @@ class CACM(PredictionAlgorithm): ci_test="mmd", attr_types=[], E_conditioned=True, - E_eq_Aind=True, + E_eq_A=[], gamma=1e-6, lambda_causal=1.0, + lambda_conf=1.0, lambda_ind=1.0, + lambda_sel=1.0, ): - """Class for Causally Adaptive Constraint Minimization (CACM) Algorithm. @article{Kaur2022ModelingTD, title={Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization}, @@ -43,34 +44,31 @@ class CACM(PredictionAlgorithm): :param momentum: Value of momentum for SGD optimzer :param kernel_type: Kernel type for MMD penalty. Currently, supports "gaussian" (RBF). If None, distance between mean and second-order statistics (covariances) is used. :param ci_test: Conditional independence metric used for regularization penalty. Currently, MMD is supported. - :param attr_types: list of attribute types (based on relationship with label Y); can be in any order. Currently, only 'causal' and 'ind' are supported. + :param attr_types: list of attribute types (based on relationship with label Y); should be ordered according to attribute order in loaded dataset. + Currently, 'causal' (Causal), 'conf' (Confounded), 'ind' (Independent) and 'sel' (Selected) are supported. For single-shift datasets, use: ['causal'], ['ind'] For multi-shift datasets, use: ['causal', 'ind'] :param E_conditioned: Binary flag indicating whether E-conditioned regularization has to be applied - :param E_eq_Aind: Binary flag indicating whether environment (E) and Aind (Independent attribute) coincide + :param E_eq_A: list indicating indices of attributes that coincide with environment (E) definition; default is empty. :param gamma: kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See `mmd_compute` in utils.py) :param lambda_causal: MMD penalty hyperparameter for Causal shift + :param lambda_conf: MMD penalty hyperparameter for Confounded shift :param lambda_ind: MMD penalty hyperparameter for Independent shift + :param lambda_sel: MMD penalty hyperparameter for Selected shift :returns: an instance of PredictionAlgorithm class """ super().__init__(model, optimizer, lr, weight_decay, betas, momentum) - self.kernel_type = kernel_type + self.CACMRegularizer = Regularizer(E_conditioned, ci_test, kernel_type, gamma) + self.attr_types = attr_types - self.E_conditioned = E_conditioned # E-conditioned regularization by default - self.E_eq_Aind = E_eq_Aind - self.gamma = gamma + self.E_eq_A = E_eq_A self.lambda_causal = lambda_causal + self.lambda_conf = lambda_conf self.lambda_ind = lambda_ind - - def mmd(self, x, y): - """ - Compute MMD penalty between x and y. - - """ - return mmd_compute(x, y, self.kernel_type, self.gamma) + self.lambda_sel = lambda_sel def training_step(self, train_batch, batch_idx): """ @@ -85,173 +83,83 @@ class CACM(PredictionAlgorithm): objective = 0 correct, total = 0, 0 - penalty_causal, penalty_ind = 0, 0 + penalty_causal, penalty_conf, penalty_ind, penalty_sel = 0, 0, 0, 0 nmb = len(minibatches) - if len(minibatches[0]) == 4: - features = [self.featurizer(xi) for xi, _, _, _ in minibatches] - classifs = [self.classifier(fi) for fi in features] - targets = [yi for _, yi, _, _ in minibatches] - causal_attribute_labels = [ai for _, _, ai, _ in minibatches] - ind_attribute_labels = [ai for _, _, _, ai in minibatches] - elif len(minibatches[0]) == 3: # redundant for now since enforcing 4-dim output from dataset - features = [self.featurizer(xi) for xi, _, _ in minibatches] - classifs = [self.classifier(fi) for fi in features] - targets = [yi for _, yi, _ in minibatches] - causal_attribute_labels = [ai for _, _, ai in minibatches] + features = [self.featurizer(xi) for xi, _, _ in minibatches] + classifs = [self.classifier(fi) for fi in features] + targets = [yi for _, yi, _ in minibatches] for i in range(nmb): objective += F.cross_entropy(classifs[i], targets[i]) correct += (torch.argmax(classifs[i], dim=1) == targets[i]).float().sum().item() total += classifs[i].shape[0] - # Acause regularization - if "causal" in self.attr_types: - if self.E_conditioned: - for i in range(nmb): - unique_labels = torch.unique(targets[i]) # find distinct labels in environment - unique_label_indices = [] - for label in unique_labels: - label_ind = [ind for ind, j in enumerate(targets[i]) if j == label] - unique_label_indices.append(label_ind) - - nulabels = unique_labels.shape[0] - for idx in range(nulabels): - unique_attrs = torch.unique( - causal_attribute_labels[i][unique_label_indices[idx]] - ) # find distinct attributes in environment with same label - unique_attr_indices = [] - for attr in unique_attrs: - single_attr = [] - for y_attr_idx in unique_label_indices[idx]: - if causal_attribute_labels[i][y_attr_idx] == attr: - single_attr.append(y_attr_idx) - unique_attr_indices.append(single_attr) - - nuattr = unique_attrs.shape[0] - for aidx in range(nuattr): - for bidx in range(aidx + 1, nuattr): - penalty_causal += self.mmd( - classifs[i][unique_attr_indices[aidx]], classifs[i][unique_attr_indices[bidx]] - ) - - else: - overall_label_attr_vindices = {} # storing attribute indices - overall_label_attr_eindices = {} # storing corresponding environment indices - - for i in range(nmb): - unique_labels = torch.unique(targets[i]) # find distinct labels in environment - unique_label_indices = [] - for label in unique_labels: - label_ind = [ind for ind, j in enumerate(targets[i]) if j == label] - unique_label_indices.append(label_ind) - - nulabels = unique_labels.shape[0] - for idx in range(nulabels): - label = unique_labels[idx] - if label not in overall_label_attr_vindices: - overall_label_attr_vindices[label] = {} - overall_label_attr_eindices[label] = {} - - unique_attrs = torch.unique( - causal_attribute_labels[i][unique_label_indices[idx]] - ) # find distinct attributes in environment with same label - unique_attr_indices = [] - for attr in unique_attrs: # storing indices with same attribute value and label - if attr not in overall_label_attr_vindices[label]: - overall_label_attr_vindices[label][attr] = [] - overall_label_attr_eindices[label][attr] = [] - single_attr = [] - for y_attr_idx in unique_label_indices[idx]: - if causal_attribute_labels[i][y_attr_idx] == attr: - single_attr.append(y_attr_idx) - overall_label_attr_vindices[label][attr].append(single_attr) - overall_label_attr_eindices[label][attr].append(i) - unique_attr_indices.append(single_attr) - - for ( - y_val - ) in ( - overall_label_attr_vindices - ): # applying MMD penalty between distributions P(φ(x)|ai, y), P(φ(x)|aj, y) i.e samples with different attribute values but same label - tensors_list = [] - for attr in overall_label_attr_vindices[y_val]: - attrs_list = [] - if overall_label_attr_vindices[y_val][attr] != []: - for il_ind, indices_list in enumerate(overall_label_attr_vindices[y_val][attr]): - attrs_list.append( - classifs[overall_label_attr_eindices[y_val][attr][il_ind]][indices_list] - ) - if len(attrs_list) > 0: - tensor_attrs = torch.cat(attrs_list, 0) - tensors_list.append(tensor_attrs) - - nuattr = len(tensors_list) - for aidx in range(nuattr): - for bidx in range(aidx + 1, nuattr): - penalty_causal += self.mmd(tensors_list[aidx], tensors_list[bidx]) - - # Aind regularization - if "ind" in self.attr_types: - if self.E_eq_Aind: # Environment (E) and Independent attribute (Aind) coincide - for i in range(nmb): - for j in range(i + 1, nmb): - penalty_ind += self.mmd(classifs[i], classifs[j]) - - else: - if self.E_conditioned: - for i in range(nmb): - unique_aind_labels = torch.unique(ind_attribute_labels[i]) - unique_aind_label_indices = [] - for label in unique_aind_labels: - label_ind = [ind for ind, j in enumerate(ind_attribute_labels[i]) if j == label] - unique_aind_label_indices.append(label_ind) - - nulabels = unique_aind_labels.shape[0] - for aidx in range(nulabels): - for bidx in range(aidx + 1, nulabels): - penalty_ind += self.mmd( - classifs[i][unique_aind_label_indices[aidx]], - classifs[i][unique_aind_label_indices[bidx]], - ) - - else: # this currently assumes we have a disjoint set of attributes (Aind) across environments i.e., environment is defined by multiple closely related values of the attribute - overall_nmb_indices, nmb_id = [], [] - for i in range(nmb): - unique_attrs = torch.unique(ind_attribute_labels[i]) - unique_attr_indices = [] - for attr in unique_attrs: - attr_ind = [ind for ind, j in enumerate(ind_attribute_labels[i]) if j == attr] - unique_attr_indices.append(attr_ind) - overall_nmb_indices.append(attr_ind) - nmb_id.append(i) - - nuattr = len(overall_nmb_indices) - for aidx in range(nuattr): - for bidx in range(aidx + 1, nuattr): - a_nmb_id = nmb_id[aidx] - b_nmb_id = nmb_id[bidx] - penalty_ind += self.mmd( - classifs[a_nmb_id][overall_nmb_indices[aidx]], - classifs[b_nmb_id][overall_nmb_indices[bidx]], - ) - objective /= nmb - if nmb > 1: - penalty_causal /= nmb * (nmb - 1) / 2 - penalty_ind /= nmb * (nmb - 1) / 2 - - # Compile loss loss = objective - loss += self.lambda_causal * penalty_causal - loss += self.lambda_ind * penalty_ind - if torch.is_tensor(penalty_causal): - penalty_causal = penalty_causal.item() - self.log("penalty_causal", penalty_causal, on_step=False, on_epoch=True, prog_bar=True) - if torch.is_tensor(penalty_ind): - penalty_ind = penalty_ind.item() - self.log("penalty_ind", penalty_ind, on_step=False, on_epoch=True, prog_bar=True) + if self.attr_types != []: + for attr_type_idx, attr_type in enumerate(self.attr_types): + attribute_labels = [ + ai for _, _, ai in minibatches + ] # [(batch_size, num_attrs)_1, batch_size, num_attrs)_2, ..., (batch_size, num_attrs)_(num_environments)] + + E_eq_A_attr = attr_type_idx in self.E_eq_A + + # Acause regularization + if attr_type == "causal": + penalty_causal += self.CACMRegularizer.conditional_reg( + classifs, [a[:, attr_type_idx] for a in attribute_labels], [targets], nmb, E_eq_A_attr + ) + + # Aconf regularization + elif attr_type == "conf": + penalty_conf += self.CACMRegularizer.unconditional_reg( + classifs, [a[:, attr_type_idx] for a in attribute_labels], nmb, E_eq_A_attr + ) + + # Aind regularization + elif attr_type == "ind": + penalty_ind += self.CACMRegularizer.unconditional_reg( + classifs, [a[:, attr_type_idx] for a in attribute_labels], nmb, E_eq_A_attr + ) + + # Asel regularization + elif attr_type == "sel": + penalty_sel += self.CACMRegularizer.conditional_reg( + classifs, [a[:, attr_type_idx] for a in attribute_labels], [targets], nmb, E_eq_A_attr + ) + + if nmb > 1: + penalty_causal /= nmb * (nmb - 1) / 2 + penalty_conf /= nmb * (nmb - 1) / 2 + penalty_ind /= nmb * (nmb - 1) / 2 + penalty_sel /= nmb * (nmb - 1) / 2 + + # Compile loss + loss += self.lambda_causal * penalty_causal + loss += self.lambda_conf * penalty_conf + loss += self.lambda_ind * penalty_ind + loss += self.lambda_sel * penalty_sel + + if torch.is_tensor(penalty_causal): + penalty_causal = penalty_causal.item() + self.log("penalty_causal", penalty_causal, on_step=False, on_epoch=True, prog_bar=True) + if torch.is_tensor(penalty_conf): + penalty_conf = penalty_conf.item() + self.log("penalty_conf", penalty_conf, on_step=False, on_epoch=True, prog_bar=True) + if torch.is_tensor(penalty_ind): + penalty_ind = penalty_ind.item() + self.log("penalty_ind", penalty_ind, on_step=False, on_epoch=True, prog_bar=True) + if torch.is_tensor(penalty_sel): + penalty_sel = penalty_sel.item() + self.log("penalty_sel", penalty_sel, on_step=False, on_epoch=True, prog_bar=True) + + elif self.graph is not None: + pass # TODO + + else: + raise ValueError("No attribute types or graph provided.") acc = correct / total diff --git a/dowhy/causal_prediction/algorithms/erm.py b/dowhy/causal_prediction/algorithms/erm.py index fa15b3921..aa394f0ba 100644 --- a/dowhy/causal_prediction/algorithms/erm.py +++ b/dowhy/causal_prediction/algorithms/erm.py @@ -27,8 +27,8 @@ class ERM(PredictionAlgorithm): """ - x = torch.cat([x for x, y, _, _ in train_batch]) - y = torch.cat([y for x, y, _, _ in train_batch]) + x = torch.cat([x for x, y, _ in train_batch]) + y = torch.cat([y for x, y, _ in train_batch]) out = self.model(x) loss = F.cross_entropy(out, y) diff --git a/dowhy/causal_prediction/algorithms/regularization.py b/dowhy/causal_prediction/algorithms/regularization.py new file mode 100644 index 000000000..3d9bf4215 --- /dev/null +++ b/dowhy/causal_prediction/algorithms/regularization.py @@ -0,0 +1,300 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from dowhy.causal_prediction.algorithms.utils import mmd_compute + + +class Regularizer: + """ + Implements methods for applying unconditional and conditional regularization. + """ + + def __init__( + self, + E_conditioned, + ci_test, + kernel_type, + gamma, + ): + """ + :param E_conditioned: Binary flag indicating whether E-conditioned regularization has to be applied + :param ci_test: Conditional independence metric used for regularization penalty. Currently, MMD is supported. + :param kernel_type: Kernel type for MMD penalty. Currently, supports "gaussian" (RBF). If None, distance between mean and second-order statistics (covariances) is used. + :param gamma: kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See `mmd_compute` in utils.py) + """ + + self.E_conditioned = E_conditioned # E-conditioned regularization by default + self.ci_test = ci_test + self.kernel_type = kernel_type + self.gamma = gamma + + def mmd(self, x, y): + """ + Compute MMD penalty between x and y. + + """ + return mmd_compute(x, y, self.kernel_type, self.gamma) + + def unconditional_reg(self, classifs, attribute_labels, num_envs, E_eq_A=False): + """ + Implement unconditional regularization φ(x) ⊥⊥ A_i + + :param classifs: feature representations output from classifier layer (gφ(x)) + :param attribute_labels: attribute labels loaded with the dataset for attribute A_i + :param num_envs: number of environments/domains + :param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition + + """ + + penalty = 0 + + if E_eq_A: # Environment (E) and attribute (A) coincide + if self.E_conditioned is False: # there is no correlation between E and X_c + for i in range(num_envs): + for j in range(i + 1, num_envs): + penalty += self.mmd(classifs[i], classifs[j]) + + else: + if self.E_conditioned: + for i in range(num_envs): + unique_attr_labels = torch.unique(attribute_labels[i]) + unique_attr_label_indices = [] + for label in unique_attr_labels: + label_ind = [ind for ind, j in enumerate(attribute_labels[i]) if j == label] + unique_attr_label_indices.append(label_ind) + + nulabels = unique_attr_labels.shape[0] + for aidx in range(nulabels): + for bidx in range(aidx + 1, nulabels): + penalty += self.mmd( + classifs[i][unique_attr_label_indices[aidx]], + classifs[i][unique_attr_label_indices[bidx]], + ) + + else: # this currently assumes we have a disjoint set of attributes (Aind) across environments i.e., environment is defined by multiple closely related values of the attribute + overall_nmb_indices, nmb_id = [], [] + for i in range(num_envs): + unique_attrs = torch.unique(attribute_labels[i]) + unique_attr_indices = [] + for attr in unique_attrs: + attr_ind = [ind for ind, j in enumerate(attribute_labels[i]) if j == attr] + unique_attr_indices.append(attr_ind) + overall_nmb_indices.append(attr_ind) + nmb_id.append(i) + + nuattr = len(overall_nmb_indices) + for aidx in range(nuattr): + for bidx in range(aidx + 1, nuattr): + a_nmb_id = nmb_id[aidx] + b_nmb_id = nmb_id[bidx] + penalty += self.mmd( + classifs[a_nmb_id][overall_nmb_indices[aidx]], + classifs[b_nmb_id][overall_nmb_indices[bidx]], + ) + + return penalty + + def conditional_reg(self, classifs, attribute_labels, conditioning_subset, num_envs, E_eq_A=False): + """ + Implement conditional regularization φ(x) ⊥⊥ A_i | A_s + + :param classifs: feature representations output from classifier layer (gφ(x)) + :param attribute_labels: attribute labels loaded with the dataset for attribute A_i + :param conditioning_subset: list of subset of observed variables A_s (attributes + targets) such that (X_c, A_i) are d-separated conditioned on this subset + :param num_envs: number of environments/domains + :param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition + + Find group indices for conditional regularization based on conditioning subset by taking all possible combinations + e.g., conditioning_subset = [A1, Y], where A1 is in {0, 1} and Y is in {0, 1, 2}, + we assign groups in the following way: + A1 = 0, Y = 0 -> group 0 + A1 = 1, Y = 0 -> group 1 + A1 = 0, Y = 1 -> group 2 + A1 = 1, Y = 1 -> group 3 + A1 = 0, Y = 2 -> group 4 + A1 = 1, Y = 2 -> group 5 + + Code snippet for computing group indices adapted from WILDS: https://github.com/p-lambda/wilds + @inproceedings{wilds2021, + title = {{WILDS}: A Benchmark of in-the-Wild Distribution Shifts}, + author = {Pang Wei Koh and Shiori Sagawa and Henrik Marklund and Sang Michael Xie and Marvin Zhang and Akshay Balsubramani and Weihua Hu and Michihiro Yasunaga and Richard Lanas Phillips and Irena Gao and Tony Lee and Etienne David and Ian Stavness and Wei Guo and Berton A. Earnshaw and Imran S. Haque and Sara Beery and Jure Leskovec and Anshul Kundaje and Emma Pierson and Sergey Levine and Chelsea Finn and Percy Liang}, + booktitle = {International Conference on Machine Learning (ICML)}, + year = {2021} + }` + + """ + + penalty = 0 + + if E_eq_A: # Environment (E) and attribute (A) coincide + if self.E_conditioned is False: # there is no correlation between E and X_c + overall_group_vindices = {} # storing group indices + overall_group_eindices = {} # storing corresponding environment indices + + for i in range(num_envs): + conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset] + conditioning_subset_i_uniform = [ + ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i + ] + grouping_data = torch.cat(conditioning_subset_i_uniform, 1) + assert grouping_data.min() >= 0, "Group numbers cannot be negative." + cardinality = 1 + torch.max(grouping_data, dim=0)[0] + cumprod = torch.cumprod(cardinality, dim=0) + n_groups = cumprod[-1].item() + factors_np = np.concatenate(([1], cumprod[:-1])) + factors = torch.from_numpy(factors_np) + group_indices = grouping_data @ factors + + for group_idx in range(n_groups): + group_idx_indices = [ + gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx + ] + + if group_idx not in overall_group_vindices: + overall_group_vindices[group_idx] = {} + overall_group_eindices[group_idx] = {} + + unique_attrs = torch.unique( + attribute_labels[i][group_idx_indices] + ) # find distinct attributes in environment with same group_idx_indices + unique_attr_indices = [] + for attr in unique_attrs: # storing indices with same attribute value and group label + if attr not in overall_group_vindices[group_idx]: + overall_group_vindices[group_idx][attr] = [] + overall_group_eindices[group_idx][attr] = [] + single_attr = [] + for group_idx_indices_attr in group_idx_indices: + if attribute_labels[i][group_idx_indices_attr] == attr: + single_attr.append(group_idx_indices_attr) + overall_group_vindices[group_idx][attr].append(single_attr) + overall_group_eindices[group_idx][attr].append(i) + unique_attr_indices.append(single_attr) + + for ( + group_label + ) in ( + overall_group_vindices + ): # applying MMD penalty between distributions P(φ(x)|ai, g), P(φ(x)|aj, g) i.e samples with different attribute labelues but same group label + tensors_list = [] + for attr in overall_group_vindices[group_label]: + attrs_list = [] + if overall_group_vindices[group_label][attr] != []: + for il_ind, indices_list in enumerate(overall_group_vindices[group_label][attr]): + attrs_list.append( + classifs[overall_group_eindices[group_label][attr][il_ind]][indices_list] + ) + if len(attrs_list) > 0: + tensor_attrs = torch.cat(attrs_list, 0) + tensors_list.append(tensor_attrs) + + nuattr = len(tensors_list) + for aidx in range(nuattr): + for bidx in range(aidx + 1, nuattr): + penalty += self.mmd(tensors_list[aidx], tensors_list[bidx]) + + else: + if self.E_conditioned: + for i in range(num_envs): + conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset] + conditioning_subset_i_uniform = [ + ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i + ] + grouping_data = torch.cat(conditioning_subset_i_uniform, 1) + assert grouping_data.min() >= 0, "Group numbers cannot be negative." + cardinality = 1 + torch.max(grouping_data, dim=0)[0] + cumprod = torch.cumprod(cardinality, dim=0) + n_groups = cumprod[-1].item() + factors_np = np.concatenate(([1], cumprod[:-1])) + factors = torch.from_numpy(factors_np) + group_indices = grouping_data @ factors + + for group_idx in range(n_groups): + group_idx_indices = [ + gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx + ] + unique_attrs = torch.unique( + attribute_labels[i][group_idx_indices] + ) # find distinct attributes in environment with same group_idx_indices + unique_attr_indices = [] + for attr in unique_attrs: + single_attr = [] + for group_idx_indices_attr in group_idx_indices: + if attribute_labels[i][group_idx_indices_attr] == attr: + single_attr.append(group_idx_indices_attr) + unique_attr_indices.append(single_attr) + + nuattr = unique_attrs.shape[0] + for aidx in range(nuattr): + for bidx in range(aidx + 1, nuattr): + penalty += self.mmd( + classifs[i][unique_attr_indices[aidx]], classifs[i][unique_attr_indices[bidx]] + ) + + else: + overall_group_vindices = {} # storing group indices + overall_group_eindices = {} # storing corresponding environment indices + + for i in range(num_envs): + conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset] + conditioning_subset_i_uniform = [ + ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i + ] + grouping_data = torch.cat(conditioning_subset_i_uniform, 1) + assert grouping_data.min() >= 0, "Group numbers cannot be negative." + cardinality = 1 + torch.max(grouping_data, dim=0)[0] + cumprod = torch.cumprod(cardinality, dim=0) + n_groups = cumprod[-1].item() + factors_np = np.concatenate(([1], cumprod[:-1])) + factors = torch.from_numpy(factors_np) + group_indices = grouping_data @ factors + + for group_idx in range(n_groups): + group_idx_indices = [ + gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx + ] + + if group_idx not in overall_group_vindices: + overall_group_vindices[group_idx] = {} + overall_group_eindices[group_idx] = {} + + unique_attrs = torch.unique( + attribute_labels[i][group_idx_indices] + ) # find distinct attributes in environment with same group_idx_indices + unique_attr_indices = [] + for attr in unique_attrs: # storing indices with same attribute value and group label + if attr not in overall_group_vindices[group_idx]: + overall_group_vindices[group_idx][attr] = [] + overall_group_eindices[group_idx][attr] = [] + single_attr = [] + for group_idx_indices_attr in group_idx_indices: + if attribute_labels[i][group_idx_indices_attr] == attr: + single_attr.append(group_idx_indices_attr) + overall_group_vindices[group_idx][attr].append(single_attr) + overall_group_eindices[group_idx][attr].append(i) + unique_attr_indices.append(single_attr) + + for ( + group_label + ) in ( + overall_group_vindices + ): # applying MMD penalty between distributions P(φ(x)|ai, g), P(φ(x)|aj, g) i.e samples with different attribute labelues but same group label + tensors_list = [] + for attr in overall_group_vindices[group_label]: + attrs_list = [] + if overall_group_vindices[group_label][attr] != []: + for il_ind, indices_list in enumerate(overall_group_vindices[group_label][attr]): + attrs_list.append( + classifs[overall_group_eindices[group_label][attr][il_ind]][indices_list] + ) + if len(attrs_list) > 0: + tensor_attrs = torch.cat(attrs_list, 0) + tensors_list.append(tensor_attrs) + + nuattr = len(tensors_list) + for aidx in range(nuattr): + for bidx in range(aidx + 1, nuattr): + penalty += self.mmd(tensors_list[aidx], tensors_list[bidx]) + + return penalty diff --git a/dowhy/causal_prediction/dataloaders/get_data_loader.py b/dowhy/causal_prediction/dataloaders/get_data_loader.py index 7e186728d..aaf3b67b6 100644 --- a/dowhy/causal_prediction/dataloaders/get_data_loader.py +++ b/dowhy/causal_prediction/dataloaders/get_data_loader.py @@ -108,7 +108,6 @@ def get_loaders( holdout_fraction=0.2, trial_seed=0, ): - """Return training, validation, and test dataloaders. :param dataset: dataset class containing list of environments diff --git a/dowhy/causal_prediction/datasets/mnist.py b/dowhy/causal_prediction/datasets/mnist.py index ea2b7e89d..3fe99b834 100644 --- a/dowhy/causal_prediction/datasets/mnist.py +++ b/dowhy/causal_prediction/datasets/mnist.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import torch +import torchvision from PIL import Image from torch.utils.data import Subset, TensorDataset from torchvision import transforms @@ -22,12 +23,12 @@ from dowhy.causal_prediction.datasets.base_dataset import MultipleDomainDataset * datasets initialized from torchvision.datasets.MNIST * We assume causal attribute (Acause) = color, independent attribute (Aind) = rotation * Environments/domains stored in list self.datasets (required for all datasets) - * Default env structure is TensorDataset(x, y, Acause, Aind) - * We do not need Aind (rotation) explicitly here since E=Aind - * In Aind case, we return (x, y, _, _) - * '_' is replaced by any dummy vector (this is because the training loop assumes 4 inputs in TensorDataset) + * Default env structure is TensorDataset(x, y, a) + * a is a combined tensor for all attributes (metadata) a1, a2, ..., ak + * a's shape is (n, k) where n is the number of samples in the environment """ + # single-attribute Causal class MNISTCausalAttribute(MultipleDomainDataset): N_STEPS = 5001 @@ -103,8 +104,9 @@ class MNISTCausalAttribute(MultipleDomainDataset): x = images.float().div_(255.0) y = labels.view(-1).long() + a = torch.unsqueeze(colors, 1) - return TensorDataset(x, y, colors, colors) + return TensorDataset(x, y, a) def torch_bernoulli_(self, p, size): return (torch.rand(size) < p).float() @@ -148,21 +150,21 @@ class MNISTIndAttribute(MultipleDomainDataset): for i, env in enumerate(angles[:-1]): images = original_images[:50000][i::2] labels = original_labels[:50000][i::2] - self.datasets.append(self.rotate_dataset(images, labels, angles[i])) + self.datasets.append(self.rotate_dataset(images, labels, i, angles[i])) images = original_images[50000:] labels = original_labels[50000:] - self.datasets.append(self.rotate_dataset(images, labels, angles[-1])) + self.datasets.append(self.rotate_dataset(images, labels, len(angles) - 1, angles[-1])) # test environment original_dataset_te = MNIST(root, train=False, download=download) original_images = original_dataset_te.data original_labels = original_dataset_te.targets - self.datasets.append(self.rotate_dataset(original_images, original_labels, angles[-1])) + self.datasets.append(self.rotate_dataset(original_images, original_labels, len(angles) - 1, angles[-1])) self.input_shape = self.INPUT_SHAPE self.num_classes = 2 - def rotate_dataset(self, images, labels, angle): + def rotate_dataset(self, images, labels, env_id, angle): """ Transform MNIST dataset by applying rotation to images. Attribute (rotation angle) is independent of label Y. @@ -170,12 +172,16 @@ class MNISTIndAttribute(MultipleDomainDataset): :param images: original MNIST images :param labels: original MNIST labels :param angle: Value of rotation angle used for transforming the image - :returns: TensorDataset containing transformed images and labels + :returns: TensorDataset containing transformed images, labels, and attributes (angle) """ rotation = transforms.Compose( [ transforms.ToPILImage(), - transforms.Lambda(lambda x: rotate(x, int(angle), fill=(0,), resample=Image.BICUBIC)), + transforms.Lambda( + lambda x: rotate( + x, int(angle), fill=(0,), interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + ), transforms.ToTensor(), ] ) @@ -192,8 +198,10 @@ class MNISTIndAttribute(MultipleDomainDataset): x[i] = rotation(images[i].float().div_(255.0)) y = labels.view(-1).long() + a = torch.full((y.shape[0],), env_id, dtype=torch.float32) + a = torch.unsqueeze(a, 1) - return TensorDataset(x, y, y, y) + return TensorDataset(x, y, a) def torch_bernoulli_(self, p, size): return (torch.rand(size) < p).float() @@ -238,21 +246,23 @@ class MNISTCausalIndAttribute(MultipleDomainDataset): for i, env in enumerate(environments[:-1]): images = original_images[:50000][i::2] labels = original_labels[:50000][i::2] - self.datasets.append(self.color_rot_dataset(images, labels, env, angles[i])) + self.datasets.append(self.color_rot_dataset(images, labels, env, i, angles[i])) images = original_images[50000:] labels = original_labels[50000:] - self.datasets.append(self.color_rot_dataset(images, labels, environments[-1], angles[-1])) + self.datasets.append(self.color_rot_dataset(images, labels, environments[-1], len(angles) - 1, angles[-1])) # test environment original_dataset_te = MNIST(root, train=False, download=download) original_images = original_dataset_te.data original_labels = original_dataset_te.targets - self.datasets.append(self.color_rot_dataset(original_images, original_labels, environments[-1], angles[-1])) + self.datasets.append( + self.color_rot_dataset(original_images, original_labels, environments[-1], len(angles) - 1, angles[-1]) + ) self.input_shape = self.INPUT_SHAPE self.num_classes = 2 - def color_rot_dataset(self, images, labels, environment, angle): + def color_rot_dataset(self, images, labels, environment, env_id, angle): """ Transform MNIST dataset by (i) applying rotation to images, then (ii) introducing correlation between attribute (color) and label. Attribute (rotation angle) is independent of label Y; there is a direct-causal relationship between label Y and color. @@ -261,7 +271,7 @@ class MNISTCausalIndAttribute(MultipleDomainDataset): :param labels: original MNIST labels :param environment: Value of correlation between color and label :param angle: Value of rotation angle used for transforming the image - :returns: TensorDataset containing transformed images, labels, and attributes (color) + :returns: TensorDataset containing transformed images, labels, and attributes (color, angle) """ # Subsample 2x for computational convenience images = images.reshape((-1, 28, 28))[:, ::2, ::2] @@ -272,8 +282,10 @@ class MNISTCausalIndAttribute(MultipleDomainDataset): x = images # .float().div_(255.0) y = labels.view(-1).long() + angles = torch.full((y.shape[0],), env_id, dtype=torch.float32) + a = torch.stack((colors, angles), 1) - return TensorDataset(x, y, colors, colors) + return TensorDataset(x, y, a) def color_dataset(self, images, labels, environment): """