Search code examples
pythontensorflowgradio

How can I use a self made tensorflow CNN model in a gradio app?


I got a self made Tensorflow CNN dat I trained for handwritten digit recognition. The code is as follows:

import tensorflow as tf
from keras.datasets import mnist

(train_data, train_labels), (test_data, test_labels) = mnist.load_data()

model_2 = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

model_2.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model_2.fit(train_data, train_labels, epochs=15)

# Evaluate the model
model_2.evaluate(test_data, test_labels)

# Save model 2
model_2.save("mnist_model.h5")

And I would like to use it in a gradio app just in local hosting.

I tried creating a simple gradio interface to do my prediction.

import tensorflow as tf
import gradio as gr

# Load the model
model = tf.keras.models.load_model("mnist_model.h5")

app = gr.Interface(model, inputs=gr.Image(shape=(28, 28, 1)), outputs=gr.outputs.Label(num_top_classes=1))
app.launch()

When I input my image and I press the predict button i get an error:

Traceback (most recent call last):
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\routes.py", line 488, in run_predict
    output = await app.get_blocks().process_api(
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\blocks.py", line 1428, in process_api        
    inputs = self.preprocess_data(fn_index, inputs, state)
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\blocks.py", line 1245, in preprocess_data    
    processed_input.append(block.preprocess(inputs[i]))
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\components\image.py", line 279, in preprocess
    im = processing_utils.resize_and_crop(im, self.shape)
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\processing_utils.py", line 135, in resize_and_crop
    return ImageOps.fit(img, resize, centering=center)  # type: ignore
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\PIL\ImageOps.py", line 495, in fit
    return image.resize(size, method, box=crop)
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\PIL\Image.py", line 2082, in resize
    return self._new(self.im.resize(size, resample, box))
TypeError: argument 1 must be sequence of length 2, not 3

I even tried modifying the shape to (28, 28) but then I got this error:

Traceback (most recent call last):
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\routes.py", line 488, in run_predict
    output = await app.get_blocks().process_api(
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\blocks.py", line 1431, in process_api
    result = await self.call_function(
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\blocks.py", line 1109, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\anyio\to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\anyio\_backends\_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\anyio\_backends\_asyncio.py", line 807, in run
    result = context.run(func, *args)
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\gradio\utils.py", line 706, in wrapper
    response = f(*args, **kwargs)
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\matof\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\input_spec.py", line 295, in assert_input_compatibility
    raise ValueError(
ValueError: Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 28, 28, 1), found shape=(28, 28, 3)

How could I fix this?


Solution

  • After a lot of research and try's I found the solution to my problem. Simply grayscaling was not enough. Turns out I had more steps to do. I will post the solution here since it still might be useful to those that want to try this.

    import tensorflow as tf
    import gradio as gr
    
    # Load the model
    model = tf.keras.models.load_model("mnist_model_1.h5")
    
    def predict_number(image):
    """Predicts the number from the image."""
    # Convert the image to grayscale.
    image_gray = tf.image.rgb_to_grayscale(image)
    
    # Convert the image to a tensor.
    image_tensor = tf.convert_to_tensor(image_gray)
    
    # Resize the image to 28x28.
    image_tensor = tf.image.resize(image_tensor, (28, 28))
    
    # Cast the data to float32.
    image_tensor = tf.cast(image_tensor, tf.float32)
    
    # Add a batch dimension.
    image_tensor = tf.expand_dims(image_tensor, 0)
    
    # Normalize the data.
    image_tensor = image_tensor / 255.0
    
    # Get the prediction.
    prediction = model.predict(image_tensor)
    
    # Convert the prediction to a string label.
    prediction_label = str(prediction.argmax())
    
      # Return the prediction label.
      return prediction_label
    
    
    app = gr.Interface(predict_number, 
                     inputs=gr.Image(shape=(28, 28)), 
                     outputs=gr.Label(num_classes=10),
                     examples=[
                        '1_1.png', '1_2.png', '2_1.png', '2_2.png', '3_1.png',
                        '3_2.png', '4_1.png', '4_2.png', '5_1.png', '5_2.png',
                        '6_1.png', '6_2.png', '7_1.png', '7_2.png', '8_1.png',
                        '8_2.png', '9_1.png', '9_2.png', '0_1.png', '0_2.png']
                     )
    app.launch()
    

    Meanwhile there are some additions since I worked further on it for my use case the answer here should still be helpful!