{-# LANGUAGE RankNTypes #-} module OCO where -- From Linear.Vector, we get the Additive typeclass. -- Additive exposes vector substraction and scalar multiplication/division -- for arbitrary memory representations. import Linear.Vector (Additive, (^-^), (^/), (*^), (^*)) -- We'll need Linear.Metric later on for dot products and squared norms. import Linear.Metric (Metric, norm, dot)-- | Derivative of a loss function w.r.t the action w. type LossDerivative v a = v a -> -- ^ w v a -- ^ dc/dw -- | Projection operator. type Projection v a = v a -> -- ^ input vector v a -- ^ output vector -- | Learning rate schedule. type LearningRate a = Int -> -- ^ Round counter a -- ^ Rate value-- | Inverse squared learning rate. rootRate :: (Floating a) => a -> LearningRate a rootRate eta t = eta / sqrt (fromIntegral (t))-- | Fixed learning rate. fixedRate :: a -> LearningRate a fixedRate = const-- | Algorithm state for OGD. Encodes both the round count and the last -- action. data Ogdstate v a = Ogdstate Int -- ^ round counter (v a) -- ^ weights -- | The OGD algorithm. ogd :: (Additive v, Floating a) => (LearningRate a) -> -- ^ Learning rate schedule (Projection v a) -> -- ^ Projection operator (Ogdstate v a) -> -- ^ Last state/action (LossDerivative v a) -> -- ^ Derivative of the loss at this round (Ogdstate v a) -- ^ New state/action ogd rate projection (Ogdstate t w) dc = Ogdstate (t+1) \$ projection \$ w ^-^ ( rate t *^ dc w) -- | Smart constuctor for @Ogdstate@. initialOGDState :: v a -> Ogdstate v a initialOGDState w = Ogdstate 1 w-- | Differential of a regression loss type RegressionLossDerivative v a = v a -> -- ^ x^t a -> -- ^ y^t (LossDerivative v a) -- ^ c^t-- | Prediction function predictReg :: (Metric v, Floating a) => v a -> -- ^ Feature vector Ogdstate v a -> -- ^ Learner a -- ^ Prediction predictReg x (Ogdstate _ w) = w `dot` x-- | Regression fit call type FitReg v a = v a -> -- ^ Feature vector x^t a -> -- ^ Regression target y^t Ogdstate v a -> -- ^ Old learner state Ogdstate v a -- | Generates a fitter for one regression data point using OGD. fitOGDReg :: (Additive v, Floating a) => (LearningRate a) -> -- ^ Learning rate (Projection v a) -> -- ^ Projection operator (RegressionLossDerivative v a) -> -- ^ Regression loss derivative FitReg v a -- ^ Fitting function fitOGDReg rate projection differential x y state = ogd rate projection state \$ differential x y-- | Argument name for Lambda. newtype Lambda a = Lambda a -- | Projection on a L2 Ball. l2Projection :: (Metric v, Ord a, Floating a) => Lambda a -> -- ^ radius Projection v a -- ^ projection operator l2Projection (Lambda lam) w | q <= lam = w | otherwise = lam *^ w ^/ q where q = norm w -- | Derivative of the square loss function. mseDiff :: (Metric v, Floating a) => RegressionLossDerivative v a mseDiff x y w = x ^* ( w `dot` x - y) -- | Online ridge regressor via OGD. fitRidge :: (Metric v, Ord a, Floating a) => LearningRate a -> -- ^ Rate schedule Lambda a -> -- ^ Constraint set radius FitReg v a fitRidge rate lambda = fitOGDReg rate (l2Projection lambda) mseDiff