Search code examples
pythonnumpyjupytercartopy

I have 12 maps using cartopy, and I want to use plt.subplots to make a gridded display of the maps (4 rows and 3 columns)


I have 12 cartopy maps and I would like to organize them in 3 rows and 4 columns to make them more easily visible instead of just plotting vertically. I tried using fig,ax = plt.subplots and using ax.ravel but I keep getting this error. I also tried updating cartopy, but the error is still there:

fig, ax = plt.subplots(nrows=3, ncols=4,figsize=(15,5))
    for i in zip(range(12),ax.ravel()):
        ax = plt.axes(projection=ccrs.PlateCarree())
        ax.set_extent([-90, 10, 5, 85], crs=ccrs.PlateCarree())
        x = ax.contourf(longitude,latitude,climatology[i], np.arange(28,39,.3),cmap='jet', ax=ax, extend='both')
        ax.coastlines()
        #ax.add_feature(cfeature.LAND, zorder=100, edgecolor='k')
        gridlines = ax.gridlines(draw_labels=True)
        cbar = plt.colorbar(x, fraction=.046, pad=0.04)
        cbar.set_label('psu', labelpad=15, y=.5, rotation=90)
        ax.text(.5,-.12, 'Longitude' , va='bottom' , ha='center', rotation='horizontal', rotation_mode= 'anchor',transform=ax.transAxes)
        ax.text(-.1, .5, 'Latitude' , va='bottom' , ha='center', rotation='vertical', rotation_mode= 'anchor',transform=ax.transAxes)
        #plt.title(title)
        plt.show()

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-57-4fa6b08f3ee5> in <module>
      3     ax = plt.axes(projection=ccrs.PlateCarree())
      4     ax.set_extent([-90, 10, 5, 85], crs=ccrs.PlateCarree())
----> 5     x = ax.contourf(sss_md.longitude,sss_md.latitude,climatology[i], np.arange(28,39,.3),cmap='jet', ax=ax, extend='both')
      6     ax.coastlines()
      7     #ax.add_feature(cfeature.LAND, zorder=100, edgecolor='k')

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/dataarray.py in __getitem__(self, key)
    640         else:
    641             # xarray-style array indexing
--> 642             return self.isel(indexers=self._item_key_to_dict(key))
    643 
    644     def __setitem__(self, key: Any, value: Any) -> None:

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/dataarray.py in isel(self, indexers, drop, missing_dims, **indexers_kwargs)
   1037 
   1038         if any(is_fancy_indexer(idx) for idx in indexers.values()):
-> 1039             ds = self._to_temp_dataset()._isel_fancy(
   1040                 indexers, drop=drop, missing_dims=missing_dims
   1041             )

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/dataset.py in _isel_fancy(self, indexers, drop, missing_dims)
   2011 
   2012             if name in self.indexes:
-> 2013                 new_var, new_index = isel_variable_and_index(
   2014                     name, var, self.indexes[name], var_indexers
   2015                 )

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/indexes.py in isel_variable_and_index(name, variable, index, indexers)
    104         )
    105 
--> 106     new_variable = variable.isel(indexers)
    107 
    108     if new_variable.dims != (name,):

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/variable.py in isel(self, indexers, missing_dims, **indexers_kwargs)
   1116 
   1117         key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
-> 1118         return self[key]
   1119 
   1120     def squeeze(self, dim=None):

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/variable.py in __getitem__(self, key)
    764         array `x.values` directly.
    765         """
--> 766         dims, indexer, new_order = self._broadcast_indexes(key)
    767         data = as_indexable(self._data)[indexer]
    768         if new_order:

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/variable.py in _broadcast_indexes(self, key)
    610         # key can be mapped as an OuterIndexer.
    611         if all(not isinstance(k, Variable) for k in key):
--> 612             return self._broadcast_indexes_outer(key)
    613 
    614         # If all key is 1-dimensional and there are no duplicate labels,

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/variable.py in _broadcast_indexes_outer(self, key)
    686             new_key.append(k)
    687 
--> 688         return dims, OuterIndexer(tuple(new_key)), None
    689 
    690     def _nonzero(self):

~/miniconda3/envs/py3_std_maps/lib/python3.8/site-packages/xarray/core/indexing.py in __init__(self, key)
    407             elif isinstance(k, np.ndarray):
    408                 if not np.issubdtype(k.dtype, np.integer):
--> 409                     raise TypeError(
    410                         f"invalid indexer array, does not have integer dtype: {k!r}"
    411                     )

TypeError: invalid indexer array, does not have integer dtype: array(<AxesSubplot:>, dtype=object)

Solution

  • Try this edited code. Note that the projection should be declared as a part of subplot_kw.

    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    
    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(18,10), subplot_kw=dict(projection=ccrs.PlateCarree()))
    
    for i, ax in zip(range(12), axes.ravel()):
    # (or) for i, ax in enumerate(axes.flat):
        # (NOT this) ax = plt.axes(projection=ccrs.PlateCarree())
        ax.set_extent([-90, 10, 5, 85], crs=ccrs.PlateCarree())
    
        ax.coastlines(lw=0.2)
        # (plot land) ax.add_feature(cfeature.LAND, zorder=100, edgecolor='k')
        gridlines = ax.gridlines(draw_labels=True)
        # cbar = plt.colorbar(x, fraction=.046, pad=0.04)
        # cbar.set_label('psu', labelpad=15, y=.5, rotation=90)
        ax.text(.5,-.12, 'Longitude' , va='bottom' , ha='center', rotation='horizontal', rotation_mode= 'anchor',transform=ax.transAxes)
        ax.text(-.1, .5, 'Latitude' , va='bottom' , ha='center', rotation='vertical', rotation_mode= 'anchor',transform=ax.transAxes)
        # plt.title(title)
    
    plt.show()  # place this outside the `for-loop`
    

    The output plot:

    3x4array-of-subplots