Proof by Reflection (and why not to use it)

Timothy Mou - Wed 27 March 2024 - lean, proof-assistants

Let's say you frequently need to solve equalities on monoids, and you find yourself proving lots of goals like the following:

example [Monoid α]:  (a b c d : α), 
  a * b * c * d * 1 = a * (b * c) * (1 * 1 * d) := by
  intros a b c d
  rw [mul_one]
  rw [mul_one]
  rw [one_mul]
  rw [mul_assoc a b _]

These proofs are straightforward and follow directly from the monoid laws, but they can be long and tedious. You feel like you are applying the same strategy to each goal, but still have to come up with a new proof each time. You get tired of writing out proofs by hand, so you turn to tactics to solve equalities for you. How would you go about doing so?

The purpose of this post is to introduce various approaches to writing such a tactic in Lean, as well as the factors that go into deciding which approach to use. These methods are widely applicable to many domains far more interesting and applicable than monoids, such as commutative groups, rings, and linear integer arithmetic (e.g., omega, ring, linarith, and friends).

Equality via Normalization

One common motif is to use normalization as a method of checking equality. The idea goes something like this: suppose you want to solve equalities on a type A, and we have a normalization function normA : A -> A that maps all elements to a canonical element in its equivalence class. 1

Normalization example

For instance, if we wish to check if two lists of integer contain the same multiset of elements, we can normalize them by sorting, and comparing the sorted lists for equality.

If we have a monoid expression, we can find a canonical form for it by associating all the way to the left/right and removing identity elements. For instance,

normA (a * (b * c * 1) * 1) = ((a * b) * c)

and

normA (a * (1 * b) * c) = ((a * b) * c)

We can check if two expressions are equal by normalizing both sides and seeing if their canonical forms are equal.

The problem with applying this idea is that for many cases, we are not able to write such a function normA. This often requires pattern matching on an expression of type A, which we cannot do for an arbitrary type. One way to get around this issue is to use a type B that "reifies" expressions of type A, allowing us to write a reification function A -> B and an explicit normalization function normB : B -> B. If we can write a "denotation" function denote : B -> A, we can compose these together to get a normalization function A -> A.

For instance, we can define the following inductive type exp α that represents expressions of a monoid of type α:

inductive exp α where
  | atom : α  exp α
  | mult : exp α  exp α  exp α
  | mempty : exp α

Each constructor corresponds to a method of forming a monoid expression. A monoid expression can either consist of the identity element, combining two subexpressions with the monoid operation, or a "atom"--a catch-all constructor for any other kind of expression.

We can easily write a denotation function to convert from an exp α to an α:

def exp.denote [Monoid α] (e : exp α) : α :=
  match e with
  | atom x => x
  | mult x y => exp.denote x * exp.denote y
  | mempty => 1

However, writing a reification function α -> exp α is not as straightforward. In fact, it runs into our original problem: we cannot pattern match on an expression of type α to begin with. Such a function cannot be written in Lean itself. Instead, we must use Lean's meta-language, which allows us to directly manipulate the ASTs of the expressions.

Reifying Monoid Expressions: An Interlude on Lean Metaprogramming

Lean 4 offers a metaprogramming interface similar to Agda's. Writing tactics, building expressions, and other metaprogramming tasks are done in a set of monads, which give access to things like current declarations and goals. For a further exploration of metaprogramming in Lean, please see the book Metaprogramming in Lean 4.

Expressions at the meta level are terms of the inductive type Expr, which represents the abstract syntax of a Lean program. Exprs can be scrutinized, constructed, and manipulated just like any other inductive type.

Let's see how we can use Lean's meta-language to write our reification function:

partial def exp.reify (e : Expr) : M Expr := do
  match getAppFnArgs e with
  | (``HMul.hMul, #[_, _, _, _, e₁, e₂]) => do
    let e₁'  exp.reify e₁
    let e₂'  exp.reify e₂
    return ( read).app ``exp.mult #[e₁', e₂']
  | _ =>
    if  isDefEq e ( read).α1 then
      return ( read).app ``exp.mempty #[]
    else
      return ( read).app ``exp.atom #[e]

We are using a custom monad built on top of MetaM that allows us to assume that we have things like the current monoid instance we're working with in scope. At the heart of this function, we are able to pattern match on the shape of the expressions e using the getAppFnArgs function, which returns the function name and arguments of an expression.

  • In the first case, we check if e is constructive from multiplying two subexpressions. This case corresponds to the exp.mult constructor, so we recursively call reify on the subexpressions and use the results as arguments in the exp.mult constructor.
  • In the second case, we check if e is definitionally equal to the identity of our monoid instance. This case corresponds to the exp.mempty constructor.
  • The remaining case is our default option, which is to call e an atom and use the exp.atom constructor.

A Reflection Theorem

With this inductive representation exp α in hand, we can now write the normalization function we initially envisioned. Instead of returning an exp α, we can use a more suitable type to represent normalized expressions. Since associativity allows us to effectively ignore all parentheses we can represent a normalized expression with a List α (technically, a list without identity elements, but this is slightly more difficult to write without much benefit). We can then write the normalization procedure, by removing all mempty's and recursively flattening and then combining subexpressions:

def exp.normalize : exp α  List α
  | atom x => [x]
  | mult x y => exp.normalize x ++ exp.normalize y
  | mempty => []

Finally, we must have a method of going back to our original type α, i.e., a denotation function List α -> α:

def nf.denote [Monoid α] : List α  α
  | [] => 1
  | x :: xs => x * nf.denote xs

The relationship between α, exp α, and List α is shown in this example:

Reflection diagram

The most important property of our normalize function is that is preserves the meaning of the reified expression, in the sense that nf.denote (exp.normalize e) = exp.denote e. This can be stated in this theorem:

theorem normalize_correct [Monoid α] (e : exp α) : 
  nf.denote (exp.normalize e) = exp.denote e

The proof is by a straightforward induction argument.

We derive the following "reflection theorem" as a corollary:

theorem monoid_reflect [Monoid α] (a b : exp α) :
  nf.denote (exp.normalize a) = nf.denote (exp.normalize b) 
  exp.denote a = exp.denote b := by
  repeat simp [normalize_correct]

We can apply this theorem to solve equalities. Suppose that we have a goal a1 = a2, and we have exp.denote a = a1 and exp.denote b = a2. Then our goal is definitionally equal to exp.denote a = exp.denote b, and is of the form of the conclusion of monoid_reflect. Applying this theorem gives us the goal nf.denote (exp.normalize a) = nf.denote (exp.normalize b), which is hopefully easier to prove. In the case where the normal forms of a and b are equal, this goal can be trivially discharged with reflexivity.

For instance, consider this lemma:

lemma ex2' :  (a b c : ), 
  a * 1 * b * (a + c) = a * (b * (a + c)) := by
  open exp in
  intros a b c
  exact (monoid_reflect
    (mult (mult (mult (atom a) mempty) (atom b)) (atom (a + c)))
    (mult (atom a) (mult (atom b) (atom (a + c))))
    rfl)

Here, we have manually performed the reification of the expressions a * 1 * b * (a + c) and a * (b * (a + c)). This demonstrates that we can use our monoid_reflect theorem without any reliance on tactics or metaprogramming at all, if we are willing to do the reification ourselves.

However, it is much more convenient to use our reification function. Here is how we would combine it with our reflection theorem and package them into a tactic:

syntax (name := monoid) "monoid" : tactic
elab_rules : tactic | `(tactic| monoid) => withMainContext do
    let some (_, e₁, e₂) := ( whnfR <|  getMainTarget).eq?
      | throwError "monoid: requires an equality goal"
    let c  mkContext e₁
    closeMainGoal <|  AtomM.run .default <| ReaderT.run (r := c) do
      let t₁  exp.reify e₁ -- Reify the expressions into exp
      let t₂  exp.reify e₂
      let n₁ := ( read).app ``exp.normalize #[t₁] -- Normalize the expressions
      let n₂ := ( read).app ``exp.normalize #[t₂]
      unless  isDefEq n₁ n₂ do
        throwError "monoid: normalized forms not equal"
      let m₁ := ( read).appInst ``nf.denote #[n₁]
      let eq  mkAppM ``Eq.refl #[m₁]
      mkAppM ``monoid_reflect #[t₁, t₂, eq] -- apply the reflect theorem

This is an instance of "proof by reflection", which is well introduced in the book Certified Programming with Dependent Types. This style of proof has several advantages:

  • Every proof using the monoid tactic shares the same monoid_reflect proof. Instead of having to write a new, bespoke proof for each new goal, we have proven one theorem which can be used repeatedly to solve many different goals.

  • Having the normalize_correct and monoid_reflect theorems gives us greater confidence that our normalization procedure is correct. In the case of monoids, it's not difficult to convince yourself that this normalization procedure should work, but we can imagine more complicated theories such as rings or linear integer arithmetic where finding a normal form could involve far more complicated algorithms and reasoning.

  • There is a high degree of separation of concerns between the metaprogramming used for reification and the "business logic" of performing the actual normalization. This is desirable because writing tactics is less familiar to most Lean users and has far less type safety. It is easy to construct Exprs that are completely meaningless when evaluated. By limiting the metaprogramming to only the initial reification step, we are able to reduce the number of parts of the tactic that rely on metaprogramming, which tend to be (in my experience) harder to write correctly and maintain.

However, proofs by reflection do have some disadvantages as well:

  • Proofs by reflection can be thought of as "converting validity of logical statements into symbolic computation", and relying on the kernel to do symbolic computation can be slow, particularly in Lean.

    This can be partially offset by using the Lean interpreter rather using the kernel. This unlocks much more performant evaluation at the cost of relying on a much larger code base (the Lean compiler and interpreter). For more information, see reduceBool.

  • Having to prove that the normalization procedure is correct can be restricting in some cases. For instance, one could imagine in certain domains, a normal form cannot always be directly computed, but has to be found using some form of search. In order to use write a reflective proof, you are limited to normalization procedures for which you can write a correctness theorem.

A Alternative Approach

An another way of writing a monoid tactic is exemplified by the abel tactic from mathlib. Here, in what I'll call the "proof-producing" style, we write a normalization function, but we do it in the meta-language. Given an expression, we can construct a proof that it is equal to its normal form. If two expressions have the same normal form, we can prove they are equal to each other by chaining together the equality proofs with transitivty.

(Note: I will use the notation e : Q(α) to mean that e : Expr is a quoted form of an expression of type α.)

More precisely, we write a function eval : Expr -> M (NF x Expr), where NF is a type that represents normal forms, and M is a monad where we have access to the standard metaprogramming capabilities (MetaM). eval e returns a pair (nf, p), where nf is an element of NF such that nf.e : Q(α) is an Expr that is the normalized form of e, and p : Q(e = nf.e) is a proof that that e = nf.e. In other words, eval takes an expression, and returns its normal form, along with a proof that the original expression is equal to the expression that the normal form represents.

Let's look at how we might write such a function for our monoid example. We first define a type of normal forms NF:

inductive NF where
  | mult : Expr  Expr  NF  NF
  | mempty : Expr  NF
  deriving Inhabited

def NF.e : NF  Expr
  | mult e .. => e
  | mempty e .. => e

NF is effectively a linked list, where each cell contains an Expr that is the quoted form of the normalized expression it represents.

  • mempty is the normal form of an identity element. mempty.e should be an Expr representing 1, the identity element of the monoid.
  • mult a b c is the normal form constructed by multiplying the b : Q(α) with the normal form c : NF. It should hold that (mult a b c).e = b * c.e.

If you squint, you'll notice that this representation is quite similar to the type of normal forms we were using earlier, List α, except we are leaving the expressions as quoted Exprs and doing some extra bookkeeping to remember the Expr that each normal form represents.

Now we are able to write our eval function:

partial def eval (e : Expr) : M (NF × Expr) := do
  match getAppFnArgs e with
  | (``HMul.hMul, #[_, _, _, _, e₁, e₂]) => do
    let (e₁', p₁)  eval e₁
    let (e₂', p₂)  eval e₂
    let (e', p')  evalMult e₁' e₂'
    return (
      e', 
      ( read).appInst ``subst_into_mult #[e₁, e₂, e₁'.e, e₂'.e, e'.e, p₁, p₂, p'])
  | _ =>
    if  isDefEq e ( read).α1 then
      return (NF.mempty e,  mkEqRefl ( read).α1)
    else
      mkSingle e

The shape of this function is very similar to the reify function we wrote earlier, but instead of directly reifying e into an inductive type, we perform the normalization while recursing on the structure of e.

Let's focus on the first and only non-trivial case, where e = e₁ * e₂. We first recursively call eval on e₁ to obtain e₁' : NF and p₁ : Q(e₁ = e₁'.e), and analogously for e₂. Then, we merge these two normal forms together using evalMult 2 to obtain e' : NF and p' : Q(e'.e = e₁'.e * e₂'.e). Now e' is our desired normal form, but recall that we also need to produce a proof that e = e'.e, which can be shown with a series of rewriting steps using p₁, p₂, and p'. This can be packaged up into the following lemma:

lemma subst_into_mult {α} [Monoid α] (l r tl tr t)
  (prl : (l : α) = tl)
  (prr : r = tr)
  (prt : tl * tr = t) :
  l * r = t := by
  rw [prl, prr, prt]

We then call this lemma (instantiating the monoid instance from our context) to produce the desired proof.

Let's look at how we would tie this together into a tactic:

syntax (name := monoid) "monoid" : tactic
elab_rules : tactic | `(tactic| monoid) => withMainContext do
    let some (_, e₁, e₂) := ( whnfR <|  getMainTarget).eq?
      | throwError "monoid: requires an equality goal"
    let c  mkContext e₁
    closeMainGoal <|  AtomM.run .default <| ReaderT.run (r := c) do
      let (e₁', p₁)  eval e₁
      let (e₂', p₂)  eval e₂
      unless  isDefEq e₁'.e e₂'.e do
        throwError "monoid: could not find that the two sides are equal"
      mkEqTrans p₁ ( mkEqSymm p₂)

Similar to our previous tactic, we first identify the two expressions we want to show are equal, and then attempt to synthesize a Monoid instance. We then compute the normal forms of e₁ and e₂ and check if they are definitionally equal. If so, then e₁.e = e₂.e, and since p₁ : Q(e₁ = e₁.e) and p₂ : Q(e₂ = e₂.e), then we can combine these two by symmetry/transitivity to obtain a proof that e₁ = e₂, which we use to solve the goal.

Differences:

  • I find this style much more difficult to write. The primary reason is that it relies on constructing quoted equality proofs, but doing so in an essentially untyped manner. For instance, when writing the eval function, one could easily provide too many/too few arguments to the subst_into_mult function, or swap the order of two arguments, and there would not be a way to detect the error until the tactic is actually run.
  • Proof terms can be much larger. Compare the outputs of the two tactics on the goal ∀ (a b c : ℕ), a * 1 * b * (a + c) = a * (b * (a + c)):

    Reflective tactic

    theorem Monoid.ex2 :  (a b c : ), a * 1 * b * (a + c) = a * (b * (a + c)) :=
      fun a b c     monoid_reflect (exp.mult (exp.mult (exp.mult (exp.atom a) exp.mempty) (exp.atom b)) (exp.atom (a + c)))
          (exp.mult (exp.atom a) (exp.mult (exp.atom b) (exp.atom (a + c))))
          (Eq.refl
            (nf.denote
              (exp.normalize (exp.mult (exp.mult (exp.mult (exp.atom a) exp.mempty) (exp.atom b)) (exp.atom (a + c))))))
    

    Proof producing tactic

    theorem Monoid.ex2 :  (a b c : ), a * 1 * b * (a + c) = a * (b * (a + c)) :=
      fun a b c   (subst_into_mult (a * 1 * b) (a + c) (a * (b * 1)) ((a + c) * 1) (a * (b * ((a + c) * 1)))
            (subst_into_mult (a * 1) b (a * 1) (b * 1) (a * (b * 1))
              (subst_into_mult a 1 (a * 1) 1 (a * 1) (Monoid.mul_one a).symm (Eq.refl 1) (Monoid.mul_one (a * 1)))
              (Monoid.mul_one b).symm (left_mult a 1 (b * 1) (b * 1) (Monoid.one_mul (b * 1))))
            (Monoid.mul_one (a + c)).symm
            (left_mult a (b * 1) ((a + c) * 1) (b * ((a + c) * 1))
              (left_mult b 1 ((a + c) * 1) ((a + c) * 1) (Monoid.one_mul ((a + c) * 1))))).trans
        (subst_into_mult a (b * (a + c)) (a * 1) (b * ((a + c) * 1)) (a * (b * ((a + c) * 1))) (Monoid.mul_one a).symm
            (subst_into_mult b (a + c) (b * 1) ((a + c) * 1) (b * ((a + c) * 1)) (Monoid.mul_one b).symm
              (Monoid.mul_one (a + c)).symm (left_mult b 1 ((a + c) * 1) ((a + c) * 1) (Monoid.one_mul ((a + c) * 1))))
            (left_mult a 1 (b * ((a + c) * 1)) (b * ((a + c) * 1)) (Monoid.one_mul (b * ((a + c) * 1))))).symm
    

    The size of the proof term produced by the reflective tactic is linear in the size of the terms, while the proof term of the proof producing tactic scales with the number of rewriting steps performed (essentially, there is no difference from writing out all the rewriting steps explicitly). It should be noted, however, that the size of a proof term is not always indicative of how efficiently a proof term can be checked.

  • As mentioned before, tactics are more flexible, and don't need to prove their correctness in order to be used.

  • Meta code (outside the kernel) is more efficient than evaluation within kernel.

Conclusion

I have presented three different methods of writing tactics for solving equalities: (1) reflective tactics, (2) reflective tactics using native evaluation, and (3) proof producing tactics. (For brevity I will refer to these as P1, P2, and P3 respectively.)

I propose several metrics to evaluate these methods:

  • Trusted code base, or which systems do we need to believe in their correctness in order to accept a proof: Both P1 and P3 rely on only the correctness of the Lean kernel, while P2 also uses the entire Lean compiler and interpreter, a much larger code base. As noted here, however, you will most likely have to trust the Lean compiler and interpreter anyways, so by using P2, we primarily lose the ability to check our proof with other third-party checkers.

  • Efficiency: Because P1 heavily relies on Lean's kernel reduction, it suffers in performance. In a comparison between P1 and P3-style tactics for ring, it was found that the P1 version was slower by 20-50%. Comparisons in efficiency between P2 and P3-style tactics have not been done, as far as I know.

  • Expressiveness: As discussed previously, reflective tactics are restricted by the need to prove their normalization procedure correct. P3 tactics don't have this limitation, and can perform arbitrary and unrestricted computation outside the kernel.

  • Ease of implementation and maintenance: This is the most subjective metric, but one that I think is quite important. P1 and P2-style tactics limit the amount of metaprogramming code to the bare minimum. Writing a reification function is simpler and less error prone than writing a P3 tactic, which has to use meta functions from the beginning. Limiting the scope of metaprogramming also makes updating and debugging far easier, in my opinion.

So, which method should you use? I don't think there is a clear choice to make, but we can look to existing libraries and tactics to get a sense of what others have done. The mathlib library includes several P3 tactics such as abel and ring. On the other hand, P2 tactics are used in the leansat library.

In my opinion, if you are writing a tactic for your use in your own projects, I would suggest writing a P1 tactic if possible, because I feel it is the simplest and least error-prone method.

Acknowledgements

Thanks to Scott Morrison, Henrik Böving, and Mario Carneiro for helpful discussions on this topic.

Appendix

See the following code snippets for details on how to write the monoid tactic.


  1. Technically, we only require that the equivalence relation induced by normA be finer than the equality relation (i.e., normA may map to equivalent elements to different canonical elements). For instance, a trivial normalization function would map every element to itself. In such cases, we only obtain a semi-decision procedure for checking equality. 

  2. Another meta helper function that produces a normal form by performing the analogous version of list concatenation on NFs. See the code for details.