Search code examples
rrasterpredictionxgboostxgbregressor

Predict at a finer spatial scale using XGBoost regression


I want to make a prediction at a finer spatial scale using XGBoost regression. I created a model at a coarse spatial scale and now I want to apply the model parameters at a finer scale. The issue is that at the finer spatial scale the xgb.DMatrix doesn't include the dependent variable so the predict function returns this error:

Error in predict.xgb.Booster(m, xgb_p, na.rm = TRUE): Feature names stored in `object` and `newdata` are different!

I have seen this post but I still get the same error. Also, I tried to create a data.frame with the independent variables and an extra empty column with the name of the dependent variable, but still the same error.

How can I use the XGB model parameters found in the coarse spatial scale to make predictions at a fine spatial scale?

I have checked the names of the variables in both the coarse and spatial scales and they are the same.

Here is the code:

library(xgboost)
library(raster)

# the dependent and independent variables at the coarse scale
xgb_m <- xgb.DMatrix(data = data.matrix(block.data), label = block.data$ntl)

# that's the model
m = xgb.train(data = xgb_m, 
            max.depth = 2,  
            nrounds = 1000, 
            min_child_weight = 1, 
            subsample = 0.75, 
            eta = 0.015, 
            gamma = 0.5,
            colsample_bytree = 1, 
            objective = "reg:squarederror")

# these are the independent variables at the fine spatial scale
pop = raster(paste0(wd, "pop.tif"))
tirs = raster(paste0(wd, "tirs.tif"))
agbh = raster(paste0(wd, "agbh.tif"))
vars = stack(pop, tirs, agbh)

# the xgb.DMatrix with the independent variables
xgb_p <- xgb.DMatrix(data = data.matrix(vars))

xgb_pred <- predict(m, xgb_p, na.rm = TRUE) # the error

The data.frame (block.data):

block.data = structure(list(x = c(11880750L, 11879250L, 11879750L, 11880250L, 
11880750L, 11881250L, 11879250L, 11879750L, 11880250L, 11880750L, 
11881250L, 11878750L, 11879250L, 11879750L, 11880250L, 11880750L, 
11881250L, 11879250L, 11879750L, 11880250L, 11880750L, 11881250L, 
11881750L, 11882250L, 11879250L, 11879750L, 11880250L, 11880750L, 
11881250L, 11881750L, 11882250L, 11882750L, 11879250L, 11879750L
), y = c(1802250L, 1801750L, 1801750L, 1801750L, 1801750L, 1801750L, 
1801250L, 1801250L, 1801250L, 1801250L, 1801250L, 1800750L, 1800750L, 
1800750L, 1800750L, 1800750L, 1800750L, 1800250L, 1800250L, 1800250L, 
1800250L, 1800250L, 1800250L, 1800250L, 1799750L, 1799750L, 1799750L, 
1799750L, 1799750L, 1799750L, 1799750L, 1799750L, 1799250L, 1799250L
), ntl = c(18.7969169616699, 25.7222957611084, 23.4188251495361, 
25.4322757720947, 16.4593601226807, 12.7868213653564, 30.9337253570557, 
29.865758895874, 30.4080600738525, 29.5479888916016, 24.3493347167969, 
35.2427635192871, 38.989933013916, 34.6536979675293, 29.4607238769531, 
30.7469024658203, 34.3946380615234, 42.8660278320312, 34.7930717468262, 
30.9516315460205, 32.20654296875, 39.999755859375, 46.6002235412598, 
38.6480979919434, 60.5214920043945, 33.1799964904785, 31.8498134613037, 
30.9209423065186, 32.2269744873047, 53.7062034606934, 45.5225944519043, 
38.3570976257324, 123.040382385254, 73.0528182983398), pop = c(19.6407718658447, 
610.009216308594, 654.812622070312, 426.475830078125, 66.3839492797852, 
10.6471328735352, 443.848846435547, 602.677429199219, 488.478454589844, 
387.470947265625, 58.2341117858887, 413.888488769531, 315.057678222656, 
354.082946777344, 602.827758789062, 463.518829345703, 296.713928222656, 
923.920593261719, 434.436645507812, 799.562927246094, 404.709564208984, 
265.043304443359, 366.697235107422, 399.851684570312, 952.2314453125, 
870.356994628906, 673.406616210938, 493.521606445312, 273.841888427734, 
371.428619384766, 383.057830810547, 320.986755371094, 991.131225585938, 
1148.87768554688), tirs = c(39.7242431640625, 44.9583969116211, 
41.4048385620117, 42.6056709289551, 40.0976028442383, 38.7490005493164, 
44.2747650146484, 43.5645370483398, 41.6180191040039, 40.3799781799316, 
38.8664817810059, 44.9089202880859, 44.414306640625, 44.560977935791, 
43.1288986206055, 40.9315185546875, 38.8918418884277, 46.3063850402832, 
45.5805702209473, 44.9196586608887, 42.2495613098145, 39.3051452636719, 
38.7914810180664, 38.6069412231445, 44.6782455444336, 46.4024772644043, 
44.4720573425293, 41.7361183166504, 42.3378067016602, 41.0018348693848, 
39.3579216003418, 41.6303863525391, 43.8207550048828, 46.0460357666016
), agbh = c(3.32185006141663, 4.98925733566284, 4.35699367523193, 
4.94798421859741, 3.14325952529907, 2.93211793899536, 4.52736520767212, 
4.99723243713379, 5.13944292068481, 3.92965626716614, 3.43465113639832, 
3.55617475509644, 3.4659411907196, 5.24469566345215, 5.36995029449463, 
4.61549234390259, 4.82002925872803, 4.20452928543091, 4.71502685546875, 
5.20452785491943, 5.05676746368408, 5.9952244758606, 6.16778612136841, 
4.69053316116333, 2.62325501441956, 4.74775457382202, 4.93133020401001, 
5.02366256713867, 5.74016952514648, 6.28353786468506, 4.67424774169922, 
4.56812858581543, 1.88153350353241, 4.31531000137329)), class = "data.frame", row.names = c(NA, 
-34L))

The fine resolution data

vars = new("RasterBrick", file = new(".RasterFile", name = "", datanotation = "FLT4S", 
    byteorder = "little", nodatavalue = -Inf, NAchanged = FALSE, 
    nbands = 1L, bandorder = "BIL", offset = 0L, toptobottom = TRUE, 
    blockrows = 0L, blockcols = 0L, driver = "", open = FALSE), 
    data = new(".MultipleRasterData", values = structure(c(NA, 
    NA, NA, NA, 18.7969169616699, NA, NA, NA, NA, NA, 25.7222957611084, 
    23.4188251495361, 25.4322757720947, 16.4593601226807, 12.7868213653564, 
    NA, NA, NA, NA, 30.9337253570557, 29.865758895874, 30.4080600738525, 
    29.5479888916016, 24.3493347167969, NA, NA, NA, 35.2427635192871, 
    38.989933013916, 34.6536979675293, 29.4607238769531, 30.7469024658203, 
    34.3946380615234, NA, NA, NA, NA, 42.8660278320312, 34.7930717468262, 
    30.9516315460205, 32.20654296875, 39.999755859375, 46.6002235412598, 
    38.6480979919434, NA, NA, 60.5214920043945, 33.1799964904785, 
    31.8498134613037, 30.9209423065186, 32.2269744873047, 53.7062034606934, 
    45.5225944519043, 38.3570976257324, NA, 123.040382385254, 
    73.0528182983398, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
    19.6407718658447, NA, NA, NA, NA, NA, 610.009216308594, 654.812622070312, 
    426.475830078125, 66.3839492797852, 10.6471328735352, NA, 
    NA, NA, NA, 443.848846435547, 602.677429199219, 488.478454589844, 
    387.470947265625, 58.2341117858887, NA, NA, NA, 413.888488769531, 
    315.057678222656, 354.082946777344, 602.827758789062, 463.518829345703, 
    296.713928222656, NA, NA, NA, NA, 923.920593261719, 434.436645507812, 
    799.562927246094, 404.709564208984, 265.043304443359, 366.697235107422, 
    399.851684570312, NA, NA, 952.2314453125, 870.356994628906, 
    673.406616210938, 493.521606445312, 273.841888427734, 371.428619384766, 
    383.057830810547, 320.986755371094, NA, 991.131225585938, 
    1148.87768554688, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
    39.7242431640625, NA, NA, NA, NA, NA, 44.9583969116211, 41.4048385620117, 
    42.6056709289551, 40.0976028442383, 38.7490005493164, NA, 
    NA, NA, NA, 44.2747650146484, 43.5645370483398, 41.6180191040039, 
    40.3799781799316, 38.8664817810059, NA, NA, NA, 44.9089202880859, 
    44.414306640625, 44.560977935791, 43.1288986206055, 40.9315185546875, 
    38.8918418884277, NA, NA, NA, NA, 46.3063850402832, 45.5805702209473, 
    44.9196586608887, 42.2495613098145, 39.3051452636719, 38.7914810180664, 
    38.6069412231445, NA, NA, 44.6782455444336, 46.4024772644043, 
    44.4720573425293, 41.7361183166504, 42.3378067016602, 41.0018348693848, 
    39.3579216003418, 41.6303863525391, NA, 43.8207550048828, 
    46.0460357666016, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
    3.32185006141663, NA, NA, NA, NA, NA, 4.98925733566284, 4.35699367523193, 
    4.94798421859741, 3.14325952529907, 2.93211793899536, NA, 
    NA, NA, NA, 4.52736520767212, 4.99723243713379, 5.13944292068481, 
    3.92965626716614, 3.43465113639832, NA, NA, NA, 3.55617475509644, 
    3.4659411907196, 5.24469566345215, 5.36995029449463, 4.61549234390259, 
    4.82002925872803, NA, NA, NA, NA, 4.20452928543091, 4.71502685546875, 
    5.20452785491943, 5.05676746368408, 5.9952244758606, 6.16778612136841, 
    4.69053316116333, NA, NA, 2.62325501441956, 4.74775457382202, 
    4.93133020401001, 5.02366256713867, 5.74016952514648, 6.28353786468506, 
    4.67424774169922, 4.56812858581543, NA, 1.88153350353241, 
    4.31531000137329, NA, NA, NA, NA, NA, NA), .Dim = c(63L, 
    4L)), offset = 0, gain = 1, inmemory = TRUE, fromdisk = FALSE, 
        nlayers = 4L, dropped = NULL, isfactor = c(FALSE, FALSE, 
        FALSE, FALSE), attributes = list(), haveminmax = TRUE, 
        min = c(12.7868213653564, 10.6471328735352, 38.6069412231445, 
        1.88153350353241), max = c(123.040382385254, 1148.87768554688, 
        46.4024772644043, 6.28353786468506), unit = "", names = c("ntl", 
        "pop", "tirs", "agbh")), legend = new(".RasterLegend", 
        type = character(0), values = logical(0), color = logical(0), 
        names = logical(0), colortable = logical(0)), title = character(0), 
    extent = new("Extent", xmin = 11878500, xmax = 11883000, 
        ymin = 1799000, ymax = 1802500), rotated = FALSE, rotation = new(".Rotation", 
        geotrans = numeric(0), transfun = function () 
        NULL), ncols = 9L, nrows = 7L, crs = new("CRS", projargs = NA_character_), 
    srs = character(0), history = list(), z = list())

Solution

  • The issue you have is that the model is trained on all columns present in block.data, so also on x, y and ntl itself. In your later rasterstack, those variables are not present and thus you get an error.

    We get can around this by creating the xgb.DMatrix differently:

    #create matrix
    xgb_m <- xgb.DMatrix(data = data.matrix(block.data[, c("pop", "tirs", "agbh")]), label = block.data$ntl)
    
    #your model
    m = xgb.train(data = xgb_m, 
                  max.depth = 2,  
                  nrounds = 1000, 
                  min_child_weight = 1, 
                  subsample = 0.75, 
                  eta = 0.015, 
                  gamma = 0.5,
                  colsample_bytree = 1, 
                  objective = "reg:squarederror")
    
    #load rasterstack
    library(terra)
    pop = rast("pop.tif")
    tirs = rast("tirs.tif")
    agbh = rast("agbh.tif")
    
    vars <- c(pop, tirs, agbh)
    
    #predict
    xgb_p <- xgb.DMatrix(data.matrix(vars))
    predict(m, xgb_p)