Verifying Programming Contest Problems

Timothy Mou - Mon 05 December 2022 -

In programming contests, one of the most frustrating (and common!) verdicts to receive is Wrong Answer (WA). One of the reasons behind this is that there are often many potential sources of WA that a programmer has to hunt down. A WA could come from a typo in the implementation, or it could come from the use of a greedy algorithm which is not always correct, or many other reasons.

In this tutorial, I will try to give a brief overview of how you can prove your programs correct by writing specifications and proofs in the Coq proof assistant. Coq is based on a dependent type theory that allows you to define your programs and write proofs about them. Writing proofs by hand can be quite laborious, so Coq also has an interactive proving system using tactics that allows you to automate away tedious parts of proofs and also prove write proofs in a manner that is much more natural for us humans. This very document is a (literate) Coq file, so you can download the source and follow along. By the end of this tutorial, you should be able to compile this Coq file, extract it to Ocaml, and submit your solution to an online judge!

Some Disclaimers

There are, of course, several caveats I should point out. The first of which is that writing out a specification and proving an algorithm correct takes a significant amount of time, even for experienced users. It would be utterly impractical for a contest setting, where "proofs by AC" are quite common. My goal is to use contest problems as a vehicle to introduce formal verification in a small-scale, concrete way--although of course program verification is also used on much larger projects.

Secondly, we may be limited in the kinds of problems we can (easily) verify and write efficient solutions for. For instance, many contest problems involve mutable arrays, while when using Coq, we tend to prefer lists and other functional data structures because they are easier to reason about.

A Simple Example

For our example, we'll solve the Codeforces problem Red and Blue Beans .

This is a pretty basic problem involving only basic arithmetic. The answer is YES if and only if b <= r * (d+1) and r <= b * (d+1).

How can we prove this is the right condition? Without loss of generality, we can assume that r <= b. If we wanted to maximize the number of blue beans for a given number of red beans, we would create r packets, each with 1 red bean and d+1 blue beans. Therefore, if we have more than r * (d+1) blue beans, we cannot distribute all the beans; otherwise, we can.

Now, let's get to programming our solution!

Require Import Lia.
Require Import Bool.
Require Import List.
Require Import Arith.Arith.
Import ListNotations.
Require Extraction.

Definition can_distribute (r b d : nat) : bool :=
  (b <=? (r * (d + 1))) && (r <=? (b * (d + 1))).

That wasn't too bad! If we were programming in another language like Java or C++, we would stop here and submit our solution. However, even after passing hundreds of tests, we still can't be sure that is correct for all inputs. The only way to be certain of a program's correctness is to prove it, like we did above. But how do we know our proof is correct? People make mistakes in proofs all the time. Luckily, Coq's type system is expressive enough to allow you to write propositions about your programs as types. If we can find a term t that has type T, we say that T is inhabited. Interpreting T as a proposition, we say that t is evidence for T being true. This effectively means that we can use Coq's type-checker as a proof-checker! This is known as the Curry-Howard correspondance .

Okay, here's a concrete example of writing a proposition and its proof in Coq. When writing a specification for our original problem, we'll need to define a function abs that takes two natural numbers a and b and returns the absolute value |a-b|.

Definition abs (a b : nat) := max (b-a) (a-b).

Note that since we are working with natural numbers, subtraction is capped at zero--so 3 - 5 = 0, for instance.

To show that abs is defined correctly, we can prove that it satisfies the properties we want for all possible inputs. For instance, we would expect that abs a b is the same as abs b a. We can write this proposition in Coq like so:


forall a b : nat, abs a b = abs b a

We can read this like a sentence in first-order logic: "For all a b, (abs a b) and (abs b a) are equal."

By writing Proof., we enter into the interactive proving mode, where we can prove our goal by entering a series of tactics.


forall a b : nat, abs a b = abs b a

First, we move a and b into our local context using the intros tactic. This is like writing "Let a and b be arbitrary natural numbers." in a written proof.

  
a, b: nat

abs a b = abs b a

Next, we unfold the definition of abs using the unfold tactic. Since abs is defined in terms of the max function, we can use properties about max to prove our goal.

  
a, b: nat

Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

The theorem Nat.max_comm should be useful. We can check its type using the Check directive:

  
Nat.max_comm : forall n m : nat, Nat.max n m = Nat.max m n
a, b: nat

Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

We can introduce a new hypothesis using assert, prove it, and then use it to help prove our original goal. This is like introducing a lemma in a written proof.

  
a, b: nat

Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)
a, b: nat
H: Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)
Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

Our lemma follows directly from the Nat.max_comm theorem, so we can use the apply tactic:

  
a, b: nat

Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)
apply Nat.max_comm.
a, b: nat
H: Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

The most important thing about equalities is that if x = y, then x and y are interchangeable in every context. We can use the rewrite tactic to change max (b-a) (a-b) in our goal to max (a-b) (b-a).

  
a, b: nat
H: Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a)

Init.Nat.max (a - b) (b - a) = Init.Nat.max (a - b) (b - a)

Now, both sides of the equality are exactly the same, so we can discharge the goal using reflexivity.

  reflexivity.

Use Qed. to declare victory.

Qed.

It's important to stress that tactics are only high-level directions that tell Coq how to create the proof. The proof term itself is often much longer and harder to understand, and it is what is checked by the typechecker to verify that you have proven your theorem. You can view the proof abs_comm using the Print directive.

abs_comm = fun a b : nat => (let H : Init.Nat.max (b - a) (a - b) = Init.Nat.max (a - b) (b - a) := Nat.max_comm (b - a) (a - b) in eq_ind_r (fun n : nat => n = Init.Nat.max (a - b) (b - a)) eq_refl H) : abs a b = abs b a : forall a b : nat, abs a b = abs b a Arguments abs_comm (a b)%nat_scope

Let's prove one more property about abs. abs a b should be equal to the distance between a and b, so if a < b, then a + abs a b = b, and if b < a, then b + abs a b = b.

To prove this, we can use the lia tactic, which is a decision procedure for linear integer arithmetic. As our original problem involves a lot of arithmetic, lia will frequently come in handy.


forall a b : nat, (a < b -> a + abs a b = b) /\ (b < a -> b + abs a b = a)

forall a b : nat, (a < b -> a + abs a b = b) /\ (b < a -> b + abs a b = a)
a, b: nat

(a < b -> a + abs a b = b) /\ (b < a -> b + abs a b = a)
a, b: nat

(a < b -> a + Init.Nat.max (b - a) (a - b) = b) /\ (b < a -> b + Init.Nat.max (b - a) (a - b) = a)
lia. Qed.

Specification Using Lists

Now let's return to our original problem. We can write a proposition that defines whether there is a valid distribution of beans. For any r,b,d, there exists a correct distribution of beans if and only if there exists a set of packets such that

  • the sum of all the red beans is r,
  • the sum of all the blue beans is b,
  • each packet contains a positive number of red and blue beans,
  • for each packet, the number of red and blue beans should not differ by more than d.

The simplest way to represent a set of packets is a list. We'll define a packet as a pair of natural numbers, and we'll define the function packet_sum which adds up the number of red and blue beans in a list of packets.

Module ListSpec.

Definition packet : Set := nat * nat.

Fixpoint packet_sum (l : list packet) := match l with
  | [] => (0,0)
  | (r,b) :: rest => let (x,y) := packet_sum rest in (r+x,b+y)
  end.

Now we can define what a correct_distribution means as define above. Here, we use the existential quantifier exists, meaning that to prove this proposition, we must supply a valid list that all the conditions.

Definition correct_distribution (r b d : nat) :=
  exists l, packet_sum l = (r,b) /\ Forall (fun '(x,y) => abs x y <= d /\ x > 0 /\ y > 0) l.


correct_distribution 1 1 0

correct_distribution 1 1 0

packet_sum [(1, 1)] = (1, 1) /\ Forall (fun '(x, y) => abs x y <= 0 /\ x > 0 /\ y > 0) [(1, 1)]

Forall (fun '(x, y) => Init.Nat.max (y - x) (x - y) <= 0 /\ x > 0 /\ y > 0) [(1, 1)]
repeat constructor; lia. Qed.

correct_distribution 2 7 3

correct_distribution 2 7 3

packet_sum [(1, 4); (1, 3)] = (2, 7) /\ Forall (fun '(x, y) => abs x y <= 3 /\ x > 0 /\ y > 0) [(1, 4); (1, 3)]

Forall (fun '(x, y) => Init.Nat.max (y - x) (x - y) <= 3 /\ x > 0 /\ y > 0) [(1, 4); (1, 3)]
repeat constructor; lia. Qed.

Finally, we can relate our algorithm can_distribute to our specification correct_distribution. Our algorithm should return true if and only if correct_distribution is provable, meaning that there exists a list of packets that meets the conditions defined in correct_distribution.

Definition algorithm_iff_correct_distribution :=
  forall r b d,
    can_distribute r b d = true <-> correct_distribution r b d.

If we can prove this theorem (meaning we can find a term which has this type), we know that our algorithm is correct. Moreover, other people do not even need to read our proof to trust that our algorithm is correct. They can simply read the specification and make sure the program typechecks.

Some helpful lemmas

A few lemmas about arithmetic that will come in handy. We make heavy use of the lia to avoid having to deal with most of these proofs by hand, but for some cases (multiplication by non-constants), we need to help guide the solver:


forall d r : nat, r > 0 -> d <= r * d

forall d r : nat, r > 0 -> d <= r * d
d, r: nat
H: r > 0

d <= r * d
induction d; lia. Qed.

forall r b d x y : nat, r > 0 -> b <= r + d -> y <= x * (d + 1) -> b + y <= (r + x) * (d + 1)

forall r b d x y : nat, r > 0 -> b <= r + d -> y <= x * (d + 1) -> b + y <= (r + x) * (d + 1)
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)

b + y <= (r + x) * (d + 1)
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)

(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)
H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)
b + y <= (r + x) * (d + 1)
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)

(r + x) * (d + 1) = x * (d + 1) + r * (d + 1)
lia.
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)
H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)

b + y <= (r + x) * (d + 1)
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)
H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)

b + y <= x * (d + 1) + r * (d + 1)
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)
H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)

r + d <= r * (d + 1)
r, b, d, x, y: nat
H: r > 0
H0: b <= r + d
H1: y <= x * (d + 1)
H2: (r + x) * (d + 1) = x * (d + 1) + r * (d + 1)
H3: r > 0 -> d <= r * d

r + d <= r * (d + 1)
lia. Qed. (** This lemma says that we can exchange the uses of `r` and `b`, meaning that if we can distribute `r` red beans and `b` blue beans, we must also be able to distribute `b` red beans and `r` blue beans, and vice versa. Although the proof is a bit complicated, the intuition is simple: we can flip the number of blue and red beans in each packet. *)

forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d

forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d

forall r b d : nat, correct_distribution r b d -> correct_distribution b r d
H: forall r b d : nat, correct_distribution r b d -> correct_distribution b r d
forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d

forall r b d : nat, correct_distribution r b d -> correct_distribution b r d

forall r b d : nat, (exists l : list packet, packet_sum l = (r, b) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
r, b, d: nat
l: list packet
H2: packet_sum l = (r, b)
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
b, d: nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

forall r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d: nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d: nat
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
b, r: nat
H2: packet_sum [] = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d: nat
a: packet
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
b, r: nat
H2: packet_sum (a :: l) = (r, b)
exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d: nat
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
b, r: nat
H2: packet_sum [] = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d: nat
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
b, r: nat
H2: (0, 0) = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d: nat
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
H2: (0, 0) = (0, 0)

exists l : list packet, packet_sum l = (0, 0) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
exists []; auto.
d: nat
a: packet
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
b, r: nat
H2: packet_sum (a :: l) = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
b, r: nat
H2: packet_sum ((r', b') :: l) = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, packet_sum l = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
b, r: nat
H2: (let (x, y) := packet_sum l in (r' + x, b' + y)) = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
X: (nat * nat)%type
HeqX: X = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, X = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
b, r: nat
H2: (let (x, y) := X in (r' + x, b' + y)) = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, (x, y) = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
b, r: nat
H2: (r' + x, b' + y) = (r, b)

exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall b r : nat, (x, y) = (r, b) -> exists l : list packet, packet_sum l = (b, r) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
H2: (r' + x, b' + y) = (r' + x, b' + y)

exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> (x, y) = (x, y) -> exists l : list packet, packet_sum l = (y, x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
H2: (r' + x, b' + y) = (r' + x, b' + y)

exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)

Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0
exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

exists l : list packet, packet_sum l = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

packet_sum ((b', r') :: x0) = (b' + y, r' + x) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

packet_sum ((b', r') :: x0) = (b' + y, r' + x)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

(let (x, y) := packet_sum x0 in (b' + x, r' + y)) = (b' + y, r' + x)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

(b' + y, r' + x) = (b' + y, r' + x)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((b', r') :: x0)
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0

abs b' r' <= d /\ b' > 0 /\ r' > 0
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0
H5: abs r' b' <= d /\ r' > 0 /\ b' > 0
H6: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

abs b' r' <= d /\ b' > 0 /\ r' > 0
d, r', b': nat
l: list packet
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
H2: (r' + x, b' + y) = (r' + x, b' + y)
x0: list packet
H: packet_sum x0 = (y, x)
H0: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) x0
Ha: abs r' b' <= d
Hr: r' > 0
Hb: b' > 0
H6: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

abs b' r' <= d /\ b' > 0 /\ r' > 0
split; auto; unfold abs in *; try lia.
H: forall r b d : nat, correct_distribution r b d -> correct_distribution b r d

forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d
split; apply H. Qed.

Here is the crucial lemma proving that our algorithm's conditions is sufficient. Note that we assume that r <= b, like we did in the informal proof.


forall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d

forall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
r, b, d: nat
r_leq_b: r <= b
H: b <= r * (d + 1)

correct_distribution r b d
r, d: nat

forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)

correct_distribution 0 b d
r, d: nat
IHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
b: nat
r_leq_b: S r <= b
H: b <= S r * (d + 1)
correct_distribution (S r) b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)

correct_distribution 0 b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)

b = 0
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)
H0: b = 0
correct_distribution 0 b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)
H0: b = 0

correct_distribution 0 b d
d: nat
H: 0 <= 0 * (d + 1)
r_leq_b: 0 <= 0

correct_distribution 0 0 d
d: nat
H: 0 <= 0 * (d + 1)
r_leq_b: 0 <= 0

packet_sum [] = (0, 0) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
auto.
r, d: nat
IHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
b: nat
r_leq_b: S r <= b
H: b <= S r * (d + 1)

correct_distribution (S r) b d
r, d: nat
IHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
b: nat
r_leq_b: S r <= b
H: b <= S r * (d + 1)
t: nat
Heqt: t = Init.Nat.min (b - r) (d + 1)

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)

r <= b - t
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: r <= b - t
correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)

r <= b - t
lia.
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: r <= b - t

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: correct_distribution r (b - t) d

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: r <= b - t
b - t <= r * (d + 1)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: r <= b - t
b - t <= r * (d + 1)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

packet_sum ((1, t) :: l) = (S r, b)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((1, t) :: l)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: r <= b - t
b - t <= r * (d + 1)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

packet_sum ((1, t) :: l) = (S r, b)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

(let (x, y) := packet_sum l in (S x, t + y)) = (S r, b)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

(S r, t + (b - t)) = (S r, b)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

t + (b - t) = b
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
Htb: t + (b - t) = b
(S r, t + (b - t)) = (S r, b)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

t + (b - t) = b
lia.
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
Htb: t + (b - t) = b

(S r, t + (b - t)) = (S r, b)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l
Htb: t + (b - t) = b

(S r, b) = (S r, b)
reflexivity.
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((1, t) :: l)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
l: list packet
Hl1: packet_sum l = (r, b - t)
Hl2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

abs 1 t <= d
unfold abs; lia.
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: r <= b - t

b - t <= r * (d + 1)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= d + 1 + r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hr: r <= b - t

b - t <= r * (d + 1)
pose proof (Nat.min_spec (b-r) (d+1)) as [[Hmin1 minEq] | [Hmin2 minEq]]; rewrite minEq in Heqt; lia. Qed.

Next, we'll prove that our algorithm's condition is necessary, meaning that if there is a list l that shows that correct_distribution r b d is true, then b <= r * (d + 1) and r <= b * (d + 1).

The proof proceeds by induction on the list l.


forall r b d : nat, correct_distribution r b d -> can_distribute r b d = true

forall r b d : nat, correct_distribution r b d -> can_distribute r b d = true

forall r b d : nat, (exists l : list packet, packet_sum l = (r, b) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l) -> (b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
r, b, d: nat
H: exists l : list packet, packet_sum l = (r, b) /\ Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
r, b, d: nat
l: list packet
H1: packet_sum l = (r, b)
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
r, b, d: nat
l: list packet
H1: packet_sum l = (r, b)
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
r, b, d: nat
l: list packet
H1: packet_sum l = (r, b)
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

(b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = true
r, b, d: nat
l: list packet
H1: packet_sum l = (r, b)
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

b <= r * (d + 1) /\ r <= b * (d + 1)
r, d: nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

forall b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
d: nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
d: nat
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
r, b: nat
H1: packet_sum [] = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d: nat
a: packet
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
r, b: nat
H1: packet_sum (a :: l) = (r, b)
b <= r * (d + 1) /\ r <= b * (d + 1)
d: nat
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
r, b: nat
H1: packet_sum [] = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d: nat
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
r, b: nat
H1: (0, 0) = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d: nat
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) []
H1: (0, 0) = (0, 0)

0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1)
lia.
d: nat
a: packet
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) (a :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
r, b: nat
H1: packet_sum (a :: l) = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
r, b: nat
H1: packet_sum ((r', b') :: l) = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, packet_sum l = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
r, b: nat
H1: (let (x, y) := packet_sum l in (r' + x, b' + y)) = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
X: (nat * nat)%type
HeqX: X = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, X = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
r, b: nat
H1: (let (x, y) := X in (r' + x, b' + y)) = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
r, b: nat
H1: (r' + x, b' + y) = (r, b)

b <= r * (d + 1) /\ r <= b * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)

b' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
H1: abs r' b' <= d /\ r' > 0 /\ b' > 0
H3: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l

b' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
H1: abs r' b' <= d /\ r' > 0 /\ b' > 0
H3: y <= x * (d + 1) /\ x <= y * (d + 1)

b' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)
d, r', b': nat
l: list packet
H2: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) ((r', b') :: l)
x, y: nat
HeqX: (x, y) = packet_sum l
IHl: Forall (fun '(x, y) => abs x y <= d /\ x > 0 /\ y > 0) l -> forall r b : nat, (x, y) = (r, b) -> b <= r * (d + 1) /\ r <= b * (d + 1)
H1: Init.Nat.max (b' - r') (r' - b') <= d /\ r' > 0 /\ b' > 0
H3: y <= x * (d + 1) /\ x <= y * (d + 1)

b' + y <= (r' + x) * (d + 1) /\ r' + x <= (b' + y) * (d + 1)
split; apply d_bound; lia. Qed.

The proof of correctness!


algorithm_iff_correct_distribution

algorithm_iff_correct_distribution

forall r b d : nat, can_distribute r b d = true <-> correct_distribution r b d
r, b, d: nat

can_distribute r b d = true <-> correct_distribution r b d
r, b, d: nat
H: can_distribute r b d = true

correct_distribution r b d
r, b, d: nat
H: correct_distribution r b d
can_distribute r b d = true
r, b, d: nat
H: can_distribute r b d = true

correct_distribution r b d
r, b, d: nat
H: (b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true

correct_distribution r b d
r, b, d: nat
H: (b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = true

correct_distribution r b d
r, b, d: nat
H: b <= r * (d + 1) /\ r <= b * (d + 1)

correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)

correct_distribution r b d
(* Case on whether r <= b or b <= r. *)
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)

r <= b \/ b <= r
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hrb: r <= b
correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r
correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hrb: r <= b

correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r
correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hrb: r <= b

correct_distribution r b d
apply can_make_distr; auto.
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r

correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r

correct_distribution b r d
apply can_make_distr; auto.
r, b, d: nat
H: correct_distribution r b d

can_distribute r b d = true
apply algorithm_condition_necessary; auto. Qed. End ListSpec.

Specification Using Inductive Relations

We've done it! We have defined our specification using lists and proven that our algorithm always computes the correct answer according to our specification.

Now that we know for certain that our program is correct, we can proceed to extracting the program into Ocaml. Before we move on, however, I want to present another specification for this problem, this time defining correct_distribution as an inductive proposition.

While it may be natural to interpret the specification using lists, it does make the definitions more complicated, which affects the length and readability of the proofs. This alternate specification is equivalent (in fact, proving that these two specifications are equivalent is a good exercise), but it requires a lot less unfolding and destructing, which simplifies the proofs.

Module InductiveSpec.
Inductive correct_distribution : nat -> nat -> nat -> Prop :=
  | no_packets : forall d,
      correct_distribution 0 0 d
  | add_packet : forall r b r' b' d,
      correct_distribution r b d ->
      r' > 0 ->
      b' > 0 ->
      abs r' b' <= d ->
      correct_distribution (r'+r) (b'+b) d.
#[export] Hint Constructors correct_distribution : core.

Definition algorithm_iff_correct_distribution :=
  forall r b d,
    can_distribute r b d = true <-> correct_distribution r b d.


forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d

forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d

forall r b d : nat, correct_distribution r b d -> correct_distribution b r d
H: forall r b d : nat, correct_distribution r b d -> correct_distribution b r d
forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d

forall r b d : nat, correct_distribution r b d -> correct_distribution b r d
r, b, d: nat
H: correct_distribution r b d

correct_distribution b r d
d: nat

correct_distribution 0 0 d
r, b, r', b', d: nat
H: correct_distribution r b d
H0: r' > 0
H1: b' > 0
H2: abs r' b' <= d
IHcorrect_distribution: correct_distribution b r d
correct_distribution (b' + b) (r' + r) d
d: nat

correct_distribution 0 0 d
auto.
r, b, r', b', d: nat
H: correct_distribution r b d
H0: r' > 0
H1: b' > 0
H2: abs r' b' <= d
IHcorrect_distribution: correct_distribution b r d

correct_distribution (b' + b) (r' + r) d
r, b, r', b', d: nat
H: correct_distribution r b d
H0: r' > 0
H1: b' > 0
H2: abs r' b' <= d
IHcorrect_distribution: correct_distribution b r d

abs b' r' <= d
r, b, r', b', d: nat
H: correct_distribution r b d
H0: r' > 0
H1: b' > 0
H2: Init.Nat.max (b' - r') (r' - b') <= d
IHcorrect_distribution: correct_distribution b r d

Init.Nat.max (r' - b') (b' - r') <= d
lia.
H: forall r b d : nat, correct_distribution r b d -> correct_distribution b r d

forall r b d : nat, correct_distribution r b d <-> correct_distribution b r d
split; apply H. Qed.

forall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d

forall r b d : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
r, b, d: nat
r_leq_b: r <= b
H: b <= r * (d + 1)

correct_distribution r b d
r, d: nat

forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)

correct_distribution 0 b d
r, d: nat
IHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
b: nat
r_leq_b: S r <= b
H: b <= S r * (d + 1)
correct_distribution (S r) b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)

correct_distribution 0 b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)

b = 0
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)
H0: b = 0
correct_distribution 0 b d
d, b: nat
r_leq_b: 0 <= b
H: b <= 0 * (d + 1)
H0: b = 0

correct_distribution 0 b d
d: nat
H: 0 <= 0 * (d + 1)
r_leq_b: 0 <= 0

correct_distribution 0 0 d
auto.
r, d: nat
IHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
b: nat
r_leq_b: S r <= b
H: b <= S r * (d + 1)

correct_distribution (S r) b d
r, d: nat
IHr: forall b : nat, r <= b -> b <= r * (d + 1) -> correct_distribution r b d
b: nat
r_leq_b: S r <= b
H: b <= S r * (d + 1)
t: nat
Heqt: t = Init.Nat.min (b - r) (d + 1)

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)

b = t + (b - t)
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)
correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)

b = t + (b - t)
lia.
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)

correct_distribution (S r) b d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)

correct_distribution (S r) (t + (b - t)) d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)

S r = 1 + r
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)
Hr: S r = 1 + r
correct_distribution (S r) (t + (b - t)) d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)

S r = 1 + r
lia.
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)
Hr: S r = 1 + r

correct_distribution (S r) (t + (b - t)) d
r, d, b, t: nat
IHr: r <= b - t -> b - t <= r * (d + 1) -> correct_distribution r (b - t) d
r_leq_b: S r <= b
H: b <= S r * (d + 1)
Heqt: t = Init.Nat.min (b - r) (d + 1)
Hb: b = t + (b - t)
Hr: S r = 1 + r

correct_distribution (1 + r) (t + (b - t)) d
constructor; unfold abs; try apply IHr; try lia. Qed.

algorithm_iff_correct_distribution

algorithm_iff_correct_distribution

forall r b d : nat, can_distribute r b d = true <-> correct_distribution r b d
r, b, d: nat
H: can_distribute r b d = true

correct_distribution r b d
r, b, d: nat
H: correct_distribution r b d
can_distribute r b d = true
r, b, d: nat
H: can_distribute r b d = true

correct_distribution r b d
r, b, d: nat
H: (b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true

correct_distribution r b d
r, b, d: nat
H: (b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = true

correct_distribution r b d
r, b, d: nat
H: b <= r * (d + 1) /\ r <= b * (d + 1)

correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)

correct_distribution r b d
(* Case on whether r <= b or b <= r. *)
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)

r <= b \/ b <= r
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hrb: r <= b
correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r
correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hrb: r <= b

correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r
correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hrb: r <= b

correct_distribution r b d
apply can_make_distr; auto.
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r

correct_distribution r b d
r, b, d: nat
H1: b <= r * (d + 1)
H2: r <= b * (d + 1)
Hbr: b <= r

correct_distribution b r d
apply can_make_distr; auto.
r, b, d: nat
H: correct_distribution r b d

can_distribute r b d = true
r, b, d: nat
H: correct_distribution r b d

(b <=? r * (d + 1)) && (r <=? b * (d + 1)) = true
r, b, d: nat
H: correct_distribution r b d

(b <=? r * (d + 1)) = true /\ (r <=? b * (d + 1)) = true
r, b, d: nat
H: correct_distribution r b d

b <= r * (d + 1) /\ r <= b * (d + 1)
d: nat

0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1)
r, b, r', b', d: nat
H: correct_distribution r b d
H0: r' > 0
H1: b' > 0
H2: abs r' b' <= d
IHcorrect_distribution: b <= r * (d + 1) /\ r <= b * (d + 1)
b' + b <= (r' + r) * (d + 1) /\ r' + r <= (b' + b) * (d + 1)
d: nat

0 <= 0 * (d + 1) /\ 0 <= 0 * (d + 1)
lia.
r, b, r', b', d: nat
H: correct_distribution r b d
H0: r' > 0
H1: b' > 0
H2: abs r' b' <= d
IHcorrect_distribution: b <= r * (d + 1) /\ r <= b * (d + 1)

b' + b <= (r' + r) * (d + 1) /\ r' + r <= (b' + b) * (d + 1)
split; apply ListSpec.d_bound; unfold abs in H2; try lia. Qed. End InductiveSpec.

Extraction

Now we can extract our verified algorithm to Ocaml. We don't extract our proofs, since they not actually meant to be run. So in this case, we only need to extract our function can_distribute, which is quite simple. We can try running the extraction command now:

Extraction "imp.ml" can_distribute.

If you look in the file imp.ml, you will see the following datatype definitions at the top:

type bool =
| True
| False

type nat =
| O
| S of nat

Without any directions on how to perform the extraction, Coq will redefine all the datatypes that are used, including booleans and nat. This is a big problem for nat, since defining numbers Peano-style is quite inefficient--if you look how addition is defined, adding two numbers is actually linear in the size of the first number. We can tell Coq to extract nat to int or int64, but this can be quite dangerous. This is because theorems about nat may no longer hold. For instance, it is a theorem for nat that x + y >= x, but this is not true for int or int64, since there may be overflow cases.

In this case, since we know that the inputs are less than 10^9, we can determine that we will not run into overflow issues. So, we can safely extract nat to int64. If you would like, you can also extract to an arbitrary precision integer type like big_int.

Module ExtractionDefs.
Module ExtrNatToInt64.
Extract Inductive nat => "int64"
  [ "Int64.zero" "(fun x -> Int64.succ x)" ]
  "(fun zero succ n -> if Int64.compare n 0 = 0 then zero () else succ (Int64.pred n))".
Extract Constant plus => "Int64.add".
Extract Constant mult => "Int64.mul".
Extract Constant leb => "(fun x y -> Int64.compare x y <= 0)".
End ExtrNatToInt64.

Extract Inductive bool => "bool" [ "true" "false" ].
End ExtractionDefs.

Extraction "imp.ml" can_distribute.

And there we have it--a fully verified program that you can submit to Codeforces! You can see here for a version that includes the input/output plumbing.