# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import json
from graphviz import *
import argparse
def json2graph(json_path, graph_path):
with open(json_path, 'r') as file:
json_str = file.read()
try:
conf_dic = json.loads(json_str)
except ValueError as e:
print(e)
graph_path += '.gv'
color = {
"Input": "royalblue",
"Embedding": "orange",
"Linear": "tan",
"LinearAttention": "tan",
"BiGRU": "salmon",
"BiLSTM": "salmon",
"BiLSTMAtt": "salmon1",
"BiGRULast": "salmon",
"Conv": "sandybrown",
"ConvPooling": "sandybrown",
"Pooling": "skyblue",
"Dropout": "yellowgreen",
"Combination": "purple",
"EncoderDecoder": "lightsalmon",
"FullAttention": "lightsalmon",
"Seq2SeqAttention": "lightsalmon"
}
layer_conf = {
"Linear": ["hidden_dim", "activation", "last_hidden_activation", "last_hidden_softmax", "batch_normalization"],
"LinearAttention": ["keep_dim"],
"BiGRU": ["hidden_dim", "dropout"],
"BiGRULast": ["hidden_dim", "dropout"],
"BiLSTM": ["hidden_dim", "dropout", "num_layers"],
"BiLSTMAtt": ["hidden_dim", "dropout", "num_layers"],
"Conv": ["stride", "padding", "window_sizes", "input_channel_num", "output_channel_num", "activation",
"batch_normalization"],
"ConvPooling": ["stride", "padding", "window_sizes", "input_channel_num", "output_channel_num",
"batch_normalization",
"activation", "pool_type", "pool_axis"],
"Pooling": ["pool_axis", "pool_type"],
"Dropout": ["dropout"],
"Combination": ["operations"],
"EncoderDecoder": ["encoder", "decoder"],
"FullAttention": ["hidden_dim", "activation"],
"Seq2SeqAttention": ["attention_dropout"]
}
model = Digraph(format='svg',
node_attr={"style": "rounded, filled",
"shape": "box",
"fontcolor": "white"})
model.attr(rankdir="BT")
for item in conf_dic['architecture']:
if item['layer'] == "Embedding":
for c in item['conf']:
dim = item['conf'][c]['dim']
for n in item['conf'][c]['cols']:
label_str = "<" \
+ "
" \
+ "" + n + " | " + "Embedding |
" \
+ "dim: | " + "" + str(dim) + " |
" \
+ "
>"
model.node(name=n, label=label_str, fillcolor=color["Embedding"])
break
for inp in conf_dic['inputs']['model_inputs']:
model.node(name=inp,
label=inp,
fillcolor=color['Input'])
for n in conf_dic['inputs']['model_inputs'][inp]:
model.edge(n, inp)
layer_dic = {}
for item in conf_dic['architecture']:
if 'layer_id' in item.keys() and 'layer' in item.keys() and 'conf' in item.keys():
layer_dic[item['layer_id']] = [item['layer'], item['conf']]
for item in conf_dic['architecture']:
if 'layer_id' in item.keys():
if item['layer'] in layer_dic:
tmp_layer = item['layer']
item['conf'] = layer_dic[tmp_layer][1]
item['layer'] = layer_dic[tmp_layer][0]
label_str = "<" \
+ "" \
+ "" + item['layer_id'] + " | " + "" + item[
'layer'] + " |
"
if item['layer'] in layer_conf:
for c in layer_conf[item['layer']]:
if c in item['conf']:
label_str = label_str + "" + c + " | " + "" + str(
item['conf'][c]) + " |
"
else:
for c in item['conf']:
label_str = label_str + "" + c + " | " + "" + str(
item['conf'][c]) + " |
"
label_str += "
>"
model.node(name=item['layer_id'],
label=label_str,
fillcolor=color.get(item['layer'], "grey"))
for inp in item['inputs']:
model.edge(inp, item['layer_id'])
# model
model.render(graph_path, view=False)
return
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='get model graph')
parser.add_argument("--conf_path", type=str, help="JSON config path")
parser.add_argument("--graph_path", type=str, default="graph", help="Model graph path")
args = parser.parse_args()
json2graph(args.conf_path, args.graph_path)
print("The model graph has been successfully generated!")