Separate TFRecords code on different file (#514)

* WIP: adding TFRecords Builder class

* Added TFRecordsBuilder class

* added unit test

* Refactored AddFeature and AddFeatureArray interface

* removed import
This commit is contained in:
Jacopo Mangiavacchi 2019-01-29 13:57:21 -08:00 коммит произвёл GitHub
Родитель f0982c369e
Коммит 8198c41775
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 224 добавлений и 133 удалений

Просмотреть файл

@ -7,13 +7,11 @@ import { IProject, AssetState, AssetType, IAsset,
import { AssetService } from "../../services/assetService";
import Guard from "../../common/guard";
import HtmlFileReader from "../../common/htmlFileReader";
import { itemTemplate, annotationTemplate, objectTemplate } from "./tensorFlowPascalVOC/tensorFlowPascalVOCTemplates";
import { itemTemplate } from "./tensorFlowPascalVOC/tensorFlowPascalVOCTemplates";
import { strings, interpolate } from "../../common/strings";
import { TFRecordsImageMessage, Features, Feature, FeatureList,
BytesList, Int64List, FloatList } from "./tensorFlowRecords/tensorFlowRecordsProtoBuf_pb";
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer } from "./tensorFlowRecords/tensorFlowHelpers";
import { TFRecordsBuilder, FeatureType } from "./tensorFlowRecords/tensorFlowBuilder";
/**
/**64
* @name - ITFRecordsJsonExportOptions
* @description - Defines the configurable options for the Vott JSON Export provider
*/
@ -81,105 +79,6 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
await this.exportRecords(exportFolderName, allAssets);
}
private addIntFeature(features: Features, key: string, value: number) {
const intList = new Int64List();
intList.addValue(value);
const feature = new Feature();
feature.setInt64List(intList);
const featuresMap = features.getFeatureMap();
featuresMap.set(key, feature);
}
private addIntArrayFeature(features: Features, key: string, values: number[]) {
const intList = new Int64List();
values.forEach((value) => {
intList.addValue(value);
});
const feature = new Feature();
feature.setInt64List(intList);
const featuresMap = features.getFeatureMap();
featuresMap.set(key, feature);
}
private addFloatArrayFeature(features: Features, key: string, values: number[]) {
const floatList = new FloatList();
values.forEach((value) => {
floatList.addValue(value);
});
const feature = new Feature();
feature.setFloatList(floatList);
const featuresMap = features.getFeatureMap();
featuresMap.set(key, feature);
}
private addStringFeature(features: Features, key: string, value: string) {
this.addBinaryArrayFeature(features, key, this.textEncode(value));
}
private addBinaryArrayFeature(features: Features, key: string, value: Uint8Array) {
const byteList = new BytesList();
byteList.addValue(value);
const feature = new Feature();
feature.setBytesList(byteList);
const featuresMap = features.getFeatureMap();
featuresMap.set(key, feature);
}
private addStringArrayFeature(features: Features, key: string, values: string[]) {
const byteList = new BytesList();
values.forEach((value) => {
byteList.addValue(this.textEncode(value));
});
const feature = new Feature();
feature.setBytesList(byteList);
const featuresMap = features.getFeatureMap();
featuresMap.set(key, feature);
}
private async writeTFRecord(fileNamePath: string, features: Features) {
try {
// Get Protocol Buffer TFRecords object with exported image features
const imageMessage = new TFRecordsImageMessage();
imageMessage.setContext(features);
// Serialize Protocol Buffer in a buffer
const bytes = imageMessage.serializeBinary();
const bufferData = new Buffer(bytes);
const length = bufferData.length;
// Get TFRecords CRCs for TFRecords Header and Footer
const bufferLength = getInt64Buffer(length);
const bufferLengthMaskedCRC = getInt32Buffer(maskCrc(crc32c(bufferLength)));
const bufferDataMaskedCRC = getInt32Buffer(maskCrc(crc32c(bufferData)));
// Concatenate all TFRecords Header, Data and Footer buffer
const outBuffer = Buffer.concat([bufferLength,
bufferLengthMaskedCRC,
bufferData,
bufferDataMaskedCRC]);
// Write TFRecords
await this.storageProvider.writeBinary(fileNamePath, outBuffer);
} catch (error) {
// Ignore the error at the moment
// TODO: Refactor ExportProvider abstract class export() method
// to return Promise<object> with an object containing
// the number of files succesfully exported out of total
console.log(`Error Writing TFRecords ${fileNamePath} - ${error}`);
}
}
private async exportRecords(exportFolderName: string, allAssets: IAssetMetadata[]) {
const allImageExports = allAssets.map((element) => {
return this.exportSingleRecord(exportFolderName, element);
@ -218,30 +117,30 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
this.updateAssetTagArrays(element, imageInfo);
// Generate TFRecord
const features = new Features();
const builder = new TFRecordsBuilder();
this.addIntFeature(features, "image/height", imageInfo.height);
this.addIntFeature(features, "image/width", imageInfo.width);
this.addStringFeature(features, "image/filename", element.asset.name);
this.addStringFeature(features, "image/source_id", element.asset.name);
this.addStringFeature(features, "image/key/sha256", CryptoJS.SHA256(imageBuffer)
builder.addFeature("image/height", FeatureType.Int64, imageInfo.height);
builder.addFeature("image/width", FeatureType.Int64, imageInfo.width);
builder.addFeature("image/filename", FeatureType.String, element.asset.name);
builder.addFeature("image/source_id", FeatureType.String, element.asset.name);
builder.addFeature("image/key/sha256", FeatureType.String, CryptoJS.SHA256(imageBuffer)
.toString(CryptoJS.enc.Base64));
this.addBinaryArrayFeature(features, "image/encoded", imageBuffer);
this.addStringFeature(features, "image/format", element.asset.name.split(".").pop());
this.addFloatArrayFeature(features, "image/object/bbox/xmin", imageInfo.xmin);
this.addFloatArrayFeature(features, "image/object/bbox/ymin", imageInfo.ymin);
this.addFloatArrayFeature(features, "image/object/bbox/xmax", imageInfo.xmax);
this.addFloatArrayFeature(features, "image/object/bbox/ymax", imageInfo.ymax);
this.addStringArrayFeature(features, "image/object/class/text", imageInfo.text);
this.addIntArrayFeature(features, "image/object/class/label", imageInfo.label);
this.addIntArrayFeature(features, "image/object/difficult", imageInfo.difficult);
this.addIntArrayFeature(features, "image/object/truncated", imageInfo.truncated);
this.addStringArrayFeature(features, "image/object/view", imageInfo.view);
builder.addFeature("image/encoded", FeatureType.Binary, imageBuffer);
builder.addFeature("image/format", FeatureType.String, element.asset.name.split(".").pop());
builder.addArrayFeature("image/object/bbox/xmin", FeatureType.Float, imageInfo.xmin);
builder.addArrayFeature("image/object/bbox/ymin", FeatureType.Float, imageInfo.ymin);
builder.addArrayFeature("image/object/bbox/xmax", FeatureType.Float, imageInfo.xmax);
builder.addArrayFeature("image/object/bbox/ymax", FeatureType.Float, imageInfo.ymax);
builder.addArrayFeature("image/object/class/text", FeatureType.String, imageInfo.text);
builder.addArrayFeature("image/object/class/label", FeatureType.Int64, imageInfo.label);
builder.addArrayFeature("image/object/difficult", FeatureType.Int64, imageInfo.difficult);
builder.addArrayFeature("image/object/truncated", FeatureType.Int64, imageInfo.truncated);
builder.addArrayFeature("image/object/view", FeatureType.String, imageInfo.view);
// Save TFRecord
// Save TFRecords
const fileName = element.asset.name.split(".").slice(0, -1).join(".");
const fileNamePath = `${exportFolderName}/${fileName}.tfrecord`;
await this.writeTFRecord(fileNamePath, features);
await this.writeTFRecords(fileNamePath, [builder.build()]);
resolve();
} catch (error) {
@ -256,6 +155,14 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
});
}
private async writeTFRecords(fileNamePath: string, buffers: Buffer[]) {
// Get TFRecords buffer
const tfRecords = TFRecordsBuilder.buildTFRecords(buffers);
// Write TFRecords
await this.storageProvider.writeBinary(fileNamePath, tfRecords);
}
private async updateImageSizeInfo(image64: string, imageInfo: IImageInfo) {
if (image64.length > 10) {
const assetProps = await HtmlFileReader.readAssetAttributesWithBuffer(image64);
@ -307,13 +214,4 @@ export class TFRecordsJsonExportProvider extends ExportProvider<ITFRecordsJsonEx
await this.storageProvider.writeText(pbtxtFileName, items.join(""));
}
}
private textEncode(str: string): Uint8Array {
const utf8 = unescape(encodeURIComponent(str));
const result = new Uint8Array(utf8.length);
for (let i = 0; i < utf8.length; i++) {
result[i] = utf8.charCodeAt(i);
}
return result;
}
}

Просмотреть файл

@ -0,0 +1,54 @@
import { TFRecordsBuilder, FeatureType } from "./tensorFlowBuilder";
describe("TFRecords Builder Functions", () => {
describe("Check Adding Single Features", () => {
let builder: TFRecordsBuilder;
beforeEach(() => {
builder = new TFRecordsBuilder();
});
it("Check addIntFeature", async () => {
builder.addFeature("image/height", FeatureType.Int64, 123);
expect(builder.build()).toEqual(
new Buffer([10, 23, 10, 21, 10, 12, 105, 109, 97, 103, 101, 47, 104,
101, 105, 103, 104, 116, 18, 5, 26, 3, 10, 1, 123]));
});
it("Check addFloatFeature", async () => {
builder.addFeature("image/height", FeatureType.Float, 123.0);
expect(builder.build()).toEqual(
new Buffer([10, 26, 10, 24, 10, 12, 105, 109, 97, 103, 101, 47, 104,
101, 105, 103, 104, 116, 18, 8, 18, 6, 10, 4, 0, 0, 246, 66]));
});
it("Check addStringFeature", async () => {
builder.addFeature("image/height", FeatureType.String, "123");
expect(builder.build()).toEqual(
new Buffer([10, 25, 10, 23, 10, 12, 105, 109, 97, 103, 101, 47, 104,
101, 105, 103, 104, 116, 18, 7, 10, 5, 10, 3, 49, 50, 51]));
});
});
describe("Check single TFRecord generation with arrays", () => {
let builder: TFRecordsBuilder;
it("Check releaseTFRecord", async () => {
builder = new TFRecordsBuilder();
builder.addArrayFeature("image/height", FeatureType.Int64, [1, 2]);
builder.addArrayFeature("image/height", FeatureType.Float, [1.0, 2.0]);
builder.addArrayFeature("image/height", FeatureType.String, ["1", "2"]);
const buffer = builder.build();
expect(buffer.length).toEqual(28);
const tfrecords = TFRecordsBuilder.buildTFRecords([buffer]);
// 16 = 8bytes for Lenght + 4bytes for CRC(Length) + 4bytes CRC(buffer)
const headersSize = 16;
expect(tfrecords.length).toEqual(28 + headersSize);
});
});
});

Просмотреть файл

@ -0,0 +1,120 @@
import { TFRecordsImageMessage, Features, Feature, FeatureList,
BytesList, Int64List, FloatList } from "./tensorFlowRecordsProtoBuf_pb";
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer, textEncode } from "./tensorFlowHelpers";
/**
* @name - TFRecords Feature Type
* @description - Defines the type of TFRecords Feature
* @member String - Specifies a Feature as a string
* @member Binary - Specifies a Feature as a binary UInt8Array
* @member Int64 - Specifies a Feature as a Int64
* @member Float - Specifies a Feature as a Float
*/
export enum FeatureType {
String = 0,
Binary = 1,
Int64 = 2,
Float = 3,
}
/**
* @name - TFRecords Builder Class
* @description - Create a TFRecords object
*/
export class TFRecordsBuilder {
/**
* @records - An Array of TFRecord Buffer created with releaseTFRecord()
* @description - Return a Buffer representation of a TFRecords object
*/
public static buildTFRecords(records: Buffer[]): Buffer {
return Buffer.concat(records.map((record) => {
const length = record.length;
// Get TFRecords CRCs for TFRecords Header and Footer
const bufferLength = getInt64Buffer(length);
const bufferLengthMaskedCRC = getInt32Buffer(maskCrc(crc32c(bufferLength)));
const bufferDataMaskedCRC = getInt32Buffer(maskCrc(crc32c(record)));
// Concatenate all TFRecords Header, Data and Footer buffer
return Buffer.concat([bufferLength,
bufferLengthMaskedCRC,
record,
bufferDataMaskedCRC]);
}));
}
private features: Features;
constructor() {
this.features = new Features();
}
/**
* @key - Feature Key
* @type - Feature Type
* @value - A Int64 | Float | String | Binary value
* @description - Add a Int64 | Float | String | Binary value feature
*/
public addFeature(key: string, type: FeatureType, value: string | number | Uint8Array) {
this.addArrayFeature(key, type, [value]);
}
/**
* @key - Feature Key
* @type - Feature Type
* @value - An Array of Int64 | Float | String | Binary values
* @description - Add an Array of Int64 | Float | String | Binary values feature
*/
public addArrayFeature<T extends string | number | Uint8Array>(key: string, type: FeatureType, values: T[]) {
const feature = new Feature();
switch (type) {
case FeatureType.String:
const stringList = new BytesList();
values.forEach((value) => {
stringList.addValue(textEncode(value as string));
});
feature.setBytesList(stringList);
break;
case FeatureType.Binary:
const byteList = new BytesList();
values.forEach((value) => {
byteList.addValue(value);
});
feature.setBytesList(byteList);
break;
case FeatureType.Int64:
const intList = new Int64List();
values.forEach((value) => {
intList.addValue(value);
});
feature.setInt64List(intList);
break;
case FeatureType.Float:
const floatList = new FloatList();
values.forEach((value) => {
floatList.addValue(value);
});
feature.setFloatList(floatList);
break;
default:
break;
}
const featuresMap = this.features.getFeatureMap();
featuresMap.set(key, feature);
}
/**
* @description - Return a Buffer representation of a single TFRecord
*/
public build(): Buffer {
// Get Protocol Buffer TFRecords object with exported image features
const imageMessage = new TFRecordsImageMessage();
imageMessage.setContext(this.features);
// Serialize Protocol Buffer in a buffer
const bytes = imageMessage.serializeBinary();
return new Buffer(bytes);
}
}

Просмотреть файл

@ -1,4 +1,4 @@
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer } from "./tensorFlowHelpers";
import { crc32c, maskCrc, getInt64Buffer, getInt32Buffer, textEncode } from "./tensorFlowHelpers";
describe("TFRecords Helper Functions", () => {
describe("Run getInt64Buffer method test", () => {
@ -31,4 +31,10 @@ describe("TFRecords Helper Functions", () => {
.toEqual(new Buffer([5, 135, 25, 235]));
});
});
describe("Run textEncode method test", () => {
it("Check textEncode for string 'ABC123'", async () => {
expect(textEncode("ABC123")).toEqual(new Uint8Array([65, 66, 67, 49, 50, 51]));
});
});
});

Просмотреть файл

@ -109,3 +109,16 @@ export function getInt32Buffer(value: number): Buffer {
return new Buffer(intArray);
}
/**
* @s - Input string
* @description - Get a Uint8Array representation of an input string value
*/
export function textEncode(str: string): Uint8Array {
const utf8 = unescape(encodeURIComponent(str));
const result = new Uint8Array(utf8.length);
for (let i = 0; i < utf8.length; i++) {
result[i] = utf8.charCodeAt(i);
}
return result;
}