Separate TFRecords code on different file (#514)
* WIP: adding TFRecords Builder class * Added TFRecordsBuilder class * added unit test * Refactored AddFeature and AddFeatureArray interface * removed import
This commit is contained in:
Родитель
f0982c369e
Коммит
8198c41775
|
@ -7,13 +7,11 @@ import { IProject, AssetState, AssetType, IAsset,
|
|||
import { AssetService } from "../../services/assetService";
|
||||
import Guard from "../../common/guard";
|
||||
import HtmlFileReader from "../../common/htmlFileReader";
|
||||
import { itemTemplate, annotationTemplate, objectTemplate } from "./tensorFlowPascalVOC/tensorFlowPascalVOCTemplates";
|
||||
import { itemTemplate } from "./tensorFlowPascalVOC/tensorFlowPascalVOCTemplates";
|
||||
import { strings, interpolate } from "../../common/strings";
|
||||
import { TFRecordsImageMessage, Features, Feature, FeatureList,
|
||||
BytesList, Int64List, FloatList } from "./tensorFlowRecords/tensorFlowRecordsProtoBuf_pb";
|
||||
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer } from "./tensorFlowRecords/tensorFlowHelpers";
|
||||
import { TFRecordsBuilder, FeatureType } from "./tensorFlowRecords/tensorFlowBuilder";
|
||||
|
||||
/**
|
||||
/**64
|
||||
* @name - ITFRecordsJsonExportOptions
|
||||
* @description - Defines the configurable options for the Vott JSON Export provider
|
||||
*/
|
||||
|
@ -81,105 +79,6 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
|
|||
await this.exportRecords(exportFolderName, allAssets);
|
||||
}
|
||||
|
||||
private addIntFeature(features: Features, key: string, value: number) {
|
||||
const intList = new Int64List();
|
||||
intList.addValue(value);
|
||||
|
||||
const feature = new Feature();
|
||||
feature.setInt64List(intList);
|
||||
|
||||
const featuresMap = features.getFeatureMap();
|
||||
featuresMap.set(key, feature);
|
||||
}
|
||||
|
||||
private addIntArrayFeature(features: Features, key: string, values: number[]) {
|
||||
const intList = new Int64List();
|
||||
values.forEach((value) => {
|
||||
intList.addValue(value);
|
||||
});
|
||||
|
||||
const feature = new Feature();
|
||||
feature.setInt64List(intList);
|
||||
|
||||
const featuresMap = features.getFeatureMap();
|
||||
featuresMap.set(key, feature);
|
||||
}
|
||||
|
||||
private addFloatArrayFeature(features: Features, key: string, values: number[]) {
|
||||
const floatList = new FloatList();
|
||||
values.forEach((value) => {
|
||||
floatList.addValue(value);
|
||||
});
|
||||
|
||||
const feature = new Feature();
|
||||
feature.setFloatList(floatList);
|
||||
|
||||
const featuresMap = features.getFeatureMap();
|
||||
featuresMap.set(key, feature);
|
||||
}
|
||||
|
||||
private addStringFeature(features: Features, key: string, value: string) {
|
||||
this.addBinaryArrayFeature(features, key, this.textEncode(value));
|
||||
}
|
||||
|
||||
private addBinaryArrayFeature(features: Features, key: string, value: Uint8Array) {
|
||||
const byteList = new BytesList();
|
||||
byteList.addValue(value);
|
||||
|
||||
const feature = new Feature();
|
||||
feature.setBytesList(byteList);
|
||||
|
||||
const featuresMap = features.getFeatureMap();
|
||||
featuresMap.set(key, feature);
|
||||
}
|
||||
|
||||
private addStringArrayFeature(features: Features, key: string, values: string[]) {
|
||||
const byteList = new BytesList();
|
||||
values.forEach((value) => {
|
||||
byteList.addValue(this.textEncode(value));
|
||||
});
|
||||
|
||||
const feature = new Feature();
|
||||
feature.setBytesList(byteList);
|
||||
|
||||
const featuresMap = features.getFeatureMap();
|
||||
featuresMap.set(key, feature);
|
||||
}
|
||||
|
||||
private async writeTFRecord(fileNamePath: string, features: Features) {
|
||||
try {
|
||||
// Get Protocol Buffer TFRecords object with exported image features
|
||||
const imageMessage = new TFRecordsImageMessage();
|
||||
imageMessage.setContext(features);
|
||||
|
||||
// Serialize Protocol Buffer in a buffer
|
||||
const bytes = imageMessage.serializeBinary();
|
||||
const bufferData = new Buffer(bytes);
|
||||
const length = bufferData.length;
|
||||
|
||||
// Get TFRecords CRCs for TFRecords Header and Footer
|
||||
const bufferLength = getInt64Buffer(length);
|
||||
const bufferLengthMaskedCRC = getInt32Buffer(maskCrc(crc32c(bufferLength)));
|
||||
const bufferDataMaskedCRC = getInt32Buffer(maskCrc(crc32c(bufferData)));
|
||||
|
||||
// Concatenate all TFRecords Header, Data and Footer buffer
|
||||
const outBuffer = Buffer.concat([bufferLength,
|
||||
bufferLengthMaskedCRC,
|
||||
bufferData,
|
||||
bufferDataMaskedCRC]);
|
||||
|
||||
// Write TFRecords
|
||||
await this.storageProvider.writeBinary(fileNamePath, outBuffer);
|
||||
|
||||
} catch (error) {
|
||||
// Ignore the error at the moment
|
||||
// TODO: Refactor ExportProvider abstract class export() method
|
||||
// to return Promise<object> with an object containing
|
||||
// the number of files succesfully exported out of total
|
||||
console.log(`Error Writing TFRecords ${fileNamePath} - ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
private async exportRecords(exportFolderName: string, allAssets: IAssetMetadata[]) {
|
||||
const allImageExports = allAssets.map((element) => {
|
||||
return this.exportSingleRecord(exportFolderName, element);
|
||||
|
@ -218,30 +117,30 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
|
|||
this.updateAssetTagArrays(element, imageInfo);
|
||||
|
||||
// Generate TFRecord
|
||||
const features = new Features();
|
||||
const builder = new TFRecordsBuilder();
|
||||
|
||||
this.addIntFeature(features, "image/height", imageInfo.height);
|
||||
this.addIntFeature(features, "image/width", imageInfo.width);
|
||||
this.addStringFeature(features, "image/filename", element.asset.name);
|
||||
this.addStringFeature(features, "image/source_id", element.asset.name);
|
||||
this.addStringFeature(features, "image/key/sha256", CryptoJS.SHA256(imageBuffer)
|
||||
builder.addFeature("image/height", FeatureType.Int64, imageInfo.height);
|
||||
builder.addFeature("image/width", FeatureType.Int64, imageInfo.width);
|
||||
builder.addFeature("image/filename", FeatureType.String, element.asset.name);
|
||||
builder.addFeature("image/source_id", FeatureType.String, element.asset.name);
|
||||
builder.addFeature("image/key/sha256", FeatureType.String, CryptoJS.SHA256(imageBuffer)
|
||||
.toString(CryptoJS.enc.Base64));
|
||||
this.addBinaryArrayFeature(features, "image/encoded", imageBuffer);
|
||||
this.addStringFeature(features, "image/format", element.asset.name.split(".").pop());
|
||||
this.addFloatArrayFeature(features, "image/object/bbox/xmin", imageInfo.xmin);
|
||||
this.addFloatArrayFeature(features, "image/object/bbox/ymin", imageInfo.ymin);
|
||||
this.addFloatArrayFeature(features, "image/object/bbox/xmax", imageInfo.xmax);
|
||||
this.addFloatArrayFeature(features, "image/object/bbox/ymax", imageInfo.ymax);
|
||||
this.addStringArrayFeature(features, "image/object/class/text", imageInfo.text);
|
||||
this.addIntArrayFeature(features, "image/object/class/label", imageInfo.label);
|
||||
this.addIntArrayFeature(features, "image/object/difficult", imageInfo.difficult);
|
||||
this.addIntArrayFeature(features, "image/object/truncated", imageInfo.truncated);
|
||||
this.addStringArrayFeature(features, "image/object/view", imageInfo.view);
|
||||
builder.addFeature("image/encoded", FeatureType.Binary, imageBuffer);
|
||||
builder.addFeature("image/format", FeatureType.String, element.asset.name.split(".").pop());
|
||||
builder.addArrayFeature("image/object/bbox/xmin", FeatureType.Float, imageInfo.xmin);
|
||||
builder.addArrayFeature("image/object/bbox/ymin", FeatureType.Float, imageInfo.ymin);
|
||||
builder.addArrayFeature("image/object/bbox/xmax", FeatureType.Float, imageInfo.xmax);
|
||||
builder.addArrayFeature("image/object/bbox/ymax", FeatureType.Float, imageInfo.ymax);
|
||||
builder.addArrayFeature("image/object/class/text", FeatureType.String, imageInfo.text);
|
||||
builder.addArrayFeature("image/object/class/label", FeatureType.Int64, imageInfo.label);
|
||||
builder.addArrayFeature("image/object/difficult", FeatureType.Int64, imageInfo.difficult);
|
||||
builder.addArrayFeature("image/object/truncated", FeatureType.Int64, imageInfo.truncated);
|
||||
builder.addArrayFeature("image/object/view", FeatureType.String, imageInfo.view);
|
||||
|
||||
// Save TFRecord
|
||||
// Save TFRecords
|
||||
const fileName = element.asset.name.split(".").slice(0, -1).join(".");
|
||||
const fileNamePath = `${exportFolderName}/${fileName}.tfrecord`;
|
||||
await this.writeTFRecord(fileNamePath, features);
|
||||
await this.writeTFRecords(fileNamePath, [builder.build()]);
|
||||
|
||||
resolve();
|
||||
} catch (error) {
|
||||
|
@ -256,6 +155,14 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
|
|||
});
|
||||
}
|
||||
|
||||
private async writeTFRecords(fileNamePath: string, buffers: Buffer[]) {
|
||||
// Get TFRecords buffer
|
||||
const tfRecords = TFRecordsBuilder.buildTFRecords(buffers);
|
||||
|
||||
// Write TFRecords
|
||||
await this.storageProvider.writeBinary(fileNamePath, tfRecords);
|
||||
}
|
||||
|
||||
private async updateImageSizeInfo(image64: string, imageInfo: IImageInfo) {
|
||||
if (image64.length > 10) {
|
||||
const assetProps = await HtmlFileReader.readAssetAttributesWithBuffer(image64);
|
||||
|
@ -307,13 +214,4 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
|
|||
await this.storageProvider.writeText(pbtxtFileName, items.join(""));
|
||||
}
|
||||
}
|
||||
|
||||
private textEncode(str: string): Uint8Array {
|
||||
const utf8 = unescape(encodeURIComponent(str));
|
||||
const result = new Uint8Array(utf8.length);
|
||||
for (let i = 0; i < utf8.length; i++) {
|
||||
result[i] = utf8.charCodeAt(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import { TFRecordsBuilder, FeatureType } from "./tensorFlowBuilder";
|
||||
|
||||
describe("TFRecords Builder Functions", () => {
|
||||
describe("Check Adding Single Features", () => {
|
||||
let builder: TFRecordsBuilder;
|
||||
beforeEach(() => {
|
||||
builder = new TFRecordsBuilder();
|
||||
});
|
||||
|
||||
it("Check addIntFeature", async () => {
|
||||
builder.addFeature("image/height", FeatureType.Int64, 123);
|
||||
|
||||
expect(builder.build()).toEqual(
|
||||
new Buffer([10, 23, 10, 21, 10, 12, 105, 109, 97, 103, 101, 47, 104,
|
||||
101, 105, 103, 104, 116, 18, 5, 26, 3, 10, 1, 123]));
|
||||
});
|
||||
|
||||
it("Check addFloatFeature", async () => {
|
||||
builder.addFeature("image/height", FeatureType.Float, 123.0);
|
||||
|
||||
expect(builder.build()).toEqual(
|
||||
new Buffer([10, 26, 10, 24, 10, 12, 105, 109, 97, 103, 101, 47, 104,
|
||||
101, 105, 103, 104, 116, 18, 8, 18, 6, 10, 4, 0, 0, 246, 66]));
|
||||
});
|
||||
|
||||
it("Check addStringFeature", async () => {
|
||||
builder.addFeature("image/height", FeatureType.String, "123");
|
||||
|
||||
expect(builder.build()).toEqual(
|
||||
new Buffer([10, 25, 10, 23, 10, 12, 105, 109, 97, 103, 101, 47, 104,
|
||||
101, 105, 103, 104, 116, 18, 7, 10, 5, 10, 3, 49, 50, 51]));
|
||||
});
|
||||
});
|
||||
|
||||
describe("Check single TFRecord generation with arrays", () => {
|
||||
let builder: TFRecordsBuilder;
|
||||
|
||||
it("Check releaseTFRecord", async () => {
|
||||
builder = new TFRecordsBuilder();
|
||||
|
||||
builder.addArrayFeature("image/height", FeatureType.Int64, [1, 2]);
|
||||
builder.addArrayFeature("image/height", FeatureType.Float, [1.0, 2.0]);
|
||||
builder.addArrayFeature("image/height", FeatureType.String, ["1", "2"]);
|
||||
|
||||
const buffer = builder.build();
|
||||
expect(buffer.length).toEqual(28);
|
||||
|
||||
const tfrecords = TFRecordsBuilder.buildTFRecords([buffer]);
|
||||
// 16 = 8bytes for Lenght + 4bytes for CRC(Length) + 4bytes CRC(buffer)
|
||||
const headersSize = 16;
|
||||
expect(tfrecords.length).toEqual(28 + headersSize);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,120 @@
|
|||
import { TFRecordsImageMessage, Features, Feature, FeatureList,
|
||||
BytesList, Int64List, FloatList } from "./tensorFlowRecordsProtoBuf_pb";
|
||||
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer, textEncode } from "./tensorFlowHelpers";
|
||||
|
||||
/**
|
||||
* @name - TFRecords Feature Type
|
||||
* @description - Defines the type of TFRecords Feature
|
||||
* @member String - Specifies a Feature as a string
|
||||
* @member Binary - Specifies a Feature as a binary UInt8Array
|
||||
* @member Int64 - Specifies a Feature as a Int64
|
||||
* @member Float - Specifies a Feature as a Float
|
||||
*/
|
||||
export enum FeatureType {
|
||||
String = 0,
|
||||
Binary = 1,
|
||||
Int64 = 2,
|
||||
Float = 3,
|
||||
}
|
||||
|
||||
/**
|
||||
* @name - TFRecords Builder Class
|
||||
* @description - Create a TFRecords object
|
||||
*/
|
||||
export class TFRecordsBuilder {
|
||||
/**
|
||||
* @records - An Array of TFRecord Buffer created with releaseTFRecord()
|
||||
* @description - Return a Buffer representation of a TFRecords object
|
||||
*/
|
||||
public static buildTFRecords(records: Buffer[]): Buffer {
|
||||
return Buffer.concat(records.map((record) => {
|
||||
const length = record.length;
|
||||
|
||||
// Get TFRecords CRCs for TFRecords Header and Footer
|
||||
const bufferLength = getInt64Buffer(length);
|
||||
const bufferLengthMaskedCRC = getInt32Buffer(maskCrc(crc32c(bufferLength)));
|
||||
const bufferDataMaskedCRC = getInt32Buffer(maskCrc(crc32c(record)));
|
||||
|
||||
// Concatenate all TFRecords Header, Data and Footer buffer
|
||||
return Buffer.concat([bufferLength,
|
||||
bufferLengthMaskedCRC,
|
||||
record,
|
||||
bufferDataMaskedCRC]);
|
||||
}));
|
||||
}
|
||||
|
||||
private features: Features;
|
||||
|
||||
constructor() {
|
||||
this.features = new Features();
|
||||
}
|
||||
|
||||
/**
|
||||
* @key - Feature Key
|
||||
* @type - Feature Type
|
||||
* @value - A Int64 | Float | String | Binary value
|
||||
* @description - Add a Int64 | Float | String | Binary value feature
|
||||
*/
|
||||
public addFeature(key: string, type: FeatureType, value: string | number | Uint8Array) {
|
||||
this.addArrayFeature(key, type, [value]);
|
||||
}
|
||||
|
||||
/**
|
||||
* @key - Feature Key
|
||||
* @type - Feature Type
|
||||
* @value - An Array of Int64 | Float | String | Binary values
|
||||
* @description - Add an Array of Int64 | Float | String | Binary values feature
|
||||
*/
|
||||
public addArrayFeature<T extends string | number | Uint8Array>(key: string, type: FeatureType, values: T[]) {
|
||||
const feature = new Feature();
|
||||
|
||||
switch (type) {
|
||||
case FeatureType.String:
|
||||
const stringList = new BytesList();
|
||||
values.forEach((value) => {
|
||||
stringList.addValue(textEncode(value as string));
|
||||
});
|
||||
feature.setBytesList(stringList);
|
||||
break;
|
||||
case FeatureType.Binary:
|
||||
const byteList = new BytesList();
|
||||
values.forEach((value) => {
|
||||
byteList.addValue(value);
|
||||
});
|
||||
feature.setBytesList(byteList);
|
||||
break;
|
||||
case FeatureType.Int64:
|
||||
const intList = new Int64List();
|
||||
values.forEach((value) => {
|
||||
intList.addValue(value);
|
||||
});
|
||||
feature.setInt64List(intList);
|
||||
break;
|
||||
case FeatureType.Float:
|
||||
const floatList = new FloatList();
|
||||
values.forEach((value) => {
|
||||
floatList.addValue(value);
|
||||
});
|
||||
feature.setFloatList(floatList);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
const featuresMap = this.features.getFeatureMap();
|
||||
featuresMap.set(key, feature);
|
||||
}
|
||||
|
||||
/**
|
||||
* @description - Return a Buffer representation of a single TFRecord
|
||||
*/
|
||||
public build(): Buffer {
|
||||
// Get Protocol Buffer TFRecords object with exported image features
|
||||
const imageMessage = new TFRecordsImageMessage();
|
||||
imageMessage.setContext(this.features);
|
||||
|
||||
// Serialize Protocol Buffer in a buffer
|
||||
const bytes = imageMessage.serializeBinary();
|
||||
return new Buffer(bytes);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer } from "./tensorFlowHelpers";
|
||||
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer, textEncode } from "./tensorFlowHelpers";
|
||||
|
||||
describe("TFRecords Helper Functions", () => {
|
||||
describe("Run getInt64Buffer method test", () => {
|
||||
|
@ -31,4 +31,10 @@ describe("TFRecords Helper Functions", () => {
|
|||
.toEqual(new Buffer([5, 135, 25, 235]));
|
||||
});
|
||||
});
|
||||
|
||||
describe("Run textEncode method test", () => {
|
||||
it("Check textEncode for string 'ABC123'", async () => {
|
||||
expect(textEncode("ABC123")).toEqual(new Uint8Array([65, 66, 67, 49, 50, 51]));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -109,3 +109,16 @@ export function getInt32Buffer(value: number): Buffer {
|
|||
|
||||
return new Buffer(intArray);
|
||||
}
|
||||
|
||||
/**
|
||||
* @s - Input string
|
||||
* @description - Get a Uint8Array representation of an input string value
|
||||
*/
|
||||
export function textEncode(str: string): Uint8Array {
|
||||
const utf8 = unescape(encodeURIComponent(str));
|
||||
const result = new Uint8Array(utf8.length);
|
||||
for (let i = 0; i < utf8.length; i++) {
|
||||
result[i] = utf8.charCodeAt(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче