This commit is contained in:
Xinyang Jiang 2022-01-12 01:50:05 -08:00
Родитель 4183041aa7
Коммит feafbf5098
2 изменённых файлов: 19 добавлений и 31 удалений

Просмотреть файл

@ -20,14 +20,16 @@ a) Download and unzip the [Floorplans dataset](http://mathieu.delalandre.free.fr
b) Run the following scripts to prepare the dataset for training/inference.
```sh
python utils/svg_utils/build_graph_bbox.py
cd utils
python svg_utils/build_graph_bbox.py
```
#### Diagrams
a) Download and unzip the [Diagrams dataset](http://mathieu.delalandre.free.fr/projects/sesyd/symbols/diagrams.html) to the dataset folder: `data/diagrams`
b) Run the following scripts to prepare the dataset for training/inference.
```sh
python utils/svg_utils/build_graph_bbox_diagram.py
cd utils
python svg_utils/build_graph_bbox_diagram.py
```
### 2. Training & Inference

Просмотреть файл

@ -1,6 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))
@ -49,19 +46,17 @@ def shape2Path(type_dict):
return paths
def getConnnectedComponent(node_dict):
#edges = node_dict['edge']['shape']
edges = np.concatenate([node_dict['edge']['shape'], node_dict['edge']['control']], axis = 0)
edges = node_dict['edge']['shape']
pos = node_dict['pos']['spatial']
is_control = node_dict['attr']['is_control']
#print(edges)
adj = np.eye(pos.shape[0], pos.shape[0]).astype(bool)
adj = np.eye(pos.shape[0], pos.shape[0]).astype(np.bool)
for e in edges:
adj[e[0], e[1]] = True
adj[e[1], e[0]] = True
n_node = pos.shape[0]
#visited = [False if not is_control[i] else True for i in range(n_node) ]
visited = [False for i in range(n_node) ]
visited = [False if not is_control[i] else True for i in range(n_node) ]
clusters = []
for start_node in range(0, n_node):
@ -110,7 +105,6 @@ def draw_cluster_graph(svg_path, save_path, width, height, bboxs, pos, is_contro
facecolor="none",
)
ax.add_patch(bbox)
os.makedirs(save_path, exist_ok=True)
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.0, dpi=600)
def mergeCluster(cc, bboxs, ratio=None, expand_length=None):
@ -201,19 +195,15 @@ def mergeCC(node_dict, svg_path, width, height):
#cc, bboxs = mergeCluster(cc, bboxs, ratio=0.5, expand_length=None)
# hardcode 70 for expand_length
cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(70 / width, 70 / height))
#cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(40 / width, 40 / height))
#cc, bboxs = mergeCluster(cc, bboxs, ratio=1)
#cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(70 / width, 70 / height))
cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(40 / width, 40 / height))
paths = []
shape_shape_edges = []
for i, cluster in enumerate(cc):
for idx in cluster:
if is_control[idx]: continue
for idx_j in cluster:
if idx == idx_j: continue
if is_control[idx_j]: continue
shape_shape_edges.append((idx, idx_j))
if True:
@ -233,11 +223,12 @@ def mergeCC(node_dict, svg_path, width, height):
))
svg_path_list = svg_path.split('/')
draw_cluster_graph(svg_path, "data/diagram_expand_len_graph/{}_{}".format(svg_path_list[-2], svg_path_list[-1].replace('.svg', '.pdf')), width, height, bboxs, pos, is_control, edges)
print("draw bbox and node of {}".format(svg_path))
#draw_cluster_graph(svg_path, "/home/v-luliu1/datasets/diagram_expand_len_graph/{}_{}".format(svg_path_list[-2], svg_path_list[-1].replace('.svg', '.pdf')), width, height, bboxs, pos, is_control, edges)
#print("draw bbox and node of {}".format(svg_path))
cross_shape_edges = []
same_cc = np.zeros((len(bboxs), len(bboxs))).astype(bool)
same_cc = np.zeros((len(bboxs), len(bboxs))).astype(np.bool)
for i, parent_bb in enumerate(bboxs):
for j, child_bb in enumerate(bboxs):
if i == j: continue
@ -263,9 +254,7 @@ def mergeCC(node_dict, svg_path, width, height):
if is_parent_child:
for parent_idx in cc[i]:
if is_control[parent_idx]: continue
for child_idx in cc[j]:
if is_control[child_idx]: continue
cross_shape_edges.append((parent_idx, child_idx))
same_cc[i, j] = True
same_cc[j, i] = True
@ -282,7 +271,7 @@ def mergeCC(node_dict, svg_path, width, height):
visited[i] = True
get_all_neighboors(i, ret)
visited = np.zeros(same_cc.shape[0]).astype(bool)
visited = np.zeros(same_cc.shape[0]).astype(np.bool)
merged_cc = []
for i, all_neighbors in enumerate(same_cc):
if visited[i]: continue
@ -318,9 +307,6 @@ def mergeCC(node_dict, svg_path, width, height):
shape_shape_edge_attr = get_attr(shape_shape_edges)
cross_shape_edge_attr = get_attr(cross_shape_edges)
print(new_cc)
return np.array(shape_shape_edges), np.array(cross_shape_edges), np.array(shape_shape_edge_attr), np.array(cross_shape_edge_attr), paths, new_cc
@ -328,9 +314,10 @@ def mergeCC(node_dict, svg_path, width, height):
if __name__ == '__main__':
graph_builder = SVGGraphBuilderBezier()
input_dir = 'data/diagrams/'
output_dir = 'data/diagrams/'
#input_dir = '/home/v-luliu1/datasets/floorplans_test'
#output_dir = '/home/v-luliu1/datasets/floorplans_test'
input_dir = '/data/xinyangjiang/Datasets/SESYD/diagram2'
output_dir = '/data/xinyangjiang/Datasets/SESYD/diagram2'
dir_list = os.listdir(input_dir)
angles = []
@ -346,7 +333,6 @@ if __name__ == '__main__':
p = SVGParser(filepath)
# split_cross splits the segments into multiple small segments if there is a cross-point
type_dict = split_cross(p.get_all_shape())
width, height = p.get_image_size()
paths = shape2Path(type_dict)
# {'pos': {'spatial': the positions of nodes}, {'attr': {'color': color of every node, 'stroke_width': stroke width, 'is_control': if the node is control node}}, {'edge': {'control': control edge, 'spatial': spatial edge}}, 'edge_attr': N * 6}
@ -393,7 +379,7 @@ if __name__ == '__main__':
node_dict['edge']['super'] = np.concatenate([shape_shape_edges, cross_shape_edges], axis = 0)
#node_dict['attr']['is_control'] = np.concatenate([node_dict['attr']['is_control'], np.zeros((super_pos.shape[0], 1)).astype(np.bool)], axis = 0)
#node_dict['attr']['is_super'] = np.concatenate([np.zeros((start_end_size, 1)).astype(np.bool), np.ones((super_pos.shape[0], 1)).astype(np.bool)], axis = 0)
node_dict['attr']['is_super'] = np.zeros((start_end_size, 1)).astype(bool)
node_dict['attr']['is_super'] = np.zeros((start_end_size, 1)).astype(np.bool)
if len(cross_shape_edge_attr) == 0:
node_dict['edge_attr']['super'] = shape_shape_edge_attr
else: