Search code examples
pythonmatplotlibtiffqgisrasterio

Rasterio: writing tiff with cmap on float values


I have this very simple code:

import numpy as np
from rasterio.transform import Affine

nx = 5
maxx = 4.0
minx = -4.0
res = (maxx - minx) / nx
maxy = 3.0
miny = -3.0
ny = int((maxy - miny) / res)

x = np.linspace(minx, maxx, nx)
y = np.linspace(miny, maxy, ny)
z = numpy.array([
    [-1, 10, 15.1, 6.3, 50.4],
    [26.7, -1, 15.7, 40.7, 5],
    [5, -1, 9.0, 38, 40.3],
])
cmap = plt.get_cmap("nipy_spectral")
with rasterio.open(
    os.path.join(os.path.dirname(__file__), "test.tiff"),
    "w",
    driver='GTiff',
    height=z.shape[0],
    width=z.shape[1],
    count=1,
    dtype=z.dtype,
    crs='+proj=latlong',
    transform=Affine.translation(x[0]-res/2, y[0]-res/2) * Affine.scale(res, res),
    nodata=-1,
) as df:
    df.colorinterp = [ColorInterp.palette]
    # df.write_colormap(1, cmap)
    df.write(z, 1)

It create a basic image when drag and drop in QGIS:

QGIS image screenshot

I would like to drag and drop this file in Qgis and it have the cmap working from matplotlib named nipy_spectral:

CMAP matplotlib

The line # df.write_colormap(1, cmap) is working only for uint8 data (when cmap is a dictionary using int values as keys) according to the documentation, but there is no documentation about float data...

My question and need is simple but there is nothing in documentation: how to apply this cmap to my df rasterio object in the python code?

For the moment it is working when I force data to be uint8 but i can have only 256 values, which is not enough...

Actual solution.

An other solution is to add manually in qgis a predefined cmap like this:

enter image description here

Then it is possible to export the style as a folder. Maybe it is possible to automatically apply this style to tiff file using qgis.core python module?


Solution

  • I finaly have done the trick using theses lines of code:

    import os
    import numpy
    from qgis.core import (
        QgsCoordinateReferenceSystem, QgsSingleBandPseudoColorRenderer, QgsColorRampShader, QgsStyle, QgsRasterBandStats,
        QgsRasterShader, QgsApplication, QgsProject, QgsRasterLayer)
    from calcul import conversion_to_geotiff
    
    
    
    def create_project(path_qgz, tiffs=None, epsg=2154):
        qgs = QgsApplication([], False)
        qgs.initQgis()
        project = QgsProject.instance()
        project.setTitle('test')
        project.setCrs(QgsCoordinateReferenceSystem(epsg))
        for data in tiffs:
            x, y, z, path_tif, colormap, precision = data
            conversion_to_geotiff.create_tiff(path_tif, x, y, z)
            layer = QgsRasterLayer(path_tif, os.path.splitext(os.path.basename(path_tif))[0])
            stats = layer.dataProvider().bandStatistics(1, QgsRasterBandStats.All)
            minimum = stats.minimumValue
            maximum = stats.maximumValue
            delta = maximum - minimum
            nclass = max(2, int(delta / precision))
            fractional_steps = [i / (nclass - 1) for i in range(nclass)]
            ramp = QgsStyle().defaultStyle().colorRamp(colormap)
            colors = [ramp.color(f) for f in fractional_steps]
            steps = [minimum + f * delta for f in fractional_steps]
            ramp_items = [
                QgsColorRampShader.ColorRampItem(step, color, str(step))
                for step, color in zip(steps, colors)
            ]
            shader_function = QgsColorRampShader()
            shader_function.setClassificationMode(QgsColorRampShader.EqualInterval)
            shader_function.setColorRampItemList(ramp_items)
            raster_shader = QgsRasterShader()
            raster_shader.setRasterShaderFunction(shader_function)
            renderer = QgsSingleBandPseudoColorRenderer(layer.dataProvider(), 1, raster_shader)
            layer.setRenderer(renderer)
            layer.triggerRepaint()
            project.addMapLayer(layer)
        project.write(path_qgz)
        qgs.exitQgis()
    
    
    if __name__ == "__main__":
        x = np.linspace(minx, maxx, nx)
        y = np.linspace(miny, maxy, ny)
        z = numpy.array([
            [-1, 10, 15.1, 6.3, 50.4],
            [26.7, -1, 15.7, 40.7, 5],
            [5, -1, 9.0, 38, 40.3],
        ])
    
        create_project(
            "/home/vince/test.qgz",
            tiffs=[
                [x, y, z, "/home/vince/test.tif", "Turbo", 5]
            ]
        )
    

    The hard part is to add the qgis module to your PYTHONPATH. It is very hard in windows because of DLLs, much easier in linux. (Just make sure your python version of qgis and your python version of your code is the same.