From Coq Require Import Wellfounded ssreflect.
From stdpp Require Import strings binders gmap gmultiset ssreflect.

From dislog.lang Require Import syntax.

(******************************************************************************)
Definition binder_set (x:binder) : gset string :=
  match x with
  | BAnon => ∅
  | BNamed x => {[x]}
  end.

Fixpoint subst (x:string) (v:val) (e:expr) : expr :=
  match e with
  | Val _ => e
  | Clo c => Clo (subst_clo x v c)
  | Var y =>
       if (decide (x=y)) then v else e
  | Call e ts =>
      Call (subst x v e) (subst x v <$> ts)
  | CallPrim p e1 e2 =>
      CallPrim p (subst x v e1) (subst x v e2)
  | If e1 e2 e3 =>
      If (subst x v e1) (subst x v e2) (subst x v e3)
  | Let y e1 e2 =>
      Let y (subst x v e1)
        (if (decide (y = BNamed x)) then e2 else subst x v e2)
  | Alloc e1 e2 =>
      Alloc (subst x v e1) (subst x v e2)
  | Load e1 e2 =>
      Load (subst x v e1) (subst x v e2)
  | Store e1 e2 e3 =>
      Store (subst x v e1) (subst x v e2) (subst x v e3)
  | Length e1 =>
      Length (subst x v e1)
  | Prod e1 e2 =>
      Prod (subst x v e1) (subst x v e2)
  | Fst e1 => Fst (subst x v e1)
  | Snd e1 => Snd (subst x v e1)
  | InL e1 => InL (subst x v e1)
  | InR e1 => InR (subst x v e1)
  | Case e1 x2 e2 x3 e3 =>
      Case (subst x v e1)
        x2 (if (decide ((BNamed x) = x2)) then e2 else subst x v e2)
        x3 (if (decide ((BNamed x) = x3)) then e3 else subst x v e3)
  | Par e1 e2 =>
      Par (subst x v e1) (subst x v e2)
  | RunPar e1 e2 =>
      RunPar (subst x v e1) (subst x v e2)
  | CAS e1 e2 e3 e4 =>
      CAS (subst x v e1) (subst x v e2) (subst x v e3) (subst x v e4)
  | Fold e1 => Fold (subst x v e1)
  | Unfold e1 => Unfold (subst x v e1)
  end
with subst_clo x v (c:function) :=
  match c with
  | Lam f ys e =>
      if (decide (x ∈ binder_set f ∪ ⋃ (binder_set <$> ys))) then c else Lam f ys (subst x v e) end.

(* Substitution by a binder. *)
Definition subst' (x:binder) (v:val) (e:expr) :=
  match x with
  | BAnon => e
  | BNamed x => subst x v e end.

(* Iterated substitution. *)
Definition substs (xlvs : list (string * val)) (i : expr) : expr :=
  foldr (fun '(x, lv) => subst x lv) i xlvs.

Definition substs' (xlvs : list (binder * val)) (i : expr) : expr :=
  foldr (fun '(x, lv) => subst' x lv) i xlvs.

Lemma substs'_app xs ys e :
  substs' (xs ++ ys) e = substs' xs (substs' ys e).
Proof. rewrite /substs' foldr_app //. Qed.

(******************************************************************************)
Local Ltac ih_for H x v e :=
  assert (locs (subst x v e) ⊆ locs v ∪ locs e) by (apply H; simpl; lia).

(* No equality, as we don't know if x occurs in v. *)
Lemma locs_subst x v e :
  locs (subst x v e) ⊆ locs v ∪ locs e.
Proof.
  induction e using (well_founded_induction (wf_inverse_image _ nat _ expr_size PeanoNat.Nat.lt_wf_0)).
  destruct e; simpl;
    try (ih_for H x v e); try (ih_for H x v e1);
    try (ih_for H x v e2); try (ih_for H x v e3); try (ih_for H x v e4).
  2:{ destruct f. unfold locs,location_expr. simpl. case_decide. set_solver.
    apply H. simpl. lia. }
  2,6: case_decide; set_solver.
  2:{ induction l.
    { set_solver. }
    { assert  (∀ y : expr, expr_size y < expr_size (Call e l) → locs (subst x v y) ⊆ locs v ∪ locs y) as IHt.
      { intros ? Ht. apply H.
        transitivity (expr_size (Call e l)); try easy.
        simpl. pose proof (expr_size_non_zero a). unfold "<$>". lia. }
      apply IHl in IHt. clear IHl.
      rewrite fmap_cons. unfold locs, locs_expr in *. simpl in *.
      ih_for H x v a.
      set_solver. } }
  9:{ repeat case_decide; set_solver. }
  all:set_solver.
Qed.

Lemma locs_subst' x v e :
  locs (subst' x v e) ⊆ locs v ∪ locs e.
Proof. destruct x. set_solver. apply locs_subst. Qed.

Lemma locs_substs' xs e :
  locs (substs' xs e) ⊆ locs xs.*2 ∪ locs e.
Proof.
  revert e. induction xs as [|(?,?)]; intros; simpl.
  { set_solver. }
  { etrans. apply locs_subst'. set_solver. }
Qed.

(******************************************************************************)

Fixpoint fv (e:expr) : gset string :=
  match e with
  | Val x => ∅
  | Clo x => fv_func x
  | Var x => {[x]}
  | Call x x1 => fv x ∪ ⋃ (fv <$> x1)
  | CallPrim x x0 x1 => fv x0 ∪ fv x1
  | If x x0 x1 => fv x ∪ fv x0 ∪ fv x1
  | Let x x0 x1 => fv x0 ∪ (fv x1 ∖ binder_set x)
  | Fst x | Snd x | InL x | InR x | Length x | Fold x | Unfold x => fv x
  | Case x x0 x1 x2 x3 => fv x ∪ (fv x1 ∖ binder_set x0) ∪ (fv x3 ∖ binder_set x2)
  | Prod x x0 | Alloc x x0 | Load x x0 | Par x x0 | RunPar x x0 => fv x ∪ fv x0
  | Store x x0 x1 => fv x ∪ fv x0 ∪ fv x1
  | CAS x x0 x1 x2 => fv x ∪ fv x0 ∪ fv x1 ∪ fv x2
  end
with fv_func (f:function) : gset string :=
  match f with
  | Lam x x3 x4 => fv x4 ∖ (binder_set x ∪ ⋃ (binder_set <$> x3))
  end.

Local Ltac go H := f_equal; (apply H; simpl; [ lia | set_solver ]).

Lemma subst_not_in x v e :
  x ∉ fv e ->
  subst x v e = e.
Proof.
  induction e using (well_founded_induction (wf_inverse_image _ nat _ expr_size PeanoNat.Nat.lt_wf_0)).
  destruct e; simpl; intros ?.
  all:try go H.
  { f_equal. destruct f. simpl. case_decide. done. go H. }
  { case_decide; set_solver. }
  { f_equal. go H. induction l. done.
    simpl in *. f_equal. go H. apply IHl.
    { intros. apply H. unfold "<$>" in *. lia. done. }
    { set_solver. } }
  { f_equal. go H. case_decide. done. destruct b; go H. }
  { f_equal. go H.
    { case_decide. done. apply H. simpl. lia.
      unfold binder_set in *.
      destruct b. all: simpl in *;set_solver. }
    { case_decide. done. apply H. simpl. lia.
      unfold binder_set in *.
      destruct b0,b. all: simpl in *;set_solver. } }
Qed.

Local Ltac go H ::= (try f_equal); apply H; simpl; lia.

Lemma subst_subst_commut x1 v1 x2 v2 e :
  x1 ≠ x2 ->
  subst x1 v1 (subst x2 v2 e) = subst x2 v2 (subst x1 v1 e).
Proof.
  intros E.
  induction e using (well_founded_induction (wf_inverse_image _ nat _ expr_size PeanoNat.Nat.lt_wf_0)).
  destruct e; simpl.
  all:try go H.
  { f_equal. destruct f. simpl.
    do 2 case_decide; simpl; try done.
    { rewrite !decide_True //. }
    { rewrite !(@decide_False _ (x1 ∈ _)) // !(@decide_True _ (x2 ∈ _)) //. }
    { rewrite !(@decide_False _ (x2 ∈ _)) // !(@decide_True _ (x1 ∈ _)) //. }
    { rewrite !decide_False //. go H. } }
  { case_decide; subst; simpl.
    { rewrite decide_False //. simpl. rewrite decide_True //. }
    { case_decide; try done. simpl. rewrite decide_False //. } }
  { f_equal.
    { apply H. simpl. lia. }
    { induction l. done. simpl. f_equal. apply H; simpl; lia.
      apply IHl. intros.  apply H. simpl. simpl in *. unfold "<$>" in *. lia. } }
  { case_decide; subst.
    { rewrite decide_False //. 2:naive_solver. go H. }
    { case_decide; subst; go H. } }
  { f_equal. apply H; simpl; lia.
    { case_decide; subst.
      { case_decide; try done. }
      { case_decide; try done. apply H; simpl; lia. } }
    { case_decide; subst.
      { case_decide; try done. }
      { case_decide; try done. apply H; simpl; lia. } } }
Qed.
