Require Import ImpDefs.

From Coq Require Import Equality Ensembles Relations RelationClasses.

Class LabelLattice {Label : Set} `(LabelOrder Label) := {
  flows_to_dec : forall l1 l2, {flows_to l1 l2}+{~(flows_to l1 l2)} ;
  flows_to_antisymm : forall l1 l2, flows_to l1 l2 -> flows_to l2 l1 -> l1 = l2 ;

  bot : Label ;
  bot_least : forall l, flows_to bot l ;
  top : Label ;
  top_most : forall l, flows_to l top ;

  join : Label -> Label -> Label ;
  join_bound_l : forall l1 l2, flows_to l1 (join l1 l2) ;
  join_bound_r : forall l1 l2, flows_to l2 (join l1 l2) ;
  join_lub : forall l1 l2 l, flows_to l1 l -> flows_to l2 l -> flows_to (join l1 l2) l ;

  meet : Label -> Label -> Label ;
  meet_bound_l : forall l1 l2, flows_to (meet l1 l2) l1 ;
  meet_bound_r : forall l1 l2, flows_to (meet l1 l2) l2 ;
  meet_glb : forall l1 l2 l, flows_to l l1 -> flows_to l l2 -> flows_to l (meet l1 l2) ;

  reflect_galois: forall l1 l2, (flows_to l1 (reflect l2)) <-> (flows_to l2 (reflect l1)) ;
}.

#[global] Hint Resolve bot_least : core.
#[global] Hint Resolve top_most : core.

Module Type InferenceDefs (ID : ImpDefs).
  Import ID.

  Open Scope imp_scope.

  (* For inference to work, the label set needs to be a lattice *)
  Parameter label_lattice : LabelLattice label_order.
  #[global] Instance label_lattice_inst : LabelLattice label_order.
    apply label_lattice.
  Defined.

  Local Notation "'[' p ']' '<--' e1 '|' e2" := (match e1 with | Some p => e2 | None => None end) (at level 90, right associativity, p pattern).

  Fixpoint expr_label G e : option Label :=
    match e with
      | Nat _ => Some bot
      | Var x => G x
      | Op _ e1 e2 => [l1] <-- expr_label G e1 |
                      [l2] <-- expr_label G e2 |
                      Some (join l1 l2)
    end.

  Inductive PartialInferCmd :=
    | PartSkip : PartialInferCmd
    | PartAssign (x : Varname) (e : Expr) : PartialInferCmd
    | PartIf (e : Expr) (l : Label) (c1 : PartialInferCmd) (c2 : PartialInferCmd) : PartialInferCmd
    | PartSeq (c1 : PartialInferCmd) (c1nt : Label) (c2 : PartialInferCmd) : PartialInferCmd
    | PartWhile (e : Expr) (inpc : Label) (c : PartialInferCmd) : PartialInferCmd
    | PartProgDown (c : PartialInferCmd) : PartialInferCmd.

  Fixpoint infer_pdown_locs G pc c :=
    match c with
      | Skip => Some (PartSkip, top, bot)
      | Assign x e => [l] <-- expr_label G e |
                      [gx] <-- G x |
                      if flows_to_dec (join pc l) gx
                      then Some (PartAssign x e, gx, bot)
                      else None
      | If e c1 c2 => [l] <-- expr_label G e |
                      [(c1', b1, nt1)] <-- infer_pdown_locs G (join pc l) c1 |
                      [(c2', b2, nt2)] <-- infer_pdown_locs G (join pc l) c2 |
                      let joinnt := (join nt1 nt2) in
                      if flows_to_dec joinnt (reflect joinnt)
                      then Some (PartIf e l c1' c2', meet b1 b2, joinnt)
                      else Some (PartIf e l (PartProgDown c1') c2', meet b1 b2, nt2)
      | Seq c1 c2 =>  [(c1', b1, nt1)] <-- infer_pdown_locs G pc c1 |
                      [(c2', b2, nt2)] <-- infer_pdown_locs G pc c2 |
                      if flows_to_dec nt1 b2
                      then Some (PartSeq c1' nt1 c2', meet b1 b2, join nt1 nt2)
                      else Some (PartSeq (PartProgDown c1') bot c2', meet b1 b2, join pc nt2)
      | While e c1 => [l] <-- expr_label G e |
                      let pcl := join pc l in
                      if flows_to_dec pcl (reflect pcl)
                      then [(c', b, nt)] <-- infer_pdown_locs G pcl c1 |
                          if flows_to_dec nt b
                          then Some (PartWhile e (join l nt) c', meet b (reflect pcl), join nt pcl)
                          else Some (PartWhile e l (PartProgDown c'), meet b (reflect pcl), pcl)
                      else None
      | _ => None
    end.

  Fixpoint set_pdown_labs pc c : Cmd :=
    match c with
      | PartSkip => Skip
      | PartAssign x e => Assign x e
      | PartIf e l c1 c2 => If e (set_pdown_labs (join pc l) c1) (set_pdown_labs (join pc l) c2)
      | PartSeq c1 c1nt c2 => Seq (set_pdown_labs pc c1) (set_pdown_labs (join pc c1nt) c2)
      | PartWhile e inpc c' => While e (set_pdown_labs (join pc inpc) c')
      | PartProgDown c' => ProgDown pc (set_pdown_labs pc c')
    end.

  Definition infer G pc c : option (Cmd * Label) :=
    [(partc, _, nt)] <-- infer_pdown_locs G pc c |
    Some (set_pdown_labs pc partc, nt).

  Inductive EqCmdStructure : relation Cmd :=
    | EqSkip : EqCmdStructure Skip Skip
    | EqStop : EqCmdStructure Stop Stop
    | EqAssign : forall x e, EqCmdStructure (Assign x e) (Assign x e)
    | EqIf : forall e c1 c1' c2 c2',
        EqCmdStructure c1 c1' -> EqCmdStructure c2 c2' -> EqCmdStructure (If e c1 c2) (If e c1' c2')
    | EqSeq : forall c1 c1' c2 c2',
        EqCmdStructure c1 c1' -> EqCmdStructure c2 c2' -> EqCmdStructure (Seq c1 c2) (Seq c1' c2')
    | EqWhile : forall e c c', EqCmdStructure c c' -> EqCmdStructure (While e c) (While e c')
    | EqPDown : forall l l' c c', EqCmdStructure c c' -> EqCmdStructure (ProgDown l c) (ProgDown l' c').

  #[global] Instance eq_cmd_equiv : Equivalence EqCmdStructure.
    split.
    * unfold Reflexive. induction x ; auto using EqCmdStructure.
    * unfold Symmetric. intros ? ? EqStruc. induction EqStruc ; auto using EqCmdStructure.
    * unfold Transitive. intros ? ? z EqStruc01. revert z.
      induction EqStruc01 ; intros ? EqStruc12 ; inversion EqStruc12 ; auto using EqCmdStructure.
  Defined.

  Inductive PDownLe : relation Cmd :=
    | PdLeSkip : PDownLe Skip Skip
    | PdLeStop : PDownLe Stop Stop
    | PdLeAssign : forall x e, PDownLe (Assign x e) (Assign x e)
    | PdLeIf : forall e c1 c1' c2 c2',
        PDownLe c1 c1' -> PDownLe c2 c2' -> PDownLe (If e c1 c2) (If e c1' c2')
    | PdLeSeq : forall c1 c1' c2 c2',
        PDownLe c1 c1' -> PDownLe c2 c2' -> PDownLe (Seq c1 c2) (Seq c1' c2')
    | PdLeWhile : forall e c c', PDownLe c c' -> PDownLe (While e c) (While e c')
    | PdLePDownEq : forall l l' c c', PDownLe c c' -> PDownLe (ProgDown l c) (ProgDown l' c')
    | PdLePDownLt : forall l c c', PDownLe c c' -> PDownLe c (ProgDown l c').

  #[global] Instance pdownle_preorder : PreOrder PDownLe.
    split.
    * unfold Reflexive. induction x ; auto using PDownLe.
    * unfold Transitive. intros x y z PDown01. revert z.
      induction PDown01 ; intros ? PDown12 ; dependent induction PDown12 ; eauto using PDownLe.
  Defined.

  Fixpoint erase_pdown c : Cmd :=
    match c with
      | If e c1 c2 => If e (erase_pdown c1) (erase_pdown c2)
      | Seq c1 c2 => Seq (erase_pdown c1) (erase_pdown c2)
      | While e c' => While e (erase_pdown c')
      | ProgDown l c' => erase_pdown c'
      | _ => c
    end.

End InferenceDefs.