From Coq Require Import Wellfounded.
From stdpp Require Import strings binders gmap gmultiset.

(* ------------------------------------------------------------------------ *)
(* Syntax of DisLog *)

(* Locations are modeled with Z, any countable set would work. *)
Inductive loc := to_loc : Z -> loc.
Definition of_loc : loc -> Z := fun '(to_loc x) => x.

(* We inherit various instances form Z. *)
#[export] Instance loc_eq_dec : EqDecision loc.
Proof. solve_decision. Qed.
#[export] Instance loc_countable : Countable loc.
Proof. apply inj_countable' with of_loc to_loc. now intros []. Qed.
#[export] Instance loc_infinite : Infinite loc.
Proof. apply inj_infinite with to_loc (fun x => Some (of_loc x)). easy. Qed.
#[export] Instance loc_inhabited : Inhabited loc := populate (to_loc inhabitant).

(* Primitives *)
Inductive int_op := IntAdd | IntMul | IntSub | IntDiv | IntMod | IntMin | IntMax.
Inductive int_cmp := IntLt | IntLe | IntGt | IntGe.
Inductive bool_op := BoolAnd | BoolOr.

Inductive prim :=
| PrimEq : prim
| PrimBoolOp : bool_op -> prim
| PrimIntOp : int_op -> prim
| PrimIntCmp : int_cmp -> prim.

(******************************************************************************)
(* The timestamp type is just a wrapper around nat. Any infinite countable type
   would work. *)

Inductive timestamp := to_timestamp : nat -> timestamp.
Definition of_timestamp : timestamp -> nat := fun '(to_timestamp x) => x.

(* We inherit various instances form nat. *)
#[export] Instance eqdec_timestamp : EqDecision timestamp.
Proof. solve_decision. Qed.
#[export] Instance countable_timestamp : Countable timestamp.
Proof. apply inj_countable' with of_timestamp to_timestamp. now intros []. Qed.
#[export] Instance infinite_timestamp : Infinite timestamp.
Proof. apply inj_infinite with to_timestamp (fun x => Some (of_timestamp x)). easy. Qed.
#[export] Instance inhabited_timestamp : Inhabited timestamp := populate (to_timestamp inhabitant).

Inductive expr : Set :=
(* Values *)
| Val : val -> expr
(* closures are heap allocated, so they take a step! *)
| Clo : function -> expr
| Var : string -> expr
| Call : expr -> list expr -> expr
| CallPrim : prim -> expr -> expr -> expr
| If : expr -> expr -> expr -> expr
| Let : binder -> expr -> expr -> expr
(* Product *)
| Prod : expr -> expr -> expr
| Fst : expr -> expr
| Snd : expr -> expr
(* Sum *)
| InL : expr -> expr
| InR : expr -> expr
| Case :
  expr ->
  binder -> expr -> (* inl branch *)
  binder -> expr -> (* inr branch *)
  expr
(* Memory *)
| Alloc : expr -> expr -> expr
| Load : expr -> expr -> expr
| Store : expr -> expr -> expr -> expr
| Length : expr -> expr
(* Parallelism *)
| Par : expr -> expr -> expr (* A "non-yet-running" par, available at the user-level syntax. *)
| RunPar : expr -> expr -> expr (* A running par. *)
| CAS : expr -> expr -> expr -> expr -> expr
(* Types *)
| Fold : expr -> expr
| Unfold : expr -> expr
with function : Set :=
  Lam :
    binder -> (* the recursive name *)
    list binder -> (* the arguments *)
    expr ->
    function
with val : Set :=
| VUnit : val
| VBool : bool -> val
| VInt  : Z -> val
| VLoc : loc -> val
| VFold : val -> val
| VCode : function -> val.

Coercion VBool : bool >-> val.
Coercion VLoc : loc >-> val.
Coercion VInt : Z >-> val.
Coercion Val : val >-> expr.
Coercion Var : string >-> expr.

(* ------------------------------------------------------------------------ *)
(* [to_val], [is_val], and [is_loc] *)

Definition to_val e :=
  match e with
  | Val v => Some v
  | _ => None end.

Definition is_val e :=
  match e with
  | Val v => true
  | _ => false end.

Lemma is_val_true e :
  is_val e <-> exists v, e = Val v.
Proof. split; destruct e; naive_solver. Qed.
Lemma is_val_false e :
  ¬ is_val e <-> to_val e = None.
Proof. destruct e; naive_solver. Qed.
Lemma is_val_false1 e :
  ¬ is_val e ->
  to_val e = None.
Proof. intros. by apply is_val_false. Qed.

Fixpoint is_loc v :=
  match v with
  | VLoc _ => true
  | VFold v => is_loc v
  | _ => false end.

(* Lemma is_loc_inv v :
  is_loc v -> exists l, v = VLoc l.
Proof. destruct v; naive_solver. Qed.
*)

(* ------------------------------------------------------------------------ *)
(* The induction principle generated by Coq for expr is too weak (due to
   the call construct with lists). We set-up a size function for
   wellfounded induction.
 *)

Fixpoint expr_size (e : expr):= 1 +
  match e with
  | Var _ => 0
  | Val v =>
      match v with
      | VCode c => clo_size c
      | _ => 0 end
  | Clo c => clo_size c
  | Length e | Fst e | Snd e | InL e | InR e | Fold e | Unfold e => expr_size e
  | Call e1 xs => expr_size e1 + list_sum (expr_size <$> xs)
  | Prod e1 e2 | CallPrim _ e1 e2 | Let _ e1 e2 | Alloc e1 e2 | Load e1 e2 | Par e1 e2 | RunPar e1 e2 => expr_size e1 + expr_size e2
  | Case e1 _ e2 _ e3 | Store e1 e2 e3 | If e1 e2 e3 => expr_size e1 + expr_size e2 + expr_size e3
  | CAS e1 e2 e3 e4 => expr_size e1 + expr_size e2 + expr_size e3 + expr_size e4
  end
with clo_size c := match c with Lam _ _ e => expr_size e end.

Lemma expr_size_non_zero e :
  expr_size e ≠ 0.
Proof. destruct e; simpl; lia. Qed.

(* ------------------------------------------------------------------------ *)
(* Contexts. *)

(* Contexts are syntactically non-recursive. *)
Inductive ctx : Set :=
| CtxCall1 : val -> list val -> list expr -> ctx (* call v (vs ++ ◻ :: es) *)
| CtxCall2 : list expr -> ctx (* call ◻ ts *)
| CtxCallPrim1 : prim -> expr -> ctx (* call_prim p ◻ e *)
| CtxCallPrim2 : prim -> val -> ctx (* call_prim p v ◻ *)
| CtxIf : expr -> expr -> ctx (* if ◻ then e1 else e2 *)
| CtxLet : binder -> expr -> ctx (* let x = ◻ in e2 *)
| CtxAlloc1 : expr -> ctx (* alloc ◻ e1 *)
| CtxAlloc2 : val -> ctx (* alloc v1 ◻ *)
| CtxLoad1 : expr -> ctx (* load ◻ e1 *)
| CtxLoad2 : val -> ctx (* load v1 ◻ *)
| CtxStore1 : expr -> expr -> ctx (* store ◻ e1 e2 *)
| CtxStore2 : val -> expr -> ctx (* load v1 ◻ e2 *)
| CtxStore3 : val -> val -> ctx (* load v1 v2 ◻ *)
| CtxLength : ctx (* length ◻ *)
| CtxPair1 : expr -> ctx
| CtxPair2 : val -> ctx
| CtxFst : ctx
| CtxSnd : ctx
| CtxInL : ctx
| CtxInR : ctx
| CtxCase : binder -> expr -> binder -> expr -> ctx
| CtxCas1 : expr -> expr -> expr -> ctx (* cas ◻ e1 e2 e3 *)
| CtxCas2 : val -> expr -> expr -> ctx (* cas v1 ◻ e2 e3 *)
| CtxCas3 : val -> val -> expr -> ctx (* cas v1 v2 ◻ e3 *)
| CtxCas4 : val -> val -> val -> ctx (* cas v1 v2 v3 ◻ *)
| CtxPar1 : expr -> ctx
| CtxPar2 : val -> ctx
| CtxFold : ctx
| CtxUnfold : ctx
.

Definition fill_item (k:ctx) (e:expr) : expr :=
  match k with
  | CtxCall1 e' vs ts => Call e' ((Val <$> vs)++e::ts)
  | CtxCall2 ts => Call e ts
  | CtxCallPrim1 p e2 => CallPrim p e e2
  | CtxCallPrim2 p v => CallPrim p v e
  | CtxIf e2 e3 => If e e2 e3
  | CtxLet x e2 => Let x e e2
  | CtxAlloc1 e' => Alloc e e'
  | CtxAlloc2 v => Alloc (Val v) e
  | CtxLoad1 e2 => Load e e2
  | CtxLoad2 v => Load (Val v) e
  | CtxStore1 e1 e2 => Store e e1 e2
  | CtxStore2 v e2 => Store (Val v) e e2
  | CtxStore3 v1 v2 => Store (Val v1) (Val v2) e
  | CtxLength => Length e
  | CtxPair1 e2 => Prod e e2
  | CtxPair2 v1 => Prod v1 e
  | CtxFst => Fst e
  | CtxSnd => Snd e
  | CtxInL => InL e
  | CtxInR => InR e
  | CtxCase xl el xr er => Case e xl el xr er
  | CtxCas1 e1 e2 e3 => CAS e e1 e2 e3
  | CtxCas2 v e2 e3 => CAS (Val v) e e2 e3
  | CtxCas3 v1 v2 e3 => CAS (Val v1) (Val v2) e e3
  | CtxCas4 v1 v2 v3 => CAS (Val v1) (Val v2) (Val v3) e
  | CtxPar1 e1 => Par e e1
  | CtxPar2 v1 => Par v1 e
  | CtxFold => Fold e
  | CtxUnfold => Unfold e
  end.

Lemma ctx_list_length xs xs' e e' ys ys' :
  ¬ is_val e ->
  ¬ is_val e' ->
  (Val <$> xs) ++ e :: ys = (Val <$> xs') ++ e' :: ys' ->
  length xs = length xs'.
Proof.
  revert xs' e e' ys ys'.
  induction xs; intros.
  all: destruct xs'; naive_solver.
Qed.

(* [fill_item] is injective for non-values. *)
Lemma fill_item_inj K1 K2 e1 e2 :
  ¬ is_val e1 ->
  ¬ is_val e2 ->
  fill_item K1 e1 = fill_item K2 e2 ->
  K1=K2 /\ e1 = e2.
Proof using.
  intros ? ? E.
  assert (Inj eq eq Val) as Hinj.
  { intros ? ? Heq'. injection Heq'. easy. }
  destruct K1,K2; inversion E; subst; simpl in *; try naive_solver.
  assert (length l = length l1).
  { eapply (ctx_list_length _ _ e1 e2); eauto. }
  apply app_inj_1 in H3. 2:{ now do 2 rewrite fmap_length. }
  destruct H3  as (Hl1&Hl2). injection Hl2. clear Hl2. intros -> ->.
  split; try easy.
  apply list_fmap_eq_inj in Hl1; naive_solver.
Qed.

(******************************************************************************)
(* Storables: what can be stored in the heap. *)

Inductive storable :=
| SBlock : list val -> storable
| SProd : val -> val -> storable
| SInL : val -> storable
| SInR : val -> storable
| SClo : binder -> list binder -> expr -> storable.

(******************************************************************************)
(* Locations *)

(* We define a typeclass for [locs], a function returning a set of locations. *)
Class Location A := locs : A -> gset loc.

Global Instance location_list  `{Location A} : Location (list A) :=
  fun xs => ⋃ (locs <$> xs).
Global Instance location_loc : Location loc := gset_singleton.

Fixpoint locs_val (v:val) : gset loc :=
  match v with
  | VLoc l => {[l]}
  | VFold v => locs_val v
  | _ => ∅ end.

Global Instance location_val : Location val := locs_val.

Fixpoint locs_expr (e:expr) : gset loc :=
  match e with
  | Val v => locs_val v
  | Var _ => ∅
  | Clo c => locs_func c
  | Length e1 | Fst e1 | Snd e1 | InL e1 | InR e1 | Fold e1 | Unfold e1 => locs_expr e1
  | Call e1 e2 => locs_expr e1 ∪ ⋃ (locs_expr <$> e2)
  | Case e1 _ e2 _ e3 | If e1 e2 e3 | Store e1 e2 e3 => locs_expr e1 ∪ locs_expr e2 ∪ locs_expr e3
  | CAS e1 e2 e3 e4 => locs_expr e1 ∪ locs_expr e2 ∪ locs_expr e3 ∪ locs_expr e4
  | Prod e1 e2 | CallPrim _ e1 e2 | Let _ e1 e2 | Alloc e1 e2
  | Load e1 e2 | Par e1 e2 | RunPar e1 e2 => locs_expr e1 ∪ locs_expr e2 end
with locs_func (c:function) :=
  match c with
  | Lam _ _ e => locs_expr e end.

Global Instance location_expr : Location expr := locs_expr.
Global Instance location_clo : Location function := locs_func.

Definition locs_ctx (k:ctx) : gset loc :=
  match k with
  | CtxCall1 e vs ts => locs e ∪ locs vs ∪ locs ts
  | CtxCall2 vs => locs vs
  | CtxCallPrim1 p e => locs e
  | CtxCallPrim2 p v => locs v
  | CtxIf e2 e3 => locs e2 ∪ locs e3
  | CtxLet _ e2 => locs e2
  | CtxAlloc1 e => locs e
  | CtxAlloc2 v => locs v
  | CtxLoad1 e2 => locs e2
  | CtxLoad2 v => locs v
  | CtxStore1 e1 e2 => locs e1 ∪ locs e2
  | CtxStore2 v e2 => locs v ∪ locs e2
  | CtxStore3 v1 v2 => locs v1 ∪ locs v2
  | CtxLength | CtxFst | CtxSnd | CtxInL | CtxInR | CtxFold | CtxUnfold => ∅
  | CtxPair1 e1 => locs e1
  | CtxPair2 v1 => locs v1
  | CtxCase _ e1 _ e2 => locs e1 ∪ locs e2
  | CtxCas1 e1 e2 e3 => locs e1 ∪ locs e2 ∪ locs e3
  | CtxCas2 v e2 e3 => locs v ∪ locs e2 ∪ locs e3
  | CtxCas3 v1 v2 e3 => locs v1 ∪ locs v2 ∪ locs e3
  | CtxCas4 v1 v2 v3 => locs v1 ∪ locs v2 ∪ locs v3
  | CtxPar1 e1 => locs e1
  | CtxPar2 v1 => locs v1
  end.

Global Instance location_ctx : Location ctx := locs_ctx.

Local Lemma union_list_locs_particular l :
  ⋃ (locs_expr <$> (Val <$> l)) = locs l.
Proof. induction l; set_solver. Qed.

Lemma locs_fill_item K e :
  locs (fill_item K e) = locs K ∪ locs e.
Proof.
  destruct K; try set_solver; simpl.
  all: unfold locs,location_expr,location_ctx; simpl.
  rewrite fmap_app, fmap_cons.
  rewrite union_list_app_L. simpl.
  rewrite union_list_locs_particular. set_solver.
Qed.

Lemma locs_par e1 e2 : locs (Par e1 e2) = locs e1 ∪ locs e2.
Proof. reflexivity. Qed.

Lemma location_list_val vs :
  ⋃ (location_expr <$> (Val <$> vs)) = location_list vs.
Proof. induction vs; set_solver. Qed.

Lemma to_val_fill_item K e : to_val (fill_item K e) = None.
Proof. destruct K; naive_solver. Qed.

Lemma to_val_Some_inv e v : to_val e = Some v -> e = Val v.
Proof. destruct e; naive_solver. Qed.

(* ------------------------------------------------------------------------ *)

(* elim_ctx_sure tries to find a context, destruct it, and solve the goal. *)
Ltac elim_ctx_sure :=
  match goal with
  | K : ctx |- _ => destruct K; naive_solver end.
(* elim_ctx tries elim_ctx_sure *)
Ltac elim_ctx := try (exfalso; elim_ctx_sure).
