Search code examples
pythonmultidimensional-arraysignal-processingwavelet

How to extend pyWavelets to work with N-dimensional data?


This may be a question for a different forum, if so please let me know. I noticed that only 14 people follow the wavelet tag.

I've here an elegant way of extending the wavelet decomposition in pywt (pyWavelets package) to multiple dimensions. This should run out of the box if pywt is installed. Test 1 shows the decomposition and recomposition of a 3D array. All, one has to do is increase the number of dimensions and the code will work in decomposing/recomposing with 4, 6 or even 18 dimensions of data.

I've replaced the pywt.wavedec and pywt.waverec functions here. Also, in fn_dec, I show how the new wavedec function works just like the old one.

There is one catch though: It represents the wavelet coefficients as an array of the same shape as the data. As a consequence, with my limited knowledge of wavelets, I've only been able to use it for Haar wavelets. Others like DB4 for example bleed coefficients over the edges of this strict bounds (not a problem with the current representation of coefficients as list of arrays [CA, CD1 ... CDN]. Another catch is that I've only worked this with 2^N edge cuboids of data.

Theoretically, I think it should be possible to make sure that the "bleeding" does not occur. An algorithm for this sort of wavelet decomposition and recomposition is discussed in "numerical recipies in C" - by William Press, Saul A teukolsky, William T. Vetterling and Brian P. Flannery (Second Edition). Though this algorithm assumes reflection at the edges rather than the other forms of edge extensions (like zpd), the method is general enough to work for other forms of extension.

Any suggestion on how to extend this work to other wavelets?

NOTE: This query is also posted on http://groups.google.com/group/pywavelets

Thanks, Ajo

import pywt
import sys
import numpy as np

def waveFn(wavelet):
    if not isinstance(wavelet, pywt.Wavelet):
        return pywt.Wavelet(wavelet)
    else:
        return wavelet

# given a single dimensional array ... returns the coefficients.
def wavedec(data, wavelet, mode='sym'):
    wavelet = waveFn(wavelet)

    dLen = len(data)
    coeffs = np.zeros_like(data)
    level = pywt.dwt_max_level(dLen, wavelet.dec_len)

    a = data    
    end_idx = dLen
    for idx in xrange(level):
        a, d = pywt.dwt(a, wavelet, mode)
        begin_idx = end_idx/2
        coeffs[begin_idx:end_idx] = d
        end_idx = begin_idx

    coeffs[:end_idx] = a
    return coeffs

def waverec(data, wavelet, mode='sym'):
    wavelet = waveFn(wavelet)

    dLen = len(data)
    level = pywt.dwt_max_level(dLen, wavelet.dec_len)

    end_idx = 1
    a = data[:end_idx] # approximation ... also the original data 
    d = data[end_idx:end_idx*2]    
    for idx in xrange(level):
        a = pywt.idwt(a, d, wavelet, mode)
        end_idx *= 2
        d = data[end_idx:end_idx*2]
    return a

def fn_dec(arr):
    return np.array(map(lambda row: reduce(lambda x,y : np.hstack((x,y)), pywt.wavedec(row, 'haar', 'zpd')), arr))
    # return np.array(map(lambda row: row*2, arr))

if __name__ == '__main__':
    test  = 1
    np.random.seed(10)
    wavelet = waveFn('haar')
    if test==0:
        # SIngle dimensional test.
        a = np.random.randn(1,8)
        print "original values A"
        print a
        print "decomposition of A by method in pywt"
        print fn_dec(a)
        print " decomposition of A by my method"
        coeffs =  wavedec(a[0], 'haar', 'zpd')
        print coeffs
        print "recomposition of A by my method"
        print waverec(coeffs, 'haar', 'zpd')
        sys.exit()
    if test==1:
        a = np.random.randn(4,4,4)
        # 2 D test
        print "original value of A"
        print a

        # decompose the signal into wavelet coefficients.
        dimensions = a.shape
        for dim in dimensions:
            a = np.rollaxis(a, 0, a.ndim)
            ndim = a.shape
            #a = fn_dec(a.reshape(-1, dim))
            a = np.array(map(lambda row: wavedec(row, wavelet), a.reshape(-1, dim)))
            a = a.reshape(ndim)
        print " decomposition of signal into coefficients"
        print a

        # re-composition of the coefficients into original signal
        for dim in dimensions:
            a = np.rollaxis(a, 0, a.ndim)
            ndim = a.shape
            a = np.array(map(lambda row: waverec(row, wavelet), a.reshape(-1, dim)))
            a = a.reshape(ndim)
        print "recomposition of coefficients to signal"
        print a

Solution

  • First of all, I would like to point you to the function that already implements Single-level Multi-dimensional Transform (Source). It returns a dictionary of n-dimensional coefficients arrays. Coefficients are addressed by keys that describe type of the transform (approximation/details) applied to each of the dimensions.

    For example for a 2D case the result is a dictionary with approximation and details coefficients arrays:

    >>> pywt.dwtn([[1,2,3,4],[3,4,5,6],[5,6,7,8],[7,8,9,10]], 'db1')
    {'aa': [[5.0, 9.0], [13.0, 17.0]],
     'ad': [[-1.0, -1.0], [-1.0, -1.0]],
     'da': [[-2.0, -2.0], [-2.0, -2.0]],
     'dd': [[0.0, 0.0], [0.0, -0.0]]}
    

    Where aa is the coefficients array with approximation transform applied to both dimensions (LL) and da is the coefficients array with details transform applied to the first dimension and approximation transform applied to the second one (HL) (compare with dwt2 output).

    Based on that it should be fairly easy to extend it to the multi-level case.

    Here's my take on the decomposition part: https://gist.github.com/934166.

    I would also like to address one issue you mention in your question:

    There is one catch though: It represents the wavelet coefficients as an array of the same shape as the data.

    The approach of representing results as an array of the same shape/size as the input data is in my opinion harmful. It makes the whole thing unnecessarily complex to understand and work with because anyway you have to make assumptions or maintain a secondary data structure with indexes to be able to access coefficient in the output array and perform an inverse transform (see Matlab's documentation for wavedec/waverec).

    Also, even though it works great on paper, it does not always fit real world applications because of the problems you have mentioned: most of the times input data size is not 2^n and the decimated result of convolving signal with wavelet filter is larger that the "storage space", which in turn can lead to data loss and non-perfect reconstruction.

    To avoid these problems I would recommend using more natural data structures to represent the result data hierarchy, like Python's lists, dictionaries and tuples (where available).