feat: Add more export options to pascal voc exporter (#705)

Adds options to specify test/train split, export regions with no assigned tags.
Resolves regressions for exporting image sets and duplicate tags

Resolves AB#16708,16532
Resolves issues #692 #685
This commit is contained in:
Phil 2019-04-01 14:11:34 -07:00 коммит произвёл Wallace Breza
Родитель ab3ce6f108
Коммит a5ee17a7c7
24 изменённых файлов: 511 добавлений и 177 удалений

45
package-lock.json сгенерированный
Просмотреть файл

@ -14151,6 +14151,17 @@
"react-lifecycles-compat": "^3.0.4"
}
},
"rc-checkbox": {
"version": "2.1.6",
"resolved": "https://registry.npmjs.org/rc-checkbox/-/rc-checkbox-2.1.6.tgz",
"integrity": "sha512-+VxQbt2Cwe1PxCvwosrAYXT6EQeGwrbLJB2K+IPGCSRPCKnk9zcub/0eW8A4kxjyyfh60PkwsAUZ7qmB31OmRA==",
"requires": {
"babel-runtime": "^6.23.0",
"classnames": "2.x",
"prop-types": "15.x",
"rc-util": "^4.0.4"
}
},
"rc-menu": {
"version": "7.4.21",
"resolved": "https://registry.npmjs.org/rc-menu/-/rc-menu-7.4.21.tgz",
@ -14169,6 +14180,40 @@
"resize-observer-polyfill": "^1.5.0"
}
},
"rc-slider": {
"version": "8.6.7",
"resolved": "https://registry.npmjs.org/rc-slider/-/rc-slider-8.6.7.tgz",
"integrity": "sha512-QIFWMnK1VLc4TtJSZJgjhI6UOhN8eg53EM2La+eRa8rSPZwJT3rIWfZnTZs7OV7zXG/AiLWN4G+oGxuMcEFpsg==",
"requires": {
"babel-runtime": "6.x",
"classnames": "^2.2.5",
"prop-types": "^15.5.4",
"rc-tooltip": "^3.7.0",
"rc-util": "^4.0.4",
"shallowequal": "^1.0.1",
"warning": "^4.0.3"
},
"dependencies": {
"warning": {
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/warning/-/warning-4.0.3.tgz",
"integrity": "sha512-rpJyN222KWIvHJ/F53XSZv0Zl/accqHR8et1kpaMTD/fLCRxtV8iX8czMzY7sVZupTI3zcUTg8eycS2kNF9l6w==",
"requires": {
"loose-envify": "^1.0.0"
}
}
}
},
"rc-tooltip": {
"version": "3.7.3",
"resolved": "https://registry.npmjs.org/rc-tooltip/-/rc-tooltip-3.7.3.tgz",
"integrity": "sha512-dE2ibukxxkrde7wH9W8ozHKUO4aQnPZ6qBHtrTH9LoO836PjDdiaWO73fgPB05VfJs9FbZdmGPVEbXCeOP99Ww==",
"requires": {
"babel-runtime": "6.x",
"prop-types": "^15.5.8",
"rc-trigger": "^2.2.2"
}
},
"rc-trigger": {
"version": "2.6.2",
"resolved": "https://registry.npmjs.org/rc-trigger/-/rc-trigger-2.6.2.tgz",

Просмотреть файл

@ -26,7 +26,9 @@
"lodash": "^4.17.11",
"md5.js": "^1.3.5",
"node-int64": "^0.4.0",
"rc-checkbox": "^2.1.6",
"rc-menu": "^7.4.21",
"rc-slider": "^8.6.7",
"react": "^16.7.0",
"react-appinsights": "^3.0.0-rc.5",
"react-dom": "^16.7.0",

Просмотреть файл

@ -286,6 +286,14 @@ export const english: IAppStrings = {
},
tfPascalVoc: {
displayName: "Tensorflow Pascal VOC",
testTrainSplit: {
title: "Test / Train Split",
description: "The test train split to use for exported data",
},
exportUnassigned: {
title: "Export Unassigned",
description: "Whether or not to include unassigned tags in exported data",
},
},
},
messages: {

Просмотреть файл

@ -287,6 +287,14 @@ export const spanish: IAppStrings = {
},
tfPascalVoc: {
displayName: "Tensorflow Pascal VOC",
testTrainSplit: {
title: "Prueba/tren Split",
description: "La división del tren de prueba que se utilizará para los datos exportados",
},
exportUnassigned: {
title: "Exportar sin asignar",
description: "Si se incluyen o no etiquetas no asignadas en los datos exportados",
},
},
},
messages: {

Просмотреть файл

@ -288,6 +288,14 @@ export interface IAppStrings {
},
tfPascalVoc: {
displayName: string,
testTrainSplit: {
title: string,
description: string,
},
exportUnassigned: {
title: string,
description: string,
},
},
},
messages: {

Просмотреть файл

@ -62,4 +62,13 @@ input[type=file] {
border: solid 1px #62c462;
}
}
.rc-checkbox {
margin-left: 0.5em;
}
.slider-value {
font-weight: 500;
font-size: 90%;
}
}

Просмотреть файл

@ -181,7 +181,7 @@ export interface IExportProviderOptions extends IProviderOptions {
* @description - Defines the settings for how project data is exported into commonly used format
* @member id - Unique identifier for export format
* @member name - Name of export format
* @member providerType - The export format type (TF Records, YOLO, CVS, etc)
* @member providerType - The export format type (TF Records, YOLO, CSV, etc)
* @member providerOptions - The provider specific option required to export data
*/
export interface IExportFormat {

Просмотреть файл

@ -17,6 +17,18 @@
"${strings.export.providers.common.properties.assetState.options.visited}",
"${strings.export.providers.common.properties.assetState.options.tagged}"
]
},
"testTrainSplit": {
"title": "${strings.export.providers.tfPascalVoc.testTrainSplit.title}",
"description": "${strings.export.providers.tfPascalVoc.testTrainSplit.description}",
"type": "number",
"default": 80
},
"exportUnassigned": {
"title": "${strings.export.providers.tfPascalVoc.exportUnassigned.title}",
"description": "${strings.export.providers.tfPascalVoc.exportUnassigned.description}",
"type": "boolean",
"default": true
}
}
}

Просмотреть файл

@ -1,5 +1,5 @@
import _ from "lodash";
import { TFPascalVOCJsonExportProvider } from "./tensorFlowPascalVOC";
import { TFPascalVOCExportProvider, ITFPascalVOCExportProviderOptions } from "./tensorFlowPascalVOC";
import { ExportAssetState } from "./exportProvider";
import registerProviders from "../../registerProviders";
import { ExportProviderFactory } from "./exportProviderFactory";
@ -21,16 +21,6 @@ import { AssetProviderFactory } from "../storage/assetProviderFactory";
registerMixins();
function _base64ToArrayBuffer(base64: string) {
const binaryString = window.atob(base64);
const len = binaryString.length;
const bytes = new Uint8Array(len);
for (let i = 0; i < len; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return bytes.buffer;
}
describe("TFPascalVOC Json Export Provider", () => {
const testAssets = MockFactory.createTestAssets(10, 1);
const baseTestProject = MockFactory.createTestProject("Test Project");
@ -38,7 +28,6 @@ describe("TFPascalVOC Json Export Provider", () => {
"asset-1": MockFactory.createTestAsset("1", AssetState.Tagged),
"asset-2": MockFactory.createTestAsset("2", AssetState.Tagged),
"asset-3": MockFactory.createTestAsset("3", AssetState.Visited),
"asset-4": MockFactory.createTestAsset("4", AssetState.NotVisited),
};
baseTestProject.sourceConnection = MockFactory.createTestConnection("test", "localFileSystemProxy");
baseTestProject.targetConnection = MockFactory.createTestConnection("test", "localFileSystemProxy");
@ -62,16 +51,18 @@ describe("TFPascalVOC Json Export Provider", () => {
});
it("Is defined", () => {
expect(TFPascalVOCJsonExportProvider).toBeDefined();
expect(TFPascalVOCExportProvider).toBeDefined();
});
it("Can be instantiated through the factory", () => {
const options: IExportProviderOptions = {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.All,
exportUnassigned: true,
testTrainSplit: 80,
};
const exportProvider = ExportProviderFactory.create("tensorFlowPascalVOC", baseTestProject, options);
expect(exportProvider).not.toBeNull();
expect(exportProvider).toBeInstanceOf(TFPascalVOCJsonExportProvider);
expect(exportProvider).toBeInstanceOf(TFPascalVOCExportProvider);
});
describe("Export variations", () => {
@ -79,27 +70,12 @@ describe("TFPascalVOC Json Export Provider", () => {
const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
const mockTag = MockFactory.createTestTag();
const mockStartPoint: IPoint = {
x: 1,
y: 2,
};
const mockEndPoint: IPoint = {
x: 3,
y: 4,
};
const mockRegion: IRegion = {
id: "id",
type: RegionType.Rectangle,
tags: [mockTag.name],
points: [mockStartPoint, mockEndPoint],
};
const mockRegion1 = MockFactory.createTestRegion("region-1", [mockTag.name]);
const mockRegion2 = MockFactory.createTestRegion("region-2", [mockTag.name]);
const assetMetadata: IAssetMetadata = {
asset,
regions: [mockRegion],
regions: [mockRegion1, mockRegion2],
version: appInfo.version,
};
@ -111,14 +87,16 @@ describe("TFPascalVOC Json Export Provider", () => {
});
it("Exports all assets", async () => {
const options: IExportProviderOptions = {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.All,
exportUnassigned: true,
testTrainSplit: 80,
};
const testProject = { ...baseTestProject };
testProject.tags = MockFactory.createTestTags(3);
const exportProvider = new TFPascalVOCJsonExportProvider(testProject, options);
const exportProvider = new TFPascalVOCExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
@ -167,14 +145,16 @@ describe("TFPascalVOC Json Export Provider", () => {
});
it("Exports only visited assets (includes tagged)", async () => {
const options: IExportProviderOptions = {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.Visited,
exportUnassigned: true,
testTrainSplit: 80,
};
const testProject = { ...baseTestProject };
testProject.tags = MockFactory.createTestTags(1);
const exportProvider = new TFPascalVOCJsonExportProvider(testProject, options);
const exportProvider = new TFPascalVOCExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
@ -211,14 +191,16 @@ describe("TFPascalVOC Json Export Provider", () => {
});
it("Exports only tagged assets", async () => {
const options: IExportProviderOptions = {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.Tagged,
exportUnassigned: true,
testTrainSplit: 80,
};
const testProject = { ...baseTestProject };
testProject.tags = MockFactory.createTestTags(3);
const exportProvider = new TFPascalVOCJsonExportProvider(testProject, options);
const exportProvider = new TFPascalVOCExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
@ -258,5 +240,156 @@ describe("TFPascalVOC Json Export Provider", () => {
expect(writeTextFileCalls.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_train.txt")))
.toBeGreaterThanOrEqual(0);
});
it("Export includes unassigned tags", async () => {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.Tagged,
exportUnassigned: true,
testTrainSplit: 80,
};
const testProject = { ...baseTestProject };
const testAssets = MockFactory.createTestAssets(10, 0);
testAssets.forEach((asset) => asset.state = AssetState.Tagged);
testProject.assets = _.keyBy(testAssets, (asset) => asset.id);
testProject.tags = MockFactory.createTestTags(3);
const exportProvider = new TFPascalVOCExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[];
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_val.txt")))
.toBeGreaterThanOrEqual(0);
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_train.txt")))
.toBeGreaterThanOrEqual(0);
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt")))
.toBeGreaterThanOrEqual(0);
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt")))
.toBeGreaterThanOrEqual(0);
});
it("Export does not include unassigned tags", async () => {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.Tagged,
exportUnassigned: false,
testTrainSplit: 80,
};
const testProject = { ...baseTestProject };
const testAssets = MockFactory.createTestAssets(10, 0);
testAssets.forEach((asset) => asset.state = AssetState.Tagged);
testProject.assets = _.keyBy(testAssets, (asset) => asset.id);
testProject.tags = MockFactory.createTestTags(3);
const exportProvider = new TFPascalVOCExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[];
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_val.txt")))
.toEqual(-1);
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_train.txt")))
.toEqual(-1);
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt")))
.toBeGreaterThanOrEqual(0);
expect(writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt")))
.toBeGreaterThanOrEqual(0);
});
describe("Annotations", () => {
it("contains expected XML", async () => {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.Tagged,
exportUnassigned: false,
testTrainSplit: 80,
};
const testProject = { ...baseTestProject };
const testAssets = MockFactory.createTestAssets(10, 0);
testAssets.forEach((asset) => asset.state = AssetState.Tagged);
testProject.assets = _.keyBy(testAssets, (asset) => asset.id);
testProject.tags = [MockFactory.createTestTag("1")];
const exportProvider = new TFPascalVOCExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[];
const assetIndex = writeTextFileCalls.findIndex((args) => args[0].endsWith("/Annotations/Asset 1.xml"));
const assetXml = writeTextFileCalls[assetIndex][1] as string;
const objectRegExp = /<object>([\s\S]*?)<\/object>/g;
const folderRegExp = new RegExp(/<filename>(.*?)<\/filename>/g);
const pathRegExp = new RegExp(/<path>(.*?)<\/path>/g);
const widthRegExp = new RegExp(/<width>(.*?)<\/width>/g);
const heightRegExp = new RegExp(/<height>(.*?)<\/height>/g);
const objectMatches = assetXml.match(objectRegExp);
expect(objectMatches).toHaveLength(2);
expect(folderRegExp.exec(assetXml)[1]).toEqual(testAssets[1].name);
expect(pathRegExp.exec(assetXml)[1]).toContain(testAssets[1].name);
expect(widthRegExp.exec(assetXml)[1]).toEqual(testAssets[1].size.width.toString());
expect(heightRegExp.exec(assetXml)[1]).toEqual(testAssets[1].size.height.toString());
});
});
describe("Test Train Splits", () => {
async function testTestTrainSplit(testTrainSplit: number): Promise<void> {
const options: ITFPascalVOCExportProviderOptions = {
assetState: ExportAssetState.Tagged,
exportUnassigned: true,
testTrainSplit,
};
const testProject = { ...baseTestProject };
const testAssets = MockFactory.createTestAssets(10, 0);
testAssets.forEach((asset) => asset.state = AssetState.Tagged);
testProject.assets = _.keyBy(testAssets, (asset) => asset.id);
testProject.tags = [MockFactory.createTestTag("1")];
const exportProvider = new TFPascalVOCExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[];
const valDataIndex = writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt"));
const trainDataIndex = writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt"));
const expectedTrainCount = (testTrainSplit / 100) * testAssets.length;
const expectedTestCount = ((100 - testTrainSplit) / 100) * testAssets.length;
expect(writeTextFileCalls[valDataIndex][1].split("\n")).toHaveLength(expectedTestCount);
expect(writeTextFileCalls[trainDataIndex][1].split("\n")).toHaveLength(expectedTrainCount);
}
it("Correctly generated files based on 50/50 test / train split", async () => {
await testTestTrainSplit(50);
});
it("Correctly generated files based on 60/40 test / train split", async () => {
await testTestTrainSplit(60);
});
it("Correctly generated files based on 80/20 test / train split", async () => {
await testTestTrainSplit(80);
});
it("Correctly generated files based on 90/10 test / train split", async () => {
await testTestTrainSplit(90);
});
});
});
});

Просмотреть файл

@ -5,6 +5,8 @@ import Guard from "../../common/guard";
import HtmlFileReader from "../../common/htmlFileReader";
import { itemTemplate, annotationTemplate, objectTemplate } from "./tensorFlowPascalVOC/tensorFlowPascalVOCTemplates";
import { interpolate } from "../../common/strings";
import { PlatformType } from "../../common/hostProcess";
import os from "os";
interface IObjectInfo {
name: string;
@ -20,14 +22,24 @@ interface IImageInfo {
objects: IObjectInfo[];
}
/**
* Export options for TensorFlow Pascal VOC Export Provider
*/
export interface ITFPascalVOCExportProviderOptions extends IExportProviderOptions {
/** The test / train split ratio for exporting data */
testTrainSplit?: number;
/** Whether or not to include unassigned tags in exported data */
exportUnassigned?: boolean;
}
/**
* @name - TFPascalVOC Json Export Provider
* @description - Exports a project into a single JSON file that include all configured assets
*/
export class TFPascalVOCJsonExportProvider extends ExportProvider {
export class TFPascalVOCExportProvider extends ExportProvider<ITFPascalVOCExportProviderOptions> {
private imagesInfo = new Map<string, IImageInfo>();
constructor(project: IProject, options: IExportProviderOptions) {
constructor(project: IProject, options: ITFPascalVOCExportProviderOptions) {
super(project, options);
Guard.null(options);
}
@ -48,8 +60,15 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider {
await this.exportPBTXT(exportFolderName, this.project);
await this.exportAnnotations(exportFolderName, allAssets);
// TODO: Make testSplit && exportUnassignedTags optional parameter in the UI Exporter configuration
await this.exportImageSets(exportFolderName, allAssets, this.project.tags, 0.2, true);
// TestSplit && exportUnassignedTags are optional parameter in the UI Exporter configuration
const testSplit = (100 - (this.options.testTrainSplit || 80)) / 100;
await this.exportImageSets(
exportFolderName,
allAssets,
this.project.tags,
testSplit,
this.options.exportUnassigned,
);
}
private async exportImages(exportFolderName: string, allAssets: IAssetMetadata[]) {
@ -57,13 +76,9 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider {
const jpegImagesFolderName = `${exportFolderName}/JPEGImages`;
await this.storageProvider.createContainer(jpegImagesFolderName);
try {
await allAssets.mapAsync(async (assetMetadata) => {
await this.exportSingleImage(jpegImagesFolderName, assetMetadata);
});
} catch (err) {
console.log(err);
}
await allAssets.mapAsync(async (assetMetadata) => {
await this.exportSingleImage(jpegImagesFolderName, assetMetadata);
});
}
private async exportSingleImage(jpegImagesFolderName: string, assetMetadata: IAssetMetadata): Promise<void> {
@ -102,22 +117,20 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider {
private getAssetTagArray(element: IAssetMetadata): IObjectInfo[] {
const tagObjects = [];
element.regions.filter((region) => (region.type === RegionType.Rectangle ||
region.type === RegionType.Square) &&
region.points.length === 2)
.forEach((region) => {
region.tags.forEach((tagName) => {
const objectInfo: IObjectInfo = {
name: tagName,
xmin: region.points[0].x,
ymin: region.points[0].y,
xmax: region.points[1].x,
ymax: region.points[1].y,
};
element.regions.forEach((region) => {
region.tags.forEach((tagName) => {
const objectInfo: IObjectInfo = {
name: tagName,
xmin: region.boundingBox.left,
ymin: region.boundingBox.top,
xmax: region.boundingBox.left + region.boundingBox.width,
ymax: region.boundingBox.top + region.boundingBox.height,
};
tagObjects.push(objectInfo);
});
tagObjects.push(objectInfo);
});
});
return tagObjects;
}
@ -199,7 +212,7 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider {
await this.storageProvider.writeText(assetFilePath, interpolate(annotationTemplate, params));
});
} catch (err) {
console.log(err);
console.log("Error writing Pascal VOC annotation file");
}
}
@ -209,6 +222,10 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider {
tags: ITag[],
testSplit: number,
exportUnassignedTags: boolean) {
if (!tags) {
return;
}
// Create ImageSets Sub Folder (Main ?)
const imageSetsFolderName = `${exportFolderName}/ImageSets`;
await this.storageProvider.createContainer(imageSetsFolderName);
@ -216,68 +233,61 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider {
const imageSetsMainFolderName = `${exportFolderName}/ImageSets/Main`;
await this.storageProvider.createContainer(imageSetsMainFolderName);
const tagsDict = new Map<string, string[]>();
if (tags) {
tags.forEach((tag) => {
tagsDict.set(tag.name, []);
});
const assetUsage = new Map<string, Set<string>>();
const tagUsage = new Map<string, number>();
allAssets.forEach((asset) => {
if (asset.regions.length > 0) {
asset.regions.forEach((region) => {
tags.forEach((tag) => {
const array = tagsDict.get(tag.name);
if (region.tags.filter((tagName) => tagName === tag.name).length > 0) {
array.push(`${asset.asset.name} 1`);
} else {
array.push(`${asset.asset.name} -1`);
}
});
});
} else if (exportUnassignedTags) {
// Generate tag usage per asset
allAssets.forEach((assetMetadata) => {
const appliedTags = new Set<string>();
assetUsage.set(assetMetadata.asset.name, appliedTags);
if (assetMetadata.regions.length > 0) {
assetMetadata.regions.forEach((region) => {
tags.forEach((tag) => {
const array = tagsDict.get(tag.name);
array.push(`${asset.asset.name} -1`);
let tagInstances = tagUsage.get(tag.name) || 0;
if (region.tags.filter((tagName) => tagName === tag.name).length > 0) {
appliedTags.add(tag.name);
tagUsage.set(tag.name, tagInstances += 1);
}
});
}
});
});
}
});
// Save ImageSets
tags.forEach(async (tag) => {
if (testSplit > 0 && testSplit <= 1) {
// Shuffle tagsDict sets
tagsDict.forEach((value, key) => {
value = this.shuffle(value);
});
const array = tagsDict.get(tag.name);
// Split in Test and Train sets
const totalAssets = array.length;
const testCount = Math.ceil(totalAssets * testSplit);
const testArray = array.slice(0, testCount);
const trainArray = array.slice(testCount, totalAssets);
const testImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_val.txt`;
await this.storageProvider.writeText(testImageSetFileName, testArray.join("\n"));
const trainImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_train.txt`;
await this.storageProvider.writeText(trainImageSetFileName, trainArray.join("\n"));
// Save ImageSets
await tags.forEachAsync(async (tag) => {
const tagInstances = tagUsage.get(tag.name) || 0;
if (!exportUnassignedTags && tagInstances === 0) {
return;
}
const assetList = [];
assetUsage.forEach((tags, assetName) => {
if (tags.has(tag.name)) {
assetList.push(`${assetName} 1`);
} else {
const imageSetFileName = `${imageSetsMainFolderName}/${tag.name}.txt`;
await this.storageProvider.writeText(imageSetFileName, tagsDict.get(tag.name).join("\n"));
assetList.push(`${assetName} -1`);
}
});
}
}
private shuffle(a: any[]) {
for (let i = a.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1));
[a[i], a[j]] = [a[j], a[i]];
}
return a;
if (testSplit > 0 && testSplit <= 1) {
// Split in Test and Train sets
const totalAssets = assetUsage.size;
const testCount = Math.ceil(totalAssets * testSplit);
const testArray = assetList.slice(0, testCount);
const trainArray = assetList.slice(testCount, totalAssets);
const testImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_val.txt`;
await this.storageProvider.writeText(testImageSetFileName, testArray.join(os.EOL));
const trainImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_train.txt`;
await this.storageProvider.writeText(trainImageSetFileName, trainArray.join(os.EOL));
} else {
const imageSetFileName = `${imageSetsMainFolderName}/${tag.name}.txt`;
await this.storageProvider.writeText(imageSetFileName, assetList.join(os.EOL));
}
});
}
}

Просмотреть файл

@ -1 +1,8 @@
{}
{
"testTrainSplit": {
"ui:widget": "slider"
},
"exportUnassigned": {
"ui:widget": "checkbox"
}
}

Просмотреть файл

@ -1,5 +1,5 @@
import _ from "lodash";
import { TFRecordsJsonExportProvider } from "./tensorFlowRecords";
import { TFRecordsExportProvider } from "./tensorFlowRecords";
import { ExportAssetState } from "./exportProvider";
import registerProviders from "../../registerProviders";
import { ExportProviderFactory } from "./exportProviderFactory";
@ -8,7 +8,6 @@ import {
RegionType, IPoint, IExportProviderOptions,
} from "../../models/applicationState";
import MockFactory from "../../common/mockFactory";
import axios, { AxiosResponse } from "axios";
jest.mock("../../services/assetService");
import { AssetService } from "../../services/assetService";
@ -18,6 +17,7 @@ import { LocalFileSystemProxy } from "../storage/localFileSystemProxy";
import registerMixins from "../../registerMixins";
import { appInfo } from "../../common/appInfo";
import { AssetProviderFactory } from "../storage/assetProviderFactory";
import HtmlFileReader from "../../common/htmlFileReader";
registerMixins();
@ -35,15 +35,7 @@ describe("TFRecords Json Export Provider", () => {
const tagLengthInPbtxt = 31;
axios.get = jest.fn((url, config) => {
return Promise.resolve<AxiosResponse>({
config,
headers: null,
status: 200,
statusText: "OK",
data: [1, 2, 3],
});
});
HtmlFileReader.getAssetArray = jest.fn(() => Promise.resolve(new Uint8Array([1, 2, 3]).buffer));
beforeAll(() => {
AssetProviderFactory.create = jest.fn(() => {
@ -58,7 +50,7 @@ describe("TFRecords Json Export Provider", () => {
});
it("Is defined", () => {
expect(TFRecordsJsonExportProvider).toBeDefined();
expect(TFRecordsExportProvider).toBeDefined();
});
it("Can be instantiated through the factory", () => {
@ -67,7 +59,7 @@ describe("TFRecords Json Export Provider", () => {
};
const exportProvider = ExportProviderFactory.create("tensorFlowRecords", baseTestProject, options);
expect(exportProvider).not.toBeNull();
expect(exportProvider).toBeInstanceOf(TFRecordsJsonExportProvider);
expect(exportProvider).toBeInstanceOf(TFRecordsExportProvider);
});
describe("Export variations", () => {
@ -114,7 +106,7 @@ describe("TFRecords Json Export Provider", () => {
const testProject = { ...baseTestProject };
testProject.tags = MockFactory.createTestTags(3);
const exportProvider = new TFRecordsJsonExportProvider(testProject, options);
const exportProvider = new TFRecordsExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
@ -143,7 +135,7 @@ describe("TFRecords Json Export Provider", () => {
const testProject = { ...baseTestProject };
testProject.tags = MockFactory.createTestTags(1);
const exportProvider = new TFRecordsJsonExportProvider(testProject, options);
const exportProvider = new TFRecordsExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;
@ -171,7 +163,7 @@ describe("TFRecords Json Export Provider", () => {
const testProject = { ...baseTestProject };
testProject.tags = MockFactory.createTestTags(3);
const exportProvider = new TFRecordsJsonExportProvider(testProject, options);
const exportProvider = new TFRecordsExportProvider(testProject, options);
await exportProvider.export();
const storageProviderMock = LocalFileSystemProxy as any;

Просмотреть файл

@ -26,7 +26,7 @@ interface IImageInfo {
* @name - TFRecords Json Export Provider
* @description - Exports a project into a single JSON file that include all configured assets
*/
export class TFRecordsJsonExportProvider extends ExportProvider {
export class TFRecordsExportProvider extends ExportProvider {
constructor(project: IProject, options: IExportProviderOptions) {
super(project, options);
Guard.null(options);

Просмотреть файл

@ -17,7 +17,8 @@ describe("Azure blob functions", () => {
ContainerURL.fromServiceURL = jest.fn(() => new ContainerURL(null, null));
const containerURL = ContainerURL as jest.Mocked<typeof ContainerURL>;
containerURL.prototype.delete = jest.fn(() => Promise.resolve());
containerURL.prototype.create = jest.fn(() => Promise.resolve({ statusCode: 201 }));
containerURL.prototype.delete = jest.fn(() => Promise.resolve({ statusCode: 204 }));
containerURL.prototype.listBlobFlatSegment = jest.fn(() => Promise.resolve(ad.blobs));
BlockBlobURL.fromContainerURL = jest.fn(() => new BlockBlobURL(null, null));
@ -147,9 +148,23 @@ describe("Azure blob functions", () => {
expect(containers).toEqual(ad.containers.containerItems.map((element) => element.name));
});
it("Creates a container in the account", () => {
it("Creates a container in the account", async () => {
const provider: AzureBlobStorage = new AzureBlobStorage(options);
const container = provider.createContainer(null);
await expect(provider.createContainer(null)).resolves.not.toBeNull();
expect(ContainerURL.fromServiceURL).toBeCalledWith(
expect.any(ServiceURL),
ad.containerName,
);
expect(containerURL.prototype.create).toBeCalled();
});
it("Creates a container that already exists", async () => {
containerURL.prototype.create = jest.fn(() => {
return Promise.reject({ statusCode: 409 });
});
const provider: AzureBlobStorage = new AzureBlobStorage(options);
await expect(provider.createContainer(null)).resolves.not.toBeNull();
expect(ContainerURL.fromServiceURL).toBeCalledWith(
expect.any(ServiceURL),
ad.containerName,

Просмотреть файл

@ -165,7 +165,15 @@ export class AzureBlobStorage implements IStorageProvider {
*/
public async createContainer(containerName: string): Promise<void> {
const containerURL = this.getContainerURL();
await containerURL.create(Aborter.none);
try {
await containerURL.create(Aborter.none);
} catch (e) {
if (e.statusCode === 409) {
return;
}
throw e;
}
}
/**

Просмотреть файл

@ -31,7 +31,7 @@ describe("TFRecord Asset Component", () => {
});
HtmlFileReader.getAssetArray = jest.fn((asset) => {
return Promise.resolve<ArrayBuffer>(new Uint8Array(tfrecords));
return Promise.resolve<ArrayBuffer>(new Uint8Array(tfrecords).buffer);
});
const defaultProps: IAssetProps = {

Просмотреть файл

@ -1,5 +1,5 @@
import React from "react";
import { FieldProps } from "react-jsonschema-form";
import { FieldProps, WidgetProps } from "react-jsonschema-form";
import Guard from "../../../../common/guard";
/**
@ -7,11 +7,25 @@ import Guard from "../../../../common/guard";
* @param Widget UI Widget for form
* @param mapProps Function mapping props to an object
*/
export default function CustomField<Props = {}>(Widget: any, mapProps?: (props: FieldProps) => Props) {
export function CustomField<Props = {}>(Widget: any, mapProps?: (props: FieldProps) => Props) {
Guard.null(Widget);
return function render(props: FieldProps) {
const widgetProps = mapProps ? mapProps(props) : props;
return ( <Widget {...widgetProps} /> );
return (<Widget {...widgetProps} />);
};
}
/**
* Custom widget for react-jsonschema-form
* @param Widget UI Widget for form
* @param mapProps Function mapping component props to form widget props
*/
export function CustomWidget<Props = {}>(Widget: any, mapProps?: (props: WidgetProps) => Props) {
Guard.null(Widget);
return function render(props: WidgetProps) {
const widgetProps = mapProps ? mapProps(props) : props;
return (<Widget {...widgetProps} />);
};
}

Просмотреть файл

@ -0,0 +1,36 @@
import React from "react";
import { ISliderProps, Slider } from "./slider";
import { mount, ReactWrapper } from "enzyme";
import RcSlider from "rc-slider";
describe("Slider Component", () => {
const onChangeHandler = jest.fn();
const defaultProps: ISliderProps = {
value: 80,
onChange: onChangeHandler,
};
function createComponent(props?: ISliderProps): ReactWrapper<ISliderProps> {
props = props || defaultProps;
return mount(<Slider {...props} />);
}
let wrapper: ReactWrapper<ISliderProps>;
beforeEach(() => {
wrapper = createComponent();
});
it("renders correctly", () => {
expect(wrapper.find(".slider-value").text()).toEqual(defaultProps.value.toString());
expect(wrapper.find(RcSlider).exists()).toBe(true);
});
it("raises onChange handler when value has changed", () => {
const expectedValue = 60;
const slider = wrapper.find(RcSlider) as ReactWrapper<any>;
slider.props().onChange(expectedValue);
expect(onChangeHandler).toBeCalledWith(expectedValue);
});
});

Просмотреть файл

@ -0,0 +1,25 @@
import React from "react";
import RcSlider from "rc-slider";
import "rc-slider/assets/index.css";
export interface ISliderProps {
value: number;
min?: number;
max?: number;
onChange: (value) => void;
disabled?: boolean;
}
/**
* Slider component to select a value between a min / max range
*/
export class Slider extends React.Component<ISliderProps> {
public render() {
return (
<div className="slider">
<span className="slider-value">{this.props.value}</span>
<RcSlider {...this.props} />
</div>
);
}
}

Просмотреть файл

@ -6,7 +6,7 @@ import CustomFieldTemplate from "../../common/customField/customFieldTemplate";
import { ArrayFieldTemplate } from "../../common/arrayField/arrayFieldTemplate";
import { IAppSettings } from "../../../../models/applicationState";
import { ProtectedInput } from "../../common/protectedInput/protectedInput";
import CustomField from "../../common/customField/customField";
import { CustomField } from "../../common/customField/customField";
import { generateKey } from "../../../../common/crypto";
// tslint:disable-next-line:no-var-requires
const formSchema = addLocValues(require("./appSettingsForm.json"));

Просмотреть файл

@ -8,7 +8,10 @@ import ExportProviderPicker from "../../common/exportProviderPicker/exportProvid
import CustomFieldTemplate from "../../common/customField/customFieldTemplate";
import ExternalPicker from "../../common/externalPicker/externalPicker";
import { ProtectedInput } from "../../common/protectedInput/protectedInput";
import { ExportAssetState } from "../../../../providers/export/exportProvider";
import Checkbox from "rc-checkbox";
import "rc-checkbox/assets/index.css";
import { CustomWidget } from "../../common/customField/customField";
import { Slider } from "../../common/slider/slider";
// tslint:disable-next-line:no-var-requires
const formSchema = addLocValues(require("./exportForm.json"));
@ -48,29 +51,26 @@ export interface IExportFormState {
* @description - Form to view/edit settings for exporting of project
*/
export default class ExportForm extends React.Component<IExportFormProps, IExportFormState> {
public state: IExportFormState = {
classNames: ["needs-validation"],
providerName: this.props.settings ? this.props.settings.providerType : null,
formSchema: { ...formSchema },
uiSchema: { ...uiSchema },
formData: this.props.settings,
};
private widgets = {
externalPicker: (ExternalPicker as any) as Widget,
exportProviderPicker: (ExportProviderPicker as any) as Widget,
protectedInput: (ProtectedInput as any) as Widget,
slider: (Slider as any) as Widget,
checkbox: CustomWidget(Checkbox, (props) => ({
checked: props.value,
onChange: (value) => props.onChange(value.target.checked),
disabled: props.disabled,
})),
};
constructor(props, context) {
super(props, context);
this.state = {
classNames: ["needs-validation"],
providerName: this.props.settings ? this.props.settings.providerType : null,
formSchema: { ...formSchema },
uiSchema: { ...uiSchema },
formData: this.props.settings,
};
this.onFormSubmit = this.onFormSubmit.bind(this);
this.onFormValidate = this.onFormValidate.bind(this);
this.onFormChange = this.onFormChange.bind(this);
this.onFormCancel = this.onFormCancel.bind(this);
}
public componentDidMount() {
if (this.props.settings) {
this.bindForm(this.props.settings);
@ -119,7 +119,7 @@ export default class ExportForm extends React.Component<IExportFormProps, IExpor
}
}
private onFormValidate(exportFormat: IExportFormat, errors: FormValidation) {
private onFormValidate = (exportFormat: IExportFormat, errors: FormValidation): FormValidation => {
if (this.state.classNames.indexOf("was-validated") === -1) {
this.setState({
classNames: [...this.state.classNames, "was-validated"],
@ -129,20 +129,22 @@ export default class ExportForm extends React.Component<IExportFormProps, IExpor
return errors;
}
private onFormSubmit = (args: ISubmitEvent<IExportFormat>) => {
private onFormSubmit = (args: ISubmitEvent<IExportFormat>): void => {
this.props.onSubmit(args.formData);
}
private onFormCancel() {
private onFormCancel = (): void => {
if (this.props.onCancel) {
this.props.onCancel();
}
}
private bindForm(exportFormat: IExportFormat, resetProviderOptions: boolean = false) {
private bindForm = (exportFormat: IExportFormat, resetProviderOptions: boolean = false): void => {
// If no provider type was specified on bind, pick the default provider
const providerType = (exportFormat && exportFormat.providerType) ?
exportFormat.providerType : ExportProviderFactory.defaultProvider.name;
const providerType = (exportFormat && exportFormat.providerType)
? exportFormat.providerType
: ExportProviderFactory.defaultProvider.name;
let newFormSchema: any = this.state.formSchema;
let newUiSchema: any = this.state.uiSchema;

Просмотреть файл

@ -5,7 +5,7 @@ import { addLocValues, strings } from "../../../../common/strings";
import { IConnection, IProject, ITag, IAppSettings } from "../../../../models/applicationState";
import { StorageProviderFactory } from "../../../../providers/storage/storageProviderFactory";
import { ConnectionPickerWithRouter } from "../../common/connectionPicker/connectionPicker";
import CustomField from "../../common/customField/customField";
import { CustomField } from "../../common/customField/customField";
import CustomFieldTemplate from "../../common/customField/customFieldTemplate";
import { ISecurityTokenPickerProps, SecurityTokenPicker } from "../../common/securityTokenPicker/securityTokenPicker";
import "vott-react/dist/css/tagsInput.css";

Просмотреть файл

@ -1,6 +1,6 @@
import { ExportProviderFactory } from "./providers/export/exportProviderFactory";
import { TFPascalVOCJsonExportProvider } from "./providers/export/tensorFlowPascalVOC";
import { TFRecordsJsonExportProvider } from "./providers/export/tensorFlowRecords";
import { TFPascalVOCExportProvider } from "./providers/export/tensorFlowPascalVOC";
import { TFRecordsExportProvider } from "./providers/export/tensorFlowRecords";
import { VottJsonExportProvider } from "./providers/export/vottJson";
import { AssetProviderFactory } from "./providers/storage/assetProviderFactory";
import { AzureBlobStorage } from "./providers/storage/azureBlobStorage";
@ -56,12 +56,12 @@ export default function registerProviders() {
ExportProviderFactory.register({
name: "tensorFlowPascalVOC",
displayName: strings.export.providers.tfPascalVoc.displayName,
factory: (project, options) => new TFPascalVOCJsonExportProvider(project, options),
factory: (project, options) => new TFPascalVOCExportProvider(project, options),
});
ExportProviderFactory.register({
name: "tensorFlowRecords",
displayName: strings.export.providers.tfRecords.displayName,
factory: (project, options) => new TFRecordsJsonExportProvider(project, options),
factory: (project, options) => new TFRecordsExportProvider(project, options),
});
ExportProviderFactory.register({
name: "azureCustomVision",

Просмотреть файл

@ -293,7 +293,7 @@ describe("Asset Service", () => {
});
HtmlFileReader.getAssetArray = jest.fn((asset) => {
return Promise.resolve<Uint8Array>(new Uint8Array(tfrecords));
return Promise.resolve<ArrayBuffer>(new Uint8Array(tfrecords).buffer);
});
it("Loads the asset metadata from the tfrecord file", async () => {