Merge pull request #38 from TevenLeScao/coord_check_plot_features

coord check plot improvements
This commit is contained in:
Greg Yang 2023-03-16 17:03:09 -05:00 коммит произвёл GitHub
Родитель 97b411dddf 442f2016c8
Коммит a33ea802bc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 23 добавлений и 14 удалений

Просмотреть файл

@ -468,8 +468,9 @@ def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True,
data['lr'] = lr data['lr'] = lr
return data return data
def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module', def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
legend='full', name_contains=None, name_not_contains=None, legend='full', name_contains=None, name_not_contains=None, module_list=None,
loglog=True, logbase=2, face_color=None, subplot_width=5, loglog=True, logbase=2, face_color=None, subplot_width=5,
subplot_height=4): subplot_height=4):
'''Plot coord check data `df` obtained from `get_coord_data`. '''Plot coord check data `df` obtained from `get_coord_data`.
@ -489,10 +490,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
the column of `df` to represent as color. Default: `'module'` the column of `df` to represent as color. Default: `'module'`
legend: legend:
'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`. 'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
name_contains: name_contains, name_not_contains:
only plot modules whose name contains `name_contains` only plot modules whose name contains `name_contains` and does not contain `name_not_contains`
name_not_contains: module_list:
only plot modules whose name does not contain `name_not_contains` only plot modules that are given in the list, overrides `name_contains` and `name_not_contains`
loglog: loglog:
whether to use loglog scale. Default: True whether to use loglog scale. Default: True
logbase: logbase:
@ -501,10 +502,10 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
background color of the plot. Default: None (which means white) background color of the plot. Default: None (which means white)
subplot_width, subplot_height: subplot_width, subplot_height:
The width and height for each timestep's subplot. More precisely, The width and height for each timestep's subplot. More precisely,
the figure size will be the figure size will be
`(subplot_width*number_of_time_steps, subplot_height)`. `(subplot_width*number_of_time_steps, subplot_height)`.
Default: 5, 4 Default: 5, 4
Output: Output:
the `matplotlib` figure object the `matplotlib` figure object
''' '''
@ -512,14 +513,17 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
df = copy(df) df = copy(df)
# nn.Sequential has name '', which duplicates the output layer # nn.Sequential has name '', which duplicates the output layer
df = df[df.module != ''] df = df[df.module != '']
try: if module_list is not None:
df = df[df['module'].isin(module_list)]
else:
if name_contains is not None: if name_contains is not None:
df = df[df['module'].str.contains(name_contains)] df = df[df['module'].str.contains(name_contains)]
elif name_not_contains is not None: if name_not_contains is not None:
df = df[~(df['module'].str.contains(name_not_contains))] df = df[~(df['module'].str.contains(name_not_contains))]
# for nn.Sequential, module names are numerical # for nn.Sequential, module names are numerical
try:
df['module'] = pd.to_numeric(df['module']) df['module'] = pd.to_numeric(df['module'])
except Exception as e: except ValueError:
pass pass
ts = df.t.unique() ts = df.t.unique()
@ -530,26 +534,31 @@ def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='mod
def tight_layout(plt): def tight_layout(plt):
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.tight_layout(rect=[0, 0.03, 1, 0.95])
### plot ### plot
fig = plt.figure(figsize=(subplot_width*len(ts), subplot_height)) fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height))
hue_order = sorted(set(df['module']))
if face_color is not None: if face_color is not None:
fig.patch.set_facecolor(face_color) fig.patch.set_facecolor(face_color)
ymin, ymax = min(df[y]), max(df[y])
for t in ts: for t in ts:
t = int(t) t = int(t)
plt.subplot(1, len(ts), t) plt.subplot(1, len(ts), t)
sns.lineplot(x=x, y=y, data=df[df.t==t], hue=hue, legend=legend if t==1 else None) sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None)
plt.title(f't={t}') plt.title(f't={t}')
if t != 1: if t != 1:
plt.ylabel('') plt.ylabel('')
if loglog: if loglog:
plt.loglog(base=logbase) plt.loglog(base=logbase)
ax = plt.gca()
ax.set_ylim([ymin, ymax])
if suptitle: if suptitle:
plt.suptitle(suptitle) plt.suptitle(suptitle)
tight_layout(plt) tight_layout(plt)
if save_to is not None: if save_to is not None:
plt.savefig(save_to) plt.savefig(save_to)
print(f'coord check plot saved to {save_to}') print(f'coord check plot saved to {save_to}')
return fig return fig
# example of how to plot coord check results # example of how to plot coord check results