зеркало из https://github.com/microsoft/mup.git
coord check plot improvements
This commit is contained in:
Родитель
96d1f404e5
Коммит
3934867cb8
|
@ -468,8 +468,9 @@ def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True,
|
|||
data['lr'] = lr
|
||||
return data
|
||||
|
||||
|
||||
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
|
||||
legend='full', name_contains=None, name_not_contains=None,
|
||||
legend='full', name_contains=None, name_not_contains=None, module_list=None,
|
||||
loglog=True, logbase=2, face_color=None, subplot_width=5,
|
||||
subplot_height=4):
|
||||
'''Plot coord check data `df` obtained from `get_coord_data`.
|
||||
|
@ -501,10 +502,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
|||
background color of the plot. Default: None (which means white)
|
||||
subplot_width, subplot_height:
|
||||
The width and height for each timestep's subplot. More precisely,
|
||||
the figure size will be
|
||||
the figure size will be
|
||||
`(subplot_width*number_of_time_steps, subplot_height)`.
|
||||
Default: 5, 4
|
||||
|
||||
|
||||
Output:
|
||||
the `matplotlib` figure object
|
||||
'''
|
||||
|
@ -513,13 +514,17 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
|||
# nn.Sequential has name '', which duplicates the output layer
|
||||
df = df[df.module != '']
|
||||
try:
|
||||
if name_contains is not None:
|
||||
df = df[df['module'].str.contains(name_contains)]
|
||||
elif name_not_contains is not None:
|
||||
df = df[~(df['module'].str.contains(name_not_contains))]
|
||||
# for nn.Sequential, module names are numerical
|
||||
df['module'] = pd.to_numeric(df['module'])
|
||||
if module_list is not None:
|
||||
df = df[df['module'].isin(module_list)]
|
||||
else:
|
||||
if name_contains is not None:
|
||||
df = df[df['module'].str.contains(name_contains)]
|
||||
if name_not_contains is not None:
|
||||
df = df[~(df['module'].str.contains(name_not_contains))]
|
||||
# for nn.Sequential, module names are numerical
|
||||
# df['module'] = pd.to_numeric(df['module'])
|
||||
except Exception as e:
|
||||
print(f"Filtering not applied: {e}")
|
||||
pass
|
||||
|
||||
ts = df.t.unique()
|
||||
|
@ -530,25 +535,32 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
|||
|
||||
def tight_layout(plt):
|
||||
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
||||
|
||||
### plot
|
||||
fig = plt.figure(figsize=(subplot_width*len(ts), subplot_height))
|
||||
fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
|
||||
hue_order = sorted(set(df['module']))
|
||||
if face_color is not None:
|
||||
fig.patch.set_facecolor(face_color)
|
||||
ymin, ymax = min(df[y]), max(df[y])
|
||||
for t in ts:
|
||||
plt.subplot(1, len(ts), t)
|
||||
sns.lineplot(x=x, y=y, data=df[df.t==t], hue=hue, legend=legend if t==1 else None)
|
||||
sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None)
|
||||
plt.title(f't={t}')
|
||||
if t != 1:
|
||||
plt.ylabel('')
|
||||
else:
|
||||
plt.legend(labels=hue_order, bbox_to_anchor=(0, 1), loc="upper right")
|
||||
if loglog:
|
||||
plt.loglog(base=logbase)
|
||||
ax = plt.gca()
|
||||
ax.set_ylim([ymin, ymax])
|
||||
if suptitle:
|
||||
plt.suptitle(suptitle)
|
||||
tight_layout(plt)
|
||||
if save_to is not None:
|
||||
plt.savefig(save_to)
|
||||
print(f'coord check plot saved to {save_to}')
|
||||
|
||||
|
||||
return fig
|
||||
|
||||
# example of how to plot coord check results
|
||||
|
|
Загрузка…
Ссылка в новой задаче