import Monad
import Maybe
import Lang

-- pridame osetreni chyb
-- trida MonadPlus zavadi funkce
-- mplus::m a -> m a -> m a
-- interpretace dle monady, napriklad
--    1) pokud prvni vetev selze, zkus druhou (nas pripad), nebo
--    2) vyzkousej obe vetve
-- a mzero::m a  (neutralni prvek pro mplus)

data Result x = Chyba String | Hodnota x deriving (Show)

instance Monad Result where
  Chyba s >>=  _ = Chyba s
  Hodnota a >>= f = f a
  return x = Hodnota x
  fail s = Chyba s

type Stream w = [w]->[w]

data Vypocet w s x = V ((s,Stream w)->(s, Stream w, Result x))
unV (V vyp) = vyp

instance Monad (Vypocet w s) where
  V vyp1 >>= f = V spoj
    where
      spoj stav = let (stav', w', val) = vyp1 stav
                   in case val of
		        Chyba ch -> (stav', w', Chyba ch)
			Hodnota x -> unV (f x) (stav', w')
  return x = V (\(s,w) -> (s,w,return x))
  fail ch = V (\(s,w) -> (s,w,fail ch))

instance MonadPlus (Vypocet w s) where
 V f1 `mplus` V f2 = V spoj
   where
     spoj (s, w) = let (s', w', r') = f1 (s, w) in
                      case r' of
		        Chyba _ -> f2 (s', w')  -- kdybychom chteli zamezit zmenam/vystupum z chybne vetve, volali bychom f2 (s, w)
			_ -> (s', w', r')
 mzero = V (\(s,w)->(s,w,fail "mzero"))

class MonadRead m s | m -> s where
  get :: m s

instance MonadRead (Vypocet w s) s where
  get = V (\(s,w) -> (s,w,return s))

class MonadState m s | m -> s where
  put :: s -> m ()

instance MonadState (Vypocet w s) s where
  put s = V (\(_,w) -> (s,w,return ()))

class MonadWrite m w | m -> w where
  write :: w -> m ()

instance MonadWrite (Vypocet w s) w where
  write w = V (\(s, ws) -> (s, ws . (w:), return ()))

runVypocet :: Vypocet w s x -> s -> (s, [w], Result x)
runVypocet (V f) state = (s, ws [], r)
  where
    (s, ws, r) = f (state, id)

eval::(Monad m, MonadPlus m, MonadRead m Values, MonadState m Values, MonadWrite m Integer) => Expr->m Integer
eval (Plus e1 e2) = 
  do
   r1 <- eval e1
   r2 <- eval e2
   return (r1 + r2)
eval (Minus e1 e2) =
  do
   r1 <- eval e1
   r2 <- eval e2
   return (r1 - r2)
eval (Mul e1 e2) = 
  do
   r1 <- eval e1
   r2 <- eval e2
   return (r1 * r2)
eval (Div e1 e2) =
  do
   r1 <- eval e1
   r2 <- eval e2
   if r2 == 0 then fail "Deleni nulou" else return (r1 `div` r2)
eval (Mod e1 e2) =
  do
   r1 <- eval e1
   r2 <- eval e2 
   if r2 == 0 then fail "Deleni nulou" else return (r1 `mod` r2)
eval (Negate e) =
  do
   r <- eval e
   return (negate r)
eval (Num n) = return n
eval (Var s) =
  do
    ohodnoceni <- get
    case lookup s ohodnoceni of
      Just x -> return x
      Nothing -> fail ("Neznama promenna " ++ s)
eval (Assign s e) =
  do
    r <- eval e
    ohodnoceni <- get
    put (update ohodnoceni s r)
    return r
eval (Output e) =
  do
    r <- eval e
    write r
    return r
eval (Try e1 e2) =
  eval e1 `mplus` eval e2

update::Values->Variable->Integer->Values
update [] s v = [(s,v)]
update ((s1,v1):t) s v
  | s == s1 = (s,v) : t
  | otherwise = (s1, v1) : update t s v

