{-# LANGUAGE RecordWildCards,RankNTypes,FlexibleContexts,MonoLocalBinds #-}
-- | Module      : ML
-- Description : ML Pipeline.
module ML (aggregator,learner,P(..),initialState,Prediction(..)) where
import Model
import OCO (Lambda(..),Ogdstate(..),predictReg,fitRidge,initialOGDState
           ,LearningRate)
import Data.Conduit (ConduitT, yield, awaitForever)
import qualified Data.Vector as V (Vector, fromList, length, replicate)
import Data.Default.Class (def)

featureBuilder :: X -> V.Vector Double
featureBuilder X{..} = V.fromList $ 1.0 : (concat $ (map tickToList lastpticks))
  where tickToList NewTick{..} = [ newOpen
                                 , lastClose
                                 , lastHigh
                                 , lastLow
                                 , lastVolume]

-- | aggregation size p
newtype P = P Int
-- | AR(p) model builder. Conduit input: New tick information, output: y^t-1, x^t.
aggregator :: Monad m =>
  P ->
  ConduitT NewTick (Maybe (Y, X)) m ()
aggregator (P p) = agg []
  where agg q = awaitForever $ process q
        process q t =
          if length q < p
          then do yield Nothing
                  agg $ q ++ [t]
          else let q' = tail q ++ [t]
               in do yield $ Just $ (Y (lastHigh t), X q')
                     agg $ q'

-- | Learner. Conduit input: (y^t-1, x^t), output: \widehat{y}^t
learner :: (Monad m) =>
  LearningRate Double ->
  Lambda Double ->              --lambda(regularizer size)
  Ogdstate V.Vector Double ->   --w^t-1
  Maybe X ->                    --x^t-1
  ConduitT (Maybe (Y, X)) Prediction m ()
learner rate (Lambda lambda) state x' = awaitForever process
   where fitter = OCO.fitRidge rate (Lambda lambda)
         update y = case x' of
            Just x'' -> fitter state (featureBuilder x'') (high y)
            Nothing -> state
         process (Just (y, x)) =
          let state' = update y
          in do
           let xv = featureBuilder x
           yield $ Prediction $ OCO.predictReg state' xv
           learner rate (Lambda lambda) state' $ Just x
         process Nothing = do
           yield $ Ingested
           learner rate (Lambda lambda) state x'

-- | Helper function. Builds an initial state for 'learner'.
initialState :: Int -> Ogdstate V.Vector Double
initialState p = OCO.initialOGDState $ V.replicate n 0.0
  where n = 1 + p * (V.length $ featureBuilder $ Data.Default.Class.def)