{-# OPTIONS --experimental-lossy-unification --safe #-}

open import Prelude
open import FreeMonad.PackagedTheory
open import FinitenessConditions
open import Signatures
open import FreeMonad.Syntax
open import FreeMonad.Theory

module Effects.State.Tensor
  (S : Type)
  (choiceState : SplitQuotientedChoiceω S)
  (setState : isSet S)
  (ℱ-theory : FullTheory 0)
  where

import Effects.State S choiceState setState as 𝒮
open import FreeMonad.Combinations.Tensor ℱ-theory 𝒮.𝕊-theory

open FullTheory ℱ-theory using (𝔽) renaming (𝒯 to 𝒯𝒻)

open 𝒮 using (𝕊) renaming (StateTheory to 𝒯𝓈)

State-Synt : Type a  Type _
State-Synt A = S  Syntax 𝔽 (A × S)

import FreeMonad.Quotiented ℱ-theory as 
import FreeMonad.Quotiented 𝒮.𝕊-theory as 𝒮t

State-Term : Type a  Type _
State-Term A = S  ℱ.Term (A × S)

module _ {A : Type a} where
  runState-alg : 𝕊 -Alg[ State-Term A ]
  runState-alg (𝒮.`get   , k) s = k s  s
  runState-alg (𝒮.`put s , k) _ = k tt s

  inj-alg : 𝔽 -Alg[ State-Term A ]
  inj-alg xs s = ℱ.opₜ (cmap  r  r s) xs)

  inj-alg-resp : inj-alg Models 𝒯𝒻
  inj-alg-resp i t k = funExt λ s 
    let lhs  rhs = 𝒯𝒻 .Theory.Eqns i .eqn t
    in  interp 𝔽 inj-alg k lhs s       ≡⟨ interp-fusion 𝔽 (_$ s) _  _  refl) lhs 
        interp 𝔽 ℱ.opₜ (flip k s) lhs  ≡⟨ ℱ.opₜ-resp i t (flip k s) 
        interp 𝔽 ℱ.opₜ (flip k s) rhs  ≡˘⟨ interp-fusion 𝔽 (_$ s) _  _  refl) rhs 
        interp 𝔽 inj-alg k rhs s 

  runState-resp : (runState-alg  𝕊 -Alg[ State-Term A ]) Models 𝒯𝓈
  runState-resp 𝒮.Synt.Laws.put-put t k = refl
  runState-resp 𝒮.Synt.Laws.put-get t k = refl
  runState-resp 𝒮.Synt.Laws.get-put t k = refl

  runStateT-alg : (𝔽  𝕊) -Alg[ State-Term A ]
  runStateT-alg = inj-alg -⊞- runState-alg

  runStateT-ret : A  State-Term A
  runStateT-ret x s = ℱ.return (x , s)

  runStateT-commutes :  t  runStateT-alg Respects commutative .eqn t
  runStateT-commutes (lift (xs , 𝒮.`get   )) k = refl
  runStateT-commutes (lift (xs , 𝒮.`put s )) k = refl

  runStateT-resp : runStateT-alg Models 𝒯
  runStateT-resp = ⊗-resp inj-alg runState-alg  runStateT-commutes inj-alg-resp runState-resp 

module Synt where
  open SyntaxBind (𝔽  𝕊) public

  put : S  Syntax (𝔽  𝕊) 
  put x = op (inr (𝒮.`put x) , var)

  get : Syntax (𝔽  𝕊) S
  get = op (inr 𝒮.`get , var)

  runStateT : Syntax (𝔽  𝕊) A  State-Term A
  runStateT = interp _ runStateT-alg runStateT-ret

open import FreeMonad.Quotiented combinedTheory renaming (Term to State)

put : S  State 
put x = [ Synt.put x ]

get : State S
get = [ Synt.get ]

put-put :  s₁ s₂ (k : State A)  put s₁ >> (put s₂ >> k)  put s₂ >> k
put-put s₁ s₂ k = cong (_>> k) (eq/ _ _ (lawₜ (inl (inr 𝒮.Synt.put-put)) (s₁ , s₂)))

get-put :  (k : State A)  (do s  get ; put s ; k)  k
get-put k = cong (_>> k) (eq/ _ _ (lawₜ (inl (inr 𝒮.Synt.get-put)) _))

put-get :  s (k : S  State A)  (do put s ; s′  get ; k s′)  (do put s ; k s)
put-get s k = cong (_>>= k) (eq/ _ _ (lawₜ (inl (inr 𝒮.Synt.put-get)) s))

get-put-get : (k : S  State A) 
              (do s  get ; put s ; k s)   (do s  get ; k s)
get-put-get k =

  (do s  get
      put s
      k s)                                 ≡⟨⟩

  (do s  get
      put s
      s′  return s
      k s′)                                ≡˘⟨ s  get ;/ put-get s k 

  (do s  get
      put s
      s′  get
      k s′)                                ≡⟨ get-put (get >>= k) 

  (do s  get
      k s) 

get-get :  (k : S  S  State B) 
          (do s₁  get
              s₂  get
              k s₁ s₂)   (do s  get
                              k s s)

get-get k =

  (do s₁  get
      s₂  get
      k s₁ s₂)          ≡˘⟨ get-put _ 

  (do s  get
      put s
      s₁  get
      s₂  get
      k s₁ s₂)          ≡⟨ s  get ;/ put-get s _ 

  (do s₁  get
      put s₁
      s₂  get
      k s₁ s₂)          ≡⟨ s₁  get ;/ put-get s₁ _ 

  (do s  get
      put s
      k s s)            ≡⟨ get-put-get _ 

  (do s  get
      k s s) 

get-nop : (k : State A)  (get >> k)  k
get-nop k =

  (do get
      k)                  ≡˘⟨ get-put-get  _  k) 

  (do s  get
      put s
      k)                  ≡⟨ get-put k 

  k 

liftT : ℱ.Term A  State A
liftT = lift₁

liftT-lemma : (xs : Syntax 𝔽 A)  [ inj₁ xs ]  liftT [ xs ]
liftT-lemma = elim-s _ _ λ { (var x)  refl ; (op (O , xs) P⟨xs⟩)  sym (opₜ-com-[] _) ; cong (opₜ  _,_ (inl O)) (funExt P⟨xs⟩) }

state-commutes :  Sₒ O  (k : Signature.Arity 𝕊 Sₒ  Signature.Arity 𝔽 O  State A)  opₜ (inr Sₒ , λ s  opₜ (inl O , k s))  opₜ (inl O , λ i  opₜ (inr Sₒ , flip k i))
state-commutes Sₒ o k = sym (commutes o Sₒ (flip k))

state-commutes-mult : (xs : 𝒮t.Term A) (ys : ℱ.Term B) (k : A  B  State C)
                     (do x  lift₂ xs ; y  lift₁ ys ; k x y)  (do y  lift₁ ys ; x  lift₂ xs ; k x y)
state-commutes-mult xs ys k = sym (commutes-mult ys xs (flip k))

pattern `put s = inr (𝒮.`put s)
pattern `get   = inr  𝒮.`get

module _ {A : Type a} where
  toStateT : State-Term A  State A
  toStateT f =
    do s  get
       y , s′  liftT (f s)
       put s′
       return y

  runStateT : State A  State-Term A
  runStateT = interpₜ runStateT-alg runStateT-ret runStateT-resp (isSetΠ λ _   squash/)

  stateT-iso : State A  State-Term A
  stateT-iso .fun = runStateT
  stateT-iso .inv = toStateT

  stateT-iso .rightInv =
    SplitQuotientedChoiceAt.elim~canonical
      choiceState _  _  isSetΠ  _  squash/) _ _)
      λ k  funExt λ s₁ 

      runStateT (do s₂  get
                    y , s₃  liftT [ k s₂ ]
                    put s₃
                    return y) s₁                                                ≡˘⟨ cong (flip runStateT s₁) (s₂  get ;/ cong (_>>= _) (liftT-lemma (k s₂))) 

      runStateT (do s₂  get
                    y , s₃  [ inj₁ (k s₂) ]
                    put s₃
                    return y) s₁                                                ≡˘⟨ cong (flip runStateT s₁) (syntactic-bind _ Synt.get ; s₂  get ;/ syntactic-bind _ (inj₁ (k s₂)) ) 


      runStateT [ Synt.get Synt.>>=  s₂ 
                  inj₁ (k s₂) Synt.>>= λ { (y , s₃) 
                  Synt.put s₃ Synt.>>
                  Synt.return y
                }) ] s₁                                                         ≡⟨⟩

      Synt.runStateT (inj₁ (k s₁) Synt.>>=  { (y , s₃) 
                      Synt.put s₃ Synt.>>
                      Synt.return y })) s₁                                      ≡⟨⟩

      Synt.runStateT
        (interp _ op  { (y , s₃)  Synt.put s₃ Synt.>> Synt.return y })
          (inj₁ (k s₁))) s₁                                                     ≡⟨ cong (_$ s₁) (interp-comp _ runStateT-alg runStateT-ret _ (inj₁ (k s₁))) 

      interp _ runStateT-alg
        (Synt.runStateT   { (y , s₃)  Synt.put s₃ Synt.>> Synt.return y }))
        (inj₁ (k s₁)) s₁                                                        ≡⟨⟩

      interp _ runStateT-alg (const  ℱ.return) (inj₁ (k s₁)) s₁                ≡⟨ cong (_$ s₁) (interp₁ inj-alg runState-alg _ (k s₁)) 

      interp _ inj-alg (const  ℱ.return) (k s₁) s₁                             ≡˘⟨ cong (_$ s₁) (interp-fusion _ (const  [_]) Syntax.var  xs  funExt λ s  sym (ℱ.opₜ-com-[] _)) (k s₁)) 

      [ interp _ op Syntax.var (k s₁) ]                                         ≡⟨ cong [_] (interp-id _ (k s₁)) 

      [ k s₁ ] 

  stateT-iso .leftInv = elimProp/  _  squash/ _ _) (elim-s _ _ leftInv-alg)
    where
    leftInv-alg : DepAlg (𝔽  𝕊) A λ xs  toStateT (runStateT [ xs ])  [ xs ]

    leftInv-alg (var x) =

      toStateT (runStateT (return x))              ≡⟨⟩

      toStateT  s  [ Syntax.var (x , s) ])      ≡⟨⟩

      (do s  get
          y , s′  liftT [ Syntax.var (x , s) ]
          put s′
          return y)                                ≡⟨⟩

      (do s  get
          put s
          return x)                                ≡⟨ get-put _ 

      return x 

    leftInv-alg (op (`get , k) P⟨xs⟩) =

      toStateT (runStateT [ Syntax.op (`get , k) ])                   ≡⟨⟩

      (do s  get
          y , s′  liftT (runStateT [ Syntax.op (`get , k) ] s)
          put s′
          return y)                                                   ≡⟨⟩

      (do s  get
          y , s′  liftT (runStateT [ k s ] s)
          put s′
          return y)                                                   ≡˘⟨ get-get _ 

      (do s₁  get
          s₂  get
          y , s′  liftT (runStateT [ k s₁ ] s₂)
          put s′
          return y)                                                   ≡⟨ s₁  get ;/ P⟨xs⟩ s₁ 

      (do s₁  get
          [ k s₁ ])                                                   ≡˘⟨ syntactic-bind k Synt.get 

      [ Syntax.op (`get , k) ] 


    leftInv-alg (op (`put x , k) P⟨xs⟩) =

      toStateT (runStateT [ op (`put x , k) ])                        ≡⟨⟩

      (do s₁  get
          y , s₂  liftT (runStateT [ op (`put x , k) ] s₁)
          put s₂
          return y)                                                   ≡⟨⟩

      (do s₁  get
          y , s₂  liftT (runStateT [ k tt ] x)
          put s₂
          return y)                                                   ≡⟨ get-nop _ 

      (do y , s₂  liftT (runStateT [ k tt ] x)
          put s₂
          return y)                                                   ≡˘⟨ (y , s₂)  liftT (runStateT [ k tt ] x) ;/ put-put _ _ (return y) 

      (do y , s₂  liftT (runStateT [ k tt ] x)
          put x
          put s₂
          return y)                                                   ≡˘⟨ state-commutes-mult [ 𝒮.Synt.put x ] (runStateT [ k tt ] x) _ 

      (do put x
          y , s₂  liftT (runStateT [ k tt ] x)
          put s₂
          return y)                                                   ≡˘⟨ put-get _ _ 

      (do put x
          s₁  get
          y , s₂  liftT (runStateT [ k tt ] s₁)
          put s₂
          return y)                                                   ≡⟨ cong (put x >>_) (P⟨xs⟩ tt) 

      (do put x
          [ k tt ])                                                   ≡⟨⟩

      [ op (`put x , k) ] 

    leftInv-alg (op (inl O , k) P⟨xs⟩) =

      toStateT (runStateT [ Syntax.op (inl O , k) ])                     ≡⟨⟩

      (do s  get
          y , s′  liftT (ℱ.opₜ (O , λ i  runStateT [ k i ] s))
          put s′
          return y)                                                      ≡⟨ s  get ;/ cong (_>>= λ { (y , s′)  put s′ >> return y }) ((lift₁-com O (flip runStateT s  [_]  k))) 

      (do s  get
          y , s′  opₜ (inl O , λ i  liftT (runStateT [ k i ] s))
          put s′
          return y)                                                      ≡⟨ s  get ;/ algebraic (inl O)  { (y , s′)  put s′ >> return y }) (liftT  flip runStateT s  [_]  k) 

      (do s  get
          opₜ (inl O , λ i  do y , s′  liftT (runStateT [ k i ] s)
                                put s′
                                return y))                               ≡⟨ state-commutes 𝒮.`get O  s i  liftT (runStateT [ k i ] s) >>= λ ys′   put (snd ys′) >> return (fst ys′)) 

      opₜ (inl O , λ i  do s  get
                            y , s′  liftT (runStateT [ k i ] s)
                            put s′
                            return y)                                    ≡⟨ cong (opₜ  _,_ (inl O)) (funExt P⟨xs⟩) 

      opₜ (inl O , λ i  [ k i ])                                        ≡⟨ opₜ-com-[] k 

      [ Syntax.op (inl O , k) ]