Search code examples
pythonarraysvariablesnumpyastropy

Comparing scalars to Numpy arrays


What I am trying to do is make a table based on a piece-wise function in Python. For example, say I wrote this code:

import numpy as np
from astropy.table import Table, Column
from astropy.io import ascii
x = np.array([1, 2, 3, 4, 5])
y = x * 2
data = Table([x, y], names = ['x', 'y'])
ascii.write(data, "xytable.dat")
xytable = ascii.read("xytable.dat")
print xytable

This works as expected, it prints a table that has x values 1 through 5 and y values 2, 4, 6, 8, 10.

But, what if I instead want y to be x * 2 only if x is 3 or less, and y to be x + 2 otherwise?

If I add:

if x > 3: 
    y = x + 2

it says:

The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

How do I code my table so that it works as a piece-wise function? How do I compare scalars to Numpy arrays?


Solution

  • You can possibly use numpy.where():

    In [196]: y = np.where(x > 3, x + 2, y)
    
    In [197]: y
    Out[197]: array([2, 4, 6, 6, 7])
    

    The code above gets the job done in a fully vectorized manner. This approach is generally more efficient (and arguably more elegant) than using list comprehensions and type conversions.