This commit is contained in:
TevenLeScao 2023-02-01 15:30:14 +01:00
Родитель 96d1f404e5
Коммит 3934867cb8
1 изменённых файлов: 24 добавлений и 12 удалений

Просмотреть файл

@ -468,8 +468,9 @@ def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True,
data['lr'] = lr
return data
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
legend='full', name_contains=None, name_not_contains=None,
legend='full', name_contains=None, name_not_contains=None, module_list=None,
loglog=True, logbase=2, face_color=None, subplot_width=5,
subplot_height=4):
'''Plot coord check data `df` obtained from `get_coord_data`.
@ -501,10 +502,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
background color of the plot. Default: None (which means white)
subplot_width, subplot_height:
The width and height for each timestep's subplot. More precisely,
the figure size will be
the figure size will be
`(subplot_width*number_of_time_steps, subplot_height)`.
Default: 5, 4
Output:
the `matplotlib` figure object
'''
@ -513,13 +514,17 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
# nn.Sequential has name '', which duplicates the output layer
df = df[df.module != '']
try:
if name_contains is not None:
df = df[df['module'].str.contains(name_contains)]
elif name_not_contains is not None:
df = df[~(df['module'].str.contains(name_not_contains))]
# for nn.Sequential, module names are numerical
df['module'] = pd.to_numeric(df['module'])
if module_list is not None:
df = df[df['module'].isin(module_list)]
else:
if name_contains is not None:
df = df[df['module'].str.contains(name_contains)]
if name_not_contains is not None:
df = df[~(df['module'].str.contains(name_not_contains))]
# for nn.Sequential, module names are numerical
# df['module'] = pd.to_numeric(df['module'])
except Exception as e:
print(f"Filtering not applied: {e}")
pass
ts = df.t.unique()
@ -530,25 +535,32 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
def tight_layout(plt):
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
### plot
fig = plt.figure(figsize=(subplot_width*len(ts), subplot_height))
fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
hue_order = sorted(set(df['module']))
if face_color is not None:
fig.patch.set_facecolor(face_color)
ymin, ymax = min(df[y]), max(df[y])
for t in ts:
plt.subplot(1, len(ts), t)
sns.lineplot(x=x, y=y, data=df[df.t==t], hue=hue, legend=legend if t==1 else None)
sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None)
plt.title(f't={t}')
if t != 1:
plt.ylabel('')
else:
plt.legend(labels=hue_order, bbox_to_anchor=(0, 1), loc="upper right")
if loglog:
plt.loglog(base=logbase)
ax = plt.gca()
ax.set_ylim([ymin, ymax])
if suptitle:
plt.suptitle(suptitle)
tight_layout(plt)
if save_to is not None:
plt.savefig(save_to)
print(f'coord check plot saved to {save_to}')
return fig
# example of how to plot coord check results