Search code examples
matlabimage-processingimage-thresholdingdwt

How do I find local threshold for coefficients in image compression using DWT in MATLAB


I'm trying to write an image compression script in MATLAB using multilayer 3D DWT(color image). along the way, I want to apply thresholding on coefficient matrices, both global and local thresholds. I like to use the formula below to calculate my local threshold:

enter image description here

where sigma is variance and N is the number of elements.

Global thresholding works fine; but my problem is that the calculated local threshold is (most often!) greater than the maximum band coefficient, therefore no thresholding is applied.

Everything else works fine and I get a result too, but I suspect the local threshold is miscalculated. Also, the resulting image is larger than the original! I'd appreciate any help on the correct way to calculate the local threshold, or if there's a pre-set MATLAB function.

here's an example output:

coefficients

input(1.3Mb) and output(5.3Mb!)

here's my code:

clear;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%    COMPRESSION    %%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% read base image
% dwt 3/5-L on base images
% quantize coeffs (local/global)
% count zero value-ed coeffs
% calculate mse/psnr
% save and show result

% read images
base = imread('circ.jpg');
fam = 'haar'; % wavelet family
lvl = 3; % wavelet depth
% set to 1 to apply global thr
thr_type = 0;
% global threshold value
gthr = 180;

% convert base to grayscale
%base = rgb2gray(base);

% apply dwt on base image
dc = wavedec3(base, lvl, fam);

% extract coeffs
ll_base = dc.dec{1};
lh_base = dc.dec{2};
hl_base = dc.dec{3};
hh_base = dc.dec{4};

ll_var = var(ll_base, 0);
lh_var = var(lh_base, 0);
hl_var = var(hl_base, 0);
hh_var = var(hh_base, 0);

% count number of elements
ll_n = numel(ll_base);
lh_n = numel(lh_base);
hl_n = numel(hl_base);
hh_n = numel(hh_base);

% find local threshold
ll_t = ll_var * (sqrt(2 * log2(ll_n)));
lh_t = lh_var * (sqrt(2 * log2(lh_n)));
hl_t = hl_var * (sqrt(2 * log2(hl_n)));
hh_t = hh_var * (sqrt(2 * log2(hh_n)));

% global
if thr_type == 1
    ll_t = gthr; lh_t = gthr; hl_t = gthr; hh_t = gthr;
end

% count zero values in bands
ll_size = size(ll_base);
lh_size = size(lh_base);
hl_size = size(hl_base);
hh_size = size(hh_base);

% count zero values in new band matrices
ll_zeros = sum(ll_base==0,'all');
lh_zeros = sum(lh_base==0,'all');
hl_zeros = sum(hl_base==0,'all');
hh_zeros = sum(hh_base==0,'all');

% initiate new matrices
ll_new = zeros(ll_size);
lh_new = zeros(lh_size);
hl_new = zeros(lh_size);
hh_new = zeros(lh_size);

% apply thresholding on bands
% if new value < thr => 0
% otherwise, keep the previous value
for id=1:ll_size(1)
    for idx=1:ll_size(2)
        if ll_base(id,idx) < ll_t
            ll_new(id,idx) = 0;
        else
            ll_new(id,idx) = ll_base(id,idx);
        end
    end
end
for id=1:lh_size(1)
    for idx=1:lh_size(2)
       if lh_base(id,idx) < lh_t
           lh_new(id,idx) = 0;
       else
           lh_new(id,idx) = lh_base(id,idx);
       end
    end
end
for id=1:hl_size(1)
    for idx=1:hl_size(2)
       if hl_base(id,idx) < hl_t
           hl_new(id,idx) = 0;
       else
           hl_new(id,idx) = hl_base(id,idx);
       end
    end
end
for id=1:hh_size(1)
    for idx=1:hh_size(2)
       if hh_base(id,idx) < hh_t
           hh_new(id,idx) = 0;
       else
           hh_new(id,idx) = hh_base(id,idx);
       end
    end
end

% count zeros of the new matrices
ll_new_size = size(ll_new);
lh_new_size = size(lh_new);
hl_new_size = size(hl_new);
hh_new_size = size(hh_new);

% count number of zeros among new values
ll_new_zeros = sum(ll_new==0,'all');
lh_new_zeros = sum(lh_new==0,'all');
hl_new_zeros = sum(hl_new==0,'all');
hh_new_zeros = sum(hh_new==0,'all');


% set new band matrices
dc.dec{1} = ll_new;
dc.dec{2} = lh_new;
dc.dec{3} = hl_new;
dc.dec{4} = hh_new;

% count how many coeff. were thresholded
ll_zeros_diff = ll_new_zeros - ll_zeros;
lh_zeros_diff = lh_zeros - lh_new_zeros;
hl_zeros_diff = hl_zeros - hl_new_zeros;
hh_zeros_diff = hh_zeros - hh_new_zeros;

% show coeff. matrices vs. thresholded version
figure
colormap(gray);
subplot(2,4,1); imagesc(ll_base); title('LL');
subplot(2,4,2); imagesc(lh_base); title('LH');
subplot(2,4,3); imagesc(hl_base); title('HL');
subplot(2,4,4); imagesc(hh_base); title('HH');
subplot(2,4,5); imagesc(ll_new); title({'LL thr';ll_zeros_diff});
subplot(2,4,6); imagesc(lh_new); title({'LH thr';lh_zeros_diff});
subplot(2,4,7); imagesc(hl_new); title({'HL thr';hl_zeros_diff});
subplot(2,4,8); imagesc(hh_new); title({'HH thr';hh_zeros_diff});

% idwt to reconstruct compressed image
cmp = waverec3(dc);
cmp = uint8(cmp);

% calculate mse/psnr
D = abs(cmp - base) .^2;
mse  = sum(D(:))/numel(base);
psnr = 10*log10(255*255/mse);

% show images and mse/psnr
figure
subplot(1,2,1);
imshow(base); title("Original"); axis square;
subplot(1,2,2);
imshow(cmp); colormap(gray); axis square;
msg = strcat("MSE: ", num2str(mse), " | PSNR: ", num2str(psnr));
title({"Compressed";msg});

% save image locally
imwrite(cmp, 'compressed.png');

Solution

  • I solved the question. the sigma in the local threshold formula is not variance, it's the standard deviation. I applied these steps:

    1. used stdfilt() std2() to find standard deviation of my coeff. matrices (thanks to @Rotem for pointing this out)
    2. used numel() to count the number of elements in coeff. matrices

    this is a summary of the process. it's the same for other bands (LH, HL, HH))

    [c, s] = wavedec2(image, wname, level); %apply dwt
    ll = appcoeff2(c, s, wname); %find LL
    ll_std = std2(ll); %find standard deviation
    ll_n = numel(ll); %find number of coeffs in LL
    ll_t = ll_std * (sqrt(2 * log2(ll_n))); %local the formula
    ll_new = ll .* double(ll > ll_t); %thresholding
    
    1. replace the LL values in c in a for loop
    2. reconstruct by applying IDWT using waverec2

    this is a sample output: enter image description here enter image description here