In programming contests, one of the most frustrating (and common!) verdicts to receive is Wrong Answer (WA). One of the reasons behind this is that there are often many potential sources of WA that a programmer has to hunt down. A WA could come from a typo in the implementation, or it could come from the use of a greedy algorithm which is not always correct, or many other reasons.

In this tutorial, I will try to give a brief overview of how you can *prove your programs correct* by writing specifications and proofs in the Coq proof assistant.
Coq is based on a dependent type theory that allows you to define your programs and write proofs about them.
Writing proofs by hand can be quite laborious, so Coq also has an interactive proving system using *tactics* that allows you to automate away tedious parts of proofs and also prove write proofs in a manner that is much more natural for us humans.
This very document is a (literate) Coq file, so you can download the source and follow along.
By the end of this tutorial, you should be able to compile this Coq file, extract it to Ocaml, and submit your solution to an online judge!

## Some Disclaimers

There are, of course, several caveats I should point out. The first of which is that writing out a specification and proving an algorithm correct takes a significant amount of time, even for experienced users. It would be utterly impractical for a contest setting, where "proofs by AC" are quite common. My goal is to use contest problems as a vehicle to introduce formal verification in a small-scale, concrete way--although of course program verification is also used on much larger projects.

Secondly, we may be limited in the kinds of problems we can (easily) verify and write efficient solutions for. For instance, many contest problems involve mutable arrays, while when using Coq, we tend to prefer lists and other functional data structures because they are easier to reason about.

## A Simple Example

For our example, we'll solve the Codeforces problem Red and Blue Beans .

This is a pretty basic problem involving only basic arithmetic.
The answer is `YES`

if and only if `b <= r * (d+1)`

and `r <= b * (d+1)`

.

How can we prove this is the right condition?
Without loss of generality, we can assume that `r <= b`

.
If we wanted to maximize the number of blue beans for a given number of red beans, we would create `r`

packets, each with `1`

red bean and `d+1`

blue beans.
Therefore, if we have more than `r * (d+1)`

blue beans, we cannot distribute all the beans; otherwise, we can.

Now, let's get to programming our solution!

Require Import Lia. Require Import Bool. Require Import List. Require Import Arith.Arith. Import ListNotations. Require Extraction. Definition can_distribute (r b d : nat) : bool := (b <=? (r * (d + 1))) && (r <=? (b * (d + 1))).

That wasn't too bad!
If we were programming in another language like Java or C++, we would stop here and submit our solution.
However, even after passing hundreds of tests, we still can't be sure that is correct for all inputs.
The only way to be certain of a program's correctness is to prove it, like we did above.
But how do we know our *proof* is correct?
People make mistakes in proofs all the time.
Luckily, Coq's type system is expressive enough to allow you to write propositions about your programs as types.
If we can find a term `t`

that has type `T`

, we say that `T`

is *inhabited*. Interpreting `T`

as a proposition, we say that `t`

is *evidence* for `T`

being true.
This effectively means that we can use Coq's type-checker as a proof-checker!
This is known as the Curry-Howard correspondance .

Okay, here's a concrete example of writing a proposition and its proof in Coq.
When writing a specification for our original problem, we'll need to define a function `abs`

that takes two natural numbers `a`

and `b`

and returns the absolute value `|a-b|`

.

`Definition abs (a b : nat) := max (b-a) (a-b).`

Note that since we are working with natural numbers, subtraction is capped at zero--so `3 - 5 = 0`

, for instance.

To show that `abs`

is defined correctly, we can prove that it satisfies the properties we want for all possible inputs.
For instance, we would expect that `abs a b`

is the same as `abs b a`

.
We can write this proposition in Coq like so:

forall a b : nat, abs a b = abs b a

We can read this like a sentence in first-order logic:
"For all `a b`

, `(abs a b)`

and `(abs b a)`

are equal."

By writing Proof., we enter into the interactive proving mode, where we can prove our goal by entering a series of tactics.

forall a b : nat, abs a b = abs b a

First, we move `a`

and `b`

into our local context using the `intros`

tactic.
This is like writing "Let a and b be arbitrary natural numbers." in a written proof.

a, b:natabs a b = abs b a

Next, we unfold the definition of `abs`

using the `unfold`

tactic.
Since `abs`

is defined in terms of the `max`

function, we can use properties about `max`

to prove our goal.

a, b:natInit.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

The theorem Nat.max_comm should be useful. We can check its type using the `Check`

directive:

a, b:natInit.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

We can introduce a new hypothesis using `assert`

, prove it, and then use it to help prove our original goal.
This is like introducing a lemma in a written proof.

a, b:natInit.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)a, b:natH:Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

Our lemma follows directly from the Nat.max_comm theorem, so we can use the `apply`

tactic:

apply Nat.max_comm.a, b:natInit.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)a, b:natH:Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

The most important thing about equalities is that if `x = y`

,
then `x`

and `y`

are interchangeable in every context.
We can use the `rewrite`

tactic to change `max (b-a) (a-b)`

in our goal to `max (a-b) (b-a)`

.

a, b:natH:Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)Init.Nat.max (a - b) (b - a) = Init.Nat.max (a - b) (b - a)

Now, both sides of the equality are exactly the same, so we can discharge the goal using `reflexivity`

.

reflexivity.

Use `Qed.`

to declare victory.

`Qed.`

It's important to stress that tactics are only high-level directions that tell Coq how to create the proof.
The proof term itself is often much longer and harder to understand, and it is what is checked by the typechecker to verify that you have proven your theorem. You can view the proof `abs_comm`

using the `Print`

directive.

Let's prove one more property about `abs`

.
`abs a b`

should be equal to the distance between `a`

and `b`

, so if `a < b`

, then `a + abs a b = b`

, and if `b < a`

, then `b + abs a b = b`

.

To prove this, we can use the `lia`

tactic, which is a decision procedure for linear integer arithmetic.
As our original problem involves a lot of arithmetic, `lia`

will frequently come in handy.

forall a b : nat, (a < b -> a + abs a b = b) /\ (b < a -> b + abs a b = a)forall a b : nat, (a < b -> a + abs a b = b) /\ (b < a -> b + abs a b = a)a, b:nat(a < b -> a + abs a b = b) /\ (b < a -> b + abs a b = a)lia. Qed.a, b:nat(a < b -> a + Init.Nat.max (b - a) (a - b) = b) /\ (b < a -> b + Init.Nat.max (b - a) (a - b) = a)

## Specification Using Lists

Now let's return to our original problem.
We can write a proposition that defines whether there is a valid distribution of beans.
For any `r,b,d`

, there exists a *correct distribution* of beans if and only if there exists a set of packets such that

- the sum of all the red beans is
`r`

,- the sum of all the blue beans is
`b`

,- each packet contains a positive number of red and blue beans,
- for each packet, the number of red and blue beans should not differ by more than
`d`

.

The simplest way to represent a set of packets is a `list`

.
We'll define a `packet`

as a pair of natural numbers, and we'll define the function `packet_sum`

which adds up the number of red and blue beans in a list of packets.

Module ListSpec. Definition packet : Set := nat * nat. Fixpoint packet_sum (l : list packet) := match l with | [] => (0,0) | (r,b) :: rest => let (x,y) := packet_sum rest in (r+x,b+y) end.

Now we can define what a `correct_distribution`

means as define above. Here, we use the existential quantifier `exists`

, meaning that to prove this proposition, we must supply a valid list that all the conditions.

Definition correct_distribution (r b d : nat) := exists l, packet_sum l = (r,b) /\ Forall (fun '(x,y) => abs x y <= d /\ x > 0 /\ y > 0) l.correct_distribution 1 1 0correct_distribution 1 1 0packet_sum [(1, 1)] = (1, 1) /\ Forall (fun '(x, y) => abs x y <= 0 /\ x > 0 /\ y > 0) [(1, 1)]repeat constructor; lia. Qed.Forall (fun '(x, y) => Init.Nat.max (y - x) (x - y) <= 0 /\ x > 0 /\ y > 0) [(1, 1)]correct_distribution 2 7 3correct_distribution 2 7 3packet_sum [(1, 4); (1, 3)] = (2, 7) /\ Forall (fun '(x, y) => abs x y <= 3 /\ x > 0 /\ y > 0) [(1, 4); (1, 3)]repeat constructor; lia. Qed.Forall (fun '(x, y) => Init.Nat.max (y - x) (x - y) <= 3 /\ x > 0 /\ y > 0) [(1, 4); (1, 3)]

Finally, we can relate our algorithm `can_distribute`

to our specification `correct_distribution`

.
Our algorithm should return true if and only if `correct_distribution`

is provable, meaning that there exists a list of packets that meets the conditions defined in `correct_distribution`

.

```
Definition algorithm_iff_correct_distribution :=
forall r b d,
can_distribute r b d = true <-> correct_distribution r b d.
```

If we can prove this theorem (meaning we can find a term which has this type), we know that our algorithm is correct. Moreover, other people do not even need to read our proof to trust that our algorithm is correct. They can simply read the specification and make sure the program typechecks.

### Some helpful lemmas

A few lemmas about arithmetic that will come in handy.
We make heavy use of the `lia`

to avoid having to deal with most of these proofs by hand, but for some cases (multiplication by non-constants), we need to help guide the solver:

forall d r : nat, r > 0 -> d <= r * dforall d r : nat, r > 0 -> d <= r * dinduction d; lia. Qed.d, r:natH:r > 0d <= r * dforall r b d x y : nat, r > 0 -> b <= r + d -> y <= x * (d + 1) -> b + y <= (r + x) * (d + 1)forall r b d x y : nat, r > 0 -> b <= r + d -> y <= x * (d + 1) -> b + y <= (r + x) * (d + 1)r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)b + y <= (r + x) * (d + 1)r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)H2:(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)b + y <= (r + x) * (d + 1)lia.r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)H2:(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)b + y <= (r + x) * (d + 1)r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)H2:(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)b + y <= x * (d + 1) + r * (d + 1)r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)H2:(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)r + d <= r * (d + 1)lia. Qed. (** This lemma says that we can exchange the uses of `r` and `b`, meaning that if we can distribute `r` red beans and `b` blue beans, we must also be able to distribute `b` red beans and `r` blue beans, and vice versa. Although the proof is a bit complicated, the intuition is simple: we can flip the number of blue and red beans in each packet. *)r, b, d, x, y:natH:r > 0H0:b <= r + dH1:y <= x * (d + 1)H2:(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)H3:r > 0 -> d <= r * dr + d <= r * (d + 1)forall r b d : nat, correct_distribution r b d <-> correct_distribution b r dforall r b d : nat, correct_distribution r b d <-> correct_distribution b r dforall r b d : nat, correct_distribution r b d -> correct_distribution b r dH:forall r b d : nat, correct_distribution r b d -> correct_distribution b r dforall r b d : nat, correct_distribution r b d <-> correct_distribution b r dforall r b d : nat, correct_distribution r b d -> correct_distribution b r dforall r b d : nat, (exists l : list packet, packet_sum l = (r, b) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lr, b, d:natl:list packetH2:packet_sum l = (r, b)H3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lexists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, d:natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lforall r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld:natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lforall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld:natH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []b, r:natH2:packet_sum [] = (r, b)exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld:nata:packetl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r:natH2:packet_sum (a :: l) = (r, b)exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld:natH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []b, r:natH2:packet_sum [] = (r, b)d:natH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []b, r:natH2:(0, 0) = (r, b)exists []; auto.d:natH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []H2:(0, 0) = (0, 0)exists l : list packet, packet_sum l = (0, 0) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld:nata:packetl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r:natH2:packet_sum (a :: l) = (r, b)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r:natH2:packet_sum ((r', b') :: l) = (r, b)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r:natH2:(let (x, y) := packet_sum l in (r' + x, b' + y)) = (r, b)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)X:(nat * nat)%typeHeqX:X = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, X = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r:natH2:(let (x, y) := X in (r' + x, b' + y)) = (r, b)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, (x, y) = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r:natH2:(r' + x, b' + y) = (r, b)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, (x, y) = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lH2:(r' + x, b' + y) = (r' + x, b' + y)exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> (x, y) = (x, y) -> exists l : list packet, packet_sum l = (y, x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lH2:(r' + x, b' + y) = (r' + x, b' + y)exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0packet_sum ((b', r') :: x0) = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0packet_sum ((b', r') :: x0) = (b' + y, r' + x)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0(let (x, y) := packet_sum x0 in (b' + x, r' + y)) = (b' + y, r' + x)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0(b' + y, r' + x) = (b' + y, r' + x)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0abs b' r' <= d /\ b' > 0 /\ r' > 0d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0H5:abs r' b' <= d /\ r' > 0 /\ b' > 0H6:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) labs b' r' <= d /\ b' > 0 /\ r' > 0split; auto; unfold abs in *; try lia.d, r', b':natl:list packetH3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lH2:(r' + x, b' + y) = (r' + x, b' + y)x0:list packetH:packet_sum x0 = (y, x)H0:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0Ha:abs r' b' <= dHr:r' > 0Hb:b' > 0H6:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) labs b' r' <= d /\ b' > 0 /\ r' > 0split; apply H. Qed.H:forall r b d : nat, correct_distribution r b d -> correct_distribution b r dforall r b d : nat, correct_distribution r b d <-> correct_distribution b r d

Here is the crucial lemma proving that our algorithm's conditions is sufficient. Note that we assume that `r <= b`

, like we did in the informal proof.

forall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b dforall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b dr, b, d:natr_leq_b:r <= bH:b <= r * (d + 1)correct_distribution r b dr, d:natforall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)correct_distribution 0 b dr, d:natIHr:forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db:natr_leq_b:S r <= bH:b <= S r * (d + 1)correct_distribution (S r) b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)correct_distribution 0 b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)b = 0d, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)H0:b = 0correct_distribution 0 b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)H0:b = 0correct_distribution 0 b dd:natH:0 <= 0 * (d + 1)r_leq_b:0 <= 0correct_distribution 0 0 dauto.d:natH:0 <= 0 * (d + 1)r_leq_b:0 <= 0packet_sum [] = (0, 0) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []r, d:natIHr:forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db:natr_leq_b:S r <= bH:b <= S r * (d + 1)correct_distribution (S r) b dr, d:natIHr:forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db:natr_leq_b:S r <= bH:b <= S r * (d + 1)t:natHeqt:t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)r <= b - tr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:r <= b - tcorrect_distribution (S r) b dlia.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)r <= b - tr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:r <= b - tcorrect_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:correct_distribution r (b - t) dcorrect_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:r <= b - tb - t <= r * (d + 1)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lcorrect_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:r <= b - tb - t <= r * (d + 1)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lpacket_sum ((1, t) :: l) = (S r, b)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lForall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((1, t) :: l)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:r <= b - tb - t <= r * (d + 1)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lpacket_sum ((1, t) :: l) = (S r, b)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l(let (x, y) := packet_sum l in (S x, t + y)) = (S r, b)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l(S r, t + (b - t)) = (S r, b)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lt + (b - t) = br, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lHtb:t + (b - t) = b(S r, t + (b - t)) = (S r, b)lia.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lt + (b - t) = br, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lHtb:t + (b - t) = b(S r, t + (b - t)) = (S r, b)reflexivity.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lHtb:t + (b - t) = b(S r, b) = (S r, b)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lForall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((1, t) :: l)unfold abs; lia.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)l:list packetHl1:packet_sum l = (r, b - t)Hl2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) labs 1 t <= dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:r <= b - tb - t <= r * (d + 1)pose proof (Nat.min_spec (b-r) (d+1)) as [[Hmin1 minEq] | [Hmin2 minEq]]; rewrite minEq in Heqt; lia. Qed.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= d + 1 + r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hr:r <= b - tb - t <= r * (d + 1)

Next, we'll prove that our algorithm's condition is necessary, meaning that if there is a list `l`

that shows that `correct_distribution r b d`

is true, then `b <= r * (d + 1)`

and `r <= b * (d + 1)`

.

The proof proceeds by induction on the list `l`

.

forall r b d : nat, correct_distribution r b d -> can_distribute r b d = trueforall r b d : nat, correct_distribution r b d -> can_distribute r b d = trueforall r b d : nat, (exists l : list packet, packet_sum l = (r, b) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l) -> (b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truer, b, d:natH:exists l : list packet, packet_sum l = (r, b) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truer, b, d:natl:list packetH1:packet_sum l = (r, b)H2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truer, b, d:natl:list packetH1:packet_sum l = (r, b)H2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truer, b, d:natl:list packetH1:packet_sum l = (r, b)H2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l(b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = truer, b, d:natl:list packetH1:packet_sum l = (r, b)H2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb <= r * (d + 1) /\ r <= b * (d + 1)r, d:natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lforall b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)d:natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lforall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)d:natH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []r, b:natH1:packet_sum [] = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d:nata:packetl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b:natH1:packet_sum (a :: l) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d:natH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []r, b:natH1:packet_sum [] = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d:natH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []r, b:natH1:(0, 0) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)lia.d:natH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []H1:(0, 0) = (0, 0)0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1)d:nata:packetl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b:natH1:packet_sum (a :: l) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b:natH1:packet_sum ((r', b') :: l) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)IHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b:natH1:(let (x, y) := packet_sum l in (r' + x, b' + y)) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)X:(nat * nat)%typeHeqX:X = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, X = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b:natH1:(let (x, y) := X in (r' + x, b' + y)) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b:natH1:(r' + x, b' + y) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)b' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)H1:abs r' b' <= d /\ r' > 0 /\ b' > 0H3:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)H1:abs r' b' <= d /\ r' > 0 /\ b' > 0H3:y <= x * (d + 1) /\ x <= y * (d + 1)b' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)split; apply d_bound; lia. Qed.d, r', b':natl:list packetH2:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)x, y:natHeqX:(x, y) = packet_sum lIHl:Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)H1:Init.Nat.max (b' - r') (r' - b') <= d /\ r' > 0 /\ b' > 0H3:y <= x * (d + 1) /\ x <= y * (d + 1)b' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)

### The proof of correctness!

algorithm_iff_correct_distributionalgorithm_iff_correct_distributionforall r b d : nat, can_distribute r b d = true <-> correct_distribution r b dr, b, d:natcan_distribute r b d = true <-> correct_distribution r b dr, b, d:natH:can_distribute r b d = truecorrect_distribution r b dr, b, d:natH:correct_distribution r b dcan_distribute r b d = truer, b, d:natH:can_distribute r b d = truecorrect_distribution r b dr, b, d:natH:(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truecorrect_distribution r b dr, b, d:natH:(b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = truecorrect_distribution r b dr, b, d:natH:b <= r * (d + 1) /\ r <= b * (d + 1)correct_distribution r b d(* Case on whether r <= b or b <= r. *)r, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)correct_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)r <= b \/ b <= rr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hrb:r <= bcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hrb:r <= bcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution r b dapply can_make_distr; auto.r, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hrb:r <= bcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution r b dapply can_make_distr; auto.r, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution b r dapply algorithm_condition_necessary; auto. Qed. End ListSpec.r, b, d:natH:correct_distribution r b dcan_distribute r b d = true

## Specification Using Inductive Relations

We've done it! We have defined our specification using lists and proven that our algorithm always computes the correct answer according to our specification.

Now that we know for certain that our program is correct, we can proceed to extracting the program into Ocaml.
Before we move on, however, I want to present another specification for this problem, this time defining `correct_distribution`

as an inductive proposition.

While it may be natural to interpret the specification using lists, it does make the definitions more complicated, which affects the length and readability of the proofs. This alternate specification is equivalent (in fact, proving that these two specifications are equivalent is a good exercise), but it requires a lot less unfolding and destructing, which simplifies the proofs.

Module InductiveSpec. Inductive correct_distribution : nat -> nat -> nat -> Prop := | no_packets : forall d, correct_distribution 0 0 d | add_packet : forall r b r' b' d, correct_distribution r b d -> r' > 0 -> b' > 0 -> abs r' b' <= d -> correct_distribution (r'+r) (b'+b) d. #[export] Hint Constructors correct_distribution : core. Definition algorithm_iff_correct_distribution := forall r b d, can_distribute r b d = true <-> correct_distribution r b d.forall r b d : nat, correct_distribution r b d <-> correct_distribution b r dforall r b d : nat, correct_distribution r b d <-> correct_distribution b r dforall r b d : nat, correct_distribution r b d -> correct_distribution b r dH:forall r b d : nat, correct_distribution r b d -> correct_distribution b r dforall r b d : nat, correct_distribution r b d <-> correct_distribution b r dforall r b d : nat, correct_distribution r b d -> correct_distribution b r dr, b, d:natH:correct_distribution r b dcorrect_distribution b r dd:natcorrect_distribution 0 0 dr, b, r', b', d:natH:correct_distribution r b dH0:r' > 0H1:b' > 0H2:abs r' b' <= dIHcorrect_distribution:correct_distribution b r dcorrect_distribution (b' + b) (r' + r) dauto.d:natcorrect_distribution 0 0 dr, b, r', b', d:natH:correct_distribution r b dH0:r' > 0H1:b' > 0H2:abs r' b' <= dIHcorrect_distribution:correct_distribution b r dcorrect_distribution (b' + b) (r' + r) dr, b, r', b', d:natH:correct_distribution r b dH0:r' > 0H1:b' > 0H2:abs r' b' <= dIHcorrect_distribution:correct_distribution b r dabs b' r' <= dlia.r, b, r', b', d:natH:correct_distribution r b dH0:r' > 0H1:b' > 0H2:Init.Nat.max (b' - r') (r' - b') <= dIHcorrect_distribution:correct_distribution b r dInit.Nat.max (r' - b') (b' - r') <= dsplit; apply H. Qed.H:forall r b d : nat, correct_distribution r b d -> correct_distribution b r dforall r b d : nat, correct_distribution r b d <-> correct_distribution b r dforall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b dforall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b dr, b, d:natr_leq_b:r <= bH:b <= r * (d + 1)correct_distribution r b dr, d:natforall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)correct_distribution 0 b dr, d:natIHr:forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db:natr_leq_b:S r <= bH:b <= S r * (d + 1)correct_distribution (S r) b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)correct_distribution 0 b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)b = 0d, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)H0:b = 0correct_distribution 0 b dd, b:natr_leq_b:0 <= bH:b <= 0 * (d + 1)H0:b = 0correct_distribution 0 b dauto.d:natH:0 <= 0 * (d + 1)r_leq_b:0 <= 0correct_distribution 0 0 dr, d:natIHr:forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db:natr_leq_b:S r <= bH:b <= S r * (d + 1)correct_distribution (S r) b dr, d:natIHr:forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db:natr_leq_b:S r <= bH:b <= S r * (d + 1)t:natHeqt:t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)b = t + (b - t)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)correct_distribution (S r) b dlia.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)b = t + (b - t)r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)correct_distribution (S r) b dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)correct_distribution (S r) (t + (b - t)) dr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)S r = 1 + rr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)Hr:S r = 1 + rcorrect_distribution (S r) (t + (b - t)) dlia.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)S r = 1 + rr, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)Hr:S r = 1 + rcorrect_distribution (S r) (t + (b - t)) dconstructor; unfold abs; try apply IHr; try lia. Qed.r, d, b, t:natIHr:r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b:S r <= bH:b <= S r * (d + 1)Heqt:t = Init.Nat.min (b - r) (d + 1)Hb:b = t + (b - t)Hr:S r = 1 + rcorrect_distribution (1 + r) (t + (b - t)) dalgorithm_iff_correct_distributionalgorithm_iff_correct_distributionforall r b d : nat, can_distribute r b d = true <-> correct_distribution r b dr, b, d:natH:can_distribute r b d = truecorrect_distribution r b dr, b, d:natH:correct_distribution r b dcan_distribute r b d = truer, b, d:natH:can_distribute r b d = truecorrect_distribution r b dr, b, d:natH:(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truecorrect_distribution r b dr, b, d:natH:(b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = truecorrect_distribution r b dr, b, d:natH:b <= r * (d + 1) /\ r <= b * (d + 1)correct_distribution r b d(* Case on whether r <= b or b <= r. *)r, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)correct_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)r <= b \/ b <= rr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hrb:r <= bcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hrb:r <= bcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution r b dapply can_make_distr; auto.r, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hrb:r <= bcorrect_distribution r b dr, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution r b dapply can_make_distr; auto.r, b, d:natH1:b <= r * (d + 1)H2:r <= b * (d + 1)Hbr:b <= rcorrect_distribution b r dr, b, d:natH:correct_distribution r b dcan_distribute r b d = truer, b, d:natH:correct_distribution r b d(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truer, b, d:natH:correct_distribution r b d(b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = truer, b, d:natH:correct_distribution r b db <= r * (d + 1) /\ r <= b * (d + 1)d:nat0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1)r, b, r', b', d:natH:correct_distribution r b dH0:r' > 0H1:b' > 0H2:abs r' b' <= dIHcorrect_distribution:b <= r * (d + 1) /\ r <= b * (d + 1)b' + b <= (r' + r) * (d + 1) /\ r' + r <= (b' + b) * (d + 1)lia.d:nat0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1)split; apply ListSpec.d_bound; unfold abs in H2; try lia. Qed. End InductiveSpec.r, b, r', b', d:natH:correct_distribution r b dH0:r' > 0H1:b' > 0H2:abs r' b' <= dIHcorrect_distribution:b <= r * (d + 1) /\ r <= b * (d + 1)b' + b <= (r' + r) * (d + 1) /\ r' + r <= (b' + b) * (d + 1)

## Extraction

Now we can extract our verified algorithm to Ocaml.
We don't extract our proofs, since they not actually meant to be run.
So in this case, we only need to extract our function `can_distribute`

, which is quite simple.
We can try running the extraction command now:

`Extraction "imp.ml" can_distribute.`

If you look in the file `imp.ml`

, you will see the following datatype definitions at the top:

```
type bool =
| True
| False
type nat =
| O
| S of nat
```

Without any directions on how to perform the extraction, Coq will redefine all the datatypes that are used, including booleans and nat.
This is a big problem for nat, since defining numbers Peano-style is quite inefficient--if you look how addition is defined, adding two numbers is actually linear in the size of the first number.
We can tell Coq to extract nat to `int`

or `int64`

, but this can be quite dangerous. This is because theorems about nat may no longer hold. For instance, it is a theorem for nat that `x + y >= x`

, but this is not true for `int`

or `int64`

, since there may be overflow cases.

In this case, since we know that the inputs are less than `10^9`

, we can determine that we will not run into overflow issues.
So, we can safely extract nat to `int64`

.
If you would like, you can also extract to an arbitrary precision integer type like `big_int`

.

Module ExtractionDefs. Module ExtrNatToInt64. Extract Inductive nat => "int64" [ "Int64.zero" "(fun x -> Int64.succ x)" ] "(fun zero succ n -> if Int64.compare n 0 = 0 then zero () else succ (Int64.pred n))". Extract Constant plus => "Int64.add". Extract Constant mult => "Int64.mul". Extract Constant leb => "(fun x y -> Int64.compare x y <= 0)". End ExtrNatToInt64. Extract Inductive bool => "bool" [ "true" "false" ]. End ExtractionDefs. Extraction "imp.ml" can_distribute.

And there we have it--a fully verified program that you can submit to Codeforces! You can see here for a version that includes the input/output plumbing.