From iris.proofmode Require Import base proofmode.
From iris.base_logic.lib Require Export fancy_updates.

From dislog.utils Require Import graph.
From dislog.lang Require Import syntax.
From dislog.newlang Require Import semantics.

Section gen.
Context `{Countable A}.

(* ------------------------------------------------------------------------ *)
(* [rootsde] the compatibility of the task tree with a term. *)

Inductive rootsde : graph.graph A -> gmap loc A -> semantics_cycle.task_tree A -> expr -> Prop :=
| RDeLeaf : forall G α (t:A) (e:expr),
    all_abef G α t (locs e) ->
    rootsde G α (Leaf t) e
| RDeCtx : forall G α T (K:ctx) (e:expr),
    rootsde G α T e ->
    (set_Forall (fun t => all_abef G α t (locs K)) (frontier T)) ->
    rootsde G α T (fill_item K e)
| RDePar : forall G α T1 T2 (e1 e2:expr) t,
    rootsde G α T1 e1 ->
    rootsde G α T2 e2 ->
    rootsde G α (Node t T1 T2) (RunPar e1 e2).

Lemma rootsde_mon_graph G G' α T e :
  G ⊆ G' ->
  rootsde G α T e ->
  rootsde G' α T e.
Proof.
  induction 2;
    eauto using RDeLeaf, RDeCtx, RDePar, all_abef_mon_graph.
  eapply RDeCtx; eauto.
  eapply set_Forall_impl; eauto. intros. eauto using all_abef_mon_graph.
Qed.

Lemma rootsde_mon_amap G α α' T e :
  α ⊆ α' ->
  rootsde G α T e ->
  rootsde G α' T e.
Proof.
  induction 2;
    eauto using RDeLeaf, RDeCtx, RDePar, all_abef_mon_amap.
  eapply RDeCtx; eauto.
  eapply set_Forall_impl; eauto. intros. eauto using all_abef_mon_amap.
Qed.

Lemma rootsde_join G α T1 T2 (v1 v2:val) t :
  rootsde G α (Node t T1 T2) (RunPar v1 v2) ->
  exists n1 n2, T1=Leaf n1 /\ T2=Leaf n2.
Proof.
  intros E. inversion E; subst; elim_ctx.
  inversion H4; subst.
  2:{ inversion H; elim_ctx. }
  inversion H8; subst.
  2:{ inversion H1; elim_ctx. }
  naive_solver.
Qed.

Lemma rootsde_par G α T1 T2 e1 e2 t :
  rootsde G α (Node t T1 T2) (RunPar e1 e2) ->
  rootsde G α T1 e1 /\ rootsde G α T2 e2.
Proof. inversion 1; elim_ctx. naive_solver. Qed.

Lemma rootsde_node_no_val G α T1 T2 e t :
  rootsde G α (Node t T1 T2) e ->
  ¬ is_val e.
Proof. inversion 1; intros ?; elim_ctx. eauto. Qed.

Lemma rootsde_leaf_inv G α t e:
  rootsde G α (Leaf t) e ->
  all_abef G α t (locs e).
Proof.
  remember (Leaf t) as T.
  induction 1.
  { naive_solver. }
  { subst. rewrite locs_fill_item.
    simpl in *. rewrite set_Forall_singleton in H1.
    apply all_abef_union; eauto. }
  { inversion HeqT. }
Qed.

Lemma rootsde_inv_ctx G α T K e :
  ¬ is_val e ->
  rootsde G α T (fill_item K e) ->
  (set_Forall (fun t => all_abef G α t (locs K)) (frontier T)) /\ rootsde G α T e.
Proof.
  intros Ht Hcomp.
  inversion Hcomp; subst.
  { rewrite locs_fill_item all_abef_union in H0. rewrite set_Forall_singleton.
    split; first naive_solver. apply RDeLeaf. naive_solver. }
  { destruct T.
    { apply rootsde_leaf_inv in Hcomp.
      rewrite locs_fill_item all_abef_union in Hcomp. simpl in *.
      rewrite set_Forall_singleton. destruct Hcomp.
      split; first done. apply RDeLeaf. naive_solver. }
    { apply fill_item_inj in H0. destruct H0 as (?,?); subst.
      all:eauto using rootsde_node_no_val. } }
  { inversion H; elim_ctx. }
Qed.

Lemma rootsde_alloc G α t (l:loc) :
  rootsde G (<[l:=t]> α) (Leaf t) l.
Proof.
  apply RDeLeaf.
  replace (locs (Val l)) with ({[l]}:gset loc) by set_solver.
  apply set_Forall_singleton.
  apply abef_insert.
Qed.

Lemma disentangled_is_rootsde G α T e :
  disentangled G α T e <-> rootsde G α T e.
Proof.
  split; intros X; induction X;
    eauto using DELeaf, RDeLeaf, DEBind, RDeCtx, DEPar, RDePar.
  destruct T.
  { constructor. inversion IHX; try done. subst.
    simpl in *. rewrite set_Forall_singleton in H0.
    rewrite locs_fill_item. by apply set_Forall_union. }
  { constructor. naive_solver. done. }
Qed.

End gen.

(* ------------------------------------------------------------------------ *)
(* [pureinv] *)

Record pureinv (G:graph) (α:amap) (σ:store) (T:task_tree) (e:expr) :=
  { pdom : dom σ = dom α;
    pcmp : rootsde G α T e;
  }.

Lemma pureinv_mon_graph G G' α σ T e :
  G ⊆ G' ->
  pureinv G  α σ T e ->
  pureinv G' α σ T e.
Proof.
  intros ? [].
  constructor; eauto using rootsde_mon_graph.
Qed.

Lemma pureinv_ctx G α σ T K e :
  (set_Forall (λ t, all_abef G α t (locs K)) (frontier T)) ->
  pureinv G α σ T e ->
  pureinv G α σ T (fill_item K e).
Proof.
  intros ? [].
  constructor; eauto using RDeCtx.
Qed.

Lemma pureinv_if G α σ t (b:bool) e1 e2:
  pureinv G α σ (Leaf t) (If b e1 e2) ->
  pureinv G α σ (Leaf t) (if b then e1 else e2).
Proof.
  intros [? Hcomp].
  constructor; eauto.
  apply rootsde_leaf_inv in Hcomp.
  apply RDeLeaf.
  eapply all_abef_mon_set. 2:eauto.
  destruct b; set_solver.
Qed.

Lemma pureinv_unfold_fold G α σ t (v:val):
  pureinv G α σ (Leaf t) (Unfold (VFold v)) ->
  pureinv G α σ (Leaf t) v.
Proof.
  intros [? X]. constructor. done.
  inversion X; subst.
  { constructor. naive_solver. }
  { destruct K; inversion H. subst. constructor.
    inversion H0; subst. naive_solver. destruct K; inversion H1. }
Qed.

Lemma pureinv_let_val G α σ t x (v:val) (e:expr):
  pureinv G α σ (Leaf t) (Let x v e) ->
  pureinv G α σ (Leaf t) (subst' x v e).
Proof.
  intros [? Hcomp].
  constructor; eauto.
  apply rootsde_leaf_inv in Hcomp.
  apply RDeLeaf.
  eapply all_abef_mon_set; eauto using locs_subst'.
Qed.

Lemma pureinv_fork G α σ t e1 e2 v w :
  (t,v) ∈ G -> (t,w) ∈ G ->
  pureinv G α σ (Leaf t) (Par e1 e2) ->
  pureinv G  α σ (Node t (Leaf v) (Leaf w)) (RunPar (Call e1 [Val VUnit]) (Call e2 [Val VUnit])).
Proof.
  intros ?? [? Hcomp].
  apply rootsde_leaf_inv in Hcomp.
  rewrite locs_par all_abef_union in Hcomp. destruct Hcomp as (?&?).
  constructor; eauto.
  { apply RDePar; apply RDeLeaf.
    { apply all_abef_pre_reachable with t.
      { apply edge_reachable. done. }
      { replace (locs (Call e1 [Val VUnit])) with (locs e1) by set_solver.
        apply all_abef_mon_graph with G; eauto using graph_fork_incl. } }
    { apply all_abef_pre_reachable with t.
      { apply edge_reachable. done. }
      { replace (locs (Call e2 [Val VUnit])) with (locs e2) by set_solver.
        apply all_abef_mon_graph with G; eauto using graph_fork_incl. } } }
Qed.

Lemma pureinv_par_inv G α σ T1 T2 e1 e2 t :
  pureinv G α σ (Node t T1 T2) (RunPar e1 e2) -> pureinv G α σ T1 e1 /\ pureinv G α σ T2 e2.
Proof.
  intros [? Hcomp].
  destruct (rootsde_par _ _ _ _ _ _ _ Hcomp).
  split; constructor; eauto.
Qed.

Lemma pureinv_par_l G α T1 T2 e1 e2 σ σ' α' T1' e1' t:
  step G σ α T1 e1 σ' α' T1' e1' ->
  pureinv G α σ (Node t T1 T2) (RunPar e1 e2) ->
  pureinv G α' σ' T1' e1' ->
  pureinv G α' σ' (Node t T1' T2) (RunPar e1' e2).
Proof.
  intros Hstep [? Hcomp] [] .
  constructor; eauto.
  { apply rootsde_par in Hcomp. destruct Hcomp.
    apply RDePar; eauto.
    apply rootsde_mon_amap with α; eauto using step_inv_amap. }
Qed.

Lemma pureinv_par_r G α T1 T2 e1 e2 σ σ' α' T2' e2' t:
  step G σ α T2 e2 σ' α' T2' e2' ->
  pureinv G α σ (Node t T1 T2) (RunPar e1 e2) ->
  pureinv G α' σ' T2' e2' ->
  pureinv G α' σ' (Node t T1 T2') (RunPar e1 e2').
Proof.
  intros Hstep [? Hcomp] [].
  constructor; eauto.
  { apply rootsde_par in Hcomp. destruct Hcomp.
    apply RDePar; eauto.
    apply rootsde_mon_amap with α; eauto using step_inv_amap. }
Qed.

Local Lemma locs_no_loc (v:val) :
  ¬ is_loc v ->
  locs (Val v) = ∅.
Proof. induction v; naive_solver. Qed.

Lemma pureinv_leaf_val G α σ t v :
  dom σ = dom α ->
  vabef G α t v ->
  pureinv G α σ (Leaf t) v.
Proof.
  intros.
  constructor; eauto.
  { apply RDeLeaf.
    destruct_decide (decide (is_loc v)) as Hv.
    { induction v; try done.
      rewrite /locs /location_expr. simpl.
      apply set_Forall_singleton. eauto. apply IHv. simpl in H0. done. set_solver. }
    { rewrite locs_no_loc //. } }
Qed.

Lemma pureinv_case G α σ t (b:bool) (l:loc) (v:val) xl el xr er:
  vabef G α t v ->
  pureinv G α σ (Leaf t) (Case l xl el xr er) ->
  pureinv G α σ (Leaf t) (if b then subst' xl v el else subst' xr v er).
Proof.
  intros ? [? Hcomp].
  constructor; eauto.
  apply rootsde_leaf_inv in Hcomp.
  apply RDeLeaf.
  apply all_abef_mon_set with (L' := locs v ∪ locs el ∪ locs er).
  { pose proof locs_subst'. destruct b; set_solver. }
  rewrite -assoc_L. apply all_abef_union.
  split.
  { destruct_decide (decide (is_loc v)) as Hv.
    { induction v; try done.
      rewrite /locs /location_expr. simpl.
      apply set_Forall_singleton. eauto. apply IHv; set_solver. }
    { assert (locs v = ∅) as ->. by apply locs_no_loc. done. } }
  { eapply all_abef_mon_set. 2:eauto. set_solver. }
Qed.

Lemma pureinv_fold G α σ t (v:val) :
  pureinv G α σ (Leaf t) (Fold v) ->
  pureinv G α σ (Leaf t) (VFold v).
Proof.
  intros [X1 X2]. constructor. done.
  apply rootsde_leaf_inv in X2. constructor. done.
Qed.

Lemma pureinv_immut G α σ t (l:loc) :
  dom σ = dom α ->
  α !! l = Some t ->
  pureinv G α σ (Leaf t) l.
Proof.
  intros ? Hl.
  constructor; eauto. constructor.
  rewrite /locs /all_abef.
  replace (location_expr l) with ({[l]} : gset loc) by set_solver.
  apply set_Forall_singleton. rewrite /abef Hl. reflexivity.
Qed.

Lemma pureinv_init G t e σ α :
  dom σ = dom α ->
  locs e = ∅ ->
  pureinv G σ α (Leaf t) e.
Proof.
  intros ? Ht.
  constructor.
  { done. }
  { apply RDeLeaf. rewrite Ht. apply set_Forall_empty. }
Qed.

Lemma pureinv_bind G α σ T K e :
  (is_val e -> is_leaf T) ->
  pureinv G α σ T (fill_item K e) ->
  (set_Forall (fun t => all_abef G α t (locs K)) (frontier T)) /\ pureinv G α σ T e.
Proof.
  intros C [? Hcomp].
  destruct_decide (decide (is_val e)).
  { destruct T; last naive_solver.
    apply rootsde_leaf_inv in Hcomp.
    rewrite locs_fill_item all_abef_union in Hcomp. simpl. rewrite set_Forall_singleton.
    split; first naive_solver.
    constructor; eauto.  apply RDeLeaf. naive_solver. }
  { apply rootsde_inv_ctx in Hcomp; eauto.
    split; first naive_solver.
    constructor; naive_solver.  }
Qed.


Lemma roots_call `{Countable A} (l:loc) self args body (G:graph.graph A) α t vs :
  length args = length vs ->
  all_abef G α t (locs body) ->
  rootsde G α (Leaf t) (Call l (Val <$> vs)) ->
  rootsde G α (Leaf t) (substs' (zip (self :: args) (VLoc l :: vs)) body).
Proof.
  intros ? Hm Hr.
  apply rootsde_leaf_inv in Hr.
  replace (locs (Call l (Val <$> vs))) with ({[l]} ∪ locs vs) in Hr.
  2:{ rewrite /locs /locs_expr. simpl. f_equal. rewrite location_list_val //. }
  apply RDeLeaf.
  eapply all_abef_mon_set. apply locs_substs'.
  rewrite snd_zip.
  2:{ simpl. lia. }
  replace (locs (VLoc l::vs)) with ({[l]} ∪ locs vs) by set_solver.
  apply set_Forall_union; eauto.
Qed.

Lemma pureinv_call_clo (l:loc) self args body G α σ t vs :
  length args = length vs ->
  all_abef G α t (locs body) ->
  pureinv G α σ (Leaf t) (Call l (Val <$> vs)) ->
  pureinv G α σ (Leaf t) (substs' (zip (self :: args) (VLoc l :: vs)) body).
Proof.
  intros ? ? [].
  constructor; eauto using roots_call.
Qed.

Lemma rootde_call `{Countable A} self args body (G:graph.graph A) α t vs :
  length args = length vs ->
  locs body = ∅ ->
  rootsde G α (Leaf t) (Call (VCode (Lam self args body)) (Val <$> vs)) ->
  rootsde G α (Leaf t) (substs' (zip (self :: args) ((VCode (Lam self args body)) :: vs)) body).
Proof.
  intros ? Hlb Hr.
  apply rootsde_leaf_inv in Hr.
  replace (locs (Call (VCode _) (Val <$> vs))) with (locs vs) in Hr.
  2:{ rewrite /locs /locs_expr. simpl. f_equal. rewrite location_list_val. set_solver. }
  apply RDeLeaf.
  eapply all_abef_mon_set. apply locs_substs'.
  rewrite snd_zip.
  2:{ simpl. lia. }
  rewrite Hlb right_id_L.
  apply set_Forall_union; eauto.
  rewrite /locs /location_val. simpl. easy.
Qed.

Lemma pureinv_call self args body G α σ t vs :
  locs body = ∅ ->
  length args = length vs ->
  pureinv G α σ (Leaf t) (Call (VCode (Lam self args body)) (Val <$> vs)) ->
  pureinv G α σ (Leaf t) (substs' (zip (self :: args) ((VCode (Lam self args body)) :: vs)) body).
Proof.
  intros ? ? [].
  constructor; eauto using roots_call, rootde_call.
Qed.

Lemma eval_call_prim_is_no_loc p v1 v2 v :
  eval_call_prim p v1 v2 = Some v -> ¬ is_loc v.
Proof.
Proof.
  intros E. destruct p,v1,v2; simpl in E; naive_solver.
Qed.

Lemma pureinv_alloc G (l:loc) t α σ s :
  l ∉ dom σ ->
  dom σ = dom α ->
  pureinv G (<[l:=t]> α) (<[l:=s]> σ) (Leaf t) l.
Proof.
  intros.
  constructor; eauto using rootsde_alloc.
  rewrite !dom_insert_L. set_solver.
Qed.

Lemma all_abef_vabef `{Countable A} (G:graph.graph A) α t g v :
  locs v ⊆ g ->
  all_abef G α t g ->
  vabef G α t v.
Proof.
  induction v; try done.
  intros. eapply all_abef_elem; last done.
  rewrite /locs /location_val in H0.
  set_solver.
Qed.

Lemma step_inv_reach t S σ α G T e σ' α' T' e' :
  step G σ α T e σ' α' T' e'  ->
  set_Forall (λ t : timestamp, all_abef G α t S) (frontier T) ->
  t ∈ frontier T'  ->
  all_abef G α t S.
Proof.
  induction 1; intros Hforall Hu; eauto.
  { inversion H; subst; eauto.
    { rewrite set_Forall_singleton in Hforall.
      eapply all_abef_pre_reachable with t0.
      { simpl in Hu. rewrite elem_of_union in Hu.
        apply rtc_once. set_solver. }
      { eapply all_abef_mon_graph; eauto using graph_fork_incl. } }
    { simpl in *. rewrite elem_of_singleton in Hu. subst.
      eapply all_abef_pre_reachable with t1.
      { apply rtc_once. set_solver. }
      { eapply all_abef_mon_graph; eauto using graph_join_incl.
        apply Hforall. set_solver. } } }
  all:simpl in Hu.
  { rewrite !elem_of_union in Hu.
    destruct Hu as [?|?].
    { apply IHstep; eauto. eapply set_Forall_union_inv_1. done. }
    { apply Hforall. set_solver. } }
  { rewrite !elem_of_union in Hu.
    destruct Hu as [?|?].
    { apply Hforall. set_solver. }
    { apply IHstep; eauto. eapply set_Forall_union_inv_2. done. } }
Qed.
