Let's say you frequently need to solve equalities on monoids, and you find yourself proving lots of goals like the following:
example [Monoid α]: ∀ (a b c d : α),
a * b * c * d * 1 = a * (b * c) * (1 * 1 * d) := by
intros a b c d
rw [mul_one]
rw [mul_one]
rw [one_mul]
rw [mul_assoc a b _]
These proofs are straightforward and follow directly from the monoid laws, but they can be long and tedious. You feel like you are applying the same strategy to each goal, but still have to come up with a new proof each time. You get tired of writing out proofs by hand, so you turn to tactics to solve equalities for you. How would you go about doing so?
The purpose of this post is to introduce various approaches to writing such a tactic in Lean,
as well as the factors that go into deciding which approach to use.
These methods are widely applicable to many domains far more interesting and applicable than monoids, such as commutative groups, rings,
and linear integer arithmetic (e.g., omega
, ring
, linarith
, and friends).
Equality via Normalization
One common motif is to use normalization as a method of checking equality.
The idea goes something like this:
suppose you want to solve equalities on a type A
, and we have
a normalization function normA : A -> A
that maps all elements to a canonical element in its equivalence class. 1
For instance, if we wish to check if two lists of integer contain the same multiset of elements, we can normalize them by sorting, and comparing the sorted lists for equality.
If we have a monoid expression, we can find a canonical form for it by associating all the way to the left/right and removing identity elements. For instance,
normA (a * (b * c * 1) * 1) = ((a * b) * c)
and
normA (a * (1 * b) * c) = ((a * b) * c)
We can check if two expressions are equal by normalizing both sides and seeing if their canonical forms are equal.
The problem with applying this idea is that for many cases, we are not able to write such a function normA
.
This often requires pattern matching on an expression of type A
, which we cannot do for an arbitrary type.
One way to get around this issue is to use a type B
that "reifies" expressions of type A
,
allowing us to write a reification function A -> B
and an explicit normalization function normB : B -> B
.
If we can write a "denotation" function denote : B -> A
, we can compose these together to get a normalization function A -> A
.
For instance, we can define the following inductive type exp α
that represents expressions of a monoid of type α
:
inductive exp α where
| atom : α → exp α
| mult : exp α → exp α → exp α
| mempty : exp α
Each constructor corresponds to a method of forming a monoid expression. A monoid expression can either consist of the identity element, combining two subexpressions with the monoid operation, or a "atom"--a catch-all constructor for any other kind of expression.
We can easily write a denotation function to convert from an exp α
to an α
:
def exp.denote [Monoid α] (e : exp α) : α :=
match e with
| atom x => x
| mult x y => exp.denote x * exp.denote y
| mempty => 1
However, writing a reification function α -> exp α
is not as straightforward.
In fact, it runs into our original problem:
we cannot pattern match on an expression of type α
to begin with.
Such a function cannot be written in Lean itself.
Instead, we must use Lean's meta-language, which allows us to directly manipulate the ASTs of the expressions.
Reifying Monoid Expressions: An Interlude on Lean Metaprogramming
Lean 4 offers a metaprogramming interface similar to Agda's. Writing tactics, building expressions, and other metaprogramming tasks are done in a set of monads, which give access to things like current declarations and goals. For a further exploration of metaprogramming in Lean, please see the book Metaprogramming in Lean 4.
Expressions at the meta level are terms of the inductive type Expr
, which represents the abstract syntax of a Lean program.
Expr
s can be scrutinized, constructed, and manipulated just like any other inductive type.
Let's see how we can use Lean's meta-language to write our reification function:
partial def exp.reify (e : Expr) : M Expr := do
match getAppFnArgs e with
| (``HMul.hMul, #[_, _, _, _, e₁, e₂]) => do
let e₁' ← exp.reify e₁
let e₂' ← exp.reify e₂
return (← read).app ``exp.mult #[e₁', e₂']
| _ =>
if ← isDefEq e (← read).α1 then
return (← read).app ``exp.mempty #[]
else
return (← read).app ``exp.atom #[e]
We are using a custom monad built on top of MetaM
that allows us to assume that we have things like the current monoid instance we're working with in scope.
At the heart of this function, we are able to pattern match on the shape of the expressions e
using the getAppFnArgs
function, which returns the function name and arguments of an expression.
- In the first case, we check if
e
is constructive from multiplying two subexpressions. This case corresponds to theexp.mult
constructor, so we recursively callreify
on the subexpressions and use the results as arguments in theexp.mult
constructor. - In the second case, we check if
e
is definitionally equal to the identity of our monoid instance. This case corresponds to theexp.mempty
constructor. - The remaining case is our default option, which is to call
e
an atom and use theexp.atom
constructor.
A Reflection Theorem
With this inductive representation exp α
in hand, we can now write the normalization function we initially envisioned.
Instead of returning an exp α
, we can use a more suitable type to represent normalized expressions.
Since associativity allows us to effectively ignore all parentheses
we can represent a normalized expression with a List α
(technically, a list without identity elements, but this is slightly more difficult to write without much benefit).
We can then write the normalization procedure, by removing all mempty
's and recursively flattening and then combining subexpressions:
def exp.normalize : exp α → List α
| atom x => [x]
| mult x y => exp.normalize x ++ exp.normalize y
| mempty => []
Finally, we must have a method of going back to our original type α
, i.e., a denotation function List α -> α
:
def nf.denote [Monoid α] : List α → α
| [] => 1
| x :: xs => x * nf.denote xs
The relationship between α
, exp α
, and List α
is shown in this example:
The most important property of our normalize
function is that is preserves the meaning of the reified expression, in the sense that nf.denote (exp.normalize e) = exp.denote e
.
This can be stated in this theorem:
theorem normalize_correct [Monoid α] (e : exp α) :
nf.denote (exp.normalize e) = exp.denote e
The proof is by a straightforward induction argument.
We derive the following "reflection theorem" as a corollary:
theorem monoid_reflect [Monoid α] (a b : exp α) :
nf.denote (exp.normalize a) = nf.denote (exp.normalize b) →
exp.denote a = exp.denote b := by
repeat simp [normalize_correct]
We can apply this theorem to solve equalities. Suppose that we have a goal a1 = a2
, and we have exp.denote a = a1
and exp.denote b = a2
. Then our goal is definitionally equal to exp.denote a = exp.denote b
, and is of the form of the conclusion of monoid_reflect
. Applying this theorem gives us the goal nf.denote (exp.normalize a) = nf.denote (exp.normalize b)
, which is hopefully easier to prove. In the case where the normal forms of a
and b
are equal, this goal can be trivially discharged with reflexivity.
For instance, consider this lemma:
lemma ex2' : ∀ (a b c : ℕ),
a * 1 * b * (a + c) = a * (b * (a + c)) := by
open exp in
intros a b c
exact (monoid_reflect
(mult (mult (mult (atom a) mempty) (atom b)) (atom (a + c)))
(mult (atom a) (mult (atom b) (atom (a + c))))
rfl)
Here, we have manually performed the reification of the expressions a * 1 * b * (a + c)
and a * (b * (a + c))
. This demonstrates that we can use our monoid_reflect
theorem without any reliance on tactics or metaprogramming at all, if we are willing to do the reification ourselves.
However, it is much more convenient to use our reification function. Here is how we would combine it with our reflection theorem and package them into a tactic:
syntax (name := monoid) "monoid" : tactic
elab_rules : tactic | `(tactic| monoid) => withMainContext do
let some (_, e₁, e₂) := (← whnfR <| ← getMainTarget).eq?
| throwError "monoid: requires an equality goal"
let c ← mkContext e₁
closeMainGoal <| ← AtomM.run .default <| ReaderT.run (r := c) do
let t₁ ← exp.reify e₁ -- Reify the expressions into exp
let t₂ ← exp.reify e₂
let n₁ := (← read).app ``exp.normalize #[t₁] -- Normalize the expressions
let n₂ := (← read).app ``exp.normalize #[t₂]
unless ← isDefEq n₁ n₂ do
throwError "monoid: normalized forms not equal"
let m₁ := (← read).appInst ``nf.denote #[n₁]
let eq ← mkAppM ``Eq.refl #[m₁]
mkAppM ``monoid_reflect #[t₁, t₂, eq] -- apply the reflect theorem
This is an instance of "proof by reflection", which is well introduced in the book Certified Programming with Dependent Types. This style of proof has several advantages:
-
Every proof using the
monoid
tactic shares the samemonoid_reflect
proof. Instead of having to write a new, bespoke proof for each new goal, we have proven one theorem which can be used repeatedly to solve many different goals. -
Having the
normalize_correct
andmonoid_reflect
theorems gives us greater confidence that our normalization procedure is correct. In the case of monoids, it's not difficult to convince yourself that this normalization procedure should work, but we can imagine more complicated theories such as rings or linear integer arithmetic where finding a normal form could involve far more complicated algorithms and reasoning. -
There is a high degree of separation of concerns between the metaprogramming used for reification and the "business logic" of performing the actual normalization. This is desirable because writing tactics is less familiar to most Lean users and has far less type safety. It is easy to construct
Expr
s that are completely meaningless when evaluated. By limiting the metaprogramming to only the initial reification step, we are able to reduce the number of parts of the tactic that rely on metaprogramming, which tend to be (in my experience) harder to write correctly and maintain.
However, proofs by reflection do have some disadvantages as well:
-
Proofs by reflection can be thought of as "converting validity of logical statements into symbolic computation", and relying on the kernel to do symbolic computation can be slow, particularly in Lean.
This can be partially offset by using the Lean interpreter rather using the kernel. This unlocks much more performant evaluation at the cost of relying on a much larger code base (the Lean compiler and interpreter). For more information, see
reduceBool
. -
Having to prove that the normalization procedure is correct can be restricting in some cases. For instance, one could imagine in certain domains, a normal form cannot always be directly computed, but has to be found using some form of search. In order to use write a reflective proof, you are limited to normalization procedures for which you can write a correctness theorem.
A Alternative Approach
An another way of writing a monoid tactic is exemplified by the abel tactic from mathlib. Here, in what I'll call the "proof-producing" style, we write a normalization function, but we do it in the meta-language. Given an expression, we can construct a proof that it is equal to its normal form. If two expressions have the same normal form, we can prove they are equal to each other by chaining together the equality proofs with transitivty.
(Note: I will use the notation e : Q(α)
to mean that e : Expr
is a quoted form of an expression of type α
.)
More precisely, we write a function eval : Expr -> M (NF x Expr)
, where NF
is a type that represents normal forms, and M
is a monad where we have access to the standard metaprogramming capabilities (MetaM
).
eval e
returns a pair (nf, p)
, where nf
is an element of NF
such that nf.e : Q(α)
is an Expr that is the normalized form of e
, and p : Q(e = nf.e)
is a proof that that e = nf.e
.
In other words, eval
takes an expression, and returns its normal form, along with a proof that the original expression is equal to the expression that the normal form represents.
Let's look at how we might write such a function for our monoid example.
We first define a type of normal forms NF
:
inductive NF where
| mult : Expr → Expr → NF → NF
| mempty : Expr → NF
deriving Inhabited
def NF.e : NF → Expr
| mult e .. => e
| mempty e .. => e
NF
is effectively a linked list, where each cell contains an Expr that is the quoted form of the normalized expression it represents.
mempty
is the normal form of an identity element.mempty.e
should be an Expr representing1
, the identity element of the monoid.mult a b c
is the normal form constructed by multiplying theb : Q(α)
with the normal formc : NF
. It should hold that(mult a b c).e = b * c.e
.
If you squint, you'll notice that this representation is quite similar to the type of normal forms we were using earlier, List α
, except we are leaving the expressions as quoted Expr
s and doing some extra bookkeeping to remember the Expr
that each normal form represents.
Now we are able to write our eval
function:
partial def eval (e : Expr) : M (NF × Expr) := do
match getAppFnArgs e with
| (``HMul.hMul, #[_, _, _, _, e₁, e₂]) => do
let (e₁', p₁) ← eval e₁
let (e₂', p₂) ← eval e₂
let (e', p') ← evalMult e₁' e₂'
return (
e',
(← read).appInst ``subst_into_mult #[e₁, e₂, e₁'.e, e₂'.e, e'.e, p₁, p₂, p'])
| _ =>
if ← isDefEq e (← read).α1 then
return (NF.mempty e, ← mkEqRefl (← read).α1)
else
mkSingle e
The shape of this function is very similar to the reify
function we wrote earlier, but instead of directly reifying e
into an inductive type, we perform the normalization while recursing on the structure of e
.
Let's focus on the first and only non-trivial case, where e = e₁ * e₂
. We first recursively call eval
on e₁
to obtain e₁' : NF
and p₁ : Q(e₁ = e₁'.e)
, and analogously for e₂
.
Then, we merge these two normal forms together using evalMult
2 to obtain e' : NF
and p' : Q(e'.e = e₁'.e * e₂'.e)
.
Now e'
is our desired normal form, but recall that we also need to produce a proof that e = e'.e
, which can be shown with a series of rewriting steps using p₁
, p₂
, and p'
.
This can be packaged up into the following lemma:
lemma subst_into_mult {α} [Monoid α] (l r tl tr t)
(prl : (l : α) = tl)
(prr : r = tr)
(prt : tl * tr = t) :
l * r = t := by
rw [prl, prr, prt]
We then call this lemma (instantiating the monoid instance from our context) to produce the desired proof.
Let's look at how we would tie this together into a tactic:
syntax (name := monoid) "monoid" : tactic
elab_rules : tactic | `(tactic| monoid) => withMainContext do
let some (_, e₁, e₂) := (← whnfR <| ← getMainTarget).eq?
| throwError "monoid: requires an equality goal"
let c ← mkContext e₁
closeMainGoal <| ← AtomM.run .default <| ReaderT.run (r := c) do
let (e₁', p₁) ← eval e₁
let (e₂', p₂) ← eval e₂
unless ← isDefEq e₁'.e e₂'.e do
throwError "monoid: could not find that the two sides are equal"
mkEqTrans p₁ (← mkEqSymm p₂)
Similar to our previous tactic, we first identify the two expressions we want to show are equal, and then attempt to synthesize a Monoid instance.
We then compute the normal forms of e₁
and e₂
and check if they are definitionally equal.
If so, then e₁.e = e₂.e
, and since
p₁ : Q(e₁ = e₁.e)
and
p₂ : Q(e₂ = e₂.e)
,
then we can combine these two by symmetry/transitivity to obtain a proof that e₁ = e₂
, which we use to solve the goal.
Differences:
- I find this style much more difficult to write.
The primary reason is that it relies on constructing quoted equality proofs, but doing so in an essentially untyped manner.
For instance, when writing the
eval
function, one could easily provide too many/too few arguments to thesubst_into_mult
function, or swap the order of two arguments, and there would not be a way to detect the error until the tactic is actually run. -
Proof terms can be much larger. Compare the outputs of the two tactics on the goal
∀ (a b c : ℕ), a * 1 * b * (a + c) = a * (b * (a + c))
:Reflective tactic
theorem Monoid.ex2 : ∀ (a b c : ℕ), a * 1 * b * (a + c) = a * (b * (a + c)) := fun a b c ↦ monoid_reflect (exp.mult (exp.mult (exp.mult (exp.atom a) exp.mempty) (exp.atom b)) (exp.atom (a + c))) (exp.mult (exp.atom a) (exp.mult (exp.atom b) (exp.atom (a + c)))) (Eq.refl (nf.denote (exp.normalize (exp.mult (exp.mult (exp.mult (exp.atom a) exp.mempty) (exp.atom b)) (exp.atom (a + c))))))
Proof producing tactic
theorem Monoid.ex2 : ∀ (a b c : ℕ), a * 1 * b * (a + c) = a * (b * (a + c)) := fun a b c ↦ (subst_into_mult (a * 1 * b) (a + c) (a * (b * 1)) ((a + c) * 1) (a * (b * ((a + c) * 1))) (subst_into_mult (a * 1) b (a * 1) (b * 1) (a * (b * 1)) (subst_into_mult a 1 (a * 1) 1 (a * 1) (Monoid.mul_one a).symm (Eq.refl 1) (Monoid.mul_one (a * 1))) (Monoid.mul_one b).symm (left_mult a 1 (b * 1) (b * 1) (Monoid.one_mul (b * 1)))) (Monoid.mul_one (a + c)).symm (left_mult a (b * 1) ((a + c) * 1) (b * ((a + c) * 1)) (left_mult b 1 ((a + c) * 1) ((a + c) * 1) (Monoid.one_mul ((a + c) * 1))))).trans (subst_into_mult a (b * (a + c)) (a * 1) (b * ((a + c) * 1)) (a * (b * ((a + c) * 1))) (Monoid.mul_one a).symm (subst_into_mult b (a + c) (b * 1) ((a + c) * 1) (b * ((a + c) * 1)) (Monoid.mul_one b).symm (Monoid.mul_one (a + c)).symm (left_mult b 1 ((a + c) * 1) ((a + c) * 1) (Monoid.one_mul ((a + c) * 1)))) (left_mult a 1 (b * ((a + c) * 1)) (b * ((a + c) * 1)) (Monoid.one_mul (b * ((a + c) * 1))))).symm
The size of the proof term produced by the reflective tactic is linear in the size of the terms, while the proof term of the proof producing tactic scales with the number of rewriting steps performed (essentially, there is no difference from writing out all the rewriting steps explicitly). It should be noted, however, that the size of a proof term is not always indicative of how efficiently a proof term can be checked.
-
As mentioned before, tactics are more flexible, and don't need to prove their correctness in order to be used.
- Meta code (outside the kernel) is more efficient than evaluation within kernel.
Conclusion
I have presented three different methods of writing tactics for solving equalities: (1) reflective tactics, (2) reflective tactics using native evaluation, and (3) proof producing tactics. (For brevity I will refer to these as P1, P2, and P3 respectively.)
I propose several metrics to evaluate these methods:
-
Trusted code base, or which systems do we need to believe in their correctness in order to accept a proof: Both P1 and P3 rely on only the correctness of the Lean kernel, while P2 also uses the entire Lean compiler and interpreter, a much larger code base. As noted here, however, you will most likely have to trust the Lean compiler and interpreter anyways, so by using P2, we primarily lose the ability to check our proof with other third-party checkers.
-
Efficiency: Because P1 heavily relies on Lean's kernel reduction, it suffers in performance. In a comparison between P1 and P3-style tactics for
ring
, it was found that the P1 version was slower by 20-50%. Comparisons in efficiency between P2 and P3-style tactics have not been done, as far as I know. -
Expressiveness: As discussed previously, reflective tactics are restricted by the need to prove their normalization procedure correct. P3 tactics don't have this limitation, and can perform arbitrary and unrestricted computation outside the kernel.
-
Ease of implementation and maintenance: This is the most subjective metric, but one that I think is quite important. P1 and P2-style tactics limit the amount of metaprogramming code to the bare minimum. Writing a reification function is simpler and less error prone than writing a P3 tactic, which has to use meta functions from the beginning. Limiting the scope of metaprogramming also makes updating and debugging far easier, in my opinion.
So, which method should you use?
I don't think there is a clear choice to make, but we can look to existing libraries and tactics to get a sense of what others have done.
The mathlib library includes several P3 tactics such as abel
and ring
.
On the other hand, P2 tactics are used in the leansat
library.
In my opinion, if you are writing a tactic for your use in your own projects, I would suggest writing a P1 tactic if possible, because I feel it is the simplest and least error-prone method.
Acknowledgements
Thanks to Scott Morrison, Henrik Böving, and Mario Carneiro for helpful discussions on this topic.
Appendix
See the following code snippets for details on how to write the monoid tactic.
-
Technically, we only require that the equivalence relation induced by
normA
be finer than the equality relation (i.e.,normA
may map to equivalent elements to different canonical elements). For instance, a trivial normalization function would map every element to itself. In such cases, we only obtain a semi-decision procedure for checking equality. ↩ -
Another meta helper function that produces a normal form by performing the analogous version of list concatenation on
NF
s. See the code for details. ↩