YOLaT-VectorGraphicsRecogni.../Datasets/svg2.py

281 строка
9.8 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import os
import numpy as np
from xml.dom.minidom import parse, Node, parseString
from torch_geometric.data import Data
from Datasets.svg_parser import SVGGraphBuilderShape as SVGGraphBuilder
from Datasets.svg_parser import SVGParser
from sklearn.metrics.pairwise import euclidean_distances
#from a2c import a2c
class SESYDFloorPlan(torch.utils.data.Dataset):
def __init__(self, root, opt, partition = 'train', data_aug = False):
super(SESYDFloorPlan, self).__init__()
svg_list = open(os.path.join(root, partition + '_list.txt')).readlines()
svg_list = [os.path.join(root, line.strip()) for line in svg_list]
self.graph_builder = SVGGraphBuilder()
#print(svg_list)
self.pos_edge_th = opt.pos_edge_th
self.data_aug = data_aug
self.svg_list = svg_list
self.class_dict = {
'armchair':0,
'bed':1,
'door1':2,
'door2':3,
'sink1':4,
'sink2':5,
'sink3':6,
'sink4':7,
'sofa1':8,
'sofa2':9,
'table1':10,
'table2':11,
'table3':12,
'tub':13,
'window1':14,
'window2':15
}
'''
self.class_dict = {
'armchair':0,
'bed':1,
'door1':2,
'door2':2,
'sink1':3,
'sink2':3,
'sink3':3,
'sink4':3,
'sofa1':4,
'sofa2':4,
'table1':5,
'table2':5,
'table3':5,
'tub':6,
'window1':7,
'window2':7
}
'''
#self.anchors = self.get_anchor()
'''
self.n_objects = 0
for idx in range(len(self.svg_list)):
filepath = self.svg_list[idx]
print(filepath)
p = SVGParser(filepath)
width, height = p.get_image_size()
#graph_dict = self.graph_builder.buildGraph(p.get_all_shape())
gt_bbox, gt_labels = self._get_bbox(filepath, width, height)
self.n_objects += gt_bbox.shape[0]
print(self.n_objects)
'''
self.n_objects = 13238
def __len__(self):
return len(self.svg_list)
def _get_bbox(self, path, width, height):
dom = parse(path.replace('.svg', '.xml'))
root = dom.documentElement
nodes = []
for tagname in ['a', 'o']:
nodes += root.getElementsByTagName(tagname)
bbox = []
labels = []
for node in nodes:
for n in node.childNodes:
if n.nodeType != Node.ELEMENT_NODE:
continue
x0 = float(n.getAttribute('x0')) / width
y0 = float(n.getAttribute('y0')) / height
x1 = float(n.getAttribute('x1')) / width
y1 = float(n.getAttribute('y1')) / height
label = n.getAttribute('label')
bbox.append((x0, y0, x1, y1))
labels.append(self.class_dict[label])
return np.array(bbox), np.array(labels)
def gen_y(self, graph_dict, bbox, labels, width, height):
pos = graph_dict['pos']
th = 1e-3
gt_bb = []
gt_cls = []
gt_object = []
for node_idx, p in enumerate(pos):
diff_0 = p[None, :] - bbox[:, 0:2]
diff_1 = p[None, :] - bbox[:, 2:]
in_object = (diff_0[:, 0] >= -th) & (diff_0[:, 1] >= -th) & (diff_1[:, 0] <= th) & (diff_1[:, 1] <= th)
object_index = np.where(in_object)[0]
if len(object_index) > 1:
#print(object_index)
#print('node', p[0] * width, p[1] * height, 'is inside more than one object')
candidates = bbox[object_index]
s = euclidean_distances(p[None, :], candidates[:, 0:2])[0]
#print(np.argsort(s))
object_index = object_index[np.argsort(s)]
#print(candidates, s, object_index)
elif len(object_index) == 0:
#print(diff_0 * [width, height], diff_1* [width, height])
#print(object_index)
print('node', p[0] * width, p[1] * height, 'outside all object')
#for i, line in enumerate(bbox[:, 0:2] * [width, height]):
# print(i, line)
raise SystemExit
cls = labels[object_index[0]]
bb = bbox[object_index[0]]
'''
h = bb[3] - bb[1]
w = bb[2] - bb[0]
offset_x = bb[0] - p[0]
offset_y = bb[1] - p[1]
gt_bb.append((offset_x, offset_y, w, h))
'''
gt_bb.append(bb)
gt_cls.append(cls)
gt_object.append(object_index[0])
return np.array(gt_bb), np.array(gt_cls), np.array(gt_object)
def __transform__(self, pos, scale, angle, translate):
scale_m = np.eye(2)
scale_m[0, 0] = scale
scale_m[1, 1] = scale
rot_m = np.eye(2)
rot_m[0, 0:2] = [np.cos(angle), np.sin(angle)]
rot_m[1, 0:2] = [-np.sin(angle), np.cos(angle)]
#print(pos.shape, scale_m[0:2].shape)
#pos = np.matmul(pos, scale_m[0:2])
#print(pos.shape)
center = np.array((0.5, 0.5))[None, :]
pos -= center
pos = np.matmul(pos, rot_m[0:2])
pos += center
#pos += np.array(translate)[None, :]
return pos
def __transform_bbox__(self, bbox, scale, angle, translate):
p0 = bbox[:, 0:2]
p2 = bbox[:, 2:]
p1 = np.concatenate([p2[:, 0][:, None], p0[:, 1][:, None]], axis = 1)
p3 = np.concatenate([p0[:, 0][:, None], p2[:, 1][:, None]], axis = 1)
p0 = self.__transform__(p0, scale, angle, translate)
p1 = self.__transform__(p1, scale, angle, translate)
p2 = self.__transform__(p2, scale, angle, translate)
p3 = self.__transform__(p3, scale, angle, translate)
def bound_rect(p0, p1, p2, p3):
x = np.concatenate((p0[:, 0][:, None], p1[:, 0][:, None], p2[:, 0][:, None], p3[:, 0][:, None]), axis = 1)
y = np.concatenate((p0[:, 1][:, None], p1[:, 1][:, None], p2[:, 1][:, None], p3[:, 1][:, None]), axis = 1)
x_min = x.min(1, keepdims = True)
x_max = x.max(1, keepdims = True)
y_min = y.min(1, keepdims = True)
y_max = y.max(1, keepdims = True)
return np.concatenate([x_min, y_min, x_max, y_max], axis = 1)
return bound_rect(p0, p1, p2, p3)
def random_transfer(self, pos, bbox, gt_bbox):
scale = np.random.random() * 0.1 + 0.9
angle = np.random.random() * np.pi * 2
translate = [0, 0]
translate[0] = np.random.random() * 0.2 - 0.1
translate[1] = np.random.random() * 0.2 - 0.1
pos = self.__transform__(pos, scale, angle, translate)
bbox = self.__transform_bbox__(bbox, scale, angle, translate)
gt_bbox = self.__transform_bbox__(gt_bbox, scale, angle, translate)
return pos, bbox, gt_bbox
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
#for idx in range(len(self.svg_list)):
filepath = self.svg_list[idx]
p = SVGParser(filepath)
width, height = p.get_image_size()
graph_dict = self.graph_builder.buildGraph(p.get_all_shape())
gt_bbox, gt_labels = self._get_bbox(filepath, width, height)
bbox, labels, gt_object = self.gen_y(graph_dict, gt_bbox, gt_labels, width, height)
feats = graph_dict['f']
pos = graph_dict['pos']
is_control = np.zeros((pos.shape[0], 1))
edge = graph_dict['edge']
feats = torch.tensor(feats, dtype=torch.float32)
pos = torch.tensor(pos, dtype=torch.float32)
edge = torch.tensor(edge, dtype=torch.long)
is_control = torch.tensor(is_control, dtype=torch.bool)
bbox = torch.tensor(bbox, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)
gt_bbox = torch.tensor(gt_bbox, dtype=torch.float32)
gt_labels = torch.tensor(gt_labels, dtype=torch.long)
gt_object = torch.tensor(gt_object, dtype=torch.long)
e_weight = torch.tensor(graph_dict['edge_weight'], dtype=torch.float32)
#print('bbox', bbox.size())
#print('labels', labels.size())
#raise SystemExit
data = Data(x = feats, pos = pos)
data.edge = edge
#data.edge_control = None
#data.edge_pos = None
data.is_control = is_control
data.bbox = bbox
data.labels = labels
data.gt_bbox = gt_bbox
data.gt_labels = gt_labels
data.gt_object = gt_object
data.filepath = filepath
data.width = width
data.height = height
data.e_weight = e_weight
return data
if __name__ == '__main__':
svg_list = open('/home/xinyangjiang/Datasets/SESYD/FloorPlans/train_list.txt').readlines()
svg_list = ['/home/xinyangjiang/Datasets/SESYD/FloorPlans/' + line.strip() for line in svg_list]
builder = SVGGraphBuilder()
for line in svg_list:
print(line)
#line = '/home/xinyangjiang/Datasets/SESYD/FloorPlans/floorplans16-01/file_56.svg'
p = SVGParser(line)
builder.buildGraph(p.get_all_shape())
#train_dataset = SESYDFloorPlan(opt.data_dir, pre_transform=T.NormalizeScale())
#train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)
#for batch in train_loader:
# pass
#paths, attributes, svg_attributes = svg2paths2('/home/xinyangjiang/Datasets/SESYD/FloorPlans/floorplans16-05/file_47.svg')
#print(paths, attributes, svg_attributes)