Adding causal prediction for general datasets (CACM) (#925)

* add causal prediction using cacm + demo notebook

Signed-off-by: jivatneet <jivatneet@gmail.com>

* Adding causal prediction code using CACM with demo notebook

Signed-off-by: jivatneet <jivatneet@gmail.com>

* addressed PR comments for causal prediction

Signed-off-by: jivatneet <jivatneet@gmail.com>

* addressing comments: fixing documentation in base_algorithm and cacm

Signed-off-by: jivatneet <jivatneet@gmail.com>

* addressing comments: fixing documentation in base_dataset

Signed-off-by: jivatneet <jivatneet@gmail.com>

* rename Algorithm class

Signed-off-by: jivatneet <jivatneet@gmail.com>

* updating base_dataset for format check

Signed-off-by: jivatneet <jivatneet@gmail.com>

* adding exception for optimizer; resolving list bug

Signed-off-by: jivatneet <jivatneet@gmail.com>

* adding docs for base algo

Signed-off-by: jivatneet <jivatneet@gmail.com>

* add algo files for general regularization

Signed-off-by: jivatneet <jivatneet@gmail.com>

* modify mnist

Signed-off-by: jivatneet <jivatneet@gmail.com>

* add changes for general cacm api

Signed-off-by: jivatneet <jivatneet@gmail.com>

* format changes general cacm api

Signed-off-by: jivatneet <jivatneet@gmail.com>

* resolve comments for general cacm api

Signed-off-by: jivatneet <jivatneet@gmail.com>

* format check

Signed-off-by: jivatneet <jivatneet@gmail.com>

* Revert "format check"

This reverts commit 6cbc3cb5c9.

Signed-off-by: jivatneet <jivatneet@gmail.com>

* format check mnist

Signed-off-by: jivatneet <jivatneet@gmail.com>

---------

Signed-off-by: jivatneet <jivatneet@gmail.com>
Signed-off-by: Jivat Neet <39404029+jivatneet@users.noreply.github.com>
This commit is contained in:
Jivat Neet 2023-06-15 17:29:01 +05:30 коммит произвёл GitHub
Родитель c87b511bfd
Коммит 0fb1314c5d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 440 добавлений и 218 удалений

Просмотреть файл

@ -7,9 +7,7 @@
"source": [
"# Demo for DoWhy Causal Prediction on MNIST\n",
"\n",
"We're adding prediction functionality to DoWhy. The goal of this notebook is to demonstrate an example of causal prediction using *Causally Adaptive Constraint Minimization (CACM)* [1]. \n",
"\n",
"[1] Kaur, J.N., Kıcıman, E., & Sharma, A. (2022). Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization. ArXiv, abs/2206.07837."
"The goal of this notebook is to demonstrate an example of causal prediction using *Causally Adaptive Constraint Minimization (CACM)* (https://arxiv.org/abs/2206.07837) [1]. "
]
},
{
@ -21,9 +19,6 @@
"\n",
"Domain generalization literature has largely focused on datasets with a single kind of distribution shift over one attribute. Using MNIST as an example, domains are created either by adding new values of a spurious attribute like rotation (e.g., Rotated-MNIST dataset [2]) or domains exhibit different values of correlation between the class label and a spurious attribute like color (e.g., Colored-MNIST [3]). However, real-world data often has multiple distribution shifts over different attributes. For example, satellite imagery data demonstrates distribution shifts over time as well as the region captured.\n",
"\n",
"[2] Ghifary, M., Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain Generalization for Object Recognition with Multi-task Autoencoders. 2015 IEEE International Conference on Computer Vision (ICCV), 2551-2559.<br>\n",
"[3] Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant Risk Minimization. ArXiv, abs/1907.02893.\n",
"\n",
"\n",
"### Multi-attribute MNIST\n",
"\n",
@ -264,7 +259,9 @@
"id": "adf11498",
"metadata": {},
"source": [
"### Initialize algorithm class: ERM"
"### Initialize algorithm class: ERM\n",
"\n",
"We have implemented Empirical Risk Minimization (ERM) in `dowhy.causal_prediction.algorithms` as a baseline."
]
},
{
@ -340,9 +337,9 @@
"id": "bb318eea",
"metadata": {},
"source": [
"## Prediction with CACM\n",
"## Prediction with *CACM*\n",
"\n",
"We now train and evaluate the above dataset with CACM. We specify the type of shifts present using list `attr_types` provided as input to CACM. Further instructions regarding using CACM with multi-attribute shifts is provided in the next section."
"We now train and evaluate the above dataset with *CACM*. We specify the type of shifts present using list `attr_types` provided as input to *CACM*. Further instructions regarding using *CACM* with multi-attribute shifts is provided in the next section."
]
},
{
@ -362,7 +359,7 @@
"metadata": {},
"outputs": [],
"source": [
"# `attr_types` list contains type of attributes present (supports 'causal' and 'ind' currently)\n",
"# `attr_types` list contains type of attributes present (supports 'causal', 'conf', ind', and 'sel' currently)\n",
"algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal'], lambda_causal=100.)"
]
},
@ -404,7 +401,7 @@
"source": [
"### MNIST Independent and Causal+Independent datasets\n",
"\n",
"We show how to perform the above evaluation for `MNISTIndAttribute` and`MNISTCausalIndAttribute` datasets. Additional `attr_types` should be provided to CACM algorithm for handling multiple shifts. We currently support `Causal` and `Independent` distribution shifts in the data."
"We show how to perform the above evaluation for `MNISTIndAttribute` and`MNISTCausalIndAttribute` datasets. Additional `attr_types` should be provided to *CACM* algorithm for handling multiple shifts. We currently support *Causal*, *Confounded*, *Independent*, and *Selected* distribution shifts in the data."
]
},
{
@ -435,7 +432,7 @@
"metadata": {},
"outputs": [],
"source": [
"algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind'], lambda_ind=10.)"
"algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind'], lambda_ind=10., E_eq_A=[0])"
]
},
{
@ -466,8 +463,8 @@
"metadata": {},
"outputs": [],
"source": [
"# `attr_types` can be provided in any order\n",
"algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind', 'causal'], lambda_causal=100., lambda_ind=10.)"
"# `attr_types` should be ordered consistent with the attribute order in dataset class\n",
"algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal', 'ind'], lambda_causal=100., lambda_ind=10., E_eq_A=[1])"
]
},
{
@ -480,18 +477,24 @@
"We provide our demo on MNIST using ERM and *CACM* algorithms. It is possible to extend the evaluation to new datasets and algorithms for evaluation.\n",
"\n",
"\n",
"New datasets can be added to `dowhy.causal_prediction.datasets` and imported here, as we did for MNIST. We provide description of the MNIST dataset (and variants) in `dowhy.causal_prediction.datasets.mnist` that will be helpful in creating new dataset classes. We currently support `Causal` and `Independent` distribution shifts in the data.\n",
"New datasets can be added to `dowhy.causal_prediction.datasets` and imported here, as we did for MNIST. We provide description of the MNIST dataset (and variants) in `dowhy.causal_prediction.datasets.mnist` that will be helpful in creating new dataset classes. We currently support *Causal*, *Confounded*, *Independent*, and *Selected* distribution shifts in the data. \n",
"\n",
"We have implemented ERM in `dowhy.causal_prediction.algorithms` as a baseline. Additional algorithms can be added by overriding the `training_step` function in base class `PredictionAlgorithm`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6206a050",
"cell_type": "markdown",
"id": "3b9e1b05",
"metadata": {},
"outputs": [],
"source": []
"source": [
"## References\n",
"\n",
"[1] Kaur, J.N., Kıcıman, E., & Sharma, A. (2022). Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization. ArXiv, abs/2206.07837.\n",
"\n",
"[2] Ghifary, M., Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain Generalization for Object Recognition with Multi-task Autoencoders. 2015 IEEE International Conference on Computer Vision (ICCV), 2551-2559.<br>\n",
"\n",
"[3] Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant Risk Minimization. ArXiv, abs/1907.02893.\n"
]
}
],
"metadata": {
@ -510,7 +513,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
"version": "3.8.16"
}
},
"nbformat": 4,

Просмотреть файл

@ -45,8 +45,8 @@ class PredictionAlgorithm(pl.LightningModule):
"""
if isinstance(batch[0], list):
x = torch.cat([x for x, y, _, _ in batch])
y = torch.cat([y for x, y, _, _ in batch])
x = torch.cat([x for x, y, _ in batch])
y = torch.cat([y for x, y, _ in batch])
else:
x = batch[0]
y = batch[1]
@ -67,8 +67,8 @@ class PredictionAlgorithm(pl.LightningModule):
"""
if isinstance(batch[0], list):
x = torch.cat([x for x, y, _, _ in batch])
y = torch.cat([y for x, y, _, _ in batch])
x = torch.cat([x for x, y, _ in batch])
y = torch.cat([y for x, y, _ in batch])
else:
x = batch[0]
y = batch[1]

Просмотреть файл

@ -3,7 +3,7 @@ from torch import nn
from torch.nn import functional as F
from dowhy.causal_prediction.algorithms.base_algorithm import PredictionAlgorithm
from dowhy.causal_prediction.algorithms.utils import mmd_compute
from dowhy.causal_prediction.algorithms.regularization import Regularizer
class CACM(PredictionAlgorithm):
@ -19,12 +19,13 @@ class CACM(PredictionAlgorithm):
ci_test="mmd",
attr_types=[],
E_conditioned=True,
E_eq_Aind=True,
E_eq_A=[],
gamma=1e-6,
lambda_causal=1.0,
lambda_conf=1.0,
lambda_ind=1.0,
lambda_sel=1.0,
):
"""Class for Causally Adaptive Constraint Minimization (CACM) Algorithm.
@article{Kaur2022ModelingTD,
title={Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization},
@ -43,34 +44,31 @@ class CACM(PredictionAlgorithm):
:param momentum: Value of momentum for SGD optimzer
:param kernel_type: Kernel type for MMD penalty. Currently, supports "gaussian" (RBF). If None, distance between mean and second-order statistics (covariances) is used.
:param ci_test: Conditional independence metric used for regularization penalty. Currently, MMD is supported.
:param attr_types: list of attribute types (based on relationship with label Y); can be in any order. Currently, only 'causal' and 'ind' are supported.
:param attr_types: list of attribute types (based on relationship with label Y); should be ordered according to attribute order in loaded dataset.
Currently, 'causal' (Causal), 'conf' (Confounded), 'ind' (Independent) and 'sel' (Selected) are supported.
For single-shift datasets, use: ['causal'], ['ind']
For multi-shift datasets, use: ['causal', 'ind']
:param E_conditioned: Binary flag indicating whether E-conditioned regularization has to be applied
:param E_eq_Aind: Binary flag indicating whether environment (E) and Aind (Independent attribute) coincide
:param E_eq_A: list indicating indices of attributes that coincide with environment (E) definition; default is empty.
:param gamma: kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See `mmd_compute` in utils.py)
:param lambda_causal: MMD penalty hyperparameter for Causal shift
:param lambda_conf: MMD penalty hyperparameter for Confounded shift
:param lambda_ind: MMD penalty hyperparameter for Independent shift
:param lambda_sel: MMD penalty hyperparameter for Selected shift
:returns: an instance of PredictionAlgorithm class
"""
super().__init__(model, optimizer, lr, weight_decay, betas, momentum)
self.kernel_type = kernel_type
self.CACMRegularizer = Regularizer(E_conditioned, ci_test, kernel_type, gamma)
self.attr_types = attr_types
self.E_conditioned = E_conditioned # E-conditioned regularization by default
self.E_eq_Aind = E_eq_Aind
self.gamma = gamma
self.E_eq_A = E_eq_A
self.lambda_causal = lambda_causal
self.lambda_conf = lambda_conf
self.lambda_ind = lambda_ind
def mmd(self, x, y):
"""
Compute MMD penalty between x and y.
"""
return mmd_compute(x, y, self.kernel_type, self.gamma)
self.lambda_sel = lambda_sel
def training_step(self, train_batch, batch_idx):
"""
@ -85,173 +83,83 @@ class CACM(PredictionAlgorithm):
objective = 0
correct, total = 0, 0
penalty_causal, penalty_ind = 0, 0
penalty_causal, penalty_conf, penalty_ind, penalty_sel = 0, 0, 0, 0
nmb = len(minibatches)
if len(minibatches[0]) == 4:
features = [self.featurizer(xi) for xi, _, _, _ in minibatches]
classifs = [self.classifier(fi) for fi in features]
targets = [yi for _, yi, _, _ in minibatches]
causal_attribute_labels = [ai for _, _, ai, _ in minibatches]
ind_attribute_labels = [ai for _, _, _, ai in minibatches]
elif len(minibatches[0]) == 3: # redundant for now since enforcing 4-dim output from dataset
features = [self.featurizer(xi) for xi, _, _ in minibatches]
classifs = [self.classifier(fi) for fi in features]
targets = [yi for _, yi, _ in minibatches]
causal_attribute_labels = [ai for _, _, ai in minibatches]
features = [self.featurizer(xi) for xi, _, _ in minibatches]
classifs = [self.classifier(fi) for fi in features]
targets = [yi for _, yi, _ in minibatches]
for i in range(nmb):
objective += F.cross_entropy(classifs[i], targets[i])
correct += (torch.argmax(classifs[i], dim=1) == targets[i]).float().sum().item()
total += classifs[i].shape[0]
# Acause regularization
if "causal" in self.attr_types:
if self.E_conditioned:
for i in range(nmb):
unique_labels = torch.unique(targets[i]) # find distinct labels in environment
unique_label_indices = []
for label in unique_labels:
label_ind = [ind for ind, j in enumerate(targets[i]) if j == label]
unique_label_indices.append(label_ind)
nulabels = unique_labels.shape[0]
for idx in range(nulabels):
unique_attrs = torch.unique(
causal_attribute_labels[i][unique_label_indices[idx]]
) # find distinct attributes in environment with same label
unique_attr_indices = []
for attr in unique_attrs:
single_attr = []
for y_attr_idx in unique_label_indices[idx]:
if causal_attribute_labels[i][y_attr_idx] == attr:
single_attr.append(y_attr_idx)
unique_attr_indices.append(single_attr)
nuattr = unique_attrs.shape[0]
for aidx in range(nuattr):
for bidx in range(aidx + 1, nuattr):
penalty_causal += self.mmd(
classifs[i][unique_attr_indices[aidx]], classifs[i][unique_attr_indices[bidx]]
)
else:
overall_label_attr_vindices = {} # storing attribute indices
overall_label_attr_eindices = {} # storing corresponding environment indices
for i in range(nmb):
unique_labels = torch.unique(targets[i]) # find distinct labels in environment
unique_label_indices = []
for label in unique_labels:
label_ind = [ind for ind, j in enumerate(targets[i]) if j == label]
unique_label_indices.append(label_ind)
nulabels = unique_labels.shape[0]
for idx in range(nulabels):
label = unique_labels[idx]
if label not in overall_label_attr_vindices:
overall_label_attr_vindices[label] = {}
overall_label_attr_eindices[label] = {}
unique_attrs = torch.unique(
causal_attribute_labels[i][unique_label_indices[idx]]
) # find distinct attributes in environment with same label
unique_attr_indices = []
for attr in unique_attrs: # storing indices with same attribute value and label
if attr not in overall_label_attr_vindices[label]:
overall_label_attr_vindices[label][attr] = []
overall_label_attr_eindices[label][attr] = []
single_attr = []
for y_attr_idx in unique_label_indices[idx]:
if causal_attribute_labels[i][y_attr_idx] == attr:
single_attr.append(y_attr_idx)
overall_label_attr_vindices[label][attr].append(single_attr)
overall_label_attr_eindices[label][attr].append(i)
unique_attr_indices.append(single_attr)
for (
y_val
) in (
overall_label_attr_vindices
): # applying MMD penalty between distributions P(φ(x)|ai, y), P(φ(x)|aj, y) i.e samples with different attribute values but same label
tensors_list = []
for attr in overall_label_attr_vindices[y_val]:
attrs_list = []
if overall_label_attr_vindices[y_val][attr] != []:
for il_ind, indices_list in enumerate(overall_label_attr_vindices[y_val][attr]):
attrs_list.append(
classifs[overall_label_attr_eindices[y_val][attr][il_ind]][indices_list]
)
if len(attrs_list) > 0:
tensor_attrs = torch.cat(attrs_list, 0)
tensors_list.append(tensor_attrs)
nuattr = len(tensors_list)
for aidx in range(nuattr):
for bidx in range(aidx + 1, nuattr):
penalty_causal += self.mmd(tensors_list[aidx], tensors_list[bidx])
# Aind regularization
if "ind" in self.attr_types:
if self.E_eq_Aind: # Environment (E) and Independent attribute (Aind) coincide
for i in range(nmb):
for j in range(i + 1, nmb):
penalty_ind += self.mmd(classifs[i], classifs[j])
else:
if self.E_conditioned:
for i in range(nmb):
unique_aind_labels = torch.unique(ind_attribute_labels[i])
unique_aind_label_indices = []
for label in unique_aind_labels:
label_ind = [ind for ind, j in enumerate(ind_attribute_labels[i]) if j == label]
unique_aind_label_indices.append(label_ind)
nulabels = unique_aind_labels.shape[0]
for aidx in range(nulabels):
for bidx in range(aidx + 1, nulabels):
penalty_ind += self.mmd(
classifs[i][unique_aind_label_indices[aidx]],
classifs[i][unique_aind_label_indices[bidx]],
)
else: # this currently assumes we have a disjoint set of attributes (Aind) across environments i.e., environment is defined by multiple closely related values of the attribute
overall_nmb_indices, nmb_id = [], []
for i in range(nmb):
unique_attrs = torch.unique(ind_attribute_labels[i])
unique_attr_indices = []
for attr in unique_attrs:
attr_ind = [ind for ind, j in enumerate(ind_attribute_labels[i]) if j == attr]
unique_attr_indices.append(attr_ind)
overall_nmb_indices.append(attr_ind)
nmb_id.append(i)
nuattr = len(overall_nmb_indices)
for aidx in range(nuattr):
for bidx in range(aidx + 1, nuattr):
a_nmb_id = nmb_id[aidx]
b_nmb_id = nmb_id[bidx]
penalty_ind += self.mmd(
classifs[a_nmb_id][overall_nmb_indices[aidx]],
classifs[b_nmb_id][overall_nmb_indices[bidx]],
)
objective /= nmb
if nmb > 1:
penalty_causal /= nmb * (nmb - 1) / 2
penalty_ind /= nmb * (nmb - 1) / 2
# Compile loss
loss = objective
loss += self.lambda_causal * penalty_causal
loss += self.lambda_ind * penalty_ind
if torch.is_tensor(penalty_causal):
penalty_causal = penalty_causal.item()
self.log("penalty_causal", penalty_causal, on_step=False, on_epoch=True, prog_bar=True)
if torch.is_tensor(penalty_ind):
penalty_ind = penalty_ind.item()
self.log("penalty_ind", penalty_ind, on_step=False, on_epoch=True, prog_bar=True)
if self.attr_types != []:
for attr_type_idx, attr_type in enumerate(self.attr_types):
attribute_labels = [
ai for _, _, ai in minibatches
] # [(batch_size, num_attrs)_1, batch_size, num_attrs)_2, ..., (batch_size, num_attrs)_(num_environments)]
E_eq_A_attr = attr_type_idx in self.E_eq_A
# Acause regularization
if attr_type == "causal":
penalty_causal += self.CACMRegularizer.conditional_reg(
classifs, [a[:, attr_type_idx] for a in attribute_labels], [targets], nmb, E_eq_A_attr
)
# Aconf regularization
elif attr_type == "conf":
penalty_conf += self.CACMRegularizer.unconditional_reg(
classifs, [a[:, attr_type_idx] for a in attribute_labels], nmb, E_eq_A_attr
)
# Aind regularization
elif attr_type == "ind":
penalty_ind += self.CACMRegularizer.unconditional_reg(
classifs, [a[:, attr_type_idx] for a in attribute_labels], nmb, E_eq_A_attr
)
# Asel regularization
elif attr_type == "sel":
penalty_sel += self.CACMRegularizer.conditional_reg(
classifs, [a[:, attr_type_idx] for a in attribute_labels], [targets], nmb, E_eq_A_attr
)
if nmb > 1:
penalty_causal /= nmb * (nmb - 1) / 2
penalty_conf /= nmb * (nmb - 1) / 2
penalty_ind /= nmb * (nmb - 1) / 2
penalty_sel /= nmb * (nmb - 1) / 2
# Compile loss
loss += self.lambda_causal * penalty_causal
loss += self.lambda_conf * penalty_conf
loss += self.lambda_ind * penalty_ind
loss += self.lambda_sel * penalty_sel
if torch.is_tensor(penalty_causal):
penalty_causal = penalty_causal.item()
self.log("penalty_causal", penalty_causal, on_step=False, on_epoch=True, prog_bar=True)
if torch.is_tensor(penalty_conf):
penalty_conf = penalty_conf.item()
self.log("penalty_conf", penalty_conf, on_step=False, on_epoch=True, prog_bar=True)
if torch.is_tensor(penalty_ind):
penalty_ind = penalty_ind.item()
self.log("penalty_ind", penalty_ind, on_step=False, on_epoch=True, prog_bar=True)
if torch.is_tensor(penalty_sel):
penalty_sel = penalty_sel.item()
self.log("penalty_sel", penalty_sel, on_step=False, on_epoch=True, prog_bar=True)
elif self.graph is not None:
pass # TODO
else:
raise ValueError("No attribute types or graph provided.")
acc = correct / total

Просмотреть файл

@ -27,8 +27,8 @@ class ERM(PredictionAlgorithm):
"""
x = torch.cat([x for x, y, _, _ in train_batch])
y = torch.cat([y for x, y, _, _ in train_batch])
x = torch.cat([x for x, y, _ in train_batch])
y = torch.cat([y for x, y, _ in train_batch])
out = self.model(x)
loss = F.cross_entropy(out, y)

Просмотреть файл

@ -0,0 +1,300 @@
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from dowhy.causal_prediction.algorithms.utils import mmd_compute
class Regularizer:
"""
Implements methods for applying unconditional and conditional regularization.
"""
def __init__(
self,
E_conditioned,
ci_test,
kernel_type,
gamma,
):
"""
:param E_conditioned: Binary flag indicating whether E-conditioned regularization has to be applied
:param ci_test: Conditional independence metric used for regularization penalty. Currently, MMD is supported.
:param kernel_type: Kernel type for MMD penalty. Currently, supports "gaussian" (RBF). If None, distance between mean and second-order statistics (covariances) is used.
:param gamma: kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See `mmd_compute` in utils.py)
"""
self.E_conditioned = E_conditioned # E-conditioned regularization by default
self.ci_test = ci_test
self.kernel_type = kernel_type
self.gamma = gamma
def mmd(self, x, y):
"""
Compute MMD penalty between x and y.
"""
return mmd_compute(x, y, self.kernel_type, self.gamma)
def unconditional_reg(self, classifs, attribute_labels, num_envs, E_eq_A=False):
"""
Implement unconditional regularization φ(x) A_i
:param classifs: feature representations output from classifier layer ((x))
:param attribute_labels: attribute labels loaded with the dataset for attribute A_i
:param num_envs: number of environments/domains
:param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition
"""
penalty = 0
if E_eq_A: # Environment (E) and attribute (A) coincide
if self.E_conditioned is False: # there is no correlation between E and X_c
for i in range(num_envs):
for j in range(i + 1, num_envs):
penalty += self.mmd(classifs[i], classifs[j])
else:
if self.E_conditioned:
for i in range(num_envs):
unique_attr_labels = torch.unique(attribute_labels[i])
unique_attr_label_indices = []
for label in unique_attr_labels:
label_ind = [ind for ind, j in enumerate(attribute_labels[i]) if j == label]
unique_attr_label_indices.append(label_ind)
nulabels = unique_attr_labels.shape[0]
for aidx in range(nulabels):
for bidx in range(aidx + 1, nulabels):
penalty += self.mmd(
classifs[i][unique_attr_label_indices[aidx]],
classifs[i][unique_attr_label_indices[bidx]],
)
else: # this currently assumes we have a disjoint set of attributes (Aind) across environments i.e., environment is defined by multiple closely related values of the attribute
overall_nmb_indices, nmb_id = [], []
for i in range(num_envs):
unique_attrs = torch.unique(attribute_labels[i])
unique_attr_indices = []
for attr in unique_attrs:
attr_ind = [ind for ind, j in enumerate(attribute_labels[i]) if j == attr]
unique_attr_indices.append(attr_ind)
overall_nmb_indices.append(attr_ind)
nmb_id.append(i)
nuattr = len(overall_nmb_indices)
for aidx in range(nuattr):
for bidx in range(aidx + 1, nuattr):
a_nmb_id = nmb_id[aidx]
b_nmb_id = nmb_id[bidx]
penalty += self.mmd(
classifs[a_nmb_id][overall_nmb_indices[aidx]],
classifs[b_nmb_id][overall_nmb_indices[bidx]],
)
return penalty
def conditional_reg(self, classifs, attribute_labels, conditioning_subset, num_envs, E_eq_A=False):
"""
Implement conditional regularization φ(x) A_i | A_s
:param classifs: feature representations output from classifier layer ((x))
:param attribute_labels: attribute labels loaded with the dataset for attribute A_i
:param conditioning_subset: list of subset of observed variables A_s (attributes + targets) such that (X_c, A_i) are d-separated conditioned on this subset
:param num_envs: number of environments/domains
:param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition
Find group indices for conditional regularization based on conditioning subset by taking all possible combinations
e.g., conditioning_subset = [A1, Y], where A1 is in {0, 1} and Y is in {0, 1, 2},
we assign groups in the following way:
A1 = 0, Y = 0 -> group 0
A1 = 1, Y = 0 -> group 1
A1 = 0, Y = 1 -> group 2
A1 = 1, Y = 1 -> group 3
A1 = 0, Y = 2 -> group 4
A1 = 1, Y = 2 -> group 5
Code snippet for computing group indices adapted from WILDS: https://github.com/p-lambda/wilds
@inproceedings{wilds2021,
title = {{WILDS}: A Benchmark of in-the-Wild Distribution Shifts},
author = {Pang Wei Koh and Shiori Sagawa and Henrik Marklund and Sang Michael Xie and Marvin Zhang and Akshay Balsubramani and Weihua Hu and Michihiro Yasunaga and Richard Lanas Phillips and Irena Gao and Tony Lee and Etienne David and Ian Stavness and Wei Guo and Berton A. Earnshaw and Imran S. Haque and Sara Beery and Jure Leskovec and Anshul Kundaje and Emma Pierson and Sergey Levine and Chelsea Finn and Percy Liang},
booktitle = {International Conference on Machine Learning (ICML)},
year = {2021}
}`
"""
penalty = 0
if E_eq_A: # Environment (E) and attribute (A) coincide
if self.E_conditioned is False: # there is no correlation between E and X_c
overall_group_vindices = {} # storing group indices
overall_group_eindices = {} # storing corresponding environment indices
for i in range(num_envs):
conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset]
conditioning_subset_i_uniform = [
ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i
]
grouping_data = torch.cat(conditioning_subset_i_uniform, 1)
assert grouping_data.min() >= 0, "Group numbers cannot be negative."
cardinality = 1 + torch.max(grouping_data, dim=0)[0]
cumprod = torch.cumprod(cardinality, dim=0)
n_groups = cumprod[-1].item()
factors_np = np.concatenate(([1], cumprod[:-1]))
factors = torch.from_numpy(factors_np)
group_indices = grouping_data @ factors
for group_idx in range(n_groups):
group_idx_indices = [
gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx
]
if group_idx not in overall_group_vindices:
overall_group_vindices[group_idx] = {}
overall_group_eindices[group_idx] = {}
unique_attrs = torch.unique(
attribute_labels[i][group_idx_indices]
) # find distinct attributes in environment with same group_idx_indices
unique_attr_indices = []
for attr in unique_attrs: # storing indices with same attribute value and group label
if attr not in overall_group_vindices[group_idx]:
overall_group_vindices[group_idx][attr] = []
overall_group_eindices[group_idx][attr] = []
single_attr = []
for group_idx_indices_attr in group_idx_indices:
if attribute_labels[i][group_idx_indices_attr] == attr:
single_attr.append(group_idx_indices_attr)
overall_group_vindices[group_idx][attr].append(single_attr)
overall_group_eindices[group_idx][attr].append(i)
unique_attr_indices.append(single_attr)
for (
group_label
) in (
overall_group_vindices
): # applying MMD penalty between distributions P(φ(x)|ai, g), P(φ(x)|aj, g) i.e samples with different attribute labelues but same group label
tensors_list = []
for attr in overall_group_vindices[group_label]:
attrs_list = []
if overall_group_vindices[group_label][attr] != []:
for il_ind, indices_list in enumerate(overall_group_vindices[group_label][attr]):
attrs_list.append(
classifs[overall_group_eindices[group_label][attr][il_ind]][indices_list]
)
if len(attrs_list) > 0:
tensor_attrs = torch.cat(attrs_list, 0)
tensors_list.append(tensor_attrs)
nuattr = len(tensors_list)
for aidx in range(nuattr):
for bidx in range(aidx + 1, nuattr):
penalty += self.mmd(tensors_list[aidx], tensors_list[bidx])
else:
if self.E_conditioned:
for i in range(num_envs):
conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset]
conditioning_subset_i_uniform = [
ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i
]
grouping_data = torch.cat(conditioning_subset_i_uniform, 1)
assert grouping_data.min() >= 0, "Group numbers cannot be negative."
cardinality = 1 + torch.max(grouping_data, dim=0)[0]
cumprod = torch.cumprod(cardinality, dim=0)
n_groups = cumprod[-1].item()
factors_np = np.concatenate(([1], cumprod[:-1]))
factors = torch.from_numpy(factors_np)
group_indices = grouping_data @ factors
for group_idx in range(n_groups):
group_idx_indices = [
gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx
]
unique_attrs = torch.unique(
attribute_labels[i][group_idx_indices]
) # find distinct attributes in environment with same group_idx_indices
unique_attr_indices = []
for attr in unique_attrs:
single_attr = []
for group_idx_indices_attr in group_idx_indices:
if attribute_labels[i][group_idx_indices_attr] == attr:
single_attr.append(group_idx_indices_attr)
unique_attr_indices.append(single_attr)
nuattr = unique_attrs.shape[0]
for aidx in range(nuattr):
for bidx in range(aidx + 1, nuattr):
penalty += self.mmd(
classifs[i][unique_attr_indices[aidx]], classifs[i][unique_attr_indices[bidx]]
)
else:
overall_group_vindices = {} # storing group indices
overall_group_eindices = {} # storing corresponding environment indices
for i in range(num_envs):
conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset]
conditioning_subset_i_uniform = [
ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i
]
grouping_data = torch.cat(conditioning_subset_i_uniform, 1)
assert grouping_data.min() >= 0, "Group numbers cannot be negative."
cardinality = 1 + torch.max(grouping_data, dim=0)[0]
cumprod = torch.cumprod(cardinality, dim=0)
n_groups = cumprod[-1].item()
factors_np = np.concatenate(([1], cumprod[:-1]))
factors = torch.from_numpy(factors_np)
group_indices = grouping_data @ factors
for group_idx in range(n_groups):
group_idx_indices = [
gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx
]
if group_idx not in overall_group_vindices:
overall_group_vindices[group_idx] = {}
overall_group_eindices[group_idx] = {}
unique_attrs = torch.unique(
attribute_labels[i][group_idx_indices]
) # find distinct attributes in environment with same group_idx_indices
unique_attr_indices = []
for attr in unique_attrs: # storing indices with same attribute value and group label
if attr not in overall_group_vindices[group_idx]:
overall_group_vindices[group_idx][attr] = []
overall_group_eindices[group_idx][attr] = []
single_attr = []
for group_idx_indices_attr in group_idx_indices:
if attribute_labels[i][group_idx_indices_attr] == attr:
single_attr.append(group_idx_indices_attr)
overall_group_vindices[group_idx][attr].append(single_attr)
overall_group_eindices[group_idx][attr].append(i)
unique_attr_indices.append(single_attr)
for (
group_label
) in (
overall_group_vindices
): # applying MMD penalty between distributions P(φ(x)|ai, g), P(φ(x)|aj, g) i.e samples with different attribute labelues but same group label
tensors_list = []
for attr in overall_group_vindices[group_label]:
attrs_list = []
if overall_group_vindices[group_label][attr] != []:
for il_ind, indices_list in enumerate(overall_group_vindices[group_label][attr]):
attrs_list.append(
classifs[overall_group_eindices[group_label][attr][il_ind]][indices_list]
)
if len(attrs_list) > 0:
tensor_attrs = torch.cat(attrs_list, 0)
tensors_list.append(tensor_attrs)
nuattr = len(tensors_list)
for aidx in range(nuattr):
for bidx in range(aidx + 1, nuattr):
penalty += self.mmd(tensors_list[aidx], tensors_list[bidx])
return penalty

Просмотреть файл

@ -108,7 +108,6 @@ def get_loaders(
holdout_fraction=0.2,
trial_seed=0,
):
"""Return training, validation, and test dataloaders.
:param dataset: dataset class containing list of environments

Просмотреть файл

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torchvision
from PIL import Image
from torch.utils.data import Subset, TensorDataset
from torchvision import transforms
@ -22,12 +23,12 @@ from dowhy.causal_prediction.datasets.base_dataset import MultipleDomainDataset
* datasets initialized from torchvision.datasets.MNIST
* We assume causal attribute (Acause) = color, independent attribute (Aind) = rotation
* Environments/domains stored in list self.datasets (required for all datasets)
* Default env structure is TensorDataset(x, y, Acause, Aind)
* We do not need Aind (rotation) explicitly here since E=Aind
* In Aind case, we return (x, y, _, _)
* '_' is replaced by any dummy vector (this is because the training loop assumes 4 inputs in TensorDataset)
* Default env structure is TensorDataset(x, y, a)
* a is a combined tensor for all attributes (metadata) a1, a2, ..., ak
* a's shape is (n, k) where n is the number of samples in the environment
"""
# single-attribute Causal
class MNISTCausalAttribute(MultipleDomainDataset):
N_STEPS = 5001
@ -103,8 +104,9 @@ class MNISTCausalAttribute(MultipleDomainDataset):
x = images.float().div_(255.0)
y = labels.view(-1).long()
a = torch.unsqueeze(colors, 1)
return TensorDataset(x, y, colors, colors)
return TensorDataset(x, y, a)
def torch_bernoulli_(self, p, size):
return (torch.rand(size) < p).float()
@ -148,21 +150,21 @@ class MNISTIndAttribute(MultipleDomainDataset):
for i, env in enumerate(angles[:-1]):
images = original_images[:50000][i::2]
labels = original_labels[:50000][i::2]
self.datasets.append(self.rotate_dataset(images, labels, angles[i]))
self.datasets.append(self.rotate_dataset(images, labels, i, angles[i]))
images = original_images[50000:]
labels = original_labels[50000:]
self.datasets.append(self.rotate_dataset(images, labels, angles[-1]))
self.datasets.append(self.rotate_dataset(images, labels, len(angles) - 1, angles[-1]))
# test environment
original_dataset_te = MNIST(root, train=False, download=download)
original_images = original_dataset_te.data
original_labels = original_dataset_te.targets
self.datasets.append(self.rotate_dataset(original_images, original_labels, angles[-1]))
self.datasets.append(self.rotate_dataset(original_images, original_labels, len(angles) - 1, angles[-1]))
self.input_shape = self.INPUT_SHAPE
self.num_classes = 2
def rotate_dataset(self, images, labels, angle):
def rotate_dataset(self, images, labels, env_id, angle):
"""
Transform MNIST dataset by applying rotation to images.
Attribute (rotation angle) is independent of label Y.
@ -170,12 +172,16 @@ class MNISTIndAttribute(MultipleDomainDataset):
:param images: original MNIST images
:param labels: original MNIST labels
:param angle: Value of rotation angle used for transforming the image
:returns: TensorDataset containing transformed images and labels
:returns: TensorDataset containing transformed images, labels, and attributes (angle)
"""
rotation = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Lambda(lambda x: rotate(x, int(angle), fill=(0,), resample=Image.BICUBIC)),
transforms.Lambda(
lambda x: rotate(
x, int(angle), fill=(0,), interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
),
transforms.ToTensor(),
]
)
@ -192,8 +198,10 @@ class MNISTIndAttribute(MultipleDomainDataset):
x[i] = rotation(images[i].float().div_(255.0))
y = labels.view(-1).long()
a = torch.full((y.shape[0],), env_id, dtype=torch.float32)
a = torch.unsqueeze(a, 1)
return TensorDataset(x, y, y, y)
return TensorDataset(x, y, a)
def torch_bernoulli_(self, p, size):
return (torch.rand(size) < p).float()
@ -238,21 +246,23 @@ class MNISTCausalIndAttribute(MultipleDomainDataset):
for i, env in enumerate(environments[:-1]):
images = original_images[:50000][i::2]
labels = original_labels[:50000][i::2]
self.datasets.append(self.color_rot_dataset(images, labels, env, angles[i]))
self.datasets.append(self.color_rot_dataset(images, labels, env, i, angles[i]))
images = original_images[50000:]
labels = original_labels[50000:]
self.datasets.append(self.color_rot_dataset(images, labels, environments[-1], angles[-1]))
self.datasets.append(self.color_rot_dataset(images, labels, environments[-1], len(angles) - 1, angles[-1]))
# test environment
original_dataset_te = MNIST(root, train=False, download=download)
original_images = original_dataset_te.data
original_labels = original_dataset_te.targets
self.datasets.append(self.color_rot_dataset(original_images, original_labels, environments[-1], angles[-1]))
self.datasets.append(
self.color_rot_dataset(original_images, original_labels, environments[-1], len(angles) - 1, angles[-1])
)
self.input_shape = self.INPUT_SHAPE
self.num_classes = 2
def color_rot_dataset(self, images, labels, environment, angle):
def color_rot_dataset(self, images, labels, environment, env_id, angle):
"""
Transform MNIST dataset by (i) applying rotation to images, then (ii) introducing correlation between attribute (color) and label.
Attribute (rotation angle) is independent of label Y; there is a direct-causal relationship between label Y and color.
@ -261,7 +271,7 @@ class MNISTCausalIndAttribute(MultipleDomainDataset):
:param labels: original MNIST labels
:param environment: Value of correlation between color and label
:param angle: Value of rotation angle used for transforming the image
:returns: TensorDataset containing transformed images, labels, and attributes (color)
:returns: TensorDataset containing transformed images, labels, and attributes (color, angle)
"""
# Subsample 2x for computational convenience
images = images.reshape((-1, 28, 28))[:, ::2, ::2]
@ -272,8 +282,10 @@ class MNISTCausalIndAttribute(MultipleDomainDataset):
x = images # .float().div_(255.0)
y = labels.view(-1).long()
angles = torch.full((y.shape[0],), env_id, dtype=torch.float32)
a = torch.stack((colors, angles), 1)
return TensorDataset(x, y, colors, colors)
return TensorDataset(x, y, a)
def color_dataset(self, images, labels, environment):
"""