{-- | This module provides facilities for building Evaluation Dependancy Trees, 
    | trusting of modules, functions and applications, displaying trees, and
    | finding parts of the tree.
    | --}
module EDT (EDT(..),buildEDT,displayTree,displayTrees,trustIO,trustConstant
           ,trustModule,trustApps ,foldHiddens,leaves,subEDTs,detectCycles
           ,trustUnevaluated,(/==),trustMatchingFunction,rearange)
           where

import NodeExp          (NodeExp(..),children,(===),isomorphicIn,finalResult
                        ,flatEval,fullEval,getNode,result,expandFunction)
import SExp             (QName(..))
import System.IO.Unsafe (unsafePerformIO)
import SExp             (QName(..),SFixity(..))
import List             (nubBy,nub)
import Maybe            (fromJust,isJust)

{-- | The EDT data structure contains a NodeExp (always an application or
    | constant), which represents a reduction within the program.  It also
    | contains a list of child reductions.
    | --}
data EDT = Branch [Int] NodeExp [EDT]
         | Cycle [Int] [EDT]
           deriving Show

tag (Branch t _ _) = t
tag (Cycle t _) = t

instance Eq EDT where
  (==) (Branch _ n _) (Branch _ n' _) = n == n'
  (==) (Cycle _ xs) (Cycle _ xs') = xs == xs'
  (==) _ _ = False

(/==) :: EDT -> EDT -> Bool
(/==) e1 e2 = tag e1 /= tag e2

instance Ord EDT where
  compare (Branch _ n _) (Branch _ n' _) = compare n n'
  compare (Branch _ _ _) (Cycle t _) = GT
  compare (Cycle _ _) (Branch _ _ _) = LT
  compare (Cycle _ _) (Cycle _ _) = EQ

-- | Creates an evaluation dependancy tree
buildEDT :: NodeExp -- ^ The trace to gather an EDT from -
                    -- ^ in the form of a NodeExp
         -> [EDT]
buildEDT = buildEDT' [1]
buildEDT' :: [Int] -> NodeExp -> [EDT]
buildEDT' t exp@(NExpApp n _ _ _) =
  [Branch t exp (concatZipWith (\e nt -> buildEDT' (nt:t) e)
                               (children exp)
                               nats)]
buildEDT' t exp@(NExpConstUse n _ _) =
  [Branch t exp (concatZipWith (\e nt -> buildEDT' (nt:t) e)
                               (children exp)
                               nats)]
buildEDT' t exp@(NExpConstDef n _ _) = 
  [Branch t exp (concatZipWith (\e nt -> buildEDT' (nt:t) e)
                               (children exp)
                               nats)]
buildEDT' t exp@(NExpCond n _ _ r) = 
  buildEDT' t r
buildEDT' t exp@(NExpProjection n r) = 
  buildEDT' t r
buildEDT' t exp@(NExpForward n r) = 
  buildEDT' t r
buildEDT' t exp@(NExpHidden n _ _) =
  (concatZipWith (\e nt -> buildEDT' (nt:t) e)
                               (children exp)
                               nats)
buildEDT' _ _ = []

nats :: [Int]
nats = 1:map (+1) nats

concatZipWith f l1 l2= concat (zipWith f l1 l2)

detectCycles :: [EDT] -> [EDT]
detectCycles =
  map $ detCycles []
  where
    detCycles :: [EDT] -> EDT -> EDT
    detCycles previous x@(Branch t exp chldr) =
      let found = findExp (flatEval fullEval exp) previous
      in case found of
           [] -> Branch t exp $ map (detCycles (x:previous)) chldr
           d  -> Cycle t d

    findExp :: NodeExp -> [EDT] -> [EDT]
    findExp x l =
      filter ((x ===) . (flatEval fullEval) . (\(Branch _ e _) -> e)) l

-- | Displays an EDT using a provided formatting.
displayTree :: Int
            -> (Int -> NodeExp -> String)
            -> EDT
            -> String
displayTree w f (Branch _ exp ts) =
  f w exp ++ "\n" ++ (concat $ otherQuestions (w - 2) f ts)
displayTree w f (Cycle t cs) =
  if isJust exp
    then "CYCLE!\n" ++ f (w - 2) (fromJust exp)
    else "CYCLE!\n"
  where
    exp = findBranch cs
    findBranch [] = Nothing
    findBranch (b@(Branch _ n _):_) = Just n
    findBranch (_:xs) = findBranch xs

otherQuestions :: Int -> (Int -> NodeExp -> String) -> [EDT] -> [String]
otherQuestions w f edts = 
     zipWith prependTree bTrees prepends
  ++ zipWith prependTree wTree wPrepends
  where
    prepends = " | " : prepends
    wPrepends = "   " : wPrepends
    trees = map (displayTree (w - 3) f) edts
    (bTrees,wTree) = splitAt (length trees - 1) trees

prependTree text prepend = 
  case theLines of
    []     -> "Error: An empty question was produced"
    (x:xs) -> unlines (  (" +-" ++ x) : (map (prepend ++) $ xs))
  where
    theLines = lines text

displayTrees :: Int -> (Int -> NodeExp -> String) -> [EDT] -> String
displayTrees w f = concatMap (displayTree w f)

foldHiddens :: [EDT] -> [EDT]
foldHiddens = trustMatchingFunction False f
              where
                f (Branch _ e@(NExpHidden _ _ _) _) = True
                f _ = False

trustIO :: [EDT] -> [EDT]
trustIO =
  trustMatchingFunction False f
  where
    f (Branch _ e _) =
      (NExpIdentifier undefined
                      (Plain "{IO}")
                      undefined) `isomorphicIn` finalResult e

trustUnevaluated :: [EDT] -> [EDT]
trustUnevaluated = trustMatchingFunction False 
                                         (\(Branch _ e _)
                                            -> finalResult e === NExpUneval)

trustConstant :: QName -> [EDT] -> [EDT]
trustConstant name =
  trustMatchingFunction False f
  where
    f (Branch _ (NExpConstUse _ cName _) _) = name == cName
    f (Branch _ (NExpConstDef _ cName _) _) = name == cName
    f _ = False

trustModule :: String -> [EDT] -> [EDT]
trustModule mod =
  trustMatchingFunction False f
  where
    f (Branch _ (NExpApp _ (NExpIdentifier _ (Qualified mod' _) _) _ _) _)
      = mod == mod'
    f (Branch _ (NExpConstUse _ (Qualified mod' _) _) _)
      = mod == mod'
    f (Branch _ (NExpConstDef _ (Qualified mod' _) _) _)
      = mod == mod'
    f (Branch _ _ _) = False

trustApps :: [NodeExp] -> [EDT] -> [EDT]
trustApps apps trees =
  foldr trustApp trees apps
  where
    trustApp :: NodeExp -> [EDT] -> [EDT]
    trustApp exp = trustMatchingFunction False
                                         (  (compareApps exp)
                                          . (\(Branch _ e _) -> e))
    
    compareApps :: NodeExp -> NodeExp -> Bool
    compareApps (NExpApp _ f as _) (NExpApp _ f' as' _) =
         fullEval f === fullEval f'
      && all id (zipWith (===) (map fullEval as) (map fullEval as'))
    compareApps (NExpConstDef n _ _) (NExpConstDef n' _ _) = n == n'
    compareApps (NExpConstUse _ _ r) (NExpConstUse _ _ r') = compareApps r r'
    compareApps x y = False

{-removeRepeatedQuestions :: [EDT] -> [EDT]
removeRepeatedQuestions xs =
  map goIntoChildren (nubBy compareQuestions xs)
  where
    compareQuestions (Branch _ n _) (Branch _ n' _) = (flatEval fullEval n) === (flatEval fullEval n')
    compareQuestions (Cycle _ xs) (Cycle _ xs') = and (zipWith compareQuestions xs xs')
    compareQuestions _ _ = False
    
    goIntoChildren (Branch t n ch) = Branch t n (removeRepeatedQuestions ch)
    goIntoChildren (Cycle t xs) = Cycle t (map goIntoChildren xs)-}

trustMatchingFunction :: Bool -> (EDT -> Bool) -> [EDT] -> [EDT]
trustMatchingFunction _ _ [] = []
trustMatchingFunction tc f (b@(Branch t e ch):os)
  | f b       = if tc 
                  then trustMatchingFunction tc f os
                  else    trustMatchingFunction tc f ch
                       ++ trustMatchingFunction tc f os
  | otherwise =   (Branch t e (trustMatchingFunction tc f ch))
                : trustMatchingFunction tc f os 
trustMatchingFunction tc f (c@(Cycle t xs):os) =
  (Cycle t (trustMatchingFunction tc f xs)) : (trustMatchingFunction tc f os)

{-rearange :: NodeExp -> [EDT] -> [EDT]
rearange n es =
  concatMap (rearange1 n) es
  where
    rearange1 :: NodeExp -> EDT -> [EDT]
    rearange1 n e@(Branch t node es) =
      [Branch t (expandFunction n node) (rearange n es)]
    rearange1 n e@(Cycle t es) =
      [Cycle t (rearange n es)]-}

subEDTs :: Maybe Int -> [EDT] -> [EDT]
subEDTs _ []
  = []
subEDTs (Just 0) _
  = []
subEDTs (Just n) (e@(Branch _ _ []):others)
  = e:subEDTs (Just n) others
subEDTs (Just n) (e@(Branch _ _ chldrn):others)
  = e:(subEDTs (Just (n-1)) chldrn ++ subEDTs (Just n) others)
subEDTs Nothing (e@(Branch _ _ []):others)
  = e:subEDTs Nothing others
subEDTs Nothing (e@(Branch _ _ chldrn):others)
  = e:(subEDTs Nothing chldrn ++ subEDTs Nothing others)
subEDTs x ((Cycle _ e):others)
  = subEDTs x others

leaves :: EDT -> [NodeExp]
leaves (Branch _ e []) = [e]
leaves (Branch _ e chldrn) = e:(concatMap leaves chldrn)
leaves (Cycle t e) = []

rearange :: NodeExp -> [EDT] -> [EDT]
rearange n xs = stripped ++ collected
                where
                  (stripped, collected) = stripCollect n xs

stripCollect :: NodeExp -> [EDT] -> ([EDT],[EDT])
stripCollect n [] = ([],[])
stripCollect n ((Branch t app@(NExpApp _ f _ _) cs):others) = 
  if f == n
    then (so, (Branch t app sc) : (cc ++ co))
    else ((Branch t (expandFunction n app) sc):so, cc ++ co)
  where
    (sc, cc) = stripCollect n cs
    (so, co) = stripCollect n others
stripCollect n ((Branch t const cs):others) = 
  ((Branch t const sc):so, cc ++ co)
  where
    (sc, cc) = stripCollect n cs
    (so, co) = stripCollect n others
stripCollect n ((Cycle t cs):others) = 
  ((Cycle t sc):so, co)
  where
    (sc, cc) = stripCollect n cs
    (so, co) = stripCollect n others