Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
This commit is contained in:
Keith Battocchi 2024-07-17 20:20:35 -04:00 коммит произвёл Keith Battocchi
Родитель 4db044a0a8
Коммит b577be9a63
3 изменённых файлов: 15 добавлений и 27 удалений

Просмотреть файл

@ -264,16 +264,8 @@ class TestDeepIV(unittest.TestCase):
time = rng.rand(n) * 10 time = rng.rand(n) * 10
emotion_id = rng.randint(0, 7, size=n) emotion_id = rng.randint(0, 7, size=n)
emotion = one_hot(emotion_id, categories=[np.arange(7)]) emotion = one_hot(emotion_id, categories=[np.arange(7)])
if use_images:
idx = np.argsort(emotion_id) emotion_feature = emotion
emotion_feature = np.zeros((0, 28 * 28))
for i in range(7):
img = get_images(i, np.sum(emotion_id == i), seed, test)
emotion_feature = np.vstack([emotion_feature, img])
reorder = np.argsort(idx)
emotion_feature = emotion_feature[reorder, :]
else:
emotion_feature = emotion
# random instrument # random instrument
z = rng.randn(n) z = rng.randn(n)
@ -304,8 +296,8 @@ class TestDeepIV(unittest.TestCase):
y.reshape((-1, 1)), y.reshape((-1, 1)),
g) g)
def datafunction(n, s, images=False, test=False): def datafunction(n, s, test=False):
return demand(n=n, seed=s, ypcor=0.5, use_images=images, test=test) return demand(n=n, seed=s, ypcor=0.5, test=test)
n = 1000 n = 1000
epochs = 50 epochs = 50
@ -397,7 +389,7 @@ Response:{y}".format(**{'x': x.shape, 'z': z.shape,
y = (g - ymu) / ysd y = (g - ymu) / ysd
return y.reshape(-1, 1) return y.reshape(-1, 1)
def demand(n, seed=1, ynoise=1., pnoise=1., ypcor=0.8, use_images=False, test=False): def demand(n, seed=1, ynoise=1., pnoise=1., ypcor=0.8, test=False):
rng = np.random.RandomState(seed) rng = np.random.RandomState(seed)
# covariates: time and emotion # covariates: time and emotion
@ -435,8 +427,8 @@ Response:{y}".format(**{'x': x.shape, 'z': z.shape,
y.reshape((-1, 1)), y.reshape((-1, 1)),
g) g)
def datafunction(n, s, images=False, test=False): def datafunction(n, s, test=False):
return demand(n=n, seed=s, ypcor=0.5, use_images=images, test=test) return demand(n=n, seed=s, ypcor=0.5, test=test)
n = 1000 n = 1000
epochs = 20 epochs = 20

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -150,7 +150,6 @@ ignore = [
"F405", # 'module' may be undefined, or defined from star imports "F405", # 'module' may be undefined, or defined from star imports
"F541", # f-string is missing placeholders "F541", # f-string is missing placeholders
"F811", # Redefinition of unused name from line N "F811", # Redefinition of unused name from line N
"F821", # Undefined name
"E713", # Test for membership should be 'not in' "E713", # Test for membership should be 'not in'
"D100", # Missing docstring in public module "D100", # Missing docstring in public module
"D101", # Missing docstring in public class "D101", # Missing docstring in public class