From e5698c08a72b5326f7f672e3adc946c78dba9352 Mon Sep 17 00:00:00 2001 From: dakirsa Date: Wed, 4 Oct 2017 09:16:52 -0400 Subject: [PATCH] Added grayscale support; fixed partitioning --- .../org/apache/spark/image/ImageSchema.scala | 44 ++++++++++++++----- .../apache/spark/image/TestImageSchema.scala | 6 +-- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/main/scala/org/apache/spark/image/ImageSchema.scala b/src/main/scala/org/apache/spark/image/ImageSchema.scala index b3341ee..bf2ae00 100644 --- a/src/main/scala/org/apache/spark/image/ImageSchema.scala +++ b/src/main/scala/org/apache/spark/image/ImageSchema.scala @@ -9,6 +9,7 @@ import java.awt.image.BufferedImage import java.awt.{Color, Image} import java.io.ByteArrayInputStream import javax.imageio.ImageIO +import java.awt.color.ColorSpace object ImageSchema{ @@ -74,27 +75,45 @@ object ImageSchema{ None } else { + val is_gray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY + val has_alpha = img.getColorModel.hasAlpha + val height = img.getHeight val width = img.getWidth - val (nChannels, mode) = if(img.getColorModel().hasAlpha()) (4, "CV_8UC4") else (3, "CV_8UC3") //TODO: grayscale + val (nChannels, mode) = if(is_gray) (1, "CV_8UC1") + else if(has_alpha) (4, "CV_8UC4") + else (3, "CV_8UC3") assert(height*width*nChannels < 1e9, "image is too large") val decoded = Array.ofDim[Byte](height*width*nChannels) - var offset = 0 - for(h <- 0 until height) { - for (w <- 0 until width) { - val color = new Color(img.getRGB(w, h)) - - decoded(offset) = color.getBlue.toByte - decoded(offset+1) = color.getGreen.toByte - decoded(offset+2) = color.getRed.toByte - if(nChannels == 4){ - decoded(offset+3) = color.getAlpha.toByte + // grayscale images in Java require special handling to get the correct intensity + if(is_gray){ + var offset = 0 + val raster = img.getRaster + for(h <- 0 until height) { + for (w <- 0 until width) { + decoded(offset) = raster.getSample(w, h, 0).toByte + offset += 1 } - offset += nChannels } } + else{ + var offset = 0 + for (h <- 0 until height) { + for (w <- 0 until width) { + val color = new Color(img.getRGB(w, h)) + + decoded(offset) = color.getBlue.toByte + decoded(offset + 1) = color.getGreen.toByte + decoded(offset + 2) = color.getRed.toByte + if (nChannels == 4) { + decoded(offset + 3) = color.getAlpha.toByte + } + offset += nChannels + } + } + } // the internal "Row" is needed, because the image is a single dataframe column Some(Row(Row(origin, height, width, nChannels, mode, decoded))) @@ -132,6 +151,7 @@ object ImageSchema{ var result: DataFrame = null try { val streams = session.sparkContext.binaryFiles(path, partitions) + .repartition(partitions) val images = if(dropImageFailures){ streams.flatMap{ diff --git a/src/test/scala/org/apache/spark/image/TestImageSchema.scala b/src/test/scala/org/apache/spark/image/TestImageSchema.scala index 914bdbc..e8f0791 100644 --- a/src/test/scala/org/apache/spark/image/TestImageSchema.scala +++ b/src/test/scala/org/apache/spark/image/TestImageSchema.scala @@ -53,8 +53,7 @@ class TestImageSchemaSuite extends FunSuite with TestSparkContext { assert(count50 > 0.2 * count100 && count50 < 0.8 * count100) } - // TODO: fix the partition test - ignore("readImages partition test") { + test("readImages partition test") { val df = readImages(imagePath, recursive = true, dropImageFailures = true, numPartitions = 3) assert(df.rdd.getNumPartitions == 3) } @@ -88,12 +87,11 @@ class TestImageSchemaSuite extends FunSuite with TestSparkContext { } } - // TODO: fix grayscale test // number of channels and first 20 bytes of OpenCV representation // - default representation for 3-channel RGB images is BGR row-wise: (B00, G00, R00, B10, G10, R10, ...) // - default representation for 4-channel RGB images is BGRA row-wise: (B00, G00, R00, A00, B10, G10, R10, A00, ...) private val firstBytes20 = Map( - //"grayscale.png" -> (("CV_8UC1", Array[Byte](0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 3, 5, 2, 1))), + "grayscale.png" -> (("CV_8UC1", Array[Byte](0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 3, 5, 2, 1))), "RGB.png" -> (("CV_8UC3", Array[Byte](-34, -66, -98, -38, -69, -98, -62, -90, -117, -70, -98, -124, -34, -63, -90, -20, -48, -74, -18, -45))), "RGBA.png" -> (("CV_8UC4", Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))) )