I have the following setup:
data Shape : ( n : Nat ) -> Type where
Nil : Shape Z
(::) : ( k : Nat ) -> { auto prf : IsSucc k } -> ( xs : Shape n ) -> Shape ( S n )
record Tensor ( n : Nat ) ( a : Type ) where
constructor MkTensor
shape : ( Shape n )
arr : Vect ( count shape ) a
count : Shape n -> Nat
count Nil = 1
count ( x :: xs ) = x * count xs
I also have a proof for the theorem:
countIsNonZero : ( shape : Shape n ) -> IsSucc ( count shape )
countIsNonZero Nil = ItIsSucc
countIsNonZero ( (::) x { prf } xs ) = helper x {p1 = prf} ( count xs ) {p2 = countIsNonZero xs}
where
helper2 : ( k : Nat ) -> { auto p : IsSucc k } -> ( x : Nat ) -> IsSucc ( x + k )
helper2 k {p} Z = p
helper2 k {p} ( S l ) = ItIsSucc
helper : ( a : Nat ) -> { auto p1 : IsSucc a } -> ( b : Nat ) -> { auto p2 : IsSucc b } -> IsSucc ( a * b )
helper ( S a ) b {p1} {p2} = rewrite plusCommutative b ( a * b ) in helper2 b {p = p2} ( a * b )
Now, I want to use it to create an indexing function, which informally has the following signature:
index : Tensor n a -> Fin t1 -> ... -> Fin tn -> a
where t1...tn
are values in shape.
I wrote a simple helper for that purpose, namely:
getShapeFunctionType : ( Shape n ) -> ( Nat -> Type ) -> Type -> Type
getShapeFunctionType Nil f r = r
getShapeFunctionType ( x :: xs ) f r = ( ( f x ) -> getShapeFunctionType xs f r )
index : { a : Type } -> ( t : Tensor n a ) -> getShapeFunctionType ( shape t ) ( \i : Nat => Fin i ) a
index { a } tn = getIndexFunction ( shape tn )
where
getIndexFunction : ( sh : Shape n ) -> getShapeFunctionType sh ( \i : Nat => Fin i ) a
getIndexFunction Nil = Data.Vect.index ( fromInteger 0 {prf=?h1} ) ( arr tn )
getIndexFunction ( x :: xs ) = ?h2
The first problem I met is that I do not know how to convert the theorem I have to the hole h1
, i.e. how to tell Idris that since count shape
is never equal to 0, we are always able to get zeroth element of Vect ( count shape )
.