Package org.apache.spark.ml.ann
Interface LayerModel
- All Superinterfaces:
Serializable
Trait that holds Layer weights (or parameters).
Implements functions needed for forward propagation, computing delta and gradient.
Can return weights in Vector format.
-
Method Summary
Modifier and TypeMethodDescriptionvoid
computePrevDelta
(breeze.linalg.DenseMatrix<Object> delta, breeze.linalg.DenseMatrix<Object> output, breeze.linalg.DenseMatrix<Object> prevDelta) Computes the delta for back propagation.void
Evaluates the data (process the data through the layer).void
grad
(breeze.linalg.DenseMatrix<Object> delta, breeze.linalg.DenseMatrix<Object> input, breeze.linalg.DenseVector<Object> cumGrad) Computes the gradient.breeze.linalg.DenseVector<Object>
weights()
-
Method Details
-
computePrevDelta
void computePrevDelta(breeze.linalg.DenseMatrix<Object> delta, breeze.linalg.DenseMatrix<Object> output, breeze.linalg.DenseMatrix<Object> prevDelta) Computes the delta for back propagation. Delta is allocated based on the size provided by the LayerModel implementation and the stack (batch) size. Developer is responsible for checking the size of prevDelta when writing to it.- Parameters:
delta
- delta of this layeroutput
- output of this layerprevDelta
- the previous delta (modified in place)
-
eval
Evaluates the data (process the data through the layer). Output is allocated based on the size provided by the LayerModel implementation and the stack (batch) size. Developer is responsible for checking the size of output when writing to it.- Parameters:
data
- dataoutput
- output (modified in place)
-
grad
void grad(breeze.linalg.DenseMatrix<Object> delta, breeze.linalg.DenseMatrix<Object> input, breeze.linalg.DenseVector<Object> cumGrad) Computes the gradient. cumGrad is a wrapper on the part of the weight vector. Size of cumGrad is based on weightSize provided by implementation of LayerModel.- Parameters:
delta
- delta for this layerinput
- input datacumGrad
- cumulative gradient (modified in place)
-
weights
breeze.linalg.DenseVector<Object> weights()
-