Search code examples
pythontensorflowmachine-learningk-meanstensorflow-estimator

Tensorflow input function for K-means clustering error


This is a simple code to create a tensorflow graph to cluster iris data. It uses tf.estimator.inputs.numpy_input_fn to define an input function for a tf.contrib.learn.KMeansClustering k-means clusterer.

import os
import numpy as np
import tensorflow as tf

# Data set
IRIS = "iris.csv"

# Load datasets
iris = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS,
    target_dtype=np.int,
    features_dtype=np.float32)

# Build KMeans Clustering model.
num_clusters = 4
estimator = tf.contrib.learn.KMeansClustering(
    num_clusters,
    model_dir="/tmp/iris_model",
    initial_clusters='random',
    distance_metric='squared_euclidean',
    use_mini_batch=True)

# Define the training inputs
input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": np.array(iris.data)},
    y=np.array(iris.target),
    batch_size=4,
    num_epochs=None,
    shuffle=True)

# Fit model.
clusters = estimator.fit(input_fn=input_fn)

But Tensorflow returns the following error:

...
AssertionError: Tensor("random_shuffle_queue_DequeueMany:2", shape=(4,), dtype=int64)

Do you know why I am getting this error and and how to debug it?

iris.csv:

150,4,setosa,versicolor,virginica
6.4,2.8,5.6,2.2,2
5.0,2.3,3.3,1.0,1
4.9,2.5,4.5,1.7,2
4.9,3.1,1.5,0.1,0
5.7,3.8,1.7,0.3,0
4.4,3.2,1.3,0.2,0
5.4,3.4,1.5,0.4,0
6.9,3.1,5.1,2.3,2
6.7,3.1,4.4,1.4,1
5.1,3.7,1.5,0.4,0
5.2,2.7,3.9,1.4,1
6.9,3.1,4.9,1.5,1
5.8,4.0,1.2,0.2,0
5.4,3.9,1.7,0.4,0
7.7,3.8,6.7,2.2,2
6.3,3.3,4.7,1.6,1
6.8,3.2,5.9,2.3,2
7.6,3.0,6.6,2.1,2
6.4,3.2,5.3,2.3,2
5.7,4.4,1.5,0.4,0
6.7,3.3,5.7,2.1,2
6.4,2.8,5.6,2.1,2
5.4,3.9,1.3,0.4,0
6.1,2.6,5.6,1.4,2
7.2,3.0,5.8,1.6,2
5.2,3.5,1.5,0.2,0
5.8,2.6,4.0,1.2,1
5.9,3.0,5.1,1.8,2
5.4,3.0,4.5,1.5,1
6.7,3.0,5.0,1.7,1
6.3,2.3,4.4,1.3,1
5.1,2.5,3.0,1.1,1
6.4,3.2,4.5,1.5,1
6.8,3.0,5.5,2.1,2
6.2,2.8,4.8,1.8,2
6.9,3.2,5.7,2.3,2
6.5,3.2,5.1,2.0,2
5.8,2.8,5.1,2.4,2
5.1,3.8,1.5,0.3,0
4.8,3.0,1.4,0.3,0
7.9,3.8,6.4,2.0,2
5.8,2.7,5.1,1.9,2
6.7,3.0,5.2,2.3,2
5.1,3.8,1.9,0.4,0
4.7,3.2,1.6,0.2,0
6.0,2.2,5.0,1.5,2
4.8,3.4,1.6,0.2,0
7.7,2.6,6.9,2.3,2
4.6,3.6,1.0,0.2,0
7.2,3.2,6.0,1.8,2
5.0,3.3,1.4,0.2,0
6.6,3.0,4.4,1.4,1
6.1,2.8,4.0,1.3,1
5.0,3.2,1.2,0.2,0
7.0,3.2,4.7,1.4,1
6.0,3.0,4.8,1.8,2
7.4,2.8,6.1,1.9,2
5.8,2.7,5.1,1.9,2
6.2,3.4,5.4,2.3,2
5.0,2.0,3.5,1.0,1
5.6,2.5,3.9,1.1,1
6.7,3.1,5.6,2.4,2
6.3,2.5,5.0,1.9,2
6.4,3.1,5.5,1.8,2
6.2,2.2,4.5,1.5,1
7.3,2.9,6.3,1.8,2
4.4,3.0,1.3,0.2,0
7.2,3.6,6.1,2.5,2
6.5,3.0,5.5,1.8,2
5.0,3.4,1.5,0.2,0
4.7,3.2,1.3,0.2,0
6.6,2.9,4.6,1.3,1
5.5,3.5,1.3,0.2,0
7.7,3.0,6.1,2.3,2
6.1,3.0,4.9,1.8,2
4.9,3.1,1.5,0.1,0
5.5,2.4,3.8,1.1,1
5.7,2.9,4.2,1.3,1
6.0,2.9,4.5,1.5,1
6.4,2.7,5.3,1.9,2
5.4,3.7,1.5,0.2,0
6.1,2.9,4.7,1.4,1
6.5,2.8,4.6,1.5,1
5.6,2.7,4.2,1.3,1
6.3,3.4,5.6,2.4,2
4.9,3.1,1.5,0.1,0
6.8,2.8,4.8,1.4,1
5.7,2.8,4.5,1.3,1
6.0,2.7,5.1,1.6,1
5.0,3.5,1.3,0.3,0
6.5,3.0,5.2,2.0,2
6.1,2.8,4.7,1.2,1
5.1,3.5,1.4,0.3,0
4.6,3.1,1.5,0.2,0
6.5,3.0,5.8,2.2,2
4.6,3.4,1.4,0.3,0
4.6,3.2,1.4,0.2,0
7.7,2.8,6.7,2.0,2
5.9,3.2,4.8,1.8,1
5.1,3.8,1.6,0.2,0
4.9,3.0,1.4,0.2,0
4.9,2.4,3.3,1.0,1
4.5,2.3,1.3,0.3,0
5.8,2.7,4.1,1.0,1
5.0,3.4,1.6,0.4,0
5.2,3.4,1.4,0.2,0
5.3,3.7,1.5,0.2,0
5.0,3.6,1.4,0.2,0
5.6,2.9,3.6,1.3,1
4.8,3.1,1.6,0.2,0
6.3,2.7,4.9,1.8,2
5.7,2.8,4.1,1.3,1
5.0,3.0,1.6,0.2,0
6.3,3.3,6.0,2.5,2
5.0,3.5,1.6,0.6,0
5.5,2.6,4.4,1.2,1
5.7,3.0,4.2,1.2,1
4.4,2.9,1.4,0.2,0
4.8,3.0,1.4,0.1,0
5.5,2.4,3.7,1.0,1
5.9,3.0,4.2,1.5,1
6.9,3.1,5.4,2.1,2
5.1,3.3,1.7,0.5,0
6.0,3.4,4.5,1.6,1
5.5,2.5,4.0,1.3,1
6.2,2.9,4.3,1.3,1
5.5,4.2,1.4,0.2,0
6.3,2.8,5.1,1.5,2
5.6,3.0,4.1,1.3,1
6.7,2.5,5.8,1.8,2
7.1,3.0,5.9,2.1,2
4.3,3.0,1.1,0.1,0
5.6,2.8,4.9,2.0,2
5.5,2.3,4.0,1.3,1
6.0,2.2,4.0,1.0,1
5.1,3.5,1.4,0.2,0
5.7,2.6,3.5,1.0,1
4.8,3.4,1.9,0.2,0
5.1,3.4,1.5,0.2,0
5.7,2.5,5.0,2.0,2
5.4,3.4,1.7,0.2,0
5.6,3.0,4.5,1.5,1
6.3,2.9,5.6,1.8,2
6.3,2.5,4.9,1.5,1
5.8,2.7,3.9,1.2,1
6.1,3.0,4.6,1.4,1
5.2,4.1,1.5,0.1,0
6.7,3.1,4.7,1.5,1
6.7,3.3,5.7,2.5,2
6.4,2.9,4.3,1.3,1

Solution

  • You can't pass tf.estimator.inputs.numpy_input_fn into any of the tf.contrib classes because it doesn't return features, labels, but something that encapsulates both.

    When using the older contrib classes, your best option is to write the input function yourself. This is how:

    def make_numpy_input_fn(x, y, batch_size):
      def input_fn():
        features, labels = tf.train.shuffle_batch(
                                 [tf.constant(x), tf.constant(y)],
                                 batch_size=batch_size, 
                                 capacity=50*batch_size,
                                 min_after_dequeue=20*batch_size,
                                 enqueue_many=True)
        features = {'x': features}
        return features, labels
      return input_fn