working sampler
This commit is contained in:
Родитель
18a03b0601
Коммит
441d71450a
|
@ -27,10 +27,10 @@ class CausalAccessor(object):
|
|||
|
||||
def do(self, x, method=None, num_cores=1, variable_types={}, outcome=None, params=None, dot_graph=None,
|
||||
common_causes=None, instruments=None, estimand_type='ate', proceed_when_unidentifiable=False,
|
||||
keep_original_treatment=False):
|
||||
keep_original_treatment=False, use_previous_sampler=False):
|
||||
if not method:
|
||||
raise Exception("You must specify a do sampling method.")
|
||||
if not self._obj._causal_model:
|
||||
if not self._obj._causal_model or not use_previous_sampler:
|
||||
self._obj._causal_model = CausalModel(self._obj,
|
||||
[xi for xi in x.keys()][0],
|
||||
outcome,
|
||||
|
@ -41,7 +41,7 @@ class CausalAccessor(object):
|
|||
proceed_when_unidentifiable=proceed_when_unidentifiable)
|
||||
self._obj._identified_estimand = self._obj._causal_model.identify_effect()
|
||||
do_sampler_class = do_samplers.get_class_object(method + "_sampler")
|
||||
if not self._obj._sampler:
|
||||
if not self._obj._sampler or not use_previous_sampler:
|
||||
self._obj._sampler = do_sampler_class(self._obj,
|
||||
self._obj._identified_estimand,
|
||||
self._obj._causal_model._treatment,
|
||||
|
|
|
@ -109,25 +109,16 @@ class McmcSampler(DoSampler):
|
|||
|
||||
def do_sample(self, x):
|
||||
self.reset()
|
||||
print(self._df.sample(10))
|
||||
g_for_surgery = nx.DiGraph(self.g)
|
||||
g_modified = self.do_x_surgery(g_for_surgery, x)
|
||||
print(self._df.sample(10))
|
||||
|
||||
self._df = self.make_intervention_effective(x)
|
||||
print(self._df.sample(10))
|
||||
|
||||
g_modified, trace = self.sample_prior_causal_model(g_modified,
|
||||
self._df,
|
||||
self._variable_types,
|
||||
initialization_trace=self.fit_trace)
|
||||
print(self._df.sample(10))
|
||||
|
||||
for col in self._df:
|
||||
if col in trace and col not in self._treatment_names:
|
||||
self._df[col] = trace[col]
|
||||
print(self._df.sample(10))
|
||||
|
||||
return self._df.copy()
|
||||
|
||||
def _construct_sampler(self):
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Загрузка…
Ссылка в новой задаче