Y_shape_cache_ = X.dims(); // This is an invariant of canonical_axis, so we can DCHECK. DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); Y_shape_cache_.resize(canonical_axis + 1); Y_shape_cache_[canonical_axis] = N; Y->Resize(Y_shape_cache_); CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
if (X.size() == 0) { // skip the rest of the computation if X is empty Y->templatemutable_data<T_Y>(); returntrue; }
// default to FLOAT as math.h does. TensorProto::DataType math_type = TensorProto_DataType_FLOAT; if (fp16_type<MATH>()) { math_type = TensorProto_DataType_FLOAT16; }
// 计算XW^T math::Gemm<T_X, Context, Engine>( CblasNoTrans, TransposeWeight ? CblasTrans : CblasNoTrans, M, N, K, 1, X.templatedata<T_X>(), W.templatedata<T_W>(), 0, Y->templatemutable_data<T_Y>(), &context_, math_type); // 加上基向量 if (bias_multiplier_.size() != M) { // If the helper bias multiplier is not M, reshape and fill it with one. bias_multiplier_.Resize(M); math::Set<T_B, Context>( M, convert::To<float, T_B>(1), bias_multiplier_.templatemutable_data<T_B>(), &context_); } math::Gemm<T_B, Context, Engine>( CblasNoTrans, CblasNoTrans, M, N, 1, 1, bias_multiplier_.templatedata<T_B>(), b.templatedata<T_B>(), 1, Y->templatemutable_data<T_Y>(), &context_, math_type); returntrue; }
这里用了下面这个函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
template <> voidGemm<float, CPUContext>( const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, constint M, constint N, constint K, constfloat alpha, constfloat* A, constfloat* B, constfloat beta, float* C, CPUContext* context, TensorProto::DataType math_type);
template < typename T_X, typename T_W, typename T_DY, typename T_B, typename T_DX, typename T_DW, typename T_DB, typename MATH> boolDoRunWithType(){ constauto& X = Input(0); constauto& W = Input(1); constauto& dY = Input(2); // batch size constauto canonical_axis = X.canonical_axis_index(axis_); constint M = X.size_to_dim(canonical_axis); constint K = X.size_from_dim(canonical_axis); constauto canonical_axis_w = W.canonical_axis_index(axis_w_); constint N = W.size_to_dim(canonical_axis_w); CAFFE_ENFORCE(M * K == X.size()); CAFFE_ENFORCE(K * N == W.size());
auto* dW = Output(0); auto* db = Output(1); dW->ResizeLike(W); db->Resize(N);
if (X.size() == 0) { // generate a zero blob for db and dW when X is empty // skipped //... returntrue; }
// default to FLOAT as math.h does. TensorProto::DataType math_type = TensorProto_DataType_FLOAT; if (fp16_type<MATH>()) { math_type = TensorProto_DataType_FLOAT16; }
// Compute dW math::Gemm<T_DY, Context, Engine>( CblasTrans, CblasNoTrans, N, K, M, 1, dY.templatedata<T_DY>(), X.templatedata<T_X>(), 0, dW->templatemutable_data<T_DW>(), &context_, math_type); if (bias_multiplier_.size() != M) { // If the helper bias multiplier is not M, reshape and fill it // with one. bias_multiplier_.Resize(M); math::Set<T_B, Context>( M, convert::To<float, T_B>(1), bias_multiplier_.templatemutable_data<T_B>(), &context_); } // Compute dB math::Gemv<T_DY, Context>( CblasTrans, M, N, 1, dY.templatedata<T_DY>(), bias_multiplier_.templatedata<T_B>(), 0, db->templatemutable_data<T_DB>(), &context_);