1. 程式人生 > >Regularization loss in ‘slim’ library of Tensorflow

Regularization loss in ‘slim’ library of Tensorflow

My python code using slim library to train classification model in Tensorflow:

Python
12345678910 withtf.contrib.slim.arg_scope(mobilenet_v2.training_scope(weight_decay=0.001)):logits,_=mobilenet_v2.mobilenet(images,NUM_CLASSES)cross_entropy=tf.losses.sparse_softmax_cross_entropy(labels=labels,logits=logits)cross_entropy=tf.reduce_mean(cross_entropy
)global_step=tf.contrib.framework.get_or_create_global_step()train_op=tf.contrib.slim.learning.create_train_op(cross_entropy,opt,global_step=global_step)...sess.run(train_op)

It works fine. However, no matter what value the ‘weight_decay’ is, the training accuracy of the model could reach higher than 90% easily. It seems ‘weight_decay’ just doesn’t work.
In order to find out the reason, I reviewed the code of Tensorflow for ‘tf.losses.sparse_softmax_cross_entropy()’:

Python
12345678910111213141516171819 # tensorflow/python/ops/losses/losses_impl.py@tf_export("losses.sparse_softmax_cross_entropy")defsparse_softmax_cross_entropy(labels,logits,weights=1.0,scope=None,loss_collection=ops.GraphKeys.LOSSES,reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):...withops.name_scope(scope,"sparse_softmax_cross_entropy_loss",(logits,labels,weights))asscope:# As documented above in Args, labels contain class IDs and logits contains# 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1;# therefore, expected_rank_diff=1.labels,logits,weights=_remove_squeezable_dimensions(labels,logits,weights,expected_rank_diff=1)losses=nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits=logits,name="xentropy")returncompute_weighted_loss(losses,weights,scope,loss_collection,reduction=reduction)

The ‘losses.sparse_softmax_cross_entropy()’ simply call ‘tf.nn.sparse_softmax_cross_entropy()’. Then let’s look into the implementation of ‘compute_weighted_loss()’:

Python
123456789101112131415161718 # tensorflow/python/ops/losses/losses_impl.py@tf_export("losses.compute_weighted_loss")defcompute_weighted_loss(losses,weights=1.0,scope=None,loss_collection=ops.GraphKeys.LOSSES,reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):...loss=math_ops.cast(loss,input_dtype)util.add_loss(loss,loss_collection)returnloss<pre>What the secret in'util.add_loss()'?<pre lang='python'masks='6'># tensorflow/python/ops/losses/util.py@tf_export("losses.add_loss")defadd_loss(loss,loss_collection=ops.GraphKeys.LOSSES):...ifloss_collection:ops.add_to_collection(loss_collection,loss)

The losses of ‘losses.sparse_softmax_cross_entropy()’ will be added into collection of ‘GraphKeys.LOSSES’. Then where dose the weight of parameters go ? Will they be added into same collection ? Let’s check. All the layer written by library of ‘tf.layers’ or ‘tf.contrib.slim’ are inherited from ‘class Layer’ and will call ‘add_loss()’ when this layer call ‘add_variable()’. Let’s check ‘add_loss()’ of base class ‘Layer’:

Python
123456 @tf_export('layers.Layer')classLayer(checkpointable.CheckpointableBase):...defadd_loss(self,losses,inputs=None):..._add_elements_to_collection(losses,ops.GraphKeys.REGULARIZATION_LOSSES)

It’s weird. The loss from weight of variable has not been added into ‘GraphKeys.LOSSES’, but ‘GraphKeys.REGULARIZATION_LOSSES’. Then how could we get all the losses at training stage ? After grep ‘REGULARIZATION_LOSSES’ in whole codes of Tensorflow, it comes up with the ‘get_total_loss()’:

Python
12345678 # tensorflow/python/ops/losses/util.py@tf_export("losses.get_total_loss")defget_total_loss(add_regularization_losses=True,name="total_loss"):...losses=get_losses()ifadd_regularization_losses:losses+=get_regularization_losses()returnmath_ops.add_n(losses,name=name)

That is the secret of losses in ‘tf.layers’ and ‘tf.contrib.slim’: we should use ‘get_total_loss()’ to fetch model loss and regularization loss together!
After changing my code:

Python
123456789 cross_entropy=tf.losses.sparse_softmax_cross_entropy(labels=labels,logits=logits)cross_entropy=tf.reduce_mean(cross_entropy)global_step=tf.contrib.framework.get_or_create_global_step()loss=tf.contrib.slim.losses.get_total_loss()train_op=tf.contrib.slim.learning.create_train_op(loss,opt,global_step=global_step)...sess.run(train_op)

The ‘weight_decay’ works well now (which means training accuracy could not reach high value easily)