Search code examples
pythonnumpydtype

Why the addion of float32 array and float64 scalar is float32 array in Numpy?


If one add float32 and float64 scalars, the result is float64: float32 is promoted to float64.

However, I find that, adding a float32 array and a float64 scalar, the result is a float32 array, rather than one might expect to be a float64 array.

I have written some code repreduce the problem.

My question is why the dtype of np.add(np.array([a]), b) is float32?

import numpy as np
import sys

print(sys.version_info) # sys.version_info(major=3, minor=10, micro=9, releaselevel='final', serial=0)
print(np.__version__)   # '1.24.1'

a = np.float32(1.1)
b = np.float64(2.2)

print((a+b).dtype)                                # float64
print(np.add(a, np.array([b])).dtype)             # float64
print(np.add(np.array([a]), np.array([b])).dtype) # float64
print(np.add(np.array([a]), b).dtype)             # float32 ?? expcting float64

This contradicts the doc of np.add (https://numpy.org/doc/stable/reference/generated/numpy.add.html), which says in the Notes

Equivalent to x1 + x2 in terms of array broadcasting.

x, y = numpy.broadcast_arrays(np.array([a]), b)
print(np.add(x, y).dtype) ## float64

Solution

  • why the dtype of np.add(np.array([a]), b) is float32?

    Because NumPy's promotion rules involving scalars pre-NEP 50 were confusing*. Using the NEP 50 rules in your version of NumPy, the behavior is as you expect:

    import numpy as np
    import sys
    
    # Use this to turn on the NEP 50 promotion rules
    np._set_promotion_state("weak")
    
    print(sys.version_info) # sys.version_info(major=3, minor=10, micro=9, releaselevel='final', serial=0)
    print(np.__version__)   # '1.24.0'
    
    a = np.float32(1.1)
    b = np.float64(2.2)
    
    print((a+b).dtype)                                # float64
    print(np.add(a, np.array([b])).dtype)             # float64
    print(np.add(np.array([a]), np.array([b])).dtype) # float64
    print(np.add(np.array([a]), b).dtype)             # float64
    

    These rules are the default in NumPy 2.0.

    *: More specifically, the old rules depend on the values. If you set a = b = np.finfo(np.float32).max, the result would overflow if kept in float32, so you would get float64 even with the old rules.