Search code examples
proofagda

Flattened matrix vs 2D matrix lookup equivalence (proof) - seeking more elegance


I have a proof of the (obvious) statement that looking up elements in a flattened representation of a matrix as an m * n length vector is the same as a Vector-of-Vector representation. But my proof feels clunky. [I won't give the proof of it here, as doing so would bias the search!]. To make this question self-contained, below I give a self-contained Agda module with a few lemmas that are helpful. [Some of these lemmas should probably be in the standard library, but are not.]

Basically, I am looking for an elegant way to fill the hole at the bottom, the proof of lookup∘concat. If you can make my lemmas more elegant as well, do feel free!

module NNN where

open import Data.Nat
open import Data.Nat.Properties.Simple
open import Data.Nat.Properties
open import Data.Vec
open import Data.Fin using (Fin; inject≤; fromℕ; toℕ)
open import Data.Fin.Properties using (bounded)
open import Data.Product using (_×_; _,_)
open import Relation.Binary.PropositionalEquality

-- some useful lemmas
cong+r≤ : ∀ {i j} → i ≤ j → (k : ℕ) → i + k ≤ j + k
cong+r≤ {0}     {j}     z≤n       k = n≤m+n j k
cong+r≤ {suc i} {0}     ()        k -- absurd
cong+r≤ {suc i} {suc j} (s≤s i≤j) k = s≤s (cong+r≤ {i} {j} i≤j k)

cong+l≤ : ∀ {i j} → i ≤ j → (k : ℕ) → k + i ≤ k + j
cong+l≤ {i} {j} i≤j k =
  begin (k + i
           ≡⟨ +-comm k i ⟩ 
         i + k
           ≤⟨ cong+r≤ i≤j k ⟩ 
         j + k
           ≡⟨ +-comm j k ⟩ 
         k + j ∎)
  where open ≤-Reasoning

cong*r≤ : ∀ {i j} → i ≤ j → (k : ℕ) → i * k ≤ j * k
cong*r≤ {0}     {j}     z≤n       k = z≤n
cong*r≤ {suc i} {0}     ()        k -- absurd
cong*r≤ {suc i} {suc j} (s≤s i≤j) k = cong+l≤ (cong*r≤ i≤j k) k 

sinj≤ : ∀ {i j} → suc i ≤ suc j → i ≤ j
sinj≤ {0}     {j}     _        = z≤n
sinj≤ {suc i} {0}     (s≤s ()) -- absurd
sinj≤ {suc i} {suc j} (s≤s p)  = p

i*n+k≤m*n : ∀ {m n} → (i : Fin m) → (k : Fin n) → 
            (suc (toℕ i * n + toℕ k) ≤ m * n)
i*n+k≤m*n {0} {_} () _
i*n+k≤m*n {_} {0} _ ()
i*n+k≤m*n {suc m} {suc n} i k = 
  begin (suc (toℕ i * suc n + toℕ k) 
           ≡⟨  cong suc (+-comm (toℕ i * suc n) (toℕ k))  ⟩
         suc (toℕ k + toℕ i * suc n)
           ≡⟨ refl ⟩
         suc (toℕ k) + (toℕ i * suc n)
           ≤⟨ cong+r≤ (bounded k) (toℕ i * suc n) ⟩ 
         suc n + (toℕ i * suc n)
           ≤⟨ cong+l≤ (cong*r≤ (sinj≤ (bounded i)) (suc n)) (suc n) ⟩
         suc n + (m * suc n) 
           ≡⟨ refl ⟩
         suc m * suc n ∎)
  where open ≤-Reasoning

fwd : {m n : ℕ} → (Fin m × Fin n) → Fin (m * n)
fwd {m} {n} (i , k) = inject≤ (fromℕ (toℕ i * n + toℕ k)) (i*n+k≤m*n i k)

lookup∘concat : ∀ {m n} {A : Set} (i : Fin m) (j : Fin n) 
  (xss : Vec (Vec A n) m) → 
  lookup (fwd (i , j)) (concat xss) ≡ lookup j (lookup i xss)
lookup∘concat i j xss = {!!}

Solution

  • It's better to define fwd by induction, then the rest follows.

    open import Data.Nat.Base
    open import Data.Fin hiding (_+_)
    open import Data.Vec
    open import Data.Vec.Properties
    open import Relation.Binary.PropositionalEquality
    
    fwd : ∀ {m n} -> Fin m -> Fin n -> Fin (m * n)
    fwd {suc m} {n}  zero   j = inject+ (m * n) j
    fwd     {n = n} (suc i) j = raise n (fwd i j)
    
    -- This should be in the standard library.
    lookup-++-raise : ∀ {m n} {A : Set} (j : Fin n) (xs : Vec A m) (ys : Vec A n)
                    -> lookup (raise m j) (xs ++ ys) ≡ lookup j ys
    lookup-++-raise j  []      ys = refl
    lookup-++-raise j (x ∷ xs) ys = lookup-++-raise j xs ys
    
    lookup∘concat : ∀ {m n} {A : Set} i j (xss : Vec (Vec A n) m) 
                  -> lookup (fwd i j) (concat xss) ≡ lookup j (lookup i xss)
    lookup∘concat  zero   j (xs ∷ xss) = lookup-++-inject+ xs (concat xss) j
    lookup∘concat (suc i) j (xs ∷ xss)
      rewrite lookup-++-raise (fwd i j) xs (concat xss) = lookup∘concat i j xss
    

    The soundness proof for fwd:

    module Soundness where
      open import Data.Nat.Properties.Simple
      open import Data.Fin.Properties
    
      soundness : ∀ {m n} (i : Fin m) (j : Fin n) -> toℕ (fwd i j) ≡ toℕ i * n + toℕ j
      soundness {suc m} {n}  zero   j = sym (inject+-lemma (m * n) j)
      soundness     {n = n} (suc i) j rewrite toℕ-raise n (fwd i j)
                                            | soundness i j
                                            = sym (+-assoc n (toℕ i * n) (toℕ j))