Cross-Entropy vs Mean Squared Error
The cross-entropy is almost exclusively used as the loss function for classification tasks, but it is not obvious why we can not use the mean squared error. Actually we can, but as we will see shortly, the cross-entropy is a more convenient measure of loss for classification tasks.
For this discusson we will deal with a single sample and distinquish between different cases.
Loss Function | True Label | Loss |
---|---|---|
MSE: (y-\hat{y})^2 undefined | 0 undefined | (0-\hat{y})^2 undefined |
MSE: (y-\hat{y})^2 undefined | 1 undefined | (1-\hat{y})^2 undefined |
CE: -\Big[y \log (\hat{y}) + (1 - y) \log ( 1 - \hat{y})\Big] undefined | 0 undefined | -\log ( 1 - \hat{y}) undefined |
CE: -\Big[y \log (\hat{y}) + (1 - y) \log ( 1 - \hat{y})\Big] undefined | 1 undefined | -\log ( \hat{y}) undefined |
If the label equals to 0 both losses increase as the predicted probability grows. If the true label is 1 on the other hand the error decreases when the predicted probability grows.
Below we plot the mean squared error and the cross-entropy based on the predicted probability \hat{y} undefined . The red plot depicts the mean squared error, while the blue plot depicts the cross-entropy. There are two plots for each of the losses, one for each value of the target.
The mean squared error and the cross-entropy start at the same position, but the difference in errors starts to grow as the predicted probability starts to deviate from the true label. The cross-entropy punishes misclassifications with a much higher loss, than the mean squared error. When we deal with probabilities the difference between the label and the predicted probability can not be larger than 1. That means that the mean squared error also can not grow beyond 1. The logarithm on the other hand literally explodes when the value starts approaching 0.
This behaviour can also be observed when we draw the predicted probability against the derivative of the loss function. While the derivatives of the mean squared error are linear, the cross-entropy derivatives grow exponentially when the quality of predictions deteriorates.
The exponential growth of derivative of the cross-entropy loss implies, that the gradient descent algorithm will take much larger steps compared to the mean squared error, when the classification predictions are way off, thereby converging at a higher rate.