Background:
I am trying to implement a function doing an inverse transform sampling. I use sympy for calculating CDF and getting its inverse function. While for some simple PDFs I get correct results, for a PDF which CDF's inverse function includes Lambert-W function, results are wrong.
Example:
Consider following example CDF:
import sympy as sym
y = sym.Symbol('y')
cdf = (-y - 1) * sym.exp(-y) + 1 # derived from `pdf = x * sym.exp(-x)`
sym.plot(cdf, (y, -1, 5))
Now calculating inverse of this function:
x = sym.Symbol('x')
inverse = sym.solve(sym.Eq(x, cdf), y)
print(inverse)
Output:
[-LambertW((x - 1)*exp(-1)) - 1]
This, in fact, is only a left branch of negative y's of a given CDF:
sym.plot(inverse[0], (x, -0.5, 1))
Question: How can I get the right branch for positive y's of a given CDF?
What I tried:
Specifying x
and y
to be only positive:
x = sym.Symbol('x', positive=True)
y = sym.Symbol('y', positive=True)
This doesn't have any effect, even for the first CDF plot.
Making CDF a Piecewise
function:
cdf = sym.Piecewise((0, y < 0),
((-y - 1) * sym.exp(-y) + 1, True))
Again no effect. Strange thing here is that on another computer plotting this function gave a proper graph with zero for negative y's, but solving for a positive y's branch doesn't work anywhere. (Different versions? I also had to specify adaptive=False
to sympy.plot
to make it work there.)
Using sympy.solveset
instead of sympy.solve
:
This just gives a useless ConditionSet(y, Eq(x*exp(y) + y - exp(y) + 1, 0), Complexes(S.Reals x S.Reals, False))
as a result. Apparently, solveset
still doesn't know how to deal with LambertW
functions. From the docs:
When cases which are not solved or can only be solved incompletely, a
ConditionSet
is used and acts as an unevaluated solveset object. <...> There are still a few thingssolveset
can’t do, which the oldsolve
can, such as solving non linear multivariate & LambertW type equations.
Is it a bug or am I missing something? Is there any workaround to get the desired result?
The inverse produced by sympy is almost correct. The problem lies in the fact that the LambertW function has multiple branches over the domain (-1/e, 0). By default, it uses the upper branch, however for your problem you require the lower branch. The lower branch can be accessed by passing in a second argument to LambertW with a value of -1.
inverse = -sym.LambertW((x - 1)*sym.exp(-1), -1) - 1
sym.plot(inverse, (x, 0, 0.999))
Gives