{-
    Kaya - My favourite toy language.
    Copyright (C) 2004-2007 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

-- Pattern matching compiler, generating simple case trees from
-- case on complex expressions. See notes/pmcomp.txt for details

module PMComp where

import Language

import Control.Monad.State
import Debug.Trace
import List(group,sort)

mkCase :: String -> Int -> Context -> Name -> -- for finding constructor names
          Int -> -- first variable to introduce (we might need this several times per function, so names we introduce need to be unique!)
          Raw -> [MatchAlt] -> (Raw, Int)
mkCase f l ctxt mod name v ms 
    = let (tm, CS var) = runState (caseComp f l ctxt mod [v] (addCatchAll ms)) (CS name) in
          (tm, var)
  where addCatchAll [] = [MAlt f l [RUnderscore f l]
                          (RThrow f l (RApply f l (RQVar f l missingCase) []))]
        -- Next three cases identify that there is already a catch all case,
        -- so no need to add it.
        addCatchAll (m@(MAlt f l [RUnderscore _ _] res):ms) = m:ms
        addCatchAll (m@(MAlt f l [RQVar _ _ _] res):ms) = m:ms
        addCatchAll (m@(MAlt f l [RVar _ _ v] res):ms) 
            | isVarName mod ctxt v = m:ms
        addCatchAll (m:ms) = m:(addCatchAll ms)

-- Fail if the alternative is non-linear (i.e. has a repeated name)

checkLinear :: Monad m => Name -> Context -> MatchAlt -> m ()
-- Get all the names from ps; 
checkLinear mod ctxt (MAlt f l ps _) = do
   let allnames = (group.sort) (filter (isVarName mod ctxt) (concat (map getPvars ps)))
   checkUnique allnames
  where checkUnique [] = return ()
        checkUnique (x:xs) | length x == 1 = checkUnique xs
                           | otherwise = fail $ f ++ ":" ++ show l ++ 
                                              ":Name '" ++ 
                                              showuser (head x) ++ 
                                              "' is repeated in pattern"

-- Get the names used in a raw term representing a pattern

getPvars :: Raw -> [Name]
getPvars (RVar _ _ x) = [x]
getPvars (RQVar _ _ x) = [x]
getPvars (RApply f l fn args) = getPvars fn ++ concat (map getPvars args)
getPvars _ = []

data CaseState = CS { nextVar :: Int }

getVar :: State CaseState Name
getVar = do (CS var) <- get
            put (CS (var+1))
            return $ MN ("pv", var)

getNewVars :: Name -> Context -> [Raw] -> State CaseState [Name]
getNewVars mod ctxt [] = return []
getNewVars mod ctxt ((RQVar _ _ n):xs) = do rest <- getNewVars mod ctxt xs
                                            return (n:rest)
getNewVars mod ctxt ((RVar _ _ n):xs) 
    | isVarName mod ctxt n = do rest <- getNewVars mod ctxt xs
                                return (n:rest)
getNewVars mod ctxt (x:xs) = do v <- getVar
                                rest <- getNewVars mod ctxt xs
                                return (v:rest)

-- A scrutinee needs to be a variable for the algorithm to work. If it's
-- not, make a new variable for it. Return a triple of the term, the variable
-- to examine, and whether we made a new variable for it.

{-
getScrutinee :: Name -> Context -> Raw -> State CaseState (Raw, Name, Bool)
getScrutinee mod ctxt r@(RQVar _ _ n) = return (r, n, False)
getScrutinee mod ctxt r@(RVar _ _ n) 
    | isVarName mod ctxt n = return (r, n, False)
getScrutinee _ _ r = do v <- getVar
                        return (r, v, True)
-}

caseComp :: String -> Int -> Context -> Name ->
            [Raw] -> [MatchAlt] -> 
            State CaseState Raw
caseComp f l ctxt n rs ms = do
     (CS firstv) <- get
--     rvs <- mapM (getScrutinee n ctxt) rs
     tm <- match f l ctxt n rs ms err
--     bn <- bindNames rvs tm
     -- Need to add explicit declarations for the names we introduced
     (CS lastv) <- get
     let dn = declareNames firstv lastv tm
     return dn
   where {- bindNames [] tm = return tm
         bindNames ((r,v, True):rvs) tm
             = bindNames rvs (RSeq f l (RAssign f l (RAName f l v) r) tm)
         bindNames ((r,v, False):rvs) tm
             = bindNames rvs tm -}

         declareNames firstv v tm 
             | v == firstv = tm
             | otherwise = RDeclare f l (MN ("pv", v-1),False) UnknownType
                                         (declareNames firstv (v-1) tm)
         err = RThrow f l (RApply f l (RQVar f l missingCase) [])

-- The match alternatives could either be all variables, all constructors,
-- or a mixture.

match :: String -> Int -> Context -> Name ->
         [Raw] -> [MatchAlt] -> Raw ->
         State CaseState Raw
match _ _ ctxt n [] ((MAlt f l [] res):_) err = return res
match f l ctxt n vs ms err = let ps = partition n ctxt ms in
                                      mixture f l ctxt n vs ps err

-- Mixture rule applies the Variable and Constructor rules in order
-- as appropriate. For each one, compute what to do (the fallthrough) if
-- the given partition fails to match (i.e. what the remaining partitions
-- compile to) then run the constructor rule or variable rule.

mixture :: String -> Int -> Context -> Name -> [Raw] ->
             [Partition] -> Raw -> State CaseState Raw
mixture f l ctxt n vs [] err = return err
mixture f l ctxt n vs ((Cons ms):ps) err 
    = do fallthrough <- (mixture f l ctxt n vs ps err)
         conRule f l ctxt n vs ms fallthrough
mixture f l ctxt n vs ((Vars ms):ps) err 
    = do fallthrough <- (mixture f l ctxt n vs ps err)
         varRule f l ctxt n vs ms fallthrough

data Partition = Cons [MatchAlt]
               | Vars [MatchAlt]
  deriving Show

partition :: Name -> Context -> [MatchAlt] -> [Partition]
partition mod ctxt [] = []
partition mod ctxt ms@(m:_) 
    | isVar mod ctxt m = let (vars, rest) = span (isVar mod ctxt) ms in
                            (Vars vars):(partition mod ctxt rest)
    | isCon mod ctxt m = let (cons, rest) = span (isCon mod ctxt) ms in
                            (Cons cons):(partition mod ctxt rest)
partition mod ctxt x = error (show x)

allVars n c = all (isVar n c)
allCons n c = all (isCon n c)

isVar :: Name -> Context -> MatchAlt -> Bool
isVar mod ctxt (MAlt _ _ ((RVar _ _ v):_) _) 
   = isVarName mod ctxt v
isVar mod ctxt (MAlt _ _ ((RQVar _ _ v):_) _) 
   = isVarName mod ctxt v
isVar mod ctxt (MAlt _ _ ((RUnderscore _ _):_) _) = True
isVar _ _ _ = False

isVarName mod ctxt v
     = all nonConstructor (lookupname mod v ctxt)
   where nonConstructor (n,(ty,opts)) = not (Constructor `elem` opts)

isConName mod ctxt v
     = any constructor (lookupname mod v ctxt)
   where constructor (n,(ty,opts)) = Constructor `elem` opts

isCon :: Name -> Context -> MatchAlt -> Bool
isCon mod ctxt (MAlt _ _ ((RApply _ _ (RVar _ _ con) args):_) _) = True
isCon mod ctxt (MAlt _ _ ((RApply _ _ (RQVar _ _ con) args):_) _) = True
isCon mod ctxt (MAlt _ _ ((RVar _ _ v):_) _) 
   = isConName mod ctxt v
isCon mod ctxt (MAlt _ _ ((RQVar _ _ v):_) _) 
   = isConName mod ctxt v
isCon mod ctxt (MAlt _ _ ((RConst _ _ _):_) _) = True
isCon mod ctxt (MAlt _ _ ((RArrayInit _ _ _):_) _) = True
isCon _ _ _ = False
                         
varRule :: String -> Int -> Context -> Name ->
           [Raw] -> [MatchAlt] -> Raw ->
           State CaseState Raw
varRule f l c n (v:vs) alts err = do
    let alts' = map (repVar v) alts
    match f l c n vs alts' err
  where repVar v (MAlt _ _ ((RVar f l n):ps) res) =
            -- replace n with v in res
            MAlt f l ps (rawSubst n v res)
        repVar v (MAlt _ _ ((RQVar f l n):ps) res) =
            -- replace n with v in res
            MAlt f l ps (rawSubst n v res)
        repVar v (MAlt _ _ ((RUnderscore f l):ps) res) =
            MAlt f l ps res

-- The things we treat as constructors
data ConType = CName Name -- ordinary named constructor
             | CConst Const -- constant pattern
             | CArray Int -- array pattern with length
   deriving Show

data Group = ConGroup ConType -- constructor
             -- arguments and rest of alternative for each instance
                   [([Raw], MatchAlt)] 
   deriving Show

conRule :: String -> Int -> Context -> Name ->
           [Raw] -> [MatchAlt] -> Raw ->
           State CaseState Raw
conRule f l ctxt mod (v:vs) alts err = 
  do groups <- groupCons alts ctxt mod
     caseGroups f l ctxt mod (v:vs) groups err

caseGroups :: String -> Int -> Context -> Name ->
              [Raw] -> [Group] -> Raw ->
              State CaseState Raw
caseGroups f l ctxt mod (v:vs) gs err 
    = do g <- altGroups gs
         return $ RCase f l v g
   where altGroups [] = return [RDefault f l err]
         altGroups ((ConGroup (CName n) args):cs) 
             = do g <- altGroup n args
                  rest <- altGroups cs
                  return (g:rest)
         altGroups ((ConGroup (CConst cval) args):cs)
             = do g <- altConstGroup cval args
                  rest <- altGroups cs
                  return (g:rest)
         altGroups ((ConGroup (CArray len) args):cs)
             = do g <- altArrayGroup len args
                  rest <- altGroups cs
                  return (g:rest)

         altGroup n gs
             = do (newArgs, nextMs) <- argsToAlt mod ctxt gs
                  matchMs <- match f l ctxt mod 
                                (map (RQVar f l) newArgs++vs) nextMs err
                  return $ RAlt f l n newArgs matchMs
         altConstGroup n gs
             = do (_, nextMs) <- argsToAlt mod ctxt gs
                  matchMs <- match f l ctxt mod vs nextMs err
                  return $ RConstAlt f l n matchMs
         altArrayGroup len gs
             = do (newArgs, nextMs) <- argsToAlt mod ctxt gs
                  matchMs <- match f l ctxt mod 
                                (map (RQVar f l) newArgs++vs) nextMs err
                  return $ RArrayAlt f l newArgs matchMs

argsToAlt :: Name -> Context -> 
             [([Raw], MatchAlt)] -> State CaseState ([Name], [MatchAlt])
argsToAlt mod ctxt [] = return ([],[])
argsToAlt mod ctxt rs@((r,m):_) 
    = do -- generate new argument names
         newArgs <- getNewVars mod ctxt r
         -- generate new match alternatives, by appending all the rs to m
         return (newArgs, addRs rs)
  where addRs [] = []
        addRs ((r,(MAlt f l ps res) ):rs)
            = (MAlt f l (r++ps) res):(addRs rs)

-- Divide the alternatives into Groups

groupCons :: Monad m => [MatchAlt] -> Context -> Name -> m [Group]
groupCons ms ctxt mod = gc [] ms
  where
   gc acc [] = return acc
   gc acc ((MAlt f l (p:ps) res):ms) = do
       acc' <- addGroup f l p ps res acc
       gc acc' ms

   addGroup f l p ps rval acc = case isPatt p ctxt mod of
       ConPatt con conargs -> return $ addg con conargs (MAlt f l ps rval) acc
       Constant cval -> return $ addConG cval (MAlt f l ps rval) acc
       ArrayPatt len conargs -> return $ addArrG len conargs (MAlt f l ps rval) acc
       pat -> fail $ f ++ ":" ++ show l ++ ":I don't understand this pattern " -- ++ show pat

--   addgAll var res [] = [OneVar var res]
--   addgAll var res (g@(ConGroup n cs):gs)
--       = (ConGroup n (cs ++ [([], res)

   addg con conargs res [] = [ConGroup (CName con) [(conargs, res)]]
   addg con conargs res (g@(ConGroup (CName n) cs):gs) 
      | con == n = (ConGroup (CName n) (cs ++ [(conargs, res)])):gs
      | otherwise = g:(addg con conargs res gs)

   addConG con res [] = [ConGroup (CConst con) [([], res)]]
   addConG con res (g@(ConGroup (CConst n) cs):gs) 
      | con == n = (ConGroup (CConst n) (cs ++ [([], res)])):gs
      | otherwise = g:(addConG con res gs)

   addArrG len conargs res [] = [ConGroup (CArray len) [(conargs, res)]]
   addArrG len conargs res (g@(ConGroup (CArray n) cs):gs) 
      | len == n = (ConGroup (CArray n) (cs ++ [(conargs, res)])):gs
      | otherwise = g:(addArrG len conargs res gs)

-- match f l ctxt mod [] _ = fail "Can't match with no scrutinee" 
-- match f l ctxt mod (e:[]) ms = do
--      (unsimple, groups) <- groupCons ms ctxt mod
--      trace (show groups) $
--       if (unsimple == 0) then return $ mkSimpleCase f l e groups
--         else do fail "unfinished"

{-
mkSimpleCase :: String -> Int -> Raw -> [Group] -> Raw
mkSimpleCase f l e gs = RCase f l e (map mkAlt gs)
   where mkAlt (Simple f l c args ret) = RAlt f l c args ret
         mkAlt _ = error "Can't happen PMComp mkAlt"
-}

data Patt = ConPatt Name [Raw]
          | VarPatt Name
          | Constant Const
          | ArrayPatt Int [Raw]
          | DefaultPatt
          | NoPatt
   deriving (Show, Eq)

pcons x xs = RApply "foo" 1 (RVar "foo" 1 (UN "Cons")) [x, xs]
pvar x = RVar "foo" 1 (UN x)
pnil = RVar "foo" 1 (UN "Nil")
pint n = RConst "foo" 1 (Num n)

testAlts = [MAlt "foo" 1 [pcons (pvar "x") (pcons (pvar "y") (pvar "ys"))] (pint 2),
            MAlt "foo" 1 [pcons (pvar "x") pnil] (pint 1),
            MAlt "foo" 1 [pnil] (pint 0)]

isPatt (RApply _ _ (RVar _ _ con) args) _ _ = ConPatt con args
isPatt (RApply _ _ (RQVar _ _ con) args) _ _ = ConPatt con args
isPatt (RVar _ _ v) ctxt mod = vPatt v ctxt mod
isPatt (RQVar _ _ v) ctxt mod = vPatt v ctxt mod
isPatt (RConst _ _ c) ctxt mod = Constant c
isPatt (RArrayInit _ _ ps) ctxt mod = ArrayPatt (length ps) ps
isPatt p _ _ = NoPatt

-- If v is in the context, it can't be a local variable so let's treat it
-- as a constructor pattern. If it isn't a constructor pattern this will
-- fail harmlessly (and get caught by the typechecker) so no need for anything
-- more fancy.
vPatt v ctxt mod = case ctxtlookup mod v ctxt Nothing [] of
     Just _ -> ConPatt v []
     _ -> VarPatt v

-- If all the groups are in simple case expression form, we've won. Each
-- group is guaranteed to begin with a different constructor, so no need
-- to worry there. Returns the number of groups still to deal with, and
-- the new groupings after simplification.

{-
simplify :: [Group] -> (Int, [Group])
simplify gs = simpl gs [] (length gs)
   where simpl [] acc i = (i, reverse acc)
         simpl ((ConGroup n [(args, MAlt f l [] res)]):gs) acc i
             | Just argnames <- mapM getName args 
                 = simpl gs ((Simple f l n argnames res):acc) (i-1)
         simpl (g:gs) acc i = simpl gs (g:acc) i

getName (RVar _ _ n) = Just n
getName (RQVar _ _ n) = Just n
getName _ = Nothing
-}