From Coq Require Import Wellfounded ssreflect.

From stdpp Require Import base strings gmap ssreflect.

From dislog.utils Require Import more_stdpp.
From dislog.lang Require Import semantics.
From dislog.types Require Import substmap.

(******************************************************************************)
(* [msubsts] is an environment-based substitution *)

Fixpoint msubsts (vs:gmap string val) (e:expr) : expr :=
  match e with
  | Val _ => e
  | Clo c => Clo (subst_clo vs c)
  | Var y =>
      match vs !! y with
      | Some v => Val v
      | None => e end
  | Call e ts =>
      Call (msubsts vs e) (msubsts vs <$> ts)
  | CallPrim p e1 e2 =>
      CallPrim p (msubsts vs e1) (msubsts vs e2)
  | If e1 e2 e3 =>
      If (msubsts vs e1) (msubsts vs e2) (msubsts vs e3)
  | Let y e1 e2 =>
      Let y (msubsts vs e1) (msubsts (bdelete y vs) e2)
  | Alloc e1 e2 =>
      Alloc (msubsts vs e1) (msubsts vs e2)
  | Load e1 e2 =>
      Load (msubsts vs e1) (msubsts vs e2)
  | Store e1 e2 e3 =>
      Store (msubsts vs e1) (msubsts vs e2) (msubsts vs e3)
  | Length e1 =>
      Length (msubsts vs e1)
  | Prod e1 e2 =>
      Prod (msubsts vs e1) (msubsts vs e2)
  | Fst e1 => Fst (msubsts vs e1)
  | Snd e1 => Snd (msubsts vs e1)
  | InL e1 => InL (msubsts vs e1)
  | InR e1 => InR (msubsts vs e1)
  | Case e1 x2 e2 x3 e3 =>
      Case (msubsts vs e1)
        x2 (msubsts (bdelete x2 vs) e2)
        x3 (msubsts (bdelete x3 vs) e3)
  | Par e1 e2 =>
      Par (msubsts vs e1) (msubsts vs e2)
  | RunPar e1 e2 =>
      RunPar (msubsts vs e1) (msubsts vs e2)
  | CAS e1 e2 e3 e4 =>
      CAS (msubsts vs e1) (msubsts vs e2) (msubsts vs e3) (msubsts vs e4)
  | Fold e => Fold (msubsts vs e)
  | Unfold e => Unfold (msubsts vs e)
  end
with subst_clo (vs:gmap string val) (c:function) :=
  match c with
  | Lam f ys e => Lam f ys (msubsts (bdeletes (f::ys) vs) e) end.

Lemma msubsts_val vs (v:val) :
  msubsts vs v = v.
Proof. done. Qed.

Lemma msubsts_empty e :
  msubsts ∅ e = e.
Proof.
  induction e using (well_founded_induction (wf_inverse_image _ nat _ expr_size PeanoNat.Nat.lt_wf_0)).
  rename H into IH.
  destruct e; simpl; try done.
  all:try (f_equal; apply IH; simpl; lia).
  { f_equal. destruct f. simpl. f_equal.
    rewrite bdeletes_empty bdelete_empty. naive_solver by lia. }
  { rewrite IH; last (simpl; lia).
    do 2 f_equal. induction l; first done.
    simpl. do 2 f_equal; naive_solver by lia. }
  { f_equal. naive_solver by lia. rewrite bdelete_empty. naive_solver by lia. }
  { f_equal. naive_solver by lia.
    all:rewrite bdelete_empty; naive_solver by lia. }
Qed.

Lemma subst_msubsts x v e :
  subst x v e = msubsts {[x:=v]} e.
Proof.
  induction e using (well_founded_induction (wf_inverse_image _ nat _ expr_size PeanoNat.Nat.lt_wf_0)).
  rename H into IH.
  destruct e; simpl; first done.
  { f_equal. destruct f; simpl. rewrite -bdeletes_cons.
    case_decide; f_equal.
    { rewrite -(bdeletes_already_in x); last set_solver.
      rewrite bdelete_bdeletes. simpl.
      rewrite delete_singleton bdeletes_empty bdelete_empty msubsts_empty //. }
    { rewrite bdeletes_disj. eauto. rewrite binders_set_cons. set_solver. } }
  { rewrite lookup_insert_case. case_decide; done. }
  { rewrite IH. 2:(simpl; lia).
    f_equal. induction l. done. simpl. f_equal; naive_solver by lia. }
  all: try rewrite !IH; try (simpl; lia); try done.
  { f_equal. case_decide; subst; try done.
    { simpl. rewrite delete_singleton msubsts_empty //. }
    { rewrite bdelete_singleton_ne //. } }
  { f_equal.
    { case_decide; subst; simpl.
      { rewrite delete_singleton msubsts_empty //. }
      { rewrite bdelete_singleton_ne //. } }
    { case_decide; subst; simpl.
      { rewrite delete_singleton msubsts_empty //. }
      { rewrite bdelete_singleton_ne //. } } }
Qed.

Lemma msubsts_union m1 m2 e :
  msubsts (m1 ∪ m2) e = msubsts m2 (msubsts m1 e).
Proof.
  revert m1 m2.
  induction e using (well_founded_induction (wf_inverse_image _ nat _ expr_size PeanoNat.Nat.lt_wf_0)); intros m1 m2.
  rename H into IH.
  destruct e; simpl; first done.
  all: try rewrite !IH; try (simpl; lia); try done.
  { f_equal. destruct f. simpl. rewrite bdeletes_union bdelete_union IH //.
    simpl. lia. }
  { rewrite lookup_union.
    destruct (m1!! s) eqn:E.
    { rewrite union_Some_l. destruct (m2!!s); simpl; rewrite ?E //. }
    { rewrite left_id //. } }
  { f_equal. induction l. done. simpl.
    rewrite IH; last (simpl; lia). f_equal.
    apply IHl. intros. apply IH. simpl in *. unfold "<$>" in *. lia. }
  { f_equal. rewrite bdelete_union IH //. simpl. lia. }
  { f_equal. all:rewrite bdelete_union IH //; simpl; lia. }
Qed.

Lemma msubsts_insert s v m e :
  msubsts (<[s:=v]> m) e = msubsts m (subst s v e).
Proof.
  rewrite insert_union_singleton_l msubsts_union subst_msubsts //.
Qed.

Lemma insert_msubsts s v m e :
 subst s v (msubsts (delete s m) e) = msubsts (<[s:=v]>m) e.
Proof.
  rewrite -insert_delete_insert insert_union_singleton_l subst_msubsts -msubsts_union.
  f_equal. apply map_union_comm. apply map_disjoint_dom_2.
  rewrite dom_delete_L dom_singleton_L. set_solver.
Qed.

Lemma binsert_msubsts s v m e :
 subst' s v (msubsts (bdelete s m) e) = msubsts (binsert s v m) e.
Proof.
  destruct s; first done. simpl. apply insert_msubsts.
Qed.

Lemma msubsts_insert_notin s v m e :
  s ∉ dom m ->
  msubsts (<[s:=v]> m) e = subst s v (msubsts m e).
Proof.
  intros.
  rewrite insert_union_singleton_l map_union_comm.
  2:{ apply map_disjoint_singleton_l_2. by apply not_elem_of_dom. }
  rewrite !msubsts_union subst_msubsts //.
Qed.

Lemma msubsts_binsert x v m e :
  msubsts (binsert x v m) e = msubsts m (subst' x v e).
Proof.
  destruct x; first done. simpl. rewrite msubsts_insert //.
Qed.

Lemma substs_msubsts_bdeletes xs ys m e :
  length xs = length ys ->
  binders_set xs ## dom m ->
  substs' (zip xs ys) (msubsts m e) =
  msubsts (extend xs ys m) e.
Proof.
  revert ys e; induction xs using rev_ind; intros ys e Hl Hdisj.
  { destruct ys; done. }
  destruct (last ys) eqn:Hlast; last first.
  { rewrite app_length in Hl. apply last_None in Hlast. subst.
    simpl in *. lia. }
  { apply last_Some in Hlast. destruct Hlast as (ys'&->).
    rewrite !app_length in Hl. simpl in Hl.
    rewrite zip_app; last lia. simpl.
    rewrite substs'_app. simpl.
    rewrite /extend rev_zip.
    2:{ rewrite !app_length. simpl. lia. }
    rewrite !rev_app_distr zip_app; last by (simpl; lia).
    simpl. rewrite -rev_zip; last lia.
    rewrite binders_set_app in Hdisj.
    destruct x; simpl.
    { rewrite IHxs //.
      lia. set_solver. }
    rewrite msubsts_insert -msubsts_insert_notin.
    2:{ rewrite /binders_set in Hdisj. simpl in *. set_solver. }
    rewrite -IHxs; only 2:lia; last set_solver.
    rewrite msubsts_insert //. }
Qed.
