зеркало из https://github.com/microsoft/mup.git
Merge pull request #38 from TevenLeScao/coord_check_plot_features
coord check plot improvements
This commit is contained in:
Коммит
a33ea802bc
|
@ -468,8 +468,9 @@ def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True,
|
||||||
data['lr'] = lr
|
data['lr'] = lr
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
|
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
|
||||||
legend='full', name_contains=None, name_not_contains=None,
|
legend='full', name_contains=None, name_not_contains=None, module_list=None,
|
||||||
loglog=True, logbase=2, face_color=None, subplot_width=5,
|
loglog=True, logbase=2, face_color=None, subplot_width=5,
|
||||||
subplot_height=4):
|
subplot_height=4):
|
||||||
'''Plot coord check data `df` obtained from `get_coord_data`.
|
'''Plot coord check data `df` obtained from `get_coord_data`.
|
||||||
|
@ -489,10 +490,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
||||||
the column of `df` to represent as color. Default: `'module'`
|
the column of `df` to represent as color. Default: `'module'`
|
||||||
legend:
|
legend:
|
||||||
'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
|
'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
|
||||||
name_contains:
|
name_contains, name_not_contains:
|
||||||
only plot modules whose name contains `name_contains`
|
only plot modules whose name contains `name_contains` and does not contain `name_not_contains`
|
||||||
name_not_contains:
|
module_list:
|
||||||
only plot modules whose name does not contain `name_not_contains`
|
only plot modules that are given in the list, overrides `name_contains` and `name_not_contains`
|
||||||
loglog:
|
loglog:
|
||||||
whether to use loglog scale. Default: True
|
whether to use loglog scale. Default: True
|
||||||
logbase:
|
logbase:
|
||||||
|
@ -501,10 +502,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
||||||
background color of the plot. Default: None (which means white)
|
background color of the plot. Default: None (which means white)
|
||||||
subplot_width, subplot_height:
|
subplot_width, subplot_height:
|
||||||
The width and height for each timestep's subplot. More precisely,
|
The width and height for each timestep's subplot. More precisely,
|
||||||
the figure size will be
|
the figure size will be
|
||||||
`(subplot_width*number_of_time_steps, subplot_height)`.
|
`(subplot_width*number_of_time_steps, subplot_height)`.
|
||||||
Default: 5, 4
|
Default: 5, 4
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
the `matplotlib` figure object
|
the `matplotlib` figure object
|
||||||
'''
|
'''
|
||||||
|
@ -512,14 +513,17 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
||||||
df = copy(df)
|
df = copy(df)
|
||||||
# nn.Sequential has name '', which duplicates the output layer
|
# nn.Sequential has name '', which duplicates the output layer
|
||||||
df = df[df.module != '']
|
df = df[df.module != '']
|
||||||
try:
|
if module_list is not None:
|
||||||
|
df = df[df['module'].isin(module_list)]
|
||||||
|
else:
|
||||||
if name_contains is not None:
|
if name_contains is not None:
|
||||||
df = df[df['module'].str.contains(name_contains)]
|
df = df[df['module'].str.contains(name_contains)]
|
||||||
elif name_not_contains is not None:
|
if name_not_contains is not None:
|
||||||
df = df[~(df['module'].str.contains(name_not_contains))]
|
df = df[~(df['module'].str.contains(name_not_contains))]
|
||||||
# for nn.Sequential, module names are numerical
|
# for nn.Sequential, module names are numerical
|
||||||
|
try:
|
||||||
df['module'] = pd.to_numeric(df['module'])
|
df['module'] = pd.to_numeric(df['module'])
|
||||||
except Exception as e:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ts = df.t.unique()
|
ts = df.t.unique()
|
||||||
|
@ -530,26 +534,31 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
||||||
|
|
||||||
def tight_layout(plt):
|
def tight_layout(plt):
|
||||||
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
||||||
|
|
||||||
### plot
|
### plot
|
||||||
fig = plt.figure(figsize=(subplot_width*len(ts), subplot_height))
|
fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
|
||||||
|
hue_order = sorted(set(df['module']))
|
||||||
if face_color is not None:
|
if face_color is not None:
|
||||||
fig.patch.set_facecolor(face_color)
|
fig.patch.set_facecolor(face_color)
|
||||||
|
ymin, ymax = min(df[y]), max(df[y])
|
||||||
for t in ts:
|
for t in ts:
|
||||||
t = int(t)
|
t = int(t)
|
||||||
plt.subplot(1, len(ts), t)
|
plt.subplot(1, len(ts), t)
|
||||||
sns.lineplot(x=x, y=y, data=df[df.t==t], hue=hue, legend=legend if t==1 else None)
|
sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None)
|
||||||
plt.title(f't={t}')
|
plt.title(f't={t}')
|
||||||
if t != 1:
|
if t != 1:
|
||||||
plt.ylabel('')
|
plt.ylabel('')
|
||||||
if loglog:
|
if loglog:
|
||||||
plt.loglog(base=logbase)
|
plt.loglog(base=logbase)
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.set_ylim([ymin, ymax])
|
||||||
if suptitle:
|
if suptitle:
|
||||||
plt.suptitle(suptitle)
|
plt.suptitle(suptitle)
|
||||||
tight_layout(plt)
|
tight_layout(plt)
|
||||||
if save_to is not None:
|
if save_to is not None:
|
||||||
plt.savefig(save_to)
|
plt.savefig(save_to)
|
||||||
print(f'coord check plot saved to {save_to}')
|
print(f'coord check plot saved to {save_to}')
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
# example of how to plot coord check results
|
# example of how to plot coord check results
|
||||||
|
|
Загрузка…
Ссылка в новой задаче