Search code examples
pythonalgorithmdata-structurescatalan

How to sample from all full binary trees?


I am interested in an algorithm to sample from all full binary trees that can work with large trees (1m nodes).

In other words: sample of rooted ordered full binary trees with a specified list of labels. The number of such trees with n leaves is given by Cn-1, from the Catalan number sequence.

A similar question was already asked. The following answer enumerates all full binary trees:

# A very simple representation for Nodes. Leaves are anything which is not a Node.
class Node(object):
  def __init__(self, left, right):
    self.left = left
    self.right = right

  def __repr__(self):
    return '(%s %s)' % (self.left, self.right)

# Given a list of labels, yield each rooted, unordered full binary tree with
# the specified labels.
def enum_unordered(labels):
  if len(labels) == 1:
    yield labels[0]
  else:
    for tree in enum_unordered(labels[1:]):
      for new_tree in add_leaf(tree, labels[0]):
        yield new_tree

>>> for tree in enum_ordered(("a","b","c","d", "e")): print tree
... 
(a (b (c (d e))))
(a (b ((c d) e)))
(a ((b c) (d e)))
(a ((b (c d)) e))
(a (((b c) d) e))
((a b) (c (d e)))
((a b) ((c d) e))
((a (b c)) (d e))
(((a b) c) (d e))
((a (b (c d))) e)
((a ((b c) d)) e)
(((a b) (c d)) e)
(((a (b c)) d) e)
((((a b) c) d) e)

But it breaks in the case of large trees:

next(iter(enum_ordered(1000*['a'])))
# RecursionError: maximum recursion depth exceeded in comparison

As ware suggested in the comments the size of all trees is huge and there is no hope of enumerating them. Can you please suggest how to implement a Distribution over all full binary trees and sample from it?


Solution

  • You could design a function that will take as argument the index number of the tree shape with the required number of leaves.

    Then, this function would not generate all trees from index 0 up to that index, but would instead count the number of left and right subtrees there are when choosing a label-partitioning into two. For this the Catalan numbers can be used. By trying several partitions this function can find out which of those would have the tree with the requested index. So it would "jump" over big chunks of indices.

    Since you are working with large trees and thus with very large Catalan numbers, these calculations have to happen with Big Integer types as the typical 4-byte or 8-byte integers will not be large enough.

    In Python this is not an issue (as int types are big integers natively), but this is something to be careful with in some other languages

    Here is how it could look:

    catalans = [1]  # Memoization of previously calculated Catalan numbers
    
    def get_catalan(n):
        for i in range(len(catalans) - 1, n):
            catalans.append(catalans[-1] * (4*i + 2) // (i + 2))
        return catalans[n]
    
    def tree(labels, start, end, tree_index):
        if end - start == 1:
            return labels[start]
        for root in range(end - 1, start, -1):
            num_left = get_catalan(root - start - 1)
            num_right = get_catalan(end - root - 1)
            count = num_left * num_right
            if tree_index < count:
                tree_left = tree(labels, start, root, tree_index % num_left)
                tree_right = tree(labels, root, end, tree_index // num_left)
                return f"({tree_left},{tree_right})"
            tree_index -= count
        raise ValueError("tree_index is out of range")
    
    
    # Demo
    labels = list(map(str, range(1000)))
    num_trees = get_catalan(len(labels) - 1)
    print(f"There are {num_trees} trees for these labels")
    index = num_trees * 7 // 22  # Some selected index of a tree
    print(tree(labels, 0, len(labels), index))  
    

    This demo prepares for generating trees with 1000 leaves, and then picks one number from the range 0..num_trees. So the problem of sampling is then moved to sampling indexes from the range 0..num_trees, and then for every selected index, you can call tree to get the corresponding tree.

    The output of the above demo is:

    There are 512294053774259558362972111801106814506359401696197357133662490663268680890966422168317407249277190145438911035517264555381561230116189292650837306095363076178842645481320822198226994485371813976409676367032381831285411152247284028125396742405627998638503788368259307920236258027800099771751391617605088924033394630230806037178021722568614945945597158227817488131642780881551702876651234929533423690387735417418121162690198676382656195692212519230804188796272372873746380773141117366928488415626459630446598074332450038402866155063023175006229242447751399777865500335793470023989772130248615305440 trees for these labels
    (((0,(((((((1,((2,((3,((4,(5,(6,7))),8)),((((((9,((10,11),12)),(13,14)),((((15,16),17),(18,19)),(((20,21),(22,23)),((24,((((25,((26,27),28)),((29,30),(31,32))),((33,34),35)),36)),(37,38))))),((39,40),((41,42),43))),44),45))),((46,47),48))),49),(((((50,(((((51,((52,53),54)),55),56),((57,(58,(59,(((60,61),(62,(63,(((64,65),66),(67,68))))),69)))),((70,(71,(72,(73,74)))),75))),76)),(77,(78,((79,80),(81,((82,(83,84)),(85,86))))))),87),88),((89,90),91))),92),93),(((((((94,95),(96,((((97,(98,(((((99,100),(101,((102,103),((104,(((105,((106,(107,(108,109))),110)),111),(112,(113,((114,((115,(116,(117,(118,((((119,((120,121),122)),123),(124,(125,(126,(127,(128,129)))))),((((130,131),((132,133),(134,(135,(((136,((137,138),139)),140),(141,142)))))),(143,(144,((145,146),147)))),148)))))),149)),150))))),((151,152),((153,154),((155,156),157))))))),158),(((159,160),(161,162)),163)),((((164,((((165,(166,167)),168),((((169,(170,((((171,172),173),174),(175,(((((176,177),(178,(179,180))),181),182),((183,(((184,185),((186,(187,((188,189),(190,191)))),(((192,193),194),((195,196),197)))),(198,(199,200)))),(201,(202,203)))))))),204),(205,(206,207))),((((((((208,209),210),(211,212)),(213,(214,215))),216),((217,(218,219)),220)),(((221,(222,(223,((((224,225),226),(((227,(228,229)),230),((231,(232,233)),234))),(((235,236),(237,238)),239))))),240),241)),(((((242,243),(244,((245,246),247))),(248,249)),250),((251,((252,(253,254)),((255,256),257))),(((258,(259,(260,261))),((262,(263,(((((264,265),((266,267),(268,269))),270),(271,272)),273))),274)),275)))))),276)),277),((278,279),(280,281))),(282,(283,284)))))),(285,(286,287))),288),(289,290)))),291),(292,293)),294),((295,(((((296,297),((((298,((299,((300,((301,302),303)),304)),((305,(306,(307,((308,(309,(((310,(311,312)),(313,314)),315))),316)))),(((317,((318,((319,320),321)),322)),323),((((324,325),326),((327,328),(329,(330,((331,((((((332,333),((334,335),((336,337),338))),(339,(340,341))),342),((((343,344),345),((346,347),(348,((349,((350,(351,352)),353)),((((354,(355,356)),357),((358,359),360)),361))))),362)),(((363,364),(365,((((366,(367,368)),(369,((((370,371),((372,373),(374,(375,(376,((377,378),379)))))),(380,((381,(382,383)),384))),385))),386),387))),388))),(389,390)))))),391))))),((((392,(393,((((394,395),(396,397)),398),399))),400),(((((401,402),((403,((404,405),(406,(407,408)))),409)),410),(411,(412,413))),((414,415),416))),(417,((418,(((419,(420,((421,((422,423),((424,(425,426)),((427,(428,429)),(((430,431),432),(433,(434,435))))))),436))),437),(438,((439,(440,441)),((((442,(443,(((((444,(445,(446,(447,(448,449))))),(((450,451),452),453)),454),(455,456)),(((457,(458,(459,((460,(((461,((462,(463,464)),(465,466))),467),468)),469)))),(((((470,((471,472),((473,474),(475,476)))),(((477,478),(((479,(480,481)),482),(483,(484,(((485,((486,(487,488)),((((489,490),491),(492,((493,494),(495,(((((496,497),(498,((((((499,500),501),502),(503,(504,(505,(506,((((507,(508,(((509,510),511),512))),(513,(514,(515,516)))),((517,518),519)),(520,521))))))),(522,523)),524))),(525,526)),(527,(528,529))),(530,531)))))),(532,533)))),(534,535)),536))))),(((537,((((538,539),540),(541,(542,543))),544)),(545,546)),(((547,(((548,549),550),551)),552),(((553,554),555),(556,(557,(558,559)))))))),560),(561,(((562,(563,564)),(565,566)),567))),(568,569))),((((570,571),572),573),574))))),((575,576),577)),578),((579,(580,((((581,(582,583)),584),(585,586)),587))),588)))))),((589,(590,(((591,592),(593,594)),595))),596))))),597),((598,599),600))),(((601,602),((603,(604,605)),(((((606,607),((608,609),610)),((611,(((612,((613,((614,615),616)),(617,618))),((619,((620,621),((((((622,623),624),((((625,(((626,627),(628,(629,630))),631)),(632,(633,((634,635),(((636,((637,((638,639),((640,641),((642,643),(644,645))))),646)),((647,648),649)),((((650,(651,((652,653),(654,((655,(656,(657,(658,659)))),660))))),(661,662)),663),664)))))),((665,666),667)),668)),(((669,670),(((671,(672,((673,674),675))),676),(((677,678),(679,680)),681))),682)),683),(684,685)))),((686,((687,((688,(689,(690,((((691,((692,(693,(694,695))),(696,(697,(698,(699,700)))))),701),702),(((703,((704,((705,706),707)),((708,(709,710)),711))),712),((713,714),(((((715,((716,(717,((718,((719,(720,(721,((722,723),724)))),(725,(((726,(((727,((728,(729,730)),731)),732),733)),((((734,735),736),737),738)),739)))),740))),((((741,742),743),744),745))),746),(747,((748,749),(750,751)))),(752,((((753,((754,(((((755,756),757),((758,759),(((760,(((761,((((((762,763),764),765),(((((766,767),(768,((((769,770),771),((((772,773),(774,775)),(776,(777,(778,779)))),(780,(781,(782,783))))),(784,785)))),786),(787,((788,(((789,(((((((790,((((791,(792,793)),794),795),796)),797),798),(799,800)),((((801,(802,803)),804),805),806)),807),((808,((((809,(810,811)),((812,813),(814,(815,816)))),(817,818)),(819,820))),821))),(822,(823,((824,(((825,826),(827,(((828,829),(((830,((831,(832,(((833,((834,(((835,(((836,((837,838),(839,840))),(841,842)),843)),844),845)),846)),847),((848,849),850)))),851)),(852,853)),854)),855))),856)),(857,858))))),859)),860))),((861,((862,((863,864),865)),866)),(867,868)))),869),870)),(871,(872,873))),874)),(875,((876,877),878))),879))),(880,881)),(((((882,883),(884,885)),886),887),888))),(889,(890,891)))),892),893),(894,(895,896))))),((((897,898),(899,900)),901),(902,(((903,((904,((((905,906),(907,908)),909),(((910,(911,912)),(913,((914,915),(916,917)))),918))),919)),920),921)))))))))),922)),923)),((((((924,(925,926)),(927,((928,(929,(930,(931,(932,(((933,(934,(935,936))),937),(938,(939,940)))))))),((((941,(((942,(943,944)),((945,946),((947,(948,949)),950))),951)),((952,953),954)),955),956)))),957),958),959),(960,961))))),(962,963))),964)),965),966))),967)),968),969)),(970,(971,972)))),(973,(974,(975,((((((976,977),978),(979,980)),((981,((982,983),984)),985)),(((986,987),988),989)),990)))))),((991,992),993))),((994,995),996)),((997,998),999))
    

    Still the recursion depth will be the height of the generated tree, so in case the tree is skewed, this height can still be problematic for the stack. In that case, increase the available stack size.

    When you use a random generator, make sure it is fine-grained enough to have a probability of generating each number in that range.