зеркало из https://github.com/microsoft/O-CNN.git
Update pytorch
This commit is contained in:
Родитель
3386df2e8f
Коммит
224d614637
|
@ -10,36 +10,38 @@ using torch::Tensor;
|
|||
|
||||
vector<float> bounding_sphere(Tensor data_in, string method);
|
||||
Tensor normalize_points(Tensor data_in, float radius, vector<float> center);
|
||||
Tensor transform_points(Tensor data_in, vector<float> angle, vector<float> scale,
|
||||
vector<float> jitter, float offset);
|
||||
Tensor transform_points(Tensor data_in, vector<float> angle, vector<float> scale,
|
||||
vector<float> jitter, float offset, string normal_axis);
|
||||
vector<Tensor> clip_points(Tensor data_in, vector<float> bbmin, vector<float> bbmax);
|
||||
|
||||
Tensor octree_batch(vector<Tensor> tensors_in);
|
||||
vector<Tensor> octree_samples(vector<string> names);
|
||||
Tensor octree_property(Tensor octree_in, string property, int depth);
|
||||
Tensor octree_set_property(Tensor octree_in, Tensor data_in, int depth);
|
||||
Tensor points2octree(Tensor points, int depth, int full_depth, bool node_dis,
|
||||
bool node_feature, bool split_label, bool adaptive,
|
||||
int adp_depth, float th_normal, float th_distance,
|
||||
bool extrapolate, bool save_pts, bool key2xyz);
|
||||
|
||||
Tensor octree2col(Tensor data_in, Tensor octree, int depth,
|
||||
vector<int> kernel_size, int stride);
|
||||
vector<int> kernel_size, int stride, bool nempty);
|
||||
Tensor col2octree(Tensor grad_in, Tensor octree, int depth,
|
||||
vector<int> kernel_size, int stride);
|
||||
Tensor octree2colP(Tensor data_in, Tensor octree, int depth,
|
||||
vector<int> kernel_size, int stride);
|
||||
Tensor col2octreeP(Tensor grad_in, Tensor octree, int depth,
|
||||
vector<int> kernel_size, int stride);
|
||||
vector<int> kernel_size, int stride, bool nempty);
|
||||
|
||||
Tensor octree_conv(Tensor data_in, Tensor weights, Tensor octree, int depth,
|
||||
int num_output, vector<int> kernel_size, int stride);
|
||||
int num_output, vector<int> kernel_size, int stride,
|
||||
bool nempty);
|
||||
Tensor octree_deconv(Tensor data_in, Tensor weights, Tensor octree, int depth,
|
||||
int num_output, vector<int> kernel_size, int stride);
|
||||
int num_output, vector<int> kernel_size, int stride,
|
||||
bool nempty);
|
||||
vector<Tensor> octree_conv_grad(Tensor data_in, Tensor weights, Tensor octree,
|
||||
Tensor grad_in, int depth, int num_output,
|
||||
vector<int> kernel_size, int stride);
|
||||
vector<int> kernel_size, int stride,
|
||||
bool nempty);
|
||||
vector<Tensor> octree_deconv_grad(Tensor data_in, Tensor weights, Tensor octree,
|
||||
Tensor grad_in, int depth, int num_output,
|
||||
vector<int> kernel_size, int stride);
|
||||
vector<int> kernel_size, int stride,
|
||||
bool nempty);
|
||||
|
||||
Tensor octree_pad(Tensor data_in, Tensor octree, int depth, float val = 0.0f);
|
||||
Tensor octree_depad(Tensor data_in, Tensor octree, int depth);
|
||||
|
@ -52,7 +54,17 @@ Tensor octree_encode_key(Tensor xyz);
|
|||
Tensor octree_decode_key(Tensor key);
|
||||
Tensor octree_key2xyz(Tensor key, int depth);
|
||||
Tensor octree_xyz2key(Tensor xyz, int depth);
|
||||
Tensor octree_search_key(Tensor key, Tensor data_in, int depth, bool is_in_xyz);
|
||||
Tensor octree_search_key(Tensor key, Tensor octree, int depth, bool key_is_xyz,
|
||||
bool nempty);
|
||||
|
||||
Tensor octree_scan(Tensor octree, vector<float> axis, float scale);
|
||||
Tensor octree_grow(Tensor octree_in, int target_depth, bool full_octree);
|
||||
Tensor octree_update(Tensor octree_in, Tensor label_in, int depth, int split);
|
||||
Tensor octree_new(int batch_size, int channel, bool node_dis, int adaptive_layer);
|
||||
|
||||
vector<Tensor> octree_align(Tensor src_data, Tensor src_octree,
|
||||
Tensor des_octree, int depth);
|
||||
Tensor octree_align_grad(Tensor des_grad, Tensor idx_tensor);
|
||||
|
||||
Tensor points_new(Tensor pts, Tensor normals, Tensor features, Tensor labels);
|
||||
Tensor points_property(Tensor points, string property);
|
||||
|
|
|
@ -5,6 +5,18 @@
|
|||
|
||||
namespace {
|
||||
|
||||
Tensor get_ichild(const OctreeParser& octree_, const int depth,
|
||||
const torch::TensorOptions& options) {
|
||||
int node_num = octree_.info().node_num(depth);
|
||||
int node_num_ne = octree_.info().node_num_nempty(depth);
|
||||
const int* child = octree_.children_gpu(depth);
|
||||
Tensor tmp = torch::arange(node_num, options.dtype(torch::kInt32));
|
||||
Tensor ichild = torch::zeros(node_num_ne, options.dtype(torch::kInt32));
|
||||
pad_backward_gpu(ichild.data_ptr<int>(), node_num_ne, 1, tmp.data_ptr<int>(),
|
||||
node_num, child);
|
||||
return ichild;
|
||||
}
|
||||
|
||||
class Octree2ColBase {
|
||||
public:
|
||||
explicit Octree2ColBase(int depth, std::vector<int> kernel_size, int stride)
|
||||
|
@ -48,6 +60,7 @@ class OctreeToColOp : public Octree2ColBase {
|
|||
int btm_depth = this->depth_;
|
||||
int channel = data_in.size(1);
|
||||
int btm_height = data_in.size(2);
|
||||
data_in = data_in.contiguous();
|
||||
CHECK_EQ(octree_.info().node_num(btm_depth), btm_height);
|
||||
|
||||
// output data
|
||||
|
@ -58,7 +71,8 @@ class OctreeToColOp : public Octree2ColBase {
|
|||
CHECK_EQ(top_height, octree_.info().node_num_nempty(top_depth));
|
||||
}
|
||||
int kernel_sdim = num_elements(this->kernel_size_);
|
||||
Tensor data_out = torch::zeros({channel, kernel_sdim, top_height}, data_in.options());
|
||||
Tensor data_out =
|
||||
torch::zeros({channel, kernel_sdim, top_height}, data_in.options());
|
||||
|
||||
// execute
|
||||
octree2col_gpu(data_out.data_ptr<float>(), data_in.data_ptr<float>(),
|
||||
|
@ -83,6 +97,7 @@ class ColToOctreeOp : public Octree2ColBase {
|
|||
// in grad shape, data format: [channel, kernel_sdim, top_height]
|
||||
int channel = grad_in.size(0);
|
||||
int top_height = grad_in.size(2);
|
||||
grad_in = grad_in.contiguous();
|
||||
|
||||
// out grad
|
||||
int btm_depth = this->depth_;
|
||||
|
@ -90,7 +105,8 @@ class ColToOctreeOp : public Octree2ColBase {
|
|||
if (this->stride_ == 2) {
|
||||
CHECK_EQ(top_height, octree_.info().node_num_nempty(btm_depth - 1));
|
||||
}
|
||||
Tensor grad_out = torch::zeros({1, channel, btm_height, 1}, grad_in.options());
|
||||
Tensor grad_out =
|
||||
torch::zeros({1, channel, btm_height, 1}, grad_in.options());
|
||||
|
||||
// execute
|
||||
int kernel_sdim = num_elements(this->kernel_size_);
|
||||
|
@ -118,6 +134,7 @@ class OctreeToColPOp : public Octree2ColBase {
|
|||
int btm_depth = this->depth_;
|
||||
int channel = data_in.size(1);
|
||||
int btm_height = data_in.size(2);
|
||||
data_in = data_in.contiguous();
|
||||
int node_num = octree_.info().node_num(btm_depth);
|
||||
int node_num_ne = octree_.info().node_num_nempty(btm_depth);
|
||||
CHECK_EQ(node_num_ne, btm_height);
|
||||
|
@ -163,6 +180,7 @@ class ColToOctreePOp : public Octree2ColBase {
|
|||
torch::TensorOptions options = grad_in.options();
|
||||
int channel = grad_in.size(0);
|
||||
int top_height = grad_in.size(2);
|
||||
grad_in = grad_in.contiguous();
|
||||
|
||||
// child pointer
|
||||
int btm_depth = this->depth_;
|
||||
|
@ -174,7 +192,6 @@ class ColToOctreePOp : public Octree2ColBase {
|
|||
int* ichild = t1.data_ptr<int>();
|
||||
pad_backward_gpu(ichild, node_num_ne, 1, t0.data_ptr<int>(), node_num, child);
|
||||
|
||||
|
||||
// out grad
|
||||
int btm_height = node_num_ne;
|
||||
if (this->stride_ == 2) {
|
||||
|
@ -187,36 +204,34 @@ class ColToOctreePOp : public Octree2ColBase {
|
|||
// execute
|
||||
int kernel_sdim = num_elements(this->kernel_size_);
|
||||
col2octreeP_gpu(grad_in.data_ptr<float>(), grad_out.data_ptr<float>(),
|
||||
channel, top_height, btm_height, kernel_sdim, this->stride_,
|
||||
octree_.neighbor_gpu(btm_depth), ni_ptr, child, ichild,
|
||||
top_height, 0);
|
||||
channel, top_height, btm_height, kernel_sdim, this->stride_,
|
||||
octree_.neighbor_gpu(btm_depth), ni_ptr, child, ichild,
|
||||
top_height, 0);
|
||||
return grad_out;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
} // anonymous namespace
|
||||
|
||||
// API implementation
|
||||
Tensor octree2col(Tensor data_in, Tensor octree, int depth,
|
||||
std::vector<int> kernel_size, int stride) {
|
||||
OctreeToColOp op(depth, kernel_size, stride);
|
||||
return op.compute(data_in, octree);
|
||||
std::vector<int> kernel_size, int stride, bool nempty) {
|
||||
if (!nempty) {
|
||||
OctreeToColOp op(depth, kernel_size, stride);
|
||||
return op.compute(data_in, octree);
|
||||
} else {
|
||||
OctreeToColPOp op(depth, kernel_size, stride);
|
||||
return op.compute(data_in, octree);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor col2octree(Tensor grad_in, Tensor octree, int depth,
|
||||
std::vector<int> kernel_size, int stride) {
|
||||
ColToOctreeOp op(depth, kernel_size, stride);
|
||||
return op.compute(grad_in, octree);
|
||||
}
|
||||
|
||||
Tensor octree2colP(Tensor data_in, Tensor octree, int depth,
|
||||
std::vector<int> kernel_size, int stride) {
|
||||
OctreeToColPOp op(depth, kernel_size, stride);
|
||||
return op.compute(data_in, octree);
|
||||
}
|
||||
|
||||
Tensor col2octreeP(Tensor grad_in, Tensor octree, int depth,
|
||||
std::vector<int> kernel_size, int stride) {
|
||||
ColToOctreePOp op(depth, kernel_size, stride);
|
||||
return op.compute(grad_in, octree);
|
||||
std::vector<int> kernel_size, int stride, bool nempty) {
|
||||
if (!nempty) {
|
||||
ColToOctreeOp op(depth, kernel_size, stride);
|
||||
return op.compute(grad_in, octree);
|
||||
} else {
|
||||
ColToOctreePOp op(depth, kernel_size, stride);
|
||||
return op.compute(grad_in, octree);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
#ifndef KEY64
|
||||
#define KEY64
|
||||
#endif
|
||||
#include <octree/octree_nn.h>
|
||||
#include <octree/octree_parser.h>
|
||||
|
||||
#include "ocnn.h"
|
||||
|
||||
vector<Tensor> octree_align(Tensor src_data, Tensor src_octree,
|
||||
Tensor des_octree, int depth) {
|
||||
// in data
|
||||
src_data = src_data.contiguous();
|
||||
float* src_ptr = src_data.data_ptr<float>();
|
||||
int src_h = src_data.size(2);
|
||||
int channel = src_data.size(1);
|
||||
|
||||
// octrees
|
||||
OctreeParser src_parser, des_parser;
|
||||
src_parser.set_gpu(src_octree.data_ptr<uint8_t>());
|
||||
des_parser.set_gpu(des_octree.data_ptr<uint8_t>());
|
||||
int des_h = des_parser.info().node_num(depth);
|
||||
CHECK_EQ(src_parser.info().node_num(depth), src_h);
|
||||
|
||||
// get key
|
||||
torch::TensorOptions options = src_octree.options();
|
||||
const uintk* src_key = src_parser.key_gpu(depth);
|
||||
Tensor src_key_tensor;
|
||||
if (src_parser.info().is_key2xyz()) {
|
||||
src_key_tensor = torch::zeros({src_h}, options.dtype(torch::kInt64));
|
||||
uintk* ptr = (uintk*)src_key_tensor.data_ptr<int64_t>();
|
||||
xyz2key_gpu(ptr, src_key, src_h, depth);
|
||||
src_key = ptr;
|
||||
}
|
||||
|
||||
const uintk* des_key = des_parser.key_gpu(depth);
|
||||
Tensor des_key_tensor;
|
||||
if (des_parser.info().is_key2xyz()) {
|
||||
des_key_tensor = torch::zeros({des_h}, options.dtype(torch::kInt64));
|
||||
uintk* ptr = (uintk*)des_key_tensor.data_ptr<int64_t>();
|
||||
xyz2key_gpu(ptr, des_key, des_h, depth);
|
||||
des_key = ptr;
|
||||
}
|
||||
|
||||
// binary search
|
||||
Tensor idx_tensor = torch::zeros({src_h}, options.dtype(torch::kInt32));
|
||||
int* idx_ptr = idx_tensor.data_ptr<int>();
|
||||
search_key_gpu(idx_ptr, des_key, des_h, src_key, src_h);
|
||||
|
||||
// out data
|
||||
Tensor des_tensor =
|
||||
torch::zeros({1, channel, des_h, 1}, options.dtype(torch::kFloat32));
|
||||
float* des_ptr = des_tensor.data_ptr<float>();
|
||||
|
||||
// exec
|
||||
align_forward_gpu(des_ptr, des_h, channel, src_ptr, src_h, idx_ptr);
|
||||
return {des_tensor, idx_tensor};
|
||||
}
|
||||
|
||||
Tensor octree_align_grad(Tensor des_grad, Tensor idx_tensor) {
|
||||
// gradients
|
||||
des_grad = des_grad.contiguous();
|
||||
float* des_ptr = des_grad.data_ptr<float>();
|
||||
int channel = des_grad.size(1);
|
||||
int des_h = des_grad.size(2);
|
||||
|
||||
// index
|
||||
int src_h = idx_tensor.size(0);
|
||||
int* idx_ptr = idx_tensor.data_ptr<int>();
|
||||
|
||||
// grad out
|
||||
torch::TensorOptions options = des_grad.options();
|
||||
Tensor src_tensor = torch::zeros({1, channel, src_h, 1}, options);
|
||||
float* src_ptr = src_tensor.data_ptr<float>();
|
||||
|
||||
// exec
|
||||
align_backward_gpu(des_ptr, des_h, channel, src_ptr, src_h, idx_ptr);
|
||||
return src_tensor;
|
||||
}
|
|
@ -8,6 +8,25 @@ namespace {
|
|||
|
||||
using octree::OctreeBaseConv;
|
||||
|
||||
// used for debug
|
||||
template <typename dtype>
|
||||
void dump_tensor(const Tensor tensor, string filename="") {
|
||||
int dim = tensor.dim();
|
||||
filename += "_shape";
|
||||
for (int j = 0; j < dim; ++j) {
|
||||
filename += "_" + std::to_string(tensor.size(j));
|
||||
}
|
||||
|
||||
std::cout << filename << std::endl;
|
||||
|
||||
Tensor t = tensor.cpu();
|
||||
int n = t.numel();
|
||||
std::ofstream outfile(filename, std::ios::binary);
|
||||
const dtype* ptr = t.data_ptr<dtype>();
|
||||
outfile.write((char*) ptr, n * sizeof(dtype));
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
class THGpuGemm : public octree::GEMMEngine<float> {
|
||||
public:
|
||||
virtual void gemm(const bool TransA, const bool TransB, const int M,
|
||||
|
@ -27,11 +46,15 @@ class THGpuGemm : public octree::GEMMEngine<float> {
|
|||
|
||||
class OctreeConvTH : public OctreeBaseConv<float> {
|
||||
public:
|
||||
explicit OctreeConvTH(int depth, int num_output, vector<int> kernel_size, int stride)
|
||||
: depth_(depth), num_output_(num_output), kernel_size_(kernel_size), stride_(stride) {
|
||||
explicit OctreeConvTH(int depth, int num_output, vector<int> kernel_size,
|
||||
int stride, bool nempty)
|
||||
: depth_(depth), num_output_(num_output), kernel_size_(kernel_size),
|
||||
stride_(stride), non_empty_(nempty) {
|
||||
CHECK_GT(depth_, 0) << "The depth should be larger than 0";
|
||||
CHECK_GT(num_output_, 0) << "The num_output should be larger than 0";
|
||||
for (auto k : kernel_size_) { CHECK(0 < k && k < 4) << "Invalide kernel size"; }
|
||||
for (auto k : kernel_size_) {
|
||||
CHECK(0 < k && k < 4) << "Invalide kernel size";
|
||||
}
|
||||
CHECK(stride_ == 1 || stride_ == 2) << "Unsupport stride";
|
||||
}
|
||||
|
||||
|
@ -44,10 +67,11 @@ class OctreeConvTH : public OctreeBaseConv<float> {
|
|||
|
||||
// setup octree conv
|
||||
int channel_in = data_in.size(1), height_btm = data_in.size(2);
|
||||
OctreeBaseConv<float>::setup(kernel_size_, stride_, depth_, channel_in,
|
||||
num_output_);
|
||||
if (stride_ == 2 && is_deconvolution_layer()) {
|
||||
CHECK_EQ(height_btm, this->octree_.info().node_num_nempty(depth_));
|
||||
OctreeBaseConv<float>::setup(
|
||||
kernel_size_, stride_, depth_, channel_in, num_output_, non_empty_);
|
||||
if ((stride_ == 2 && is_deconvolution_layer()) || non_empty_) {
|
||||
CHECK_EQ(height_btm, this->octree_.info().node_num_nempty(depth_))
|
||||
<< ", d: " << depth_ << ", channel_in: " << channel_in;
|
||||
} else {
|
||||
CHECK_EQ(height_btm, this->octree_.info().node_num(depth_))
|
||||
<< ", d: " << depth_ << ", channel_in: " << channel_in;
|
||||
|
@ -72,15 +96,6 @@ class OctreeConvTH : public OctreeBaseConv<float> {
|
|||
this->result_buffer_ = nullptr;
|
||||
}
|
||||
|
||||
count = num_elements(this->data_buffer_shape_);
|
||||
if (count != 0) {
|
||||
Tensor data_buffer = torch::zeros({count}, options);
|
||||
this->data_buffer_ = data_buffer.data_ptr<float>();
|
||||
tmp_tensors.push_back(data_buffer);
|
||||
} else {
|
||||
this->data_buffer_ = nullptr;
|
||||
}
|
||||
|
||||
vector<int>& ni_cpu = NeighHelper::get_ni(kernel_size_);
|
||||
count = ni_cpu.size();
|
||||
if (count != 0) {
|
||||
|
@ -90,6 +105,17 @@ class OctreeConvTH : public OctreeBaseConv<float> {
|
|||
this->ni_gpu_ptr_ = ni_ptr;
|
||||
tmp_tensors.push_back(ni_gpu);
|
||||
}
|
||||
|
||||
if (non_empty_) {
|
||||
this->child_ = octree_.children_gpu(this->workspace_depth_);
|
||||
Tensor t0 = torch::arange(this->child_h_, options.dtype(torch::kInt32));
|
||||
Tensor t1 = torch::zeros(this->ichild_h_, options.dtype(torch::kInt32));
|
||||
this->ichild_ = t1.data_ptr<int>();
|
||||
pad_backward_gpu(t1.data_ptr<int>(), this->ichild_h_, 1,
|
||||
t0.data_ptr<int>(), this->child_h_, this->child_);
|
||||
tmp_tensors.push_back(t1);
|
||||
}
|
||||
|
||||
return tmp_tensors;
|
||||
}
|
||||
|
||||
|
@ -98,22 +124,26 @@ class OctreeConvTH : public OctreeBaseConv<float> {
|
|||
int num_output_;
|
||||
vector<int> kernel_size_;
|
||||
int stride_;
|
||||
bool non_empty_;
|
||||
THGpuGemm th_gpu_gemm_;
|
||||
};
|
||||
|
||||
class OctreeConvOp : public OctreeConvTH {
|
||||
public:
|
||||
explicit OctreeConvOp(int depth, int num_output, vector<int> kernel_size, int stride)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride) {}
|
||||
explicit OctreeConvOp(int depth, int num_output, vector<int> kernel_size,
|
||||
int stride, bool nempty)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride, nempty) {}
|
||||
|
||||
Tensor compute(Tensor data_in, Tensor weights, Tensor octree) {
|
||||
// init
|
||||
this->setup_op(data_in, octree);
|
||||
torch::TensorOptions options = data_in.options();
|
||||
vector<Tensor> tmp_tensors = this->alloc_temp_memory(options);
|
||||
Tensor data_out = torch::zeros({1, this->top_shape_[1], this->top_shape_[2], 1}, options);
|
||||
Tensor data_out =
|
||||
torch::zeros({1, this->top_shape_[1], this->top_shape_[2], 1}, options);
|
||||
|
||||
// get pointers
|
||||
data_in = data_in.contiguous();
|
||||
const float* btm_data = data_in.data_ptr<float>();
|
||||
const float* weights_data = weights.data_ptr<float>();
|
||||
float* top_data = data_out.data_ptr<float>();
|
||||
|
@ -128,10 +158,12 @@ class OctreeConvOp : public OctreeConvTH {
|
|||
|
||||
class OctreeConvGradOp : public OctreeConvTH {
|
||||
public:
|
||||
explicit OctreeConvGradOp(int depth, int num_output, vector<int> kernel_size, int stride)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride) {}
|
||||
explicit OctreeConvGradOp(int depth, int num_output, vector<int> kernel_size,
|
||||
int stride, bool nempty)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride, nempty) {}
|
||||
|
||||
vector<Tensor> compute(Tensor data_in, Tensor weights, Tensor octree, Tensor diff_in) {
|
||||
vector<Tensor> compute(Tensor data_in, Tensor weights, Tensor octree,
|
||||
Tensor diff_in) {
|
||||
// init
|
||||
this->setup_op(data_in, octree);
|
||||
vector<Tensor> tmp_tensors = this->alloc_temp_memory(data_in.options());
|
||||
|
@ -139,6 +171,8 @@ class OctreeConvGradOp : public OctreeConvTH {
|
|||
Tensor weights_out = torch::zeros_like(weights);
|
||||
|
||||
// get points
|
||||
data_in = data_in.contiguous();
|
||||
diff_in = diff_in.contiguous();
|
||||
auto btm_data = data_in.data_ptr<float>();
|
||||
auto weights_data = weights.data_ptr<float>();
|
||||
auto top_diff = diff_in.data_ptr<float>();
|
||||
|
@ -156,17 +190,20 @@ class OctreeConvGradOp : public OctreeConvTH {
|
|||
|
||||
class OctreeDeconvOp : public OctreeConvTH {
|
||||
public:
|
||||
explicit OctreeDeconvOp(int depth, int num_output, vector<int> kernel_size, int stride)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride) {}
|
||||
explicit OctreeDeconvOp(int depth, int num_output, vector<int> kernel_size,
|
||||
int stride, bool nempty)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride, nempty) {}
|
||||
|
||||
Tensor compute(Tensor data_in, Tensor weights, Tensor octree) {
|
||||
// init
|
||||
this->setup_op(data_in, octree);
|
||||
torch::TensorOptions options = data_in.options();
|
||||
vector<Tensor> tmp_tensors = this->alloc_temp_memory(options);
|
||||
Tensor data_out = torch::zeros({1, this->top_shape_[1], this->top_shape_[2], 1}, options);
|
||||
Tensor data_out =
|
||||
torch::zeros({1, this->top_shape_[1], this->top_shape_[2], 1}, options);
|
||||
|
||||
// get pointers
|
||||
data_in = data_in.contiguous();
|
||||
const float* btm_data = data_in.data_ptr<float>();
|
||||
const float* weights_data = weights.data_ptr<float>();
|
||||
float* top_data = data_out.data_ptr<float>();
|
||||
|
@ -181,10 +218,12 @@ class OctreeDeconvOp : public OctreeConvTH {
|
|||
|
||||
class OctreeDeconvGradOp : public OctreeConvTH {
|
||||
public:
|
||||
explicit OctreeDeconvGradOp(int depth, int num_output, vector<int> kernel_size, int stride)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride) {}
|
||||
explicit OctreeDeconvGradOp(int depth, int num_output,
|
||||
vector<int> kernel_size, int stride, bool nempty)
|
||||
: OctreeConvTH(depth, num_output, kernel_size, stride, nempty) {}
|
||||
|
||||
vector<Tensor> compute(Tensor data_in, Tensor weights, Tensor octree, Tensor diff_in) {
|
||||
vector<Tensor> compute(Tensor data_in, Tensor weights, Tensor octree,
|
||||
Tensor diff_in) {
|
||||
// init
|
||||
this->setup_op(data_in, octree);
|
||||
vector<Tensor> tmp_tensors = this->alloc_temp_memory(data_in.options());
|
||||
|
@ -192,6 +231,8 @@ class OctreeDeconvGradOp : public OctreeConvTH {
|
|||
Tensor weights_out = torch::zeros_like(weights);
|
||||
|
||||
// get points
|
||||
data_in = data_in.contiguous();
|
||||
diff_in = diff_in.contiguous();
|
||||
auto btm_data = data_in.data_ptr<float>();
|
||||
auto weights_data = weights.data_ptr<float>();
|
||||
auto top_diff = diff_in.data_ptr<float>();
|
||||
|
@ -208,31 +249,35 @@ class OctreeDeconvGradOp : public OctreeConvTH {
|
|||
virtual bool is_deconvolution_layer() { return true; }
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
} // anonymous namespace
|
||||
|
||||
// API implementation
|
||||
Tensor octree_conv(Tensor data_in, Tensor weights, Tensor octree, int depth,
|
||||
int num_output, vector<int> kernel_size, int stride) {
|
||||
OctreeConvOp conv_op(depth, num_output, kernel_size, stride);
|
||||
int num_output, vector<int> kernel_size, int stride,
|
||||
bool nempty) {
|
||||
OctreeConvOp conv_op(depth, num_output, kernel_size, stride, nempty);
|
||||
return conv_op.compute(data_in, weights, octree);
|
||||
}
|
||||
|
||||
Tensor octree_deconv(Tensor data_in, Tensor weights, Tensor octree, int depth,
|
||||
int num_output, vector<int> kernel_size, int stride) {
|
||||
OctreeDeconvOp deconv_op(depth, num_output, kernel_size, stride);
|
||||
int num_output, vector<int> kernel_size, int stride,
|
||||
bool nempty) {
|
||||
OctreeDeconvOp deconv_op(depth, num_output, kernel_size, stride, nempty);
|
||||
return deconv_op.compute(data_in, weights, octree);
|
||||
}
|
||||
|
||||
vector<Tensor> octree_conv_grad(Tensor data_in, Tensor weights, Tensor octree,
|
||||
Tensor grad_in, int depth, int num_output,
|
||||
vector<int> kernel_size, int stride) {
|
||||
OctreeConvGradOp grad_op(depth, num_output, kernel_size, stride);
|
||||
vector<int> kernel_size, int stride,
|
||||
bool nempty) {
|
||||
OctreeConvGradOp grad_op(depth, num_output, kernel_size, stride, nempty);
|
||||
return grad_op.compute(data_in, weights, octree, grad_in);
|
||||
}
|
||||
|
||||
vector<Tensor> octree_deconv_grad(Tensor data_in, Tensor weights, Tensor octree,
|
||||
Tensor grad_in, int depth, int num_output,
|
||||
vector<int> kernel_size, int stride) {
|
||||
OctreeDeconvGradOp grad_op(depth, num_output, kernel_size, stride);
|
||||
Tensor grad_in, int depth, int num_output,
|
||||
vector<int> kernel_size, int stride,
|
||||
bool nempty) {
|
||||
OctreeDeconvGradOp grad_op(depth, num_output, kernel_size, stride, nempty);
|
||||
return grad_op.compute(data_in, weights, octree, grad_in);
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
#include <octree/octree_nn.h>
|
||||
#include <octree/octree_parser.h>
|
||||
|
||||
#include "ocnn.h"
|
||||
|
||||
namespace {
|
||||
|
||||
class OctreeGrowOp {
|
||||
public:
|
||||
explicit OctreeGrowOp(int target_depth, bool full_octree)
|
||||
: target_depth_(target_depth), full_octree_(full_octree) {}
|
||||
|
||||
Tensor compute(Tensor tensor_in) {
|
||||
// in octree
|
||||
OctreeParser octree_in;
|
||||
octree_in.set_gpu(tensor_in.data_ptr<uint8_t>());
|
||||
|
||||
// out info
|
||||
batch_size_ = octree_in.info().batch_size();
|
||||
node_num_ = octree_in.info().node_num_nempty(target_depth_ - 1) << 3;
|
||||
OctreeInfo oct_info = octree_in.info();
|
||||
update_octreeinfo(oct_info);
|
||||
|
||||
// out octree
|
||||
torch::TensorOptions options = tensor_in.options();
|
||||
Tensor tensor_out = torch::zeros({oct_info.sizeof_octree()}, options);
|
||||
|
||||
// copy octree
|
||||
OctreeParser octree_out;
|
||||
octree_out.set_gpu(tensor_out.data_ptr<uint8_t>(), &oct_info);
|
||||
copy_octree_gpu(octree_out, octree_in);
|
||||
|
||||
// grow octree
|
||||
if (full_octree_) {
|
||||
calc_neigh_gpu(octree_out.mutable_neighbor_gpu(target_depth_),
|
||||
target_depth_, batch_size_);
|
||||
generate_key_gpu(octree_out.mutable_key_gpu(target_depth_), target_depth_,
|
||||
batch_size_);
|
||||
sequence_gpu(octree_out.mutable_children_gpu(target_depth_), node_num_);
|
||||
} else {
|
||||
vector<Tensor> tmp = init_neigh_ptrs(options);
|
||||
const int* label_ptr = octree_out.children_gpu(target_depth_ - 1);
|
||||
calc_neigh_gpu(octree_out.mutable_neighbor_gpu(target_depth_),
|
||||
octree_out.neighbor_gpu(target_depth_ - 1), label_ptr,
|
||||
octree_out.info().node_num(target_depth_ - 1), ptr_parent_,
|
||||
ptr_dis_);
|
||||
generate_key_gpu(octree_out.mutable_key_gpu(target_depth_),
|
||||
octree_out.key_gpu(target_depth_ - 1), label_ptr,
|
||||
octree_out.info().node_num(target_depth_ - 1));
|
||||
sequence_gpu(octree_out.mutable_children_gpu(target_depth_), node_num_);
|
||||
}
|
||||
|
||||
return tensor_out;
|
||||
}
|
||||
|
||||
private:
|
||||
void update_octreeinfo(OctreeInfo& oct_info) {
|
||||
oct_info.set_depth(target_depth_);
|
||||
if (full_octree_) {
|
||||
oct_info.set_full_layer(target_depth_);
|
||||
}
|
||||
float width = 1 << target_depth_;
|
||||
float bbmin[] = {0, 0, 0};
|
||||
float bbmax[] = {width, width, width};
|
||||
oct_info.set_bbox(bbmin, bbmax);
|
||||
oct_info.set_nnum(target_depth_, node_num_);
|
||||
// Just set the non-empty node number as node_num_,
|
||||
// it needs to be updated by the new node-splitting label
|
||||
oct_info.set_nempty(target_depth_, node_num_);
|
||||
oct_info.set_nnum_cum();
|
||||
oct_info.set_ptr_dis();
|
||||
}
|
||||
|
||||
void copy_octree_gpu(OctreeParser& octree_o, const OctreeParser& octree_i) {
|
||||
int node_num_cum = octree_i.info().node_num_cum(target_depth_);
|
||||
int num = node_num_cum * octree_i.info().channel(OctreeInfo::kKey);
|
||||
memcpy_gpu(num, octree_i.key_gpu(0), octree_o.mutable_key_gpu(0));
|
||||
|
||||
num = node_num_cum * octree_i.info().channel(OctreeInfo::kChild);
|
||||
memcpy_gpu(num, octree_i.children_gpu(0), octree_o.mutable_children_gpu(0));
|
||||
|
||||
num = node_num_cum * octree_i.info().channel(OctreeInfo::kNeigh);
|
||||
memcpy_gpu(num, octree_i.neighbor_gpu(0), octree_o.mutable_neighbor_gpu(0));
|
||||
|
||||
num = node_num_cum * octree_i.info().channel(OctreeInfo::kFeature);
|
||||
memcpy_gpu(num, octree_i.feature_gpu(0), octree_o.mutable_feature_gpu(0));
|
||||
}
|
||||
|
||||
vector<Tensor> init_neigh_ptrs(torch::TensorOptions options) {
|
||||
const vector<int>& dis_cpu = NeighHelper::Get().get_dis_array();
|
||||
int count = dis_cpu.size();
|
||||
Tensor dis_gpu = torch::zeros({count}, options.dtype(torch::kInt32));
|
||||
ptr_dis_ = dis_gpu.data_ptr<int>();
|
||||
memcpy_gpu(count, dis_cpu.data(), ptr_dis_);
|
||||
|
||||
const vector<int>& parent_cpu = NeighHelper::Get().get_parent_array();
|
||||
count = parent_cpu.size();
|
||||
Tensor parent_gpu = torch::zeros({count}, options.dtype(torch::kInt32));
|
||||
ptr_parent_ = parent_gpu.data_ptr<int>();
|
||||
memcpy_gpu(count, parent_cpu.data(), ptr_parent_);
|
||||
|
||||
return vector<Tensor>{dis_gpu, parent_gpu};
|
||||
}
|
||||
|
||||
private:
|
||||
int batch_size_;
|
||||
int target_depth_;
|
||||
int node_num_;
|
||||
bool full_octree_;
|
||||
int* ptr_parent_;
|
||||
int* ptr_dis_;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// API implementation
|
||||
Tensor octree_grow(Tensor octree_in, int target_depth, bool full_octree) {
|
||||
OctreeGrowOp op(target_depth, full_octree);
|
||||
return op.compute(octree_in);
|
||||
}
|
||||
|
||||
Tensor octree_new(int batch_size, int channel, bool node_dis, int adaptive_layer) {
|
||||
CHECK_GE(batch_size, 1);
|
||||
int node_num = batch_size;
|
||||
int depth = 0;
|
||||
|
||||
// octree info
|
||||
OctreeInfo oct_info;
|
||||
oct_info.set_batch_size(batch_size);
|
||||
oct_info.set_depth(depth);
|
||||
oct_info.set_full_layer(depth);
|
||||
oct_info.set_node_dis(node_dis);
|
||||
if (adaptive_layer > 1) {
|
||||
oct_info.set_adaptive(true);
|
||||
oct_info.set_adaptive_layer(adaptive_layer);
|
||||
} else {
|
||||
oct_info.set_adaptive(false);
|
||||
}
|
||||
oct_info.set_key2xyz(true); // !!! NOTE: set_key2xyz with true
|
||||
oct_info.set_property(OctreeInfo::kKey, 1, -1);
|
||||
oct_info.set_property(OctreeInfo::kChild, 1, -1);
|
||||
oct_info.set_property(OctreeInfo::kNeigh, 8, -1);
|
||||
oct_info.set_property(OctreeInfo::kFeature, channel, -1);
|
||||
float bbmin[] = {0, 0, 0};
|
||||
float bbmax[] = {2, 2, 2};
|
||||
oct_info.set_bbox(bbmin, bbmax);
|
||||
oct_info.set_nnum(depth, node_num);
|
||||
oct_info.set_nnum_cum();
|
||||
oct_info.set_nempty(depth, node_num);
|
||||
oct_info.set_ptr_dis();
|
||||
|
||||
// init output tensor
|
||||
Tensor tensor_out = torch::zeros({oct_info.sizeof_octree()}, torch::kUInt8);
|
||||
|
||||
// set octree, skip the propoerties neigh and feature
|
||||
OctreeParser octree_out;
|
||||
octree_out.set_cpu(tensor_out.data_ptr<uint8_t>(), &oct_info);
|
||||
sequence_cpu(octree_out.mutable_key_cpu(depth), node_num); // !!! NOTE: inconsitent with L140
|
||||
sequence_cpu(octree_out.mutable_children_cpu(depth), node_num);
|
||||
|
||||
return tensor_out.cuda();
|
||||
}
|
||||
|
||||
Tensor octree_update(Tensor octree_in, Tensor label_in, int depth, int split) {
|
||||
Tensor tensor_out = octree_in.clone();
|
||||
uint8_t* out_ptr = tensor_out.data_ptr<uint8_t>();
|
||||
OctreeParser octree_;
|
||||
octree_.set_gpu(out_ptr);
|
||||
int node_num = octree_.info().node_num(depth);
|
||||
|
||||
label_in = label_in.contiguous();
|
||||
int* label_ptr = label_in.data_ptr<int>();
|
||||
CHECK_EQ(node_num, label_in.numel());
|
||||
|
||||
// update children
|
||||
int split_num = 0; // non-empty node number
|
||||
int* children = octree_.mutable_children_gpu(depth);
|
||||
generate_label_gpu(children, split_num, label_ptr, node_num, split);
|
||||
|
||||
// deal with degenatated case
|
||||
if (split_num == 0) {
|
||||
split_num = 1;
|
||||
memset_gpu(1, 0, children);
|
||||
LOG(INFO) << "Warning: split_num == 0 in octree update layer.";
|
||||
}
|
||||
|
||||
octree_.mutable_info().set_nempty(depth, split_num);
|
||||
memcpy_gpu(sizeof(OctreeInfo), (const char*)&octree_.info(), (char*)out_ptr);
|
||||
return tensor_out;
|
||||
}
|
|
@ -9,6 +9,7 @@
|
|||
// if KEY64 is defined, uintk is uint64
|
||||
|
||||
Tensor octree_encode_key(Tensor xyz) {
|
||||
xyz = xyz.contiguous(); // !!! make sure the Tensor is contiguous
|
||||
auto ptr_in = xyz.data_ptr<int16_t>();
|
||||
int num = xyz.size(0);
|
||||
int channel = xyz.size(1);
|
||||
|
@ -27,6 +28,7 @@ Tensor octree_encode_key(Tensor xyz) {
|
|||
}
|
||||
|
||||
Tensor octree_decode_key(Tensor key) {
|
||||
key = key.contiguous(); // !!! make sure the Tensor is contiguous
|
||||
auto ptr_in = key.data_ptr<int64_t>();
|
||||
int num = key.size(0);
|
||||
CHECK_EQ(key.dim(), 1) << "The dim of input tensor must be 1.";
|
||||
|
@ -43,6 +45,7 @@ Tensor octree_decode_key(Tensor key) {
|
|||
}
|
||||
|
||||
Tensor octree_xyz2key(Tensor xyz, int depth) {
|
||||
xyz = xyz.contiguous(); // !!! make sure the Tensor is contiguous
|
||||
auto ptr_in = xyz.data_ptr<int64_t>();
|
||||
int num = xyz.numel();
|
||||
CHECK_GE(num, 1) << "The numel of the input tensor must be larger than 1.";
|
||||
|
@ -54,6 +57,7 @@ Tensor octree_xyz2key(Tensor xyz, int depth) {
|
|||
}
|
||||
|
||||
Tensor octree_key2xyz(Tensor key, int depth) {
|
||||
key = key.contiguous(); // !!! make sure the Tensor is contiguous
|
||||
auto ptr_in = key.data_ptr<int64_t>();
|
||||
int num = key.numel();
|
||||
CHECK_GE(num, 1) << "The numel of the input tensor must be larger than 1.";
|
||||
|
@ -64,14 +68,17 @@ Tensor octree_key2xyz(Tensor key, int depth) {
|
|||
return xyz;
|
||||
}
|
||||
|
||||
Tensor octree_search_key(Tensor key, Tensor octree, int depth, bool is_in_xyz) {
|
||||
Tensor octree_search_key(Tensor key, Tensor octree, int depth, bool key_is_xyz,
|
||||
bool nempty) {
|
||||
key = key.contiguous();
|
||||
octree = octree.contiguous(); // !!! make sure the Tensor is contiguous
|
||||
int64_t* src_key = key.data_ptr<int64_t>();
|
||||
int src_h = key.numel();
|
||||
CHECK_GE(src_h, 1) << "The numel of the input tensor must be larger than 1.";
|
||||
torch::TensorOptions options = key.options();
|
||||
|
||||
Tensor src_key_tmp;
|
||||
if (is_in_xyz) {
|
||||
if (key_is_xyz) {
|
||||
src_key_tmp = torch::zeros_like(key);
|
||||
int64_t* tmp = src_key_tmp.data_ptr<int64_t>();
|
||||
xyz2key_gpu((uintk*)tmp, (uintk*)src_key, src_h, depth);
|
||||
|
@ -83,6 +90,16 @@ Tensor octree_search_key(Tensor key, Tensor octree, int depth, bool is_in_xyz) {
|
|||
int des_h = octree_.info().node_num(depth);
|
||||
const uintk* des_key = octree_.key_gpu(depth);
|
||||
|
||||
Tensor key_tmp;
|
||||
if (nempty) { // Search the non-empty octree nodes only
|
||||
int top_h = des_h; // cache old des_h
|
||||
des_h = octree_.info().node_num_nempty(depth); // update des_h
|
||||
key_tmp = torch::zeros({des_h}, options);
|
||||
int64_t* tmp = key_tmp.data_ptr<int64_t>();
|
||||
pad_backward_gpu((uintk*)tmp, des_h, 1, des_key, top_h, octree_.children_gpu(depth));
|
||||
des_key = (const uintk*)tmp;
|
||||
}
|
||||
|
||||
Tensor des_key_tmp;
|
||||
if (octree_.info().is_key2xyz()) {
|
||||
des_key_tmp = torch::zeros({des_h}, options);
|
||||
|
|
|
@ -3,89 +3,58 @@
|
|||
|
||||
#include "ocnn.h"
|
||||
|
||||
namespace {
|
||||
Tensor octree_pad(Tensor data_in, Tensor octree_in, int depth, float val) {
|
||||
CHECK_GE(depth, 1) << "Depth should be larger than 1";
|
||||
|
||||
class OctreePadOp {
|
||||
public:
|
||||
explicit OctreePadOp(int depth, float val = 0.0f)
|
||||
: depth_(depth), dval_(val) {
|
||||
CHECK_GE(depth_, 1) << "Depth should be larger than 1";
|
||||
}
|
||||
// in octree
|
||||
OctreeParser octree_;
|
||||
octree_.set_gpu(octree_in.data_ptr<uint8_t>());
|
||||
|
||||
Tensor compute(Tensor btm_data, Tensor octree) {
|
||||
// in octree
|
||||
OctreeParser octree_;
|
||||
octree_.set_gpu(octree.data_ptr<uint8_t>());
|
||||
// btm data
|
||||
Tensor btm_data = data_in.contiguous();
|
||||
const float* btm_ptr = btm_data.data_ptr<float>();
|
||||
int channel = btm_data.size(1);
|
||||
int btm_h = btm_data.size(2);
|
||||
|
||||
// btm data
|
||||
const float* btm_ptr = btm_data.data_ptr<float>();
|
||||
int channel = btm_data.size(1);
|
||||
int btm_h = btm_data.size(2);
|
||||
// check
|
||||
CHECK_EQ(octree_.info().node_num_nempty(depth), btm_h)
|
||||
<< ", pad, d = " << depth << ", channel = " << channel;
|
||||
|
||||
// check
|
||||
CHECK_EQ(octree_.info().node_num_nempty(depth_), btm_h)
|
||||
<< ", pad, d = " << depth_ << ", channel = " << channel;
|
||||
// top data
|
||||
int top_h = octree_.info().node_num(depth);
|
||||
Tensor top_data = torch::zeros({1, channel, top_h, 1}, btm_data.options());
|
||||
float* top_ptr = top_data.data_ptr<float>();
|
||||
|
||||
// top data
|
||||
int top_h = octree_.info().node_num(depth_);
|
||||
Tensor top_data = torch::zeros({1, channel, top_h, 1}, btm_data.options());
|
||||
float* top_ptr = top_data.data_ptr<float>();
|
||||
|
||||
// padding data
|
||||
pad_forward_gpu(top_ptr, top_h, channel, btm_ptr, btm_h,
|
||||
octree_.children_gpu(depth_), dval_);
|
||||
return top_data;
|
||||
}
|
||||
|
||||
protected:
|
||||
int depth_;
|
||||
float dval_;
|
||||
};
|
||||
|
||||
class OctreeDepadOp {
|
||||
public:
|
||||
explicit OctreeDepadOp(int depth) : depth_(depth) {
|
||||
CHECK_GE(depth_, 1) << "Depth should be larger than 1";
|
||||
}
|
||||
|
||||
Tensor compute(Tensor top_data, Tensor octree) {
|
||||
// in octree
|
||||
OctreeParser octree_;
|
||||
octree_.set_gpu(octree.data_ptr<uint8_t>());
|
||||
|
||||
// top grad
|
||||
const float* top_ptr = top_data.data_ptr<float>();
|
||||
int channel = top_data.size(1);
|
||||
int top_h = top_data.size(2);
|
||||
|
||||
// check
|
||||
CHECK_EQ(octree_.info().node_num(depth_), top_h)
|
||||
<< ", depad, d = " << depth_ << ", channel = " << channel;
|
||||
|
||||
// btm grad
|
||||
int btm_h = octree_.info().node_num_nempty(depth_);
|
||||
Tensor btm_data = torch::zeros({1, channel, btm_h, 1}, top_data.options());
|
||||
float* btm_ptr = btm_data.data_ptr<float>();
|
||||
|
||||
// padding data
|
||||
pad_backward_gpu(btm_ptr, btm_h, channel, top_ptr, top_h,
|
||||
octree_.children_gpu(depth_));
|
||||
return btm_data;
|
||||
}
|
||||
|
||||
protected:
|
||||
int depth_;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// API implementation
|
||||
Tensor octree_pad(Tensor data_in, Tensor octree, int depth, float val) {
|
||||
OctreePadOp pad_op(depth, val);
|
||||
return pad_op.compute(data_in, octree);
|
||||
// padding data
|
||||
pad_forward_gpu(top_ptr, top_h, channel, btm_ptr, btm_h,
|
||||
octree_.children_gpu(depth), val);
|
||||
return top_data;
|
||||
}
|
||||
|
||||
Tensor octree_depad(Tensor data_in, Tensor octree, int depth) {
|
||||
OctreeDepadOp depad_op(depth);
|
||||
return depad_op.compute(data_in, octree);
|
||||
Tensor octree_depad(Tensor data_in, Tensor octree_in, int depth) {
|
||||
CHECK_GE(depth, 1) << "Depth should be larger than 1";
|
||||
|
||||
// in octree
|
||||
OctreeParser octree_;
|
||||
octree_.set_gpu(octree_in.data_ptr<uint8_t>());
|
||||
|
||||
// top grad
|
||||
Tensor top_data = data_in.contiguous();
|
||||
const float* top_ptr = top_data.data_ptr<float>();
|
||||
int channel = top_data.size(1);
|
||||
int top_h = top_data.size(2);
|
||||
|
||||
// check
|
||||
CHECK_EQ(octree_.info().node_num(depth), top_h)
|
||||
<< ", depad, d = " << depth << ", channel = " << channel;
|
||||
|
||||
// btm grad
|
||||
int btm_h = octree_.info().node_num_nempty(depth);
|
||||
Tensor btm_data = torch::zeros({1, channel, btm_h, 1}, top_data.options());
|
||||
float* btm_ptr = btm_data.data_ptr<float>();
|
||||
|
||||
// padding data
|
||||
pad_backward_gpu(btm_ptr, btm_h, channel, top_ptr, top_h,
|
||||
octree_.children_gpu(depth));
|
||||
return btm_data;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ vector<Tensor> octree_max_pool(Tensor btm_data, Tensor octree, int depth) {
|
|||
octree_.set_gpu(octree.data_ptr<uint8_t>());
|
||||
|
||||
// btm data
|
||||
btm_data = btm_data.contiguous();
|
||||
const float* btm_ptr = btm_data.data_ptr<float>();
|
||||
int channel = btm_data.size(1);
|
||||
int btm_h = btm_data.size(2);
|
||||
|
@ -38,12 +39,14 @@ Tensor octree_max_unpool(Tensor top_data, Tensor mask, Tensor octree, int depth)
|
|||
octree_.set_gpu(octree.data_ptr<uint8_t>());
|
||||
|
||||
// top data
|
||||
top_data = top_data.contiguous();
|
||||
const float* top_ptr = top_data.data_ptr<float>();
|
||||
int channel = top_data.size(1);
|
||||
int top_h = top_data.size(2);
|
||||
CHECK_EQ(top_h, octree_.info().node_num_nempty(depth - 1));
|
||||
|
||||
// mask
|
||||
mask = mask.contiguous();
|
||||
const int* mask_ptr = mask.data_ptr<int>();
|
||||
CHECK(mask.size(1) == channel && mask.size(2) == top_h);
|
||||
|
||||
|
@ -63,11 +66,13 @@ Tensor octree_mask_pool(Tensor btm_data, Tensor mask, Tensor octree, int depth)
|
|||
octree_.set_gpu(octree.data_ptr<uint8_t>());
|
||||
|
||||
// btm data
|
||||
btm_data = btm_data.contiguous();
|
||||
const float* btm_ptr = btm_data.data_ptr<float>();
|
||||
int channel = btm_data.size(1);
|
||||
int btm_h = btm_data.size(2);
|
||||
|
||||
// mask
|
||||
mask = mask.contiguous();
|
||||
auto mask_ptr = mask.data_ptr<int>();
|
||||
int top_h = mask.size(2);
|
||||
|
||||
|
|
|
@ -12,87 +12,133 @@ Tensor octree_property_gpu(Tensor octree_in, string property, int depth) {
|
|||
OctreeParser octree_;
|
||||
octree_.set_gpu(octree_in.data_ptr<uint8_t>());
|
||||
|
||||
int octree_depth = octree_.info().depth();
|
||||
int node_num = octree_.info().node_num(depth);
|
||||
int total_node_num = octree_.info().total_nnum();
|
||||
int nnum = depth > 0 ? node_num : total_node_num;
|
||||
|
||||
torch::TensorOptions options = octree_in.options();
|
||||
Tensor data_out = torch::zeros({1}, options);
|
||||
|
||||
if (property == "key") {
|
||||
const uintk* ptr = octree_.key_gpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt64));
|
||||
memcpy_gpu(total_num, ptr, (uintk*)data_out.data_ptr<int64_t>());
|
||||
}
|
||||
|
||||
if (property == "xyz") {
|
||||
else if (property == "xyz") {
|
||||
const uintk* ptr = octree_.key_gpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt64));
|
||||
uintk* des_ptr = (uintk*)data_out.data_ptr<int64_t>();
|
||||
if (!octree_.info().is_key2xyz()) {
|
||||
key2xyz_gpu((uintk*)data_out.data_ptr<int64_t>(), ptr, total_num, depth);
|
||||
if (depth > 0) {
|
||||
key2xyz_gpu(des_ptr, ptr, total_num, depth);
|
||||
} else {
|
||||
for (int d = 1; d < octree_depth + 1; d++) {
|
||||
int nnum_d = octree_.info().node_num(d);
|
||||
int ncum_d = octree_.info().node_num_cum(d);
|
||||
key2xyz_gpu(des_ptr + ncum_d, ptr + ncum_d, nnum_d, d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
memcpy_gpu(total_num, ptr, (uintk*)data_out.data_ptr<int64_t>());
|
||||
memcpy_gpu(total_num, ptr, des_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (property == "index") {
|
||||
else if (property == "index") {
|
||||
const uintk* key_ptr = octree_.key_gpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt32));
|
||||
key2idx_gpu(data_out.data_ptr<int>(), key_ptr, total_num);
|
||||
}
|
||||
|
||||
if (property == "child") {
|
||||
else if (property == "child") {
|
||||
const int* child_ptr = octree_.children_gpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kChild); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt32));
|
||||
memcpy_gpu(total_num, child_ptr, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
if (property == "neigh") {
|
||||
else if (property == "neigh") {
|
||||
const int* neigh_ptr = octree_.neighbor_gpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kNeigh);
|
||||
int total_num = channel * node_num;
|
||||
data_out = torch::zeros({node_num, channel}, options.dtype(torch::kInt32));
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({nnum, channel}, options.dtype(torch::kInt32));
|
||||
memcpy_gpu(total_num, neigh_ptr, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
if (property == "feature") {
|
||||
else if (property == "feature") {
|
||||
const float* feature_ptr = octree_.feature_gpu(depth);
|
||||
CHECK(feature_ptr != nullptr) << "The features do not exist: d = " << depth;
|
||||
int channel = octree_.info().channel(OctreeInfo::kFeature);
|
||||
int total_num = channel * node_num;
|
||||
data_out = torch::zeros({1, channel, node_num, 1}, options.dtype(torch::kFloat32));
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({1, channel, nnum, 1}, options.dtype(torch::kFloat32));
|
||||
memcpy_gpu(total_num, feature_ptr, data_out.data_ptr<float>());
|
||||
}
|
||||
|
||||
if (property == "label") {
|
||||
else if (property == "label") {
|
||||
const float* label_ptr = octree_.label_gpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kLabel);
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kFloat32));
|
||||
memcpy_gpu(total_num, label_ptr, data_out.data_ptr<float>());
|
||||
}
|
||||
|
||||
if (property == "split") {
|
||||
else if (property == "split") {
|
||||
const float* split_ptr = octree_.split_gpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kSplit);
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kFloat32));
|
||||
memcpy_gpu(total_num, split_ptr, data_out.data_ptr<float>());
|
||||
}
|
||||
|
||||
if (property == "node_num") {
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_gpu(1, &node_num, data_out.data_ptr<int>());
|
||||
else if (property == "node_num") {
|
||||
int num = depth > 0 ? 1 : octree_depth + 1;
|
||||
data_out = torch::zeros({num}, options.dtype(torch::kInt32));
|
||||
const int* ptr = octree_.info().node_num_ptr();
|
||||
memcpy_gpu(num, ptr + depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
if (property == "node_num_ne" || property == "node_num_nempty") {
|
||||
int num = octree_.info().node_num_nempty(depth);
|
||||
else if (property == "node_num_ne" || property == "node_num_nempty") {
|
||||
int num = depth > 0 ? 1 : octree_depth + 1;
|
||||
data_out = torch::zeros({num}, options.dtype(torch::kInt32));
|
||||
const int* ptr = octree_.info().node_nempty_ptr();
|
||||
memcpy_gpu(num, ptr + depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "node_num_cum") {
|
||||
int num = depth > 0 ? 1 : octree_depth + 2;
|
||||
const int* ptr = octree_.info().node_num_cum_ptr();
|
||||
data_out = torch::zeros({num}, options.dtype(torch::kInt32));
|
||||
memcpy_gpu(num, ptr + depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "batch_size") {
|
||||
int batch_size = octree_.info().batch_size();
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_gpu(1, &num, data_out.data_ptr<int>());
|
||||
memcpy_gpu(1, &batch_size, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "depth") {
|
||||
int depth = octree_.info().depth();
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_gpu(1, &depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "full_depth") {
|
||||
int full_depth = octree_.info().full_layer();
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_gpu(1, &full_depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else{
|
||||
LOG(FATAL) << "Unsupport octree property: " << property;
|
||||
}
|
||||
|
||||
return data_out;
|
||||
|
@ -102,92 +148,155 @@ Tensor octree_property_cpu(Tensor octree_in, string property, int depth) {
|
|||
OctreeParser octree_;
|
||||
octree_.set_cpu(octree_in.data_ptr<uint8_t>());
|
||||
|
||||
int octree_depth = octree_.info().depth();
|
||||
int node_num = octree_.info().node_num(depth);
|
||||
int total_node_num = octree_.info().total_nnum();
|
||||
int nnum = depth > 0 ? node_num : total_node_num;
|
||||
|
||||
torch::TensorOptions options = octree_in.options();
|
||||
Tensor data_out = torch::zeros({1}, options);
|
||||
|
||||
if (property == "key") {
|
||||
const uintk* ptr = octree_.key_cpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt64));
|
||||
memcpy_cpu(total_num, ptr, (uintk*)data_out.data_ptr<int64_t>());
|
||||
}
|
||||
|
||||
if (property == "xyz") {
|
||||
else if (property == "xyz") {
|
||||
const uintk* ptr = octree_.key_cpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt64));
|
||||
uintk* des_ptr = (uintk*)data_out.data_ptr<int64_t>();
|
||||
if (!octree_.info().is_key2xyz()) {
|
||||
key2xyz_cpu((uintk*)data_out.data_ptr<int64_t>(), ptr, total_num, depth);
|
||||
if (depth > 0) {
|
||||
key2xyz_cpu(des_ptr, ptr, total_num, depth);
|
||||
} else {
|
||||
for (int d = 1; d < octree_depth + 1; d++) {
|
||||
int nnum_d = octree_.info().node_num(d);
|
||||
int ncum_d = octree_.info().node_num_cum(d);
|
||||
key2xyz_cpu(des_ptr + ncum_d, ptr + ncum_d, nnum_d, d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
memcpy_cpu(total_num, ptr, (uintk*)data_out.data_ptr<int64_t>());
|
||||
memcpy_cpu(total_num, ptr, des_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (property == "index") {
|
||||
else if (property == "index") {
|
||||
const uintk* key_ptr = octree_.key_cpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt32));
|
||||
key2idx_cpu(data_out.data_ptr<int>(), key_ptr, total_num);
|
||||
}
|
||||
|
||||
if (property == "child") {
|
||||
else if (property == "child") {
|
||||
const int* child_ptr = octree_.children_cpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kChild); // = 1
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kInt32));
|
||||
memcpy_cpu(total_num, child_ptr, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
if (property == "neigh") {
|
||||
else if (property == "neigh") {
|
||||
const int* neigh_ptr = octree_.neighbor_cpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kNeigh);
|
||||
int total_num = channel * node_num;
|
||||
data_out = torch::zeros({node_num, channel}, options.dtype(torch::kInt32));
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({nnum, channel}, options.dtype(torch::kInt32));
|
||||
memcpy_cpu(total_num, neigh_ptr, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
if (property == "feature") {
|
||||
else if (property == "feature") {
|
||||
const float* feature_ptr = octree_.feature_cpu(depth);
|
||||
CHECK(feature_ptr != nullptr) << "The features do not exist: d = " << depth;
|
||||
int channel = octree_.info().channel(OctreeInfo::kFeature);
|
||||
int total_num = channel * node_num;
|
||||
data_out = torch::zeros({1, channel, node_num, 1}, options.dtype(torch::kFloat32));
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({1, channel, nnum, 1}, options.dtype(torch::kFloat32));
|
||||
memcpy_cpu(total_num, feature_ptr, data_out.data_ptr<float>());
|
||||
}
|
||||
|
||||
if (property == "label") {
|
||||
else if (property == "label") {
|
||||
const float* label_ptr = octree_.label_cpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kLabel);
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kFloat32));
|
||||
memcpy_cpu(total_num, label_ptr, data_out.data_ptr<float>());
|
||||
}
|
||||
|
||||
if (property == "split") {
|
||||
else if (property == "split") {
|
||||
const float* split_ptr = octree_.split_cpu(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kSplit);
|
||||
int total_num = channel * node_num;
|
||||
int total_num = channel * nnum;
|
||||
data_out = torch::zeros({total_num}, options.dtype(torch::kFloat32));
|
||||
memcpy_cpu(total_num, split_ptr, data_out.data_ptr<float>());
|
||||
}
|
||||
|
||||
if (property == "node_num") {
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_cpu(1, &node_num, data_out.data_ptr<int>());
|
||||
else if (property == "node_num") {
|
||||
int num = depth > 0 ? 1 : octree_depth + 1;
|
||||
data_out = torch::zeros({num}, options.dtype(torch::kInt32));
|
||||
const int* ptr = octree_.info().node_num_ptr();
|
||||
memcpy_cpu(num, ptr + depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
if (property == "node_num_ne" || property == "node_num_nempty") {
|
||||
int num = octree_.info().node_num_nempty(depth);
|
||||
else if (property == "node_num_ne" || property == "node_num_nempty") {
|
||||
int num = depth > 0 ? 1 : octree_depth + 1;
|
||||
data_out = torch::zeros({num}, options.dtype(torch::kInt32));
|
||||
const int* ptr = octree_.info().node_nempty_ptr();
|
||||
memcpy_cpu(num, ptr + depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "node_num_cum") {
|
||||
int num = depth > 0 ? 1 : octree_depth + 2;
|
||||
const int* ptr = octree_.info().node_num_cum_ptr();
|
||||
data_out = torch::zeros({num}, options.dtype(torch::kInt32));
|
||||
memcpy_cpu(num, ptr + depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "batch_size") {
|
||||
int batch_size = octree_.info().batch_size();
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_cpu(1, &num, data_out.data_ptr<int>());
|
||||
memcpy_cpu(1, &batch_size, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "depth") {
|
||||
int depth = octree_.info().depth();
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_cpu(1, &depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else if (property == "full_depth") {
|
||||
int full_depth = octree_.info().full_layer();
|
||||
data_out = torch::zeros({1}, options.dtype(torch::kInt32));
|
||||
memcpy_cpu(1, &full_depth, data_out.data_ptr<int>());
|
||||
}
|
||||
|
||||
else{
|
||||
LOG(FATAL) << "Unsupport octree property: " << property;
|
||||
}
|
||||
|
||||
return data_out;
|
||||
}
|
||||
|
||||
Tensor octree_set_property_gpu(Tensor octree_in, Tensor data_in, int depth) {
|
||||
Tensor octree_out = octree_in.clone();
|
||||
|
||||
OctreeParser octree_;
|
||||
octree_.set_gpu(octree_out.data_ptr<uint8_t>());
|
||||
float* property_ptr = octree_.mutable_feature_gpu(depth);
|
||||
|
||||
int length = octree_.info().node_num(depth);
|
||||
int channel = octree_.info().channel(OctreeInfo::kFeature);
|
||||
int count = length * channel;
|
||||
data_in = data_in.contiguous();
|
||||
CHECK_EQ(count, data_in.numel()) << "Wrong Property Size";
|
||||
memcpy_gpu(count, data_in.data_ptr<float>(), property_ptr);
|
||||
|
||||
return octree_out;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// API implementation
|
||||
|
@ -197,4 +306,10 @@ Tensor octree_property(Tensor octree_in, string property, int depth) {
|
|||
} else {
|
||||
return octree_property_cpu(octree_in, property, depth);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor octree_set_property(Tensor octree_in, Tensor data_in, int depth) {
|
||||
CHECK(octree_in.is_cuda());
|
||||
CHECK(data_in.is_cuda());
|
||||
return octree_set_property_gpu(octree_in, data_in, depth);
|
||||
}
|
|
@ -9,10 +9,11 @@ Tensor points2octree(Tensor points, int depth, int full_depth, bool node_dis,
|
|||
// init the points
|
||||
Points point_cloud;
|
||||
point_cloud.set(points.data_ptr<uint8_t>());
|
||||
// // check the points
|
||||
// string msg;
|
||||
// bool succ = point_cloud.info().check_format(msg);
|
||||
// CHECK(succ) << msg;
|
||||
|
||||
// check the points
|
||||
string msg;
|
||||
bool succ = point_cloud.info().check_format(msg);
|
||||
CHECK(succ) << msg;
|
||||
|
||||
// init the octree info
|
||||
OctreeInfo octree_info;
|
||||
|
|
|
@ -22,21 +22,25 @@ PointsInfo::PropType get_ptype(const string property) {
|
|||
return ptype;
|
||||
}
|
||||
|
||||
vector<float> tensor2vector(const Tensor& data_in) {
|
||||
vector<float> vec;
|
||||
Tensor data = data_in.contiguous(); // !!! make sure the Tensor is contiguous
|
||||
const int64_t num = data.numel();
|
||||
if (num > 0) {
|
||||
const float* ptr = data.data_ptr<float>();
|
||||
vec.assign(ptr, ptr + num);
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor points_new(Tensor pts, Tensor normals, Tensor features, Tensor labels) {
|
||||
auto get_data = [](vector<float>& vec, const Tensor& data_in) {
|
||||
const int64_t num = data_in.numel();
|
||||
if (num > 0) {
|
||||
const float* ptr = data_in.data_ptr<float>();
|
||||
vec.assign(ptr, ptr + num);
|
||||
}
|
||||
};
|
||||
vector<float> pts_in, normals_in, features_in, labels_in;
|
||||
get_data(pts_in, pts);
|
||||
get_data(normals_in, normals);
|
||||
get_data(features_in, features);
|
||||
get_data(labels_in, labels);
|
||||
// input
|
||||
vector<float> pts_in = tensor2vector(pts);
|
||||
vector<float> normals_in = tensor2vector(normals);
|
||||
vector<float> features_in = tensor2vector(features);
|
||||
vector<float> labels_in = tensor2vector(labels);
|
||||
|
||||
// create the point cloud
|
||||
Points point_cloud;
|
||||
|
@ -55,6 +59,7 @@ Tensor points_set_property(Tensor points_in, Tensor data, string property) {
|
|||
CHECK_EQ(data.dim(), 2);
|
||||
int num = data.size(0);
|
||||
int channel = data.size(1);
|
||||
data = data.contiguous(); // !!! make sure the Tensor is contiguous
|
||||
|
||||
// init the points
|
||||
Points pts;
|
||||
|
|
|
@ -6,8 +6,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
m.def("points2octree", &points2octree, "convert points to octree");
|
||||
m.def("octree2col", &octree2col, "octree2col");
|
||||
m.def("col2octree", &col2octree, "col2octree");
|
||||
m.def("octree2colP", &octree2colP, "octree2colP");
|
||||
m.def("col2octreeP", &col2octreeP, "col2octreeP");
|
||||
m.def("octree_conv", &octree_conv, "octree based convolution");
|
||||
m.def("octree_deconv", &octree_deconv, "octree based deconvolution");
|
||||
m.def("octree_conv_grad", &octree_conv_grad, "octree based convolution");
|
||||
|
@ -17,16 +15,33 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
m.def("octree_max_pool", &octree_max_pool, "octree max pooling");
|
||||
m.def("octree_max_unpool", &octree_max_unpool, "octree max unpooling");
|
||||
m.def("octree_mask_pool", &octree_mask_pool, "octree mask pooling");
|
||||
m.def("octree_property", &octree_property, "get the octree property");
|
||||
m.def("octree_property", &octree_property, "get the octree property",
|
||||
py::arg("octree"), py::arg("property"), py::arg("depth") = 0);
|
||||
m.def("octree_set_property", &octree_set_property, "set the octree property",
|
||||
py::arg("octree"), py::arg("data"), py::arg("depth"));
|
||||
m.def("transform_points", &transform_points, "transform the point cloud");
|
||||
m.def("clip_points", &clip_points, "clip the points with unit boundingbox");
|
||||
m.def("normalize_points", &normalize_points, "normalize the point cloud");
|
||||
m.def("bounding_sphere", &bounding_sphere, "calc the bounding sphere");
|
||||
m.def("octree_encode_key", &octree_encode_key, "encode xyz-id to octree key");
|
||||
m.def("octree_decode_key", &octree_decode_key, "decode octree key to xyz-id");
|
||||
m.def("octree_xyz2key", &octree_xyz2key, "convert key from xyz order");
|
||||
m.def("octree_key2xyz", &octree_key2xyz, "convert key to xyz order");
|
||||
m.def("octree_search_key", &octree_search_key, "search key from octree");
|
||||
|
||||
m.def("octree_search_key", &octree_search_key, "search key from octree",
|
||||
py::arg("key"), py::arg("octree"), py::arg("depth"),
|
||||
py::arg("key_is_xyz") = true, py::arg("nempty") = false);
|
||||
m.def("octree_scan", &octree_scan, "octree scanning", py::arg("octree"),
|
||||
py::arg("axis"), py::arg("scale") = 1.0);
|
||||
m.def("octree_grow", &octree_grow, "octree growing", py::arg("octree"),
|
||||
py::arg("target_depth"), py::arg("full_octree") = false);
|
||||
m.def("octree_update", &octree_update, "update octree via splitting labels",
|
||||
py::arg("octree"), py::arg("label"), py::arg("depth"), py::arg("split") = 1);
|
||||
m.def("octree_new", &octree_new, "create a new octree",
|
||||
py::arg("batch_size") = 1, py::arg("channel") = 4,
|
||||
py::arg("node_dis") = true, py::arg("adaptive_layer") = 0);
|
||||
m.def("octree_align", &octree_align, "align octree data", py::arg("src_data"),
|
||||
py::arg("src_octree"), py::arg("des_octree"), py::arg("depth"));
|
||||
m.def("octree_align_grad", &octree_align_grad, "backward of align-octree-data");
|
||||
m.def("points_batch_property", &points_batch_property, "get batch of points' property");
|
||||
m.def("points_property", &points_property, "get the points' property");
|
||||
m.def("points_set_property", &points_set_property, "set the points' property");
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
#include <octree/transform_octree.h>
|
||||
|
||||
#include "ocnn.h"
|
||||
|
||||
Tensor octree_scan(Tensor octree, vector<float> axis, float scale) {
|
||||
// input
|
||||
OctreeParser parser;
|
||||
parser.set_cpu(octree.data_ptr<uint8_t>());
|
||||
|
||||
// scan
|
||||
ScanOctree scan_octree(scale);
|
||||
vector<char> octree_out;
|
||||
scan_octree.scan(octree_out, parser, axis);
|
||||
|
||||
// output
|
||||
torch::TensorOptions options = octree.options();
|
||||
Tensor output = torch::zeros(octree_out.size(), options);
|
||||
memcpy(output.data_ptr<uint8_t>(), octree_out.data(), octree_out.size());
|
||||
|
||||
return output;
|
||||
}
|
|
@ -34,6 +34,7 @@ vector<float> bounding_sphere(Tensor data_in, string method) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
// TODO: Tensor.clone()
|
||||
void setup_transform(Tensor data_in, Tensor& data_out, Points& pts) {
|
||||
data_out = torch::zeros_like(data_in);
|
||||
uint8_t* out_ptr = data_out.data_ptr<uint8_t>();
|
||||
|
@ -70,7 +71,7 @@ Tensor normalize_points(Tensor data_in, float radius, vector<float> center) {
|
|||
}
|
||||
|
||||
Tensor transform_points(Tensor data_in, vector<float> angle, vector<float> scale,
|
||||
vector<float> jitter, float offset) {
|
||||
vector<float> jitter, float offset, string normal_axis) {
|
||||
// copy the data out of the input tensor
|
||||
Tensor data_out; Points pts;
|
||||
setup_transform(data_in, data_out, pts);
|
||||
|
@ -99,11 +100,26 @@ Tensor transform_points(Tensor data_in, vector<float> angle, vector<float> scale
|
|||
pts.scale(scale.data());
|
||||
}
|
||||
|
||||
// clip the points to the box[-1, 1] ^ 3,
|
||||
const float bbmin[] = {-1.0f, -1.0f, -1.0f};
|
||||
const float bbmax[] = {1.0f, 1.0f, 1.0f};
|
||||
pts.clip(bbmin, bbmax);
|
||||
// orient normal
|
||||
if(!normal_axis.empty()) {
|
||||
pts.orient_normal(normal_axis);
|
||||
}
|
||||
|
||||
// output
|
||||
return data_out;
|
||||
}
|
||||
|
||||
vector<Tensor> clip_points(Tensor data_in, vector<float> bbmin, vector<float> bbmax) {
|
||||
// copy the data out of the input tensor
|
||||
Tensor data_out; Points pts;
|
||||
setup_transform(data_in, data_out, pts);
|
||||
|
||||
// clip the points to the box[-1, 1] ^ 3,
|
||||
const vector<int> inbox_mask_buffer = pts.clip(bbmin.data(), bbmax.data());
|
||||
size_t sz = inbox_mask_buffer.size();
|
||||
Tensor inbox_mask = torch::zeros({(int64_t)sz}, torch::dtype(torch::kInt32));
|
||||
memcpy(inbox_mask.data_ptr<int>(), inbox_mask_buffer.data(), sz*sizeof(int));
|
||||
|
||||
// output
|
||||
return {data_out, inbox_mask};
|
||||
}
|
||||
|
|
|
@ -3,34 +3,40 @@ import torch
|
|||
# low level api
|
||||
from . import nn
|
||||
from .nn import octree_batch, octree_samples, points2octree, octree_property, \
|
||||
bounding_sphere, transform_points, normalize_points, \
|
||||
octree_set_property, bounding_sphere, normalize_points, \
|
||||
octree_scan, transform_points, clip_points, \
|
||||
octree_encode_key, octree_decode_key, octree_search_key, \
|
||||
octree_xyz2key, octree_key2xyz, \
|
||||
octree_grow, octree_new, octree_update, \
|
||||
points_property, points_batch_property, \
|
||||
points_new, points_set_property
|
||||
|
||||
# transforms
|
||||
from .transforms import NormalizePoints, TransformPoints, Points2Octree, \
|
||||
TransformCompose, CollateOctrees, collate_octrees
|
||||
TransformCompose, collate_octrees
|
||||
|
||||
# octree-based cnn layers
|
||||
from .octree2voxel import FullOctree2Voxel
|
||||
from .octree2col import octree2col, Octree2Col, col2octree, Col2Octree, \
|
||||
octree2colP, Octree2ColP, col2octreeP, Col2OctreeP
|
||||
from .octree_align import octree_align, OctreeAlign
|
||||
from .octree2col import octree2col, Octree2Col, col2octree, Col2Octree
|
||||
from .octree_pad import octree_pad, OctreePad, octree_depad, OctreeDepad
|
||||
from .octree_conv import octree_conv, OctreeConv, OctreeConvFast
|
||||
from .octree_deconv import octree_deconv, OctreeDeconv, OctreeDeconvFast
|
||||
from .octree_conv import octree_conv, OctreeConv, OctreeConvFast, \
|
||||
octree_deconv, OctreeDeconv, OctreeDeconvFast
|
||||
from .octree_pool import octree_max_pool, OctreeMaxPool, octree_max_unpool, \
|
||||
OctreeMaxUnpool, OctreeAvgPool, FullOctreeGlobalPool
|
||||
|
||||
# octree-base modules
|
||||
# octree-based modules
|
||||
from .modules import OctreeConvBnRelu, OctreeDeConvBnRelu, \
|
||||
FcBnRelu, OctreeConv1x1, OctreeConv1x1BnRelu, \
|
||||
OctreeResBlock, OctreeResBlock2, \
|
||||
OctreeResBlocks, OctreeResBlocks2, \
|
||||
OctreeTile, octree_trilinear_pts, octree_trilinear
|
||||
OctreeResBlock, OctreeResBlock2, OctreeResBlocks, \
|
||||
OctreeTile, octree_trilinear_pts, octree_trilinear, \
|
||||
octree_nearest_pts, OctreeInterp, create_full_octree, \
|
||||
octree_feature
|
||||
|
||||
# networks
|
||||
from .lenet import LeNet
|
||||
from .resnet import ResNet
|
||||
from .segnet import SegNet
|
||||
from .unet import UNet
|
||||
from .ounet import OUNet
|
||||
from .mlp import MLP
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
import torch.nn
|
||||
|
||||
|
||||
class FC(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, act=torch.nn.ReLU(inplace=True)):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=True)
|
||||
self.act = act
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
class FcBn(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, act=torch.nn.ReLU(inplace=True)):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
|
||||
self.norm = torch.nn.BatchNorm1d(out_features)
|
||||
self.act = act
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
output = self.norm(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, hidden_features=256,
|
||||
hidden_layers=2, layer=FcBn, act=torch.nn.ReLU(inplace=True)):
|
||||
super().__init__()
|
||||
|
||||
layers = [layer(in_features, hidden_features, act)]
|
||||
for _ in range(hidden_layers):
|
||||
layers.append(layer(hidden_features, hidden_features, act))
|
||||
layers.append(FC(hidden_features, out_features, act=torch.nn.Identity()))
|
||||
|
||||
self.layers = torch.nn.Sequential(*layers)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.layers(input)
|
||||
return output
|
|
@ -1,15 +1,17 @@
|
|||
import torch
|
||||
import ocnn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
bn_momentum, bn_eps = 0.01, 0.001
|
||||
# bn_momentum, bn_eps = 0.1, 1e-05
|
||||
|
||||
|
||||
class OctreeConvBn(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1):
|
||||
super(OctreeConvBn, self).__init__()
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1,
|
||||
nempty=False):
|
||||
super().__init__()
|
||||
self.conv = ocnn.OctreeConv(
|
||||
depth, channel_in, channel_out, kernel_size, stride)
|
||||
depth, channel_in, channel_out, kernel_size, stride, nempty)
|
||||
self.bn = torch.nn.BatchNorm2d(channel_out, bn_eps, bn_momentum)
|
||||
|
||||
def forward(self, data_in, octree):
|
||||
|
@ -19,10 +21,11 @@ class OctreeConvBn(torch.nn.Module):
|
|||
|
||||
|
||||
class OctreeConvBnRelu(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1):
|
||||
super(OctreeConvBnRelu, self).__init__()
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1,
|
||||
nempty=False):
|
||||
super().__init__()
|
||||
self.conv = ocnn.OctreeConv(
|
||||
depth, channel_in, channel_out, kernel_size, stride)
|
||||
depth, channel_in, channel_out, kernel_size, stride, nempty)
|
||||
self.bn = torch.nn.BatchNorm2d(channel_out, bn_eps, bn_momentum)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
|
||||
|
@ -34,10 +37,11 @@ class OctreeConvBnRelu(torch.nn.Module):
|
|||
|
||||
|
||||
class OctreeDeConvBnRelu(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1):
|
||||
super(OctreeDeConvBnRelu, self).__init__()
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1,
|
||||
nempty=False):
|
||||
super().__init__()
|
||||
self.deconv = ocnn.OctreeDeconv(
|
||||
depth, channel_in, channel_out, kernel_size, stride)
|
||||
depth, channel_in, channel_out, kernel_size, stride, nempty)
|
||||
self.bn = torch.nn.BatchNorm2d(channel_out, bn_eps, bn_momentum)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
|
||||
|
@ -50,7 +54,7 @@ class OctreeDeConvBnRelu(torch.nn.Module):
|
|||
|
||||
class FcBnRelu(torch.nn.Module):
|
||||
def __init__(self, channel_in, channel_out):
|
||||
super(FcBnRelu, self).__init__()
|
||||
super().__init__()
|
||||
self.flatten = torch.nn.Flatten(start_dim=1)
|
||||
self.fc = torch.nn.Linear(channel_in, channel_out, bias=False)
|
||||
self.bn = torch.nn.BatchNorm1d(channel_out, bn_eps, bn_momentum)
|
||||
|
@ -66,7 +70,7 @@ class FcBnRelu(torch.nn.Module):
|
|||
|
||||
class OctreeConv1x1(torch.nn.Module):
|
||||
def __init__(self, channel_in, channel_out, use_bias=False):
|
||||
super(OctreeConv1x1, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1x1 = torch.nn.Conv1d(
|
||||
channel_in, channel_out, kernel_size=1, bias=use_bias)
|
||||
|
||||
|
@ -79,7 +83,7 @@ class OctreeConv1x1(torch.nn.Module):
|
|||
|
||||
class OctreeConv1x1Bn(torch.nn.Module):
|
||||
def __init__(self, channel_in, channel_out, use_bias=False):
|
||||
super(OctreeConv1x1Bn, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1x1 = OctreeConv1x1(channel_in, channel_out, use_bias)
|
||||
self.bn = torch.nn.BatchNorm2d(channel_out, bn_eps, bn_momentum)
|
||||
|
||||
|
@ -91,7 +95,7 @@ class OctreeConv1x1Bn(torch.nn.Module):
|
|||
|
||||
class OctreeConv1x1BnRelu(torch.nn.Module):
|
||||
def __init__(self, channel_in, channel_out, use_bias=False):
|
||||
super(OctreeConv1x1BnRelu, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1x1 = OctreeConv1x1(channel_in, channel_out, use_bias)
|
||||
self.bn = torch.nn.BatchNorm2d(channel_out, bn_eps, bn_momentum)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
|
@ -104,18 +108,20 @@ class OctreeConv1x1BnRelu(torch.nn.Module):
|
|||
|
||||
|
||||
class OctreeResBlock(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, stride=1, bottleneck=4):
|
||||
super(OctreeResBlock, self).__init__()
|
||||
def __init__(self, depth, channel_in, channel_out, stride=1, bottleneck=4,
|
||||
nempty=False):
|
||||
super().__init__()
|
||||
self.channel_in = channel_in
|
||||
self.channel_out = channel_out
|
||||
self.bottleneck = bottleneck
|
||||
self.stride = stride
|
||||
channelb = int(channel_out / bottleneck)
|
||||
self.depth = depth
|
||||
channelb = int(channel_out / bottleneck)
|
||||
if self.stride == 2:
|
||||
self.maxpool = ocnn.OctreeMaxPool(self.depth)
|
||||
self.depth = self.depth - 1
|
||||
self.conv1x1a = OctreeConv1x1BnRelu(channel_in, channelb)
|
||||
self.conv3x3 = OctreeConvBnRelu(self.depth, channelb, channelb)
|
||||
self.conv3x3 = OctreeConvBnRelu(self.depth, channelb, channelb, nempty=nempty)
|
||||
self.conv1x1b = OctreeConv1x1Bn(channelb, channel_out)
|
||||
if self.channel_in != self.channel_out:
|
||||
self.conv1x1c = OctreeConv1x1Bn(channel_in, channel_out)
|
||||
|
@ -134,17 +140,21 @@ class OctreeResBlock(torch.nn.Module):
|
|||
|
||||
|
||||
class OctreeResBlock2(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, stride=1):
|
||||
super(OctreeResBlock2, self).__init__()
|
||||
def __init__(self, depth, channel_in, channel_out, stride=1, bottleneck=1,
|
||||
nempty=False):
|
||||
super().__init__()
|
||||
self.channel_in = channel_in
|
||||
self.channel_out = channel_out
|
||||
self.stride = stride
|
||||
self.depth = depth
|
||||
channelb = int(channel_out / bottleneck)
|
||||
if self.stride == 2:
|
||||
self.maxpool = ocnn.OctreeMaxPool(self.depth)
|
||||
self.depth = self.depth - 1
|
||||
self.conv3x3a = OctreeConvBnRelu(self.depth, channel_in, channel_out)
|
||||
self.conv3x3b = OctreeConvBn(self.depth, channel_out, channel_out)
|
||||
self.conv3x3a = OctreeConvBnRelu(
|
||||
self.depth, channel_in, channelb, nempty=nempty)
|
||||
self.conv3x3b = OctreeConvBn(
|
||||
self.depth, channelb, channel_out, nempty=nempty)
|
||||
if self.channel_in != self.channel_out:
|
||||
self.conv1x1 = OctreeConv1x1Bn(channel_in, channel_out)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
|
@ -161,32 +171,22 @@ class OctreeResBlock2(torch.nn.Module):
|
|||
|
||||
|
||||
class OctreeResBlocks(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, resblk_num, bottleneck=4):
|
||||
super(OctreeResBlocks, self).__init__()
|
||||
def __init__(self, depth, channel_in, channel_out, resblk_num, bottleneck=4,
|
||||
nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.resblk_num = resblk_num
|
||||
self.use_checkpoint = use_checkpoint
|
||||
channels = [channel_in] + [channel_out] * resblk_num
|
||||
self.resblocks = torch.nn.ModuleList(
|
||||
[OctreeResBlock(depth, channels[i], channels[i+1], 1, bottleneck)
|
||||
self.resblks = torch.nn.ModuleList(
|
||||
[resblk(depth, channels[i], channels[i+1], 1, bottleneck, nempty)
|
||||
for i in range(self.resblk_num)])
|
||||
|
||||
def forward(self, data, octree):
|
||||
for i in range(self.resblk_num):
|
||||
data = self.resblocks[i](data, octree)
|
||||
return data
|
||||
|
||||
|
||||
class OctreeResBlocks2(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, resblk_num):
|
||||
super(OctreeResBlocks2, self).__init__()
|
||||
self.resblk_num = resblk_num
|
||||
channels = [channel_in] + [channel_out] * resblk_num
|
||||
self.resblocks = torch.nn.ModuleList(
|
||||
[OctreeResBlock2(depth, channels[i], channels[i+1], stride=1)
|
||||
for i in range(self.resblk_num)])
|
||||
|
||||
def forward(self, data, octree):
|
||||
for i in range(self.resblk_num):
|
||||
data = self.resblocks[i](data, octree)
|
||||
if self.use_checkpoint:
|
||||
data = torch.utils.checkpoint.checkpoint(self.resblks[i], data, octree)
|
||||
else:
|
||||
data = self.resblks[i](data, octree)
|
||||
return data
|
||||
|
||||
|
||||
|
@ -195,7 +195,7 @@ class OctreeTile(torch.nn.Module):
|
|||
'''
|
||||
|
||||
def __init__(self, depth):
|
||||
super(OctreeTile, self).__init__()
|
||||
super().__init__()
|
||||
self.depad = ocnn.OctreeDepad(depth)
|
||||
|
||||
def forward(self, data_in, octree):
|
||||
|
@ -206,10 +206,25 @@ class OctreeTile(torch.nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
def octree_trilinear_pts(data, octree, depth, pts):
|
||||
def octree_nearest_pts(data, octree, depth, pts, nempty=False):
|
||||
key = pts.short() # (x, y, z, id)
|
||||
key = ocnn.octree_encode_key(key).long() # (N, )
|
||||
|
||||
idx = ocnn.octree_search_key(key, octree, depth, True, nempty)
|
||||
flgs = idx > -1 # valid indices
|
||||
idx = idx * flgs
|
||||
|
||||
data = torch.squeeze(data).t() # (1, C, H, 1) -> (H, C)
|
||||
output = data[idx.long()] * flgs.unsqueeze(-1)
|
||||
output = torch.unsqueeze((torch.unsqueeze(output.t(), dim=0)), dim=-1)
|
||||
return output
|
||||
|
||||
|
||||
def octree_trilinear_pts(data, octree, depth, pts, nempty=False):
|
||||
''' Linear Interpolatation with input points.
|
||||
pts: (N, 4), i.e. N x (x, y, z, id).
|
||||
data: (1, C, H, 1)
|
||||
nempty: the data only contains features of non-empty octree nodes
|
||||
!!! Note: the pts should be scaled into [0, 2^depth]
|
||||
'''
|
||||
|
||||
|
@ -222,15 +237,15 @@ def octree_trilinear_pts(data, octree, depth, pts):
|
|||
|
||||
# 1. Neighborhood searching
|
||||
xyzf, ids = torch.split(pts, [3, 1], 1)
|
||||
xyzf = xyzf - 0.5 # since the value is defined on the center of each voxel
|
||||
xyzf = xyzf - 0.5 # the value is defined on the center of each voxel
|
||||
xyzi = torch.floor(xyzf) # the integer part (N, 3)
|
||||
frac = xyzf - xyzi # the fraction part (N, 3)
|
||||
|
||||
key = torch.cat([xyzi, ids], dim=1).short() # (N, 4)
|
||||
key = ocnn.octree_encode_key(key).long() # (N, )
|
||||
key = (torch.unsqueeze(key, dim=1) + masku).view(-1) # (N, 1)->(N, 8)->(8*N,)
|
||||
|
||||
idx = ocnn.octree_search_key(key, octree, depth, True)
|
||||
|
||||
key = torch.cat([xyzi, ids], dim=1).short() # (N, 4)
|
||||
key = ocnn.octree_encode_key(key).long() # (N, )
|
||||
key = (torch.unsqueeze(key, dim=1) + masku).view(-1) # (N, 1)->(N, 8)->(8*N,)
|
||||
|
||||
idx = ocnn.octree_search_key(key, octree, depth, True, nempty)
|
||||
flgs = idx > -1 # valid indices
|
||||
idx = idx[flgs]
|
||||
|
||||
|
@ -239,8 +254,7 @@ def octree_trilinear_pts(data, octree, depth, pts):
|
|||
ids = torch.arange(npt).cuda()
|
||||
ids = ids.view(-1, 1).repeat(1, 8).view(-1)
|
||||
ids = ids[flgs]
|
||||
indices = torch.cat([torch.unsqueeze(ids, dim=1),
|
||||
torch.unsqueeze(idx, dim=1)], dim=1).long()
|
||||
indices = torch.stack([ids, idx], dim=1).long()
|
||||
|
||||
maskc = 1 - mask
|
||||
frac = maskc - torch.unsqueeze(frac, dim=1)
|
||||
|
@ -262,6 +276,8 @@ def octree_trilinear_pts(data, octree, depth, pts):
|
|||
|
||||
|
||||
def octree_trilinear(data, octree, depth, target_depth):
|
||||
''' Interpolate data from octree `depth` to `target_depth`
|
||||
'''
|
||||
xyz = ocnn.octree_property(octree, 'xyz', target_depth)
|
||||
xyz = ocnn.octree_decode_key(xyz).float()
|
||||
scale = 2.0**(depth-target_depth)
|
||||
|
@ -273,9 +289,49 @@ def octree_trilinear(data, octree, depth, target_depth):
|
|||
|
||||
class OctreeTrilinear(torch.nn.Module):
|
||||
def __init__(self, depth):
|
||||
super(OctreeTrilinear, self).__init__()
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
|
||||
def forward(self, data_in, octree):
|
||||
out = octree_trilinear(data_in, octree, self.depth, self.depth + 1)
|
||||
def forward(self, data, octree):
|
||||
out = octree_trilinear(data, octree, self.depth, self.depth + 1)
|
||||
return out
|
||||
|
||||
|
||||
class OctreeInterp(torch.nn.Module):
|
||||
def __init__(self, depth, method='linear', nempty=False):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.method = method
|
||||
self.nempty = nempty
|
||||
|
||||
def forward(self, data, octree, pts):
|
||||
# Input pts in [-1, 1], convert pts to [0, 2^depth]
|
||||
xyz = (pts[:, :3] + 1.0) * (2 ** (self.depth - 1))
|
||||
pts = torch.cat([xyz, pts[:, 3:]], dim=1)
|
||||
|
||||
if self.method == 'nearest':
|
||||
out = octree_nearest_pts(data, octree, self.depth, pts, self.nempty)
|
||||
elif self.method == 'linear':
|
||||
out = octree_trilinear_pts(data, octree, self.depth, pts, self.nempty)
|
||||
else:
|
||||
raise ValueError
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return ('depth={}, method={}, nempty={}').format(
|
||||
self.depth, self.method, self.nempty)
|
||||
|
||||
|
||||
def create_full_octree(depth, channel, batch_size=1, node_dis=True):
|
||||
assert depth > 1
|
||||
octree = ocnn.octree_new(batch_size, channel, node_dis)
|
||||
for target_depth in range(1, depth+1):
|
||||
octree = ocnn.octree_grow(octree, target_depth, full_octree=True)
|
||||
return octree
|
||||
|
||||
|
||||
def octree_feature(octree, depth, nempty=False):
|
||||
output = ocnn.octree_property(octree, 'feature', depth)
|
||||
if nempty:
|
||||
output = ocnn.nn.octree_depad(output, octree, depth)
|
||||
return output
|
||||
|
|
|
@ -6,123 +6,76 @@ import ocnn
|
|||
|
||||
class Octree2ColFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data_in, octree, depth, kernel_size, stride):
|
||||
def forward(ctx, data_in, octree, depth, kernel_size, stride, nempty):
|
||||
ctx.save_for_backward(octree)
|
||||
ctx.depth = depth
|
||||
ctx.kernel_size = kernel_size
|
||||
ctx.stride = stride
|
||||
ctx.nempty = nempty
|
||||
|
||||
data_in = data_in.contiguous()
|
||||
data_out = ocnn.nn.octree2col(data_in, octree, depth, kernel_size, stride)
|
||||
data_out = ocnn.nn.octree2col(
|
||||
data_in, octree, depth, kernel_size, stride, nempty)
|
||||
return data_out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_in):
|
||||
octree, = ctx.saved_tensors
|
||||
grad_in = grad_in.contiguous()
|
||||
grad_out = ocnn.nn.col2octree(grad_in, octree,
|
||||
ctx.depth, ctx.kernel_size, ctx.stride)
|
||||
return grad_out, None, None, None, None
|
||||
grad_out = ocnn.nn.col2octree(grad_in, octree, ctx.depth, ctx.kernel_size,
|
||||
ctx.stride, ctx.nempty)
|
||||
return grad_out, None, None, None, None, None
|
||||
|
||||
|
||||
class Col2OctreeFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data_in, octree, depth, kernel_size, stride):
|
||||
def forward(ctx, data_in, octree, depth, kernel_size, stride, nempty):
|
||||
ctx.save_for_backward(octree)
|
||||
ctx.depth = depth
|
||||
ctx.kernel_size = kernel_size
|
||||
ctx.stride = stride
|
||||
ctx.nempty = nempty
|
||||
|
||||
data_in = data_in.contiguous()
|
||||
data_out = ocnn.nn.col2octree(data_in, octree, depth, kernel_size, stride)
|
||||
data_out = ocnn.nn.col2octree(
|
||||
data_in, octree, depth, kernel_size, stride, nempty)
|
||||
return data_out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_in):
|
||||
octree, = ctx.saved_tensors
|
||||
grad_in = grad_in.contiguous()
|
||||
grad_out = ocnn.nn.octree2col(grad_in, octree,
|
||||
ctx.depth, ctx.kernel_size, ctx.stride)
|
||||
return grad_out, None, None, None, None
|
||||
|
||||
|
||||
class Octree2ColPFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data_in, octree, depth, kernel_size, stride):
|
||||
ctx.save_for_backward(octree)
|
||||
ctx.depth = depth
|
||||
ctx.kernel_size = kernel_size
|
||||
ctx.stride = stride
|
||||
|
||||
data_in = data_in.contiguous()
|
||||
data_out = ocnn.nn.octree2colP(data_in, octree, depth, kernel_size, stride)
|
||||
return data_out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_in):
|
||||
octree, = ctx.saved_tensors
|
||||
grad_in = grad_in.contiguous()
|
||||
grad_out = ocnn.nn.col2octreeP(grad_in, octree,
|
||||
ctx.depth, ctx.kernel_size, ctx.stride)
|
||||
return grad_out, None, None, None, None
|
||||
|
||||
|
||||
class Col2OctreePFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data_in, octree, depth, kernel_size, stride):
|
||||
ctx.save_for_backward(octree)
|
||||
ctx.depth = depth
|
||||
ctx.kernel_size = kernel_size
|
||||
ctx.stride = stride
|
||||
|
||||
data_in = data_in.contiguous()
|
||||
data_out = ocnn.nn.col2octreeP(data_in, octree, depth, kernel_size, stride)
|
||||
return data_out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_in):
|
||||
octree, = ctx.saved_tensors
|
||||
grad_in = grad_in.contiguous()
|
||||
grad_out = ocnn.nn.octree2colP(grad_in, octree,
|
||||
ctx.depth, ctx.kernel_size, ctx.stride)
|
||||
return grad_out, None, None, None, None
|
||||
grad_out = ocnn.nn.octree2col(grad_in, octree, ctx.depth, ctx.kernel_size,
|
||||
ctx.stride, ctx.nempty)
|
||||
return grad_out, None, None, None, None, None
|
||||
|
||||
|
||||
# alias
|
||||
octree2col = Octree2ColFunction.apply
|
||||
col2octree = Col2OctreeFunction.apply
|
||||
octree2colP = Octree2ColPFunction.apply
|
||||
col2octreeP = Col2OctreePFunction.apply
|
||||
|
||||
|
||||
# module
|
||||
class Octree2ColBase(nn.Module):
|
||||
def __init__(self, depth, kernel_size, stride):
|
||||
def __init__(self, depth, kernel_size, stride, nempty=False):
|
||||
super(Octree2ColBase, self).__init__()
|
||||
self.depth = depth
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.nempty = nempty
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'depth={}, kernel_size={}, stride={}'.format(
|
||||
self.depth, self.kernel_size, self.stride)
|
||||
return 'depth={}, kernel_size={}, stride={}, nempty={}'.format(
|
||||
self.depth, self.kernel_size, self.stride, self.nempty)
|
||||
|
||||
|
||||
class Octree2Col(Octree2ColBase):
|
||||
def forward(self, data_in, octree):
|
||||
return octree2col(data_in, octree, self.depth, self.kernel_size, self.stride)
|
||||
return octree2col(data_in, octree, self.depth, self.kernel_size,
|
||||
self.stride, self.nempty)
|
||||
|
||||
|
||||
class Col2Octree(Octree2ColBase):
|
||||
def forward(self, data_in, octree):
|
||||
return col2octree(data_in, octree, self.depth, self.kernel_size, self.stride)
|
||||
|
||||
|
||||
class Octree2ColP(Octree2ColBase):
|
||||
def forward(self, data_in, octree):
|
||||
return octree2colP(data_in, octree, self.depth, self.kernel_size, self.stride)
|
||||
|
||||
|
||||
class Col2OctreeP(Octree2ColBase):
|
||||
def forward(self, data_in, octree):
|
||||
return col2octreeP(data_in, octree, self.depth, self.kernel_size, self.stride)
|
||||
return col2octree(data_in, octree, self.depth, self.kernel_size,
|
||||
self.stride, self.nempty)
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
import ocnn
|
||||
|
||||
|
||||
class OctreeAlignFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, src_data, src_octree, des_octree, depth):
|
||||
src_data = src_data.contiguous()
|
||||
des_data, index = ocnn.nn.octree_align(src_data, src_octree, des_octree, depth)
|
||||
ctx.save_for_backward(index)
|
||||
return des_data, index
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, des_grad, index_grad):
|
||||
index, = ctx.saved_tensors
|
||||
des_grad = des_grad.contiguous()
|
||||
grad_out = ocnn.nn.octree_align_grad(des_grad, index)
|
||||
return grad_out, None, None, None
|
||||
|
||||
|
||||
# alias
|
||||
octree_align = OctreeAlignFunction.apply
|
||||
|
||||
|
||||
# module
|
||||
class OctreeAlign(nn.Module):
|
||||
def __init__(self, depth, return_index=False):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.return_index = return_index
|
||||
|
||||
def forward(self, src_data, src_octree, des_octree):
|
||||
output = octree_align(src_data, src_octree, des_octree, self.depth)
|
||||
return output if self.return_index else output[0]
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'depth={}'.format(self.depth)
|
|
@ -14,67 +14,145 @@ def resize_with_last_val(list_in, num=3):
|
|||
|
||||
class OctreeConvFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data_in, weights, octree, depth, channel_out, kernel_size, stride):
|
||||
def forward(ctx, data_in, weights, octree, depth, channel_out, kernel_size,
|
||||
stride, nempty):
|
||||
data_in = data_in.contiguous()
|
||||
ctx.save_for_backward(data_in, weights, octree)
|
||||
ctx.depth = depth
|
||||
ctx.channel_out = channel_out
|
||||
ctx.kernel_size = resize_with_last_val(kernel_size)
|
||||
ctx.stride = stride
|
||||
ctx.nempty = nempty
|
||||
|
||||
data_out = ocnn.nn.octree_conv(data_in, weights, octree,
|
||||
depth, channel_out, kernel_size, stride)
|
||||
data_out = ocnn.nn.octree_conv(
|
||||
data_in, weights, octree, depth, channel_out,
|
||||
kernel_size, stride, nempty)
|
||||
return data_out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_in):
|
||||
grad_in = grad_in.contiguous()
|
||||
data_in, weights, octree = ctx.saved_tensors
|
||||
grad_out, grad_w = ocnn.nn.octree_conv_grad(data_in, weights, octree, grad_in,
|
||||
ctx.depth, ctx.channel_out, ctx.kernel_size, ctx.stride)
|
||||
return (grad_out, grad_w) + (None,) * 5
|
||||
grad_out, grad_w = ocnn.nn.octree_conv_grad(
|
||||
data_in, weights, octree, grad_in, ctx.depth, ctx.channel_out,
|
||||
ctx.kernel_size, ctx.stride, ctx.nempty)
|
||||
return (grad_out, grad_w) + (None,) * 6
|
||||
|
||||
|
||||
class OctreeDeconvFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data_in, weights, octree, depth, channel_out, kernel_size,
|
||||
stride, nempty):
|
||||
data_in = data_in.contiguous()
|
||||
ctx.save_for_backward(data_in, weights, octree)
|
||||
ctx.depth = depth
|
||||
ctx.channel_out = channel_out
|
||||
ctx.kernel_size = resize_with_last_val(kernel_size)
|
||||
ctx.stride = stride
|
||||
ctx.nempty = nempty
|
||||
|
||||
data_out = ocnn.nn.octree_deconv(
|
||||
data_in, weights, octree, depth, channel_out,
|
||||
kernel_size, stride, nempty)
|
||||
return data_out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_in):
|
||||
grad_in = grad_in.contiguous()
|
||||
data_in, weights, octree = ctx.saved_tensors
|
||||
grad_out, grad_w = ocnn.nn.octree_deconv_grad(
|
||||
data_in, weights, octree, grad_in, ctx.depth, ctx.channel_out,
|
||||
ctx.kernel_size, ctx.stride, ctx.nempty)
|
||||
return (grad_out, grad_w) + (None,) * 6
|
||||
|
||||
|
||||
# alias
|
||||
octree_conv = OctreeConvFunction.apply
|
||||
octree_deconv = OctreeDeconvFunction.apply
|
||||
|
||||
|
||||
# module
|
||||
class OctreeConvBase(nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1):
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1,
|
||||
nempty=False):
|
||||
super(OctreeConvBase, self).__init__()
|
||||
self.depth = depth
|
||||
self.channel_out = channel_out
|
||||
self.kernel_size = resize_with_last_val(kernel_size)
|
||||
self.stride = stride
|
||||
self.channel_in = channel_in
|
||||
self.nempty = nempty
|
||||
|
||||
kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
||||
self.dim = channel_in * kdim
|
||||
self.weights = nn.Parameter(torch.Tensor(self.channel_out, self.dim))
|
||||
self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
||||
conv_in = channel_in if self.is_conv_layer() else channel_out
|
||||
conv_out = channel_out if self.is_conv_layer() else channel_in
|
||||
self.cdim = conv_in * self.kdim
|
||||
self.weights = nn.Parameter(torch.Tensor(conv_out, self.cdim))
|
||||
nn.init.xavier_uniform_(self.weights)
|
||||
|
||||
def is_conv_layer():
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'depth={}, channel_in={}, channel_out={}, kernel_size={}, stride={}'.format(
|
||||
self.depth, self.channel_in, self.channel_out, self.kernel_size, self.stride)
|
||||
return ('depth={}, channel_in={}, channel_out={}, kernel_size={}, '
|
||||
'stride={}, nempty={}').format(self.depth, self.channel_in,
|
||||
self.channel_out, self.kernel_size, self.stride, self.nempty)
|
||||
|
||||
|
||||
class OctreeConv(OctreeConvBase):
|
||||
def forward(self, data_in, octree):
|
||||
conv = octree_conv(data_in, self.weights, octree, self.depth,
|
||||
self.channel_out, self.kernel_size, self.stride)
|
||||
if self.stride == 2:
|
||||
def is_conv_layer(self):
|
||||
return True
|
||||
|
||||
def forward(self, data, octree):
|
||||
assert data.size(1) == self.channel_in
|
||||
conv = octree_conv(
|
||||
data, self.weights, octree, self.depth, self.channel_out,
|
||||
self.kernel_size, self.stride, self.nempty)
|
||||
if self.stride == 2 and not self.nempty:
|
||||
conv = ocnn.octree_pad(conv, octree, self.depth-1)
|
||||
return conv
|
||||
|
||||
|
||||
class OctreeDeconv(OctreeConvBase):
|
||||
def is_conv_layer(self):
|
||||
return False
|
||||
|
||||
def forward(self, data, octree):
|
||||
assert data.size(1) == self.channel_in
|
||||
if self.stride == 2 and not self.nempty:
|
||||
data = ocnn.octree_depad(data, octree, self.depth)
|
||||
deconv = octree_deconv(
|
||||
data, self.weights, octree, self.depth, self.channel_out,
|
||||
self.kernel_size, self.stride, self.nempty)
|
||||
return deconv
|
||||
|
||||
|
||||
class OctreeConvFast(OctreeConvBase):
|
||||
def forward(self, data_in, octree):
|
||||
col = ocnn.octree2col(data_in, octree, self.depth,
|
||||
self.kernel_size, self.stride)
|
||||
col = col.view([self.dim, -1])
|
||||
def is_conv_layer(self):
|
||||
return True
|
||||
|
||||
def forward(self, data, octree):
|
||||
depth = self.depth
|
||||
col = ocnn.octree2col(data, octree, depth, self.kernel_size, self.stride, False)
|
||||
col = col.view([self.cdim, -1])
|
||||
conv = torch.mm(self.weights, col)
|
||||
conv = torch.unsqueeze(torch.unsqueeze(conv, 0), -1) # [C,H] -> [1,C,H,1]
|
||||
if self.stride == 2:
|
||||
conv = ocnn.octree_pad(conv, octree, self.depth-1)
|
||||
conv = ocnn.octree_pad(conv, octree, depth-1)
|
||||
return conv
|
||||
|
||||
|
||||
class OctreeDeconvFast(OctreeConvBase):
|
||||
def is_conv_layer(self):
|
||||
return False
|
||||
|
||||
def forward(self, data, octree):
|
||||
depth = self.depth
|
||||
if self.stride == 2:
|
||||
data = ocnn.octree_depad(data, octree, depth)
|
||||
depth = depth + 1
|
||||
data = torch.squeeze(torch.squeeze(data, dim=0), dim=-1)
|
||||
col = torch.mm(self.weights.t(), data)
|
||||
col = col.view(self.channel_out, self.kdim, -1)
|
||||
deconv = ocnn.col2octree(col, octree, depth, self.kernel_size, self.stride, False)
|
||||
return deconv
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
import ocnn
|
||||
|
||||
|
||||
def resize_with_last_val(list_in, num=3):
|
||||
assert (type(list_in) is list and len(list_in) < num + 1)
|
||||
for i in range(len(list_in), num):
|
||||
list_in.append(list_in[-1])
|
||||
return list_in
|
||||
|
||||
|
||||
class OctreeDeconvFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data_in, weights, octree, depth, channel_out, kernel_size, stride):
|
||||
data_in = data_in.contiguous()
|
||||
ctx.save_for_backward(data_in, weights, octree)
|
||||
ctx.depth = depth
|
||||
ctx.channel_out = channel_out
|
||||
ctx.kernel_size = resize_with_last_val(kernel_size)
|
||||
ctx.stride = stride
|
||||
|
||||
data_out = ocnn.nn.octree_deconv(data_in, weights, octree,
|
||||
depth, channel_out, kernel_size, stride)
|
||||
return data_out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_in):
|
||||
grad_in = grad_in.contiguous()
|
||||
data_in, weights, octree = ctx.saved_tensors
|
||||
grad_out, grad_w = ocnn.nn.octree_deconv_grad(data_in, weights, octree,
|
||||
grad_in, ctx.depth, ctx.channel_out, ctx.kernel_size, ctx.stride)
|
||||
return (grad_out, grad_w) + (None,) * 5
|
||||
|
||||
|
||||
# alias
|
||||
octree_deconv = OctreeDeconvFunction.apply
|
||||
|
||||
|
||||
# module. TODO: merge code with OctreeConvBase to avoid redundancy
|
||||
class OctreeDeconvBase(nn.Module):
|
||||
def __init__(self, depth, channel_in, channel_out, kernel_size=[3], stride=1):
|
||||
super(OctreeDeconvBase, self).__init__()
|
||||
self.depth = depth
|
||||
self.channel_out = channel_out
|
||||
self.kernel_size = resize_with_last_val(kernel_size)
|
||||
self.stride = stride
|
||||
self.channel_in = channel_in
|
||||
|
||||
self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
||||
self.dim = channel_out * self.kdim
|
||||
self.weights = nn.Parameter(torch.Tensor(self.channel_in, self.dim))
|
||||
nn.init.xavier_uniform_(self.weights)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'depth={}, channel_in={}, channel_out={}, kernel_size={}, stride={}'.format(
|
||||
self.depth, self.channel_in, self.channel_out, self.kernel_size, self.stride)
|
||||
|
||||
|
||||
class OctreeDeconv(OctreeDeconvBase):
|
||||
def forward(self, data, octree):
|
||||
if self.stride == 2:
|
||||
data = ocnn.octree_depad(data, octree, self.depth)
|
||||
deconv = ocnn.octree_deconv(data, self.weights, octree, self.depth,
|
||||
self.channel_out, self.kernel_size, self.stride)
|
||||
return deconv
|
||||
|
||||
class OctreeDeconvFast(OctreeDeconvBase):
|
||||
def forward(self, data, octree):
|
||||
depth = self.depth
|
||||
if self.stride == 2:
|
||||
data = ocnn.octree_depad(data, octree, depth)
|
||||
depth = depth + 1
|
||||
data = torch.squeeze(torch.squeeze(data, dim=0), dim=-1)
|
||||
col = torch.mm(self.weights.t(), data)
|
||||
col = col.view(self.channel_out, self.kdim, -1)
|
||||
deconv = ocnn.col2octree(col, octree, depth, self.kernel_size, self.stride)
|
||||
return deconv
|
|
@ -48,7 +48,7 @@ octree_depad = OctreeDepadFunction.apply
|
|||
# module
|
||||
class OctreePad(nn.Module):
|
||||
def __init__(self, depth, val=0.0):
|
||||
super(OctreePad, self).__init__()
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.val = val
|
||||
|
||||
|
@ -61,7 +61,7 @@ class OctreePad(nn.Module):
|
|||
|
||||
class OctreeDepad(nn.Module):
|
||||
def __init__(self, depth):
|
||||
super(OctreeDepad, self).__init__()
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
|
||||
def forward(self, data_in, octree):
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
import torch
|
||||
import ocnn
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class OUNet(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, nout, full_depth=2):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.channel_in = channel_in
|
||||
self.nout = nout
|
||||
self.full_depth = full_depth
|
||||
self.nempty = False
|
||||
self.resblk_num = 3
|
||||
self.channels = [4, 512, 512, 256, 128, 64, 32, 16]
|
||||
|
||||
# encoder
|
||||
self.conv1 = ocnn.OctreeConvBnRelu(
|
||||
depth, channel_in, self.channels[depth], nempty=self.nempty)
|
||||
self.encoder = torch.nn.ModuleList(
|
||||
[ocnn.OctreeResBlocks(d, self.channels[d],
|
||||
self.channels[d], self.resblk_num, nempty=self.nempty)
|
||||
for d in range(depth, full_depth-1, -1)])
|
||||
self.downsample = torch.nn.ModuleList(
|
||||
[ocnn.OctreeConvBnRelu(d, self.channels[d],
|
||||
self.channels[d-1], kernel_size=[2], stride=2, nempty=self.nempty)
|
||||
for d in range(depth, full_depth, -1)])
|
||||
|
||||
# decoder
|
||||
self.upsample = torch.nn.ModuleList(
|
||||
[ocnn.OctreeDeConvBnRelu(d-1, self.channels[d-1],
|
||||
self.channels[d], kernel_size=[2], stride=2, nempty=self.nempty)
|
||||
for d in range(full_depth+1, depth + 1)])
|
||||
self.decoder = torch.nn.ModuleList(
|
||||
[ocnn.OctreeResBlocks(d, self.channels[d],
|
||||
self.channels[d], self.resblk_num, nempty=self.nempty)
|
||||
for d in range(full_depth+1, depth + 1)])
|
||||
|
||||
# header
|
||||
self.predict = torch.nn.ModuleList(
|
||||
[self._make_predict_module(self.channels[d], 2)
|
||||
for d in range(full_depth, depth + 1)])
|
||||
self.header = self._make_predict_module(self.channels[depth], nout)
|
||||
|
||||
def _make_predict_module(self, channel_in, channel_out=2, num_hidden=32):
|
||||
return torch.nn.Sequential(
|
||||
ocnn.OctreeConv1x1BnRelu(channel_in, num_hidden),
|
||||
ocnn.OctreeConv1x1(num_hidden, channel_out, use_bias=True))
|
||||
|
||||
def get_input_feature(self, octree):
|
||||
data = ocnn.octree_property(octree, 'feature', self.depth)
|
||||
assert data.size(1) == self.channel_in
|
||||
return data
|
||||
|
||||
def ocnn_encoder(self, octree):
|
||||
depth, full_depth = self.depth, self.full_depth
|
||||
data = self.get_input_feature(octree)
|
||||
|
||||
convs = dict()
|
||||
convs[depth] = self.conv1(data, octree)
|
||||
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
||||
convs[d] = self.encoder[i](convs[d], octree)
|
||||
if d > full_depth:
|
||||
convs[d-1] = self.downsample[i](convs[d], octree)
|
||||
|
||||
return convs
|
||||
|
||||
def ocnn_decoder(self, convs, octree_out, octree, return_deconvs=False):
|
||||
output, deconvs = dict(), dict()
|
||||
depth, full_depth = self.depth, self.full_depth
|
||||
|
||||
deconvs[full_depth] = convs[full_depth]
|
||||
for i, d in enumerate(range(full_depth, depth+1)):
|
||||
if d > full_depth:
|
||||
deconvd = self.upsample[i-1](deconvs[d-1], octree_out)
|
||||
skip, _ = ocnn.octree_align(convs[d], octree, octree_out, d)
|
||||
deconvd = deconvd + skip
|
||||
deconvs[d] = self.decoder[i-1](deconvd, octree_out)
|
||||
|
||||
# predict the splitting label
|
||||
logit = self.predict[i](deconvs[d])
|
||||
logit = logit.squeeze().t() # (1, C, H, 1) -> (H, C)
|
||||
|
||||
# classification loss
|
||||
label_gt = ocnn.octree_property(octree_out, 'split', d).long()
|
||||
output['loss_%d' % d] = F.cross_entropy(logit, label_gt)
|
||||
output['accu_%d' % d] = logit.argmax(1).eq(label_gt).float().mean()
|
||||
|
||||
if d == depth:
|
||||
# predict the signal
|
||||
signal = self.header(deconvs[d])
|
||||
signal = torch.tanh(signal)
|
||||
|
||||
# regression loss
|
||||
signal_gt = ocnn.octree_property(octree_out, 'feature', d)
|
||||
output['loss_reg%d' % d] = torch.mean((signal_gt - signal)**2)
|
||||
|
||||
return (output, deconvs) if return_deconvs else output
|
||||
|
||||
def decode_shape(self, convs, octree, return_deconvs=False):
|
||||
deconvs = dict()
|
||||
depth, full_depth = self.depth, self.full_depth
|
||||
octree_out = ocnn.create_full_octree(full_depth, self.nout)
|
||||
|
||||
deconvs[full_depth] = convs[full_depth]
|
||||
for i, d in enumerate(range(full_depth, depth+1)):
|
||||
if d > full_depth:
|
||||
deconvd = self.upsample[i-1](deconvs[d-1], octree_out)
|
||||
skip, _ = ocnn.octree_align(convs[d], octree, octree_out, d)
|
||||
deconvd = deconvd + skip
|
||||
deconvs[d] = self.decoder[i-1](deconvd, octree_out)
|
||||
|
||||
# predict the splitting label
|
||||
logit = self.predict[i](deconvs[d])
|
||||
logit = logit.squeeze().t() # (1, C, H, 1) -> (H, C)
|
||||
|
||||
# octree splitting
|
||||
label = logit.argmax(1).to(torch.int32)
|
||||
octree_out = ocnn.octree_update(octree_out, label, d, split=1)
|
||||
if d < depth:
|
||||
octree_out = ocnn.octree_grow(octree_out, target_depth=d+1)
|
||||
# predict the signal
|
||||
else:
|
||||
signal = self.header(deconvs[d]) # (1, C, H, 1)
|
||||
signal = torch.tanh(signal)
|
||||
normal = F.normalize(signal[:, :3], dim=1)
|
||||
signal = torch.cat([normal, signal[:, 3:]], dim=1)
|
||||
octree_out = ocnn.octree_set_property(octree_out, signal, d)
|
||||
|
||||
return (octree_out, deconvs) if return_deconvs else octree_out
|
||||
|
||||
def forward(self, octree_in, octree_gt=None, run='compute_loss'):
|
||||
convs = self.ocnn_encoder(octree_in)
|
||||
if 'compute_loss' == run:
|
||||
assert octree_gt is not None
|
||||
output = self.ocnn_decoder(convs, octree_gt, octree_in)
|
||||
elif 'decode_shape' == run:
|
||||
output = self.decode_shape(convs, octree_in)
|
||||
else:
|
||||
raise ValueError
|
||||
return output
|
|
@ -3,7 +3,7 @@ import ocnn
|
|||
|
||||
|
||||
class SegNet(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, nout):
|
||||
def __init__(self, depth, channel_in, nout, interp='linear'):
|
||||
super(SegNet, self).__init__()
|
||||
self.depth, self.channel_in = depth, channel_in
|
||||
channels = [2 ** max(10 - i, 2) for i in range(depth + 1)]
|
||||
|
@ -24,38 +24,34 @@ class SegNet(torch.nn.Module):
|
|||
[ocnn.OctreeMaxUnpool(d) for d in range(2, depth)])
|
||||
self.deconv = ocnn.OctreeConvBnRelu(depth, channels[depth], channels[depth])
|
||||
|
||||
self.header = torch.nn.Sequential(
|
||||
ocnn.OctreeConv1x1BnRelu(channels[depth], 64), # fc1
|
||||
ocnn.OctreeConv1x1(64, nout, use_bias=True)) # fc2
|
||||
self.octree_interp = ocnn.OctreeInterp(self.depth, interp, nempty=False)
|
||||
|
||||
def forward(self, octree):
|
||||
self.header = torch.nn.Sequential(
|
||||
ocnn.OctreeConv1x1BnRelu(channels[depth], 64), # fc1
|
||||
ocnn.OctreeConv1x1(64, nout, use_bias=True)) # fc2
|
||||
|
||||
def forward(self, octree, pts):
|
||||
depth = self.depth
|
||||
data = ocnn.octree_property(octree, 'feature', depth)
|
||||
data = ocnn.octree_feature(octree, depth)
|
||||
assert data.size(1) == self.channel_in
|
||||
|
||||
# encoder
|
||||
pool_idx = [None] * (depth + 1)
|
||||
for i, d in enumerate(range(depth, 2, -1)):
|
||||
data = self.convs[i](data, octree)
|
||||
data, pool_idx[d] = self.pools[i](data, octree)
|
||||
|
||||
# decoder
|
||||
for i, d in enumerate(range(2, depth)):
|
||||
data = self.deconvs[i](data, octree)
|
||||
data = self.unpools[i](data, pool_idx[d+1], octree)
|
||||
|
||||
data = self.deconv(data, octree)
|
||||
data = self.header(data)
|
||||
return data
|
||||
|
||||
# point/voxel feature
|
||||
feature = self.deconv(data, octree)
|
||||
if pts is not None:
|
||||
feature = self.octree_interp(feature, octree, pts)
|
||||
|
||||
if __name__ == '__main__':
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
writer = SummaryWriter('logs/segnet')
|
||||
octree = ocnn.octree_batch(ocnn.octree_samples(['octree_1', 'octree_2']))
|
||||
model = SegNet(depth=5, channel_in=3, nout=4)
|
||||
print(model)
|
||||
|
||||
octree = octree.cuda()
|
||||
model = model.cuda()
|
||||
writer.add_graph(model, octree)
|
||||
writer.flush()
|
||||
# header
|
||||
logits = self.header(feature)
|
||||
logits = logits.squeeze().t() # (1, C, H, 1) -> (H, C)
|
||||
return logits
|
||||
|
|
|
@ -9,7 +9,7 @@ class Points2Octree:
|
|||
|
||||
def __init__(self, depth, full_depth=2, node_dis=False, node_feature=False,
|
||||
split_label=False, adaptive=False, adp_depth=4, th_normal=0.1,
|
||||
th_distance=1.0, extrapolate=False, save_pts=False, key2xyz=False,
|
||||
th_distance=2.0, extrapolate=False, save_pts=False, key2xyz=False,
|
||||
**kwargs):
|
||||
self.depth = depth
|
||||
self.full_depth = full_depth
|
||||
|
@ -36,16 +36,23 @@ class NormalizePoints:
|
|||
''' Normalize a point cloud with its bounding sphere
|
||||
|
||||
Args:
|
||||
method: The method used to calculate the bounding sphere, choose from
|
||||
'sphere' (bounding sphere) or 'box' (bounding box).
|
||||
bsphere: The method used to calculate the bounding sphere, choose from
|
||||
'sphere' (bounding sphere) or 'box' (bounding box).
|
||||
radius: Mannually specify the radius of the bounding sphere, -1 means
|
||||
that the bounding sphere is not provided.
|
||||
'''
|
||||
|
||||
def __init__(self, method='sphere'):
|
||||
self.method = method
|
||||
def __init__(self, bsphere='sphere', radius=-1.0, center=(-1.0,), **kwargs):
|
||||
self.bsphere = bsphere
|
||||
self.radius = radius
|
||||
self.center = center
|
||||
|
||||
def __call__(self, points):
|
||||
bsphere = ocnn.bounding_sphere(points, self.method)
|
||||
radius, center = bsphere[0], bsphere[1:]
|
||||
if self.radius < 0:
|
||||
bsphere = ocnn.bounding_sphere(points, self.bsphere)
|
||||
radius, center = bsphere[0], bsphere[1:]
|
||||
else:
|
||||
radius, center = self.radius, self.center
|
||||
points = ocnn.normalize_points(points, radius, center)
|
||||
return points
|
||||
|
||||
|
@ -58,7 +65,7 @@ class TransformPoints:
|
|||
|
||||
def __init__(self, distort, angle=[0, 180, 0], scale=0.25, jitter=0.25,
|
||||
offset=0.0, angle_interval=[1, 1, 1], uniform_scale=False,
|
||||
**kwargs):
|
||||
normal_axis='', **kwargs):
|
||||
self.distort = distort
|
||||
self.angle = angle
|
||||
self.scale = scale
|
||||
|
@ -66,8 +73,9 @@ class TransformPoints:
|
|||
self.offset = offset
|
||||
self.angle_interval = angle_interval
|
||||
self.uniform_scale = uniform_scale
|
||||
self.normal_axis = normal_axis
|
||||
|
||||
def __call__(self, points):
|
||||
def __call__(self, points):
|
||||
rnd_angle = [0.0, 0.0, 0.0]
|
||||
rnd_scale = [1.0, 1.0, 1.0]
|
||||
rnd_jitter = [0.0, 0.0, 0.0]
|
||||
|
@ -80,45 +88,61 @@ class TransformPoints:
|
|||
|
||||
minval, maxval = 1 - self.scale, 1 + self.scale
|
||||
rnd_scale = np.random.uniform(low=minval, high=maxval, size=(3)).tolist()
|
||||
if self.uniform_scale: rnd_scale = [rnd_scale[0]]*3
|
||||
if self.uniform_scale:
|
||||
rnd_scale = [rnd_scale[0]]*3
|
||||
|
||||
minval, maxval = -self.jitter, self.jitter
|
||||
rnd_jitter = np.random.uniform(low=minval, high=maxval, size=(3)).tolist()
|
||||
|
||||
# The range of points is [-1, 1]
|
||||
points = ocnn.transform_points(points, rnd_angle, rnd_scale, rnd_jitter, self.offset)
|
||||
return points
|
||||
points = ocnn.transform_points(
|
||||
points, rnd_angle, rnd_scale, rnd_jitter, self.offset, self.normal_axis)
|
||||
# clip the points into [-1, 1]
|
||||
points, inbox_mask = ocnn.clip_points(points, [-1.0]*3, [1.0]*3)
|
||||
return points, inbox_mask
|
||||
|
||||
|
||||
class TransformCompose:
|
||||
def __init__(self, flags, return_pts=False):
|
||||
def __init__(self, flags):
|
||||
self.flags = flags
|
||||
self.return_pts = return_pts
|
||||
|
||||
def __call__(self, points):
|
||||
points = NormalizePoints('sphere')(points)
|
||||
points = TransformPoints(**self.flags)(points)
|
||||
octree = Points2Octree(**self.flags)(points)
|
||||
return octree if not self.return_pts else (octree, points)
|
||||
|
||||
self.normalize_points = NormalizePoints(**flags)
|
||||
self.transform_points = TransformPoints(**flags)
|
||||
self.points2octree = Points2Octree(**flags)
|
||||
|
||||
class CollateOctrees:
|
||||
def __init__(self, return_pts=False):
|
||||
self.return_pts = return_pts
|
||||
def __call__(self, points, idx):
|
||||
# Normalize the points into one unit sphere in [-1, 1]
|
||||
points = self.normalize_points(points)
|
||||
|
||||
def __call__(self, batch):
|
||||
''' Merge a batch of octrees into one super octree
|
||||
'''
|
||||
assert type(batch) == list
|
||||
octrees = [b[0] for b in batch]
|
||||
octree = ocnn.octree_batch(octrees)
|
||||
labels = torch.tensor([b[1] for b in batch])
|
||||
# Apply the general transformations provided by ocnn.
|
||||
# The augmentations including rotation, scaling, and jittering, and the
|
||||
# input points out of [-1, 1] are clipped
|
||||
points, inbox_mask = self.transform_points(points)
|
||||
|
||||
outputs = [octree, labels]
|
||||
if self.return_pts:
|
||||
points = [b[2] for b in batch]
|
||||
outputs.append(points)
|
||||
return outputs
|
||||
# Convert the points in [-1, 1] to an octree
|
||||
octree = self.points2octree(points)
|
||||
|
||||
return {'octree': octree, 'points': points, 'inbox_mask': inbox_mask}
|
||||
|
||||
|
||||
collate_octrees = CollateOctrees(return_pts=False)
|
||||
def collate_octrees(batch):
|
||||
assert type(batch) == list
|
||||
|
||||
outputs = {}
|
||||
for key in batch[0].keys():
|
||||
outputs[key] = [b[key] for b in batch]
|
||||
|
||||
# Merge a batch of octrees into one super octree
|
||||
if 'octree' in key:
|
||||
outputs[key] = ocnn.octree_batch(outputs[key])
|
||||
|
||||
# Convert the labels to a Tensor
|
||||
if 'label' in key:
|
||||
outputs['label'] = torch.tensor(outputs[key])
|
||||
|
||||
# # Concat the inbox_mask
|
||||
# if 'inbox_mask' in key:
|
||||
# pt_num = [mk.numel() for mk in outputs['inbox_mask']]
|
||||
# outputs['pt_num'] = torch.tensor(pt_num)
|
||||
# outputs['inbox_mask'] = torch.cat(outputs['inbox_mask'], dim=0)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
import torch
|
||||
import ocnn
|
||||
import torch.nn
|
||||
|
||||
|
||||
class UNet(torch.nn.Module):
|
||||
def __init__(self, depth, channel_in, nout, nempty=False, interp='linear',
|
||||
use_checkpoint=False):
|
||||
super(UNet, self).__init__()
|
||||
self.depth = depth
|
||||
self.channel_in = channel_in
|
||||
self.nempty = nempty
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.config_network()
|
||||
self.stages = len(self.encoder_blocks)
|
||||
|
||||
# encoder
|
||||
self.conv1 = ocnn.OctreeConvBnRelu(
|
||||
depth, channel_in, self.encoder_channel[0], nempty=nempty)
|
||||
self.downsample = torch.nn.ModuleList(
|
||||
[ocnn.OctreeConvBnRelu(depth - i, self.encoder_channel[i],
|
||||
self.encoder_channel[i+1], kernel_size=[2], stride=2, nempty=nempty)
|
||||
for i in range(self.stages)])
|
||||
self.encoder = torch.nn.ModuleList(
|
||||
[ocnn.OctreeResBlocks(depth - i - 1, self.encoder_channel[i+1],
|
||||
self.encoder_channel[i+1], self.encoder_blocks[i], self.bottleneck,
|
||||
nempty, self.resblk, self.use_checkpoint) for i in range(self.stages)])
|
||||
|
||||
# decoder
|
||||
depth = depth - self.stages
|
||||
channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
|
||||
for i in range(self.stages)]
|
||||
self.upsample = torch.nn.ModuleList(
|
||||
[ocnn.OctreeDeConvBnRelu(depth + i, self.decoder_channel[i],
|
||||
self.decoder_channel[i+1], kernel_size=[2], stride=2, nempty=nempty)
|
||||
for i in range(self.stages)])
|
||||
self.decoder = torch.nn.ModuleList(
|
||||
[ocnn.OctreeResBlocks(depth + i + 1, channel[i],
|
||||
self.decoder_channel[i+1], self.decoder_blocks[i], self.bottleneck,
|
||||
nempty, self.resblk, self.use_checkpoint) for i in range(self.stages)])
|
||||
|
||||
# interpolation
|
||||
self.octree_interp = ocnn.OctreeInterp(self.depth, interp, nempty)
|
||||
|
||||
# header
|
||||
self.header = self.make_predict_module(self.decoder_channel[-1], nout)
|
||||
|
||||
def config_network(self):
|
||||
self.encoder_channel = [32, 32, 64, 128, 256]
|
||||
self.decoder_channel = [256, 256, 128, 96, 96]
|
||||
self.encoder_blocks = [2, 3, 4, 6]
|
||||
self.decoder_blocks = [2, 2, 2, 2]
|
||||
self.bottleneck = 1
|
||||
self.resblk = ocnn.OctreeResBlock2
|
||||
|
||||
def make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
|
||||
return torch.nn.Sequential(
|
||||
ocnn.OctreeConv1x1BnRelu(channel_in, num_hidden),
|
||||
ocnn.OctreeConv1x1(num_hidden, channel_out, use_bias=True))
|
||||
|
||||
def forward(self, octree, pts=None):
|
||||
depth = self.depth
|
||||
data = ocnn.octree_feature(octree, depth, self.nempty)
|
||||
assert data.size(1) == self.channel_in
|
||||
|
||||
# encoder
|
||||
convd = [None] * 16
|
||||
convd[depth] = self.conv1(data, octree)
|
||||
stages = len(self.encoder_blocks)
|
||||
for i in range(stages):
|
||||
depth_i = depth - i - 1
|
||||
conv = self.downsample[i](convd[depth_i+1], octree)
|
||||
convd[depth_i] = self.encoder[i](conv, octree)
|
||||
|
||||
# decoder
|
||||
depth = depth - stages
|
||||
deconv = convd[depth]
|
||||
for i in range(stages):
|
||||
depth_i = depth + i + 1
|
||||
deconv = self.upsample[i](deconv, octree)
|
||||
deconv = torch.cat([convd[depth_i], deconv], dim=1) # skip connections
|
||||
deconv = self.decoder[i](deconv, octree)
|
||||
|
||||
# point/voxel feature
|
||||
feature = deconv
|
||||
if pts is not None:
|
||||
feature = self.octree_interp(feature, octree, pts)
|
||||
|
||||
# header
|
||||
logits = self.header(feature)
|
||||
logits = logits.squeeze().t() # (1, C, H, 1) -> (H, C)
|
||||
return logits
|
|
@ -1,108 +1,46 @@
|
|||
import os
|
||||
import torch
|
||||
import ocnn
|
||||
from tqdm import tqdm
|
||||
from config import parse_args
|
||||
from modelnet import ModelNet40
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch
|
||||
|
||||
from solver import Solver, Dataset, parse_args
|
||||
|
||||
|
||||
def get_dataloader(flags, train=True):
|
||||
transform = ocnn.TransformCompose(flags)
|
||||
dataset = ModelNet40(flags.location, train, transform, in_memory=True)
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=flags.batch_size, shuffle=train, pin_memory=True,
|
||||
num_workers=flags.num_workers, collate_fn=ocnn.collate_octrees)
|
||||
return data_loader
|
||||
class ClsSolver(Solver):
|
||||
def get_model(self, flags):
|
||||
if flags.name.lower() == 'lenet':
|
||||
model = ocnn.LeNet(flags.depth, flags.channel, flags.nout)
|
||||
elif flags.name.lower() == 'resnet':
|
||||
model = ocnn.ResNet(flags.depth, flags.channel, flags.nout,
|
||||
flags.resblock_num)
|
||||
else:
|
||||
raise ValueError
|
||||
return model
|
||||
|
||||
def get_dataset(self, flags):
|
||||
transform = ocnn.TransformCompose(flags)
|
||||
dataset = Dataset(flags.location, flags.filelist, transform, in_memory=True)
|
||||
return dataset, ocnn.collate_octrees
|
||||
|
||||
def train_step(self, batch):
|
||||
octree, label = batch['octree'].cuda(), batch['label'].cuda()
|
||||
logits = self.model(octree)
|
||||
log_softmax = torch.nn.functional.log_softmax(logits, dim=1)
|
||||
loss = torch.nn.functional.nll_loss(log_softmax, label)
|
||||
return {'train/loss': loss}
|
||||
|
||||
def test_step(self, batch):
|
||||
octree, label = batch['octree'].cuda(), batch['label'].cuda()
|
||||
logits = self.model(octree)
|
||||
log_softmax = torch.nn.functional.log_softmax(logits, dim=1)
|
||||
loss = torch.nn.functional.nll_loss(log_softmax, label)
|
||||
pred = torch.argmax(logits, dim=1)
|
||||
accu = pred.eq(label).float().mean()
|
||||
return {'test/loss': loss, 'test/accu': accu}
|
||||
|
||||
|
||||
def get_model(flags):
|
||||
if flags.name.lower() == 'lenet':
|
||||
model = ocnn.LeNet(flags.depth, flags.channel, flags.nout)
|
||||
elif flags.name.lower() == 'resnet':
|
||||
model = ocnn.ResNet(flags.depth, flags.channel, flags.nout, flags.resblock_num)
|
||||
else:
|
||||
raise ValueError
|
||||
return model
|
||||
|
||||
|
||||
def train():
|
||||
model.train()
|
||||
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
# get the inputs
|
||||
octrees, labels = data[0].cuda(), data[1].cuda()
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
logits = model(octrees)
|
||||
loss = criterion(logits, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
if i % 100 == 99:
|
||||
tqdm.write('[Train iter: %5d] loss: %.3f' % (i + 1, running_loss / i))
|
||||
return running_loss / i
|
||||
|
||||
|
||||
def test():
|
||||
model.eval()
|
||||
|
||||
accuracy = 0
|
||||
for data in test_loader:
|
||||
octrees, labels = data[0].cuda(), data[1].cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(octrees)
|
||||
pred = logits.argmax(dim=1)
|
||||
accuracy += pred.eq(labels).sum().item()
|
||||
|
||||
accuracy /= len(test_loader.dataset)
|
||||
tqdm.write('[Test] accuracy: %.3f' % accuracy)
|
||||
return accuracy
|
||||
def main(TheSolver):
|
||||
FLAGS = parse_args()
|
||||
Solver.main(FLAGS, TheSolver)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# configs
|
||||
FLAGS = parse_args()
|
||||
|
||||
# data
|
||||
train_loader = get_dataloader(FLAGS.DATA.train, train=True)
|
||||
test_loader = get_dataloader(FLAGS.DATA.test, train=False)
|
||||
|
||||
# model
|
||||
model = get_model(FLAGS.MODEL)
|
||||
model.cuda()
|
||||
print(model)
|
||||
|
||||
# loss and optimizer
|
||||
flags_solver = FLAGS.SOLVER
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer, milestones=flags_solver.step_size, gamma=0.1)
|
||||
|
||||
# summary
|
||||
logdir = flags_solver.logdir
|
||||
writer = SummaryWriter(logdir)
|
||||
ckpt_dir = os.path.join(logdir, 'checkpoints')
|
||||
if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir)
|
||||
# writer.add_graph(model, next(iter(test_loader))[0].cuda())
|
||||
|
||||
# train and test
|
||||
for epoch in tqdm(range(1, flags_solver.max_epoch+1), ncols=80):
|
||||
tqdm.write('[Epoch: %5d]' % epoch)
|
||||
train_loss = train()
|
||||
writer.add_scalar('train_loss', train_loss, epoch)
|
||||
if epoch % flags_solver.test_every_epoch == 0:
|
||||
test_accu = test()
|
||||
writer.add_scalar('test_accu', test_accu, epoch)
|
||||
ckpt_name = os.path.join(ckpt_dir, 'model_%05d.pth' % epoch)
|
||||
torch.save(model.state_dict(), ckpt_name)
|
||||
scheduler.step()
|
||||
main(ClsSolver)
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
import ocnn
|
||||
import torch
|
||||
|
||||
from solver import Solver, parse_args, get_config
|
||||
from datasets import get_completion_dataset
|
||||
|
||||
|
||||
class CompletionSolver(Solver):
|
||||
def get_model(self, flags):
|
||||
return ocnn.OUNet(flags.depth, flags.channel, flags.nout, flags.full_depth)
|
||||
|
||||
def get_dataset(self, flags):
|
||||
return get_completion_dataset(flags)
|
||||
|
||||
def model_forward(self, batch):
|
||||
octree_in, octree_gt = batch['octree_in'].cuda(), batch['octree'].cuda()
|
||||
output = self.model(octree_in, octree_gt, run='compute_loss')
|
||||
losses = [val for key, val in output.items() if 'loss' in key]
|
||||
output['loss'] = torch.sum(torch.stack(losses))
|
||||
return output
|
||||
|
||||
def train_step(self, batch):
|
||||
output = self.model_forward(batch)
|
||||
output = {'train/' + key: val for key, val in output.items()}
|
||||
return output
|
||||
|
||||
def test_step(self, batch):
|
||||
output = self.model_forward(batch)
|
||||
output = {'test/' + key: val for key, val in output.items()}
|
||||
return output
|
||||
|
||||
def eval_step(self, batch):
|
||||
octree_in = batch['octree_in'].cuda()
|
||||
octree_out = self.model(octree_in, run='decode_shape')
|
||||
|
||||
iter_num = batch['iter_num']
|
||||
filename = os.path.join(self.logdir, '%04d.input.octree' % iter_num)
|
||||
octree_in.cpu().numpy().tofile(filename)
|
||||
filename = os.path.join(self.logdir, '%04d.output.octree' % iter_num)
|
||||
octree_out.cpu().numpy().tofile(filename)
|
||||
|
||||
|
||||
def main(TheSolver):
|
||||
get_config().DATA.train.camera_path = '_' # used to generate partial scans
|
||||
get_config().DATA.test.camera_path = '_'
|
||||
get_config().DATA.train.scan = True
|
||||
get_config().DATA.test.scan = True
|
||||
get_config().MODEL.skip_connections = True
|
||||
get_config().MODEL.full_depth = 2
|
||||
|
||||
FLAGS = parse_args()
|
||||
Solver.main(FLAGS, TheSolver)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(CompletionSolver)
|
|
@ -1,134 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import argparse
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
|
||||
# SOLVER related parameters
|
||||
_C.SOLVER = CN()
|
||||
_C.SOLVER.gpu = (0,) # The gpu ids
|
||||
_C.SOLVER.logdir = 'logs' # Directory where to write event logs
|
||||
_C.SOLVER.ckpt = '' # Restore weights from checkpoint file
|
||||
_C.SOLVER.run = 'train' # Choose from train or test
|
||||
_C.SOLVER.type = 'sgd' # Choose from sgd or adam
|
||||
_C.SOLVER.max_epoch = 300 # Maximum training iterations
|
||||
_C.SOLVER.test_iter = 100 # Test steps in testing phase
|
||||
_C.SOLVER.test_every_epoch= 10 # Test model every n training epochs
|
||||
_C.SOLVER.lr_type = 'step' # Learning rate type: step or cos
|
||||
_C.SOLVER.learning_rate = 0.1 # Initial learning rate
|
||||
_C.SOLVER.gamma = 0.1 # Learning rate step-wise decay
|
||||
_C.SOLVER.step_size = (40000,) # Learning rate step size.
|
||||
_C.SOLVER.ckpt_num = 100 # The number of checkpoint kept
|
||||
_C.SOLVER.verbose = False # Whether to output some messages
|
||||
|
||||
|
||||
# DATA related parameters
|
||||
_C.DATA = CN()
|
||||
_C.DATA.train = CN()
|
||||
_C.DATA.train.name = '' # The name of the dataset
|
||||
|
||||
_C.DATA.train.depth = 5 # The octree depth
|
||||
_C.DATA.train.full_depth = 2 # The full depth
|
||||
_C.DATA.train.node_dis = False # Save the node displacement
|
||||
_C.DATA.train.split_label= False # Save the split label
|
||||
_C.DATA.train.adaptive = False # Build the adaptive octree
|
||||
_C.DATA.train.node_feat = False # Calculate the node feature
|
||||
|
||||
_C.DATA.train.distort = False # Whether to apply data augmentation
|
||||
_C.DATA.train.offset = 0.016 # Offset used to displace the points
|
||||
_C.DATA.train.scale = 0.0 # Scale the points
|
||||
_C.DATA.train.uniform = False # Generate uniform scales
|
||||
_C.DATA.train.jitter = 0.0 # Jitter the points
|
||||
_C.DATA.train.interval = (1, 1, 1) # Use interval&angle to generate random angle
|
||||
_C.DATA.train.angle = (180, 180, 180)
|
||||
|
||||
_C.DATA.train.location = '' # The data location
|
||||
_C.DATA.train.filelist = '' # The data filelist
|
||||
_C.DATA.train.batch_size = 32 # Training data batch size
|
||||
_C.DATA.train.num_workers= 8 # Number of workers to load the data
|
||||
|
||||
|
||||
_C.DATA.test = _C.DATA.train.clone()
|
||||
|
||||
|
||||
# MODEL related parameters
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.name = '' # The name of the model
|
||||
_C.MODEL.depth = 5 # The input octree depth
|
||||
_C.MODEL.depth_out = 5 # The output feature depth
|
||||
_C.MODEL.channel = 3 # The input feature channel
|
||||
_C.MODEL.factor = 1 # The factor used to widen the network
|
||||
_C.MODEL.nout = 40 # The output feature channel
|
||||
_C.MODEL.nouts = 40, # The output feature channels
|
||||
_C.MODEL.resblock_num = 3 # The resblock number
|
||||
_C.MODEL.bottleneck = 4 # The bottleneck factor of one resblock
|
||||
_C.MODEL.dropout = (0.0,) # The dropout ratio
|
||||
_C.MODEL.signal_abs = False # Use the absolute value of signal
|
||||
_C.MODEL.upsample = 'nearest' # The method used for upsampling
|
||||
|
||||
|
||||
# loss related parameters
|
||||
_C.LOSS = CN()
|
||||
_C.LOSS.num_class = 40 # The class number for the cross-entropy loss
|
||||
_C.LOSS.weight_decay = 0.0005 # The weight decay on model weights
|
||||
_C.LOSS.sigma = 0.1 # Use for MID training
|
||||
_C.LOSS.momentum = 0.5 # Use for MID training
|
||||
_C.LOSS.inst_num = 57449 # The object number in MID training
|
||||
_C.LOSS.seg_num = 100 # The clustering number in MID training
|
||||
_C.LOSS.weights = (1.0, 1.0) # The weight factors for different losses
|
||||
_C.LOSS.label_smoothing = 0.0 # The factor of label smoothing
|
||||
|
||||
|
||||
# backup the commands
|
||||
_C.SYS = CN()
|
||||
_C.SYS.cmds = '' # Used to backup the commands
|
||||
|
||||
FLAGS = _C
|
||||
|
||||
|
||||
def _update_config(FLAGS, args):
|
||||
FLAGS.defrost()
|
||||
if args.config:
|
||||
FLAGS.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
FLAGS.merge_from_list(args.opts)
|
||||
FLAGS.SYS.cmds = ' '.join(sys.argv)
|
||||
FLAGS.freeze()
|
||||
|
||||
def _backup_config(FLAGS, args):
|
||||
logdir = FLAGS.SOLVER.logdir
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
# copy the file to logdir
|
||||
if args.config:
|
||||
shutil.copy2(args.config, logdir)
|
||||
# dump all configs
|
||||
filename = os.path.join(logdir, 'all_configs.yaml')
|
||||
with open(filename, 'w') as fid:
|
||||
fid.write(FLAGS.dump())
|
||||
|
||||
def _set_env_var(FLAGS):
|
||||
gpus = ','.join([str(a) for a in FLAGS.SOLVER.gpu])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
|
||||
|
||||
def parse_args(backup=True):
|
||||
parser = argparse.ArgumentParser(description='The configs')
|
||||
parser.add_argument('--config',
|
||||
help='experiment configure file name',
|
||||
type=str)
|
||||
parser.add_argument('opts',
|
||||
help="Modify config options using the command-line",
|
||||
nargs=argparse.REMAINDER)
|
||||
args = parser.parse_args()
|
||||
_update_config(FLAGS, args)
|
||||
if backup: _backup_config(FLAGS, args)
|
||||
_set_env_var(FLAGS)
|
||||
return FLAGS
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags = parse_args(backup=False)
|
||||
print(flags)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
SOLVER:
|
||||
gpu: 0,
|
||||
logdir: logs/m40/0311_lenet
|
||||
logdir: logs/m40/m40
|
||||
run: train
|
||||
max_epoch: 300
|
||||
test_every_epoch: 5
|
||||
|
@ -15,14 +15,18 @@ DATA:
|
|||
interval: (1, 1, 1)
|
||||
scale: 0.25
|
||||
jitter: 0.125
|
||||
location: dataset/ModelNet40.points
|
||||
location: data/ModelNet40/ModelNet40.points
|
||||
filelist: data/ModelNet40/m40_train_points_list.txt
|
||||
batch_size: 32
|
||||
shuffle: True
|
||||
|
||||
test:
|
||||
distort: False
|
||||
depth: 5
|
||||
location: dataset/ModelNet40.points
|
||||
location: data/ModelNet40/ModelNet40.points
|
||||
filelist: data/ModelNet40/m40_test_points_list.txt
|
||||
batch_size: 32
|
||||
shuffle: False
|
||||
|
||||
MODEL:
|
||||
name: lenet
|
||||
|
@ -31,5 +35,4 @@ MODEL:
|
|||
depth: 5
|
||||
|
||||
LOSS:
|
||||
num_class: 40
|
||||
weight_decay: 0.0005
|
||||
num_class: 40
|
|
@ -0,0 +1,31 @@
|
|||
SOLVER:
|
||||
gpu: 0,
|
||||
logdir: logs/completion/skip_connections_test
|
||||
ckpt: logs/completion/skip_connectinos_07191553/checkpoints/model_00200.pth
|
||||
run: evaluate
|
||||
|
||||
|
||||
DATA:
|
||||
test:
|
||||
name: completion
|
||||
distort: False
|
||||
location: data/ocnn_completion/test.scans.points
|
||||
filelist: data/ocnn_completion/filelist_test_scans.txt
|
||||
batch_size: 1
|
||||
depth: 6
|
||||
full_depth: 2
|
||||
offset: 0.0
|
||||
node_dis: True
|
||||
split_label: True
|
||||
radius: 64.0
|
||||
center: (64.0,64.0,64.0)
|
||||
scan: False
|
||||
shuffle: False
|
||||
|
||||
|
||||
MODEL:
|
||||
channel: 4
|
||||
depth: 6
|
||||
nout: 4
|
||||
full_depth: 2
|
||||
skip_connections: True
|
|
@ -0,0 +1,52 @@
|
|||
SOLVER:
|
||||
gpu: 0,
|
||||
logdir: logs/completion/skip_connectinos
|
||||
run: train
|
||||
max_epoch: 200
|
||||
test_every_epoch: 10
|
||||
step_size: (100,150)
|
||||
ckpt_num: 20
|
||||
|
||||
|
||||
DATA:
|
||||
train:
|
||||
name: completion
|
||||
distort: False
|
||||
location: data/ocnn_completion/shape.points
|
||||
filelist: data/ocnn_completion/filelist_train.txt
|
||||
camera_path: data/ocnn_completion/completion_train_points.camera_path.dict
|
||||
batch_size: 16
|
||||
depth: 6
|
||||
offset: 0.0
|
||||
full_depth: 2
|
||||
node_dis: True
|
||||
split_label: True
|
||||
radius: 64.0
|
||||
center: (64.0,64.0,64.0)
|
||||
shuffle: True
|
||||
# num_workers: 0
|
||||
|
||||
test:
|
||||
name: completion
|
||||
distort: False
|
||||
location: data/ocnn_completion/shape.points
|
||||
filelist: data/ocnn_completion/filelist_test.txt
|
||||
camera_path: data/ocnn_completion/completion_train_points.camera_path.dict
|
||||
batch_size: 16
|
||||
depth: 6
|
||||
offset: 0.0
|
||||
full_depth: 2
|
||||
node_dis: True
|
||||
split_label: True
|
||||
radius: 64.0
|
||||
center: (64.0,64.0,64.0)
|
||||
shuffle: False
|
||||
# num_workers: 0
|
||||
|
||||
|
||||
MODEL:
|
||||
channel: 4
|
||||
depth: 6
|
||||
nout: 4
|
||||
full_depth: 2
|
||||
skip_connections: True
|
|
@ -0,0 +1,62 @@
|
|||
SOLVER:
|
||||
gpu: 0,
|
||||
logdir: logs/seg_partnet/bed_pts
|
||||
alias: unet
|
||||
run: train
|
||||
max_epoch: 800
|
||||
test_every_epoch: 20
|
||||
step_size: (400,600)
|
||||
ckpt_num: 20
|
||||
|
||||
DATA:
|
||||
train:
|
||||
# octree building
|
||||
depth: 6
|
||||
node_dis: True
|
||||
|
||||
# points transform
|
||||
offset: 0.0 # do not offset points along normal direction
|
||||
normal_axis: y # re-orient normals along y axis
|
||||
|
||||
# data augmentations
|
||||
distort: True
|
||||
angle: (0, 5, 0)
|
||||
interval: (1, 1, 1)
|
||||
scale: 0.25
|
||||
jitter: 0.125
|
||||
uniform: True
|
||||
|
||||
# data loading
|
||||
location: data/partnet_segmentation/data/Bed
|
||||
filelist: data/partnet_segmentation/data/Bed_train_level3.txt
|
||||
batch_size: 32
|
||||
shuffle: True
|
||||
|
||||
test:
|
||||
# octree building
|
||||
depth: 6
|
||||
node_dis: True
|
||||
|
||||
# points transform
|
||||
offset: 0.0 # do not offset points along normal direction
|
||||
normal_axis: y # re-orient normals along y axis
|
||||
|
||||
# no data augmentation
|
||||
distort: False
|
||||
|
||||
# data loading
|
||||
location: data/partnet_segmentation/data/Bed
|
||||
filelist: data/partnet_segmentation/data/Bed_test_level3.txt
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
|
||||
MODEL:
|
||||
name: unet
|
||||
channel: 4
|
||||
nout: 15
|
||||
depth: 6
|
||||
|
||||
LOSS:
|
||||
mask: 0
|
||||
num_class: 15
|
||||
point_wise: True
|
|
@ -0,0 +1,72 @@
|
|||
SOLVER:
|
||||
gpu: 0,
|
||||
run: train
|
||||
logdir: logs/scannet/D9_2cm
|
||||
max_epoch: 400
|
||||
test_every_epoch: 10
|
||||
weight_decay: 0.0005
|
||||
|
||||
# learning rate
|
||||
lr: 0.05
|
||||
lr_type: poly
|
||||
step_size: (200,300) # has no effect for `poly`
|
||||
|
||||
DATA:
|
||||
train:
|
||||
name: scannet
|
||||
|
||||
# octree building
|
||||
depth: 9
|
||||
node_dis: True
|
||||
offset: 0.0
|
||||
|
||||
# data augmentations
|
||||
distort: True
|
||||
angle: (0, 0, 180)
|
||||
scale: 0.1
|
||||
jitter: 0.1
|
||||
uniform: True
|
||||
|
||||
# data loading
|
||||
location: data/scannet/train
|
||||
filelist: data/scannet/scannetv2_train_new.txt
|
||||
batch_size: 4
|
||||
shuffle: True
|
||||
in_memory: False
|
||||
|
||||
test:
|
||||
name: scannet
|
||||
|
||||
# octree building
|
||||
depth: 9
|
||||
node_dis: True
|
||||
offset: 0.0
|
||||
|
||||
# data augmentations
|
||||
distort: False # no data augmentation
|
||||
angle: (0, 0, 180)
|
||||
scale: 0.1
|
||||
jitter: 0.1
|
||||
uniform: True
|
||||
|
||||
# data loading
|
||||
location: data/scannet/train
|
||||
filelist: data/scannet/scannetv2_val_new.txt
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
in_memory: False
|
||||
|
||||
MODEL:
|
||||
name: unet
|
||||
channel: 7
|
||||
nout: 21
|
||||
depth: 9
|
||||
nempty: True
|
||||
interp: nearest
|
||||
sync_bn: False
|
||||
use_checkpoint: False
|
||||
|
||||
LOSS:
|
||||
mask: 0
|
||||
point_wise: True
|
||||
num_class: 21
|
|
@ -0,0 +1,68 @@
|
|||
SOLVER:
|
||||
gpu: 0,
|
||||
run: train
|
||||
logdir: logs/scannet/D8_4cm
|
||||
max_epoch: 400
|
||||
test_every_epoch: 10
|
||||
weight_decay: 0.0005
|
||||
|
||||
# learning rate
|
||||
lr: 0.05
|
||||
lr_type: poly
|
||||
step_size: (200,300) # has no effect for `poly`
|
||||
|
||||
DATA:
|
||||
train:
|
||||
name: scannet
|
||||
|
||||
# octree building
|
||||
depth: 8
|
||||
node_dis: True
|
||||
offset: 0.0
|
||||
|
||||
# data augmentations
|
||||
distort: True
|
||||
angle: (0, 0, 180)
|
||||
scale: 0.1
|
||||
jitter: 0.1
|
||||
uniform: True
|
||||
|
||||
# data loading
|
||||
location: data/scannet/train
|
||||
filelist: data/scannet/scannetv2_train_new.txt
|
||||
batch_size: 4
|
||||
shuffle: True
|
||||
in_memory: False
|
||||
|
||||
test:
|
||||
name: scannet
|
||||
|
||||
# octree building
|
||||
depth: 8
|
||||
node_dis: True
|
||||
offset: 0.0
|
||||
|
||||
# data augmentations
|
||||
distort: False # no data augmentation
|
||||
|
||||
# data loading
|
||||
location: data/scannet/train
|
||||
filelist: data/scannet/scannetv2_val_new.txt
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
in_memory: False
|
||||
|
||||
MODEL:
|
||||
name: unet
|
||||
channel: 7
|
||||
nout: 21
|
||||
depth: 8
|
||||
nempty: True
|
||||
interp: nearest
|
||||
sync_bn: False
|
||||
use_checkpoint: False
|
||||
|
||||
LOSS:
|
||||
mask: 0
|
||||
point_wise: True
|
||||
num_class: 21
|
|
@ -0,0 +1,37 @@
|
|||
SOLVER:
|
||||
gpu: 0,
|
||||
logdir: logs/scannet/D9_2cm_eval
|
||||
run: evaluate
|
||||
eval_epoch: 1
|
||||
ckpt: logs/scannet/D9_2cm_poly_b6_ep600_w1e-4_lr0.1_0807/checkpoints/00600.model.pth
|
||||
|
||||
DATA:
|
||||
test:
|
||||
name: scannet
|
||||
|
||||
# octree building
|
||||
depth: 9
|
||||
node_dis: True
|
||||
offset: 0.0
|
||||
|
||||
# data augmentations
|
||||
distort: False # no data augmentation
|
||||
angle: (0, 0, 180)
|
||||
scale: 0.1
|
||||
jitter: 0.1
|
||||
uniform: True
|
||||
|
||||
location: data/scannet/test
|
||||
filelist: data/scannet/scannetv2_test_new.txt
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
in_memory: False
|
||||
# num_workers: 0
|
||||
|
||||
MODEL:
|
||||
name: unet
|
||||
channel: 7
|
||||
nout: 21
|
||||
depth: 9
|
||||
nempty: True
|
||||
interp: nearest
|
|
@ -1,5 +1,3 @@
|
|||
# Parameters for the airplane
|
||||
|
||||
SOLVER:
|
||||
gpu: 0,
|
||||
logdir: logs/seg/02691156_airplane
|
||||
|
@ -9,6 +7,7 @@ SOLVER:
|
|||
step_size: (120,180,240)
|
||||
ckpt_num: 20
|
||||
|
||||
|
||||
DATA:
|
||||
train:
|
||||
distort: True
|
||||
|
@ -19,25 +18,29 @@ DATA:
|
|||
jitter: 0.25
|
||||
uniform: True
|
||||
node_dis: True
|
||||
location: dataset/shapenet_segmentation/points
|
||||
filelist: dataset/shapenet_segmentation/train_test_split/02691156_train_val.txt
|
||||
location: data/shapenet_segmentation/points
|
||||
filelist: data/shapenet_segmentation/train_test_split/02691156_train_val.txt
|
||||
batch_size: 32
|
||||
shuffle: True
|
||||
|
||||
test:
|
||||
distort: False # no data augmentation
|
||||
depth: 6
|
||||
node_dis: True
|
||||
location: dataset/shapenet_segmentation/points
|
||||
filelist: dataset/shapenet_segmentation/train_test_split/02691156_test.txt
|
||||
location: data/shapenet_segmentation/points
|
||||
filelist: data/shapenet_segmentation/train_test_split/02691156_test.txt
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
|
||||
|
||||
MODEL:
|
||||
name: segnet
|
||||
channel: 4
|
||||
nout: 4
|
||||
depth: 6
|
||||
depth_out: 6
|
||||
|
||||
|
||||
LOSS:
|
||||
mask: -1
|
||||
point_wise: True
|
||||
num_class: 4
|
||||
weight_decay: 0.0005
|
|
@ -1,38 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, root, filelist, transform=None, in_memory=True):
|
||||
super(Dataset, self).__init__()
|
||||
self.root = root
|
||||
self.filelist = filelist
|
||||
self.transform = transform
|
||||
self.in_memory = in_memory
|
||||
self.samples, self.labels = self.load_dataset()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.samples[idx] if self.in_memory else \
|
||||
np.fromfile(self.samples[idx], dtype=np.uint8)
|
||||
sample = torch.from_numpy(sample) # convert it to torch.tensor
|
||||
if self.transform: # transform the sample
|
||||
sample = self.transform(sample)
|
||||
return sample, self.labels[idx]
|
||||
|
||||
def load_dataset(self):
|
||||
samples, labels = [], []
|
||||
tqdm.write('Load from ' + self.filelist)
|
||||
with open(self.filelist) as fid:
|
||||
lines = fid.readlines()
|
||||
for line in tqdm(lines, ncols=80):
|
||||
filename, label = line.split()
|
||||
filename_abs = os.path.join(self.root, filename)
|
||||
samples.append(np.fromfile(filename_abs, dtype=np.uint8) \
|
||||
if self.in_memory else filename_abs)
|
||||
labels.append(int(label))
|
||||
return samples, labels
|
|
@ -0,0 +1,2 @@
|
|||
from .scannet import get_scannet_dataset
|
||||
from .completion import get_completion_dataset
|
|
@ -0,0 +1,106 @@
|
|||
import ocnn
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from solver import Dataset
|
||||
|
||||
|
||||
class ScanOctree:
|
||||
def __init__(self, camera_path, scan=True):
|
||||
self.scan = scan
|
||||
self.camera_path = camera_path
|
||||
if self.scan:
|
||||
with open(camera_path, 'rb') as fid:
|
||||
self.camera = pickle.load(fid)
|
||||
|
||||
def generate_scan_axis(self, i):
|
||||
j = np.random.randint(0, 8)
|
||||
key = '%d_%d' % (i, j)
|
||||
axes = np.array(self.camera[key])
|
||||
|
||||
# perturb the axes
|
||||
rnd = np.random.random(axes.shape) * 0.4 - 0.2
|
||||
axes = np.reshape(axes + rnd, (-1, 3))
|
||||
|
||||
# normalize the axes
|
||||
axes = axes / np.sqrt(np.sum(axes ** 2, axis=1, keepdims=True) + 1.0e-6)
|
||||
axes = np.reshape(axes, (-1)).astype(np.float32).tolist()
|
||||
return axes
|
||||
|
||||
def __call__(self, octree, idx):
|
||||
if self.scan:
|
||||
scan_axis = self.generate_scan_axis(idx)
|
||||
partial_octree = ocnn.octree_scan(octree, scan_axis)
|
||||
return partial_octree
|
||||
else:
|
||||
return octree
|
||||
|
||||
|
||||
class ScanTransform(ocnn.TransformCompose):
|
||||
def __init__(self, flags):
|
||||
super().__init__(flags)
|
||||
self.scan_octree = ScanOctree(flags.camera_path, flags.scan)
|
||||
|
||||
def __call__(self, points, idx):
|
||||
# apply the default transformation provided by ocnn
|
||||
output = super().__call__(points, idx)
|
||||
# generate the partial octree via virtual scanning
|
||||
output['octree_in'] = self.scan_octree(output['octree'], idx)
|
||||
return output
|
||||
|
||||
|
||||
class Noise2cleanTransform:
|
||||
# Follow the [preprocess steps](https://github.com/autonomousvision/occupancy_networks#Building-the-dataset)
|
||||
# of `Occupancy Networks` to generate the training data.
|
||||
def __init__(self, flags):
|
||||
self.points_number = 3000
|
||||
self.points_scale = 0.95
|
||||
self.noise_std = 0.01 * self.points_scale
|
||||
|
||||
self.points2octree = ocnn.Points2Octree(**flags)
|
||||
|
||||
def __call__(self, point_cloud, idx):
|
||||
# get the input
|
||||
points, normals = point_cloud['points'], point_cloud['normals']
|
||||
|
||||
# normalize the points
|
||||
bbmin, bbmax = np.min(points, axis=0), np.max(points, axis=0)
|
||||
center = (bbmin + bbmax) / 2.0
|
||||
radius = 2.0 / (np.max(bbmax - bbmin) + 1.0e-6)
|
||||
points = (points - center) * radius # normalize to [-1, 1]
|
||||
points *= self.points_scale # normalize to [-points_scale, points_scale]
|
||||
|
||||
# randomly sample points and add noise
|
||||
noise = self.noise_std * np.random.randn(self.points_number, 3)
|
||||
rand_idx = np.random.choice(points.shape[0], size=self.points_number)
|
||||
points_noise = points[rand_idx] + noise
|
||||
|
||||
# transform points to octree
|
||||
points_gt = ocnn.points_new(
|
||||
torch.from_numpy(points).float(), torch.from_numpy(normals).float(),
|
||||
torch.Tensor(), torch.Tensor())
|
||||
points_gt, _ = ocnn.clip_points(points_gt, [-1.0]*3, [1.0]*3)
|
||||
octree_gt = self.points2octree(points_gt)
|
||||
|
||||
points_in = ocnn.points_new(
|
||||
torch.from_numpy(points_noise).float(), torch.Tensor(),
|
||||
torch.ones(self.points_number).float(), torch.Tensor())
|
||||
points_in, _ = ocnn.clip_points(points_in, [-1.0]*3, [1.0]*3)
|
||||
octree_in = self.points2octree(points_in)
|
||||
|
||||
return {'octree_in': octree_in, 'points_in': points_in,
|
||||
'octree': octree_gt, 'points': points_gt}
|
||||
|
||||
|
||||
def get_completion_dataset(flags):
|
||||
if flags.name == 'completion':
|
||||
transform = ScanTransform(flags)
|
||||
elif flags.name == 'noise2clean':
|
||||
transform = Noise2cleanTransform(flags)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
dataset = Dataset(flags.location, flags.filelist, transform,
|
||||
in_memory=flags.in_memory)
|
||||
return dataset, ocnn.collate_octrees
|
|
@ -0,0 +1,154 @@
|
|||
import ocnn
|
||||
import torch
|
||||
import numpy as np
|
||||
import scipy.interpolate
|
||||
import scipy.ndimage
|
||||
import random
|
||||
from plyfile import PlyData
|
||||
|
||||
from solver import Dataset
|
||||
|
||||
|
||||
def read_file(filename):
|
||||
def _read_ply(filename):
|
||||
plydata = PlyData.read(filename)
|
||||
vtx = plydata['vertex']
|
||||
xyz = np.stack([vtx['x'], vtx['y'], vtx['z']], axis=1).astype(np.float32)
|
||||
color = np.stack([vtx['red'], vtx['green'], vtx['blue']], axis=1).astype(np.float32)
|
||||
label = np.asarray(vtx['label']).astype(np.float32)
|
||||
normal = np.stack([vtx['nx'], vtx['ny'], vtx['nz']], axis=1).astype(np.float32)
|
||||
return xyz, color, label, normal
|
||||
|
||||
def _read_points(filename):
|
||||
points = torch.from_numpy(np.fromfile(filename, dtype=np.uint8))
|
||||
xyz = ocnn.points_property(points, 'xyz').numpy()
|
||||
label = ocnn.points_property(points, 'label').squeeze().numpy()
|
||||
normal = ocnn.points_property(points, 'normal').numpy()
|
||||
color = ocnn.points_property(points, 'feature').numpy() * 255 # !!! RGB
|
||||
return xyz, color, label, normal
|
||||
|
||||
if filename.endswith('.ply'):
|
||||
return _read_ply(filename)
|
||||
elif filename.endswith('.points'):
|
||||
return _read_points(filename)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def color_distort(color, trans_range_ratio, jitter_std):
|
||||
def _color_autocontrast(color, rand_blend_factor=True, blend_factor=0.5):
|
||||
assert color.shape[1] >= 3
|
||||
lo = color[:, :3].min(0, keepdims=True)
|
||||
hi = color[:, :3].max(0, keepdims=True)
|
||||
assert hi.max() > 1
|
||||
|
||||
scale = 255 / (hi - lo)
|
||||
contrast_feats = (color[:, :3] - lo) * scale
|
||||
|
||||
blend_factor = random.random() if rand_blend_factor else blend_factor
|
||||
color[:, :3] = (1 - blend_factor) * color + blend_factor * contrast_feats
|
||||
return color
|
||||
|
||||
def _color_translation(color, trans_range_ratio=0.1):
|
||||
assert color.shape[1] >= 3
|
||||
if random.random() < 0.95:
|
||||
tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * trans_range_ratio
|
||||
color[:, :3] = np.clip(tr + color[:, :3], 0, 255)
|
||||
return color
|
||||
|
||||
def _color_jiter(color, std=0.01):
|
||||
if random.random() < 0.95:
|
||||
noise = np.random.randn(color.shape[0], 3)
|
||||
noise *= std * 255
|
||||
color[:, :3] = np.clip(noise + color[:, :3], 0, 255)
|
||||
return color
|
||||
|
||||
color = color * 255.0
|
||||
color = _color_autocontrast(color)
|
||||
color = _color_translation(color, trans_range_ratio)
|
||||
color = _color_jiter(color, jitter_std)
|
||||
color = color / 255.0
|
||||
return color
|
||||
|
||||
|
||||
def elastic_distort(points, distortion_params):
|
||||
def _elastic_distort(coords, granularity, magnitude):
|
||||
blurx = np.ones((3, 1, 1, 1)).astype('float32') / 3
|
||||
blury = np.ones((1, 3, 1, 1)).astype('float32') / 3
|
||||
blurz = np.ones((1, 1, 3, 1)).astype('float32') / 3
|
||||
coords_min = coords.min(0)
|
||||
|
||||
noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
|
||||
noise = np.random.randn(*noise_dim, 3).astype(np.float32)
|
||||
|
||||
# Smoothing.
|
||||
convolve = scipy.ndimage.filters.convolve
|
||||
for _ in range(2):
|
||||
noise = convolve(noise, blurx, mode='constant', cval=0)
|
||||
noise = convolve(noise, blury, mode='constant', cval=0)
|
||||
noise = convolve(noise, blurz, mode='constant', cval=0)
|
||||
|
||||
# Trilinear interpolate noise filters for each spatial dimensions.
|
||||
ax = [np.linspace(d_min, d_max, d)
|
||||
for d_min, d_max, d in zip(coords_min - granularity,
|
||||
coords_min + granularity*(noise_dim - 2),
|
||||
noise_dim)]
|
||||
|
||||
interp = scipy.interpolate.RegularGridInterpolator(
|
||||
ax, noise, bounds_error=0, fill_value=0)
|
||||
coords += interp(coords) * magnitude
|
||||
return coords
|
||||
|
||||
assert distortion_params.shape[1] == 2
|
||||
if random.random() < 0.95:
|
||||
for granularity, magnitude in distortion_params:
|
||||
points = _elastic_distort(points, granularity, magnitude)
|
||||
return points
|
||||
|
||||
|
||||
class TransformScanNet:
|
||||
def __init__(self, flags):
|
||||
self.flags = flags
|
||||
self.scale_factor = 5.12
|
||||
self.color_trans_ratio = 0.10
|
||||
self.color_jit_std = 0.05
|
||||
self.elastic_params = np.array([[0.2, 0.4], [0.8, 1.6]], np.float32)
|
||||
|
||||
def transform_scannet(self, sample):
|
||||
# read ply file
|
||||
# xyz, color, label, normal = read_file(filename)
|
||||
xyz, color, label, normal = sample
|
||||
|
||||
# normalization
|
||||
center = (xyz.min(axis=0) + xyz.max(axis=0)) / 2.0
|
||||
xyz = (xyz - center) / self.scale_factor # xyz in [-1, 1]
|
||||
color = color / 255.0
|
||||
|
||||
# data augmentation
|
||||
if self.flags.distort:
|
||||
color = color_distort(color, self.color_trans_ratio, self.color_jit_std)
|
||||
xyz = elastic_distort(xyz, self.elastic_params)
|
||||
|
||||
points = ocnn.points_new(torch.from_numpy(xyz), torch.from_numpy(normal),
|
||||
torch.from_numpy(color), torch.from_numpy(label))
|
||||
return points
|
||||
|
||||
def __call__(self, sample, idx=None):
|
||||
# transformation specified for scannet
|
||||
points = self.transform_scannet(sample)
|
||||
|
||||
# general transformations provided by ocnn
|
||||
# The augmentations including rotation, scaling, and jittering, and the
|
||||
# input points out of [-1, 1] are clipped
|
||||
points, inbox_mask = ocnn.TransformPoints(**self.flags)(points)
|
||||
|
||||
# transform points to octree
|
||||
octree = ocnn.Points2Octree(**self.flags)(points)
|
||||
return {'octree': octree, 'points': points, 'inbox_mask': inbox_mask}
|
||||
|
||||
|
||||
def get_scannet_dataset(flags):
|
||||
transform = TransformScanNet(flags)
|
||||
dataset = Dataset(flags.location, flags.filelist, transform,
|
||||
read_file=read_file, in_memory=flags.in_memory)
|
||||
return dataset, ocnn.collate_octrees
|
|
@ -1,42 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ModelNet40(torch.utils.data.Dataset):
|
||||
def __init__(self, root, train=True, transform=None, in_memory=True):
|
||||
super(ModelNet40, self).__init__()
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.transform = transform
|
||||
self.in_memory = in_memory
|
||||
self.points, self.labels, self.category = self.load_modelnet40()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.points)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
points = self.points[idx] if self.in_memory else \
|
||||
np.fromfile(self.points[idx], dtype=np.uint8)
|
||||
points = torch.from_numpy(points) # convert it to torch.tensor
|
||||
if self.transform:
|
||||
octree = self.transform(points)
|
||||
return octree, self.labels[idx]
|
||||
|
||||
def load_modelnet40(self, suffix='points'):
|
||||
points, labels = [], []
|
||||
folders = sorted(os.listdir(self.root))
|
||||
assert len(folders) == 40
|
||||
for idx, folder in enumerate(folders):
|
||||
subfolder = 'train' if self.train else 'test'
|
||||
current_folder = os.path.join(self.root, folder, subfolder)
|
||||
filenames = sorted(os.listdir(current_folder))
|
||||
for filename in filenames:
|
||||
if filename.endswith(suffix):
|
||||
filename_abs = os.path.join(current_folder, filename)
|
||||
if self.in_memory:
|
||||
points.append(np.fromfile(filename_abs, dtype=np.uint8))
|
||||
else:
|
||||
points.append(filename_abs)
|
||||
labels.append(idx)
|
||||
return points, labels, folders
|
|
@ -0,0 +1,144 @@
|
|||
import os
|
||||
import ocnn
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from solver import Solver, Dataset, parse_args, get_config
|
||||
from datasets import get_scannet_dataset
|
||||
|
||||
|
||||
def loss_function(logit, label):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
loss = criterion(logit, label.long())
|
||||
return loss
|
||||
|
||||
|
||||
def accuracy(logit, label):
|
||||
pred = logit.argmax(dim=1)
|
||||
accu = pred.eq(label).float().mean()
|
||||
return accu
|
||||
|
||||
|
||||
def IoU_per_shape(logit, label, class_num):
|
||||
pred = logit.argmax(dim=1)
|
||||
|
||||
IoU, valid_part_num, esp = 0.0, 0.0, 1.0e-10
|
||||
intsc, union = [None] * class_num, [None] * class_num
|
||||
for k in range(class_num):
|
||||
pk, lk = pred.eq(k), label.eq(k)
|
||||
intsc[k] = torch.sum(torch.logical_and(pk, lk).float())
|
||||
union[k] = torch.sum(torch.logical_or(pk, lk).float())
|
||||
|
||||
valid = torch.sum(lk.any()) > 0
|
||||
valid_part_num += valid.item()
|
||||
IoU += valid * intsc[k] / (union[k] + esp)
|
||||
|
||||
# Calculate the shape IoU for ShapeNet
|
||||
IoU /= valid_part_num + esp
|
||||
return IoU, intsc, union
|
||||
|
||||
|
||||
class SegSolver(Solver):
|
||||
def get_model(self, flags):
|
||||
if flags.name.lower() == 'segnet':
|
||||
model = ocnn.SegNet(flags.depth, flags.channel, flags.nout, flags.interp)
|
||||
elif flags.name.lower() == 'unet':
|
||||
model = ocnn.UNet(flags.depth, flags.channel, flags.nout, flags.nempty,
|
||||
flags.interp, flags.use_checkpoint)
|
||||
else:
|
||||
raise ValueError
|
||||
return model
|
||||
|
||||
def get_dataset(self, flags):
|
||||
if flags.name.lower() == 'scannet':
|
||||
return get_scannet_dataset(flags)
|
||||
else:
|
||||
transform = ocnn.TransformCompose(flags)
|
||||
dataset = Dataset(flags.location, flags.filelist, transform,
|
||||
in_memory=flags.in_memory)
|
||||
return dataset, ocnn.collate_octrees
|
||||
|
||||
def parse_batch(self, batch):
|
||||
octree = batch['octree'].cuda()
|
||||
if self.FLAGS.LOSS.point_wise:
|
||||
points = batch['points']
|
||||
pts = ocnn.points_batch_property(points, 'xyzi').cuda()
|
||||
label = ocnn.points_batch_property(points, 'label').squeeze().cuda()
|
||||
else:
|
||||
pts = None
|
||||
label = ocnn.octree_property(octree, 'label', self.FLAGS.MODEL.depth)
|
||||
if self.FLAGS.MODEL.nempty:
|
||||
child = ocnn.octree_property(octree, 'child', self.FLAGS.MODEL.depth)
|
||||
label = label[child >= 0]
|
||||
return octree, pts, label
|
||||
|
||||
def model_forward(self, batch):
|
||||
octree, pts, label = self.parse_batch(batch)
|
||||
logit = self.model(octree, pts)
|
||||
label_mask = label > self.FLAGS.LOSS.mask # filter labels
|
||||
return logit[label_mask], label[label_mask]
|
||||
|
||||
def train_step(self, batch):
|
||||
logit, label = self.model_forward(batch)
|
||||
loss = loss_function(logit, label)
|
||||
return {'train/loss': loss}
|
||||
|
||||
def test_step(self, batch):
|
||||
logit, label = self.model_forward(batch)
|
||||
if logit.shape[0] == 0:
|
||||
return None # degenerated case
|
||||
loss = loss_function(logit, label)
|
||||
accu = accuracy(logit, label)
|
||||
num_class = self.FLAGS.LOSS.num_class
|
||||
IoU, insc, union = IoU_per_shape(logit, label, num_class)
|
||||
|
||||
names = ['test/loss', 'test/accu', 'test/mIoU'] + \
|
||||
['test/intsc_%d' % i for i in range(num_class)] + \
|
||||
['test/union_%d' % i for i in range(num_class)]
|
||||
tensors = [loss, accu, IoU] + insc + union
|
||||
return dict(zip(names, tensors))
|
||||
|
||||
def eval_step(self, batch):
|
||||
octree = batch['octree'].cuda()
|
||||
pts = ocnn.points_batch_property(batch['points'], 'xyzi').cuda()
|
||||
logit = self.model(octree, pts)
|
||||
prob = torch.nn.functional.softmax(logit, dim=1)
|
||||
label = prob.argmax(dim=1)
|
||||
|
||||
assert len(batch['inbox_mask']) == 1, 'The batch_size must be 1'
|
||||
filename = '%02d.%04d.npz' % (batch['epoch'], batch['iter_num'])
|
||||
np.savez(os.path.join(self.logdir, filename),
|
||||
prob=prob.cpu().numpy(),
|
||||
label=label.cpu().numpy(),
|
||||
inbox_mask=batch['inbox_mask'][0].numpy().astype(bool))
|
||||
|
||||
def result_callback(self, avg_tracker, epoch):
|
||||
''' Calculate the part mIoU for PartNet and ScanNet'''
|
||||
avg = avg_tracker.average()
|
||||
|
||||
iou_part = 0.0
|
||||
# Labels smaller than mask is ignored. The points with the label 0 in
|
||||
# PartNet are background points, i.e., unlabeled points
|
||||
mask = self.FLAGS.LOSS.mask + 1
|
||||
num_class = self.FLAGS.LOSS.num_class
|
||||
for i in range(mask, num_class):
|
||||
instc_i = avg['test/intsc_%d' % i]
|
||||
union_i = avg['test/union_%d' % i]
|
||||
iou_part += instc_i / (union_i + 1.0e-10)
|
||||
iou_part = iou_part / (num_class - mask)
|
||||
if self.summry_writer:
|
||||
self.summry_writer.add_scalar('test/mIoU_part', iou_part, epoch)
|
||||
else:
|
||||
print('Epoch: %d, test/mIoU_part: %f' % (epoch, iou_part))
|
||||
|
||||
|
||||
def main(TheSolver):
|
||||
get_config().LOSS.mask = -1 # mask the invalid labels
|
||||
get_config().LOSS.point_wise = False # point-wise loss or voxel-wise loss
|
||||
|
||||
FLAGS = parse_args()
|
||||
Solver.main(FLAGS, TheSolver)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(SegSolver)
|
|
@ -1,150 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
import ocnn
|
||||
from tqdm import tqdm
|
||||
from config import parse_args
|
||||
from dataset import Dataset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
def get_dataloader(flags, train=True):
|
||||
transform = ocnn.TransformCompose(flags)
|
||||
dataset = Dataset(flags.location, flags.filelist, transform, in_memory=True)
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=flags.batch_size, shuffle=train, pin_memory=True,
|
||||
num_workers=flags.num_workers, collate_fn=ocnn.collate_octrees)
|
||||
return data_loader
|
||||
|
||||
|
||||
def get_model(flags):
|
||||
if flags.name.lower() == 'segnet':
|
||||
model = ocnn.SegNet(flags.depth, flags.channel, flags.nout)
|
||||
else:
|
||||
raise ValueError
|
||||
return model
|
||||
|
||||
|
||||
def train():
|
||||
model.train()
|
||||
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
# get the inputs
|
||||
octrees = data[0].cuda()
|
||||
labels = ocnn.octree_property(octrees, 'label', FLAGS.MODEL.depth)
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
logits = model(octrees)
|
||||
logits = logits.squeeze().transpose(0, 1) # N x C
|
||||
loss = loss_functions_seg(logits, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
if i % 100 == 99:
|
||||
tqdm.write('[Train iter: %5d] loss: %.3f' % (i + 1, running_loss / i))
|
||||
return running_loss / i
|
||||
|
||||
|
||||
def test():
|
||||
model.eval()
|
||||
|
||||
accu, mIoU, counter = 0, 0, 0
|
||||
for data in test_loader:
|
||||
octrees = data[0].cuda()
|
||||
labels = ocnn.octree_property(octrees, 'label', FLAGS.MODEL.depth)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(octrees)
|
||||
logits = logits.squeeze().transpose(0, 1) # N x C
|
||||
|
||||
counter += 1
|
||||
accu += accuracy(logits, labels)
|
||||
mIoU += IoU_per_shape(logits, labels, FLAGS.LOSS.num_class)
|
||||
|
||||
accu /= counter
|
||||
mIoU /= counter
|
||||
tqdm.write('[Test] accuracy: %.3f, mIoU: %.3f' % (accu, mIoU))
|
||||
return accu, mIoU
|
||||
|
||||
|
||||
def loss_functions_seg(logit, label, mask=-1):
|
||||
label_mask = label > mask # filter label -1
|
||||
masked_logit = logit[label_mask, :]
|
||||
masked_label = label[label_mask]
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
loss = criterion(masked_logit, masked_label.long())
|
||||
return loss
|
||||
|
||||
|
||||
def accuracy(logit, label, mask=-1):
|
||||
label_mask = label > mask # filter label -1
|
||||
masked_logit = logit[label_mask, :]
|
||||
label_mask = label[label_mask]
|
||||
pred = masked_logit.argmax(dim=1)
|
||||
accu = pred.eq(label_mask).float().mean()
|
||||
return accu.item()
|
||||
|
||||
|
||||
def IoU_per_shape(logit, label, class_num, mask=-1):
|
||||
label_mask = label > mask # filter label -1
|
||||
masked_logit = logit[label_mask, :]
|
||||
masked_label = label[label_mask]
|
||||
pred = masked_logit.argmax(dim=1)
|
||||
|
||||
IoU, valid_part_num, esp = 0.0, 0.0, 1.0e-10
|
||||
for k in range(class_num):
|
||||
pk, lk = pred.eq(k), masked_label.eq(k)
|
||||
intsc = torch.sum(pk & lk)
|
||||
union = torch.sum(pk | lk)
|
||||
valid = torch.sum(lk.any()) > 0
|
||||
valid_part_num += valid.item()
|
||||
IoU += valid * intsc / (union + esp)
|
||||
IoU /= valid_part_num + esp
|
||||
return IoU.item()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# configs
|
||||
FLAGS = parse_args()
|
||||
|
||||
# data
|
||||
train_loader = get_dataloader(FLAGS.DATA.train, train=True)
|
||||
test_loader = get_dataloader(FLAGS.DATA.test, train=False)
|
||||
|
||||
# model
|
||||
model = get_model(FLAGS.MODEL)
|
||||
model.cuda()
|
||||
print(model)
|
||||
|
||||
# loss and optimizer
|
||||
flags_solver = FLAGS.SOLVER
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer, milestones=flags_solver.step_size, gamma=0.1)
|
||||
|
||||
# summary
|
||||
logdir = flags_solver.logdir
|
||||
writer = SummaryWriter(logdir)
|
||||
ckpt_dir = os.path.join(logdir, 'checkpoints')
|
||||
if not os.path.exists(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
# writer.add_graph(model, next(iter(test_loader))[0].cuda())
|
||||
|
||||
# train and test
|
||||
for epoch in tqdm(range(1, flags_solver.max_epoch+1), ncols=80):
|
||||
tqdm.write('[Epoch: %5d]' % epoch)
|
||||
train_loss = train()
|
||||
writer.add_scalar('train_loss', train_loss, epoch)
|
||||
if epoch % flags_solver.test_every_epoch == 0:
|
||||
test_accu, test_mIoU = test()
|
||||
writer.add_scalar('test_accu', test_accu, epoch)
|
||||
writer.add_scalar('test_mIoU', test_mIoU, epoch)
|
||||
ckpt_name = os.path.join(ckpt_dir, 'model_%05d.pth' % epoch)
|
||||
torch.save(model.state_dict(), ckpt_name)
|
||||
scheduler.step()
|
|
@ -0,0 +1,8 @@
|
|||
from . import config
|
||||
from .config import get_config, parse_args
|
||||
|
||||
from . import solver
|
||||
from .solver import Solver
|
||||
|
||||
from . import dataset
|
||||
from .dataset import Dataset
|
|
@ -0,0 +1,171 @@
|
|||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
|
||||
# SOLVER related parameters
|
||||
_C.SOLVER = CN()
|
||||
_C.SOLVER.alias = '' # The experiment alias
|
||||
_C.SOLVER.gpu = (0,) # The gpu ids
|
||||
_C.SOLVER.run = 'train' # Choose from train or test
|
||||
|
||||
_C.SOLVER.logdir = 'logs' # Directory where to write event logs
|
||||
_C.SOLVER.ckpt = '' # Restore weights from checkpoint file
|
||||
_C.SOLVER.ckpt_num = 10 # The number of checkpoint kept
|
||||
|
||||
_C.SOLVER.type = 'sgd' # Choose from sgd or adam
|
||||
_C.SOLVER.weight_decay = 0.0005 # The weight decay on model weights
|
||||
_C.SOLVER.max_epoch = 300 # Maximum training epoch
|
||||
_C.SOLVER.eval_epoch = 1 # Maximum evaluating epoch
|
||||
_C.SOLVER.test_every_epoch = 10 # Test model every n training epochs
|
||||
|
||||
_C.SOLVER.lr_type = 'step' # Learning rate type: step or cos
|
||||
_C.SOLVER.lr = 0.1 # Initial learning rate
|
||||
_C.SOLVER.gamma = 0.1 # Learning rate step-wise decay
|
||||
_C.SOLVER.step_size = (120,60,) # Learning rate step size.
|
||||
_C.SOLVER.lr_power = 0.9 # Used in poly learning rate
|
||||
|
||||
_C.SOLVER.dist_url = 'tcp://localhost:10001'
|
||||
_C.SOLVER.progress_bar = True
|
||||
|
||||
|
||||
# DATA related parameters
|
||||
_C.DATA = CN()
|
||||
_C.DATA.train = CN()
|
||||
_C.DATA.train.name = '' # The name of the dataset
|
||||
|
||||
# For octree building
|
||||
# If node_dis = True and there are normals, the octree features
|
||||
# is 4 channels, i.e., the average normals and the 1 channel displacement.
|
||||
# If node_dis = True and there are no normals, the feature is also 4 channels,
|
||||
# i.e., a 3 channel # displacement of average points relative to the center
|
||||
# points, and the last channel is constant.
|
||||
_C.DATA.train.depth = 5 # The octree depth
|
||||
_C.DATA.train.full_depth = 2 # The full depth
|
||||
_C.DATA.train.node_dis = False # Save the node displacement
|
||||
_C.DATA.train.split_label = False # Save the split label
|
||||
_C.DATA.train.adaptive = False # Build the adaptive octree
|
||||
_C.DATA.train.node_feat = False # Calculate the node feature
|
||||
|
||||
# For normalization
|
||||
# If radius < 0, then the method will compute a bounding sphere
|
||||
_C.DATA.train.bsphere = 'sphere' # The method uesd to calc the bounding sphere
|
||||
_C.DATA.train.radius = -1. # The radius and center of the bounding sphere
|
||||
_C.DATA.train.center = (-1., -1., -1.)
|
||||
|
||||
# For transformation
|
||||
_C.DATA.train.offset = 0.016 # Used to displace the points when building octree
|
||||
_C.DATA.train.normal_axis = '' # Used to re-orient normal directions
|
||||
|
||||
# For data augmentation
|
||||
_C.DATA.train.disable = False # Disable this dataset or not
|
||||
_C.DATA.train.distort = False # Whether to apply data augmentation
|
||||
_C.DATA.train.scale = 0.0 # Scale the points
|
||||
_C.DATA.train.uniform = False # Generate uniform scales
|
||||
_C.DATA.train.jitter = 0.0 # Jitter the points
|
||||
_C.DATA.train.interval = (1, 1, 1) # Use interval&angle to generate random angle
|
||||
_C.DATA.train.angle = (180, 180, 180)
|
||||
|
||||
# For data loading
|
||||
_C.DATA.train.location = '' # The data location
|
||||
_C.DATA.train.filelist = '' # The data filelist
|
||||
_C.DATA.train.batch_size = 32 # Training data batch size
|
||||
_C.DATA.train.num_workers = 8 # Number of workers to load the data
|
||||
_C.DATA.train.shuffle = False # Shuffle the input data
|
||||
_C.DATA.train.in_memory = False # Load the training data into memory
|
||||
|
||||
|
||||
_C.DATA.test = _C.DATA.train.clone()
|
||||
|
||||
|
||||
# MODEL related parameters
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.name = '' # The name of the model
|
||||
_C.MODEL.depth = 5 # The input octree depth
|
||||
_C.MODEL.full_depth = 2 # The input octree full depth layer
|
||||
_C.MODEL.depth_out = 5 # The output feature depth
|
||||
_C.MODEL.channel = 3 # The input feature channel
|
||||
_C.MODEL.factor = 1 # The factor used to widen the network
|
||||
_C.MODEL.nout = 40 # The output feature channel
|
||||
_C.MODEL.resblock_num = 3 # The resblock number
|
||||
_C.MODEL.bottleneck = 4 # The bottleneck factor of one resblock
|
||||
_C.MODEL.dropout = (0.0,) # The dropout ratio
|
||||
_C.MODEL.upsample = 'nearest' # The method used for upsampling
|
||||
_C.MODEL.interp = 'linear' # The interplation method: linear or nearest
|
||||
_C.MODEL.nempty = False # Perform Octree Conv on non-empty octree nodes
|
||||
_C.MODEL.sync_bn = False # Use sync_bn when training the network
|
||||
_C.MODEL.use_checkpoint = False # Use checkpoint to save memory
|
||||
|
||||
# loss related parameters
|
||||
_C.LOSS = CN()
|
||||
_C.LOSS.num_class = 40 # The class number for the cross-entropy loss
|
||||
_C.LOSS.weights = (1.0, 1.0) # The weight factors for different losses
|
||||
_C.LOSS.label_smoothing = 0.0 # The factor of label smoothing
|
||||
|
||||
|
||||
# backup the commands
|
||||
_C.SYS = CN()
|
||||
_C.SYS.cmds = '' # Used to backup the commands
|
||||
|
||||
FLAGS = _C
|
||||
|
||||
|
||||
def _update_config(FLAGS, args):
|
||||
FLAGS.defrost()
|
||||
if args.config:
|
||||
FLAGS.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
FLAGS.merge_from_list(args.opts)
|
||||
FLAGS.SYS.cmds = ' '.join(sys.argv)
|
||||
# update logdir
|
||||
alias = FLAGS.SOLVER.alias.lower()
|
||||
if 'time' in alias:
|
||||
alias = alias.replace('time', datetime.now().strftime('%m%d%H%M')) #%S
|
||||
if alias is not '':
|
||||
FLAGS.SOLVER.logdir += '_' + alias
|
||||
FLAGS.freeze()
|
||||
|
||||
|
||||
def _backup_config(FLAGS, args):
|
||||
logdir = FLAGS.SOLVER.logdir
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
# copy the file to logdir
|
||||
if args.config:
|
||||
shutil.copy2(args.config, logdir)
|
||||
# dump all configs
|
||||
filename = os.path.join(logdir, 'all_configs.yaml')
|
||||
with open(filename, 'w') as fid:
|
||||
fid.write(FLAGS.dump())
|
||||
|
||||
|
||||
def _set_env_var(FLAGS):
|
||||
gpus = ','.join([str(a) for a in FLAGS.SOLVER.gpu])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
|
||||
|
||||
|
||||
def get_config():
|
||||
return FLAGS
|
||||
|
||||
def parse_args(backup=True):
|
||||
parser = argparse.ArgumentParser(description='The configs')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='experiment configure file name')
|
||||
parser.add_argument('opts', nargs=argparse.REMAINDER,
|
||||
help="Modify config options using the command-line")
|
||||
|
||||
args = parser.parse_args()
|
||||
_update_config(FLAGS, args)
|
||||
if backup:
|
||||
_backup_config(FLAGS, args)
|
||||
_set_env_var(FLAGS)
|
||||
return FLAGS
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags = parse_args(backup=False)
|
||||
print(flags)
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def read_file(filename):
|
||||
points = np.fromfile(filename, dtype=np.uint8)
|
||||
return torch.from_numpy(points) # convert it to torch.tensor
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, root, filelist, transform, read_file=read_file, in_memory=True):
|
||||
super(Dataset, self).__init__()
|
||||
self.root = root
|
||||
self.filelist = filelist
|
||||
self.transform = transform
|
||||
self.in_memory = in_memory
|
||||
self.read_file = read_file
|
||||
self.filenames, self.labels = self.load_filenames()
|
||||
if self.in_memory:
|
||||
print('Load files into memory from ' + self.filelist)
|
||||
self.samples = [self.read_file(f)
|
||||
for f in tqdm(self.filenames, ncols=80, leave=False)]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.samples[idx] if self.in_memory else \
|
||||
self.read_file(self.filenames[idx])
|
||||
output = self.transform(sample, idx) # data augmentation + build octree
|
||||
output['label'] = self.labels[idx]
|
||||
return output
|
||||
|
||||
def load_filenames(self):
|
||||
filenames, labels = [], []
|
||||
with open(self.filelist) as fid:
|
||||
lines = fid.readlines()
|
||||
for line in lines:
|
||||
tokens = line.split()
|
||||
filename = tokens[0]
|
||||
label = tokens[1] if len(tokens) == 2 else 0
|
||||
filenames.append(os.path.join(self.root, filename))
|
||||
labels.append(int(label))
|
||||
return filenames, labels
|
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
from torch.utils.data import Sampler, DistributedSampler, Dataset
|
||||
|
||||
|
||||
class InfSampler(Sampler):
|
||||
def __init__(self, dataset: Dataset, shuffle: bool = True) -> None:
|
||||
self.dataset = dataset
|
||||
self.shuffle = shuffle
|
||||
self.reset_sampler()
|
||||
|
||||
def reset_sampler(self):
|
||||
num = len(self.dataset)
|
||||
indices = torch.randperm(num) if self.shuffle else torch.arange(num)
|
||||
self.indices = indices.tolist()
|
||||
self.iter_num = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
value = self.indices[self.iter_num]
|
||||
self.iter_num = self.iter_num + 1
|
||||
|
||||
if self.iter_num >= len(self.indices):
|
||||
self.reset_sampler()
|
||||
return value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
class DistributedInfSampler(DistributedSampler):
|
||||
def __init__(self, dataset: Dataset, shuffle: bool = True) -> None:
|
||||
super().__init__(dataset, shuffle=shuffle)
|
||||
self.reset_sampler()
|
||||
|
||||
def reset_sampler(self):
|
||||
self.indices = list(super().__iter__())
|
||||
self.iter_num = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
value = self.indices[self.iter_num]
|
||||
self.iter_num = self.iter_num + 1
|
||||
|
||||
if self.iter_num >= len(self.indices):
|
||||
self.reset_sampler()
|
||||
return value
|
|
@ -0,0 +1,395 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
import torch.distributed
|
||||
import torch.multiprocessing
|
||||
import torch.utils.data
|
||||
import warnings
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .sampler import InfSampler, DistributedInfSampler
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", module="torch.optim.lr_scheduler")
|
||||
|
||||
|
||||
class AverageTracker:
|
||||
def __init__(self):
|
||||
self.value = None
|
||||
self.num = 0.0
|
||||
self.max_len = 70
|
||||
|
||||
def update(self, value):
|
||||
if not value:
|
||||
return # empty input, return
|
||||
|
||||
value = {key: val.detach() for key, val in value.items()}
|
||||
if self.value is None:
|
||||
self.value = value
|
||||
else:
|
||||
for key, val in value.items():
|
||||
self.value[key] += val
|
||||
self.num += 1
|
||||
|
||||
def average(self):
|
||||
return {key: val.item()/self.num for key, val in self.value.items()}
|
||||
|
||||
@torch.no_grad()
|
||||
def average_all_gather(self):
|
||||
for key, tensor in self.value.items():
|
||||
tensors_gather = [torch.ones_like(tensor)
|
||||
for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
tensors = torch.stack(tensors_gather, dim=0)
|
||||
self.value[key] = torch.mean(tensors)
|
||||
|
||||
def log(self, epoch, summry_writer=None, log_file=None):
|
||||
if not self.value:
|
||||
return # empty, return
|
||||
|
||||
avg = self.average()
|
||||
msg = 'Epoch: %d' % epoch
|
||||
for key, val in avg.items():
|
||||
msg += ', %s: %.3f' % (key, val)
|
||||
if summry_writer:
|
||||
summry_writer.add_scalar(key, val, epoch)
|
||||
|
||||
if log_file:
|
||||
with open(log_file, 'a') as fid:
|
||||
fid.write(msg + '\n')
|
||||
if len(msg) > self.max_len:
|
||||
msg = msg[:self.max_len] + ' ...'
|
||||
tqdm.write(msg)
|
||||
|
||||
|
||||
class Solver:
|
||||
def __init__(self, FLAGS, is_master=True):
|
||||
self.FLAGS = FLAGS
|
||||
self.is_master = is_master
|
||||
self.world_size = len(FLAGS.SOLVER.gpu)
|
||||
self.device = torch.cuda.current_device()
|
||||
self.disable_tqdm = not (is_master and FLAGS.SOLVER.progress_bar)
|
||||
self.start_epoch = 1
|
||||
|
||||
self.model = None # torch.nn.Module
|
||||
self.optimizer = None # torch.optim.Optimizer
|
||||
self.scheduler = None # torch.optim.lr_scheduler._LRScheduler
|
||||
self.summry_writer = None # torch.utils.tensorboard.SummaryWriter
|
||||
self.log_file = None # str, used to save training logs
|
||||
|
||||
def get_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dataset(self, flags):
|
||||
raise NotImplementedError
|
||||
|
||||
def train_step(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def eval_step(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def result_callback(self, avg_tracker: AverageTracker, epoch):
|
||||
pass # additional operations based on the avg_tracker
|
||||
|
||||
def config_dataloader(self, disable_train_data=False):
|
||||
flags_train, flags_test = self.FLAGS.DATA.train, self.FLAGS.DATA.test
|
||||
|
||||
if not disable_train_data and not flags_train.disable:
|
||||
self.train_loader = self.get_dataloader(flags_train)
|
||||
self.train_iter = iter(self.train_loader)
|
||||
|
||||
if not flags_test.disable:
|
||||
self.test_loader = self.get_dataloader(flags_test)
|
||||
self.test_iter = iter(self.test_loader)
|
||||
|
||||
def get_dataloader(self, flags):
|
||||
dataset, collate_fn = self.get_dataset(flags)
|
||||
|
||||
if self.world_size > 1:
|
||||
sampler = DistributedInfSampler(dataset, shuffle=flags.shuffle)
|
||||
else:
|
||||
sampler = InfSampler(dataset, shuffle=flags.shuffle)
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=flags.batch_size, num_workers=flags.num_workers,
|
||||
sampler=sampler, collate_fn=collate_fn, pin_memory=True)
|
||||
return data_loader
|
||||
|
||||
def config_model(self):
|
||||
model = self.get_model(self.FLAGS.MODEL)
|
||||
model.cuda(device=self.device)
|
||||
if self.world_size > 1:
|
||||
if self.FLAGS.MODEL.sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
module=model, device_ids=[self.device],
|
||||
output_device=self.device, broadcast_buffers=False,
|
||||
find_unused_parameters=False)
|
||||
if self.is_master:
|
||||
print(model)
|
||||
self.model = model
|
||||
|
||||
def configure_optimizer(self):
|
||||
flags = self.FLAGS.SOLVER
|
||||
# The learning rate scales with regard to the world_size
|
||||
lr = flags.lr * self.world_size
|
||||
if flags.type == 'sgd':
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self.model.parameters(), lr=lr, weight_decay=flags.weight_decay,
|
||||
momentum=0.9)
|
||||
elif flags.type == 'adam':
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self.model.parameters(), lr=lr, weight_decay=flags.weight_decay)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if flags.lr_type == 'step':
|
||||
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
self.optimizer, milestones=flags.step_size, gamma=0.1)
|
||||
elif flags.lr_type == 'cos':
|
||||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer, flags.max_epoch, eta_min=0.001)
|
||||
elif flags.lr_type == 'poly':
|
||||
def poly(epoch): return (1 - epoch / flags.max_epoch) ** flags.lr_power
|
||||
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.optimizer, poly)
|
||||
elif flags.lr_type == 'constant':
|
||||
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.optimizer, lambda epoch: 1)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def configure_log(self, set_writer=True):
|
||||
self.logdir = self.FLAGS.SOLVER.logdir
|
||||
self.ckpt_dir = os.path.join(self.logdir, 'checkpoints')
|
||||
self.log_file = os.path.join(self.logdir, 'log.csv')
|
||||
|
||||
if self.is_master:
|
||||
tqdm.write('Logdir: ' + self.logdir)
|
||||
|
||||
if self.is_master and set_writer:
|
||||
self.summry_writer = SummaryWriter(self.logdir, flush_secs=20)
|
||||
if not os.path.exists(self.ckpt_dir):
|
||||
os.makedirs(self.ckpt_dir)
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
self.model.train()
|
||||
if self.world_size > 1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
train_tracker = AverageTracker()
|
||||
rng = range(len(self.train_loader))
|
||||
for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# forward
|
||||
batch = self.train_iter.next()
|
||||
batch['iter_num'] = it
|
||||
batch['epoch'] = epoch
|
||||
output = self.train_step(batch)
|
||||
|
||||
# backward
|
||||
output['train/loss'].backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# track the averaged tensors
|
||||
train_tracker.update(output)
|
||||
|
||||
# save logs
|
||||
if self.world_size > 1:
|
||||
train_tracker.average_all_gather()
|
||||
if self.is_master:
|
||||
train_tracker.log(epoch, self.summry_writer)
|
||||
|
||||
def test_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
test_tracker = AverageTracker()
|
||||
rng = range(len(self.test_loader))
|
||||
for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm):
|
||||
# forward
|
||||
batch = self.test_iter.next()
|
||||
batch['iter_num'] = it
|
||||
batch['epoch'] = epoch
|
||||
with torch.no_grad():
|
||||
output = self.test_step(batch)
|
||||
|
||||
# track the averaged tensors
|
||||
test_tracker.update(output)
|
||||
|
||||
if self.world_size > 1:
|
||||
test_tracker.average_all_gather()
|
||||
if self.is_master:
|
||||
test_tracker.log(epoch, self.summry_writer, self.log_file)
|
||||
self.result_callback(test_tracker, epoch)
|
||||
|
||||
def eval_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
for it in tqdm(range(len(self.test_loader)), ncols=80, leave=False):
|
||||
batch = self.test_iter.next()
|
||||
batch['iter_num'] = it
|
||||
batch['epoch'] = epoch
|
||||
with torch.no_grad():
|
||||
self.eval_step(batch)
|
||||
|
||||
def save_checkpoint(self, epoch):
|
||||
if not self.is_master:
|
||||
return
|
||||
|
||||
# clean up
|
||||
ckpts = sorted(os.listdir(self.ckpt_dir))
|
||||
ckpts = [ck for ck in ckpts if ck.endswith('.pth') or ck.endswith('.tar')]
|
||||
if len(ckpts) > self.FLAGS.SOLVER.ckpt_num:
|
||||
for ckpt in ckpts[:-self.FLAGS.SOLVER.ckpt_num]:
|
||||
os.remove(os.path.join(self.ckpt_dir, ckpt))
|
||||
|
||||
# save ckpt
|
||||
model_dict = self.model.module.state_dict() \
|
||||
if self.world_size > 1 else self.model.state_dict()
|
||||
ckpt_name = os.path.join(self.ckpt_dir, '%05d' % epoch)
|
||||
torch.save(model_dict, ckpt_name + '.model.pth')
|
||||
torch.save({'model_dict': model_dict, 'epoch': epoch,
|
||||
'optimizer_dict': self.optimizer.state_dict(),
|
||||
'scheduler_dict': self.scheduler.state_dict()},
|
||||
ckpt_name + '.solver.tar')
|
||||
|
||||
def load_checkpoint(self):
|
||||
ckpt = self.FLAGS.SOLVER.ckpt
|
||||
if not ckpt:
|
||||
# If ckpt is empty, then get the latest checkpoint from ckpt_dir
|
||||
if not os.path.exists(self.ckpt_dir): return
|
||||
ckpts = sorted(os.listdir(self.ckpt_dir))
|
||||
ckpts = [ck for ck in ckpts if ck.endswith('solver.tar')]
|
||||
if len(ckpts) > 0:
|
||||
ckpt = os.path.join(self.ckpt_dir, ckpts[-1])
|
||||
if not ckpt: return # return if ckpt is still empty
|
||||
|
||||
# load trained model
|
||||
# check: map_location = {'cuda:0' : 'cuda:%d' % self.rank}
|
||||
trained_dict = torch.load(ckpt, map_location='cuda')
|
||||
if ckpt.endswith('.solver.tar'):
|
||||
model_dict = trained_dict['model_dict']
|
||||
self.start_epoch = trained_dict['epoch'] + 1 # !!! add 1
|
||||
if self.optimizer:
|
||||
self.optimizer.load_state_dict(trained_dict['optimizer_dict'])
|
||||
if self.scheduler:
|
||||
self.scheduler.load_state_dict(trained_dict['scheduler_dict'])
|
||||
else:
|
||||
model_dict = trained_dict
|
||||
model = self.model.module if self.world_size > 1 else self.model
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
# print messages
|
||||
if self.is_master:
|
||||
tqdm.write('Load the checkpoint: %s' % ckpt)
|
||||
tqdm.write('The start_epoch is %d' % self.start_epoch)
|
||||
|
||||
def train(self):
|
||||
self.config_model()
|
||||
self.config_dataloader()
|
||||
self.configure_optimizer()
|
||||
self.configure_log()
|
||||
self.load_checkpoint()
|
||||
|
||||
rng = range(self.start_epoch, self.FLAGS.SOLVER.max_epoch+1)
|
||||
for epoch in tqdm(rng, ncols=80, disable=self.disable_tqdm):
|
||||
# training epoch
|
||||
self.train_epoch(epoch)
|
||||
|
||||
# update learning rate
|
||||
self.scheduler.step()
|
||||
if self.is_master:
|
||||
lr = self.scheduler.get_last_lr() # lr is a list
|
||||
self.summry_writer.add_scalar('train/lr', lr[0], epoch)
|
||||
|
||||
# testing or not
|
||||
if epoch % self.FLAGS.SOLVER.test_every_epoch != 0:
|
||||
continue
|
||||
|
||||
# testing epoch
|
||||
self.test_epoch(epoch)
|
||||
|
||||
# checkpoint
|
||||
self.save_checkpoint(epoch)
|
||||
|
||||
# sync and exit
|
||||
if self.world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
def test(self):
|
||||
self.config_model()
|
||||
self.configure_log(set_writer=False)
|
||||
self.config_dataloader(disable_train_data=True)
|
||||
self.load_checkpoint()
|
||||
self.test_epoch(epoch=0)
|
||||
|
||||
def evaluate(self):
|
||||
self.config_model()
|
||||
self.configure_log(set_writer=False)
|
||||
self.config_dataloader(disable_train_data=True)
|
||||
self.load_checkpoint()
|
||||
for epoch in tqdm(range(self.FLAGS.SOLVER.eval_epoch), ncols=80):
|
||||
self.eval_epoch(epoch)
|
||||
|
||||
def profile(self):
|
||||
''' Set `DATA.train.num_workers 0` when using this function'''
|
||||
self.config_model()
|
||||
self.config_dataloader()
|
||||
|
||||
# warm up
|
||||
batch = next(iter(self.train_loader))
|
||||
for _ in range(3):
|
||||
output = self.train_step(batch)
|
||||
output['train/loss'].backward()
|
||||
|
||||
# profile
|
||||
with torch.autograd.profiler.profile(
|
||||
use_cuda=True, profile_memory=True,
|
||||
with_stack=True, record_shapes=True) as prof:
|
||||
output = self.train_step(batch)
|
||||
output['train/loss'].backward()
|
||||
|
||||
json = os.path.join(self.FLAGS.SOLVER.logdir, 'trace.json')
|
||||
print('Save the profile into: ' + json)
|
||||
prof.export_chrome_trace(json)
|
||||
print(prof.key_averages(group_by_stack_n=10)
|
||||
.table(sort_by="cuda_time_total", row_limit=10))
|
||||
print(prof.key_averages(group_by_stack_n=10)
|
||||
.table(sort_by="cuda_memory_usage", row_limit=10))
|
||||
|
||||
def run(self):
|
||||
eval('self.%s()' % self.FLAGS.SOLVER.run)
|
||||
|
||||
@staticmethod
|
||||
def main_worker(gpu, FLAGS, TheSolver):
|
||||
world_size = len(FLAGS.SOLVER.gpu)
|
||||
if world_size > 1:
|
||||
# Set the GPU to use.
|
||||
torch.cuda.set_device(gpu)
|
||||
# Initialize the process group. Currently, the code only supports the
|
||||
# `single node + multiple GPU` mode, so the rank is equal to gpu id.
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl', init_method=FLAGS.SOLVER.dist_url,
|
||||
world_size=world_size, rank=gpu)
|
||||
# Master process is responsible for logging, writing and loading
|
||||
# checkpoints. In the multi GPU setting, we assign the master role to the
|
||||
# rank 0 process.
|
||||
is_master = gpu == 0
|
||||
solver = TheSolver(FLAGS, is_master)
|
||||
else:
|
||||
solver = TheSolver(FLAGS, is_master=True)
|
||||
solver.run()
|
||||
|
||||
@staticmethod
|
||||
def main(FLAGS, TheSolver):
|
||||
num_gpus = len(FLAGS.SOLVER.gpu)
|
||||
if num_gpus > 1:
|
||||
torch.multiprocessing.spawn(
|
||||
Solver.main_worker, nprocs=num_gpus, args=(FLAGS, TheSolver))
|
||||
else:
|
||||
Solver.main_worker(0, FLAGS, TheSolver)
|
|
@ -0,0 +1,157 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--run', type=str, required=True)
|
||||
parser.add_argument('--octree', type=str, required=False,
|
||||
default='logs/completion/skip_connections_test')
|
||||
args = parser.parse_args()
|
||||
|
||||
abs_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
root_folder = os.path.join(abs_path, 'data/ocnn_completion')
|
||||
|
||||
points2ply, ply2points, octree2pts = 'points2ply', 'ply2points', 'octree2points'
|
||||
|
||||
|
||||
def download_point_clouds():
|
||||
# download via wget
|
||||
if not os.path.exists(root_folder):
|
||||
os.makedirs(root_folder)
|
||||
url = 'https://www.dropbox.com/s/z2x0mw4ai18f855/ocnn_completion.zip?dl=0'
|
||||
cmd = 'wget %s -O %s.zip' % (url, root_folder)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
# unzip
|
||||
cmd = 'unzip %s.zip -d %s' % (root_folder, root_folder)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def _convert_ply_to_points(prefix='shape'):
|
||||
ply_folder = os.path.join(root_folder, prefix + '.ply')
|
||||
points_folder = os.path.join(root_folder, prefix + '.points')
|
||||
|
||||
folders = os.listdir(ply_folder)
|
||||
for folder in folders:
|
||||
curr_folder = os.path.join(ply_folder, folder)
|
||||
|
||||
# write filelist to disk
|
||||
filenames = os.listdir(curr_folder)
|
||||
filelist_name = os.path.join(curr_folder, 'filelist.txt')
|
||||
with open(filelist_name, 'w') as fid:
|
||||
for filename in filenames:
|
||||
if filename.endswith('.ply'):
|
||||
fid.write(os.path.join(curr_folder, filename) + '\n')
|
||||
|
||||
# run points2ply
|
||||
output_path = os.path.join(points_folder, folder)
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
cmd = '%s --filenames %s --output_path %s --verbose 0' % \
|
||||
(ply2points, filelist_name, output_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
os.remove(filelist_name)
|
||||
|
||||
|
||||
def convert_ply_to_points():
|
||||
_convert_ply_to_points('shape')
|
||||
_convert_ply_to_points('test.scans')
|
||||
|
||||
|
||||
def _convert_points_to_ply(prefix='shape'):
|
||||
ply_folder = os.path.join(root_folder, prefix + '.ply')
|
||||
points_folder = os.path.join(root_folder, prefix + '.points')
|
||||
|
||||
folders = os.listdir(points_folder)
|
||||
for folder in folders:
|
||||
curr_folder = os.path.join(points_folder, folder)
|
||||
|
||||
# write filelist to disk
|
||||
filenames = os.listdir(curr_folder)
|
||||
filelist_name = os.path.join(curr_folder, 'filelist.txt')
|
||||
with open(filelist_name, 'w') as fid:
|
||||
for filename in filenames:
|
||||
if filename.endswith('.points'):
|
||||
fid.write(os.path.join(curr_folder, filename) + '\n')
|
||||
|
||||
# run points2ply
|
||||
output_path = os.path.join(ply_folder, folder)
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
cmd = '%s --filenames %s --output_path %s --verbose 0' % \
|
||||
(points2ply, filelist_name, output_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
os.remove(filelist_name)
|
||||
|
||||
|
||||
def convert_points_to_ply():
|
||||
_convert_points_to_ply('shape')
|
||||
_convert_points_to_ply('test.scans')
|
||||
|
||||
|
||||
def generate_dataset():
|
||||
download_point_clouds()
|
||||
convert_ply_to_points()
|
||||
|
||||
|
||||
def rename_output_octree():
|
||||
filelist = os.path.join(root_folder, 'filelist_test_scans.txt')
|
||||
filenames = []
|
||||
with open(filelist, 'r') as fid:
|
||||
for line in fid:
|
||||
filename = line.split()[0]
|
||||
filenames.append(filename[:-6] + 'octree')
|
||||
|
||||
idx = 0
|
||||
folder_in = args.octree
|
||||
octree_in = os.listdir(folder_in)
|
||||
octree_in.sort()
|
||||
folder_out = os.path.join(root_folder, 'output.octree')
|
||||
for o in octree_in:
|
||||
if o.endswith('output.octree'):
|
||||
name_in = os.path.join(folder_in, o)
|
||||
name_out = os.path.join(folder_out, filenames[idx])
|
||||
os.renames(name_in, name_out)
|
||||
idx += 1
|
||||
assert (idx == 1200)
|
||||
|
||||
|
||||
def _convert_octree_to_points(suffix='ply'):
|
||||
octree_folder = os.path.join(root_folder, 'output.octree')
|
||||
points_folder = os.path.join(root_folder, 'output.' + suffix)
|
||||
|
||||
folders = os.listdir(octree_folder)
|
||||
for folder in folders:
|
||||
curr_folder = os.path.join(octree_folder, folder)
|
||||
|
||||
# write filelist to disk
|
||||
filenames = os.listdir(curr_folder)
|
||||
filelist_name = os.path.join(curr_folder, 'filelist.txt')
|
||||
with open(filelist_name, 'w') as fid:
|
||||
for filename in filenames:
|
||||
if filename.endswith('.octree'):
|
||||
fid.write(os.path.join(curr_folder, filename) + '\n')
|
||||
|
||||
# run octree2points
|
||||
output_path = os.path.join(points_folder, folder)
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
cmd = '%s --filenames %s --output_path %s --verbose 0 --suffix %s' % \
|
||||
(octree2pts, filelist_name, output_path, suffix)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
os.remove(filelist_name)
|
||||
|
||||
|
||||
def convert_octree_to_points():
|
||||
_convert_octree_to_points('points')
|
||||
_convert_octree_to_points('ply')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval('%s()' % args.run)
|
|
@ -0,0 +1,233 @@
|
|||
import os
|
||||
import math
|
||||
import argparse
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--run', type=str, required=True,
|
||||
help='The command to run.')
|
||||
parser.add_argument('--scanner', type=str, required=False,
|
||||
help='The path of the virtual_scanner')
|
||||
parser.add_argument('--simplify_points', type=str, required=False,
|
||||
default='simplify_points',
|
||||
help='The path of the simplify_points')
|
||||
parser.add_argument('--transform_points', type=str, required=False,
|
||||
default='transform_points',
|
||||
help='The path of the transform_points')
|
||||
parser.add_argument('--align_y', type=str, required=False, default='false',
|
||||
help='Align the points with y axis')
|
||||
|
||||
abs_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
root_folder = os.path.join(abs_path, 'data/ModelNet40')
|
||||
|
||||
args = parser.parse_args()
|
||||
virtual_scanner = args.scanner
|
||||
simplify = args.simplify_points
|
||||
transform = args.transform_points
|
||||
|
||||
|
||||
def download_m40():
|
||||
# download via wget
|
||||
if not os.path.exists(root_folder):
|
||||
os.makedirs(root_folder)
|
||||
url = 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
|
||||
cmd = 'wget %s -O %s/ModelNet40.zip' % (url, root_folder)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
# unzip
|
||||
cmd = 'unzip %s/ModelNet40.zip -d %s' % (root_folder, root_folder)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def download_m40_points():
|
||||
# download via wget
|
||||
if not os.path.exists(root_folder):
|
||||
os.makedirs(root_folder)
|
||||
url = 'https://www.dropbox.com/s/m233s9eza3acj2a/ModelNet40.points.zip?dl=0'
|
||||
zip_file = os.path.join(root_folder, 'ModelNet40.points.zip')
|
||||
cmd = 'wget %s -O %s' % (url, zip_file)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
# unzip
|
||||
cmd = 'unzip %s -d %s/ModelNet40.points' % (zip_file, root_folder)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def clean_off_file(filename):
|
||||
# read the contents of the file
|
||||
with open(filename) as fid:
|
||||
file_str = fid.read()
|
||||
# fix the file
|
||||
if file_str[0:3] != 'OFF':
|
||||
print('Error: not an OFF file: ' + filename)
|
||||
elif file_str[0:4] != 'OFF\n':
|
||||
print('Info: fix an OFF file: ' + filename)
|
||||
new_str = file_str[0:3] + '\n' + file_str[3:]
|
||||
with open(filename, 'w') as f_rewrite:
|
||||
f_rewrite.write(new_str)
|
||||
|
||||
|
||||
def get_filelist(root_folder, train=True, suffix='off', ratio=1.0):
|
||||
filelist, category = [], []
|
||||
folders = sorted(os.listdir(root_folder))
|
||||
assert(len(folders) == 40)
|
||||
for idx, folder in enumerate(folders):
|
||||
subfolder = 'train' if train else 'test'
|
||||
current_folder = os.path.join(root_folder, folder, subfolder)
|
||||
filenames = sorted(os.listdir(current_folder))
|
||||
filenames = [fname for fname in filenames if fname.endswith(suffix)]
|
||||
total_num = math.ceil(len(filenames) * ratio)
|
||||
for i in range(total_num):
|
||||
filelist.append(os.path.join(folder, subfolder, filenames[i]))
|
||||
category.append(idx)
|
||||
return filelist, category
|
||||
|
||||
|
||||
def move_files(src_folder, des_folder, suffix):
|
||||
folders = os.listdir(src_folder)
|
||||
for folder in folders:
|
||||
for subfolder in ['train', 'test']:
|
||||
curr_src_folder = os.path.join(src_folder, folder, subfolder)
|
||||
curr_des_folder = os.path.join(des_folder, folder, subfolder)
|
||||
if not os.path.exists(curr_des_folder):
|
||||
os.makedirs(curr_des_folder)
|
||||
filenames = os.listdir(curr_src_folder)
|
||||
for filename in filenames:
|
||||
if filename.endswith(suffix):
|
||||
os.rename(os.path.join(curr_src_folder, filename),
|
||||
os.path.join(curr_des_folder, filename))
|
||||
|
||||
|
||||
def convert_mesh_to_points():
|
||||
mesh_folder = os.path.join(root_folder, 'ModelNet40')
|
||||
# Delete the following 3 files since the virtualscanner can not deal with them
|
||||
filelist = ['cone/train/cone_0117.off',
|
||||
'curtain/train/curtain_0066.off',
|
||||
'car/train/car_0021.off.off']
|
||||
for filename in filelist:
|
||||
filename = os.path.join(mesh_folder, filename)
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
# clean the off files
|
||||
train_list, _ = get_filelist(mesh_folder, train=True, suffix='off')
|
||||
test_list, _ = get_filelist(mesh_folder, train=False, suffix='off')
|
||||
filelist = train_list + test_list
|
||||
for filename in filelist:
|
||||
clean_off_file(os.path.join(mesh_folder, filename))
|
||||
|
||||
# run virtualscanner
|
||||
folders = os.listdir(mesh_folder)
|
||||
for folder in folders:
|
||||
for subfolder in ['train', 'test']:
|
||||
curr_folder = os.path.join(mesh_folder, folder, subfolder)
|
||||
cmd = '%s %s 14' % (virtual_scanner, curr_folder)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
# move points
|
||||
move_files(mesh_folder, mesh_folder + '.points', 'points')
|
||||
|
||||
|
||||
def simplify_points(resolution=64):
|
||||
# rename and backup the original folders
|
||||
points_folder = os.path.join(root_folder, 'ModelNet40.points')
|
||||
original_folder = points_folder + ".dense"
|
||||
if os.path.exists(points_folder):
|
||||
os.rename(points_folder, original_folder)
|
||||
|
||||
folders = os.listdir(original_folder)
|
||||
for folder in folders:
|
||||
for subfolder in ['train', 'test']:
|
||||
curr_folder = os.path.join(original_folder, folder, subfolder)
|
||||
# write filelist to disk
|
||||
filenames = os.listdir(curr_folder)
|
||||
filelist_name = os.path.join(curr_folder, 'list.txt')
|
||||
with open(filelist_name, 'w') as fid:
|
||||
for filename in filenames:
|
||||
if filename.endswith('.points'):
|
||||
fid.write(os.path.join(curr_folder, filename) + '\n')
|
||||
# run simplify_points
|
||||
output_path = os.path.join(points_folder, folder, subfolder)
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
cmd = '%s --filenames %s --output_path %s --dim %d' % \
|
||||
(simplify, filelist_name, output_path, resolution)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
os.remove(filelist_name)
|
||||
|
||||
|
||||
def transform_points():
|
||||
points_folder = os.path.join(root_folder, 'ModelNet40.points')
|
||||
output_folder = os.path.join(root_folder, 'ModelNet40.points.y')
|
||||
folders = os.listdir(points_folder)
|
||||
for folder in folders:
|
||||
for subfolder in ['train', 'test']:
|
||||
curr_folder = os.path.join(points_folder, folder, subfolder)
|
||||
output_path = os.path.join(output_folder, folder, subfolder)
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
# write filelist to disk
|
||||
filenames = os.listdir(curr_folder)
|
||||
filelist_name = os.path.join(curr_folder, 'list.txt')
|
||||
with open(filelist_name, 'w') as fid:
|
||||
for filename in filenames:
|
||||
if filename.endswith('.points'):
|
||||
fid.write(os.path.join(curr_folder, filename) + '\n')
|
||||
|
||||
# write the transformation matrix
|
||||
mat = '0 0 1 1 0 0 0 1 0'
|
||||
mat_name = os.path.join(curr_folder, 'mat.txt')
|
||||
with open(mat_name, 'w') as fid:
|
||||
fid.write(mat)
|
||||
|
||||
# run transform points
|
||||
cmd = '%s --filenames %s --output_path %s --mat %s' % \
|
||||
(transform, filelist_name, output_path, mat_name)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
os.remove(filelist_name)
|
||||
os.remove(mat_name)
|
||||
|
||||
|
||||
def generate_points_filelist():
|
||||
points_folder = os.path.join(root_folder, 'ModelNet40.points')
|
||||
|
||||
for folder in ['train', 'test']:
|
||||
train = folder == 'train'
|
||||
filelist, idx = get_filelist(points_folder, train=train, suffix='points')
|
||||
prefix = 'm40_' + folder
|
||||
filename = os.path.join(root_folder, '%s_points_list.txt' % prefix)
|
||||
print('Save to %s' % filename)
|
||||
with open(filename, 'w') as fid:
|
||||
for i in range(len(filelist)):
|
||||
fid.write('%s %d\n' % (filelist[i], idx[i]))
|
||||
|
||||
|
||||
def generate_points_filelist_ratios():
|
||||
ratios = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]
|
||||
points_folder = os.path.join(root_folder, 'ModelNet40.points.y')
|
||||
|
||||
for folder in ['train', 'test']:
|
||||
train = folder == 'train'
|
||||
for ratio in ratios:
|
||||
if train == False and ratio < 1:
|
||||
continue
|
||||
prefix = 'm40_y_%.02f_%s' % (ratio, folder)
|
||||
filename = os.path.join(root_folder, '%s_points_list.txt' % prefix)
|
||||
filelist, idx = get_filelist(points_folder, train=train,
|
||||
suffix='points', ratio=ratio)
|
||||
print('Save to %s' % filename)
|
||||
with open(filename, 'w') as fid:
|
||||
for i in range(len(filelist)):
|
||||
fid.write('%s %d\n' % (filelist[i], idx[i]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval('%s()' % args.run)
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import ocnn
|
||||
import numpy as np
|
||||
|
||||
|
||||
names = ['octree_%d' % i for i in range(1, 7)]
|
|
@ -0,0 +1,257 @@
|
|||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from plyfile import PlyData, PlyElement
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path_in', type=str, default='data/scannet/ScanNet_data')
|
||||
parser.add_argument('--path_out', type=str, default='data/scannet/scans')
|
||||
parser.add_argument('--path_pred', type=str, default='logs/scannet/D9_2cm_eval')
|
||||
parser.add_argument('--filelist', type=str, default='scannetv2_test_new.txt')
|
||||
parser.add_argument('--run', type=str, default='generate_output_seg')
|
||||
parser.add_argument('--label_remap', type=str, default='true')
|
||||
args = parser.parse_args()
|
||||
|
||||
label_remap = args.label_remap.lower() == 'true'
|
||||
|
||||
suffix = '_vh_clean_2.ply'
|
||||
subsets = {'train': 'scans', 'test': 'scans_test'}
|
||||
class_labels = ('wall', 'floor', 'cabinet', 'bed', 'chair',
|
||||
'sofa', 'table', 'door', 'window', 'bookshelf',
|
||||
'picture', 'counter', 'desk', 'curtain',
|
||||
'refrigerator', 'shower curtain', 'toilet', 'sink',
|
||||
'bathtub', 'otherfurniture')
|
||||
class_ids = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 14, 16, 24, 28, 33, 34, 36, 39)
|
||||
label_dict = dict(zip(class_ids, np.arange(0, 21)))
|
||||
ilabel_dict = dict(zip(np.arange(0, 21), class_ids))
|
||||
|
||||
|
||||
def read_ply(filename, compute_normal=True):
|
||||
with open(filename, 'rb') as fid:
|
||||
plydata = PlyData.read(fid)
|
||||
vertex, face = plydata['vertex'].data, plydata['face'].data
|
||||
|
||||
props = [vertex[name].astype(np.float32) for name in vertex.dtype.names]
|
||||
vertex = np.stack(props[:3], axis=1)
|
||||
props = np.stack(props[3:], axis=1)
|
||||
face = np.stack(face['vertex_indices'], axis=0)
|
||||
|
||||
nv = vertex_normal(vertex, face) if compute_normal else np.zeros_like(vertex)
|
||||
vertex_with_props = np.concatenate([vertex, nv, props], axis=1)
|
||||
return vertex_with_props
|
||||
|
||||
|
||||
def face_normal(vertex, face):
|
||||
v01 = vertex[face[:, 1]] - vertex[face[:, 0]]
|
||||
v02 = vertex[face[:, 2]] - vertex[face[:, 0]]
|
||||
vec = np.cross(v01, v02)
|
||||
length = np.sqrt(np.sum(vec**2, axis=1, keepdims=True)) + 1.0e-8
|
||||
nf = vec / length
|
||||
area = length * 0.5
|
||||
return nf, area
|
||||
|
||||
|
||||
def vertex_normal(vertex, face):
|
||||
nf, area = face_normal(vertex, face)
|
||||
nf = nf * area
|
||||
|
||||
nv = np.zeros_like(vertex)
|
||||
for i in range(face.shape[0]):
|
||||
nv[face[i]] += nf[i]
|
||||
|
||||
length = np.sqrt(np.sum(nv**2, axis=1, keepdims=True)) + 1.0e-8
|
||||
nv = nv / length
|
||||
return nv
|
||||
|
||||
|
||||
def save_ply(point_cloud, filename):
|
||||
ncols = point_cloud.shape[1]
|
||||
py_types = (float, float, float, float, float, float,
|
||||
int, int, int, int)[:ncols]
|
||||
npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
||||
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
||||
('red', 'u1'), ('green', 'u1'), ('blue', 'u1'),
|
||||
('label', 'u1')][:ncols]
|
||||
|
||||
# format into NumPy structured array
|
||||
vertices = []
|
||||
for row_idx in range(point_cloud.shape[0]):
|
||||
point = point_cloud[row_idx]
|
||||
vertices.append(tuple(dtype(val) for dtype, val in zip(py_types, point)))
|
||||
structured_array = np.array(vertices, dtype=npy_types)
|
||||
el = PlyElement.describe(structured_array, 'vertex')
|
||||
|
||||
# write ply
|
||||
PlyData([el]).write(filename)
|
||||
print('Save:', filename)
|
||||
|
||||
|
||||
def generate_chunks(filename, point_cloud, cropsize=10.0, stride=5.0):
|
||||
vertices = point_cloud[:, :3]
|
||||
bbmin = np.min(vertices, axis=0)
|
||||
bbmax = np.max(vertices, axis=0)
|
||||
bbox = bbmax - bbmin
|
||||
inbox = bbox < cropsize
|
||||
if np.all(inbox):
|
||||
return
|
||||
|
||||
chunk_id = 0
|
||||
min_size = 3000
|
||||
chunk_num = np.ceil(np.maximum(bbmax - cropsize, 0) / stride).astype(np.int32) + 1
|
||||
for i in range(chunk_num[0]):
|
||||
for j in range(chunk_num[1]):
|
||||
for k in range(chunk_num[2]):
|
||||
cmin = bbmin + np.array([i, j, k]) * stride
|
||||
cmax = cmin + cropsize
|
||||
inbox_mask = (vertices <= cmax) & (vertices >= cmin)
|
||||
inbox_mask = np.all(inbox_mask, axis=1)
|
||||
if np.sum(inbox_mask) < min_size:
|
||||
continue
|
||||
filename_out = filename.stem + '.chunk_%d.ply' % chunk_id
|
||||
save_ply(point_cloud[inbox_mask], filename.parent / filename_out)
|
||||
filename_mask = filename.stem + '.chunk_%d.mask.npy' % chunk_id
|
||||
np.save(filename.parent / filename_mask, inbox_mask)
|
||||
chunk_id += 1
|
||||
|
||||
|
||||
def process_scannet():
|
||||
for path_out, path_in in subsets.items():
|
||||
curr_path_out = Path(args.path_out) / path_out
|
||||
curr_path_out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
filenames = (Path(args.path_in) / path_in).glob('*/*' + suffix)
|
||||
for filename in filenames:
|
||||
pointcloud = read_ply(filename)
|
||||
# Make sure alpha value is meaningless.
|
||||
assert np.unique(pointcloud[:, -1]).size == 1
|
||||
|
||||
# Load label file
|
||||
label = np.zeros((pointcloud.shape[0], 1))
|
||||
filename_label = filename.parent / (filename.stem + '.labels.ply')
|
||||
if filename_label.is_file():
|
||||
label_data = read_ply(filename_label, compute_normal=False)
|
||||
# Sanity check that the pointcloud and its label has same vertices.
|
||||
assert pointcloud.shape[0] == label_data.shape[0]
|
||||
assert np.allclose(pointcloud[:, :3], label_data[:, :3])
|
||||
|
||||
label = label_data[:, -1:]
|
||||
filename_out = curr_path_out / (filename.name[:-len(suffix)] + '.txt')
|
||||
np.savetxt(filename_out, label, fmt='%d')
|
||||
if label_remap: # remap the files
|
||||
for i in range(label.shape[0]):
|
||||
label[i] = label_dict.get(int(label[i]), 0)
|
||||
|
||||
filename_out = curr_path_out / (filename.name[:-len(suffix)] + '.ply')
|
||||
processed = np.concatenate((pointcloud[:, :-1], label), axis=-1)
|
||||
# save the original file
|
||||
save_ply(processed, filename_out)
|
||||
# save the cropped chunks in the 10x10x10 box
|
||||
generate_chunks(filename_out, processed)
|
||||
|
||||
|
||||
def fix_bug_files():
|
||||
bug_files = {
|
||||
'train/scene0270_00.ply': 50,
|
||||
'train/scene0270_02.ply': 50,
|
||||
'train/scene0384_00.ply': 149}
|
||||
for files, bug_index in bug_files.items():
|
||||
print(files)
|
||||
for f in Path(args.path_out).glob(files):
|
||||
pointcloud = read_ply(f)
|
||||
bug_mask = pointcloud[:, -1] == bug_index
|
||||
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
|
||||
pointcloud[bug_mask, -1] = 0
|
||||
save_ply(pointcloud, f)
|
||||
|
||||
|
||||
def generate_output_seg():
|
||||
# load filelist
|
||||
filename_scans = []
|
||||
with open(args.filelist, 'r') as fid:
|
||||
for line in fid:
|
||||
filename = line.split()[0]
|
||||
filename_scans.append(filename[:-4]) # remove '.ply'
|
||||
|
||||
# input files
|
||||
pred_files = sorted(os.listdir(args.path_pred))
|
||||
pred_files = [f for f in pred_files if f.endswith('.npz')]
|
||||
assert len(pred_files) % len(filename_scans) == 0
|
||||
|
||||
# process
|
||||
probs = {}
|
||||
for i in tqdm(range(len(pred_files)), ncols=80):
|
||||
filename_scan = filename_scans[i % len(filename_scans)]
|
||||
|
||||
pred = np.load(os.path.join(args.path_pred, pred_files[i]))
|
||||
prob, inbox_mask = pred['prob'], pred['inbox_mask']
|
||||
prob0 = np.zeros([inbox_mask.shape[0], prob.shape[1]])
|
||||
prob0[inbox_mask] = prob
|
||||
|
||||
if 'chunk' in filename_scan:
|
||||
filename_mask = filename_scan + '.mask.npy'
|
||||
mask = np.load(os.path.join(args.path_in, filename_mask))
|
||||
prob1 = np.zeros([mask.shape[0], prob0.shape[1]])
|
||||
prob1[mask] = prob0
|
||||
|
||||
# update prob0 and filename_scan
|
||||
prob0 = prob1
|
||||
filename_scan = filename_scan[:-8] # remove '.chunk_x'
|
||||
|
||||
probs[filename_scan] = probs.get(filename_scan, 0) + prob0
|
||||
|
||||
# output
|
||||
if not os.path.exists(args.path_out):
|
||||
os.makedirs(args.path_out)
|
||||
|
||||
for filename, prob in tqdm(probs.items(), ncols=80):
|
||||
filename_label = filename + '.txt'
|
||||
label = np.argmax(prob, axis=1)
|
||||
for i in range(label.shape[0]):
|
||||
label[i] = ilabel_dict[label[i]]
|
||||
np.savetxt(os.path.join(args.path_out, filename_label), label, fmt='%d')
|
||||
|
||||
|
||||
def calc_iou():
|
||||
# init
|
||||
intsc, union, accu = {}, {}, 0
|
||||
for k in class_ids[1:]:
|
||||
intsc[k] = 0
|
||||
union[k] = 0
|
||||
|
||||
# load files
|
||||
pred_files = sorted(os.listdir(args.path_pred))
|
||||
pred_files = [f for f in pred_files if f.endswith('.txt')]
|
||||
for filename in tqdm(pred_files, ncols=80):
|
||||
label_pred = np.loadtxt(os.path.join(args.path_pred, filename))
|
||||
label_gt = np.loadtxt(os.path.join(args.path_in, filename))
|
||||
|
||||
# omit labels out of class_ids[1:]
|
||||
mask = np.zeros_like(label_gt).astype(bool)
|
||||
for i in range(label_gt.shape[0]):
|
||||
mask[i] = label_gt[i] in class_ids[1:]
|
||||
label_pred = label_pred[mask]
|
||||
label_gt = label_gt[mask]
|
||||
|
||||
ac = (label_gt == label_pred).mean()
|
||||
tqdm.write("Accu: %s, %.4f" % (filename, ac))
|
||||
accu += ac
|
||||
|
||||
for k in class_ids[1:]:
|
||||
pk, lk = label_pred == k, label_gt == k
|
||||
intsc[k] += np.sum(np.logical_and(pk, lk).astype(np.float32))
|
||||
union[k] += np.sum(np.logical_or(pk, lk).astype(np.float32))
|
||||
|
||||
# iou
|
||||
iou_part = 0
|
||||
for k in class_ids[1:]:
|
||||
iou_part += intsc[k] / (union[k] + 1.0e-10)
|
||||
iou = iou_part / len(class_ids[1:])
|
||||
print('Accu: %.6f' % (accu / len(pred_files)))
|
||||
print('IoU: %.6f' % iou)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval('%s()' % args.run)
|
|
@ -0,0 +1,13 @@
|
|||
import ocnn
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
writer = SummaryWriter('logs/resnet')
|
||||
octree = ocnn.octree_batch(ocnn.octree_samples(['octree_1', 'octree_2']))
|
||||
model = ocnn.ResNet(depth=5, channel_in=3, nout=4, resblk_num=2)
|
||||
print(model)
|
||||
|
||||
octree = octree.cuda()
|
||||
model = model.cuda()
|
||||
writer.add_graph(model, octree)
|
||||
writer.flush()
|
|
@ -0,0 +1,9 @@
|
|||
torch
|
||||
numpy
|
||||
tqdm
|
||||
yacs
|
||||
scipy
|
||||
plyfile
|
||||
tensorboard
|
||||
scikit-image
|
||||
trimesh
|
|
@ -8,6 +8,7 @@ from test_octree_property import OctreePropertyTest
|
|||
from test_octree_key import OctreeKeyTest
|
||||
from test_points_property import PointsPropertyTest
|
||||
from test_octree_trilinear import OctreeTrilinearTest
|
||||
from test_octree_align import OctreeAlignTest
|
||||
|
||||
# Run 16 test in total
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -61,7 +61,7 @@ class Octree2ColTest(unittest.TestCase):
|
|||
octree = self.octree.cuda()
|
||||
data_in = torch.from_numpy(self.data_in).cuda()
|
||||
data_out = ocnn.octree2col(data_in, octree,
|
||||
self.depth, kernel_size[j], stride[i])
|
||||
self.depth, kernel_size[j], stride[i], False)
|
||||
|
||||
data_out = data_out.cpu().detach().numpy()
|
||||
self.assertTrue(np.array_equal(data_out, out_gt))
|
||||
|
@ -76,7 +76,7 @@ class Octree2ColTest(unittest.TestCase):
|
|||
octree = self.octree.cuda()
|
||||
data_in = torch.from_numpy(self.data_in).cuda().requires_grad_()
|
||||
|
||||
params = [data_in, octree, self.depth, kernel_size[j], stride[i]]
|
||||
params = [data_in, octree, self.depth, kernel_size[j], stride[i], False]
|
||||
succ = gradcheck(ocnn.octree2col, params, eps=1.0)
|
||||
self.assertTrue(succ)
|
||||
|
||||
|
@ -89,8 +89,8 @@ class Octree2ColTest(unittest.TestCase):
|
|||
for j in range(len(vi)):
|
||||
octree = self.octree.cuda()
|
||||
data_in = torch.from_numpy(self.data_in).cuda()
|
||||
data_out = ocnn.nn.octree2colP(data_in, octree,
|
||||
self.depth, kernel_size[j], stride[i])
|
||||
data_out = ocnn.nn.octree2col(data_in, octree,
|
||||
self.depth, kernel_size[j], stride[i], True)
|
||||
data_out = data_out.cpu().detach().numpy()
|
||||
|
||||
out_gt = self.forward(kernel_size[j], stride[i], self.idx_maps[vi[j]])
|
||||
|
@ -113,13 +113,13 @@ class Octree2ColTest(unittest.TestCase):
|
|||
# octree2colP = octree2col + depad
|
||||
for i in range(len(stride)):
|
||||
for j in range(len(kernel_size)):
|
||||
out1 = ocnn.octree2col(data1, octree, depth, kernel_size[j], stride[i])
|
||||
out1 = ocnn.octree2col(data1, octree, depth, kernel_size[j], stride[i], False)
|
||||
if stride[i] == 1:
|
||||
ks, height = out1.size(1), out1.size(2)
|
||||
out1 = out1.view(1, -1, height, 1)
|
||||
out1 = ocnn.octree_depad(out1, octree, depth)
|
||||
out1 = out1.view(channel, ks, -1)
|
||||
out2 = ocnn.octree2colP(data_in2, octree, depth, kernel_size[j], stride[i])
|
||||
out2 = ocnn.octree2col(data_in2, octree, depth, kernel_size[j], stride[i], True)
|
||||
|
||||
pesudo_grad = torch.rand(out1.shape, dtype=out1.dtype, device=out1.device)
|
||||
out1.backward(pesudo_grad, retain_graph=True)
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
import os
|
||||
import torch
|
||||
import ocnn
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OctreeAlignTest(unittest.TestCase):
|
||||
def get_octree(self, filelist):
|
||||
batch = ocnn.octree_samples(filelist)
|
||||
return ocnn.octree_batch(batch).cuda()
|
||||
|
||||
def test_forward_backward1(self):
|
||||
depth = 5
|
||||
octree = self.get_octree(['octree_1', 'octree_1'])
|
||||
data_in = torch.rand(1, 3, 16, 1).cuda().requires_grad_()
|
||||
data_out, idx = ocnn.octree_align(data_in, octree, octree, depth)
|
||||
idx_gt = torch.arange(16, dtype=torch.int32).cuda()
|
||||
|
||||
out = data_out.sum()
|
||||
out.backward()
|
||||
grad_gt = np.ones([1, 3, 16, 1])
|
||||
|
||||
self.assertTrue(np.array_equal(data_out.cpu().detach().numpy(),
|
||||
data_in.cpu().detach().numpy()))
|
||||
self.assertTrue(np.array_equal(idx.cpu().detach().numpy(),
|
||||
idx_gt.cpu().detach().numpy()))
|
||||
self.assertTrue(np.array_equal(data_in.grad.cpu().numpy(),
|
||||
grad_gt))
|
||||
|
||||
def test_forward_backward2(self):
|
||||
depth = 5
|
||||
octree_in = self.get_octree(['octree_1'])
|
||||
octree_out = self.get_octree(['octree_1', 'octree_1'])
|
||||
|
||||
data_in = torch.rand(1, 3, 8, 1).cuda().requires_grad_()
|
||||
data_out, idx = ocnn.octree_align(data_in, octree_in, octree_out, depth)
|
||||
zeros = torch.zeros(1, 3, 8, 1, dtype=torch.float32).cuda()
|
||||
data_gt = torch.cat([data_in, zeros], dim=2)
|
||||
idx_gt = torch.arange(8, dtype=torch.int32)
|
||||
|
||||
out = data_out.sum()
|
||||
out.backward()
|
||||
grad_gt = np.ones([1, 3, 8, 1])
|
||||
|
||||
self.assertTrue(np.array_equal(data_out.cpu().detach().numpy(),
|
||||
data_gt.cpu().detach().numpy()))
|
||||
self.assertTrue(np.array_equal(idx.cpu().detach().numpy(),
|
||||
idx_gt.cpu().detach().numpy()))
|
||||
self.assertTrue(np.array_equal(data_in.grad.cpu().numpy(),
|
||||
grad_gt))
|
||||
|
||||
def test_forward_backward3(self):
|
||||
depth = 5
|
||||
octree_in = self.get_octree(['octree_1', 'octree_1'])
|
||||
octree_out = self.get_octree(['octree_1'])
|
||||
data_in = torch.rand(1, 3, 16, 1).cuda().requires_grad_()
|
||||
data_out, idx = ocnn.octree_align(data_in, octree_in, octree_out, depth)
|
||||
data_gt = data_in[:, :, :8, :]
|
||||
idx_gt = np.array(list(range(8)) + [-1] * 8)
|
||||
|
||||
out = data_out.sum()
|
||||
out.backward()
|
||||
grad_gt = torch.cat([torch.ones(1, 3, 8, 1), torch.zeros(1, 3, 8, 1)], 2)
|
||||
|
||||
self.assertTrue(np.array_equal(data_out.cpu().detach().numpy(),
|
||||
data_gt.cpu().detach().numpy()))
|
||||
self.assertTrue(np.array_equal(idx.cpu().detach().numpy(),
|
||||
idx_gt))
|
||||
self.assertTrue(np.array_equal(data_in.grad.cpu().numpy(),
|
||||
grad_gt))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
unittest.main()
|
|
@ -18,24 +18,48 @@ class OctreeConvTest(unittest.TestCase):
|
|||
# forward
|
||||
conv1 = ocnn.OctreeConv(depth, channel, num_outputs, kernel_size, stride)
|
||||
conv2 = ocnn.OctreeConvFast(depth, channel, num_outputs, kernel_size, stride)
|
||||
conv3 = ocnn.OctreeConv(depth, channel, num_outputs, kernel_size, stride, True)
|
||||
conv4 = ocnn.OctreeConv(depth, channel, num_outputs, kernel_size, stride)
|
||||
|
||||
# use the same initialization
|
||||
with torch.no_grad():
|
||||
conv2.weights.data = conv1.weights.data
|
||||
conv2.weights.data.copy_(conv1.weights.data)
|
||||
conv3.weights.data.copy_(conv1.weights.data)
|
||||
conv4.weights.data.copy_(conv1.weights.data)
|
||||
|
||||
# forward
|
||||
octree = octree.to('cuda')
|
||||
conv1.to('cuda')
|
||||
data1 = torch.from_numpy(data).to('cuda').requires_grad_()
|
||||
# forward - compare OctreeConv and OctreeConvFast
|
||||
octree = octree.cuda()
|
||||
conv1.cuda()
|
||||
data1 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
out1 = conv1(data1, octree)
|
||||
conv2.to('cuda')
|
||||
data2 = torch.from_numpy(data).to('cuda').requires_grad_()
|
||||
|
||||
conv2.cuda()
|
||||
data2 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
out2 = conv2(data2, octree)
|
||||
|
||||
# forward - compare OctreeConv with nempty = True and False
|
||||
conv3.cuda()
|
||||
mask3 = ocnn.octree_property(octree, 'child', depth) >= 0
|
||||
data3 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
tmp3 = data3[:, :, mask3]
|
||||
out3 = conv3(tmp3, octree)
|
||||
|
||||
conv4.cuda()
|
||||
depth_out = depth if stride == 1 else depth - 1
|
||||
mask4 = ocnn.octree_property(octree, 'child', depth_out) >= 0
|
||||
data4 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
tmp4 = data4 * mask3.unsqueeze(-1).float()
|
||||
tmp4 = conv4(tmp4, octree)
|
||||
out4 = tmp4[:, :, mask4]
|
||||
|
||||
# backward
|
||||
pesudo_grad = torch.rand(out1.shape, dtype=out1.dtype, device=out1.device)
|
||||
out1.backward(pesudo_grad)
|
||||
out2.backward(pesudo_grad)
|
||||
pesudo_grad1 = torch.rand(out1.shape, dtype=out1.dtype, device=out1.device)
|
||||
out1.backward(pesudo_grad1)
|
||||
out2.backward(pesudo_grad1)
|
||||
|
||||
pesudo_grad2 = torch.rand(out3.shape, dtype=out3.dtype, device=out3.device)
|
||||
out3.backward(pesudo_grad2)
|
||||
out4.backward(pesudo_grad2)
|
||||
|
||||
# test
|
||||
self.assertTrue(np.array_equal(out1.cpu().detach().numpy(),
|
||||
|
@ -47,6 +71,16 @@ class OctreeConvTest(unittest.TestCase):
|
|||
conv2.weights.grad.cpu().numpy(),
|
||||
atol=1e-06))
|
||||
|
||||
self.assertTrue(np.allclose(out3.cpu().detach().numpy(),
|
||||
out4.cpu().detach().numpy(),
|
||||
atol=1e-06))
|
||||
self.assertTrue(np.allclose(data3.grad.cpu().numpy(),
|
||||
data4.grad.cpu().numpy(),
|
||||
atol=1e-06))
|
||||
self.assertTrue(np.allclose(conv3.weights.grad.cpu().numpy(),
|
||||
conv4.weights.grad.cpu().numpy(),
|
||||
atol=1e-06))
|
||||
|
||||
def test_forward_and_backward(self):
|
||||
stride = [1, 2]
|
||||
kernel_size = [[3, 3, 3], [2, 2, 2], [3, 1, 1], [3, 3, 1], [1, 1, 1]]
|
||||
|
|
|
@ -6,36 +6,61 @@ import numpy as np
|
|||
|
||||
|
||||
class OctreeDeconvTest(unittest.TestCase):
|
||||
|
||||
def forward_and_backward(self, kernel_size, stride, idx=0):
|
||||
depth = 4
|
||||
channel = 3
|
||||
height = 152
|
||||
num_outputs = 2
|
||||
num_outputs = 5
|
||||
octree = ocnn.octree_batch(ocnn.octree_samples(['octree_1', 'octree_2']))
|
||||
data = np.random.uniform(-1.0, 1.0, [1, channel, height, 1]).astype('float32')
|
||||
|
||||
# forward
|
||||
deconv1 = ocnn.OctreeDeconv(depth, channel, num_outputs, kernel_size, stride)
|
||||
deconv2 = ocnn.OctreeDeconvFast(depth, channel, num_outputs, kernel_size, stride)
|
||||
conv1 = ocnn.OctreeDeconv(depth, channel, num_outputs, kernel_size, stride)
|
||||
conv2 = ocnn.OctreeDeconvFast(depth, channel, num_outputs, kernel_size, stride)
|
||||
conv3 = ocnn.OctreeDeconv(depth, channel, num_outputs, kernel_size, stride, True)
|
||||
conv4 = ocnn.OctreeDeconv(depth, channel, num_outputs, kernel_size, stride)
|
||||
|
||||
# use the same initialization
|
||||
with torch.no_grad():
|
||||
deconv2.weights.data = deconv1.weights.data
|
||||
conv2.weights.data.copy_(conv1.weights.data)
|
||||
conv3.weights.data.copy_(conv1.weights.data)
|
||||
conv4.weights.data.copy_(conv1.weights.data)
|
||||
|
||||
# forward
|
||||
octree = octree.to('cuda')
|
||||
deconv1.to('cuda')
|
||||
data1 = torch.from_numpy(data).to('cuda').requires_grad_()
|
||||
out1 = deconv1(data1, octree)
|
||||
deconv2.to('cuda')
|
||||
data2 = torch.from_numpy(data).to('cuda').requires_grad_()
|
||||
out2 = deconv2(data2, octree)
|
||||
# forward - compare OctreeConv and OctreeConvFast
|
||||
octree = octree.cuda()
|
||||
conv1.cuda()
|
||||
data1 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
out1 = conv1(data1, octree)
|
||||
|
||||
conv2.cuda()
|
||||
data2 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
out2 = conv2(data2, octree)
|
||||
|
||||
# forward - compare OctreeConv with nempty = True and False
|
||||
conv3.cuda()
|
||||
mask3 = ocnn.octree_property(octree, 'child', depth) >= 0
|
||||
data3 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
tmp3 = data3[:, :, mask3]
|
||||
out3 = conv3(tmp3, octree)
|
||||
|
||||
conv4.cuda()
|
||||
depth_out = depth if stride == 1 else depth + 1
|
||||
mask4 = ocnn.octree_property(octree, 'child', depth_out) >= 0
|
||||
data4 = torch.from_numpy(data).cuda().requires_grad_()
|
||||
tmp4 = data4 * mask3.unsqueeze(-1).float()
|
||||
tmp4 = conv4(tmp4, octree)
|
||||
out4 = tmp4[:, :, mask4]
|
||||
|
||||
# backward
|
||||
pesudo_grad = torch.rand(out1.shape, dtype=out1.dtype, device=out1.device)
|
||||
out1.backward(pesudo_grad)
|
||||
out2.backward(pesudo_grad)
|
||||
|
||||
pesudo_grad2 = torch.rand(out3.shape, dtype=out3.dtype, device=out3.device)
|
||||
out3.backward(pesudo_grad2)
|
||||
out4.backward(pesudo_grad2)
|
||||
|
||||
# test
|
||||
self.assertTrue(np.allclose(out1.cpu().detach().numpy(),
|
||||
out2.cpu().detach().numpy(),
|
||||
|
@ -43,8 +68,18 @@ class OctreeDeconvTest(unittest.TestCase):
|
|||
self.assertTrue(np.allclose(data1.grad.cpu().numpy(),
|
||||
data2.grad.cpu().numpy(),
|
||||
atol=1e-06))
|
||||
self.assertTrue(np.allclose(deconv1.weights.grad.cpu().numpy(),
|
||||
deconv2.weights.grad.cpu().numpy(),
|
||||
self.assertTrue(np.allclose(conv1.weights.grad.cpu().numpy(),
|
||||
conv2.weights.grad.cpu().numpy(),
|
||||
atol=1e-06))
|
||||
|
||||
self.assertTrue(np.allclose(out3.cpu().detach().numpy(),
|
||||
out4.cpu().detach().numpy(),
|
||||
atol=1e-06))
|
||||
self.assertTrue(np.allclose(data3.grad.cpu().numpy(),
|
||||
data4.grad.cpu().numpy(),
|
||||
atol=1e-06))
|
||||
self.assertTrue(np.allclose(conv3.weights.grad.cpu().numpy(),
|
||||
conv4.weights.grad.cpu().numpy(),
|
||||
atol=1e-06))
|
||||
|
||||
def test_forward_and_backward(self):
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import ocnn
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
depth, channel = 5, 3
|
||||
octree = ocnn.octree_new(batch_size=1, channel=channel, node_dis=False)
|
||||
octree = ocnn.octree_grow(octree, target_depth=1, full_octree=True)
|
||||
octree = ocnn.octree_grow(octree, target_depth=2, full_octree=True)
|
||||
octree_gt = ocnn.octree_samples(['octree_2'])[0].cuda()
|
||||
|
||||
for d in range(2, depth + 1):
|
||||
child = ocnn.octree_property(octree_gt, 'child', depth=d)
|
||||
label = (child > -1).to(torch.int32)
|
||||
octree = ocnn.octree_update(octree, label, depth=d, split=1)
|
||||
if d < depth:
|
||||
octree = ocnn.octree_grow(octree, target_depth=d+1, full_octree=False)
|
||||
|
||||
feature = ocnn.octree_property(octree_gt, 'feature', depth)
|
||||
octree = ocnn.octree_set_property(octree, feature, depth)
|
||||
|
||||
print('Please check the output files:`octree_1.octree` and `octree_2.octree`.\n'
|
||||
'The MD5 of `octree_1.octree`: FEB7C4AF43669EB0FC62632C71D1C938\n'
|
||||
'The MD5 of `octree_2.octree`: D569D5BB23D34795C5FD81397F56275B')
|
||||
octree_gt.cpu().numpy().tofile('octree_1.octree')
|
||||
octree.cpu().numpy().tofile('octree_2.octree')
|
||||
|
|
@ -31,20 +31,27 @@ class OctreeKeyTest(unittest.TestCase):
|
|||
def test_search_key(self):
|
||||
samples = ocnn.octree_samples(['octree_1', 'octree_1'])
|
||||
octree = ocnn.octree_batch(samples).cuda()
|
||||
|
||||
key = torch.cuda.LongTensor([28673, 281474976739335, 10])
|
||||
idx_gt = torch.cuda.IntTensor([1, 15, -1])
|
||||
idx = ocnn.octree_search_key(key, octree, 5, False)
|
||||
idx = ocnn.octree_search_key(key, octree, 5, key_is_xyz=False, nempty=False)
|
||||
self.assertTrue((idx == idx_gt).cpu().numpy().all())
|
||||
|
||||
key = torch.cuda.LongTensor([28672, 28673, 281474976739328, 10])
|
||||
idx_gt = torch.cuda.IntTensor([0, -1, 1, -1])
|
||||
idx = ocnn.octree_search_key(key, octree, 5, key_is_xyz=False, nempty=True)
|
||||
self.assertTrue((idx == idx_gt).cpu().numpy().all())
|
||||
|
||||
def test_xyz_key_64(self):
|
||||
# the length of key is over 32 bits
|
||||
xyz = torch.cuda.ShortTensor([[2049, 4095, 8011, 1], [511, 4095, 8011, 0]])
|
||||
xyz_encode = ocnn.octree_encode_key(xyz)
|
||||
key = ocnn.octree_xyz2key(xyz_encode, 13)
|
||||
key = ocnn.octree_xyz2key(xyz_encode, 13)
|
||||
xyz_out = ocnn.octree_key2xyz(key, 13)
|
||||
xyz_decode = ocnn.octree_decode_key(xyz_out)
|
||||
self.assertTrue((xyz == xyz_decode).cpu().numpy().all())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
unittest.main()
|
||||
|
|
|
@ -7,7 +7,8 @@ import numpy as np
|
|||
|
||||
class OctreePropertyTest(unittest.TestCase):
|
||||
def octree_property(self, on_cuda=True):
|
||||
octree = ocnn.octree_batch(ocnn.octree_samples(['octree_1'] * 2))
|
||||
batch_size = 2
|
||||
octree = ocnn.octree_batch(ocnn.octree_samples(['octree_1'] * batch_size))
|
||||
if on_cuda:
|
||||
octree = octree.cuda()
|
||||
|
||||
|
@ -30,6 +31,27 @@ class OctreePropertyTest(unittest.TestCase):
|
|||
out_gt[0] = 0
|
||||
out_gt[8] = 1
|
||||
self.assertTrue(np.array_equal(out.cpu().numpy(), out_gt))
|
||||
# test child from depth=0
|
||||
out = torch.cat([ocnn.octree_property(octree, 'child', d) for d in range(1, 6)])
|
||||
outs = ocnn.octree_property(octree, 'child')
|
||||
self.assertTrue(np.array_equal(outs[batch_size:].cpu().numpy(), out.cpu().numpy()))
|
||||
|
||||
# test node number
|
||||
nnums = np.array([2, 16, 128, 16, 16, 16])
|
||||
nnum_cums = np.array([0, 2, 18, 146, 162, 178, 194])
|
||||
node_num = ocnn.octree_property(octree, 'node_num', 5)
|
||||
node_nums = ocnn.octree_property(octree, 'node_num')
|
||||
node_num_cum = ocnn.octree_property(octree, 'node_num_cum', 5)
|
||||
node_nums_cum = ocnn.octree_property(octree, 'node_num_cum')
|
||||
self.assertTrue(node_num.item() == nnums[5])
|
||||
self.assertTrue(node_num_cum.item() == nnum_cums[5])
|
||||
self.assertTrue(np.array_equal(node_nums.cpu().numpy(), nnums))
|
||||
self.assertTrue(np.array_equal(node_nums_cum.cpu().numpy(), nnum_cums))
|
||||
|
||||
# test batch_size, depth, full_depth
|
||||
self.assertTrue(ocnn.octree_property(octree, 'batch_size').item() == batch_size)
|
||||
self.assertTrue(ocnn.octree_property(octree, 'depth').item() == 5)
|
||||
self.assertTrue(ocnn.octree_property(octree, 'full_depth').item() == 2)
|
||||
|
||||
# TODO: test key, xyz, and label
|
||||
# out = ocnn.octree_property(octree, 'key', 5)
|
||||
|
|
Загрузка…
Ссылка в новой задаче