Im trying to understand tensorflow and I wanted to create a flutter app with Image Classification for cat or dog.
I used https://teachablemachine.withgoogle.com/ to train my model with 100 epoches and 128 batches. The ouput of my model is 97% accuracy for both the cat and dog. I used a dataset from kaggle with 4000 images for each cat and dog.
My Code:
import 'dart:io';
import 'package:tflite/tflite.dart';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
class MyHomePage extends StatefulWidget {
@override
_MyHomePageState createState() => _MyHomePageState();
}
class _MyHomePageState extends State<MyHomePage> {
File? _image;
bool _loading = false;
List<dynamic>? _output;
final _picker = ImagePicker();
pickImage() async {
var image = await _picker.getImage(source: ImageSource.camera);
if (image == null) {
return null;
}
setState(() {
_image = File(image.path);
});
classifyImage(_image);
}
pickGalleryImage() async {
var image = await _picker.getImage(source: ImageSource.gallery);
if (image == null) {
return null;
}
setState(() {
_image = File(image.path);
});
classifyImage(_image);
}
@override
void initState() {
super.initState();
_loading = true;
loadModel().then((value) {
// setState(() {});
});
}
@override
void dispose() {
Tflite.close();
super.dispose();
}
classifyImage(File? image) async {
var output = await Tflite.runModelOnImage(
path: image!.path,
numResults: 2,
threshold: 0.5,
imageMean: 127.5,
imageStd: 127.5,
);
setState(() {
_loading = false;
_output = output;
});
}
loadModel() async {
await Tflite.loadModel(
model: 'assets/model_unquant.tflite',
labels: 'assets/labels.txt',
);
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('Cat vs Dog Classifier'),
),
body: Center(
child: Column(
children: [
SizedBox(height: 160.0),
_image == null
? Text('No image selected')
: Container(
child: Image.file(_image!),
height: 250.0, // Fixed height for image
),
SizedBox(height: 20.0),
_output != null ? Text('${_output![0]['label']}') : Container(),
SizedBox(height: 50.0),
ElevatedButton(
onPressed: pickImage,
child: Text('Take Picture'),
),
ElevatedButton(
onPressed: pickGalleryImage,
child: Text('Camera Roll'),
),
],
),
),
);
}
}
If im picking a different image which isn't cat or dog, im still getting a 100% cat or dog feedback most of the time. How to not show these wrong results? What can we actually do?
You have to train Your original model on 3 classes:
then convert to tflite model and use it in your flutter project