I want to prove trivial_request
and think we need request_eq
to prove trivial_request
.
Lemma request_eq : forall (n k:nat)(a b:Euc (S n))(H:(S k < S n)%nat),
Vnth (request (λ _ : Euc 0, id) Vnil a b) H =
Vnth (request (λ _ : Euc 0, id) Vnil (Vtail a) (Vtail b)) (lt_S_n H).
Proof.
destruct n; intros.
inversion H; inversion H1.
suff main : (forall (k m n:nat)(v1 v2:Euc (S (S m)))(v3:Euc (S (S n)))(H3:((S k) < (S (S n)))%nat),
(k <= (S n) <= (S m))%nat ->
Vnth (updateRnA (λ _ : Euc 0, id) Vnil v1 v2 v3) H3 =
Vnth (updateRnA (λ _ : Euc 0, id) Vnil (Vtail v1) (Vtail v2) (Vtail v3)) (lt_S_n H3)).
have [klen | ngtk] := (Compare_dec.le_lt_dec k (S n)).
rewrite (main k n n a b); intuition.
have : (k < S n)%nat; lia.
induction k0; intros.
VSntac v3; VSntac (Vtail v3); rewrite /=.
rewrite 2!Derive_deriveA_id 2!Vnth_tail.
rewrite (Vnth_eq v1 (kLess (S m) (n0 - 0)) (lt_n_S (kLess m (n0 - 0)))); try lia.
rewrite (Vnth_eq v2 (kLess (S m) (n0 - 0)) (lt_n_S (kLess m (n0 - 0)))) //; try lia.
have : ((S k0) < S (S n0))%nat. move: H3; lia. move => H4.
have : (k0 <= S n0 <= S m)%nat. lia. move => H5.
move: (IHk0 m n0 v1 v2 v3 H4 H5) => H6.
Abort.
Lemma trivial_request : forall (k A:nat)(a b:Euc A)(H:(k < A)%nat),
Vnth (request (fun p:Euc 0 => id) Vnil a b) H =
Vnth a H - deriveA (fun p:Euc 0 => id) Vnil a b H.
Proof.
induction k; intros; destruct a.
inversion H.
rewrite /= 2!Derive_deriveA_id.
rewrite (Vnth_eq b (kLess n (n - 0)) H);last first. lia.
rewrite (Vnth_eq (Vcons h a) (kLess n (n - 0)) H) => //; lia.
inversion H.
rewrite Derive_deriveA_id.
VSntac b.
move: (IHk n a (Vtail b) (lt_S_n H)) => IHk2.
rewrite Derive_deriveA_id in IHk2.
rewrite request_eq /=.
apply IHk2.
Qed.
request
updates elements of vector with gradient of an arbitrary function on its element. deriveA
return gradients on each elements.
I must limit value of k
and A
more over 1 to pass two vectors to Vtail
. Induction on A may be useless because size of vector is limitted more over 1. But, I don't know other usefull way.
I only need to prove trivial_request
, so if there is a better way, please let me know. Please tell me your solution.
Require Import Psatz.
From mathcomp Require Import ssreflect.
Require Import Coq.Reals.Reals.
Require Import CoLoR.Util.Vector.VecUtil.
Require Import Coquelicot.Coquelicot.
Require Import Coq.Init.Datatypes.
Definition Euc (n:nat) := vector R n.
Definition EucSum {A}(e:Euc A) :R:= Vfold_right Rplus e 0.
Definition QE (r1 r2:R):R:= (/ 2)*((r1 - r2)^2).
Definition QuadraticError {n : nat} (v1 v2: Euc n) :Euc n:= Vmap2 QE v1 v2.
Definition kLess (P k:nat): (P - k < (S P))%nat.
lia. Defined.
Definition deriveA {P A B}(I:Euc P -> Euc A -> Euc B)
(p :Euc P)(input:Euc A)(train:Euc B){k:nat}(H:(k < A)%nat):R:=
Derive (fun x => EucSum (QuadraticError (I p (@Vreplace R A input k H x)) train))
(@Vnth R A input k H).
Definition updateRnA {P A B} :(Euc P -> Euc (S A) -> Euc B) ->
Euc P -> Euc (S A) -> Euc B -> forall {n:nat}, Euc n -> Euc n:=
fun I p input train => fix fr {n} v:=
match v with
|Vnil => Vnil
|Vcons x _ xs => Vcons (x - (deriveA I p input train (kLess A (n-1)))) (fr xs)
end.
Definition request {P A B} :(Euc P -> Euc A -> Euc B) -> Euc P -> Euc A -> Euc B -> Euc A.
refine (match A with
| O => fun _ _ _ _ => Vnil
| S A' => _
end);
intros I p input train;
exact (updateRnA I p input train input).
Defined.
We may need Derive_deriveA_id
to prove request_eq
.
Lemma ex_deriveA_id k : forall (n:nat)(a b: Euc n)(H:(k < n)%nat),
ex_derive (fun x : R => EucSum (QuadraticError (id (Vreplace a H x)) b)) (Vnth a H).
Proof.
induction k; intros; destruct a; rewrite /EucSum/=/QE.
inversion H.
apply (ex_derive_plus _ _ _ (ex_derive_scal _ _ _ (ex_derive_pow _ _ _
(ex_derive_minus _ _ _ (ex_derive_id _) (ex_derive_const _ _))))
(ex_derive_const _ _)).
inversion H.
rewrite /EucSum/= in IHk.
apply (ex_derive_plus _ _ _ (ex_derive_const _ _) (IHk n a (Vtail b) (lt_S_n H))).
Qed.
Lemma Derive_deriveA_id k : forall (n:nat)(a b: Euc n)(H:(k < n)%nat),
(deriveA (fun p:Euc 0 => id) Vnil a b H) = Vnth a H - Vnth b H.
Proof.
induction k; intros; destruct a; rewrite /deriveA/EucSum/=.
inversion H.
rewrite /=/QE;
rewrite (Derive_plus _ _ _ (ex_derive_scal _ _ _ (ex_derive_pow _ _ _
(ex_derive_minus _ _ _ (ex_derive_id _) (ex_derive_const _ _))))
(ex_derive_const _ _))
(Derive_const _ _) Derive_scal
(Derive_pow _ _ _ (ex_derive_minus _ _ _ (ex_derive_id _) (ex_derive_const _ _)))
(Derive_minus _ _ _ (ex_derive_id _) (ex_derive_const _ _)) Derive_id (Derive_const _ _)
/= Rplus_comm Rplus_0_l Rminus_0_r -Rmult_assoc 2!Rmult_1_r Rinv_l; auto; rewrite Rmult_1_l
(Vnth_eq b H (Nat.lt_0_succ n)); auto; rewrite -Vhead_nth //.
inversion H.
rewrite /deriveA/EucSum/id in IHk.
rewrite (Derive_plus _ _ _ (ex_derive_const _ _) (ex_deriveA_id k n a (Vtail b) (lt_S_n H)))
(Derive_const _ _) /EucSum/id Rplus_comm Rplus_0_r
(IHk n a (Vtail b) (lt_S_n H)) -(Vnth_eq b (lt_n_S (lt_S_n H)) H (eq_refl (S k)))
(Vnth_tail b (lt_S_n H)) //.
Qed.