diagram graph build bug fixed
This commit is contained in:
Родитель
4183041aa7
Коммит
feafbf5098
|
@ -20,14 +20,16 @@ a) Download and unzip the [Floorplans dataset](http://mathieu.delalandre.free.fr
|
|||
b) Run the following scripts to prepare the dataset for training/inference.
|
||||
|
||||
```sh
|
||||
python utils/svg_utils/build_graph_bbox.py
|
||||
cd utils
|
||||
python svg_utils/build_graph_bbox.py
|
||||
```
|
||||
#### Diagrams
|
||||
a) Download and unzip the [Diagrams dataset](http://mathieu.delalandre.free.fr/projects/sesyd/symbols/diagrams.html) to the dataset folder: `data/diagrams`
|
||||
|
||||
b) Run the following scripts to prepare the dataset for training/inference.
|
||||
```sh
|
||||
python utils/svg_utils/build_graph_bbox_diagram.py
|
||||
cd utils
|
||||
python svg_utils/build_graph_bbox_diagram.py
|
||||
```
|
||||
|
||||
### 2. Training & Inference
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os, sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))
|
||||
|
||||
|
@ -49,19 +46,17 @@ def shape2Path(type_dict):
|
|||
return paths
|
||||
|
||||
def getConnnectedComponent(node_dict):
|
||||
#edges = node_dict['edge']['shape']
|
||||
edges = np.concatenate([node_dict['edge']['shape'], node_dict['edge']['control']], axis = 0)
|
||||
edges = node_dict['edge']['shape']
|
||||
pos = node_dict['pos']['spatial']
|
||||
is_control = node_dict['attr']['is_control']
|
||||
#print(edges)
|
||||
adj = np.eye(pos.shape[0], pos.shape[0]).astype(bool)
|
||||
adj = np.eye(pos.shape[0], pos.shape[0]).astype(np.bool)
|
||||
for e in edges:
|
||||
adj[e[0], e[1]] = True
|
||||
adj[e[1], e[0]] = True
|
||||
|
||||
n_node = pos.shape[0]
|
||||
#visited = [False if not is_control[i] else True for i in range(n_node) ]
|
||||
visited = [False for i in range(n_node) ]
|
||||
visited = [False if not is_control[i] else True for i in range(n_node) ]
|
||||
clusters = []
|
||||
|
||||
for start_node in range(0, n_node):
|
||||
|
@ -110,7 +105,6 @@ def draw_cluster_graph(svg_path, save_path, width, height, bboxs, pos, is_contro
|
|||
facecolor="none",
|
||||
)
|
||||
ax.add_patch(bbox)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.0, dpi=600)
|
||||
|
||||
def mergeCluster(cc, bboxs, ratio=None, expand_length=None):
|
||||
|
@ -201,19 +195,15 @@ def mergeCC(node_dict, svg_path, width, height):
|
|||
|
||||
#cc, bboxs = mergeCluster(cc, bboxs, ratio=0.5, expand_length=None)
|
||||
# hardcode 70 for expand_length
|
||||
cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(70 / width, 70 / height))
|
||||
#cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(40 / width, 40 / height))
|
||||
#cc, bboxs = mergeCluster(cc, bboxs, ratio=1)
|
||||
|
||||
#cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(70 / width, 70 / height))
|
||||
cc, bboxs = mergeCluster(cc, bboxs, ratio=None, expand_length=(40 / width, 40 / height))
|
||||
|
||||
paths = []
|
||||
shape_shape_edges = []
|
||||
for i, cluster in enumerate(cc):
|
||||
for idx in cluster:
|
||||
if is_control[idx]: continue
|
||||
for idx_j in cluster:
|
||||
if idx == idx_j: continue
|
||||
if is_control[idx_j]: continue
|
||||
shape_shape_edges.append((idx, idx_j))
|
||||
|
||||
if True:
|
||||
|
@ -233,11 +223,12 @@ def mergeCC(node_dict, svg_path, width, height):
|
|||
))
|
||||
|
||||
svg_path_list = svg_path.split('/')
|
||||
draw_cluster_graph(svg_path, "data/diagram_expand_len_graph/{}_{}".format(svg_path_list[-2], svg_path_list[-1].replace('.svg', '.pdf')), width, height, bboxs, pos, is_control, edges)
|
||||
print("draw bbox and node of {}".format(svg_path))
|
||||
|
||||
#draw_cluster_graph(svg_path, "/home/v-luliu1/datasets/diagram_expand_len_graph/{}_{}".format(svg_path_list[-2], svg_path_list[-1].replace('.svg', '.pdf')), width, height, bboxs, pos, is_control, edges)
|
||||
#print("draw bbox and node of {}".format(svg_path))
|
||||
|
||||
cross_shape_edges = []
|
||||
same_cc = np.zeros((len(bboxs), len(bboxs))).astype(bool)
|
||||
same_cc = np.zeros((len(bboxs), len(bboxs))).astype(np.bool)
|
||||
for i, parent_bb in enumerate(bboxs):
|
||||
for j, child_bb in enumerate(bboxs):
|
||||
if i == j: continue
|
||||
|
@ -263,9 +254,7 @@ def mergeCC(node_dict, svg_path, width, height):
|
|||
|
||||
if is_parent_child:
|
||||
for parent_idx in cc[i]:
|
||||
if is_control[parent_idx]: continue
|
||||
for child_idx in cc[j]:
|
||||
if is_control[child_idx]: continue
|
||||
cross_shape_edges.append((parent_idx, child_idx))
|
||||
same_cc[i, j] = True
|
||||
same_cc[j, i] = True
|
||||
|
@ -282,7 +271,7 @@ def mergeCC(node_dict, svg_path, width, height):
|
|||
visited[i] = True
|
||||
get_all_neighboors(i, ret)
|
||||
|
||||
visited = np.zeros(same_cc.shape[0]).astype(bool)
|
||||
visited = np.zeros(same_cc.shape[0]).astype(np.bool)
|
||||
merged_cc = []
|
||||
for i, all_neighbors in enumerate(same_cc):
|
||||
if visited[i]: continue
|
||||
|
@ -318,9 +307,6 @@ def mergeCC(node_dict, svg_path, width, height):
|
|||
|
||||
shape_shape_edge_attr = get_attr(shape_shape_edges)
|
||||
cross_shape_edge_attr = get_attr(cross_shape_edges)
|
||||
|
||||
print(new_cc)
|
||||
|
||||
|
||||
return np.array(shape_shape_edges), np.array(cross_shape_edges), np.array(shape_shape_edge_attr), np.array(cross_shape_edge_attr), paths, new_cc
|
||||
|
||||
|
@ -328,9 +314,10 @@ def mergeCC(node_dict, svg_path, width, height):
|
|||
|
||||
if __name__ == '__main__':
|
||||
graph_builder = SVGGraphBuilderBezier()
|
||||
|
||||
input_dir = 'data/diagrams/'
|
||||
output_dir = 'data/diagrams/'
|
||||
#input_dir = '/home/v-luliu1/datasets/floorplans_test'
|
||||
#output_dir = '/home/v-luliu1/datasets/floorplans_test'
|
||||
input_dir = '/data/xinyangjiang/Datasets/SESYD/diagram2'
|
||||
output_dir = '/data/xinyangjiang/Datasets/SESYD/diagram2'
|
||||
dir_list = os.listdir(input_dir)
|
||||
|
||||
angles = []
|
||||
|
@ -346,7 +333,6 @@ if __name__ == '__main__':
|
|||
p = SVGParser(filepath)
|
||||
# split_cross splits the segments into multiple small segments if there is a cross-point
|
||||
type_dict = split_cross(p.get_all_shape())
|
||||
|
||||
width, height = p.get_image_size()
|
||||
paths = shape2Path(type_dict)
|
||||
# {'pos': {'spatial': the positions of nodes}, {'attr': {'color': color of every node, 'stroke_width': stroke width, 'is_control': if the node is control node}}, {'edge': {'control': control edge, 'spatial': spatial edge}}, 'edge_attr': N * 6}
|
||||
|
@ -393,7 +379,7 @@ if __name__ == '__main__':
|
|||
node_dict['edge']['super'] = np.concatenate([shape_shape_edges, cross_shape_edges], axis = 0)
|
||||
#node_dict['attr']['is_control'] = np.concatenate([node_dict['attr']['is_control'], np.zeros((super_pos.shape[0], 1)).astype(np.bool)], axis = 0)
|
||||
#node_dict['attr']['is_super'] = np.concatenate([np.zeros((start_end_size, 1)).astype(np.bool), np.ones((super_pos.shape[0], 1)).astype(np.bool)], axis = 0)
|
||||
node_dict['attr']['is_super'] = np.zeros((start_end_size, 1)).astype(bool)
|
||||
node_dict['attr']['is_super'] = np.zeros((start_end_size, 1)).astype(np.bool)
|
||||
if len(cross_shape_edge_attr) == 0:
|
||||
node_dict['edge_attr']['super'] = shape_shape_edge_attr
|
||||
else:
|
||||
|
|
Загрузка…
Ссылка в новой задаче