Search code examples
androidkotlintensorflow2.0tensorflow-lite

Image produced is incomplete - Cannot copy to a TensorFlowLite tensor (input_1) with bytes


I am trying to load a tflite model and run it on an image.

My tflite model has the dimensions you see in the image. tflite

Right now, I am receiving:

Cannot copy to a TensorFlowLite tensor (input_1) with 49152 bytes from a Java Buffer with 175584 bytes.

I can't understand how to work with input and output tensor sizes. Right now, I am initializing using the input image size and the output image size will be input * 4.

At which point do I have to "add" the 1 * 64 * 64 * 3 dimensions since I need to manipulate every input image size?

 try {
                    tflitemodel = loadModelFile()
                    tflite = Interpreter(tflitemodel, options)
                } catch (e: IOException) {
                    Log.e(TAG, "Fail to load model", e)
                }

                val imageTensorIndex = 0
                val imageShape: IntArray =
                    tflite.getInputTensor(imageTensorIndex).shape()
                val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()
                // Build a TensorImage object
                var inputImageBuffer = TensorImage(imageDataType);

                // Load the Bitmap
                inputImageBuffer.load(bitmap)

                // Preprocess image
                val imgprocessor = ImageProcessor.Builder()
                    .add(ResizeOp(inputImageBuffer.height,
                        inputImageBuffer.width,
                        ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                    //.add(NormalizeOp(127.5f, 127.5f))
                    //.add(QuantizeOp(128.0f, 1 / 128.0f))
                    .build()

                // Process the image
                val processedImage = imgprocessor.process(inputImageBuffer)

                // Access the buffer ( byte[] ) of the processedImage
                val imageBuffer = processedImage.buffer
                val imageTensorBuffer = processedImage.tensorBuffer

                // output result
                val outputImageBuffer = TensorBuffer.createFixedSize(
                    intArrayOf( inputImageBuffer.height * 4 ,
                        inputImageBuffer.width * 4 ) ,
                    DataType.FLOAT32 )

                // Normalize image
                val tensorProcessor = TensorProcessor.Builder()
                    // Normalize the tensor given the mean and the standard deviation
                    .add( NormalizeOp( 127.5f, 127.5f ) )
                    .add( CastOp( DataType.FLOAT32 ) )
                    .build()
                val processedOutputTensor = tensorProcessor.process(outputImageBuffer)


                tflite.run(imageTensorBuffer.buffer, processedOutputTensor.buffer)

I tried to cast the output tensor either to FLOAT32 or UINT8.

UPDATE

I also tried this :

 try {
         tflitemodel = loadModelFile()
         tflite = Interpreter(tflitemodel, options)
      } catch (e: IOException) {

          Log.e(TAG, "Fail to load model", e)
        }

 val imageTensorIndex = 0
 val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()

 val imgprocessor = ImageProcessor.Builder()
                    .add(ResizeOp(64,
                                 64,
                        ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
                        )
                    .add( NormalizeOp( 0.0f, 255.0f ) )
                    .add( CastOp( DataType.FLOAT32 ) )
                    .build()

 val inpIm = TensorImage(imageDataType)
 inpIm.load(bitmap)

 val processedImage = imgprocessor.process(inpIm)

 val output = TensorBuffer.createFixedSize(
                        intArrayOf(
                            124 * 4,
                            118 * 4,
                            3,
                            1
                        ),
                        DataType.FLOAT32
                    )

 val tensorProcessor = TensorProcessor.Builder()
                        
                        .add( NormalizeOp( 0.0f, 255.0f ) )
                        .add( CastOp( DataType.FLOAT32 ) )
                        .build()

 val processedOutputTensor = tensorProcessor.process(output)


 tflite.run(processedImage.buffer, processedOutputTensor.buffer)

which produces:

this image

Note, that the current image I am using as input has 124 * 118 * 3 dimensions.

The output image will have (124 * 4) * (118 * 4) * 3 dimensions.

The model needs 64 * 64 * 3 as input layer.


Solution

  • I took a look at your project, your class will be like:

    class MainActivity : AppCompatActivity() {
    
    
        private val TAG = "SuperResolution"
        private val MODEL_NAME = "model_edsr.tflite"
        private val LR_IMAGE_HEIGHT = 24
        private val LR_IMAGE_WIDTH = 24
        private val UPSCALE_FACTOR = 4
        private val SR_IMAGE_HEIGHT = LR_IMAGE_HEIGHT * UPSCALE_FACTOR
        private val SR_IMAGE_WIDTH = LR_IMAGE_WIDTH * UPSCALE_FACTOR
    
        private lateinit var photoButton: Button
        private lateinit var srButton: Button
        private lateinit var colorizeButton: Button
        private var FILE_NAME = "photo.jpg"
    
        private lateinit var filename:String
        private var resultImg: Bitmap? = null
    
        private lateinit var gpuSwitch: Switch
    
        private lateinit var tflite: Interpreter
        private lateinit var tflitemodel: ByteBuffer
    
        private val INPUT_SIZE: Int = 96
        private val PIXEL_SIZE: Int = 3
        private val IMAGE_MEAN = 0
        private val IMAGE_STD = 255.0f
    
    
        private var bitmap: Bitmap? = null
        private var bitmapResult: Bitmap? = null
    
        /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as input/output  */
        private lateinit var imgDataInput: ByteBuffer
        private lateinit var imgDataOutput: ByteBuffer
    
        /** Dimensions of inputs.  */
        private val DIM_BATCH_SIZE = 1
    
        private val DIM_PIXEL_SIZE = 3
    
        private val DIM_IMG_SIZE_X = 64
        private val DIM_IMG_SIZE_Y = 64
        private lateinit var catBitmap: Bitmap
        /* Preallocated buffers for storing image data in. */
        private val intValues = IntArray(DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y)
        private lateinit var superImage: ImageView
    
        override fun onCreate(savedInstanceState: Bundle?) {
            super.onCreate(savedInstanceState)
            setContentView(R.layout.activity_main)
            superImage = findViewById(R.id.super_resolution_image)
    
            //val assetManager = assets
            catBitmap = getBitmapFromAsset("cat.png")
    
    
            srButton = findViewById(R.id.super_resolution)
            srButton.setOnClickListener { view: View ->
    
                val intent = Intent(this, SelectedImage::class.java)
                getImageResult.launch(intent)
            }
    
    
        }
    
        private fun getBitmapFromAsset(filePath: String?): Bitmap {
            val assetManager = assets
            val istr: InputStream
            var bitmap: Bitmap? = null
            try {
                istr = assetManager.open(filePath!!)
                bitmap = BitmapFactory.decodeStream(istr)
            } catch (e: IOException) {
                // handle exception
                Log.e("Bitmap_except", e.toString())
    
            }
    
            if (bitmap != null) {
                bitmap = Bitmap.createScaledBitmap(bitmap,64,64,true)
            }
    
            return bitmap?: Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
        }
    
        private val getImageResult =
            registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { result ->
                if (result.resultCode == Activity.RESULT_OK) {
                    var theImageUri: Uri? = null
                    theImageUri = result.data?.getParcelableExtra<Uri>("imageuri")
    
                    filename = "SR_" + theImageUri?.getOriginalFileName(this).toString()
    
                    bitmap = uriToBitmap(theImageUri!!)!!//catBitmap//
                    Log.v("width", bitmap!!.width.toString())
    
                    if (bitmap != null) {
                        // call DL
                        val options = Interpreter.Options()
                        options.setNumThreads(5)
                        options.setUseNNAPI(true)
                        try {
                            tflitemodel = loadModelFile()
                            tflite = Interpreter(tflitemodel, options)
                            val index = tflite.getInputIndex("input_1")
                            tflite.resizeInput(
                                index,
                                intArrayOf(1, bitmap!!.width, bitmap!!.height, 3)
                            )
                        } catch (e: IOException) {
                            Log.e(TAG, "Fail to load model", e)
                        }
    
                        val imgprocessor = ImageProcessor.Builder()
                            .add(
                               ResizeOp(bitmap!!.width,
                                    bitmap!!.height,
                                    ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
                            )
                            .add( CastOp( DataType.FLOAT32 ) )
                            .build()
    
                        val inpIm = TensorImage(DataType.FLOAT32)
                        inpIm.load(bitmap)
    
                        // Process the image
                        val processedImage = imgprocessor.process(inpIm)
    
                        val output2 = Array(1) { Array(4*bitmap!!.width) { Array(4*bitmap!!.height) { FloatArray(3) } } }
    
                        tflite.run(processedImage.buffer, output2)
    
                        bitmapResult = convertArrayToBitmap(output2, 4*bitmap!!.height, 4*bitmap!!.width)
    
                        Log.v("widthHR", bitmapResult!!.height.toString())
                        superImage.setImageBitmap(bitmapResult)
    
                    }
                }
            }
    
    
        @Throws(IOException::class)
        private fun loadModelFile(): MappedByteBuffer {
            val fileDescriptor = assets.openFd(MODEL_NAME)
            val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
            val fileChannel = inputStream.channel
            val startOffset = fileDescriptor.startOffset
            val declaredLength = fileDescriptor.declaredLength
            return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
        }
    
    
        private fun uriToBitmap(selectedFileUri: Uri): Bitmap? {
            try {
                val parcelFileDescriptor = contentResolver.openFileDescriptor(selectedFileUri, "r")
                val fileDescriptor: FileDescriptor = parcelFileDescriptor!!.fileDescriptor
                val image = BitmapFactory.decodeFileDescriptor(fileDescriptor)
                parcelFileDescriptor.close()
                return image
            } catch (e: IOException) {
                e.printStackTrace()
            }
            return null
        }
    
        private fun getOutputImage(output: ByteBuffer): Bitmap? {
            output.rewind()
            val outputWidth = 124 * 4
            val outputHeight = 118 * 4
            val bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888)
            val pixels = IntArray(outputWidth * outputHeight)
            for (i in 0 until outputWidth * outputHeight) {
                val a = 0xFF
                val r = output.float * 255.0f
                val g = output.float * 255.0f
                val b = output.float * 255.0f
                pixels[i] = a shl 24 or (r.toInt() shl 16) or (g.toInt() shl 8) or b.toInt()
            }
            bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight)
            return bitmap
        }
    
        // save bitmap image to gallery
        private fun saveToGallery(context: Context, bitmap: Bitmap, albumName: String) {
            //val filename = "${System.currentTimeMillis()}.png"
            val write: (OutputStream) -> Boolean = {
                bitmap.compress(Bitmap.CompressFormat.PNG, 100, it)
            }
    
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
                val contentValues = ContentValues().apply {
                    put(MediaStore.MediaColumns.DISPLAY_NAME, filename)
                    put(MediaStore.MediaColumns.MIME_TYPE, "image/png")
                    put(MediaStore.MediaColumns.RELATIVE_PATH, "${Environment.DIRECTORY_DCIM}/$albumName")
                }
    
                context.contentResolver.let {
                    it.insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, contentValues)?.let { uri ->
                        it.openOutputStream(uri)?.let(write)
                    }
                }
            } else {
                val imagesDir = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DCIM).toString() + File.separator + albumName
                val file = File(imagesDir)
                if (!file.exists()) {
                    file.mkdir()
                }
                val image = File(imagesDir, filename)
                write(FileOutputStream(image))
            }
        }
    
        // get the filename from an image uri
        private fun Uri.getOriginalFileName(context: Context): String? {
            return context.contentResolver.query(this,
                null,
                null,
                null,
                null)?.use {
                val nameColumnIndex = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
                it.moveToFirst()
                it.getString(nameColumnIndex)
            }
        }
        fun convertArrayToBitmap(
            imageArray: Array<Array<Array<FloatArray>>>,
            imageWidth: Int,
            imageHeight: Int
        ): Bitmap {
    
            val conf = Bitmap.Config.ARGB_8888 // see other conf types
            val bitmap = Bitmap.createBitmap(imageWidth, imageHeight, conf)
    
            for (x in imageArray[0].indices) {
                for (y in imageArray[0][0].indices) {
                    // Create bitmap to show on screen after inference
                    val color = Color.rgb(
                        (imageArray[0][x][y][0]).toInt(),
                        (imageArray[0][x][y][1]).toInt(),
                        (imageArray[0][x][y][2]).toInt()
                    )
    
                    // this y, x is in the correct order!!!
                    bitmap.setPixel(y, x, color)
                }
            }
            return bitmap
        }
    
    }
    

    take a look inside how we resize the inputs of the model inside android, how we create input buffer and output array and how we convert the produced array to a Bitmap. For these procedures check if you can use Gpu of the phone to have x3 speed and of course there are plenty to read at the official documentation.