Search code examples
c++opencvimage-processingadaptive-threshold

Robust image segmentation in OpenCV


I'm trying to write an OpenCV program that counts fish eggs for someone else. It currently takes their uploaded image, normalizes, blurs, thresholds, dilates, distance transforms, thresholds again, and then finds contours (like in a typical watershed tutorial).

The problem I'm having is that the lighting conditions can vary quite a bit, so even with my adaptive threshold values, the accuracy of the algorithm also varies wildly. If there's a gradient brightness across the image it seems to do especially poorly. Sometimes the objects are very bright against the background and other times they're almost the same luminosity. Are there any particularly effective ways to find objects in varying light conditions?

Sample images: img gif


Solution

  • Because anything larger than 100 pixels isn't relevant to your image, I would construct a fourier band pass filter to remove these structures.

    Here is an implementation I use, based off the one in ImageJ. In this implementation the input image is mirror padded to reduce edge artifacts.

    static void GenerateBandFilter(thrust::host_vector<float>& filter, const BandPassSettings& band, const FrameSize& frame)
        {
            //From https://imagej.nih.gov/ij/plugins/fft-filter.html
            if (band.do_band_pass == false)
            {
                return;
            }
            if (frame.width != frame.height)
            {
                throw std::runtime_error("Frame height and width should be the same");
            }
            auto maxN = static_cast<int>(std::max(frame.width, frame.height));//todo make sure they are the same
    
            auto filterLargeC = 2.0f*band.max_dx / maxN;
            auto filterSmallC = 2.0f*band.min_dx / maxN;
            auto scaleLargeC = filterLargeC*filterLargeC;
            auto scaleSmallC = filterSmallC*filterSmallC;
    
            auto filterLargeR = 2.0f*band.max_dy / maxN;
            auto filterSmallR = 2.0f*band.min_dy / maxN;
            auto scaleLargeR = filterLargeR*filterLargeR;
            auto scaleSmallR = filterSmallR*filterSmallR;
    
            // loop over rows
            for (auto j = 1; j < maxN / 2; j++)
            {
                auto row = j * maxN;
                auto backrow = (maxN - j)*maxN;
                auto rowFactLarge = exp(-(j*j) * scaleLargeR);
                auto rowFactSmall = exp(-(j*j) * scaleSmallR);
                // loop over columns
                for (auto col = 1; col < maxN / 2; col++)
                {
                    auto backcol = maxN - col;
                    auto colFactLarge = exp(-(col*col) * scaleLargeC);
                    auto colFactSmall = exp(-(col*col) * scaleSmallC);
                    auto factor = (((1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall));
                    filter[col + row] *= factor;
                    filter[col + backrow] *= factor;
                    filter[backcol + row] *= factor;
                    filter[backcol + backrow] *= factor;
                }
            }
            auto fixy = [&](float t){return isinf(t) ? 0 : t; };
            auto rowmid = maxN * (maxN / 2);
            auto rowFactLarge = fixy(exp(-(maxN / 2)*(maxN / 2) * scaleLargeR));
            auto rowFactSmall = fixy(exp(-(maxN / 2)*(maxN / 2) *scaleSmallR));
            filter[maxN / 2] *= ((1 - rowFactLarge) * rowFactSmall);
            filter[rowmid] *= ((1 - rowFactLarge) * rowFactSmall);
            filter[maxN / 2 + rowmid] *= ((1 - rowFactLarge*rowFactLarge) * rowFactSmall*rowFactSmall); //
            rowFactLarge = fixy(exp(-(maxN / 2)*(maxN / 2) *scaleLargeR));
            rowFactSmall = fixy(exp(-(maxN / 2)*(maxN / 2) *scaleSmallR));
            for (auto col = 1; col < maxN / 2; col++){
                auto backcol = maxN - col;
                auto colFactLarge = exp(-(col*col) * scaleLargeC);
                auto colFactSmall = exp(-(col*col) * scaleSmallC);
                filter[col] *= ((1 - colFactLarge) * colFactSmall);
                filter[backcol] *= ((1 - colFactLarge) * colFactSmall);
                filter[col + rowmid] *= ((1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall);
                filter[backcol + rowmid] *= ((1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall);
            }
            // loop along column 0 and expanded_width/2
            auto colFactLarge = fixy(exp(-(maxN / 2)*(maxN / 2) * scaleLargeC));
            auto colFactSmall = fixy(exp(-(maxN / 2)*(maxN / 2) * scaleSmallC));
            for (auto j = 1; j < maxN / 2; j++) {
                auto row = j * maxN;
                auto backrow = (maxN - j)*maxN;
                rowFactLarge = exp(-(j*j) * scaleLargeC);
                rowFactSmall = exp(-(j*j) * scaleSmallC);
                filter[row] *= ((1 - rowFactLarge) * rowFactSmall);
                filter[backrow] *= ((1 - rowFactLarge) * rowFactSmall);
                filter[row + maxN / 2] *= ((1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall);
                filter[backrow + maxN / 2] *= ((1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall);
            }
            filter[0] = (band.remove_dc) ? 0 : filter[0];
        }
    

    enter image description here

    You can poke around my code that uses it here: https://github.com/kandel3/DPM_PhaseRetrieval