Search code examples
pythonoptimizationmarching-cubes

How can I speed up this program written in Python?


The following program is a solution to the Marching Square problem in Python:

from typing import List

def GetCaseId(Point_A_data: float, Point_B_data: float,
              Point_C_data: float, Point_D_data: float,
              threshold):
    caseId = 0
    if (Point_A_data >= threshold):
        caseId |= 1
    if (Point_B_data >= threshold):
        caseId |= 2
    if (Point_C_data >= threshold):
        caseId |= 4
    if (Point_D_data >= threshold):
        caseId |= 8
    return caseId


def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
             a: float, b: float, c: float, d: float,
             threshold: float):
    lines = []
    caseId = GetCaseId(a, b, c, d, threshold)

    if caseId in (0, 15):
        return []

    if caseId in (1, 14, 10):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_B[1]
        qX = Point_D[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (2, 13, 5):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = Point_C[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (3, 12):
        pX = Point_A[0]
        pY = (Point_A[1] + Point_D[1]) / 2
        qX = Point_C[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (4, 11, 10):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_D[1]
        qX = Point_B[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (6, 9):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = (Point_C[0] + Point_D[0]) / 2
        qY = Point_C[1]

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (7, 8, 5):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_C[1]
        qX = Point_A[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    return lines


def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
    linesList = []

    Height = len(y_int_list)  # rows
    Width = len(x_int_list)  # cols

    if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):

        for j in range(Height - 1):  # rows
            for i in range(Width - 1):  # cols
                point_a_data_float = data_2d_list[j + 1][i]
                point_b_data_float = data_2d_list[j + 1][i + 1]
                point_c_data_float = data_2d_list[j][i + 1]
                point_d_data_float = data_2d_list[j][i]

                point_A = [x_int_list[i], y_int_list[j + 1]]
                point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                point_C = [x_int_list[i + 1], y_int_list[j]]
                point_D = [x_int_list[i], y_int_list[j]]

                for threshold_item in threshold_list:
                    list = GetLines(point_A, point_B, point_C, point_D,
                                    point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                    threshold_item)
                    linesList = linesList + list
    else:
        raise AssertionError

    return [linesList]

The problem with this source code is - it takes ages to generate an output.

I.e. using the following driver program:

import drawSvg as draw_svg

N_int = 800
N2_float = N_int / 8
x_int_vector = [i for i in range(N_int)]
y_int_vector = [i for i in range(N_int)]

matrix_256x256 = [[(math.sin(i / N2_float) * math.sin(j / N2_float)) for i in range(N_int)] for j in range(N_int)]

fill = "#2591a3"
drawing = draw_svg.Drawing(N_int, N_int, displayInline=False)

threshold_float_list = [0.2, 0.4, 0.6, 0.8]
collection = marching_square(x_int_vector, y_int_vector, matrix_256x256, threshold_float_list)
for line_set in collection:
    for line in line_set:
        drawing.append(draw_svg.Line(line[0], line[1], line[2], line[3], stroke='red'))
     # END of line
# END of line_set
drawing.saveSvg('example.svg') 

The code becomes horribly slow for practical use.

How can I speed up the code?

N.B. marching_square()'s signature must not be changed.


Solution

  • Got ~10x speed-up

    1. Removed extending list which was biggest bottleneck (using this trick to concatenate list of lists)
    2. Applied numba to GetCaseId which was second bottleneck
    from typing import List
    import numba
    import functools
    import operator
    
    @numba.jit(nopython=True)
    def GetCaseId(Point_A_data: float, Point_B_data: float,
                  Point_C_data: float, Point_D_data: float,
                  threshold):
        caseId = 0
        if (Point_A_data >= threshold):
            caseId |= 1
        if (Point_B_data >= threshold):
            caseId |= 2
        if (Point_C_data >= threshold):
            caseId |= 4
        if (Point_D_data >= threshold):
            caseId |= 8
        return caseId
    
    
    def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
                 a: float, b: float, c: float, d: float,
                 threshold: float):
        lines = []
        caseId = GetCaseId(a, b, c, d, threshold)
    
        if caseId in (0, 15):
            return None
    
        if caseId in (1, 14, 10):
            pX = (Point_A[0] + Point_B[0]) / 2
            pY = Point_B[1]
            qX = Point_D[0]
            qY = (Point_A[1] + Point_D[1]) / 2
    
            line = (pX, pY, qX, qY)
    
            lines.append(line)
    
        if caseId in (2, 13, 5):
            pX = (Point_A[0] + Point_B[0]) / 2
            pY = Point_A[1]
            qX = Point_C[0]
            qY = (Point_A[1] + Point_D[1]) / 2
    
            line = (pX, pY, qX, qY)
    
            lines.append(line)
    
        if caseId in (3, 12):
            pX = Point_A[0]
            pY = (Point_A[1] + Point_D[1]) / 2
            qX = Point_C[0]
            qY = (Point_B[1] + Point_C[1]) / 2
    
            line = (pX, pY, qX, qY)
    
            lines.append(line)
    
        if caseId in (4, 11, 10):
            pX = (Point_C[0] + Point_D[0]) / 2
            pY = Point_D[1]
            qX = Point_B[0]
            qY = (Point_B[1] + Point_C[1]) / 2
    
            line = (pX, pY, qX, qY)
    
            lines.append(line)
    
        elif caseId in (6, 9):
            pX = (Point_A[0] + Point_B[0]) / 2
            pY = Point_A[1]
            qX = (Point_C[0] + Point_D[0]) / 2
            qY = Point_C[1]
    
            line = (pX, pY, qX, qY)
    
            lines.append(line)
    
        elif caseId in (7, 8, 5):
            pX = (Point_C[0] + Point_D[0]) / 2
            pY = Point_C[1]
            qX = Point_A[0]
            qY = (Point_A[1] + Point_D[1]) / 2
    
            line = (pX, pY, qX, qY)
    
            lines.append(line)
    
        return lines
    
    
    def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
        linesList = []
    
        Height = len(y_int_list)  # rows
        Width = len(x_int_list)  # cols
    
        if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):
    
            for j in range(Height - 1):  # rows
                for i in range(Width - 1):  # cols
                    point_a_data_float = data_2d_list[j + 1][i]
                    point_b_data_float = data_2d_list[j + 1][i + 1]
                    point_c_data_float = data_2d_list[j][i + 1]
                    point_d_data_float = data_2d_list[j][i]
    
                    point_A = [x_int_list[i], y_int_list[j + 1]]
                    point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                    point_C = [x_int_list[i + 1], y_int_list[j]]
                    point_D = [x_int_list[i], y_int_list[j]]
    
                    for threshold_item in threshold_list:
                        list = GetLines(point_A, point_B, point_C, point_D,
                                        point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                        threshold_item)
                        if list:
                            linesList.append(list)
    
        else:
            raise AssertionError
    
        return functools.reduce(operator.iconcat, linesList, [])