281 строка
9.8 KiB
Python
281 строка
9.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
import torch
|
|
import os
|
|
import numpy as np
|
|
from xml.dom.minidom import parse, Node, parseString
|
|
|
|
from torch_geometric.data import Data
|
|
from Datasets.svg_parser import SVGGraphBuilderShape as SVGGraphBuilder
|
|
from Datasets.svg_parser import SVGParser
|
|
from sklearn.metrics.pairwise import euclidean_distances
|
|
|
|
#from a2c import a2c
|
|
|
|
class SESYDFloorPlan(torch.utils.data.Dataset):
|
|
def __init__(self, root, opt, partition = 'train', data_aug = False):
|
|
super(SESYDFloorPlan, self).__init__()
|
|
|
|
svg_list = open(os.path.join(root, partition + '_list.txt')).readlines()
|
|
svg_list = [os.path.join(root, line.strip()) for line in svg_list]
|
|
self.graph_builder = SVGGraphBuilder()
|
|
#print(svg_list)
|
|
|
|
self.pos_edge_th = opt.pos_edge_th
|
|
self.data_aug = data_aug
|
|
|
|
self.svg_list = svg_list
|
|
|
|
self.class_dict = {
|
|
'armchair':0,
|
|
'bed':1,
|
|
'door1':2,
|
|
'door2':3,
|
|
'sink1':4,
|
|
'sink2':5,
|
|
'sink3':6,
|
|
'sink4':7,
|
|
'sofa1':8,
|
|
'sofa2':9,
|
|
'table1':10,
|
|
'table2':11,
|
|
'table3':12,
|
|
'tub':13,
|
|
'window1':14,
|
|
'window2':15
|
|
}
|
|
|
|
'''
|
|
self.class_dict = {
|
|
'armchair':0,
|
|
'bed':1,
|
|
'door1':2,
|
|
'door2':2,
|
|
'sink1':3,
|
|
'sink2':3,
|
|
'sink3':3,
|
|
'sink4':3,
|
|
'sofa1':4,
|
|
'sofa2':4,
|
|
'table1':5,
|
|
'table2':5,
|
|
'table3':5,
|
|
'tub':6,
|
|
'window1':7,
|
|
'window2':7
|
|
}
|
|
'''
|
|
#self.anchors = self.get_anchor()
|
|
'''
|
|
self.n_objects = 0
|
|
for idx in range(len(self.svg_list)):
|
|
filepath = self.svg_list[idx]
|
|
print(filepath)
|
|
p = SVGParser(filepath)
|
|
width, height = p.get_image_size()
|
|
#graph_dict = self.graph_builder.buildGraph(p.get_all_shape())
|
|
|
|
gt_bbox, gt_labels = self._get_bbox(filepath, width, height)
|
|
self.n_objects += gt_bbox.shape[0]
|
|
print(self.n_objects)
|
|
'''
|
|
self.n_objects = 13238
|
|
|
|
def __len__(self):
|
|
return len(self.svg_list)
|
|
|
|
def _get_bbox(self, path, width, height):
|
|
dom = parse(path.replace('.svg', '.xml'))
|
|
root = dom.documentElement
|
|
|
|
nodes = []
|
|
for tagname in ['a', 'o']:
|
|
nodes += root.getElementsByTagName(tagname)
|
|
|
|
bbox = []
|
|
labels = []
|
|
for node in nodes:
|
|
for n in node.childNodes:
|
|
if n.nodeType != Node.ELEMENT_NODE:
|
|
continue
|
|
x0 = float(n.getAttribute('x0')) / width
|
|
y0 = float(n.getAttribute('y0')) / height
|
|
x1 = float(n.getAttribute('x1')) / width
|
|
y1 = float(n.getAttribute('y1')) / height
|
|
label = n.getAttribute('label')
|
|
bbox.append((x0, y0, x1, y1))
|
|
labels.append(self.class_dict[label])
|
|
|
|
return np.array(bbox), np.array(labels)
|
|
|
|
def gen_y(self, graph_dict, bbox, labels, width, height):
|
|
pos = graph_dict['pos']
|
|
|
|
th = 1e-3
|
|
gt_bb = []
|
|
gt_cls = []
|
|
gt_object = []
|
|
for node_idx, p in enumerate(pos):
|
|
|
|
diff_0 = p[None, :] - bbox[:, 0:2]
|
|
diff_1 = p[None, :] - bbox[:, 2:]
|
|
in_object = (diff_0[:, 0] >= -th) & (diff_0[:, 1] >= -th) & (diff_1[:, 0] <= th) & (diff_1[:, 1] <= th)
|
|
|
|
object_index = np.where(in_object)[0]
|
|
if len(object_index) > 1:
|
|
#print(object_index)
|
|
#print('node', p[0] * width, p[1] * height, 'is inside more than one object')
|
|
candidates = bbox[object_index]
|
|
s = euclidean_distances(p[None, :], candidates[:, 0:2])[0]
|
|
#print(np.argsort(s))
|
|
object_index = object_index[np.argsort(s)]
|
|
#print(candidates, s, object_index)
|
|
elif len(object_index) == 0:
|
|
#print(diff_0 * [width, height], diff_1* [width, height])
|
|
#print(object_index)
|
|
print('node', p[0] * width, p[1] * height, 'outside all object')
|
|
#for i, line in enumerate(bbox[:, 0:2] * [width, height]):
|
|
# print(i, line)
|
|
raise SystemExit
|
|
cls = labels[object_index[0]]
|
|
bb = bbox[object_index[0]]
|
|
'''
|
|
h = bb[3] - bb[1]
|
|
w = bb[2] - bb[0]
|
|
offset_x = bb[0] - p[0]
|
|
offset_y = bb[1] - p[1]
|
|
gt_bb.append((offset_x, offset_y, w, h))
|
|
'''
|
|
gt_bb.append(bb)
|
|
gt_cls.append(cls)
|
|
gt_object.append(object_index[0])
|
|
|
|
return np.array(gt_bb), np.array(gt_cls), np.array(gt_object)
|
|
|
|
def __transform__(self, pos, scale, angle, translate):
|
|
scale_m = np.eye(2)
|
|
scale_m[0, 0] = scale
|
|
scale_m[1, 1] = scale
|
|
|
|
rot_m = np.eye(2)
|
|
rot_m[0, 0:2] = [np.cos(angle), np.sin(angle)]
|
|
rot_m[1, 0:2] = [-np.sin(angle), np.cos(angle)]
|
|
|
|
#print(pos.shape, scale_m[0:2].shape)
|
|
#pos = np.matmul(pos, scale_m[0:2])
|
|
#print(pos.shape)
|
|
center = np.array((0.5, 0.5))[None, :]
|
|
pos -= center
|
|
pos = np.matmul(pos, rot_m[0:2])
|
|
pos += center
|
|
#pos += np.array(translate)[None, :]
|
|
return pos
|
|
|
|
def __transform_bbox__(self, bbox, scale, angle, translate):
|
|
p0 = bbox[:, 0:2]
|
|
p2 = bbox[:, 2:]
|
|
p1 = np.concatenate([p2[:, 0][:, None], p0[:, 1][:, None]], axis = 1)
|
|
p3 = np.concatenate([p0[:, 0][:, None], p2[:, 1][:, None]], axis = 1)
|
|
|
|
p0 = self.__transform__(p0, scale, angle, translate)
|
|
p1 = self.__transform__(p1, scale, angle, translate)
|
|
p2 = self.__transform__(p2, scale, angle, translate)
|
|
p3 = self.__transform__(p3, scale, angle, translate)
|
|
|
|
def bound_rect(p0, p1, p2, p3):
|
|
x = np.concatenate((p0[:, 0][:, None], p1[:, 0][:, None], p2[:, 0][:, None], p3[:, 0][:, None]), axis = 1)
|
|
y = np.concatenate((p0[:, 1][:, None], p1[:, 1][:, None], p2[:, 1][:, None], p3[:, 1][:, None]), axis = 1)
|
|
x_min = x.min(1, keepdims = True)
|
|
x_max = x.max(1, keepdims = True)
|
|
y_min = y.min(1, keepdims = True)
|
|
y_max = y.max(1, keepdims = True)
|
|
|
|
return np.concatenate([x_min, y_min, x_max, y_max], axis = 1)
|
|
return bound_rect(p0, p1, p2, p3)
|
|
|
|
def random_transfer(self, pos, bbox, gt_bbox):
|
|
scale = np.random.random() * 0.1 + 0.9
|
|
angle = np.random.random() * np.pi * 2
|
|
translate = [0, 0]
|
|
translate[0] = np.random.random() * 0.2 - 0.1
|
|
translate[1] = np.random.random() * 0.2 - 0.1
|
|
|
|
pos = self.__transform__(pos, scale, angle, translate)
|
|
bbox = self.__transform_bbox__(bbox, scale, angle, translate)
|
|
gt_bbox = self.__transform_bbox__(gt_bbox, scale, angle, translate)
|
|
|
|
return pos, bbox, gt_bbox
|
|
|
|
def __getitem__(self, idx):
|
|
if torch.is_tensor(idx):
|
|
idx = idx.tolist()
|
|
|
|
#for idx in range(len(self.svg_list)):
|
|
filepath = self.svg_list[idx]
|
|
p = SVGParser(filepath)
|
|
width, height = p.get_image_size()
|
|
graph_dict = self.graph_builder.buildGraph(p.get_all_shape())
|
|
|
|
gt_bbox, gt_labels = self._get_bbox(filepath, width, height)
|
|
bbox, labels, gt_object = self.gen_y(graph_dict, gt_bbox, gt_labels, width, height)
|
|
|
|
feats = graph_dict['f']
|
|
pos = graph_dict['pos']
|
|
is_control = np.zeros((pos.shape[0], 1))
|
|
|
|
edge = graph_dict['edge']
|
|
|
|
feats = torch.tensor(feats, dtype=torch.float32)
|
|
pos = torch.tensor(pos, dtype=torch.float32)
|
|
edge = torch.tensor(edge, dtype=torch.long)
|
|
is_control = torch.tensor(is_control, dtype=torch.bool)
|
|
bbox = torch.tensor(bbox, dtype=torch.float32)
|
|
labels = torch.tensor(labels, dtype=torch.long)
|
|
gt_bbox = torch.tensor(gt_bbox, dtype=torch.float32)
|
|
gt_labels = torch.tensor(gt_labels, dtype=torch.long)
|
|
gt_object = torch.tensor(gt_object, dtype=torch.long)
|
|
|
|
e_weight = torch.tensor(graph_dict['edge_weight'], dtype=torch.float32)
|
|
|
|
#print('bbox', bbox.size())
|
|
#print('labels', labels.size())
|
|
#raise SystemExit
|
|
|
|
data = Data(x = feats, pos = pos)
|
|
data.edge = edge
|
|
#data.edge_control = None
|
|
#data.edge_pos = None
|
|
data.is_control = is_control
|
|
data.bbox = bbox
|
|
data.labels = labels
|
|
data.gt_bbox = gt_bbox
|
|
data.gt_labels = gt_labels
|
|
data.gt_object = gt_object
|
|
data.filepath = filepath
|
|
data.width = width
|
|
data.height = height
|
|
data.e_weight = e_weight
|
|
|
|
return data
|
|
|
|
|
|
if __name__ == '__main__':
|
|
svg_list = open('/home/xinyangjiang/Datasets/SESYD/FloorPlans/train_list.txt').readlines()
|
|
svg_list = ['/home/xinyangjiang/Datasets/SESYD/FloorPlans/' + line.strip() for line in svg_list]
|
|
builder = SVGGraphBuilder()
|
|
for line in svg_list:
|
|
print(line)
|
|
#line = '/home/xinyangjiang/Datasets/SESYD/FloorPlans/floorplans16-01/file_56.svg'
|
|
p = SVGParser(line)
|
|
builder.buildGraph(p.get_all_shape())
|
|
|
|
#train_dataset = SESYDFloorPlan(opt.data_dir, pre_transform=T.NormalizeScale())
|
|
#train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)
|
|
#for batch in train_loader:
|
|
# pass
|
|
|
|
#paths, attributes, svg_attributes = svg2paths2('/home/xinyangjiang/Datasets/SESYD/FloorPlans/floorplans16-05/file_47.svg')
|
|
#print(paths, attributes, svg_attributes)
|