Search code examples
pythonpython-polarsrust-polarspolars

How do you insert a map-reduce into a Polars method chain?


I’m doing a bunch of filters and other transform applications including a group_by on a polars data frame, the objective being to count the number of html tags in a single column per date per publisher. Here is the code:

120 def contains_html3(mindate, parquet_file = default_file, fieldname = "text"):                                            
121     """ checks if html tags are in field """                                                                             
122                                                                                                                          
123                                                                                                                          
124     html_tags = [                                                                                                        
125     "<html>", "</html>", "<head>", "</head>", "<title>", "</title>", "<meta>", "</meta>", "<link>", "</link>", "<style>",
126     "<body>", "</body>", "<header>", "</header>", "<footer>", "</footer>", "<nav>", "</nav>", "<main>", "</main>",       
127     "<section>", "</section>", "<article>", "</article>", "<aside>", "</aside>", "<h1>", "</h1>", "<h2>", "</h2>",       
128     "<h3>", "</h3>", "<h4>", "</h4>", "<h5>", "</h5>", "<h6>", "</h6>", "<p>", "</p>", "<ul>", "</ul>", "<ol>", "</ol>", 
129     "<li>", "</li>", "<div>", "</div>", "<span>", "</span>", "<a>", "</a>", "<img>", "</img>", "<table>", "</table>",    
130     "<thead>", "</thead>", "<tbody>", "</tbody>", "<tr>", "</tr>", "<td>", "</td>", "<th>", "</th>", "<form>", "</form>",
131     "<input>", "</input>", "<textarea>", "</textarea>", "<button>", "</button>", "<select>", "</select>", "<option>", 
132     "<script>", "</script>", "<noscript>", "</noscript>", "<iframe>", "</iframe>", "<canvas>", "</canvas>", "<source>"]
133                                                                                                                          
134     gg = (pl.scan_parquet(parquet_file)                                                                                  
135           .cast({"date": pl.Date})                                                                                       
136           .select("publisher", "date", fieldname)                                                                        
137           .drop_nulls()                                                                                                  
138           .group_by("publisher", "date")                                                                                 
139           .agg(pl.col(fieldname).str.contains_any(html_tags).sum().alias(fieldname))                                         
140           .filter(pl.col(fieldname) > 0)                                                                                     
141           .sort(fieldname, descending = True)).collect()                                                                     
142                                                                                                                          
143     return gg                                                         

Here is example output for fieldname = "text":

Out[8]:
shape: (22_925, 3)
┌───────────────────────────┬────────────┬──────┐
│ publisher                 ┆ date       ┆ text │
│ ---                       ┆ ---        ┆ ---  │
│ str                       ┆ date       ┆ u64  │
╞═══════════════════════════╪════════════╪══════╡
│ Kronen Zeitung            ┆ 2024-11-20 ┆ 183  │
│ Kronen Zeitung            ┆ 2024-10-25 ┆ 180  │
│ Kronen Zeitung            ┆ 2024-11-14 ┆ 174  │
│ Kronen Zeitung            ┆ 2024-11-06 ┆ 172  │
│ Kronen Zeitung            ┆ 2024-10-31 ┆ 171  │
│ …                         ┆ …          ┆ …    │
│ The Faroe Islands Podcast ┆ 2020-03-31 ┆ 1    │
│ Sunday Standard           ┆ 2024-07-16 ┆ 1    │
│ Stabroek News             ┆ 2024-08-17 ┆ 1    │
│ CivilNet                  ┆ 2024-09-01 ┆ 1    │
│ The Star                  ┆ 2024-06-23 ┆ 1    │
└───────────────────────────┴────────────┴──────┘

The issue is that instead of just passing a single fieldname = "text" argument, I would like to pass a list (for example ["text", "text1", "text2", ...]). The idea would be to run the bottom three lines in the chain for each element of the list. I could wrap the whole polars method chain in a for loop and then join the resulting data frames, but is there a better way? For example to insert a map, or foreach, or other such construct after the group_by clause, and then have polars add a new column for each field name without using a loop?

What's the best way of handling this?

EDIT WITH REPRODUCIBLE CODE

This will produce a dataframe df and a sample output tc, with all four columns text1 through text4, summed and sorted, but not using polars for the last step.

#colorscheme orbital dark

import polars as pl
import datetime as dt
from math import sqrt
import random
random.seed(8472)
from functools import reduce


html_tags = [
"<html>", "</html>", "<head>", "</head>", "<title>", "</title>", "<meta>", "</meta>", "<link>", "</link>", "<style>", "</style>",
"<body>", "</body>", "<header>", "</header>", "<footer>", "</footer>", "<nav>", "</nav>", "<main>", "</main>",
"<section>", "</section>", "<article>", "</article>", "<aside>", "</aside>", "<h1>", "</h1>", "<h2>", "</h2>", 
"<h3>", "</h3>", "<h4>", "</h4>", "<h5>", "</h5>", "<h6>", "</h6>", "<p>", "</p>", "<ul>", "</ul>", "<ol>", "</ol>", 
"<li>", "</li>", "<div>", "</div>", "<span>", "</span>", "<a>", "</a>", "<img>", "</img>", "<table>", "</table>", 
"<thead>", "</thead>", "<tbody>", "</tbody>", "<tr>", "</tr>", "<td>", "</td>", "<th>", "</th>", "<form>", "</form>", 
"<input>", "</input>", "<textarea>", "</textarea>", "<button>", "</button>", "<select>", "</select>", "<option>", "</option>",
"<script>", "</script>", "<noscript>", "</noscript>", "<iframe>", "</iframe>", "<canvas>", "</canvas>", "<source>", "</source>"]


def makeword(alphaLength):
    """Make a dummy name if none provided."""
    consonants = "bcdfghjklmnpqrstvwxyz"
    vowels = "aeiou"
    word = ''.join(random.choice(consonants if i % 2 == 0 else vowels) 
                for i in range(alphaLength))
    return word

def makepara(nwords):
    """Make a paragraph of dummy text."""
    words = [makeword(random.randint(3, 10)) for _ in range(nwords)]
    tags = random.choices(html_tags, k=3)
    parawords = random.choices(tags + words, k=nwords)
    para = " ".join(parawords)
    return para

def generate_df_with_tags(rows = 100, numdates = 10, num_publishers = 6):
    publishers = [makeword(5) for _ in range(num_publishers)]
    datesrange = pl.date_range(start := dt.datetime(2024, 2, 1), 
                          end = start + dt.timedelta(days = numdates - 1),
                          eager = True)
    dates = sorted(random.choices(datesrange, k = rows))
    df = pl.DataFrame({
        "publisher": random.choices(publishers, k = rows),
        "date": dates, 
        "text1": [makepara(15) for _ in range(rows)],
        "text2": [makepara(15) for _ in range(rows)],
        "text3": [makepara(15) for _ in range(rows)],
        "text4": [makepara(15) for _ in range(rows)]
    })
    return df


def contains_html_so(parquet_file, fieldname = "text"):
    """ checks if html tags are in field """

    gg = (pl.scan_parquet(parquet_file)
          .select("publisher", "date", fieldname)
          .drop_nulls()
          .group_by("publisher", "date")
          .agg(pl.col(fieldname).str.contains_any(html_tags).sum().alias(fieldname))
          .filter(pl.col(fieldname) > 0)
          .sort(fieldname, descending = True)).collect()
            
    return gg

if __name__ == "__main__":
    df = generate_df_with_tags(100)
    df.write_parquet("/tmp/test.parquet")
    tc = [contains_html_so("/tmp/test.parquet", fieldname = x) for x in ["text1", "text2", "text3", "text4"]]
    tcr = (reduce(lambda x, y: x.join(y, how = "full", on = ["publisher", "date"], coalesce = True), tc)
     .with_columns((
     pl.col("text1").fill_null(0) 
     + pl.col("text2").fill_null(0) 
     + pl.col("text3").fill_null(0) 
     + pl.col("text4").fill_null(0)).alias("sum")).sort("sum", descending = True))
    print(tcr)


Desired output is below, but you'll see that in the bottom of the code I have run a functools.reduce on four dataframes, outside of the polars ecosystem, to join them, and it's basically this reduce that I want to put into the polars method chain somehow. [As an aside, my multiple (textX).fill_null(0) are also a bit clumsy but I'll leave that for a separate question]

In [59]: %run so_question.py
shape: (45, 7)
┌───────────┬────────────┬───────┬───────┬───────┬───────┬─────┐
│ publisher ┆ date       ┆ text1 ┆ text2 ┆ text3 ┆ text4 ┆ sum │
│ ---       ┆ ---        ┆ ---   ┆ ---   ┆ ---   ┆ ---   ┆ --- │
│ str       ┆ date       ┆ u64   ┆ u64   ┆ u64   ┆ u64   ┆ u64 │
╞═══════════╪════════════╪═══════╪═══════╪═══════╪═══════╪═════╡
│ desob     ┆ 2024-02-10 ┆ 5     ┆ 5     ┆ 5     ┆ 5     ┆ 20  │
│ qopir     ┆ 2024-02-03 ┆ 5     ┆ 5     ┆ 5     ┆ 4     ┆ 19  │
│ jerag     ┆ 2024-02-04 ┆ 5     ┆ 5     ┆ 5     ┆ 4     ┆ 19  │
│ jerag     ┆ 2024-02-07 ┆ 5     ┆ 4     ┆ 5     ┆ 5     ┆ 19  │
│ wopav     ┆ 2024-02-07 ┆ 4     ┆ 5     ┆ 3     ┆ 5     ┆ 17  │
│ …         ┆ …          ┆ …     ┆ …     ┆ …     ┆ …     ┆ …   │
│ jerag     ┆ 2024-02-06 ┆ 1     ┆ null  ┆ 1     ┆ 1     ┆ 3   │
│ desob     ┆ 2024-02-05 ┆ 1     ┆ 1     ┆ null  ┆ 1     ┆ 3   │
│ cufeg     ┆ 2024-02-04 ┆ 1     ┆ 1     ┆ 1     ┆ null  ┆ 3   │
│ cufeg     ┆ 2024-02-05 ┆ 1     ┆ null  ┆ 1     ┆ 1     ┆ 3   │
│ wopav     ┆ 2024-02-06 ┆ null  ┆ 1     ┆ 1     ┆ 1     ┆ 3   │
└───────────┴────────────┴───────┴───────┴───────┴───────┴─────┘

So basically, tag counts by columns ["text1", "text2", "text3", "text4"], then summed ignoring nulls, and sorted descending on the sum. Join should be on publisher and date, outer (= "full"), and coalescing.


Solution

  • Is it not the same as aggregating the list of col() names at the same time?

    fieldnames = ["text1", "text2", "text3", "text4"]
    
    (df.group_by("publisher", "date")
       .agg(pl.col(fieldnames).str.contains_any(html_tags).sum())
       .with_columns(sum = pl.sum_horizontal(fieldnames))
    )
    
    shape: (45, 7)
    ┌───────────┬────────────┬───────┬───────┬───────┬───────┬─────┐
    │ publisher ┆ date       ┆ text1 ┆ text2 ┆ text3 ┆ text4 ┆ sum │
    │ ---       ┆ ---        ┆ ---   ┆ ---   ┆ ---   ┆ ---   ┆ --- │
    │ str       ┆ date       ┆ u32   ┆ u32   ┆ u32   ┆ u32   ┆ u32 │
    ╞═══════════╪════════════╪═══════╪═══════╪═══════╪═══════╪═════╡
    │ desob     ┆ 2024-02-01 ┆ 2     ┆ 2     ┆ 2     ┆ 2     ┆ 8   │
    │ xikoy     ┆ 2024-02-06 ┆ 1     ┆ 1     ┆ 1     ┆ 1     ┆ 4   │
    │ wopav     ┆ 2024-02-03 ┆ 2     ┆ 2     ┆ 2     ┆ 2     ┆ 8   │
    │ jerag     ┆ 2024-02-05 ┆ 3     ┆ 2     ┆ 3     ┆ 3     ┆ 11  │
    │ qopir     ┆ 2024-02-03 ┆ 5     ┆ 5     ┆ 5     ┆ 4     ┆ 19  │
    │ …         ┆ …          ┆ …     ┆ …     ┆ …     ┆ …     ┆ …   │
    │ xikoy     ┆ 2024-02-10 ┆ 2     ┆ 2     ┆ 2     ┆ 2     ┆ 8   │
    │ xikoy     ┆ 2024-02-02 ┆ 1     ┆ 1     ┆ 1     ┆ 1     ┆ 4   │
    │ cufeg     ┆ 2024-02-10 ┆ 1     ┆ 1     ┆ 1     ┆ 1     ┆ 4   │
    │ jerag     ┆ 2024-02-06 ┆ 1     ┆ 0     ┆ 1     ┆ 1     ┆ 3   │
    │ desob     ┆ 2024-02-03 ┆ 2     ┆ 2     ┆ 2     ┆ 2     ┆ 8   │
    └───────────┴────────────┴───────┴───────┴───────┴───────┴─────┘
    

    pl.sum_horizontal() replaces the multiple fill_null / addition combination.