# Verifying Programming Contest Problems

Mon 05 December 2022 -

In programming contests, one of the most frustrating (and common!) verdicts to receive is Wrong Answer (WA). One of the reasons behind this is that there are often many potential sources of WA that a programmer has to hunt down. A WA could come from a typo in the implementation, or it could come from the use of a greedy algorithm which is not always correct, or many other reasons.

In this tutorial, I will try to give a brief overview of how you can prove your programs correct by writing specifications and proofs in the Coq proof assistant. Coq is based on a dependent type theory that allows you to define your programs and write proofs about them. Writing proofs by hand can be quite laborious, so Coq also has an interactive proving system using tactics that allows you to automate away tedious parts of proofs and also prove write proofs in a manner that is much more natural for us humans. This very document is a (literate) Coq file, so you can download the source and follow along. By the end of this tutorial, you should be able to compile this Coq file, extract it to Ocaml, and submit your solution to an online judge!

## Some Disclaimers

There are, of course, several caveats I should point out. The first of which is that writing out a specification and proving an algorithm correct takes a significant amount of time, even for experienced users. It would be utterly impractical for a contest setting, where "proofs by AC" are quite common. My goal is to use contest problems as a vehicle to introduce formal verification in a small-scale, concrete way--although of course program verification is also used on much larger projects.

Secondly, we may be limited in the kinds of problems we can (easily) verify and write efficient solutions for. For instance, many contest problems involve mutable arrays, while when using Coq, we tend to prefer lists and other functional data structures because they are easier to reason about.

## A Simple Example

For our example, we'll solve the Codeforces problem Red and Blue Beans .

This is a pretty basic problem involving only basic arithmetic. The answer is YES if and only if b <= r * (d+1) and r <= b * (d+1).

How can we prove this is the right condition? Without loss of generality, we can assume that r <= b. If we wanted to maximize the number of blue beans for a given number of red beans, we would create r packets, each with 1 red bean and d+1 blue beans. Therefore, if we have more than r * (d+1) blue beans, we cannot distribute all the beans; otherwise, we can.

Now, let's get to programming our solution!

Require Import Lia.
Require Import Bool.
Require Import List.
Require Import Arith.Arith.
Import ListNotations.
Require Extraction.

Definition can_distribute (r b d : nat) : bool :=
(b <=? (r * (d + 1))) && (r <=? (b * (d + 1))).

That wasn't too bad! If we were programming in another language like Java or C++, we would stop here and submit our solution. However, even after passing hundreds of tests, we still can't be sure that is correct for all inputs. The only way to be certain of a program's correctness is to prove it, like we did above. But how do we know our proof is correct? People make mistakes in proofs all the time. Luckily, Coq's type system is expressive enough to allow you to write propositions about your programs as types. If we can find a term t that has type T, we say that T is inhabited. Interpreting T as a proposition, we say that t is evidence for T being true. This effectively means that we can use Coq's type-checker as a proof-checker! This is known as the Curry-Howard correspondance .

Okay, here's a concrete example of writing a proposition and its proof in Coq. When writing a specification for our original problem, we'll need to define a function abs that takes two natural numbers a and b and returns the absolute value |a-b|.

Definition abs (a b : nat) := max (b-a) (a-b).

Note that since we are working with natural numbers, subtraction is capped at zero--so 3 - 5 = 0, for instance.

To show that abs is defined correctly, we can prove that it satisfies the properties we want for all possible inputs. For instance, we would expect that abs a b is the same as abs b a. We can write this proposition in Coq like so:

Theorem abs_comm : forall (a b : nat), abs a b = abs b a.forall a b : nat, abs a b = abs b a

We can read this like a sentence in first-order logic: "For all a b, (abs a b) and (abs b a) are equal."

By writing Proof., we enter into the interactive proving mode, where we can prove our goal by entering a series of tactics.

Proof.forall a b : nat, abs a b = abs b a

First, we move a and b into our local context using the intros tactic. This is like writing "Let a and b be arbitrary natural numbers." in a written proof.

  intros a b.a, b: natabs a b = abs b a

Next, we unfold the definition of abs using the unfold tactic. Since abs is defined in terms of the max function, we can use properties about max to prove our goal.

  unfold abs.a, b: natInit.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)

The theorem Nat.max_comm should be useful. We can check its type using the Check directive:

  Check Nat.max_comm.Nat.max_comm
: forall n m : nat, Nat.max n m = Nat.max m na, b: natInit.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)

We can introduce a new hypothesis using assert, prove it, and then use it to help prove our original goal. This is like introducing a lemma in a written proof.

  assert (H : max (b-a) (a-b) = max (a-b) (b-a)).a, b: natInit.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)a, b: natH: Init.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)Init.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)

Our lemma follows directly from the Nat.max_comm theorem, so we can use the apply tactic:

  {a, b: natInit.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a) apply Nat.max_comm. }a, b: natH: Init.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)Init.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)

The most important thing about equalities is that if x = y, then x and y are interchangeable in every context. We can use the rewrite tactic to change max (b-a) (a-b) in our goal to max (a-b) (b-a).

  rewrite H.a, b: natH: Init.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a)Init.Nat.max (a - b) (b - a) =
Init.Nat.max (a - b) (b - a)

Now, both sides of the equality are exactly the same, so we can discharge the goal using reflexivity.

  reflexivity.

Use Qed. to declare victory.

Qed.

It's important to stress that tactics are only high-level directions that tell Coq how to create the proof. The proof term itself is often much longer and harder to understand, and it is what is checked by the typechecker to verify that you have proven your theorem. You can view the proof abs_comm using the Print directive.

Print abs_comm.abs_comm =
fun a b : nat =>
(let H :
Init.Nat.max (b - a) (a - b) =
Init.Nat.max (a - b) (b - a) :=
Nat.max_comm (b - a) (a - b) in
eq_ind_r
(fun n : nat => n = Init.Nat.max (a - b) (b - a))
eq_refl H)
:
abs a b = abs b a
: forall a b : nat, abs a b = abs b a

Arguments abs_comm (a b)%nat_scope

Let's prove one more property about abs. abs a b should be equal to the distance between a and b, so if a < b, then a + abs a b = b, and if b < a, then b + abs a b = b.

To prove this, we can use the lia tactic, which is a decision procedure for linear integer arithmetic. As our original problem involves a lot of arithmetic, lia will frequently come in handy.

Theorem abs_correct : forall (a b : nat),
(a < b -> a + abs a b = b) /\
(b < a -> b + abs a b = a).forall a b : nat,
(a < b -> a + abs a b = b) /\
(b < a -> b + abs a b = a)
Proof.forall a b : nat,
(a < b -> a + abs a b = b) /\
(b < a -> b + abs a b = a)
intros.a, b: nat(a < b -> a + abs a b = b) /\
(b < a -> b + abs a b = a)
unfold abs.a, b: nat(a < b -> a + Init.Nat.max (b - a) (a - b) = b) /\
(b < a -> b + Init.Nat.max (b - a) (a - b) = a)
lia.
Qed.

## Specification Using Lists

Now let's return to our original problem. We can write a proposition that defines whether there is a valid distribution of beans. For any r,b,d, there exists a correct distribution of beans if and only if there exists a set of packets such that

• the sum of all the red beans is r,
• the sum of all the blue beans is b,
• each packet contains a positive number of red and blue beans,
• for each packet, the number of red and blue beans should not differ by more than d.

The simplest way to represent a set of packets is a list. We'll define a packet as a pair of natural numbers, and we'll define the function packet_sum which adds up the number of red and blue beans in a list of packets.

Module ListSpec.

Definition packet : Set := nat * nat.

Fixpoint packet_sum (l : list packet) := match l with
| [] => (0,0)
| (r,b) :: rest => let (x,y) := packet_sum rest in (r+x,b+y)
end.

Now we can define what a correct_distribution means as define above. Here, we use the existential quantifier exists, meaning that to prove this proposition, we must supply a valid list that all the conditions.

Definition correct_distribution (r b d : nat) :=
exists l, packet_sum l = (r,b) /\ Forall (fun '(x,y) => abs x y <= d /\ x > 0 /\ y > 0) l.

Example ex1 : correct_distribution 1 1 0.correct_distribution 1 1 0
Proof.correct_distribution 1 1 0
exists [(1,1)].packet_sum [(1, 1)] = (1, 1) /\
Forall (fun '(x, y) => abs x y <= 0 /\ x > 0 /\ y > 0)
[(1, 1)]
split; unfold abs; auto.Forall
(fun '(x, y) =>
Init.Nat.max (y - x) (x - y) <= 0 /\ x > 0 /\ y > 0)
[(1, 1)]
repeat constructor; lia.
Qed.

Example ex2 : correct_distribution 2 7 3.correct_distribution 2 7 3
Proof.correct_distribution 2 7 3
exists [(1,4);(1,3)].packet_sum [(1, 4); (1, 3)] = (2, 7) /\
Forall (fun '(x, y) => abs x y <= 3 /\ x > 0 /\ y > 0)
[(1, 4); (1, 3)]
split; unfold abs; auto.Forall
(fun '(x, y) =>
Init.Nat.max (y - x) (x - y) <= 3 /\ x > 0 /\ y > 0)
[(1, 4); (1, 3)]
repeat constructor; lia.
Qed.

Finally, we can relate our algorithm can_distribute to our specification correct_distribution. Our algorithm should return true if and only if correct_distribution is provable, meaning that there exists a list of packets that meets the conditions defined in correct_distribution.

Definition algorithm_iff_correct_distribution :=
forall r b d,
can_distribute r b d = true <-> correct_distribution r b d.

If we can prove this theorem (meaning we can find a term which has this type), we know that our algorithm is correct. Moreover, other people do not even need to read our proof to trust that our algorithm is correct. They can simply read the specification and make sure the program typechecks.

A few lemmas about arithmetic that will come in handy. We make heavy use of the lia to avoid having to deal with most of these proofs by hand, but for some cases (multiplication by non-constants), we need to help guide the solver:

Lemma d_leq_mult_d : forall d r, r > 0 -> d <= r * d.forall d r : nat, r > 0 -> d <= r * d
Proof.forall d r : nat, r > 0 -> d <= r * d
intros.d, r: natH: r > 0d <= r * d
induction d; lia.
Qed.

Lemma d_bound : forall r b d x y,
r > 0 ->
b <= r + d ->
y <= x * (d+1) ->
b + y <= (r + x) * (d + 1).forall r b d x y : nat,
r > 0 ->
b <= r + d ->
y <= x * (d + 1) -> b + y <= (r + x) * (d + 1)
Proof.forall r b d x y : nat,
r > 0 ->
b <= r + d ->
y <= x * (d + 1) -> b + y <= (r + x) * (d + 1)
intros.r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)b + y <= (r + x) * (d + 1)
assert ((r+x)*(d+1) = x*(d+1) + r*(d+1)).r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)b + y <= (r + x) * (d + 1) {r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)(r + x) * (d + 1) = x * (d + 1) + r * (d + 1) lia. }r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)b + y <= (r + x) * (d + 1) rewrite H2.r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)b + y <= x * (d + 1) + r * (d + 1)
assert (r+d <= r*(d+1)); try lia.r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)r + d <= r * (d + 1)
pose proof (d_leq_mult_d d r).r, b, d, x, y: natH: r > 0H0: b <= r + dH1: y <= x * (d + 1)H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)H3: r > 0 -> d <= r * dr + d <= r * (d + 1) lia.
Qed.

(** This lemma says that we can exchange the uses of r and b, meaning that if we can distribute r red beans and b blue beans, we must also be able to distribute b red beans and r blue beans, and vice versa.
Although the proof is a bit complicated, the intuition is simple: we can flip the number of blue and red beans in each packet.
*)

Lemma distribution_flip : forall r b d,
correct_distribution r b d <-> correct_distribution b r d.forall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d
Proof.forall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d
assert (H:forall r b d, correct_distribution r b d -> correct_distribution b r d).forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r dH: forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r dforall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d {forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r d
unfold correct_distribution.forall r b d : nat,
(exists l : list packet,
packet_sum l = (r, b) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
intros r b d [l [H2 H3]].r, b, d: natl: list packetH2: packet_sum l = (r, b)H3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lexists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
generalize dependent r.b, d: natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lforall r : nat,
packet_sum l = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
generalize dependent b.d: natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lforall b r : nat,
packet_sum l = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
induction l; intros.d: natH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]b, r: natH2: packet_sum [] = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ld: nata: packetl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
(a :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall b r : nat,
packet_sum l = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r: natH2: packet_sum (a :: l) = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
+d: natH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]b, r: natH2: packet_sum [] = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l simpl in H2.d: natH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]b, r: natH2: (0, 0) = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l inversion H2; subst.d: natH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]H2: (0, 0) = (0, 0)exists l : list packet,
packet_sum l = (0, 0) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l exists []; auto.
+d: nata: packetl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
(a :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall b r : nat,
packet_sum l = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r: natH2: packet_sum (a :: l) = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l destruct a as [r' b'].d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall b r : nat,
packet_sum l = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r: natH2: packet_sum ((r', b') :: l) = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l simpl in H2.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall b r : nat,
packet_sum l = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) lb, r: natH2: (let (x, y) := packet_sum l in (r' + x, b' + y)) =
(r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
remember (packet_sum l) as X.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)X: (nat * nat)%typeHeqX: X = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall b r : nat,
X = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) =>
abs x y <= d /\ x > 0 /\ y > 0) lb, r: natH2: (let (x, y) := X in (r' + x, b' + y)) = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l destruct X as [x y].d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall b r : nat,
(x, y) = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) =>
abs x y <= d /\ x > 0 /\ y > 0) lb, r: natH2: (r' + x, b' + y) = (r, b)exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
inversion H2; subst.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall b r : nat,
(x, y) = (r, b) ->
exists l : list packet,
packet_sum l = (b, r) /\
Forall
(fun '(x, y) =>
abs x y <= d /\ x > 0 /\ y > 0) lH2: (r' + x, b' + y) = (r' + x, b' + y)exists l : list packet,
packet_sum l = (b' + y, r' + x) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
specialize IHl with y x.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
(x, y) = (x, y) ->
exists l : list packet,
packet_sum l = (y, x) /\
Forall
(fun '(x, y) =>
abs x y <= d /\ x > 0 /\ y > 0) lH2: (r' + x, b' + y) = (r' + x, b' + y)exists l : list packet,
packet_sum l = (b' + y, r' + x) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
destruct IHl; auto.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
ld, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0exists l : list packet,
packet_sum l = (b' + y, r' + x) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l inversion H3; auto.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0exists l : list packet,
packet_sum l = (b' + y, r' + x) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
destruct H.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0exists l : list packet,
packet_sum l = (b' + y, r' + x) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
exists ((b',r')::x0).d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0packet_sum ((b', r') :: x0) = (b' + y, r' + x) /\
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((b', r') :: x0)
split.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0packet_sum ((b', r') :: x0) = (b' + y, r' + x)d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((b', r') :: x0) simpl.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0(let (x, y) := packet_sum x0 in (b' + x, r' + y)) =
(b' + y, r' + x)d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((b', r') :: x0) rewrite H.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0(b' + y, r' + x) = (b' + y, r' + x)d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((b', r') :: x0) auto.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((b', r') :: x0)
constructor; auto.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0abs b' r' <= d /\ b' > 0 /\ r' > 0 inversion H3; subst.d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0H5: abs r' b' <= d /\ r' > 0 /\ b' > 0H6: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
labs b' r' <= d /\ b' > 0 /\ r' > 0 destruct H5 as [Ha [Hr Hb]].d, r', b': natl: list packetH3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lH2: (r' + x, b' + y) = (r' + x, b' + y)x0: list packetH: packet_sum x0 = (y, x)H0: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
x0Ha: abs r' b' <= dHr: r' > 0Hb: b' > 0H6: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
labs b' r' <= d /\ b' > 0 /\ r' > 0 split; auto; unfold abs in *; try lia.
}H: forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r dforall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d
split; apply H.
Qed.

Here is the crucial lemma proving that our algorithm's conditions is sufficient. Note that we assume that r <= b, like we did in the informal proof.

Lemma can_make_distr : forall r b d,
r <= b ->
b <= r * (d+1) ->
correct_distribution r b d.forall r b d : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b d
Proof.forall r b d : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b d
intros r b d r_leq_b H.r, b, d: natr_leq_b: r <= bH: b <= r * (d + 1)correct_distribution r b d
generalize dependent b.r, d: natforall b : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b d
induction r; intros.d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)correct_distribution 0 b dr, d: natIHr: forall b : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b db: natr_leq_b: S r <= bH: b <= S r * (d + 1)correct_distribution (S r) b d
-d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)correct_distribution 0 b d assert (b = 0).d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)b = 0d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)H0: b = 0correct_distribution 0 b d lia.d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)H0: b = 0correct_distribution 0 b d subst.d: natH: 0 <= 0 * (d + 1)r_leq_b: 0 <= 0correct_distribution 0 0 d exists [].d: natH: 0 <= 0 * (d + 1)r_leq_b: 0 <= 0packet_sum [] = (0, 0) /\
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[] auto.
-r, d: natIHr: forall b : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b db: natr_leq_b: S r <= bH: b <= S r * (d + 1)correct_distribution (S r) b d remember (min (b - r) (d + 1)) as t.r, d: natIHr: forall b : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b db: natr_leq_b: S r <= bH: b <= S r * (d + 1)t: natHeqt: t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b d
specialize IHr with (b-t).r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b d
assert (Hr : r <= b - t).r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)r <= b - tr, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: r <= b - tcorrect_distribution (S r) b d {r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)r <= b - t lia. }r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: r <= b - tcorrect_distribution (S r) b d apply IHr in Hr.r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: correct_distribution r (b - t) dcorrect_distribution (S r) b dr, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: r <= b - tb - t <= r * (d + 1)
destruct Hr as [l [Hl1 Hl2]].r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lcorrect_distribution (S r) b dr, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: r <= b - tb - t <= r * (d + 1)
exists ((1,t) :: l); split.r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lpacket_sum ((1, t) :: l) = (S r, b)r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lForall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((1, t) :: l)r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: r <= b - tb - t <= r * (d + 1)
+r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lpacket_sum ((1, t) :: l) = (S r, b) simpl.r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l(let (x, y) := packet_sum l in (S x, t + y)) =
(S r, b) rewrite Hl1.r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l(S r, t + (b - t)) = (S r, b)
assert (Htb: t + (b - t) = b).r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lt + (b - t) = br, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lHtb: t + (b - t) = b(S r, t + (b - t)) = (S r, b) {r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lt + (b - t) = b lia. }r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lHtb: t + (b - t) = b(S r, t + (b - t)) = (S r, b) rewrite Htb.r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lHtb: t + (b - t) = b(S r, b) = (S r, b)
reflexivity.
+r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lForall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((1, t) :: l) constructor; try auto; split; try lia.r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)l: list packetHl1: packet_sum l = (r, b - t)Hl2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
labs 1 t <= d unfold abs; lia.
+r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: r <= b - tb - t <= r * (d + 1) simpl in H.r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= d + 1 + r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hr: r <= b - tb - t <= r * (d + 1)
pose proof (Nat.min_spec (b-r) (d+1)) as [[Hmin1 minEq] | [Hmin2 minEq]]; rewrite minEq in Heqt; lia.
Qed.

Next, we'll prove that our algorithm's condition is necessary, meaning that if there is a list l that shows that correct_distribution r b d is true, then b <= r * (d + 1) and r <= b * (d + 1).

The proof proceeds by induction on the list l.

Theorem algorithm_condition_necessary : forall r b d,
correct_distribution r b d -> can_distribute r b d = true.forall r b d : nat,
correct_distribution r b d ->
can_distribute r b d = true
Proof.forall r b d : nat,
correct_distribution r b d ->
can_distribute r b d = true
unfold correct_distribution, can_distribute.forall r b d : nat,
(exists l : list packet,
packet_sum l = (r, b) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l) ->
(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
intros.r, b, d: natH: exists l : list packet,
packet_sum l = (r, b) /\
Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
destruct H as [l [H1 H2]].r, b, d: natl: list packetH1: packet_sum l = (r, b)H2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
unfold can_distribute.r, b, d: natl: list packetH1: packet_sum l = (r, b)H2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
rewrite andb_true_iff.r, b, d: natl: list packetH1: packet_sum l = (r, b)H2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l(b <=? r * (d + 1)) = true /\
(r <=? b * (d + 1)) = true repeat rewrite Nat.leb_le.r, b, d: natl: list packetH1: packet_sum l = (r, b)H2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lb <= r * (d + 1) /\ r <= b * (d + 1)
generalize dependent b.r, d: natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lforall b : nat,
packet_sum l = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)
generalize dependent r.d: natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lforall r b : nat,
packet_sum l = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)
induction l; intros.d: natH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]r, b: natH1: packet_sum [] = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)d: nata: packetl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
(a :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall r b : nat,
packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b: natH1: packet_sum (a :: l) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1)
+d: natH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]r, b: natH1: packet_sum [] = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1) simpl in H1.d: natH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]r, b: natH1: (0, 0) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1) inversion H1; subst.d: natH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
[]H1: (0, 0) = (0, 0)0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1) lia.
+d: nata: packetl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
(a :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall r b : nat,
packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b: natH1: packet_sum (a :: l) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1) destruct a as [r' b'].d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall r b : nat,
packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b: natH1: packet_sum ((r', b') :: l) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1) simpl in H1.d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l ->
forall r b : nat,
packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)r, b: natH1: (let (x, y) := packet_sum l in (r' + x, b' + y)) =
(r, b)b <= r * (d + 1) /\ r <= b * (d + 1) remember (packet_sum l) as X.d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)X: (nat * nat)%typeHeqX: X = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall r b : nat,
X = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)r, b: natH1: (let (x, y) := X in (r' + x, b' + y)) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1) destruct X as [x y].d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall r b : nat,
(x, y) = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)r, b: natH1: (r' + x, b' + y) = (r, b)b <= r * (d + 1) /\ r <= b * (d + 1) inversion H1; subst; clear H1.d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall r b : nat,
(x, y) = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)b' + y <= (r' + x) * (d + 1) /\
r' + x <= (b' + y) * (d + 1)
inversion H2; subst.d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall r b : nat,
(x, y) = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)H1: abs r' b' <= d /\ r' > 0 /\ b' > 0H3: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
lb' + y <= (r' + x) * (d + 1) /\
r' + x <= (b' + y) * (d + 1) apply IHl with x y in H3; auto.d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall r b : nat,
(x, y) = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)H1: abs r' b' <= d /\ r' > 0 /\ b' > 0H3: y <= x * (d + 1) /\ x <= y * (d + 1)b' + y <= (r' + x) * (d + 1) /\
r' + x <= (b' + y) * (d + 1)
unfold abs in H1.d, r', b': natl: list packetH2: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
((r', b') :: l)x, y: natHeqX: (x, y) = packet_sum lIHl: Forall
(fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0)
l ->
forall r b : nat,
(x, y) = (r, b) ->
b <= r * (d + 1) /\ r <= b * (d + 1)H1: Init.Nat.max (b' - r') (r' - b') <= d /\
r' > 0 /\ b' > 0H3: y <= x * (d + 1) /\ x <= y * (d + 1)b' + y <= (r' + x) * (d + 1) /\
r' + x <= (b' + y) * (d + 1)
split; apply d_bound; lia.
Qed.

### The proof of correctness!

Theorem algorithm_correct : algorithm_iff_correct_distribution.algorithm_iff_correct_distribution
Proof.algorithm_iff_correct_distribution
unfold algorithm_iff_correct_distribution.forall r b d : nat,
can_distribute r b d = true <->
correct_distribution r b d
intros.r, b, d: natcan_distribute r b d = true <->
correct_distribution r b d
split; intros.r, b, d: natH: can_distribute r b d = truecorrect_distribution r b dr, b, d: natH: correct_distribution r b dcan_distribute r b d = true
-r, b, d: natH: can_distribute r b d = truecorrect_distribution r b d unfold can_distribute in H.r, b, d: natH: (b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truecorrect_distribution r b d
rewrite andb_true_iff in H.r, b, d: natH: (b <=? r * (d + 1)) = true /\
(r <=? b * (d + 1)) = truecorrect_distribution r b d repeat rewrite Nat.leb_le in H.r, b, d: natH: b <= r * (d + 1) /\ r <= b * (d + 1)correct_distribution r b d destruct H as [H1 H2].r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)correct_distribution r b d
(* Case on whether r <= b or b <= r. *)
assert (r <= b \/ b <= r) as [Hrb | Hbr].r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)r <= b \/ b <= rr, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hrb: r <= bcorrect_distribution r b dr, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution r b d lia.r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hrb: r <= bcorrect_distribution r b dr, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution r b d
+r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hrb: r <= bcorrect_distribution r b d apply can_make_distr; auto.
+r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution r b d apply distribution_flip.r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution b r d apply can_make_distr; auto.

-r, b, d: natH: correct_distribution r b dcan_distribute r b d = true apply algorithm_condition_necessary; auto.
Qed.
End ListSpec.

## Specification Using Inductive Relations

We've done it! We have defined our specification using lists and proven that our algorithm always computes the correct answer according to our specification.

Now that we know for certain that our program is correct, we can proceed to extracting the program into Ocaml. Before we move on, however, I want to present another specification for this problem, this time defining correct_distribution as an inductive proposition.

While it may be natural to interpret the specification using lists, it does make the definitions more complicated, which affects the length and readability of the proofs. This alternate specification is equivalent (in fact, proving that these two specifications are equivalent is a good exercise), but it requires a lot less unfolding and destructing, which simplifies the proofs.

Module InductiveSpec.
Inductive correct_distribution : nat -> nat -> nat -> Prop :=
| no_packets : forall d,
correct_distribution 0 0 d
| add_packet : forall r b r' b' d,
correct_distribution r b d ->
r' > 0 ->
b' > 0 ->
abs r' b' <= d ->
correct_distribution (r'+r) (b'+b) d.
#[export] Hint Constructors correct_distribution : core.

Definition algorithm_iff_correct_distribution :=
forall r b d,
can_distribute r b d = true <-> correct_distribution r b d.

Lemma distribution_flip : forall r b d,
correct_distribution r b d <-> correct_distribution b r d.forall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d
Proof.forall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d
assert (H:forall r b d, correct_distribution r b d -> correct_distribution b r d).forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r dH: forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r dforall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d {forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r d
intros r b d H.r, b, d: natH: correct_distribution r b dcorrect_distribution b r d
induction H.d: natcorrect_distribution 0 0 dr, b, r', b', d: natH: correct_distribution r b dH0: r' > 0H1: b' > 0H2: abs r' b' <= dIHcorrect_distribution: correct_distribution b r dcorrect_distribution (b' + b) (r' + r) d
-d: natcorrect_distribution 0 0 d auto.
-r, b, r', b', d: natH: correct_distribution r b dH0: r' > 0H1: b' > 0H2: abs r' b' <= dIHcorrect_distribution: correct_distribution b r dcorrect_distribution (b' + b) (r' + r) d constructor; auto.r, b, r', b', d: natH: correct_distribution r b dH0: r' > 0H1: b' > 0H2: abs r' b' <= dIHcorrect_distribution: correct_distribution b r dabs b' r' <= d unfold abs in *.r, b, r', b', d: natH: correct_distribution r b dH0: r' > 0H1: b' > 0H2: Init.Nat.max (b' - r') (r' - b') <= dIHcorrect_distribution: correct_distribution b r dInit.Nat.max (r' - b') (b' - r') <= d lia. }H: forall r b d : nat,
correct_distribution r b d ->
correct_distribution b r dforall r b d : nat,
correct_distribution r b d <->
correct_distribution b r d
split; apply H.
Qed.

Lemma can_make_distr : forall r b d,
r <= b ->
b <= r * (d+1) ->
correct_distribution r b d.forall r b d : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b d
Proof.forall r b d : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b d
intros r b d r_leq_b H.r, b, d: natr_leq_b: r <= bH: b <= r * (d + 1)correct_distribution r b d
generalize dependent b.r, d: natforall b : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b d
induction r; intros.d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)correct_distribution 0 b dr, d: natIHr: forall b : nat,
r <= b ->
b <= r * (d + 1) -> correct_distribution r b db: natr_leq_b: S r <= bH: b <= S r * (d + 1)correct_distribution (S r) b d
-d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)correct_distribution 0 b d assert (b = 0).d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)b = 0d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)H0: b = 0correct_distribution 0 b d lia.d, b: natr_leq_b: 0 <= bH: b <= 0 * (d + 1)H0: b = 0correct_distribution 0 b d subst.d: natH: 0 <= 0 * (d + 1)r_leq_b: 0 <= 0correct_distribution 0 0 d auto.
-r, d: natIHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db: natr_leq_b: S r <= bH: b <= S r * (d + 1)correct_distribution (S r) b d remember (min (b - r) (d + 1)) as t.r, d: natIHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b db: natr_leq_b: S r <= bH: b <= S r * (d + 1)t: natHeqt: t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b d
specialize IHr with (b-t).r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)correct_distribution (S r) b d
assert (Hb: b = t+(b-t)).r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)b = t + (b - t)r, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)correct_distribution (S r) b d {r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)b = t + (b - t) lia. }r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)correct_distribution (S r) b d rewrite Hb.r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)correct_distribution (S r) (t + (b - t)) d
assert (Hr: S r = 1 + r).r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)S r = 1 + rr, d, b, t: natIHr: r <= b - t ->
b - t <= r * (d + 1) ->
correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)Hr: S r = 1 + rcorrect_distribution (S r) (t + (b - t)) d {r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)S r = 1 + r lia. }r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)Hr: S r = 1 + rcorrect_distribution (S r) (t + (b - t)) d rewrite Hr.r, d, b, t: natIHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) dr_leq_b: S r <= bH: b <= S r * (d + 1)Heqt: t = Init.Nat.min (b - r) (d + 1)Hb: b = t + (b - t)Hr: S r = 1 + rcorrect_distribution (1 + r) (t + (b - t)) d
constructor; unfold abs; try apply IHr; try lia.
Qed.

Theorem algorithm_correct : algorithm_iff_correct_distribution.algorithm_iff_correct_distribution
Proof.algorithm_iff_correct_distribution
unfold algorithm_iff_correct_distribution.forall r b d : nat,
can_distribute r b d = true <->
correct_distribution r b d
split; intros.r, b, d: natH: can_distribute r b d = truecorrect_distribution r b dr, b, d: natH: correct_distribution r b dcan_distribute r b d = true
-r, b, d: natH: can_distribute r b d = truecorrect_distribution r b d unfold can_distribute in H.r, b, d: natH: (b <=? r * (d + 1)) && (r <=? b * (d + 1)) = truecorrect_distribution r b d
rewrite andb_true_iff in H.r, b, d: natH: (b <=? r * (d + 1)) = true /\
(r <=? b * (d + 1)) = truecorrect_distribution r b d repeat rewrite Nat.leb_le in H.r, b, d: natH: b <= r * (d + 1) /\ r <= b * (d + 1)correct_distribution r b d destruct H as [H1 H2].r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)correct_distribution r b d
(* Case on whether r <= b or b <= r. *)
assert (r <= b \/ b <= r) as [Hrb | Hbr].r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)r <= b \/ b <= rr, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hrb: r <= bcorrect_distribution r b dr, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution r b d lia.r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hrb: r <= bcorrect_distribution r b dr, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution r b d
+r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hrb: r <= bcorrect_distribution r b d apply can_make_distr; auto.
+r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution r b d apply distribution_flip.r, b, d: natH1: b <= r * (d + 1)H2: r <= b * (d + 1)Hbr: b <= rcorrect_distribution b r d apply can_make_distr; auto.
-r, b, d: natH: correct_distribution r b dcan_distribute r b d = true unfold can_distribute.r, b, d: natH: correct_distribution r b d(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
rewrite andb_true_iff.r, b, d: natH: correct_distribution r b d(b <=? r * (d + 1)) = true /\
(r <=? b * (d + 1)) = true repeat rewrite Nat.leb_le.r, b, d: natH: correct_distribution r b db <= r * (d + 1) /\ r <= b * (d + 1)
induction H.d: nat0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1)r, b, r', b', d: natH: correct_distribution r b dH0: r' > 0H1: b' > 0H2: abs r' b' <= dIHcorrect_distribution: b <= r * (d + 1) /\
r <= b * (d + 1)b' + b <= (r' + r) * (d + 1) /\
r' + r <= (b' + b) * (d + 1)
+d: nat0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1) lia.
+r, b, r', b', d: natH: correct_distribution r b dH0: r' > 0H1: b' > 0H2: abs r' b' <= dIHcorrect_distribution: b <= r * (d + 1) /\
r <= b * (d + 1)b' + b <= (r' + r) * (d + 1) /\
r' + r <= (b' + b) * (d + 1) split; apply ListSpec.d_bound; unfold abs in H2; try lia.
Qed.

End InductiveSpec.

## Extraction

Now we can extract our verified algorithm to Ocaml. We don't extract our proofs, since they not actually meant to be run. So in this case, we only need to extract our function can_distribute, which is quite simple. We can try running the extraction command now:

Extraction "imp.ml" can_distribute.

If you look in the file imp.ml, you will see the following datatype definitions at the top:

type bool =
| True
| False

type nat =
| O
| S of nat


Without any directions on how to perform the extraction, Coq will redefine all the datatypes that are used, including booleans and nat. This is a big problem for nat, since defining numbers Peano-style is quite inefficient--if you look how addition is defined, adding two numbers is actually linear in the size of the first number. We can tell Coq to extract nat to int or int64, but this can be quite dangerous. This is because theorems about nat may no longer hold. For instance, it is a theorem for nat that x + y >= x, but this is not true for int or int64, since there may be overflow cases.

In this case, since we know that the inputs are less than 10^9, we can determine that we will not run into overflow issues. So, we can safely extract nat to int64. If you would like, you can also extract to an arbitrary precision integer type like big_int.

Module ExtractionDefs.
Module ExtrNatToInt64.
Extract Inductive nat => "int64"
[ "Int64.zero" "(fun x -> Int64.succ x)" ]
"(fun zero succ n -> if Int64.compare n 0 = 0 then zero () else succ (Int64.pred n))".
Extraction "imp.ml" can_distribute.