{-# LANGUAGE RankNTypes #-}
-- | Module      : OCO
-- Description : Online Convex Optimization algorithms.
module OCO (Ogdstate(..),Lambda(..),predictReg,fitRidge,initialOGDState
           ,squaredRate,fixedRate,LearningRate,FitReg) where
-- From Linear.Vector, we get the Additive typeclass.
-- Additive exposes vector substraction and scalar multiplication/division
-- for arbitrary memory representations.
import Linear.Vector  (Additive, (^-^), (^/), (*^), (^*))
-- We'll need Linear.Metric later on for dot products and squared norms.
import Linear.Metric (Metric, norm, dot)

-- | Derivative of a loss function w.r.t the action w.
type LossDerivative v x = (Additive v, Floating x) =>
  v x ->  -- w
  v x     -- dc/dw

-- | Projection operator.
type Projection v x = (Additive v, Floating x) => v x -> v x

-- | Learning Rate.
type LearningRate x = (Floating x) => Int -> x

-- | Inverse squared learning rate.
squaredRate :: (Floating x) => x -> LearningRate x
squaredRate eta t = eta / fromIntegral (t * t)

-- | Fixed learning rate.
fixedRate :: (Floating x) => x -> LearningRate x
fixedRate eta  _ = eta

-- | Algorithm state for OGD. Encodes both the round
-- count and the last action w.
data Ogdstate v x = Ogdstate Int (v x)

-- | The OGD algorithm.
ogd :: (Additive v, Floating x) =>
  (LearningRate x)       ->
  (Projection v x)       ->
  (Ogdstate v x)         -> -- Last state/action w^t
  (LossDerivative v x)   -> -- Derivative of c^t
  (Ogdstate v x)            -- New state/action w^(t+1)
ogd rate projection (Ogdstate t w) dc = Ogdstate (t+1) $
  projection $ w ^-^ ( rate t *^ dc w)

-- | Helper. Builds an initial state for 'ogd'.
initialOGDState :: (Additive v, Floating x) => v x -> Ogdstate v x
initialOGDState w = Ogdstate 1 w

-- | Differential of a regression loss
type RegressionLossDerivative v x =
  v x ->                  -- x^t
  x   ->                  -- y^t
  (LossDerivative v x)  -- c^t

-- | Prediction function
predictReg :: (Metric v, Floating x) => Ogdstate v x -> v x -> x
predictReg (Ogdstate _ w) x = w `dot` x

-- | Regression fit call
type FitReg v x =
  Ogdstate v x                     ->  --the old learner state
  v x                              ->  --the x^t feature vector
  x                                ->  --the y^t regression target
  Ogdstate v x                         --the new learner state

-- | Fitting one regression data point using OGD.
-- | The first three arguments of this function parametrize the problem.
fitOGDReg :: (Additive v, Floating x) =>
  (LearningRate x ) ->
  (Projection v x) ->
  (RegressionLossDerivative v x) ->
  FitReg v x
fitOGDReg rate projection differential state x y =
  ogd rate projection state $ differential x y

-- | Argument name for Lambda.
newtype Lambda x = Lambda x
-- | Projection on a L2 Ball of radius lambda.
l2Projection :: (Metric v, Ord x, Floating x) =>
  Lambda x ->
  Projection v x
l2Projection (Lambda lambda) w = if q <= lambda
                                 then w
                                 else lambda *^ w ^/ q
                                 where q = norm w

-- | Derivative of the square loss function.
mseDiff :: (Metric v) => RegressionLossDerivative v x
mseDiff x y w = x ^* ( w `dot` x - y)

-- | Online ridge regressor via OGD.
fitRidge :: (Metric v, Ord x, Floating x) =>
  LearningRate x ->
  Lambda x ->
  FitReg v x
fitRidge rate lambda = fitOGDReg rate (l2Projection lambda) mseDiff