# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. import json from graphviz import * import argparse def json2graph(json_path, graph_path): with open(json_path, 'r') as file: json_str = file.read() try: conf_dic = json.loads(json_str) except ValueError as e: print(e) graph_path += '.gv' color = { "Input": "royalblue", "Embedding": "orange", "Linear": "tan", "LinearAttention": "tan", "BiGRU": "salmon", "BiLSTM": "salmon", "BiLSTMAtt": "salmon1", "BiGRULast": "salmon", "Conv": "sandybrown", "ConvPooling": "sandybrown", "Pooling": "skyblue", "Dropout": "yellowgreen", "Combination": "purple", "EncoderDecoder": "lightsalmon", "FullAttention": "lightsalmon", "Seq2SeqAttention": "lightsalmon" } layer_conf = { "Linear": ["hidden_dim", "activation", "last_hidden_activation", "last_hidden_softmax", "batch_normalization"], "LinearAttention": ["keep_dim"], "BiGRU": ["hidden_dim", "dropout"], "BiGRULast": ["hidden_dim", "dropout"], "BiLSTM": ["hidden_dim", "dropout", "num_layers"], "BiLSTMAtt": ["hidden_dim", "dropout", "num_layers"], "Conv": ["stride", "padding", "window_sizes", "input_channel_num", "output_channel_num", "activation", "batch_normalization"], "ConvPooling": ["stride", "padding", "window_sizes", "input_channel_num", "output_channel_num", "batch_normalization", "activation", "pool_type", "pool_axis"], "Pooling": ["pool_axis", "pool_type"], "Dropout": ["dropout"], "Combination": ["operations"], "EncoderDecoder": ["encoder", "decoder"], "FullAttention": ["hidden_dim", "activation"], "Seq2SeqAttention": ["attention_dropout"] } model = Digraph(format='svg', node_attr={"style": "rounded, filled", "shape": "box", "fontcolor": "white"}) model.attr(rankdir="BT") for item in conf_dic['architecture']: if item['layer'] == "Embedding": for c in item['conf']: dim = item['conf'][c]['dim'] for n in item['conf'][c]['cols']: label_str = "<" \ + "" \ + "" + "" \ + "" + "" \ + "
" + n + "Embedding
dim:" + str(dim) + "
>" model.node(name=n, label=label_str, fillcolor=color["Embedding"]) break for inp in conf_dic['inputs']['model_inputs']: model.node(name=inp, label=inp, fillcolor=color['Input']) for n in conf_dic['inputs']['model_inputs'][inp]: model.edge(n, inp) layer_dic = {} for item in conf_dic['architecture']: if 'layer_id' in item.keys() and 'layer' in item.keys() and 'conf' in item.keys(): layer_dic[item['layer_id']] = [item['layer'], item['conf']] for item in conf_dic['architecture']: if 'layer_id' in item.keys(): if item['layer'] in layer_dic: tmp_layer = item['layer'] item['conf'] = layer_dic[tmp_layer][1] item['layer'] = layer_dic[tmp_layer][0] label_str = "<" \ + "" \ + "" + "" if item['layer'] in layer_conf: for c in layer_conf[item['layer']]: if c in item['conf']: label_str = label_str + "" + "" else: for c in item['conf']: label_str = label_str + "" + "" label_str += "
" + item['layer_id'] + "" + item[ 'layer'] + "
" + c + "" + str( item['conf'][c]) + "
" + c + "" + str( item['conf'][c]) + "
>" model.node(name=item['layer_id'], label=label_str, fillcolor=color.get(item['layer'], "grey")) for inp in item['inputs']: model.edge(inp, item['layer_id']) # model model.render(graph_path, view=False) return if __name__ == "__main__": parser = argparse.ArgumentParser(description='get model graph') parser.add_argument("--conf_path", type=str, help="JSON config path") parser.add_argument("--graph_path", type=str, default="graph", help="Model graph path") args = parser.parse_args() json2graph(args.conf_path, args.graph_path) print("The model graph has been successfully generated!")