Search code examples
typeclassassociated-typeslean

Associated types in Lean


In Rust I can define the following traits:

trait Iterator {
    type Item;
    fn next(&mut self) -> Option<Self::Item>;
}

trait Iterable {
    type Item;
    type Iterator: Iterator<Item=Self::Item>;
    fn iterator(self) -> Self::Iterator;
}

I can then define my own type CountToTen that implements the Iterable trait:

struct CountToTenIterator(u32);

impl Iterator for CountToTenIterator {
    type Item = u32;
    fn next(&mut self) -> Option<u32> {
        if self.0 < 10 {
            self.0 += 1;
            Some(self.0)
        } else {
            None
        }
    }
}

struct CountToTen;

impl Iterable for CountToTen {
    type Item = u32;
    type Iterator = CountToTenIterator;
    fn iterator(self) -> CountToTenIterator {
        CountToTenIterator(0)
    }
}

And I can use it like this:

fn print_items<I: Iterable<Item=u32>>(iterable: I) {
    let mut iterator = iterable.iterator();
    while let Some(x) = iterator.next() {
        println!("{}", x);
    }
}

fn main() {
    print_items(CountToTen);
}

How can I achieve the same thing in Lean? I can define a type class Iterator like this:

class Iterator (Self : Type) where
    Item : Type
    next : Self -> Prod Self (Option Item)

and CountToTenIterator like this:

structure CountToTenIterator where
    i : UInt32

instance : Iterator CountToTenIterator where
    Item := UInt32
    next self :=
        if self.i < 10 then
            let i := self.i + 1
            ({ i }, some i)
        else
            (self, none)

but how do I define Iterable such that its member Iterator is an instance of Iterator whose Item matches the Item of the Iterable?

class Iterable (Self : Type) where
    Item : Type
    Iterator : Type -- how do I constrain this to be an instance of Iterator?
    iterator : Self -> Iterator

Solution

  • I figured out how to do it. The associated types have to be output parameters:

    class Iterator (Self : Type) (Item : outParam Type) where
        next : Self -> Prod Self (Option Item)
    
    class Iterable (Self : Type) (Item : outParam Type) (IterableIterator : outParam Type) [Iterator IterableIterator Item] where
        iterator : Self -> IterableIterator
    
    structure CountToTenIterator where
        i : UInt32
    
    instance : Iterator CountToTenIterator UInt32 where
        next self :=
            if self.i < 10 then
                let i := self.i + 1
                ({ i }, some i)
            else
                (self, none)
    
    structure CountToTen
    
    instance : Iterable CountToTen UInt32 CountToTenIterator where
        iterator _self := { i := 0 }
    
    def print_items {I I2 : Type} [Iterator I2 UInt32] [Iterable I UInt32 I2] (iterable : I) : IO Unit := do
        let mut iterator := Iterable.iterator iterable
        repeat
            if let (next_iterator, some x) := Iterator.next iterator then
                println! "{x}"
                iterator := next_iterator
            else
                break
    
    def main : IO Unit := print_items CountToTen.mk