зеркало из https://github.com/microsoft/archai.git
Started changing over various complicated plots to plotly!
This commit is contained in:
Родитель
9d9c612f45
Коммит
cf4c9f91fc
|
@ -12,7 +12,7 @@ nas:
|
|||
model_desc:
|
||||
num_edges_to_sample: 2
|
||||
loader:
|
||||
train_batch: 256 # natsbench uses 256
|
||||
train_batch: 2048 # natsbench uses 256
|
||||
aug: '' # random flip and crop are already there in default params
|
||||
trainer: # matching natsbench paper closely
|
||||
plotsdir: ''
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import plotly.express as px
|
||||
|
||||
from plotly.subplots import make_subplots
|
||||
import plotly.graph_objects as go
|
||||
from plotly.validators.scatter.marker import SymbolValidator
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
|
||||
symbols = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||
|
||||
# subplot
|
||||
fig = make_subplots(rows=1, cols=2, subplot_titles=("Plot 1", "Plot 2"))
|
||||
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(x=[1], y=[4], mode='markers', name='reg', marker_symbol=0, marker_color='red'),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(x=[1], y=[9], mode='markers', name='ft', marker_symbol=1, marker_color='blue'),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(x=[20], y=[50], mode='markers', name='reg', marker_symbol=0, marker_color='red', showlegend=False),
|
||||
row=1, col=2
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(x=[20], y=[100], mode='markers', name='ft', marker_symbol=1, marker_color='blue', showlegend=False),
|
||||
row=1, col=2
|
||||
)
|
||||
|
||||
|
||||
fig.update_layout(height=600, width=800, title_text="Side By Side Subplots")
|
||||
fig.show()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -2,11 +2,15 @@ import os
|
|||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import random
|
||||
from typing import List, Dict
|
||||
from itertools import cycle
|
||||
from cycler import cycler
|
||||
from collections import OrderedDict
|
||||
import math as ma
|
||||
|
||||
import plotly.express as px
|
||||
from plotly.subplots import make_subplots
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
def parse_raw_data(root_exp_folder:str, exp_list:List[str])->Dict:
|
||||
|
@ -137,7 +141,7 @@ def main():
|
|||
savename = os.path.join(exp_folder, f'aggregate_spe.png')
|
||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||
|
||||
# plot spe vs. time per top percent of architectures
|
||||
# plot spe and common ratio vs. time per top percent of architectures
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# assuming that all experiments are reporting on the same
|
||||
|
@ -170,80 +174,49 @@ def main():
|
|||
|
||||
tp_info[tp] = this_tp_info
|
||||
|
||||
# now plot each top percent time vs. spe
|
||||
fig, axs = plt.subplots(5, 10)
|
||||
handles = None
|
||||
labels = None
|
||||
for tp_key, ax in zip(tp_info.keys(), axs.flat):
|
||||
# now plot each top percent time vs. spe and common ratio
|
||||
num_plots = len(tp_info)
|
||||
num_plots_per_row = 5
|
||||
num_plots_per_col = ma.ceil(num_plots / num_plots_per_row)
|
||||
subplot_titles = [f'Top {x} %' for x in tp_info.keys()]
|
||||
fig = make_subplots(rows=num_plots_per_row, cols=num_plots_per_col, subplot_titles=subplot_titles, shared_yaxes=True)
|
||||
fig_cr = make_subplots(rows=num_plots_per_row, cols=num_plots_per_col, subplot_titles=subplot_titles, shared_yaxes=True)
|
||||
|
||||
for ind, tp_key in enumerate(tp_info.keys()):
|
||||
counter = 0
|
||||
counter_reg = 0
|
||||
for exp in tp_info[tp_key].keys():
|
||||
duration = tp_info[tp_key][exp][0]
|
||||
spe = tp_info[tp_key][exp][1]
|
||||
cr = tp_info[tp_key][exp][2]
|
||||
|
||||
if 'ft_fb' in exp:
|
||||
marker = markers[counter]
|
||||
if 'ft_fb' in exp or 'ft_c100' in exp:
|
||||
marker = counter
|
||||
marker_color = 'red'
|
||||
counter += 1
|
||||
elif 'ft_c100' in exp:
|
||||
marker = markers[counter]
|
||||
counter += 1
|
||||
elif 'nb_reg' in exp:
|
||||
marker = mathy_markers[counter_reg]
|
||||
counter_reg += 1
|
||||
elif 'nb_c100_reg' in exp:
|
||||
marker = mathy_markers[counter_reg]
|
||||
counter_reg += 1
|
||||
elif 'nb_reg' in exp or 'nb_c100_reg' in exp:
|
||||
marker = counter_reg
|
||||
marker_color = 'blue'
|
||||
counter_reg += 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
ax.scatter(duration, spe, label=exp, marker=marker)
|
||||
ax.set_title(str(tp_key))
|
||||
#ax.set(xlabel='Duration (s)', ylabel='SPE')
|
||||
ax.set_ylim([0, 1])
|
||||
|
||||
|
||||
handles, labels = axs.flat[-1].get_legend_handles_labels()
|
||||
fig.legend(handles, labels, loc='center right')
|
||||
row_num = ma.floor(ind/num_plots_per_col) + 1
|
||||
col_num = ind % num_plots_per_col + 1
|
||||
showlegend = True if ind == 0 else False
|
||||
fig.add_trace(go.Scatter(x=[duration], y=[spe], mode='markers', name=exp,
|
||||
marker_symbol=marker, marker_color=marker_color, showlegend=showlegend),
|
||||
row=row_num, col=col_num)
|
||||
#fig.update_xaxes(title_text="Duration (s)", row=row_num, col=col_num)
|
||||
#fig.update_yaxes(title_text="SPE", row=row_num, col=col_num)
|
||||
fig_cr.add_trace(go.Scatter(x=[duration], y=[cr], mode='markers', name=exp,
|
||||
marker_symbol=marker, marker_color=marker_color, showlegend=showlegend),
|
||||
row=row_num, col=col_num)
|
||||
|
||||
|
||||
# now plot each top percent time vs. common ratio
|
||||
fig, axs = plt.subplots(5, 10)
|
||||
handles = None
|
||||
labels = None
|
||||
for tp_key, ax in zip(tp_info.keys(), axs.flat):
|
||||
counter = 0
|
||||
counter_reg = 0
|
||||
for exp in tp_info[tp_key].keys():
|
||||
duration = tp_info[tp_key][exp][0]
|
||||
common_ratio = tp_info[tp_key][exp][2]
|
||||
|
||||
if 'ft_fb' in exp:
|
||||
marker = markers[counter]
|
||||
counter += 1
|
||||
elif 'ft_c100' in exp:
|
||||
marker = markers[counter]
|
||||
counter += 1
|
||||
elif 'nb_reg' in exp:
|
||||
marker = mathy_markers[counter_reg]
|
||||
counter_reg += 1
|
||||
elif 'nb_c100_reg' in exp:
|
||||
marker = mathy_markers[counter_reg]
|
||||
counter_reg += 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
ax.scatter(duration, common_ratio, label=exp, marker=marker)
|
||||
ax.set_title(str(tp_key))
|
||||
#ax.set(xlabel='Duration (s)', ylabel='Common Ratio')
|
||||
ax.set_ylim([0, 1])
|
||||
|
||||
|
||||
handles, labels = axs.flat[-1].get_legend_handles_labels()
|
||||
fig.legend(handles, labels, loc='center right')
|
||||
|
||||
|
||||
|
||||
plt.show()
|
||||
fig.update_layout(title_text="Duration vs. Spearman Rank Correlation vs. Top %")
|
||||
fig.show()
|
||||
fig_cr.update_layout(title_text="Duration vs. Common Ratio vs. Top %")
|
||||
fig_cr.show()
|
||||
|
||||
|
||||
# plot timing information vs. top percent of architectures
|
||||
|
|
Загрузка…
Ссылка в новой задаче