Search code examples
pythontype-conversionsympynansubstitution

NaN type in python after sympy substitution


I try to substitute variables in a sympy expression but I get NaN type after substitution and I don't understand why.

Here is the code:

import sympy as sp
import copy
import numpy as np
import itertools as it
import matplotlib.pyplot as plt

alpha_set_values = np.linspace(0, 5, 10000)
beta_set_values = np.linspace(0, 6, 10000)

def plot_expr(exprVal, points):
  for point in points:
    value = exprVal.subs( [ (beta,point[0]), (alpha, point[1]) ] )
    print(type(value))
    if value > 0:
      plt.scatter([beta], [alpha], color = 'r')
    else:
      plt.scatter([beta], [alpha], color = 'b')

  plt.show()
plot_expr(expr1, points)

expr1 is a sympy expression with symbols alpha and beta (α*(1 - 0.1/β) + α - 0.3α/β + 2 - 1.9β - 0.1α - β)/β). After substitution, datatype of value is NaN

For the full code here is google colab link. The last 2 cells are important and must be run - the error lacks in the last cell


Solution

  • You are getting Nan because your first point is (0, 0). Take a look at your expression, expr1:

    α*(1 - 0.1/β) + α - 0.3*α/β + 2 - 1.9*(α*β - 0.1*α - β)/β
    

    In particular, the terms:

    • α*(1 - 0.1/β): after the substitution, sympy evaluates 0 * (1 - zoo) which results in Nan.
    • also, α/β might results in NaN.

    I see what you are trying to do: the best way to achieve your goal is to use plot_implicit:

    sp.plot_implicit(expr1 > 0, (beta, 0, 6), (alpha, 0, 5))
    

    Alternatively, if you'd like to go on with your approach, consider starting from a slightly different value than zero. Also, to speed up computation and plotting, use sp.lambdify to convert the symbolic expression to a numerical function:

    n = 100
    alpha_set_values = np.linspace(1e-06, 5, n)
    beta_set_values = np.linspace(1e-06, 6, n)
    alpha_set_values, beta_set_values = np.meshgrid(alpha_set_values, beta_set_values)
    
    f = sp.lambdify([alpha, beta], expr1)
    res = f(alpha_set_values, beta_set_values)
    
    alpha_set_values = alpha_set_values.flatten()
    beta_set_values = beta_set_values.flatten()
    res = res.flatten()
    idx = res > 0
    
    plt.figure()
    plt.scatter(beta_set_values[idx], alpha_set_values[idx], color="r")
    plt.scatter(beta_set_values[~idx], alpha_set_values[~idx], color="b")
    plt.show()