Search code examples
pythonmatplotlibmatplotlib-3d

Unable to make axis logarithmic in 3D plot


I am trying to make a 3D plot using matplotlib inside jupyter-notebook. I am using a dataset from kaggle.

The schema is the following

LotArea SalePrice YrSold PoolArea
8450 208500 2008 0
9600 181500 2007 0
... ... ... ...

When I plot with linear axes, everything is OK:

import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 15))
ax = plt.axes(projection='3d')

area_data = dataset_chosen["LotArea"]
price_data = dataset_chosen["SalePrice"]
year_data = dataset_chosen["YrSold"]

cmhot = plt.get_cmap("hot")

ax.scatter3D(xs=area_data, ys=price_data, zs=year_data, c=dataset_chosen["PoolArea"])

#ax.set_xscale("log")

ax.set_xlabel("Area")
ax.set_ylabel("Price")
ax.set_zlabel("Year")

plt.show()

Plot with all linear scales. And when I try to make x scale logarithmic (uncomment #ax.set_xscale("log")), the plot does not look like a plot. enter image description here

How to make X scale logarithmic?


Solution

  • If you check here, there is a discussion on the same. This is a limitation/bug within 3d plots. As mentioned there, there is a workaround... basically, you need to manually do that scaling. Below is the updated code to do that. Hope this is what you are looking for... Note I used log10 as the numbers align up nicely.

    dataset_chosen=pd.read_csv('train.csv')
    fig = plt.figure(figsize=(10, 15))
    ax = plt.axes(projection='3d')
    
    area_data = np.log10(dataset_chosen["LotArea"])  ## Changed to LOG-10
    price_data = dataset_chosen["SalePrice"]
    year_data = dataset_chosen["YrSold"]
    
    cmhot = plt.get_cmap("hot")
    
    ax.scatter3D(xs=area_data, ys=price_data, zs=year_data, c=dataset_chosen["PoolArea"])
    
    ## Set the xticks and xticklables to what you want it to be...
    xticks=[100,1000,10000,100000]
    ax.set_xticks(np.log10(xticks))
    ax.set_xticklabels(xticks)
    
    #ax.set_xscale('log')
    ax.set_xlabel("Area")
    ax.set_ylabel("Price")
    ax.set_zlabel("Year")
    
    plt.show()
    

    enter image description here