The following program is a solution to the Marching Square problem in Python:
from typing import List
def GetCaseId(Point_A_data: float, Point_B_data: float,
Point_C_data: float, Point_D_data: float,
threshold):
caseId = 0
if (Point_A_data >= threshold):
caseId |= 1
if (Point_B_data >= threshold):
caseId |= 2
if (Point_C_data >= threshold):
caseId |= 4
if (Point_D_data >= threshold):
caseId |= 8
return caseId
def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
a: float, b: float, c: float, d: float,
threshold: float):
lines = []
caseId = GetCaseId(a, b, c, d, threshold)
if caseId in (0, 15):
return []
if caseId in (1, 14, 10):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_B[1]
qX = Point_D[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (2, 13, 5):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = Point_C[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (3, 12):
pX = Point_A[0]
pY = (Point_A[1] + Point_D[1]) / 2
qX = Point_C[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (4, 11, 10):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_D[1]
qX = Point_B[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (6, 9):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = (Point_C[0] + Point_D[0]) / 2
qY = Point_C[1]
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (7, 8, 5):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_C[1]
qX = Point_A[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
return lines
def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
linesList = []
Height = len(y_int_list) # rows
Width = len(x_int_list) # cols
if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):
for j in range(Height - 1): # rows
for i in range(Width - 1): # cols
point_a_data_float = data_2d_list[j + 1][i]
point_b_data_float = data_2d_list[j + 1][i + 1]
point_c_data_float = data_2d_list[j][i + 1]
point_d_data_float = data_2d_list[j][i]
point_A = [x_int_list[i], y_int_list[j + 1]]
point_B = [x_int_list[i + 1], y_int_list[j + 1]]
point_C = [x_int_list[i + 1], y_int_list[j]]
point_D = [x_int_list[i], y_int_list[j]]
for threshold_item in threshold_list:
list = GetLines(point_A, point_B, point_C, point_D,
point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
threshold_item)
linesList = linesList + list
else:
raise AssertionError
return [linesList]
The problem with this source code is - it takes ages to generate an output.
I.e. using the following driver program:
import drawSvg as draw_svg
N_int = 800
N2_float = N_int / 8
x_int_vector = [i for i in range(N_int)]
y_int_vector = [i for i in range(N_int)]
matrix_256x256 = [[(math.sin(i / N2_float) * math.sin(j / N2_float)) for i in range(N_int)] for j in range(N_int)]
fill = "#2591a3"
drawing = draw_svg.Drawing(N_int, N_int, displayInline=False)
threshold_float_list = [0.2, 0.4, 0.6, 0.8]
collection = marching_square(x_int_vector, y_int_vector, matrix_256x256, threshold_float_list)
for line_set in collection:
for line in line_set:
drawing.append(draw_svg.Line(line[0], line[1], line[2], line[3], stroke='red'))
# END of line
# END of line_set
drawing.saveSvg('example.svg')
The code becomes horribly slow for practical use.
How can I speed up the code?
N.B. marching_square()
's signature must not be changed.
Got ~10x speed-up
numba
to GetCaseId
which was second bottleneckfrom typing import List
import numba
import functools
import operator
@numba.jit(nopython=True)
def GetCaseId(Point_A_data: float, Point_B_data: float,
Point_C_data: float, Point_D_data: float,
threshold):
caseId = 0
if (Point_A_data >= threshold):
caseId |= 1
if (Point_B_data >= threshold):
caseId |= 2
if (Point_C_data >= threshold):
caseId |= 4
if (Point_D_data >= threshold):
caseId |= 8
return caseId
def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
a: float, b: float, c: float, d: float,
threshold: float):
lines = []
caseId = GetCaseId(a, b, c, d, threshold)
if caseId in (0, 15):
return None
if caseId in (1, 14, 10):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_B[1]
qX = Point_D[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (2, 13, 5):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = Point_C[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (3, 12):
pX = Point_A[0]
pY = (Point_A[1] + Point_D[1]) / 2
qX = Point_C[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (4, 11, 10):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_D[1]
qX = Point_B[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (6, 9):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = (Point_C[0] + Point_D[0]) / 2
qY = Point_C[1]
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (7, 8, 5):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_C[1]
qX = Point_A[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
return lines
def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
linesList = []
Height = len(y_int_list) # rows
Width = len(x_int_list) # cols
if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):
for j in range(Height - 1): # rows
for i in range(Width - 1): # cols
point_a_data_float = data_2d_list[j + 1][i]
point_b_data_float = data_2d_list[j + 1][i + 1]
point_c_data_float = data_2d_list[j][i + 1]
point_d_data_float = data_2d_list[j][i]
point_A = [x_int_list[i], y_int_list[j + 1]]
point_B = [x_int_list[i + 1], y_int_list[j + 1]]
point_C = [x_int_list[i + 1], y_int_list[j]]
point_D = [x_int_list[i], y_int_list[j]]
for threshold_item in threshold_list:
list = GetLines(point_A, point_B, point_C, point_D,
point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
threshold_item)
if list:
linesList.append(list)
else:
raise AssertionError
return functools.reduce(operator.iconcat, linesList, [])