Search code examples
haskellyesodesqueletohaskell-persistent

How do I pass a rendered persistent/esqueleto query to another query?


I'd like to use Persistent/Esqueleto to implement count estimates.

One approach recommended in this article is to define a function like this

CREATE FUNCTION count_estimate(query text) RETURNS integer AS $$
DECLARE
  rec   record;
  rows  integer;
BEGIN
  FOR rec IN EXECUTE 'EXPLAIN ' || query LOOP
    rows := substring(rec."QUERY PLAN" FROM ' rows=([[:digit:]]+)');
    EXIT WHEN rows IS NOT NULL;
  END LOOP;
  RETURN rows;
END;
$$ LANGUAGE plpgsql VOLATILE STRICT;

and then use it like this

SELECT count_estimate('SELECT * FROM companies WHERE status = ''Active''');

In order to use the count_estimate function, I'll need (I think?) to render the query that Peristent/Equeleto generates, however when I try rendering the query with renderQuerySelect, I get something like this

SELECT "companies"."id", "companies"."name", "companies"."status"
FROM "companies"
WHERE "companies"."status" IN (?)
; [PersistText "Active"]

This of course can't be stuffed into the count_estimate, because it will syntax error on the ? placeholder. I also can't naïvely replace the ? with "Active", because it will syntax error on that first double quote.

How do I render the query in a way that my count_estimate function will accept?

I tried something like this, but it fails at runtime

getEstimate :: (Text, [PersistValue]) -> DB [Single Int]
getEstimate (query, params) = rawSql [st|
  SELECT count_estimate('#{query}');
  |] params

Solution

  • I managed to figure it out (mostly).

    It's a matter of escaping the single quotes in both the query and the PersistValue parameters. I'm doing it like this at the moment, but escaping will need to be added back in otherwise I think it creates a SQL injection vulnerability. I may also need to handle the other PersistValue constructors in some specific way, but I haven't run into problems there yet.

    import qualified Data.Text as T
    import qualified Database.Persist as P
    
    getEstimate :: (Text, [PersistValue]) -> DB (Maybe Int)
    getEstimate (query, params) = fmap unSingle . listToMaybe <$> rawSql [st|
      SELECT count_estimate('#{T.replace "'" "''" query}');
      |] (map replace' params)
      where literal a = PersistLiteral_ P.Unescaped ("''" <> a <> "''")
            replace' = \case
              PersistText t -> literal $ encodeUtf8 t
              PersistDay  d -> literal $ encodeUtf8 $ pack $ showGregorian d
              a             -> a