What I am trying to do is make a table based on a piece-wise function in Python. For example, say I wrote this code:
import numpy as np
from astropy.table import Table, Column
from astropy.io import ascii
x = np.array([1, 2, 3, 4, 5])
y = x * 2
data = Table([x, y], names = ['x', 'y'])
ascii.write(data, "xytable.dat")
xytable = ascii.read("xytable.dat")
print xytable
This works as expected, it prints a table that has x
values 1 through 5 and y
values 2, 4, 6, 8, 10.
But, what if I instead want y
to be x * 2
only if x
is 3 or less, and y
to be x + 2
otherwise?
If I add:
if x > 3:
y = x + 2
it says:
The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
How do I code my table so that it works as a piece-wise function? How do I compare scalars to Numpy arrays?
You can possibly use numpy.where()
:
In [196]: y = np.where(x > 3, x + 2, y)
In [197]: y
Out[197]: array([2, 4, 6, 6, 7])
The code above gets the job done in a fully vectorized manner. This approach is generally more efficient (and arguably more elegant) than using list comprehensions and type conversions.