Search code examples
fluttertensorflowtensorflow2.0tensorflow-lite

Flutter tflite image classification how to not show wrong results


Im trying to understand tensorflow and I wanted to create a flutter app with Image Classification for cat or dog.

I used https://teachablemachine.withgoogle.com/ to train my model with 100 epoches and 128 batches. The ouput of my model is 97% accuracy for both the cat and dog. I used a dataset from kaggle with 4000 images for each cat and dog.

My Code:

import 'dart:io';
import 'package:tflite/tflite.dart';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  File? _image;
  bool _loading = false;
  List<dynamic>? _output;
  final _picker = ImagePicker();

  pickImage() async {
    var image = await _picker.getImage(source: ImageSource.camera);

    if (image == null) {
      return null;
    }

    setState(() {
      _image = File(image.path);
    });
    classifyImage(_image);
  }

  pickGalleryImage() async {
    var image = await _picker.getImage(source: ImageSource.gallery);

    if (image == null) {
      return null;
    }

    setState(() {
      _image = File(image.path);
    });
    classifyImage(_image);
  }

  @override
  void initState() {
    super.initState();
    _loading = true;
    loadModel().then((value) {
      // setState(() {});
    });
  }

  @override
  void dispose() {
    Tflite.close();
    super.dispose();
  }

  classifyImage(File? image) async {
    var output = await Tflite.runModelOnImage(
      path: image!.path,
      numResults: 2,
      threshold: 0.5,
      imageMean: 127.5,
      imageStd: 127.5,
    );
    setState(() {
      _loading = false;
      _output = output;
    });
  }

  loadModel() async {
    await Tflite.loadModel(
      model: 'assets/model_unquant.tflite',
      labels: 'assets/labels.txt',
    );
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('Cat vs Dog Classifier'),
      ),
      body: Center(
        child: Column(
          children: [
            SizedBox(height: 160.0),
            _image == null
                ? Text('No image selected')
                : Container(
                    child: Image.file(_image!),
                    height: 250.0, // Fixed height for image
                  ),
            SizedBox(height: 20.0),
            _output != null ? Text('${_output![0]['label']}') : Container(),
            SizedBox(height: 50.0),
            ElevatedButton(
              onPressed: pickImage,
              child: Text('Take Picture'),
            ),
            ElevatedButton(
              onPressed: pickGalleryImage,
              child: Text('Camera Roll'),
            ),
          ],
        ),
      ),
    );
  }
}

My question:

If im picking a different image which isn't cat or dog, im still getting a 100% cat or dog feedback most of the time. How to not show these wrong results? What can we actually do?


Solution

  • You have to train Your original model on 3 classes:

    • cat
    • dog
    • other

    then convert to tflite model and use it in your flutter project