This commit is contained in:
adam 2019-02-10 21:26:29 -05:00
Родитель 18a03b0601
Коммит 441d71450a
3 изменённых файлов: 515 добавлений и 1362 удалений

Просмотреть файл

@ -27,10 +27,10 @@ class CausalAccessor(object):
def do(self, x, method=None, num_cores=1, variable_types={}, outcome=None, params=None, dot_graph=None,
common_causes=None, instruments=None, estimand_type='ate', proceed_when_unidentifiable=False,
keep_original_treatment=False):
keep_original_treatment=False, use_previous_sampler=False):
if not method:
raise Exception("You must specify a do sampling method.")
if not self._obj._causal_model:
if not self._obj._causal_model or not use_previous_sampler:
self._obj._causal_model = CausalModel(self._obj,
[xi for xi in x.keys()][0],
outcome,
@ -41,7 +41,7 @@ class CausalAccessor(object):
proceed_when_unidentifiable=proceed_when_unidentifiable)
self._obj._identified_estimand = self._obj._causal_model.identify_effect()
do_sampler_class = do_samplers.get_class_object(method + "_sampler")
if not self._obj._sampler:
if not self._obj._sampler or not use_previous_sampler:
self._obj._sampler = do_sampler_class(self._obj,
self._obj._identified_estimand,
self._obj._causal_model._treatment,

Просмотреть файл

@ -109,25 +109,16 @@ class McmcSampler(DoSampler):
def do_sample(self, x):
self.reset()
print(self._df.sample(10))
g_for_surgery = nx.DiGraph(self.g)
g_modified = self.do_x_surgery(g_for_surgery, x)
print(self._df.sample(10))
self._df = self.make_intervention_effective(x)
print(self._df.sample(10))
g_modified, trace = self.sample_prior_causal_model(g_modified,
self._df,
self._variable_types,
initialization_trace=self.fit_trace)
print(self._df.sample(10))
for col in self._df:
if col in trace and col not in self._treatment_names:
self._df[col] = trace[col]
print(self._df.sample(10))
return self._df.copy()
def _construct_sampler(self):

Разница между файлами не показана из-за своего большого размера Загрузить разницу