// Lambda-Terms in de Bruijn-Formulation
// with Type Inference
// Author: Andreas Abel

// Exceptions

class ContextLookupException extends Exception {
    ContextLookupException () { super(); }
}

class InferException extends Exception {
    InferException (String msg) { super(msg); }
}

// Types
//
// T ::= "A"        Base    base type
//     | Bool       Bool    base type "Bool"
//     | T -> T'    Fun     function type

interface Type {
    String toString();
    boolean equals(Type type);
}

class Base implements Type {
    final String name;
    Base (final String name) { this.name = name; }
    public String toString() { return name; }
    public boolean equals(Type type) { 
	if (type instanceof Base) return name.equals(((Base)type).name); 
	else return false;
    }
}

class Bool extends Base {
    Bool () { super ("Bool"); }
}

class Fun implements Type {
    final Type s, t;
    Fun (final Type s, final Type t) { this.s = s; this.t = t; }
    public String toString() { 
	return "(" + s.toString() + " -> " + t.toString() + ")";
    }
    public boolean equals (Type type) {
	if (type instanceof Fun) 
	    return s.equals(((Fun)type).s) && t.equals(((Fun)type).t);
	else return false;
    }
}

// Contexts
//
// A context is a liked list storing names and types of free variables.
//
// G ::= .        EmptyContext
//       G, x:T   ExtContext

interface Entry {
    String name ();
    Type type ();
}

interface Context {
    Entry lookup (final int n) throws ContextLookupException;
}

// Contexts implementation

class TypeEntry implements Entry {
    final String name;
    final Type type;
    TypeEntry (final String name, final Type type) {
	this.name = name;
	this.type = type;
    }
    public String name() { return name; }
    public Type type() { return type; }
}
    
class EmptyContext implements Context {
    EmptyContext () {}
    public Entry lookup (final int i) throws ContextLookupException { 
	throw new ContextLookupException(); 
    }
}

class ExtContext implements Context {
    final Context gamma;
    final Entry e;

    ExtContext (final Context gamma, final Entry e) { 
	this.gamma = gamma; 
	this.e = e; 
    }
    
    public Entry lookup (final int i) throws ContextLookupException {
	if (i>0) return gamma.lookup (i-1);
	else return e;
    }
}


// Expressions (Terms)
//
// Abstract expression trees:
//
// e ::= n           Var     de Bruijn variable with index n
//     | x:T.e       Bind    bind variable 0 with name x and type T in e
//     | f e1...ek   Symbol  node labelled "f" with child nodes e1...ek
//
// We use expression trees to implement terms
//
// t ::= n           Var     inherited from expression trees
//     | \x:T.t      Lam     implemented as Symbol("\",[Bind(x,T,t)])
//     | t1 t2       App     implemented as Symbol("", [t1,t2])

interface Exp extends Cloneable {
    String toString (Context gamma);
    Type infer (Context gamma) throws InferException;
}

class Var implements Exp {
    final int n;

    Var (final int n) { this.n = n; }

    public String toString (Context gamma) { 
	try {
	    return gamma.lookup(n).name(); 
	} catch (ContextLookupException except) {
	    return "!UNBOUND_VAR!";
	}
    }
    public Type infer (Context gamma) throws InferException {
	try {
	    return gamma.lookup(n).type(); 
	} catch (ContextLookupException except) {
	    throw new InferException ("Unbound variable.");
	}	
    }
}

class Bind implements Exp {
    final String name;
    final Type type;
    final Exp e;

    Bind (final String name, final Type type, final Exp e) { 
	this.name = name; this.type = type; this.e = e; 
    }

    public String toString (Context gamma) {
	gamma = new ExtContext (gamma, new TypeEntry (name, type));
	try {
	    Entry x = gamma.lookup(0);
	    return x.name() + " : " + x.type().toString() + " -> " + e.toString(gamma);
	} catch (ContextLookupException except) {
	    return name + ": " + type.toString() +  " -> " + e.toString(gamma);
	}
    }

    // infer like fix	
    public Type infer (Context gamma) throws InferException {
	gamma = new ExtContext (gamma, new TypeEntry (name, type));
	return e.infer (gamma);
    }
}

class Symbol implements Exp {
    String symbol;  // no "final" possible here
    Exp[] args;     // no "final" possible here

    Symbol () {}
    Symbol (final String symbol, final Exp[] args) { 
	this.symbol = symbol; 
	this.args = args;
    }

    public String toString (Context gamma) { 
	StringBuffer s = new StringBuffer (); 
	if (args.length > 0) {
	    s.append ("(");
	    s.append (symbol);
	    for (int i=0; i<args.length; i++) {
		if (i>0) s.append (" ");
		s.append (args[i].toString (gamma));
	    }
	    s.append (")");
	    return s.toString();
	} else return symbol;
    }

    Symbol shallowCopy () {
 	try { 
	    return (Symbol) this.clone(); 
	} catch (CloneNotSupportedException except) {
	    System.err.println ("Cloning of " + this.toString() + " failed");
	    System.exit (1);
        }
	return this; // fake
    }

    public Type infer (Context gamma) throws InferException {
	throw new InferException ("Symbol: infer not implemented");
    }
}

// Terms

class Lam extends Symbol {
    Lam (final String name, final Type type, final Exp e) {
	super(); // Constructor invocation must be the first thing in a method.
	this.symbol = "\\";
	this.args = new Exp[1];
	this.args[0] = new Bind (name, type, e);
    }
    public Type infer (Context gamma) throws InferException {
	return new Fun (((Bind)args[0]).type, args[0].infer (gamma));
    }
}

class App extends Symbol {
    App (final Exp f, final Exp e) {
	super();
	this.symbol = "";
	this.args = new Exp[2];
	this.args[0] = f;
	this.args[1] = e;
    }
    public Type infer (Context gamma) throws InferException {
	Type StoT = args[0].infer (gamma);
	if (StoT instanceof Fun) {
	    Type S = args[1].infer (gamma);
	    if (S.equals(((Fun)StoT).s)) 
		return ((Fun)StoT).t;
	    else throw new InferException ("App: Argument type does not match.");
	} else throw new InferException ("App: Not of function type.");
   }
}

class True extends Symbol {
    True () { super(); this.symbol = "true"; this.args = new Exp[0]; }
    public Type infer (Context gamma) throws InferException {
	return new Bool();
    }
}

class False extends Symbol {
    False () { super(); this.symbol = "false"; this.args = new Exp[0]; }
    public Type infer (Context gamma) throws InferException {
	return new Bool();
    }
}

class If extends Symbol {
    If (final Exp e1, final Exp e2, final Exp e3) {
	super();
	this.symbol = "if";
	this.args = new Exp[3];
	this.args[0] = e1;
	this.args[1] = e2;
	this.args[2] = e3;
    }
    public Type infer (Context gamma) throws InferException {
	if (args[0].infer(gamma).equals(new Bool())) {
	    Type T = args[1].infer(gamma);
	    if (T.equals(args[2].infer(gamma))) return T;
	    else throw new InferException ("Types of branches in if do not match."); 
	} else throw new InferException("Condition not a Boolean.");
    }
}

// Stub: test of implementation

class infer {
    static public void main (String[] args) throws ContextLookupException, InferException {
        Type Bool = new Bool(); 
	Type A = new Base("A");
	Type B = new Base("B");
	Type C = new Base("C");
	Exp id = new Lam ("x", Bool, new Var (0)) ;
	Exp k  = new Lam ("x", A, new Lam ("y", B, new Var(1))) ;
	Exp k_ = new Lam ("x'", Bool, new Lam ("y'", Bool, new Var(1))) ;
	Exp s  = new Lam ("x", new Fun (A, new Fun (B, C)), new Lam ("y", new Fun (A, B), new Lam ("z", A,  new App (new App (new Var(2), new Var(0)), new App (new Var(1), new Var(0)))))) ;
	Exp and = new Lam ("x", Bool, new Lam ("y", Bool, new If (new Var (1), new Var (0), new False())));
	Type BinBoolF = new Fun (Bool, new Fun (Bool, Bool));
	Exp AND = new Lam ("f", BinBoolF, new Lam ("g", BinBoolF, new Lam ("x'", Bool, new Lam ("y'", Bool, new App (new App (and, new App (new App (new Var (3), new Var(1)), new Var (0))), new App (new App (new Var (2), new Var (1)), new Var(0)))))));

	Context g = new EmptyContext ();
  	System.out.println ("id  = " + id.toString(g));
	System.out.println ("id  : " + id.infer(g).toString());
	System.out.println ("k   = " + k.toString(g));
	System.out.println ("k   : " + k.infer(g).toString());
	System.out.println ("s   = " + s.toString(g));
	System.out.println ("s   : " + s.infer(g).toString());
	System.out.println ("and = " + and.toString(g));
	System.out.println ("and : " + and.infer(g).toString());
	System.out.println ("AND = " + AND.toString(g));
	System.out.println ("AND : " + AND.infer(g).toString());

    }
}
