Search code examples
pythontabularastropy

Astropy table manipulation: How can I create a new table with rows grouped by values in a column?


I have a table with multiple rows and columns. One of the columns has different numbers that repeat several times. How can I create a new astropy table which only stores the rows with the column that repeats a number more than, say, 3 times?

example:

table

enter image description here

Notice that 0129 repeats 3 times in column c, 2780 repeats 4 times in column c. I'd like my code to then create the new table:

modified table

enter image description here

I'm using the astropy module and specifically:

from astropy.table import Table

I am assuming I need to use a for loop to accomplish this task and ultimately the command

new_table.add_row(table[index]) 

Big picture, what I am trying to accomplish is this:

if column_c_value repeats >=3:
    new_table.add_row(table[index])

Thank you for your help! I'm kind of stuck here and would greatly appreciate insight.


Solution

  • You can use the Table grouping functionality:

    In [2]: t = Table([[1, 2, 3, 4, 5, 6, 7, 8],
       ...:            [10, 11, 10, 10, 11, 12, 13, 12]],
       ...:            names=['a', 'id'])
    
    In [3]: tg = t.group_by('id')
    
    In [4]: tg.groups
    Out[4]: <TableGroups indices=[0 3 5 7 8]>
    
    In [6]: tg.groups.keys
    Out[6]: 
    <Table length=4>
      id 
    int64
    -----
       10
       11
       12
       13
    
    In [7]: np.diff(tg.groups.indices)
    Out[7]: array([3, 2, 2, 1])
    
    In [8]: tg
    Out[8]: 
    <Table length=8>
      a     id 
    int64 int64
    ----- -----
        1    10
        3    10
        4    10
        2    11
        5    11
        6    12
        8    12
        7    13
    
    In [9]: ok = np.zeros(len(tg), dtype=bool)
    
    In [10]: for i0, i1 in zip(tg.groups.indices[:-1], tg.groups.indices[1:]):
        ...:     if (i1 - i0) >= 3:
        ...:         ok[i0:i1] = True
        ...: tg3 = tg[ok]
        ...: tg3
        ...: 
    Out[10]: 
    <Table length=3>
      a     id 
    int64 int64
    ----- -----
        1    10
        3    10
        4    10
    
    In [12]: for tgg in tg.groups:
        ...:     if len(tgg) >= 2:
        ...:         print(tgg)  # or do something with it
        ...:         
     a   id
    --- ---
      1  10
      3  10
      4  10
     a   id
    --- ---
      2  11
      5  11
     a   id
    --- ---
      6  12
      8  12