I am trying to name the axes in subplots created with Plotly as shown below. I succeed with the first (left) subplot but I can not affect the axis labeling in the other subplots. I am not clear why. How can I fix the code?
def bern(theta, z, N):
"""Bernoulli likelihood with N trials and z successes."""
return np.clip(theta**z * (1-theta)**(N-z), 0, 1)
def bern2(theta1, theta2, z1, z2, N1, N2):
"""Bernoulli likelihood with N trials and z successes."""
return bern(theta1, z1, N1) * bern(theta2, z2, N2)
def make_thetas(xmin, xmax, n):
xs = np.linspace(xmin, xmax, n)
widths =(xs[1:] - xs[:-1])/2.0
thetas = xs[:-1]+ widths
return thetas
thetas1 = make_thetas(0, 1, 101)
thetas2 = make_thetas(0, 1, 101)
X, Y = np.meshgrid(thetas1, thetas2)
a = 2
b = 3
z1 = 11
N1 = 14
z2 = 7
N2 = 14
prior = stats.beta(a, b).pdf(X) * stats.beta(a, b).pdf(Y)
likelihood = bern2(X, Y, z1, z2, N1, N2)
posterior = stats.beta(a + z1, b + N1 - z1).pdf(X) * stats.beta(a + z2, b + N2 - z2).pdf(Y)
fig = make_subplots(rows=1, cols=3, specs= [[{'is_3d': True}, {'is_3d': True}, {'is_3d': True}]], subplot_titles=('Prior', 'Likelihood', 'Posterior'))
fig.add_trace(go.Surface(z= prior, showscale= True), 1, 1)
fig.add_trace(go.Surface(z= likelihood, showscale= True), 1, 2)
fig.add_trace(go.Surface(z= posterior, showscale= True),1,3, False)
fig.update_layout(title='Prior - Likelihood - Posterior', autosize= True, scene = dict(
xaxis_title='theta1',
yaxis_title='theta2',
zaxis_title='Probability Density'),
width= 1300, height=600,
margin=dict(l=65, r=50, b=65, t=90))
fig.show()
The output of the code is this:
If you replace the update_layout bit at the bottow with the following code it will work:
fig.update_layout(title='Prior - Likelihood - Posterior', autosize=True,
scene1=dict(
xaxis_title='theta1',
yaxis_title='theta2',
zaxis_title='Probability Density'),
scene2=dict(
xaxis_title='theta1',
yaxis_title='theta2',
zaxis_title='Probability Density'),
scene3=dict(
xaxis_title='theta1',
yaxis_title='theta2',
zaxis_title='Probability Density'),
width=1300, height=600,
margin=dict(l=65, r=50, b=65, t=90))
Reading the official doc here, it does say you could specifiy layout.scene, layout.scene2 and so on, not the most obvious just by looking at the example they provided but here we go.