Search code examples
pythonalgorithmfunctional-programminglambda-calculuscombinators

How to implement a fast type inference procedure for SKI combinators in Python?


How to implement a fast simple type inference procedure for SKI combinators in Python?

I am interested in 2 functions:

  1. typable: returns true if a given SKI term has a type (I suppose it should work faster than searching for a concrete type).

  2. principle_type: returns principle type if it exists and False otherwise.


typable(SKK) = True
typable(SII) = False # (I=SKK). This term does not have a type. Similar to \x.xx

principle_type(S) = (t1 -> t2 -> t3) -> (t1 -> t2) -> t1 -> t3
principle_type(K) = t1 -> t2 -> t1
principle_type(SK) = (t3 -> t2) -> t3 -> t3
principle_type(SKK) = principle_type(I) = t1 -> t1  

Theoretical questions:

  1. I read about Hindley–Milner type system. There are 2 algs: Algorithm J and Algorithm W. Do I understand correctly that they are used for more complex type system: System F? System with parametric polymorphism? Are there combinators typable in System F but not typable in the simple type system?

  2. As I understand, to find a principle type we need to solve a system of equations between symbolic expressions. Is it possible to simplify the algorithm and speed up the process by using SMT solvers like Z3?

My implementation of basic combinators, reduction and parsing:

from __future__ import annotations
import typing
from dataclasses import dataclass


@dataclass(eq=True, frozen=True)
class S:
    def __str__(self):
        return "S"

    def __len__(self):
        return 1


@dataclass(eq=True, frozen=True)
class K:
    def __str__(self):
        return "K"

    def __len__(self):
        return 1


@dataclass(eq=True, frozen=True)
class App:
    left: Term
    right: Term

    def __str__(self):
        return f"({self.left}{self.right})"

    def __len__(self):
        return len(str(self))


Term = typing.Union[S, K, App]


def parse_ski_string(s):
    # remove spaces
    s = ''.join(s.split())

    stack = []
    for c in s:
        # print(stack, len(stack))
        if c == '(':
            pass

        elif c == 'S':
            stack.append(S())
        elif c == 'K':
            stack.append(K())
        # elif c == 'I':
        #     stack.append(I())

        elif c == ')':
            x = stack.pop()
            if len(stack) > 0:
                # S(SK)
                f = stack.pop()
                node = App(f, x)
                stack.append(node)
            else:
                # S(S)
                stack.append(x)
        else:
            raise Exception('wrong c = ', c)

    if len(stack) != 1:
        raise Exception('wrong stack = ', str(stack))

    return stack[0]


def simplify(expr: Term):
    if isinstance(expr, S) or isinstance(expr, K):
        return expr

    elif isinstance(expr, App) and isinstance(expr.left, App) and isinstance(expr.left.left, K):
        return simplify(expr.left.right)

    elif isinstance(expr, App) and isinstance(expr.left, App) and isinstance(expr.left.left, App) and isinstance(
            expr.left.left.left, S):
        return simplify(App(App(expr.left.left.right, expr.right), (App(expr.left.right, expr.right))))

    elif isinstance(expr, App):
        l2 = simplify(expr.left)
        r2 = simplify(expr.right)
        if expr.left == l2 and expr.right == r2:
            return App(expr.left, expr.right)
        else:
            return simplify(App(l2, r2))

    else:
        raise Exception('Wrong type of combinator', expr)

# simplify(App(App(K(),S()),K())) = S
# simplify(parse_ski_string('(((SK)K)S)')) = S

Solution

  • Simple, maybe not the fastest (but reasonably fast if the types are small).

    from dataclasses import dataclass
    
    
    class OccursError(Exception):
        pass
    
    
    parent = {}
    
    Var = int
    
    
    def new_var() -> Var:
        t1 = Var(len(parent))
        parent[t1] = t1
        return t1
    
    
    @dataclass
    class Fun:
        dom: "Var | Fun"
        cod: "Var | Fun"
    
    
    def S() -> Fun:
        t1 = new_var()
        t2 = new_var()
        t3 = new_var()
        return Fun(Fun(t1, Fun(t2, t3)), Fun(Fun(t1, t2), Fun(t1, t3)))
    
    
    def K() -> Fun:
        t1 = new_var()
        t2 = new_var()
        return Fun(t1, Fun(t2, t1))
    
    
    def I() -> Fun:
        t1 = new_var()
        return Fun(t1, t1)
    
    
    def find(t1: Var | Fun) -> Var | Fun:
        if isinstance(t1, Var):
            if parent[t1] == t1:
                return t1
            t2 = find(parent[t1])
            parent[t1] = t2
            return t2
        if isinstance(t1, Fun):
            return Fun(find(t1.dom), find(t1.cod))
        raise TypeError
    
    
    def occurs(t1: Var, t2: Var | Fun) -> bool:
        if isinstance(t2, Var):
            return t1 == t2
        if isinstance(t2, Fun):
            return occurs(t1, t2.dom) or occurs(t1, t2.cod)
        raise TypeError
    
    
    def unify(t1: Var | Fun, t2: Var | Fun):
        t1 = find(t1)
        t2 = find(t2)
        if isinstance(t1, Var) and isinstance(t2, Var):
            parent[t1] = t2
        elif isinstance(t1, Var) and isinstance(t2, Fun):
            if occurs(t1, t2):
                raise OccursError
            parent[t1] = t2
        elif isinstance(t1, Fun) and isinstance(t2, Var):
            if occurs(t2, t1):
                raise OccursError
            parent[t2] = t1
        elif isinstance(t1, Fun) and isinstance(t2, Fun):
            unify(t1.dom, t2.dom)
            unify(t1.cod, t2.cod)
        else:
            raise TypeError
    
    
    def apply(t1: Var | Fun, t2: Var | Fun) -> Var | Fun:
        t3 = new_var()
        unify(t1, Fun(t2, t3))
        return t3
    
    
    try:
        a = S()
        b = K()
        ab = apply(a, b)
        c = K()
        abc = apply(ab, c)
        print("#", find(abc))
    except OccursError:
        print("# no type")
    
    try:
        a = S()
        b = I()
        ab = apply(a, b)
        c = I()
        abc = apply(ab, c)
        print("#", find(abc))
    except OccursError:
        print("# no type")
    
    # Fun(dom=6, cod=6)
    # no type