Polars: Groupby and idxmin#

Welcome back to Cameron’s Corner! It’s the third week of January, and, instead of talking about graphs, I want to take a dive into Polars. I recently addressed a question on Polars’ Discord server, diving into the different ways to perform an “index minimum” operation across groups.

Sure, there’s a built-in Expression.idx_min(), but it operates a little differently than it does in pandas. Let’s take a look:

Data#

from polars import (
    DataFrame as pl_DataFrame,
    Categorical, 
    Config,
    datetime_range,
    col,
    all as pl_all
)
from datetime import datetime, timedelta
from numpy.random import default_rng
from string import ascii_uppercase

def make_df(ngroups, group_size):
    size = ngroups * group_size

    rng = default_rng(0)

    return pl_DataFrame({
        'date' : datetime_range(
            start := datetime(2000, 1, 1), 
            end=start + timedelta(hours=size-1), 
            interval='1H', 
            eager=True
        ),
        'group': [*ascii_uppercase[:ngroups]] * group_size,
        'value': rng.integers(0, 1_000, size=size),
    }).cast({'group': Categorical})

pl_df = make_df(3, 10)

with Config(tbl_rows=6):
    display(pl_df)
shape: (30, 3)
dategroupvalue
datetime[μs]cati64
2000-01-01 00:00:00"A"850
2000-01-01 01:00:00"B"636
2000-01-01 02:00:00"C"511
2000-01-02 03:00:00"A"33
2000-01-02 04:00:00"B"764
2000-01-02 05:00:00"C"729

Approaches#

pandas#

from pandas import DataFrame as pd_DataFrame

pd_df = pl_df.to_pandas()
pd_df.groupby('group', observed=True)['value'].idxmin()
group
A    27
B     7
C    23
Name: value, dtype: int64

Polars#

pl_df.group_by('group').agg(col('value').arg_min()).head()
shape: (3, 2)
groupvalue
catu32
"A"9
"B"2
"C"7

As you can see, the results are different! If you take your best guess as to what has happened here, then you might wonder, “Is pandas doing something with the index?”

Of course it is!

The difference in these results is that pandas returns index values in reference to the original DataFrame, or, more specifically, the original .index. Our index was a simple RangeIndex from 0 to the length of our data. The idxmin operation returns the index value, where a minimum occurs in our Series ('value'). In contrast, Polars returns the index position relative to the group.

While this difference might seem superficial—we all know that Polars does not like the .index—it does lead to an interesting question:

“How do I find the row in each group where the minimum occurred?”

In pandas, we can simply wrap the result in a .loc

pd_df.loc[
    # same operation as before!
    pd_df.groupby('group', observed=True)['value'].idxmin()
]
date group value
27 2000-01-02 03:00:00 A 33
7 2000-01-01 07:00:00 B 16
23 2000-01-01 23:00:00 C 2

However, in Polars, we don’t have an easy way to convert from these grouped idx_min back to our original data. Instead, we need to alter our original expression. I have noted a few ways to arrive at this result:

registry = []
def register(func):
    registry.append(func)
    return func

@register
def filter_over_argmin(df):
    group_row_ids = col('group').cum_count().over(col('group')) - 1
    group_arg_min = col('value').arg_min().over(col('group'))
    return df.filter(group_row_ids == group_arg_min)

@register
def sort_keep_first(df):
    return df.sort(['group', 'value']).unique('group', keep='first')

@register
def groupby_sort_first(df):
    return df.group_by('group').agg(pl_all().sort_by('value').first())

@register
def groupby_get(df):
    return (
        df.group_by('group').agg(pl_all().get(col('value').arg_min()))
    )

Now, let’s see these queries in action. I’ll use LazyFrames to ensure that any available optimizations can be made in advance.

I’m also going to up the ante by working with a much larger data set:

pl_df = make_df(25, 10_000)
pd_df = pl_df.to_pandas()

with Config(tbl_rows=6):
    display(pl_df)
shape: (250_000, 3)
dategroupvalue
datetime[μs]cati64
2000-01-01 00:00:00"A"850
2000-01-01 01:00:00"B"636
2000-01-01 02:00:00"C"511
2028-07-08 13:00:00"W"665
2028-07-08 14:00:00"X"392
2028-07-08 15:00:00"Y"992
results = {}
for func in registry:
    res = func(pl_df.lazy())
    print(f'{func.__name__:-^40}')
    %timeit -o -n 10 -r 5 res.collect()
    results[func] = res.collect()
    print()
-----------filter_over_argmin-----------
4.28 ms ± 141 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)

------------sort_keep_first-------------
10.3 ms ± 330 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)

-----------groupby_sort_first-----------
4.04 ms ± 338 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)

--------------groupby_get---------------
1.72 ms ± 91.8 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)

And, for reference, here is how pandas performs:

%timeit -n 10 -r 5 pd_df.loc[pd_df.groupby('group', observed=True)['value'].idxmin()]
3.57 ms ± 131 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)

Finally, let’s verify each of our Polars results arrived at the same answer:

from polars import align_frames
from itertools import pairwise

for left, right in pairwise(align_frames(*results.values(), on='group')):
    left, right = left.sort('group'), right.sort('group')
    assert left['value'].equals(right['value'])

Wrap-Up#

There we have it: how to answer a group_by...idxmin problem in both pandas and Polars. As usual, when it comes to Polars, make sure you leverage the expression syntax as best you can, and leave the optimizations to the engine. However, I will note that you may see some variance in the speed of your operation, depending on the approach you take, as seen above. Try to hit parallel-processing fastpaths like groupby operations (depending on your cardinality) and working across columns.

What do you think? Let us know on the DUTC Discord server.

That’s all for today. Until next time!