Search code examples
pythonnumpyquantization

Image quantization with Numpy


I wanted to have a look at the example code for image quantization from here However, it's rather old and Python and NP have changed since then.

from pylab import imread,imshow,figure,show,subplot
from numpy import reshape,uint8,flipud
from scipy.cluster.vq import kmeans,vq

img = imread('clearsky.jpg')

# reshaping the pixels matrix
pixel = reshape(img,(img.shape[0]*img.shape[1],3))

# performing the clustering
# print(type(pixel))
centroids,_ = kmeans(pixel,6) # six colors will be found
# quantization
qnt,_ = vq(pixel,centroids)

# reshaping the result of the quantization
centers_idx = reshape(qnt,(img.shape[0],img.shape[1]))
clustered = centroids[centers_idx]

figure(1)
subplot(211)
imshow(flipud(img))
subplot(212)
imshow(flipud(clustered))
show()

It's falling over at line 12 centroids,_ = kmeans(pixel,6)

File "D:\Python30\lib\site-packages\scipy\cluster\vq.py", line 454, in kmeans book, dist = _kmeans(obs, guess, thresh=thresh) File "D:\Python30\lib\site-packages\scipy\cluster\vq.py", line 309, in _kmeans code_book.shape[0]) File "_vq.pyx", line 340, in scipy.cluster._vq.update_cluster_means TypeError: type other than float or double not supported

I can change 6 to 6.0, but confused as what to do for the NParray passing into kmeans.

What do I need to do to update the code to get the example up and running?


Solution

  • I think you just need to Convert Image Pixels to Floating Point: Before passing the pixels to kmeans, convert them to a floating-point type to ensure compatibility:

    import matplotlib.pyplot as plt
    import numpy as np
    from scipy.cluster.vq import kmeans, vq
    
    # Reading the image
    img = plt.imread('clearsky.jpg')
    
    # Convert image pixels to float
    pixel = np.reshape(img, (img.shape[0]*img.shape[1], 3)).astype(float)
    
    # Performing the clustering
    centroids, _ = kmeans(pixel, 6.0)  # six colors will be found
    
    # Quantization
    qnt, _ = vq(pixel, centroids)
    
    # Reshaping the result of the quantization
    centers_idx = np.reshape(qnt, (img.shape[0], img.shape[1]))
    clustered = centroids[centers_idx.astype(int)]
    
    # Displaying the original and quantized images
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    plt.imshow(img)
    plt.title('Original Image')
    
    plt.subplot(122)
    plt.imshow(clustered.astype(np.uint8))
    plt.title('Quantized Image')
    plt.show()