Search code examples
collectionskotlinflatmap

What is the use case for flatMap vs map in kotlin


in https://try.kotlinlang.org/#/Kotlin%20Koans/Collections/FlatMap/Task.kt

it has sample of using flatMap and map

seems both are doing the same thing, is there a sample to show the difference of using flatMap and map?

the data type:

data class Shop(val name: String, val customers: List<Customer>)

data class Customer(val name: String, val city: City, val orders: List<Order>) {
    override fun toString() = "$name from ${city.name}"
}

data class Order(val products: List<Product>, val isDelivered: Boolean)

data class Product(val name: String, val price: Double) {
    override fun toString() = "'$name' for $price"
}

data class City(val name: String) {
    override fun toString() = name
}

the samples:

fun Shop.getCitiesCustomersAreFrom(): Set<City> =
    customers.map { it.city }.toSet()
    // would it be same with customers.flatMap { it.city }.toSet() ?

val Customer.orderedProducts: Set<Product> get() {
    return orders.flatMap { it.products }.toSet()
    // would it be same with return orders.map { it.products }.toSet()
}

Solution

  • Consider the following example: You have a simple data structure Data with a single property of type List.

    class Data(val items : List<String>)
    
    val dataObjects = listOf(
        Data(listOf("a", "b", "c")), 
        Data(listOf("1", "2", "3"))
    )
    

    flatMap vs. map

    With flatMap, you can "flatten" multiple Data::items into one collection as shown with the items variable.

    val items: List<String> = dataObjects
        .flatMap { it.items } //[a, b, c, 1, 2, 3]
    

    Using map, on the other hand, simply results in a list of lists.

    val items2: List<List<String>> = dataObjects
        .map { it.items } //[[a, b, c], [1, 2, 3]] 
    

    flatten

    There's also a flatten extension on Iterable<Iterable<T>> and also Array<Array<T>> which you can use alternatively to flatMap when using those types:

    val nestedCollections: List<Int> = 
        listOf(listOf(1,2,3), listOf(5,4,3))
            .flatten() //[1, 2, 3, 5, 4, 3]