diff --git a/package-lock.json b/package-lock.json index 77c4644d..a4c5876f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14151,6 +14151,17 @@ "react-lifecycles-compat": "^3.0.4" } }, + "rc-checkbox": { + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/rc-checkbox/-/rc-checkbox-2.1.6.tgz", + "integrity": "sha512-+VxQbt2Cwe1PxCvwosrAYXT6EQeGwrbLJB2K+IPGCSRPCKnk9zcub/0eW8A4kxjyyfh60PkwsAUZ7qmB31OmRA==", + "requires": { + "babel-runtime": "^6.23.0", + "classnames": "2.x", + "prop-types": "15.x", + "rc-util": "^4.0.4" + } + }, "rc-menu": { "version": "7.4.21", "resolved": "https://registry.npmjs.org/rc-menu/-/rc-menu-7.4.21.tgz", @@ -14169,6 +14180,40 @@ "resize-observer-polyfill": "^1.5.0" } }, + "rc-slider": { + "version": "8.6.7", + "resolved": "https://registry.npmjs.org/rc-slider/-/rc-slider-8.6.7.tgz", + "integrity": "sha512-QIFWMnK1VLc4TtJSZJgjhI6UOhN8eg53EM2La+eRa8rSPZwJT3rIWfZnTZs7OV7zXG/AiLWN4G+oGxuMcEFpsg==", + "requires": { + "babel-runtime": "6.x", + "classnames": "^2.2.5", + "prop-types": "^15.5.4", + "rc-tooltip": "^3.7.0", + "rc-util": "^4.0.4", + "shallowequal": "^1.0.1", + "warning": "^4.0.3" + }, + "dependencies": { + "warning": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/warning/-/warning-4.0.3.tgz", + "integrity": "sha512-rpJyN222KWIvHJ/F53XSZv0Zl/accqHR8et1kpaMTD/fLCRxtV8iX8czMzY7sVZupTI3zcUTg8eycS2kNF9l6w==", + "requires": { + "loose-envify": "^1.0.0" + } + } + } + }, + "rc-tooltip": { + "version": "3.7.3", + "resolved": "https://registry.npmjs.org/rc-tooltip/-/rc-tooltip-3.7.3.tgz", + "integrity": "sha512-dE2ibukxxkrde7wH9W8ozHKUO4aQnPZ6qBHtrTH9LoO836PjDdiaWO73fgPB05VfJs9FbZdmGPVEbXCeOP99Ww==", + "requires": { + "babel-runtime": "6.x", + "prop-types": "^15.5.8", + "rc-trigger": "^2.2.2" + } + }, "rc-trigger": { "version": "2.6.2", "resolved": "https://registry.npmjs.org/rc-trigger/-/rc-trigger-2.6.2.tgz", diff --git a/package.json b/package.json index d4ac763d..07a2ea75 100644 --- a/package.json +++ b/package.json @@ -26,7 +26,9 @@ "lodash": "^4.17.11", "md5.js": "^1.3.5", "node-int64": "^0.4.0", + "rc-checkbox": "^2.1.6", "rc-menu": "^7.4.21", + "rc-slider": "^8.6.7", "react": "^16.7.0", "react-appinsights": "^3.0.0-rc.5", "react-dom": "^16.7.0", diff --git a/src/common/localization/en-us.ts b/src/common/localization/en-us.ts index dbe32961..a33c8cb9 100644 --- a/src/common/localization/en-us.ts +++ b/src/common/localization/en-us.ts @@ -286,6 +286,14 @@ export const english: IAppStrings = { }, tfPascalVoc: { displayName: "Tensorflow Pascal VOC", + testTrainSplit: { + title: "Test / Train Split", + description: "The test train split to use for exported data", + }, + exportUnassigned: { + title: "Export Unassigned", + description: "Whether or not to include unassigned tags in exported data", + }, }, }, messages: { diff --git a/src/common/localization/es-cl.ts b/src/common/localization/es-cl.ts index 8fe93376..d79ad03d 100644 --- a/src/common/localization/es-cl.ts +++ b/src/common/localization/es-cl.ts @@ -287,6 +287,14 @@ export const spanish: IAppStrings = { }, tfPascalVoc: { displayName: "Tensorflow Pascal VOC", + testTrainSplit: { + title: "Prueba/tren Split", + description: "La división del tren de prueba que se utilizará para los datos exportados", + }, + exportUnassigned: { + title: "Exportar sin asignar", + description: "Si se incluyen o no etiquetas no asignadas en los datos exportados", + }, }, }, messages: { diff --git a/src/common/strings.ts b/src/common/strings.ts index 799e0433..8ea12e8a 100644 --- a/src/common/strings.ts +++ b/src/common/strings.ts @@ -288,6 +288,14 @@ export interface IAppStrings { }, tfPascalVoc: { displayName: string, + testTrainSplit: { + title: string, + description: string, + }, + exportUnassigned: { + title: string, + description: string, + }, }, }, messages: { diff --git a/src/index.scss b/src/index.scss index df7d0b58..a1340235 100644 --- a/src/index.scss +++ b/src/index.scss @@ -62,4 +62,13 @@ input[type=file] { border: solid 1px #62c462; } } + + .rc-checkbox { + margin-left: 0.5em; + } + + .slider-value { + font-weight: 500; + font-size: 90%; + } } diff --git a/src/models/applicationState.ts b/src/models/applicationState.ts index 162b813e..c8ad04a9 100644 --- a/src/models/applicationState.ts +++ b/src/models/applicationState.ts @@ -181,7 +181,7 @@ export interface IExportProviderOptions extends IProviderOptions { * @description - Defines the settings for how project data is exported into commonly used format * @member id - Unique identifier for export format * @member name - Name of export format - * @member providerType - The export format type (TF Records, YOLO, CVS, etc) + * @member providerType - The export format type (TF Records, YOLO, CSV, etc) * @member providerOptions - The provider specific option required to export data */ export interface IExportFormat { diff --git a/src/providers/export/tensorFlowPascalVOC.json b/src/providers/export/tensorFlowPascalVOC.json index afc612b8..e9f9ba19 100644 --- a/src/providers/export/tensorFlowPascalVOC.json +++ b/src/providers/export/tensorFlowPascalVOC.json @@ -17,6 +17,18 @@ "${strings.export.providers.common.properties.assetState.options.visited}", "${strings.export.providers.common.properties.assetState.options.tagged}" ] + }, + "testTrainSplit": { + "title": "${strings.export.providers.tfPascalVoc.testTrainSplit.title}", + "description": "${strings.export.providers.tfPascalVoc.testTrainSplit.description}", + "type": "number", + "default": 80 + }, + "exportUnassigned": { + "title": "${strings.export.providers.tfPascalVoc.exportUnassigned.title}", + "description": "${strings.export.providers.tfPascalVoc.exportUnassigned.description}", + "type": "boolean", + "default": true } } } diff --git a/src/providers/export/tensorFlowPascalVOC.test.ts b/src/providers/export/tensorFlowPascalVOC.test.ts index 35539af6..da0f0d6a 100644 --- a/src/providers/export/tensorFlowPascalVOC.test.ts +++ b/src/providers/export/tensorFlowPascalVOC.test.ts @@ -1,5 +1,5 @@ import _ from "lodash"; -import { TFPascalVOCJsonExportProvider } from "./tensorFlowPascalVOC"; +import { TFPascalVOCExportProvider, ITFPascalVOCExportProviderOptions } from "./tensorFlowPascalVOC"; import { ExportAssetState } from "./exportProvider"; import registerProviders from "../../registerProviders"; import { ExportProviderFactory } from "./exportProviderFactory"; @@ -21,16 +21,6 @@ import { AssetProviderFactory } from "../storage/assetProviderFactory"; registerMixins(); -function _base64ToArrayBuffer(base64: string) { - const binaryString = window.atob(base64); - const len = binaryString.length; - const bytes = new Uint8Array(len); - for (let i = 0; i < len; i++) { - bytes[i] = binaryString.charCodeAt(i); - } - return bytes.buffer; -} - describe("TFPascalVOC Json Export Provider", () => { const testAssets = MockFactory.createTestAssets(10, 1); const baseTestProject = MockFactory.createTestProject("Test Project"); @@ -38,7 +28,6 @@ describe("TFPascalVOC Json Export Provider", () => { "asset-1": MockFactory.createTestAsset("1", AssetState.Tagged), "asset-2": MockFactory.createTestAsset("2", AssetState.Tagged), "asset-3": MockFactory.createTestAsset("3", AssetState.Visited), - "asset-4": MockFactory.createTestAsset("4", AssetState.NotVisited), }; baseTestProject.sourceConnection = MockFactory.createTestConnection("test", "localFileSystemProxy"); baseTestProject.targetConnection = MockFactory.createTestConnection("test", "localFileSystemProxy"); @@ -62,16 +51,18 @@ describe("TFPascalVOC Json Export Provider", () => { }); it("Is defined", () => { - expect(TFPascalVOCJsonExportProvider).toBeDefined(); + expect(TFPascalVOCExportProvider).toBeDefined(); }); it("Can be instantiated through the factory", () => { - const options: IExportProviderOptions = { + const options: ITFPascalVOCExportProviderOptions = { assetState: ExportAssetState.All, + exportUnassigned: true, + testTrainSplit: 80, }; const exportProvider = ExportProviderFactory.create("tensorFlowPascalVOC", baseTestProject, options); expect(exportProvider).not.toBeNull(); - expect(exportProvider).toBeInstanceOf(TFPascalVOCJsonExportProvider); + expect(exportProvider).toBeInstanceOf(TFPascalVOCExportProvider); }); describe("Export variations", () => { @@ -79,27 +70,12 @@ describe("TFPascalVOC Json Export Provider", () => { const assetServiceMock = AssetService as jest.Mocked; assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => { const mockTag = MockFactory.createTestTag(); - - const mockStartPoint: IPoint = { - x: 1, - y: 2, - }; - - const mockEndPoint: IPoint = { - x: 3, - y: 4, - }; - - const mockRegion: IRegion = { - id: "id", - type: RegionType.Rectangle, - tags: [mockTag.name], - points: [mockStartPoint, mockEndPoint], - }; + const mockRegion1 = MockFactory.createTestRegion("region-1", [mockTag.name]); + const mockRegion2 = MockFactory.createTestRegion("region-2", [mockTag.name]); const assetMetadata: IAssetMetadata = { asset, - regions: [mockRegion], + regions: [mockRegion1, mockRegion2], version: appInfo.version, }; @@ -111,14 +87,16 @@ describe("TFPascalVOC Json Export Provider", () => { }); it("Exports all assets", async () => { - const options: IExportProviderOptions = { + const options: ITFPascalVOCExportProviderOptions = { assetState: ExportAssetState.All, + exportUnassigned: true, + testTrainSplit: 80, }; const testProject = { ...baseTestProject }; testProject.tags = MockFactory.createTestTags(3); - const exportProvider = new TFPascalVOCJsonExportProvider(testProject, options); + const exportProvider = new TFPascalVOCExportProvider(testProject, options); await exportProvider.export(); const storageProviderMock = LocalFileSystemProxy as any; @@ -167,14 +145,16 @@ describe("TFPascalVOC Json Export Provider", () => { }); it("Exports only visited assets (includes tagged)", async () => { - const options: IExportProviderOptions = { + const options: ITFPascalVOCExportProviderOptions = { assetState: ExportAssetState.Visited, + exportUnassigned: true, + testTrainSplit: 80, }; const testProject = { ...baseTestProject }; testProject.tags = MockFactory.createTestTags(1); - const exportProvider = new TFPascalVOCJsonExportProvider(testProject, options); + const exportProvider = new TFPascalVOCExportProvider(testProject, options); await exportProvider.export(); const storageProviderMock = LocalFileSystemProxy as any; @@ -211,14 +191,16 @@ describe("TFPascalVOC Json Export Provider", () => { }); it("Exports only tagged assets", async () => { - const options: IExportProviderOptions = { + const options: ITFPascalVOCExportProviderOptions = { assetState: ExportAssetState.Tagged, + exportUnassigned: true, + testTrainSplit: 80, }; const testProject = { ...baseTestProject }; testProject.tags = MockFactory.createTestTags(3); - const exportProvider = new TFPascalVOCJsonExportProvider(testProject, options); + const exportProvider = new TFPascalVOCExportProvider(testProject, options); await exportProvider.export(); const storageProviderMock = LocalFileSystemProxy as any; @@ -258,5 +240,156 @@ describe("TFPascalVOC Json Export Provider", () => { expect(writeTextFileCalls.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_train.txt"))) .toBeGreaterThanOrEqual(0); }); + + it("Export includes unassigned tags", async () => { + const options: ITFPascalVOCExportProviderOptions = { + assetState: ExportAssetState.Tagged, + exportUnassigned: true, + testTrainSplit: 80, + }; + + const testProject = { ...baseTestProject }; + const testAssets = MockFactory.createTestAssets(10, 0); + testAssets.forEach((asset) => asset.state = AssetState.Tagged); + testProject.assets = _.keyBy(testAssets, (asset) => asset.id); + testProject.tags = MockFactory.createTestTags(3); + + const exportProvider = new TFPascalVOCExportProvider(testProject, options); + await exportProvider.export(); + + const storageProviderMock = LocalFileSystemProxy as any; + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[]; + + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_val.txt"))) + .toBeGreaterThanOrEqual(0); + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_train.txt"))) + .toBeGreaterThanOrEqual(0); + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt"))) + .toBeGreaterThanOrEqual(0); + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt"))) + .toBeGreaterThanOrEqual(0); + }); + + it("Export does not include unassigned tags", async () => { + const options: ITFPascalVOCExportProviderOptions = { + assetState: ExportAssetState.Tagged, + exportUnassigned: false, + testTrainSplit: 80, + }; + + const testProject = { ...baseTestProject }; + const testAssets = MockFactory.createTestAssets(10, 0); + testAssets.forEach((asset) => asset.state = AssetState.Tagged); + testProject.assets = _.keyBy(testAssets, (asset) => asset.id); + testProject.tags = MockFactory.createTestTags(3); + + const exportProvider = new TFPascalVOCExportProvider(testProject, options); + await exportProvider.export(); + + const storageProviderMock = LocalFileSystemProxy as any; + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[]; + + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_val.txt"))) + .toEqual(-1); + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 0_train.txt"))) + .toEqual(-1); + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt"))) + .toBeGreaterThanOrEqual(0); + expect(writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt"))) + .toBeGreaterThanOrEqual(0); + }); + + describe("Annotations", () => { + it("contains expected XML", async () => { + const options: ITFPascalVOCExportProviderOptions = { + assetState: ExportAssetState.Tagged, + exportUnassigned: false, + testTrainSplit: 80, + }; + + const testProject = { ...baseTestProject }; + const testAssets = MockFactory.createTestAssets(10, 0); + testAssets.forEach((asset) => asset.state = AssetState.Tagged); + testProject.assets = _.keyBy(testAssets, (asset) => asset.id); + testProject.tags = [MockFactory.createTestTag("1")]; + + const exportProvider = new TFPascalVOCExportProvider(testProject, options); + await exportProvider.export(); + + const storageProviderMock = LocalFileSystemProxy as any; + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[]; + const assetIndex = writeTextFileCalls.findIndex((args) => args[0].endsWith("/Annotations/Asset 1.xml")); + const assetXml = writeTextFileCalls[assetIndex][1] as string; + const objectRegExp = /([\s\S]*?)<\/object>/g; + const folderRegExp = new RegExp(/(.*?)<\/filename>/g); + const pathRegExp = new RegExp(/(.*?)<\/path>/g); + const widthRegExp = new RegExp(/(.*?)<\/width>/g); + const heightRegExp = new RegExp(/(.*?)<\/height>/g); + const objectMatches = assetXml.match(objectRegExp); + + expect(objectMatches).toHaveLength(2); + expect(folderRegExp.exec(assetXml)[1]).toEqual(testAssets[1].name); + expect(pathRegExp.exec(assetXml)[1]).toContain(testAssets[1].name); + expect(widthRegExp.exec(assetXml)[1]).toEqual(testAssets[1].size.width.toString()); + expect(heightRegExp.exec(assetXml)[1]).toEqual(testAssets[1].size.height.toString()); + }); + }); + + describe("Test Train Splits", () => { + async function testTestTrainSplit(testTrainSplit: number): Promise { + const options: ITFPascalVOCExportProviderOptions = { + assetState: ExportAssetState.Tagged, + exportUnassigned: true, + testTrainSplit, + }; + + const testProject = { ...baseTestProject }; + const testAssets = MockFactory.createTestAssets(10, 0); + testAssets.forEach((asset) => asset.state = AssetState.Tagged); + testProject.assets = _.keyBy(testAssets, (asset) => asset.id); + testProject.tags = [MockFactory.createTestTag("1")]; + + const exportProvider = new TFPascalVOCExportProvider(testProject, options); + await exportProvider.export(); + + const storageProviderMock = LocalFileSystemProxy as any; + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[]; + + const valDataIndex = writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt")); + const trainDataIndex = writeTextFileCalls + .findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt")); + + const expectedTrainCount = (testTrainSplit / 100) * testAssets.length; + const expectedTestCount = ((100 - testTrainSplit) / 100) * testAssets.length; + + expect(writeTextFileCalls[valDataIndex][1].split("\n")).toHaveLength(expectedTestCount); + expect(writeTextFileCalls[trainDataIndex][1].split("\n")).toHaveLength(expectedTrainCount); + } + + it("Correctly generated files based on 50/50 test / train split", async () => { + await testTestTrainSplit(50); + }); + + it("Correctly generated files based on 60/40 test / train split", async () => { + await testTestTrainSplit(60); + }); + + it("Correctly generated files based on 80/20 test / train split", async () => { + await testTestTrainSplit(80); + }); + + it("Correctly generated files based on 90/10 test / train split", async () => { + await testTestTrainSplit(90); + }); + }); }); }); diff --git a/src/providers/export/tensorFlowPascalVOC.ts b/src/providers/export/tensorFlowPascalVOC.ts index 166c721a..9fd44979 100644 --- a/src/providers/export/tensorFlowPascalVOC.ts +++ b/src/providers/export/tensorFlowPascalVOC.ts @@ -5,6 +5,8 @@ import Guard from "../../common/guard"; import HtmlFileReader from "../../common/htmlFileReader"; import { itemTemplate, annotationTemplate, objectTemplate } from "./tensorFlowPascalVOC/tensorFlowPascalVOCTemplates"; import { interpolate } from "../../common/strings"; +import { PlatformType } from "../../common/hostProcess"; +import os from "os"; interface IObjectInfo { name: string; @@ -20,14 +22,24 @@ interface IImageInfo { objects: IObjectInfo[]; } +/** + * Export options for TensorFlow Pascal VOC Export Provider + */ +export interface ITFPascalVOCExportProviderOptions extends IExportProviderOptions { + /** The test / train split ratio for exporting data */ + testTrainSplit?: number; + /** Whether or not to include unassigned tags in exported data */ + exportUnassigned?: boolean; +} + /** * @name - TFPascalVOC Json Export Provider * @description - Exports a project into a single JSON file that include all configured assets */ -export class TFPascalVOCJsonExportProvider extends ExportProvider { +export class TFPascalVOCExportProvider extends ExportProvider { private imagesInfo = new Map(); - constructor(project: IProject, options: IExportProviderOptions) { + constructor(project: IProject, options: ITFPascalVOCExportProviderOptions) { super(project, options); Guard.null(options); } @@ -48,8 +60,15 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider { await this.exportPBTXT(exportFolderName, this.project); await this.exportAnnotations(exportFolderName, allAssets); - // TODO: Make testSplit && exportUnassignedTags optional parameter in the UI Exporter configuration - await this.exportImageSets(exportFolderName, allAssets, this.project.tags, 0.2, true); + // TestSplit && exportUnassignedTags are optional parameter in the UI Exporter configuration + const testSplit = (100 - (this.options.testTrainSplit || 80)) / 100; + await this.exportImageSets( + exportFolderName, + allAssets, + this.project.tags, + testSplit, + this.options.exportUnassigned, + ); } private async exportImages(exportFolderName: string, allAssets: IAssetMetadata[]) { @@ -57,13 +76,9 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider { const jpegImagesFolderName = `${exportFolderName}/JPEGImages`; await this.storageProvider.createContainer(jpegImagesFolderName); - try { - await allAssets.mapAsync(async (assetMetadata) => { - await this.exportSingleImage(jpegImagesFolderName, assetMetadata); - }); - } catch (err) { - console.log(err); - } + await allAssets.mapAsync(async (assetMetadata) => { + await this.exportSingleImage(jpegImagesFolderName, assetMetadata); + }); } private async exportSingleImage(jpegImagesFolderName: string, assetMetadata: IAssetMetadata): Promise { @@ -102,22 +117,20 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider { private getAssetTagArray(element: IAssetMetadata): IObjectInfo[] { const tagObjects = []; - element.regions.filter((region) => (region.type === RegionType.Rectangle || - region.type === RegionType.Square) && - region.points.length === 2) - .forEach((region) => { - region.tags.forEach((tagName) => { - const objectInfo: IObjectInfo = { - name: tagName, - xmin: region.points[0].x, - ymin: region.points[0].y, - xmax: region.points[1].x, - ymax: region.points[1].y, - }; + element.regions.forEach((region) => { + region.tags.forEach((tagName) => { + const objectInfo: IObjectInfo = { + name: tagName, + xmin: region.boundingBox.left, + ymin: region.boundingBox.top, + xmax: region.boundingBox.left + region.boundingBox.width, + ymax: region.boundingBox.top + region.boundingBox.height, + }; - tagObjects.push(objectInfo); - }); + tagObjects.push(objectInfo); }); + }); + return tagObjects; } @@ -199,7 +212,7 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider { await this.storageProvider.writeText(assetFilePath, interpolate(annotationTemplate, params)); }); } catch (err) { - console.log(err); + console.log("Error writing Pascal VOC annotation file"); } } @@ -209,6 +222,10 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider { tags: ITag[], testSplit: number, exportUnassignedTags: boolean) { + if (!tags) { + return; + } + // Create ImageSets Sub Folder (Main ?) const imageSetsFolderName = `${exportFolderName}/ImageSets`; await this.storageProvider.createContainer(imageSetsFolderName); @@ -216,68 +233,61 @@ export class TFPascalVOCJsonExportProvider extends ExportProvider { const imageSetsMainFolderName = `${exportFolderName}/ImageSets/Main`; await this.storageProvider.createContainer(imageSetsMainFolderName); - const tagsDict = new Map(); - if (tags) { - tags.forEach((tag) => { - tagsDict.set(tag.name, []); - }); + const assetUsage = new Map>(); + const tagUsage = new Map(); - allAssets.forEach((asset) => { - if (asset.regions.length > 0) { - asset.regions.forEach((region) => { - tags.forEach((tag) => { - const array = tagsDict.get(tag.name); - if (region.tags.filter((tagName) => tagName === tag.name).length > 0) { - array.push(`${asset.asset.name} 1`); - } else { - array.push(`${asset.asset.name} -1`); - } - }); - }); - } else if (exportUnassignedTags) { + // Generate tag usage per asset + allAssets.forEach((assetMetadata) => { + const appliedTags = new Set(); + assetUsage.set(assetMetadata.asset.name, appliedTags); + + if (assetMetadata.regions.length > 0) { + assetMetadata.regions.forEach((region) => { tags.forEach((tag) => { - const array = tagsDict.get(tag.name); - array.push(`${asset.asset.name} -1`); + let tagInstances = tagUsage.get(tag.name) || 0; + if (region.tags.filter((tagName) => tagName === tag.name).length > 0) { + appliedTags.add(tag.name); + tagUsage.set(tag.name, tagInstances += 1); + } }); - } - }); + }); + } + }); - // Save ImageSets - tags.forEach(async (tag) => { - if (testSplit > 0 && testSplit <= 1) { - // Shuffle tagsDict sets - tagsDict.forEach((value, key) => { - value = this.shuffle(value); - }); - - const array = tagsDict.get(tag.name); - - // Split in Test and Train sets - const totalAssets = array.length; - const testCount = Math.ceil(totalAssets * testSplit); - - const testArray = array.slice(0, testCount); - const trainArray = array.slice(testCount, totalAssets); - - const testImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_val.txt`; - await this.storageProvider.writeText(testImageSetFileName, testArray.join("\n")); - - const trainImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_train.txt`; - await this.storageProvider.writeText(trainImageSetFileName, trainArray.join("\n")); + // Save ImageSets + await tags.forEachAsync(async (tag) => { + const tagInstances = tagUsage.get(tag.name) || 0; + if (!exportUnassignedTags && tagInstances === 0) { + return; + } + const assetList = []; + assetUsage.forEach((tags, assetName) => { + if (tags.has(tag.name)) { + assetList.push(`${assetName} 1`); } else { - const imageSetFileName = `${imageSetsMainFolderName}/${tag.name}.txt`; - await this.storageProvider.writeText(imageSetFileName, tagsDict.get(tag.name).join("\n")); + assetList.push(`${assetName} -1`); } }); - } - } - private shuffle(a: any[]) { - for (let i = a.length - 1; i > 0; i--) { - const j = Math.floor(Math.random() * (i + 1)); - [a[i], a[j]] = [a[j], a[i]]; - } - return a; + if (testSplit > 0 && testSplit <= 1) { + // Split in Test and Train sets + const totalAssets = assetUsage.size; + const testCount = Math.ceil(totalAssets * testSplit); + + const testArray = assetList.slice(0, testCount); + const trainArray = assetList.slice(testCount, totalAssets); + + const testImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_val.txt`; + await this.storageProvider.writeText(testImageSetFileName, testArray.join(os.EOL)); + + const trainImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_train.txt`; + await this.storageProvider.writeText(trainImageSetFileName, trainArray.join(os.EOL)); + + } else { + const imageSetFileName = `${imageSetsMainFolderName}/${tag.name}.txt`; + await this.storageProvider.writeText(imageSetFileName, assetList.join(os.EOL)); + } + }); } } diff --git a/src/providers/export/tensorFlowPascalVOC.ui.json b/src/providers/export/tensorFlowPascalVOC.ui.json index 9e26dfee..fc64af7e 100644 --- a/src/providers/export/tensorFlowPascalVOC.ui.json +++ b/src/providers/export/tensorFlowPascalVOC.ui.json @@ -1 +1,8 @@ -{} \ No newline at end of file +{ + "testTrainSplit": { + "ui:widget": "slider" + }, + "exportUnassigned": { + "ui:widget": "checkbox" + } +} diff --git a/src/providers/export/tensorFlowRecords.test.ts b/src/providers/export/tensorFlowRecords.test.ts index 563e0bd3..6ac3ea05 100644 --- a/src/providers/export/tensorFlowRecords.test.ts +++ b/src/providers/export/tensorFlowRecords.test.ts @@ -1,5 +1,5 @@ import _ from "lodash"; -import { TFRecordsJsonExportProvider } from "./tensorFlowRecords"; +import { TFRecordsExportProvider } from "./tensorFlowRecords"; import { ExportAssetState } from "./exportProvider"; import registerProviders from "../../registerProviders"; import { ExportProviderFactory } from "./exportProviderFactory"; @@ -8,7 +8,6 @@ import { RegionType, IPoint, IExportProviderOptions, } from "../../models/applicationState"; import MockFactory from "../../common/mockFactory"; -import axios, { AxiosResponse } from "axios"; jest.mock("../../services/assetService"); import { AssetService } from "../../services/assetService"; @@ -18,6 +17,7 @@ import { LocalFileSystemProxy } from "../storage/localFileSystemProxy"; import registerMixins from "../../registerMixins"; import { appInfo } from "../../common/appInfo"; import { AssetProviderFactory } from "../storage/assetProviderFactory"; +import HtmlFileReader from "../../common/htmlFileReader"; registerMixins(); @@ -35,15 +35,7 @@ describe("TFRecords Json Export Provider", () => { const tagLengthInPbtxt = 31; - axios.get = jest.fn((url, config) => { - return Promise.resolve({ - config, - headers: null, - status: 200, - statusText: "OK", - data: [1, 2, 3], - }); - }); + HtmlFileReader.getAssetArray = jest.fn(() => Promise.resolve(new Uint8Array([1, 2, 3]).buffer)); beforeAll(() => { AssetProviderFactory.create = jest.fn(() => { @@ -58,7 +50,7 @@ describe("TFRecords Json Export Provider", () => { }); it("Is defined", () => { - expect(TFRecordsJsonExportProvider).toBeDefined(); + expect(TFRecordsExportProvider).toBeDefined(); }); it("Can be instantiated through the factory", () => { @@ -67,7 +59,7 @@ describe("TFRecords Json Export Provider", () => { }; const exportProvider = ExportProviderFactory.create("tensorFlowRecords", baseTestProject, options); expect(exportProvider).not.toBeNull(); - expect(exportProvider).toBeInstanceOf(TFRecordsJsonExportProvider); + expect(exportProvider).toBeInstanceOf(TFRecordsExportProvider); }); describe("Export variations", () => { @@ -114,7 +106,7 @@ describe("TFRecords Json Export Provider", () => { const testProject = { ...baseTestProject }; testProject.tags = MockFactory.createTestTags(3); - const exportProvider = new TFRecordsJsonExportProvider(testProject, options); + const exportProvider = new TFRecordsExportProvider(testProject, options); await exportProvider.export(); const storageProviderMock = LocalFileSystemProxy as any; @@ -143,7 +135,7 @@ describe("TFRecords Json Export Provider", () => { const testProject = { ...baseTestProject }; testProject.tags = MockFactory.createTestTags(1); - const exportProvider = new TFRecordsJsonExportProvider(testProject, options); + const exportProvider = new TFRecordsExportProvider(testProject, options); await exportProvider.export(); const storageProviderMock = LocalFileSystemProxy as any; @@ -171,7 +163,7 @@ describe("TFRecords Json Export Provider", () => { const testProject = { ...baseTestProject }; testProject.tags = MockFactory.createTestTags(3); - const exportProvider = new TFRecordsJsonExportProvider(testProject, options); + const exportProvider = new TFRecordsExportProvider(testProject, options); await exportProvider.export(); const storageProviderMock = LocalFileSystemProxy as any; diff --git a/src/providers/export/tensorFlowRecords.ts b/src/providers/export/tensorFlowRecords.ts index e274a63a..359a5a67 100644 --- a/src/providers/export/tensorFlowRecords.ts +++ b/src/providers/export/tensorFlowRecords.ts @@ -26,7 +26,7 @@ interface IImageInfo { * @name - TFRecords Json Export Provider * @description - Exports a project into a single JSON file that include all configured assets */ -export class TFRecordsJsonExportProvider extends ExportProvider { +export class TFRecordsExportProvider extends ExportProvider { constructor(project: IProject, options: IExportProviderOptions) { super(project, options); Guard.null(options); diff --git a/src/providers/storage/azureBlobStorage.test.ts b/src/providers/storage/azureBlobStorage.test.ts index a232584a..1c95d263 100644 --- a/src/providers/storage/azureBlobStorage.test.ts +++ b/src/providers/storage/azureBlobStorage.test.ts @@ -17,7 +17,8 @@ describe("Azure blob functions", () => { ContainerURL.fromServiceURL = jest.fn(() => new ContainerURL(null, null)); const containerURL = ContainerURL as jest.Mocked; - containerURL.prototype.delete = jest.fn(() => Promise.resolve()); + containerURL.prototype.create = jest.fn(() => Promise.resolve({ statusCode: 201 })); + containerURL.prototype.delete = jest.fn(() => Promise.resolve({ statusCode: 204 })); containerURL.prototype.listBlobFlatSegment = jest.fn(() => Promise.resolve(ad.blobs)); BlockBlobURL.fromContainerURL = jest.fn(() => new BlockBlobURL(null, null)); @@ -147,9 +148,23 @@ describe("Azure blob functions", () => { expect(containers).toEqual(ad.containers.containerItems.map((element) => element.name)); }); - it("Creates a container in the account", () => { + it("Creates a container in the account", async () => { const provider: AzureBlobStorage = new AzureBlobStorage(options); - const container = provider.createContainer(null); + await expect(provider.createContainer(null)).resolves.not.toBeNull(); + expect(ContainerURL.fromServiceURL).toBeCalledWith( + expect.any(ServiceURL), + ad.containerName, + ); + expect(containerURL.prototype.create).toBeCalled(); + }); + + it("Creates a container that already exists", async () => { + containerURL.prototype.create = jest.fn(() => { + return Promise.reject({ statusCode: 409 }); + }); + + const provider: AzureBlobStorage = new AzureBlobStorage(options); + await expect(provider.createContainer(null)).resolves.not.toBeNull(); expect(ContainerURL.fromServiceURL).toBeCalledWith( expect.any(ServiceURL), ad.containerName, diff --git a/src/providers/storage/azureBlobStorage.ts b/src/providers/storage/azureBlobStorage.ts index 970590bd..a81510c3 100644 --- a/src/providers/storage/azureBlobStorage.ts +++ b/src/providers/storage/azureBlobStorage.ts @@ -165,7 +165,15 @@ export class AzureBlobStorage implements IStorageProvider { */ public async createContainer(containerName: string): Promise { const containerURL = this.getContainerURL(); - await containerURL.create(Aborter.none); + try { + await containerURL.create(Aborter.none); + } catch (e) { + if (e.statusCode === 409) { + return; + } + + throw e; + } } /** diff --git a/src/react/components/common/assetPreview/tfrecordAsset.test.tsx b/src/react/components/common/assetPreview/tfrecordAsset.test.tsx index 02100252..872eec31 100644 --- a/src/react/components/common/assetPreview/tfrecordAsset.test.tsx +++ b/src/react/components/common/assetPreview/tfrecordAsset.test.tsx @@ -31,7 +31,7 @@ describe("TFRecord Asset Component", () => { }); HtmlFileReader.getAssetArray = jest.fn((asset) => { - return Promise.resolve(new Uint8Array(tfrecords)); + return Promise.resolve(new Uint8Array(tfrecords).buffer); }); const defaultProps: IAssetProps = { diff --git a/src/react/components/common/customField/customField.tsx b/src/react/components/common/customField/customField.tsx index ff568ea9..ab92ef00 100644 --- a/src/react/components/common/customField/customField.tsx +++ b/src/react/components/common/customField/customField.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { FieldProps } from "react-jsonschema-form"; +import { FieldProps, WidgetProps } from "react-jsonschema-form"; import Guard from "../../../../common/guard"; /** @@ -7,11 +7,25 @@ import Guard from "../../../../common/guard"; * @param Widget UI Widget for form * @param mapProps Function mapping props to an object */ -export default function CustomField(Widget: any, mapProps?: (props: FieldProps) => Props) { +export function CustomField(Widget: any, mapProps?: (props: FieldProps) => Props) { Guard.null(Widget); return function render(props: FieldProps) { const widgetProps = mapProps ? mapProps(props) : props; - return ( ); + return (); + }; +} + +/** + * Custom widget for react-jsonschema-form + * @param Widget UI Widget for form + * @param mapProps Function mapping component props to form widget props + */ +export function CustomWidget(Widget: any, mapProps?: (props: WidgetProps) => Props) { + Guard.null(Widget); + + return function render(props: WidgetProps) { + const widgetProps = mapProps ? mapProps(props) : props; + return (); }; } diff --git a/src/react/components/common/slider/slider.test.tsx b/src/react/components/common/slider/slider.test.tsx new file mode 100644 index 00000000..796bb129 --- /dev/null +++ b/src/react/components/common/slider/slider.test.tsx @@ -0,0 +1,36 @@ +import React from "react"; +import { ISliderProps, Slider } from "./slider"; +import { mount, ReactWrapper } from "enzyme"; +import RcSlider from "rc-slider"; + +describe("Slider Component", () => { + const onChangeHandler = jest.fn(); + const defaultProps: ISliderProps = { + value: 80, + onChange: onChangeHandler, + }; + + function createComponent(props?: ISliderProps): ReactWrapper { + props = props || defaultProps; + return mount(); + } + + let wrapper: ReactWrapper; + + beforeEach(() => { + wrapper = createComponent(); + }); + + it("renders correctly", () => { + expect(wrapper.find(".slider-value").text()).toEqual(defaultProps.value.toString()); + expect(wrapper.find(RcSlider).exists()).toBe(true); + }); + + it("raises onChange handler when value has changed", () => { + const expectedValue = 60; + const slider = wrapper.find(RcSlider) as ReactWrapper; + slider.props().onChange(expectedValue); + + expect(onChangeHandler).toBeCalledWith(expectedValue); + }); +}); diff --git a/src/react/components/common/slider/slider.tsx b/src/react/components/common/slider/slider.tsx new file mode 100644 index 00000000..953945b3 --- /dev/null +++ b/src/react/components/common/slider/slider.tsx @@ -0,0 +1,25 @@ +import React from "react"; +import RcSlider from "rc-slider"; +import "rc-slider/assets/index.css"; + +export interface ISliderProps { + value: number; + min?: number; + max?: number; + onChange: (value) => void; + disabled?: boolean; +} + +/** + * Slider component to select a value between a min / max range + */ +export class Slider extends React.Component { + public render() { + return ( +
+ {this.props.value} + +
+ ); + } +} diff --git a/src/react/components/pages/appSettings/appSettingsForm.tsx b/src/react/components/pages/appSettings/appSettingsForm.tsx index d022b5e4..4d3ebff0 100644 --- a/src/react/components/pages/appSettings/appSettingsForm.tsx +++ b/src/react/components/pages/appSettings/appSettingsForm.tsx @@ -6,7 +6,7 @@ import CustomFieldTemplate from "../../common/customField/customFieldTemplate"; import { ArrayFieldTemplate } from "../../common/arrayField/arrayFieldTemplate"; import { IAppSettings } from "../../../../models/applicationState"; import { ProtectedInput } from "../../common/protectedInput/protectedInput"; -import CustomField from "../../common/customField/customField"; +import { CustomField } from "../../common/customField/customField"; import { generateKey } from "../../../../common/crypto"; // tslint:disable-next-line:no-var-requires const formSchema = addLocValues(require("./appSettingsForm.json")); diff --git a/src/react/components/pages/export/exportForm.tsx b/src/react/components/pages/export/exportForm.tsx index e11f547b..7780ddbc 100644 --- a/src/react/components/pages/export/exportForm.tsx +++ b/src/react/components/pages/export/exportForm.tsx @@ -8,7 +8,10 @@ import ExportProviderPicker from "../../common/exportProviderPicker/exportProvid import CustomFieldTemplate from "../../common/customField/customFieldTemplate"; import ExternalPicker from "../../common/externalPicker/externalPicker"; import { ProtectedInput } from "../../common/protectedInput/protectedInput"; -import { ExportAssetState } from "../../../../providers/export/exportProvider"; +import Checkbox from "rc-checkbox"; +import "rc-checkbox/assets/index.css"; +import { CustomWidget } from "../../common/customField/customField"; +import { Slider } from "../../common/slider/slider"; // tslint:disable-next-line:no-var-requires const formSchema = addLocValues(require("./exportForm.json")); @@ -48,29 +51,26 @@ export interface IExportFormState { * @description - Form to view/edit settings for exporting of project */ export default class ExportForm extends React.Component { + public state: IExportFormState = { + classNames: ["needs-validation"], + providerName: this.props.settings ? this.props.settings.providerType : null, + formSchema: { ...formSchema }, + uiSchema: { ...uiSchema }, + formData: this.props.settings, + }; + private widgets = { externalPicker: (ExternalPicker as any) as Widget, exportProviderPicker: (ExportProviderPicker as any) as Widget, protectedInput: (ProtectedInput as any) as Widget, + slider: (Slider as any) as Widget, + checkbox: CustomWidget(Checkbox, (props) => ({ + checked: props.value, + onChange: (value) => props.onChange(value.target.checked), + disabled: props.disabled, + })), }; - constructor(props, context) { - super(props, context); - - this.state = { - classNames: ["needs-validation"], - providerName: this.props.settings ? this.props.settings.providerType : null, - formSchema: { ...formSchema }, - uiSchema: { ...uiSchema }, - formData: this.props.settings, - }; - - this.onFormSubmit = this.onFormSubmit.bind(this); - this.onFormValidate = this.onFormValidate.bind(this); - this.onFormChange = this.onFormChange.bind(this); - this.onFormCancel = this.onFormCancel.bind(this); - } - public componentDidMount() { if (this.props.settings) { this.bindForm(this.props.settings); @@ -119,7 +119,7 @@ export default class ExportForm extends React.Component { if (this.state.classNames.indexOf("was-validated") === -1) { this.setState({ classNames: [...this.state.classNames, "was-validated"], @@ -129,20 +129,22 @@ export default class ExportForm extends React.Component) => { + private onFormSubmit = (args: ISubmitEvent): void => { this.props.onSubmit(args.formData); } - private onFormCancel() { + private onFormCancel = (): void => { if (this.props.onCancel) { this.props.onCancel(); } } - private bindForm(exportFormat: IExportFormat, resetProviderOptions: boolean = false) { + private bindForm = (exportFormat: IExportFormat, resetProviderOptions: boolean = false): void => { // If no provider type was specified on bind, pick the default provider - const providerType = (exportFormat && exportFormat.providerType) ? - exportFormat.providerType : ExportProviderFactory.defaultProvider.name; + const providerType = (exportFormat && exportFormat.providerType) + ? exportFormat.providerType + : ExportProviderFactory.defaultProvider.name; + let newFormSchema: any = this.state.formSchema; let newUiSchema: any = this.state.uiSchema; diff --git a/src/react/components/pages/projectSettings/projectForm.tsx b/src/react/components/pages/projectSettings/projectForm.tsx index 3b157fdd..7eaf7b26 100644 --- a/src/react/components/pages/projectSettings/projectForm.tsx +++ b/src/react/components/pages/projectSettings/projectForm.tsx @@ -5,7 +5,7 @@ import { addLocValues, strings } from "../../../../common/strings"; import { IConnection, IProject, ITag, IAppSettings } from "../../../../models/applicationState"; import { StorageProviderFactory } from "../../../../providers/storage/storageProviderFactory"; import { ConnectionPickerWithRouter } from "../../common/connectionPicker/connectionPicker"; -import CustomField from "../../common/customField/customField"; +import { CustomField } from "../../common/customField/customField"; import CustomFieldTemplate from "../../common/customField/customFieldTemplate"; import { ISecurityTokenPickerProps, SecurityTokenPicker } from "../../common/securityTokenPicker/securityTokenPicker"; import "vott-react/dist/css/tagsInput.css"; diff --git a/src/registerProviders.ts b/src/registerProviders.ts index f8733b8d..4259824f 100644 --- a/src/registerProviders.ts +++ b/src/registerProviders.ts @@ -1,6 +1,6 @@ import { ExportProviderFactory } from "./providers/export/exportProviderFactory"; -import { TFPascalVOCJsonExportProvider } from "./providers/export/tensorFlowPascalVOC"; -import { TFRecordsJsonExportProvider } from "./providers/export/tensorFlowRecords"; +import { TFPascalVOCExportProvider } from "./providers/export/tensorFlowPascalVOC"; +import { TFRecordsExportProvider } from "./providers/export/tensorFlowRecords"; import { VottJsonExportProvider } from "./providers/export/vottJson"; import { AssetProviderFactory } from "./providers/storage/assetProviderFactory"; import { AzureBlobStorage } from "./providers/storage/azureBlobStorage"; @@ -56,12 +56,12 @@ export default function registerProviders() { ExportProviderFactory.register({ name: "tensorFlowPascalVOC", displayName: strings.export.providers.tfPascalVoc.displayName, - factory: (project, options) => new TFPascalVOCJsonExportProvider(project, options), + factory: (project, options) => new TFPascalVOCExportProvider(project, options), }); ExportProviderFactory.register({ name: "tensorFlowRecords", displayName: strings.export.providers.tfRecords.displayName, - factory: (project, options) => new TFRecordsJsonExportProvider(project, options), + factory: (project, options) => new TFRecordsExportProvider(project, options), }); ExportProviderFactory.register({ name: "azureCustomVision", diff --git a/src/services/assetService.test.ts b/src/services/assetService.test.ts index 074b6c2e..cb5173f0 100644 --- a/src/services/assetService.test.ts +++ b/src/services/assetService.test.ts @@ -293,7 +293,7 @@ describe("Asset Service", () => { }); HtmlFileReader.getAssetArray = jest.fn((asset) => { - return Promise.resolve(new Uint8Array(tfrecords)); + return Promise.resolve(new Uint8Array(tfrecords).buffer); }); it("Loads the asset metadata from the tfrecord file", async () => {