Search code examples
pythonmachine-learningdeep-learningpytorchhaar-wavelet

Doing PyWavelets calculation on GPU


Currently working on a classifier using PyWavelets, here is my calculation block:

class WaveletLayer(nn.Module):
    def __init__(self):
        super(WaveletLayer, self).__init__()

    def forward(self, x):
        def wavelet_transform(img):
            coeffs = pywt.dwt2(img.cpu().numpy(), "haar")
            LL, (LH, HL, HH) = coeffs
            return (
                torch.from_numpy(LL).to(img.device),
                torch.from_numpy(LH).to(img.device),
                torch.from_numpy(HL).to(img.device),
                torch.from_numpy(HH).to(img.device),
            )

        # Apply wavelet transform to each channel separately
        LL, LH, HL, HH = zip(
            *[wavelet_transform(x[:, i : i + 1]) for i in range(x.shape[1])]
        )

        # Concatenate the results
        LL = torch.cat(LL, dim=1)
        LH = torch.cat(LH, dim=1)
        HL = torch.cat(HL, dim=1)
        HH = torch.cat(HH, dim=1)

        return torch.cat([LL, LH, HL, HH], dim=1)

The output from this module goes to a resnet block for learning, while doing this I find my CPU clogged and thus slowing down my training process

I am trying to use the GPUs for these calculations.


Solution

  • Since you only seem to be interested in the Haar wavelet, you can pretty much implement it yourself:

    • The high-frequency component of the Haar wavelet along each dimension can be written as a pairwise difference.
    • The low-frequency component of the Haar wavelet along each dimension can be written as a pairwise sum.

    The following code achieves this in pure PyTorch:

    class HaarWaveletLayer(nn.Module):
        
        def l_0(self, t):  # sum ("low") along cols
            t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
            return (t[..., ::2, :] + t[..., 1::2, :])
        def l_1(self, t):  # sum ("low") along rows
            t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
            return (t[..., :, ::2] + t[..., :, 1::2])
        def h_0(self, t):  # diff ("hi") along cols
            t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
            return (t[..., ::2, :] - t[..., 1::2, :])
        def h_1(self, t):  # diff ("hi") along rows
            t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
            return (t[..., :, ::2] - t[..., :, 1::2])
        
        def forward(self, x):
            
            x = .5 * x
            l_1 = self.l_1(x)
            h_1 = self.h_1(x)
            ll = self.l_0(l_1)
            lh = self.h_0(l_1)
            hl = self.l_0(h_1)
            hh = self.h_0(h_1)
            
            return torch.cat([ll, lh, hl, hh], dim=1)
    

    In combination with your given code, you can convince yourself of the equivalence as follows:

    t = torch.rand((7, 3, 127, 128)).to("cuda:0")
    result_given = WaveletLayer()(t)
    result_proposed = HaarWaveletLayer()(t)
    
    # Same result?
    assert (result_given - result_proposed).abs().max() < 1e-5
    
    # Time comparison
    from timeit import Timer
    num_timings = 100
    print("time given:   ", Timer(lambda: WaveletLayer()(t)).timeit(num_timings))
    print("time proposed:", Timer(lambda: HaarWaveletLayer()(t)).timeit(num_timings))
    

    The timing shows a speedup of more than a factor of 10 on my machine.

    Notes

    • The t = torch.cat... parts are only necessary if you want to be able to handle odd-shaped images: In that case, we pad by replicating the last row and column, respectively, mimicking the default padding of PyWavelets.
    • Multiplying x with .5 is done for normalization. Compare this discussion on the Signal Processing Stack Exchange for more details.