зеркало из https://github.com/microsoft/mup.git
Merge pull request #38 from TevenLeScao/coord_check_plot_features
coord check plot improvements
This commit is contained in:
Коммит
a33ea802bc
|
@ -468,8 +468,9 @@ def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True,
|
|||
data['lr'] = lr
|
||||
return data
|
||||
|
||||
|
||||
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
|
||||
legend='full', name_contains=None, name_not_contains=None,
|
||||
legend='full', name_contains=None, name_not_contains=None, module_list=None,
|
||||
loglog=True, logbase=2, face_color=None, subplot_width=5,
|
||||
subplot_height=4):
|
||||
'''Plot coord check data `df` obtained from `get_coord_data`.
|
||||
|
@ -489,10 +490,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
|||
the column of `df` to represent as color. Default: `'module'`
|
||||
legend:
|
||||
'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
|
||||
name_contains:
|
||||
only plot modules whose name contains `name_contains`
|
||||
name_not_contains:
|
||||
only plot modules whose name does not contain `name_not_contains`
|
||||
name_contains, name_not_contains:
|
||||
only plot modules whose name contains `name_contains` and does not contain `name_not_contains`
|
||||
module_list:
|
||||
only plot modules that are given in the list, overrides `name_contains` and `name_not_contains`
|
||||
loglog:
|
||||
whether to use loglog scale. Default: True
|
||||
logbase:
|
||||
|
@ -512,14 +513,17 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
|||
df = copy(df)
|
||||
# nn.Sequential has name '', which duplicates the output layer
|
||||
df = df[df.module != '']
|
||||
try:
|
||||
if module_list is not None:
|
||||
df = df[df['module'].isin(module_list)]
|
||||
else:
|
||||
if name_contains is not None:
|
||||
df = df[df['module'].str.contains(name_contains)]
|
||||
elif name_not_contains is not None:
|
||||
if name_not_contains is not None:
|
||||
df = df[~(df['module'].str.contains(name_not_contains))]
|
||||
# for nn.Sequential, module names are numerical
|
||||
# for nn.Sequential, module names are numerical
|
||||
try:
|
||||
df['module'] = pd.to_numeric(df['module'])
|
||||
except Exception as e:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
ts = df.t.unique()
|
||||
|
@ -530,19 +534,24 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
|
|||
|
||||
def tight_layout(plt):
|
||||
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
||||
|
||||
### plot
|
||||
fig = plt.figure(figsize=(subplot_width*len(ts), subplot_height))
|
||||
fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
|
||||
hue_order = sorted(set(df['module']))
|
||||
if face_color is not None:
|
||||
fig.patch.set_facecolor(face_color)
|
||||
ymin, ymax = min(df[y]), max(df[y])
|
||||
for t in ts:
|
||||
t = int(t)
|
||||
plt.subplot(1, len(ts), t)
|
||||
sns.lineplot(x=x, y=y, data=df[df.t==t], hue=hue, legend=legend if t==1 else None)
|
||||
sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None)
|
||||
plt.title(f't={t}')
|
||||
if t != 1:
|
||||
plt.ylabel('')
|
||||
if loglog:
|
||||
plt.loglog(base=logbase)
|
||||
ax = plt.gca()
|
||||
ax.set_ylim([ymin, ymax])
|
||||
if suptitle:
|
||||
plt.suptitle(suptitle)
|
||||
tight_layout(plt)
|
||||
|
|
Загрузка…
Ссылка в новой задаче