I'm writing a very simple network:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
training_data = np.array([[1, 1, 1], [2, 3, 1], [0, -1, 4], [0, 3, 0], [10, -6, 8], [-3, -12, 4]])
testing_data = np.array([6, 11, 1, 9, 10, -38])
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units = 1, activation = tf.keras.activations.relu, input_shape = (3, )))
model.compile(optimizer = tf.keras.optimizers.RMSprop(0.001), loss = tf.keras.losses.mean_squared_error, metrics = tf.keras.metrics.mean_squared_error)
model.summary()
model.fit(training_data, testing_data, epochs = 1, verbose = 'False')
print("Traning completed.")
model.predict(np.array([1, 1, 1]))
The goal is to train the weights like : aX + bY + cZ = (output)
But I get the error
ValueError: Input 0 of layer sequential_54 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape [None, 1]
I can't make scene of the dimensions, there is something I'm doing wrong! Any help?
In Keras when you specify the input shape batch size is ignored, please refer here for more details. Your declaration of input_shape = (3, )
is correct, but when you do inference you need to account for the batch size as well by adding an extra dimension for the same so instead of np.array([1, 1, 1])
you need to have np.array([[1, 1, 1]])
.
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
training_data = np.array([[1, 1, 1], [2, 3, 1], [0, -1, 4], [0, 3, 0], [10, -6, 8], [-3, -12, 4]])
testing_data = np.array([6, 11, 1, 9, 10, -38])
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units = 1, activation = tf.keras.activations.relu, input_shape = (3,)))
model.compile(optimizer = tf.keras.optimizers.RMSprop(0.001), loss = tf.keras.losses.mean_squared_error, metrics = [tf.keras.metrics.mean_squared_error])
model.summary()
model.fit(training_data, testing_data, epochs = 1, verbose = 'False')
print("Traning completed.")
model.predict(np.array([[1, 2, 1]]))
array([[0.08026636]], dtype=float32)