Search code examples

Use "tf.contrib.factorization.KMeansClustering"

Referring to this Link, (the Link) I try to practice using tf.contrib.factorization.KMeansClustering for clustering. The simple codes as follow works okay:

import numpy as np
import tensorflow as tf

# ---- Create Data Sample -----
k = 5
n = 100
variables = 5
points = np.random.uniform(0, 1000, [n, variables])

# ---- Clustering -----
input_fn=lambda: tf.train.limit_epochs(tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1)
centers = kmeans.cluster_centers()

# ---- Print out -----
cluster_indices = list(kmeans.predict_cluster_index(input_fn))
for i, point in enumerate(points):
  cluster_index = cluster_indices[i]
  print ('point:', point, 'is in cluster', cluster_index, 'centered at', centers[cluster_index])

My question is why would this "input_fn" code does the trick? If I change the code to this, it will run into an infinite loop. Why??

input_fn=lambda:tf.convert_to_tensor(points, dtype=tf.float32)

From the document (here), it seems that train() is expecting argument of input_fn, which is simply a A '' object , like Tensor(X). So, why do I have to do all these tricky things regarding lambda: tf.train.limit_epochs()?

Can anyone who is familiar with the fundamental of tensorflow estimators help to explain? Many Thanks!


  • My question is why would this "input_fn" code does the trick? If I change the code to this, it will run into an infinite loop. Why??

    The documentation states that input_fn is called repeatedly until it returns a tf.errors.OutOfRangeError. Adorning your tensor with tf.train.limit_epochs ensures that the error is eventually raised, which signals to KMeans that it should stop training.