Merge pull request #38 from TevenLeScao/coord_check_plot_features

coord check plot improvements
This commit is contained in:
Greg Yang 2023-03-16 17:03:09 -05:00 коммит произвёл GitHub
Родитель 97b411dddf 442f2016c8
Коммит a33ea802bc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 23 добавлений и 14 удалений

Просмотреть файл

@ -468,8 +468,9 @@ def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True,
data['lr'] = lr
return data
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
legend='full', name_contains=None, name_not_contains=None,
legend='full', name_contains=None, name_not_contains=None, module_list=None,
loglog=True, logbase=2, face_color=None, subplot_width=5,
subplot_height=4):
'''Plot coord check data `df` obtained from `get_coord_data`.
@ -489,10 +490,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
the column of `df` to represent as color. Default: `'module'`
legend:
'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
name_contains:
only plot modules whose name contains `name_contains`
name_not_contains:
only plot modules whose name does not contain `name_not_contains`
name_contains, name_not_contains:
only plot modules whose name contains `name_contains` and does not contain `name_not_contains`
module_list:
only plot modules that are given in the list, overrides `name_contains` and `name_not_contains`
loglog:
whether to use loglog scale. Default: True
logbase:
@ -512,14 +513,17 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
df = copy(df)
# nn.Sequential has name '', which duplicates the output layer
df = df[df.module != '']
try:
if module_list is not None:
df = df[df['module'].isin(module_list)]
else:
if name_contains is not None:
df = df[df['module'].str.contains(name_contains)]
elif name_not_contains is not None:
if name_not_contains is not None:
df = df[~(df['module'].str.contains(name_not_contains))]
# for nn.Sequential, module names are numerical
# for nn.Sequential, module names are numerical
try:
df['module'] = pd.to_numeric(df['module'])
except Exception as e:
except ValueError:
pass
ts = df.t.unique()
@ -530,19 +534,24 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
def tight_layout(plt):
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
### plot
fig = plt.figure(figsize=(subplot_width*len(ts), subplot_height))
fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
hue_order = sorted(set(df['module']))
if face_color is not None:
fig.patch.set_facecolor(face_color)
ymin, ymax = min(df[y]), max(df[y])
for t in ts:
t = int(t)
plt.subplot(1, len(ts), t)
sns.lineplot(x=x, y=y, data=df[df.t==t], hue=hue, legend=legend if t==1 else None)
sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None)
plt.title(f't={t}')
if t != 1:
plt.ylabel('')
if loglog:
plt.loglog(base=logbase)
ax = plt.gca()
ax.set_ylim([ymin, ymax])
if suptitle:
plt.suptitle(suptitle)
tight_layout(plt)