Search code examples
scalaslick

How to express this query in slick involving group by, having, count


Suppose I have a table defined like this in Slick 3.2.3:

class ATable(tag: Tag) extends Table[(Int, Option[Boolean])](tag, "a_table") {
  def someInt = column[Int]("some_int")
  def someBool = column[Option[Boolean]]("some_bool")
  def * = (someInt, someBool)
}
object ATable extends TableQuery(new ATable(_))

and a bit of data to go with it:

insert into a_table
values
(1, true),
(2, null),
(2, true),
(2, null),
(3, true),
(3, true),
(3, null);

Now I'd like to find those ints in the table where there's exactly one row where some_bool is not null. This is not difficult in SQL:

select some_int
from a_table
group by some_int
having count(some_bool) = 1;

This works just fine. So let's try it with Slick:

ATable
 .groupBy(_.someInt)
 .filter(_._2.map(_.someBool).countDefined === 1)
 .map(_._1)

And while this compiles, it will crash at runtime with the error message “slick.SlickTreeException: Cannot convert node to SQL Comprehension”. Is this a known limitation or a bug in slick? Or am I supposed to write my query in a different way? It's of course possible to write this with a subquery, but I'd rather understand why the groupBy thing doesn't work…


Solution

  • The Slick-equivalence of your SQL should look something like the query in the following example:

    val aTable: TableQuery[ATable] = TableQuery[ATable]
    
    val setupAction: DBIO[Unit] = DBIO.seq(
      aTable.schema.create,
      aTable += (1, Some(true)),
      aTable += (2, None),
      aTable += (2, Some(true)),
      aTable += (2, None),
      aTable += (3, Some(true)),
      aTable += (3, Some(true)),
      aTable += (3, None)
    )
    
    val setupFuture: Future[Unit] = db.run(setupAction)
    
    val f = setupFuture.flatMap{ _ =>
      val query =
        aTable.groupBy(_.someInt).
          map{ case (someInt, group) => (someInt, group.map(_.someBool).countDefined) }.
          filter(_._2 === 1).
          map(_._1)
    
      println("Generated SQL for query:\n" + query.result.statements)
      db.run(query.result.map(println))
    }
    
    // Generated SQL for query:
    // List(select "some_int" from "a_table" group by "some_int" having count("some_bool") = 1)
    // Vector(1, 2)