From Coq.ssr Require Import ssreflect.
From stdpp Require Import strings binders gmap gmultiset.

From dislog.utils Require Export graph.
From dislog.lang Require Export syntax syntax_instances substitution head_semantics.

Notation store := (gmap loc storable).
Implicit Type σ : store.

Section gen.
Context `{Countable A}.

Notation amap  := (gmap loc A).
Notation graph := (graph A).

Definition cloalloc : Type := (loc * A * function).

Implicit Type t : A.
Implicit Type α : amap.
Implicit Type G : graph.

(******************************************************************************)
(* head semantics *)

Inductive head_step : graph -> A -> store -> amap -> expr -> store -> amap -> expr -> Prop :=
| HeadIf : forall G t σ α (b:bool) e1 e2,
    head_step G t
      σ α (If b e1 e2)
      σ α (if b then e1 else e2)
| HeadLet : forall G t σ α (v:val) x e,
    head_step G t
      σ α (Let x v e)
      σ α (subst' x v e)
| HeadCall : forall G t σ α ts vs self args code,
  ts = Val <$> vs ->
  length args = length ts ->
  locs code = ∅ ->
    head_step G t
      σ α (Call (VCode (Lam self args code)) ts)
      σ α (substs' (zip (self::args) (VCode (Lam self args code)::vs)) code)
| HeadCallClo : forall G t σ α ts vs self args body l,
    σ !! l = Some (SClo self args body) ->
    ts = Val <$> vs ->
    length args = length ts ->
    head_step G t
      σ α (Call (VLoc l) ts)
      σ α (substs' (zip (self::args) (VLoc l::vs)) body)
| HeadCallPrim : forall G t σ α v1 v2 v p,
    eval_call_prim p v1 v2 = Some v ->
    head_step G t
      σ α (CallPrim p v1 v2)
      σ α v
| HeadFunc : forall G t σ α (l:loc) self args code,
    α !! l = Some t ->
    σ !! l = Some (SClo self args code) ->
    head_step G t
      σ  α  (Clo (Lam self args code))
      σ  α (VLoc l)
| HeadAlloc : forall G t σ σ' α α' (n:Z) (v:val) (l:loc),
    (0 < n)%Z ->
    l ∉ dom σ ->
    l ∉ dom α ->
    σ' = <[l:=SBlock (replicate (Z.to_nat n) v)]> σ ->
    α' = <[l:=t]> α ->
    head_step G t
      σ  α  (Alloc n v)
      σ' α' (Val l)
| HeadLoad : forall G t σ α (l:loc) (bs:list val) (i:Z) (v:val),
    σ !! l = Some (SBlock bs) ->
    (0 <= i < Z.of_nat (length bs))%Z ->
    bs !! (Z.to_nat i) = Some v ->
    head_step G t
      σ α (Load l i)
      σ α v
| HeadStore : forall G t σ σ' α (l:loc) bs i (v:val),
    σ !! l = Some (SBlock bs) ->
    (0 <= i < Z.of_nat (length bs))%Z ->
    σ' = <[l := SBlock (<[Z.to_nat i := v]> bs)]> σ ->
    head_step G t
      σ  α (Store l i v)
      σ' α VUnit
| HeadProd : forall G t σ α (v1 v2:val) (l:loc),
    α !! l = Some t ->
    σ !! l = Some (SProd v1 v2) ->
    head_step G t
      σ α (Prod v1 v2) σ α l
| HeadProj : forall b G t σ α (v1 v2:val) (l:loc),
  σ !! l = Some (SProd v1 v2) ->
  head_step G t σ α (if b then Fst l else Snd l) σ α (if b then v1 else v2)
| HeadLength : forall G t σ α (l:loc) bs,
    σ !! l = Some (SBlock bs) ->
    head_step G t
      σ α (Length l)
      σ α (VInt (Z.of_nat (length bs)))
| HeadCAS : forall G t σ σ' α (l:loc) (i:Z) (v v0 v':val) bs,
    (0 <= i < Z.of_nat (length bs))%Z ->
    σ !! l = Some (SBlock bs) ->
    bs !! (Z.to_nat i) = Some v0 ->
    σ' = (if bool_decide (v=v0)
          then (insert l (SBlock (<[Z.to_nat i := v']> bs)) σ) else σ) ->
    head_step G t
      σ  α (CAS l i v v')
      σ' α (Val (bool_decide (v=v0)))
| HeadFold : forall G t σ α (v:val),
    head_step G t σ α (Fold v) σ α (VFold v)
| HeadUnfoldFold : forall G t σ α (v:val),
  head_step G t σ α (Unfold (VFold v)) σ α v
| HeadIn : forall b G t σ α (v:val) (l:loc),
  α !! l = Some t ->
  σ !! l = Some (if b then SInL v else SInR v) ->
  head_step G t σ α (if b then InL v else InR v) σ α l
| HeadCase : forall b G t σ α vl vr (l:loc) xl el xr er,
  σ !! l = Some (if b then SInL vl else SInR vr) ->
  head_step G t σ α (Case l xl el xr er) σ α (if b then subst' xl vl el else subst' xr vr er)
.

Local Lemma middle_list {X:Type} (l:list X) x l1 l2 :
  l = l1 ++ x::l2 ->
  l !! length l1 = Some x.
Proof.
  intros ->.
  rewrite list_lookup_middle; easy.
Qed.

Lemma must_be_val vs xs e ys :
  Val <$> vs = (Val <$> xs) ++ e :: ys ->
  is_val e.
Proof.
  intros E.
  apply middle_list in E.
  rewrite list_lookup_fmap in E.
  destruct (vs !! length (Val <$> xs)); naive_solver.
Qed.

Lemma head_step_no_ctx G t σ α K e σ' α' e' :
  ¬ is_val e ->
  ¬ head_step G t σ α (fill_item K e) σ' α' e'.
Proof.
  intros ? Hstep. inversion Hstep; subst.
  all:destruct K; try destruct b; try naive_solver.
  all:inversion H1; eauto using must_be_val.
Qed.

Lemma head_step_no_val G t σ α e σ' α' e' :
  head_step G t σ α e σ' α' e' ->
  ¬ is_val e.
Proof. inversion 1; try destruct b; eauto. Qed.

Lemma head_step_inv_amap G t σ α e σ' α' e' :
  dom σ = dom α ->
  head_step G t σ α e σ' α' e' ->
  α ⊆ α'.
Proof.
  intros Hdom.
  inversion 1; eauto; subst.
  all:apply insert_subseteq; apply not_elem_of_dom; rewrite -?Hdom; eauto.
Qed.

End gen.

#[export] Hint Constructors head_step : head_step.

Definition isimmut b : bool :=
  match b with
  | SClo _ _ _ | SProd _ _ | SInL _ | SInR _ => true
  | _ => false end.
