I'm trying to get the index of the max values in a row from a Spark dataframe. It's straight forward to get the maximum value. I do the following:
library(sparklyr)
library(dplyr)
config <- spark_config()
sc <- spark_connect(master = "local", config = config)
df <- replicate(n = 3, sample(x = 0:10,size = 10, rep=TRUE)) %>%
as.data.frame()
sdf <- sdf_copy_to(sc, df, overwrite = T)
sdf %>% spark_apply(
function(df) {
return( pmax(df[1], df[2], df[3]) )})
I've tried to collect these into a vector using ft_vector_assembler
but I am not familiar with the returning data structure. For example, I cannot recover max from the following code
sdf %>% ft_vector_assembler(
input_cols = c("V1", "V2", "V3"),
output_col = "features") %>%
select(features) %>%
spark_apply( function(df) pmax(df))
Any help is appreciated.
Let's start with your first problem:
It's straight forward to get the maximum value.
It indeed is, however spark_apply
is just not a way to go. Instead it is better to use greatest
function:
sdf %>% mutate(max = greatest(V1, V2, V3))
The same function can be used for you second problem, however due to sparklyr
limitations, you'll have to use SQL expression directly:
expr <- c("V1", "V2", "V3") %>%
paste0(
"CAST(STRUCT(`",
., "`, ", seq_along(.),
") AS struct<value: double, index: double>)", collapse=", ") %>%
paste0("greatest(", ., ").index AS max_index")
sdf %>%
spark_dataframe() %>%
invoke("selectExpr", list("*", expr)) %>%
sdf_register()
In Spark 2.4 (as for now unsupported in sparklyr
) it might be possible to
sdf %>% mutate(max
max_index = array_max(arrays_zip(array(V1, V2, V3), array(1, 2, 3))).1
)