Module invmx

Require Import ssreflect ssrbool ssrfun eqtype ssrnat seq choice fintype.
Require Import div finfun bigop prime binomial ssralg finset fingroup finalg.
Require Import perm zmodp matrix ssrcomplements.
Require Import cssralg seqmatrixCR.


Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.


Section invmx.

Import GRing.Theory.
Local Open Scope ring_scope.

Variable R : comUnitRingType.

Lemma invmx_left : forall m (M M' : 'M[R]_m), M *m M' = 1%:M -> M' *m M = 1%:M.
Proof.
move=> m M M' H.
have hM : (M \in unitmx) by case: (mulmx1_unit H).
have h' : (M' = invmx M) by rewrite -(mulKmx hM M') H mulmx1.
by rewrite h' mulVmx.
Qed.

Lemma invmx_uniq m (M M' : 'M[R]_m) :
  M *m M' = 1%:M -> M' = invmx M.
Proof.
move=> H.
have hM : (M \in unitmx) by case: (mulmx1_unit H).
by rewrite -[M']mulmx1 -(mulmxV hM) mulmxA (invmx_left H) mul1mx.
Qed.


Fixpoint fast_invmx (m : nat) : 'M[R]_m -> 'M[R]_m :=
  match m return 'M[R]_m -> 'M[R]_m with
  | S p => fun (M : 'M[R]_(1 + p)) =>
           let: N := fast_invmx (drsubmx M) in
           block_mx 1%:M 0 (- N *m dlsubmx M) N
  | O => fun _ => 1%:M
  end.

Definition lower1 m (M : 'M[R]_m) :=
  forall (i j : 'I_m), i <= j -> M i j = (i == j)%:R.

Lemma drlower1 : forall m (M : 'M[R]_(1 + m)%N),
  lower1 M -> lower1 (drsubmx M).
Proof.
move=> m M H i j hij.
by rewrite !mxE !rshift1 H.
Qed.

Lemma urlower1 : forall m (M : 'M[R]_(1 + m)%N),
  lower1 M -> ursubmx M = 0.
Proof.
move=> m M H.
apply/rowP => i.
by rewrite !mxE lshift0 rshift1 H.
Qed.

Lemma ullower1 : forall m (M : 'M[R]_(1 + m)%N),
  lower1 M -> ulsubmx M = 1%:M.
Proof.
move=> m M H.
apply/rowP=> i.
by rewrite !mxE !ord1 H.
Qed.

Lemma fast_invmxE : forall m (M : 'M[R]_m),
  lower1 M -> M *m fast_invmx M = 1%:M.
Proof.
elim=> [M |n ih]; first by rewrite !thinmx0.
rewrite [n.+1]/(1 + n)%N => M hM /=.
rewrite -{1}[M]submxK (@mulmx_block _ 1 n 1 n 1 n).
rewrite !mulmx0 !mulmx1 !add0r ih; last by apply/drlower1.
rewrite (urlower1 hM) !mul0mx addr0 mulmxA mulmxN ih; last by apply/drlower1.
by rewrite mulNmx mul1mx subrr ullower1 -?scalar_mx_block.
Qed.

Lemma fast_invmxP m (M : 'M[R]_m) (H : lower1 M) :
  fast_invmx M = invmx M.
Proof.
apply/invmx_uniq.
by rewrite fast_invmxE.
Qed.

End invmx.


Section seqinvmx.

Variable R : comUnitRingType.
Variable CR : cunitRingType R.

Fixpoint cfast_invmx (m : nat) (M : seqmatrix CR) :=
  match m with
  | S p =>
   let: N := cfast_invmx p (drsubseqmx 1 1 M) in
   block_seqmx (seqmx1 _ 1) (seqmx0 _ 1 p)
               (mulseqmx p 1 (oppseqmx N) (dlsubseqmx 1 1 M)) N
  | O => seqmx1 _ O
  end.

Lemma cfast_invmxP : forall (m : nat),
  {morph (@seqmx_of_mx _ CR m m) : M / fast_invmx M >-> cfast_invmx m M}.
Proof.
elim=> [M|m ih]; first by rewrite seqmx1E.
rewrite [m.+1]/(1 + m)%N => M /=.
rewrite -(@block_seqmxE _ _ 1 _ 1) seqmx1E seqmx0E ih -drsubseqmxE -mulseqmxE.
by rewrite oppseqmxE ih -drsubseqmxE -dlsubseqmxE.
Qed.

End seqinvmx.