Search code examples
pythonnumpyvalueerrorpiecewise

Getting ValueError when trying to use np.piecewise for multivariate function


I'm trying to define a multivariate piecewise function using np.piecewise as follows:

X = np.array([
    [1, 2],
    [3, 4], 
    [5, 6]
])

pw = np.piecewise(
    X,
    [
        np.abs(X[:, 0] - X[:, 1]) < 1,
        np.abs(X[:, 0] - X[:, 1]) >= 1
    ],
    [
        lambda X: 1 + 2 * X[:, 0] + 3 * X[:, 1],
        lambda X: 1.5 + 2.5 * X[:, 0] + 3.5 * X[:, 1]
    ]
)

Running this snippet giives the following error:

ValueError: shape mismatch: value array of shape (3,) could not be broadcast to indexing result of shape (3,2)

For context, I'm attempting to represent a map f: R^2 -> R in this example, evaluating it on each of the rows of X at once.

Any idea? Do I need to define the final parameter differently so that the indexing correctly broadcasts?


Solution

  • IMO np.piecewise is more suitable if you have two arrays from np.meshgrid, so that np.piecewise can match the condition's dimension with your array dimension.

    In your case, to represent a piecewise map $f:R^2 \to R$ with input being of shape (n,2) and evaluated row by row (each column representing a variable), the easiest way to generate vectorized code would be simply using np.select:

    def pw(X):
        return np.select([np.abs(X[:,0] - X[:,1]) < 1, np.abs(X[:,0] - X[:,1]) >= 1], 
                [1 + 2 * X[:, 0] + 3 * X[:, 1], 1.5 + 2.5 * X[:, 0] + 3.5 * X[:, 1]])
    

    and pw(X) yields the answer you want.