Search code examples
javaapache-sparkrdd

Spark flatten dataset mapped col


I have a RDD with schema -

Schema: {
  "type" : "struct",
  "fields" : [ 

 {
    "name" : "cola",
    "type" : "string",
    "nullable" : true,
    "metadata" : { }
  }, {
    "name" : "mappedcol",
    "type" : {
      "type" : "map",
      "keyType" : "string",
      "valueType" : "string",
      "valueContainsNull" : true
    },
    "nullable" : true,
    "metadata" : { }
  }, {
    "name" : "colc",
    "type" : "string",
    "nullable" : true,
    "metadata" : { }
  }]
  }

Sample value:

{
cola : A1,
mappedcol : { mapped1: M1, mapped2: M2, mapped3: M3  }
colc : C1
}

I want to pull the keys in mappedcols up one level. Basically flatten all the columns at one level.

cola, mapped1, mapped2, mapped3, colc
A1, M1,M2,M3, C1

Is there an elegant way to do it in Java?


Solution

  • It is possible to access the single elements of the nested structure with a dot syntax, e.g. select mappedcol.mapped1 would return M1. The idea is to transform the schema into a list of column names using this dot syntax:

    private static List<String> structToColNames(StructField[] fields, String prefix) {
        List<String> columns = new ArrayList<>();
        for( StructField field: fields) {
            String fieldname = field.name();
            if( field.dataType() instanceof StructType) {
                columns.addAll(
                    structToColNames(((StructType)field.dataType()).fields(), 
                        prefix + fieldname + "."));
            }
            else {
                columns.add(prefix + fieldname);
            }
        }
        return columns;
    }
    

    The result of this function can then be used to select the data:

    Dataset<Row> df = spark.read().json(<path to json>);
    StructField[] fields = df.schema().fields();
    List<String> colNames = structToColNames(fields, "");
    System.out.println(colNames);
    Column[] columns = colNames.stream().map(s -> col(s)).toArray(Column[]::new);
    df.select(columns).show();
    

    prints

    [cola, colc, mappedcol.mapped1, mappedcol.mapped2, mappedcol.mapped3]
    
    +----+----+-------+-------+-------+
    |cola|colc|mapped1|mapped2|mapped3|
    +----+----+-------+-------+-------+
    |  A1|  C1|     M1|     M2|     M3|
    +----+----+-------+-------+-------+