{-# LANGUAGE RankNTypes #-} -- | Module : OCO -- Description : Online Convex Optimization algorithms. module OCO (Ogdstate(..),Lambda(..),predictReg,fitRidge,initialOGDState ,squaredRate,fixedRate,LearningRate,FitReg) where -- From Linear.Vector, we get the Additive typeclass. -- Additive exposes vector substraction and scalar multiplication/division -- for arbitrary memory representations. import Linear.Vector (Additive, (^-^), (^/), (*^), (^*)) -- We'll need Linear.Metric later on for dot products and squared norms. import Linear.Metric (Metric, norm, dot) -- | Derivative of a loss function w.r.t the action w. type LossDerivative v x = (Additive v, Floating x) => v x -> -- w v x -- dc/dw -- | Projection operator. type Projection v x = (Additive v, Floating x) => v x -> v x -- | Learning Rate. type LearningRate x = (Floating x) => Int -> x -- | Inverse squared learning rate. squaredRate :: (Floating x) => x -> LearningRate x squaredRate eta t = eta / fromIntegral (t * t) -- | Fixed learning rate. fixedRate :: (Floating x) => x -> LearningRate x fixedRate eta _ = eta -- | Algorithm state for OGD. Encodes both the round -- count and the last action w. data Ogdstate v x = Ogdstate Int (v x) -- | The OGD algorithm. ogd :: (Additive v, Floating x) => (LearningRate x) -> (Projection v x) -> (Ogdstate v x) -> -- Last state/action w^t (LossDerivative v x) -> -- Derivative of c^t (Ogdstate v x) -- New state/action w^(t+1) ogd rate projection (Ogdstate t w) dc = Ogdstate (t+1) \$ projection \$ w ^-^ ( rate t *^ dc w) -- | Helper. Builds an initial state for 'ogd'. initialOGDState :: (Additive v, Floating x) => v x -> Ogdstate v x initialOGDState w = Ogdstate 1 w -- | Differential of a regression loss type RegressionLossDerivative v x = v x -> -- x^t x -> -- y^t (LossDerivative v x) -- c^t -- | Prediction function predictReg :: (Metric v, Floating x) => Ogdstate v x -> v x -> x predictReg (Ogdstate _ w) x = w `dot` x -- | Regression fit call type FitReg v x = Ogdstate v x -> --the old learner state v x -> --the x^t feature vector x -> --the y^t regression target Ogdstate v x --the new learner state -- | Fitting one regression data point using OGD. -- | The first three arguments of this function parametrize the problem. fitOGDReg :: (Additive v, Floating x) => (LearningRate x ) -> (Projection v x) -> (RegressionLossDerivative v x) -> FitReg v x fitOGDReg rate projection differential state x y = ogd rate projection state \$ differential x y -- | Argument name for Lambda. newtype Lambda x = Lambda x -- | Projection on a L2 Ball of radius lambda. l2Projection :: (Metric v, Ord x, Floating x) => Lambda x -> Projection v x l2Projection (Lambda lambda) w = if q <= lambda then w else lambda *^ w ^/ q where q = norm w -- | Derivative of the square loss function. mseDiff :: (Metric v) => RegressionLossDerivative v x mseDiff x y w = x ^* ( w `dot` x - y) -- | Online ridge regressor via OGD. fitRidge :: (Metric v, Ord x, Floating x) => LearningRate x -> Lambda x -> FitReg v x fitRidge rate lambda = fitOGDReg rate (l2Projection lambda) mseDiff