I am trying to load a tflite model and run it on an image.
My tflite model has the dimensions you see in the image.
Right now, I am receiving:
Cannot copy to a TensorFlowLite tensor (input_1) with 49152 bytes from a Java Buffer with 175584 bytes.
I can't understand how to work with input and output tensor sizes. Right now, I am initializing using the input image size and the output image size will be input * 4
.
At which point do I have to "add" the 1 * 64 * 64 * 3
dimensions since I need to manipulate every input image size?
try {
tflitemodel = loadModelFile()
tflite = Interpreter(tflitemodel, options)
} catch (e: IOException) {
Log.e(TAG, "Fail to load model", e)
}
val imageTensorIndex = 0
val imageShape: IntArray =
tflite.getInputTensor(imageTensorIndex).shape()
val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()
// Build a TensorImage object
var inputImageBuffer = TensorImage(imageDataType);
// Load the Bitmap
inputImageBuffer.load(bitmap)
// Preprocess image
val imgprocessor = ImageProcessor.Builder()
.add(ResizeOp(inputImageBuffer.height,
inputImageBuffer.width,
ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
//.add(NormalizeOp(127.5f, 127.5f))
//.add(QuantizeOp(128.0f, 1 / 128.0f))
.build()
// Process the image
val processedImage = imgprocessor.process(inputImageBuffer)
// Access the buffer ( byte[] ) of the processedImage
val imageBuffer = processedImage.buffer
val imageTensorBuffer = processedImage.tensorBuffer
// output result
val outputImageBuffer = TensorBuffer.createFixedSize(
intArrayOf( inputImageBuffer.height * 4 ,
inputImageBuffer.width * 4 ) ,
DataType.FLOAT32 )
// Normalize image
val tensorProcessor = TensorProcessor.Builder()
// Normalize the tensor given the mean and the standard deviation
.add( NormalizeOp( 127.5f, 127.5f ) )
.add( CastOp( DataType.FLOAT32 ) )
.build()
val processedOutputTensor = tensorProcessor.process(outputImageBuffer)
tflite.run(imageTensorBuffer.buffer, processedOutputTensor.buffer)
I tried to cast the output tensor either to FLOAT32 or UINT8.
UPDATE
I also tried this :
try {
tflitemodel = loadModelFile()
tflite = Interpreter(tflitemodel, options)
} catch (e: IOException) {
Log.e(TAG, "Fail to load model", e)
}
val imageTensorIndex = 0
val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()
val imgprocessor = ImageProcessor.Builder()
.add(ResizeOp(64,
64,
ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
)
.add( NormalizeOp( 0.0f, 255.0f ) )
.add( CastOp( DataType.FLOAT32 ) )
.build()
val inpIm = TensorImage(imageDataType)
inpIm.load(bitmap)
val processedImage = imgprocessor.process(inpIm)
val output = TensorBuffer.createFixedSize(
intArrayOf(
124 * 4,
118 * 4,
3,
1
),
DataType.FLOAT32
)
val tensorProcessor = TensorProcessor.Builder()
.add( NormalizeOp( 0.0f, 255.0f ) )
.add( CastOp( DataType.FLOAT32 ) )
.build()
val processedOutputTensor = tensorProcessor.process(output)
tflite.run(processedImage.buffer, processedOutputTensor.buffer)
which produces:
Note, that the current image I am using as input has 124 * 118 * 3
dimensions.
The output image will have (124 * 4) * (118 * 4) * 3
dimensions.
The model needs 64 * 64 * 3
as input layer.
I took a look at your project, your class will be like:
class MainActivity : AppCompatActivity() {
private val TAG = "SuperResolution"
private val MODEL_NAME = "model_edsr.tflite"
private val LR_IMAGE_HEIGHT = 24
private val LR_IMAGE_WIDTH = 24
private val UPSCALE_FACTOR = 4
private val SR_IMAGE_HEIGHT = LR_IMAGE_HEIGHT * UPSCALE_FACTOR
private val SR_IMAGE_WIDTH = LR_IMAGE_WIDTH * UPSCALE_FACTOR
private lateinit var photoButton: Button
private lateinit var srButton: Button
private lateinit var colorizeButton: Button
private var FILE_NAME = "photo.jpg"
private lateinit var filename:String
private var resultImg: Bitmap? = null
private lateinit var gpuSwitch: Switch
private lateinit var tflite: Interpreter
private lateinit var tflitemodel: ByteBuffer
private val INPUT_SIZE: Int = 96
private val PIXEL_SIZE: Int = 3
private val IMAGE_MEAN = 0
private val IMAGE_STD = 255.0f
private var bitmap: Bitmap? = null
private var bitmapResult: Bitmap? = null
/** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as input/output */
private lateinit var imgDataInput: ByteBuffer
private lateinit var imgDataOutput: ByteBuffer
/** Dimensions of inputs. */
private val DIM_BATCH_SIZE = 1
private val DIM_PIXEL_SIZE = 3
private val DIM_IMG_SIZE_X = 64
private val DIM_IMG_SIZE_Y = 64
private lateinit var catBitmap: Bitmap
/* Preallocated buffers for storing image data in. */
private val intValues = IntArray(DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y)
private lateinit var superImage: ImageView
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
superImage = findViewById(R.id.super_resolution_image)
//val assetManager = assets
catBitmap = getBitmapFromAsset("cat.png")
srButton = findViewById(R.id.super_resolution)
srButton.setOnClickListener { view: View ->
val intent = Intent(this, SelectedImage::class.java)
getImageResult.launch(intent)
}
}
private fun getBitmapFromAsset(filePath: String?): Bitmap {
val assetManager = assets
val istr: InputStream
var bitmap: Bitmap? = null
try {
istr = assetManager.open(filePath!!)
bitmap = BitmapFactory.decodeStream(istr)
} catch (e: IOException) {
// handle exception
Log.e("Bitmap_except", e.toString())
}
if (bitmap != null) {
bitmap = Bitmap.createScaledBitmap(bitmap,64,64,true)
}
return bitmap?: Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
}
private val getImageResult =
registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { result ->
if (result.resultCode == Activity.RESULT_OK) {
var theImageUri: Uri? = null
theImageUri = result.data?.getParcelableExtra<Uri>("imageuri")
filename = "SR_" + theImageUri?.getOriginalFileName(this).toString()
bitmap = uriToBitmap(theImageUri!!)!!//catBitmap//
Log.v("width", bitmap!!.width.toString())
if (bitmap != null) {
// call DL
val options = Interpreter.Options()
options.setNumThreads(5)
options.setUseNNAPI(true)
try {
tflitemodel = loadModelFile()
tflite = Interpreter(tflitemodel, options)
val index = tflite.getInputIndex("input_1")
tflite.resizeInput(
index,
intArrayOf(1, bitmap!!.width, bitmap!!.height, 3)
)
} catch (e: IOException) {
Log.e(TAG, "Fail to load model", e)
}
val imgprocessor = ImageProcessor.Builder()
.add(
ResizeOp(bitmap!!.width,
bitmap!!.height,
ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
)
.add( CastOp( DataType.FLOAT32 ) )
.build()
val inpIm = TensorImage(DataType.FLOAT32)
inpIm.load(bitmap)
// Process the image
val processedImage = imgprocessor.process(inpIm)
val output2 = Array(1) { Array(4*bitmap!!.width) { Array(4*bitmap!!.height) { FloatArray(3) } } }
tflite.run(processedImage.buffer, output2)
bitmapResult = convertArrayToBitmap(output2, 4*bitmap!!.height, 4*bitmap!!.width)
Log.v("widthHR", bitmapResult!!.height.toString())
superImage.setImageBitmap(bitmapResult)
}
}
}
@Throws(IOException::class)
private fun loadModelFile(): MappedByteBuffer {
val fileDescriptor = assets.openFd(MODEL_NAME)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
private fun uriToBitmap(selectedFileUri: Uri): Bitmap? {
try {
val parcelFileDescriptor = contentResolver.openFileDescriptor(selectedFileUri, "r")
val fileDescriptor: FileDescriptor = parcelFileDescriptor!!.fileDescriptor
val image = BitmapFactory.decodeFileDescriptor(fileDescriptor)
parcelFileDescriptor.close()
return image
} catch (e: IOException) {
e.printStackTrace()
}
return null
}
private fun getOutputImage(output: ByteBuffer): Bitmap? {
output.rewind()
val outputWidth = 124 * 4
val outputHeight = 118 * 4
val bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888)
val pixels = IntArray(outputWidth * outputHeight)
for (i in 0 until outputWidth * outputHeight) {
val a = 0xFF
val r = output.float * 255.0f
val g = output.float * 255.0f
val b = output.float * 255.0f
pixels[i] = a shl 24 or (r.toInt() shl 16) or (g.toInt() shl 8) or b.toInt()
}
bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight)
return bitmap
}
// save bitmap image to gallery
private fun saveToGallery(context: Context, bitmap: Bitmap, albumName: String) {
//val filename = "${System.currentTimeMillis()}.png"
val write: (OutputStream) -> Boolean = {
bitmap.compress(Bitmap.CompressFormat.PNG, 100, it)
}
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
val contentValues = ContentValues().apply {
put(MediaStore.MediaColumns.DISPLAY_NAME, filename)
put(MediaStore.MediaColumns.MIME_TYPE, "image/png")
put(MediaStore.MediaColumns.RELATIVE_PATH, "${Environment.DIRECTORY_DCIM}/$albumName")
}
context.contentResolver.let {
it.insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, contentValues)?.let { uri ->
it.openOutputStream(uri)?.let(write)
}
}
} else {
val imagesDir = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DCIM).toString() + File.separator + albumName
val file = File(imagesDir)
if (!file.exists()) {
file.mkdir()
}
val image = File(imagesDir, filename)
write(FileOutputStream(image))
}
}
// get the filename from an image uri
private fun Uri.getOriginalFileName(context: Context): String? {
return context.contentResolver.query(this,
null,
null,
null,
null)?.use {
val nameColumnIndex = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
it.moveToFirst()
it.getString(nameColumnIndex)
}
}
fun convertArrayToBitmap(
imageArray: Array<Array<Array<FloatArray>>>,
imageWidth: Int,
imageHeight: Int
): Bitmap {
val conf = Bitmap.Config.ARGB_8888 // see other conf types
val bitmap = Bitmap.createBitmap(imageWidth, imageHeight, conf)
for (x in imageArray[0].indices) {
for (y in imageArray[0][0].indices) {
// Create bitmap to show on screen after inference
val color = Color.rgb(
(imageArray[0][x][y][0]).toInt(),
(imageArray[0][x][y][1]).toInt(),
(imageArray[0][x][y][2]).toInt()
)
// this y, x is in the correct order!!!
bitmap.setPixel(y, x, color)
}
}
return bitmap
}
}
take a look inside how we resize the inputs of the model inside android, how we create input buffer and output array and how we convert the produced array to a Bitmap. For these procedures check if you can use Gpu of the phone to have x3 speed and of course there are plenty to read at the official documentation.