I am using sksurv.linear_model.CoxPHSurvivalAnalysis to fit a cox ph regression and I would like to recover the density function f(t). The sksurv class has methods to predict the survival function and cumulative distribution function S(t) = 1-F(t) and the cumulative hazard function $H(t)$ but it doesn't seem to produce the density function.
My use case has no censoring, so ere is an example:
import pandas as pd
import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis
data = np.random.randint(5,30,size=10)
X_train = pd.DataFrame(data, columns=['covariate'])
y_train = np.array(np.random.randint(0,100,size=10)/100,dtype=[('status',bool),('target',float)])
estimator = CoxPHSurvivalAnalysis()
estimator.fit(X_train,y_train)
X_test = pd.DataFrame({'covariate':[12,2]})
chf = estimator.predict_cumulative_hazard_function(X_test)
cdf = estimator.predict_survival_function(X_test)
fig, ax = plt.subplots(1,2)
for fn_h, fn_c in zip(chf, cdf):
ax[0].step(fn_h.x,fn_h(fn_h.x),where='post')
ax[1].step(fn_c.x,fn_c(fn_c.x),where='post')
ax[0].set_title('Cumulative Hazard Functions')
ax[1].set_title('Survival Functions')
plt.show()
The probability density function (PDF) can be obtained from the cumulative distribution function (CDF) as :
f(t) = dF(t)/dt
Now, in Survival Analysis (SA) the PDF (f(t))
can be expressed in terms of Survival Function S(t)
and the hazard function h(t)
which is given by:
f(t) = h(t) x S(t)
where S(t) = 1 - F(t)
and h(t) = -dS(t)/dt x S(t) = dH(t)/dt
So, the PDF f(t)
can be expressed as : f(t) = dH(t)/dt x S(t)
Now, to compute the hazard function f(t)
we need derivative of Cumulative Hazard Function (CHF) H(t)
. Since the CHF are all discrete data points, we need InterpolatedUnivariateSpline
from the scipy
library to differentiate it. It creates a smooth spline interpolation of the CHF, which can then be differentiated to obtain h(t)
. Here's a slight modification of the code that was pasted:
# Import the necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sksurv.linear_model import CoxPHSurvivalAnalysis
from scipy.interpolate import InterpolatedUnivariateSpline
# Define a function to compute the probability density function (pdf)
# from the cumulative hazard function (chf) and survival function (sf).
def compute_pdf_from_chf_and_sf(chf, sf):
# The hazard function is the derivative of the cumulative hazard function.
# We use InterpolatedUnivariateSpline for spline interpolation to create a smooth
# function approximation of the CHF. This provides us with a smooth curve that
# passes through each data point, allowing us to differentiate the function and obtain
# the hazard function.
chf_spline = InterpolatedUnivariateSpline(chf.x, chf(chf.x))
hazard_function = chf_spline.derivative()(chf.x)
# The pdf can be computed using the formula: pdf(t) = hazard(t) * survival(t)
pdf = hazard_function * sf(chf.x)
return chf.x, pdf
# Generate random data for demonstration purposes
# Here, we create a random dataset with one covariate and survival times.
np.random.seed(42) # Setting a fixed seed.
data = np.random.randint(5, 30, size=10)
X_train = pd.DataFrame(data, columns=['covariate'])
y_train = np.array(np.random.randint(0, 100, size=10)/100, dtype=[('status', bool), ('target', float)])
# Initialize and fit the Cox Proportional Hazards model
estimator = CoxPHSurvivalAnalysis()
estimator.fit(X_train, y_train)
# Predict for new data points
X_test = pd.DataFrame({'covariate': [12, 2]})
cumulative_hazard_functions = estimator.predict_cumulative_hazard_function(X_test)
survival_functions = estimator.predict_survival_function(X_test)
# Plot the Cumulative Hazard, Survival, and PDF side by side in a single row
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for chf, sf in zip(cumulative_hazard_functions, survival_functions):
# Compute the pdf using our defined function
times, pdf_values = compute_pdf_from_chf_and_sf(chf, sf)
# Plotting the cumulative hazard function
axes[0].step(chf.x, chf(chf.x), where='post')
# Plotting the survival function
axes[1].step(sf.x, sf(sf.x), where='post')
# Plotting the probability density function
axes[2].step(times, pdf_values, where='post')
# Setting titles for each subplot
axes[0].set_title('Cumulative Hazard Functions')
axes[1].set_title('Survival Functions')
axes[2].set_title('Probability Density Functions')
# Display the plots
plt.tight_layout()
plt.show()
which results in
References : Machine Learning for Survival Analysis: A Survey