Momentum-based Gradient Descent from Scratch
With a recap of vanilla/normal/batch gradient descent
Gradient Descent laid the groundwork for optimization in machine learning, but its vanilla form often struggles with inefficiency and erratic convergence. Momentum Gradient Descent addresses these issues by smoothing updates, making optimization faster and more stable. In this article, we’ll break down Momentum Gradient Descent, its math, and why it’s a go-to strategy for modern optimization.
Shortcomings of Gradient Descent
Computational Overhead: Calculating the gradient over the entire dataset is computationally expensive for large datasets. For a dataset with N samples, computing the gradient requires N operations per update.
Slow Convergence: Gradient Descent can take small, steady steps and struggle with large datasets or complex loss landscapes.
Local Minima or Saddle Points: It may get stuck in a local minimum or wander in a flat saddle region, delaying convergence.
These limitations led to the development of Stochastic Gradient Descent, which makes optimization faster and more flexible.
NOTE: The “normal” gradient descent is also called as vanilla gradient descent or batch gradient descent (BGD). Batch because the parameter update is calculated with respect to all data points in the batch.
Momentum Gradient Descent: The Core Idea
What Is Momentum Gradient Descent?
Momentum Gradient Descent introduces a "velocity" term that accumulates gradients over time, allowing the optimizer to maintain direction in flat regions and reduce oscillations in steep ones. Imagine rolling a ball down a hill: instead of stopping at every bump, it builds speed and momentum to move smoothly toward the valley.
The momentum update rule is:
Here:
vt: Velocity term at iteration t
β: Momentum coefficient (typically 0.9)
η: Learning rate
Parameter Update:
How Momentum Works
In Flat Regions: The accumulated velocity vt helps maintain motion even when gradients are small, speeding up convergence.
In Narrow Valleys: Momentum reduces oscillations by averaging updates across multiple steps.
Analogy:
Imagine pedaling a bicycle downhill. Even if the slope becomes flat or uneven, your built-up momentum keeps you moving forward without constant effort.
Mathematical Intuition
Effective Gradient with Momentum
Momentum modifies the effective gradient using the exponential moving average of past gradients:
This cumulative term smooths updates, emphasizing consistent directions while diminishing random fluctuations.
Damping Oscillations
In directions with steep gradients, vt counterbalances excessive updates by incorporating the gradient's history.
Advantages of Momentum Gradient Descent
Faster Convergence: Momentum accelerates updates in consistent directions, reducing time to reach the optimum.
Reduced Oscillations: By incorporating past gradients, it dampens erratic movements in steep regions.
Less Sensitive to Learning Rate: Momentum provides stability, allowing for slightly higher learning rates.
Challenges of Momentum
While powerful, Momentum Gradient Descent isn’t perfect:
Tuning β: Choosing the right momentum coefficient β is crucial. Typical values range between 0.8 and 0.99, but improper tuning can cause instability.
Over-Shooting: Momentum can sometimes overshoot the optimum, especially in noisy loss landscapes.
Comparison with Other Optimizers
Versus Standard Gradient Descent: Momentum is faster and more stable, especially in steep or flat regions.
Versus Stochastic Gradient Descent: Momentum reduces the noise of SGD by incorporating gradient history, making updates smoother.
Versus Adaptive Methods (Adam, RMSprop): Momentum doesn’t adapt learning rates per parameter but remains competitive due to its simplicity and effectiveness.
Tuning Momentum
Key Hyperparameters
Momentum Coefficient β: Typically 0.9; higher values provide smoother updates but may overshoot.
Learning Rate η: Works well with moderately higher learning rates than vanilla Gradient Descent.
Conclusion
Momentum Gradient Descent is a versatile and efficient optimization technique that addresses many of the shortcomings of standard Gradient Descent. By incorporating gradient history, it smooths the optimization path, accelerates convergence, and reduces oscillations, especially in challenging loss landscapes.
Whether you’re training a linear model or a deep neural network, Momentum Gradient Descent provides the stability and speed needed to optimize complex systems. As you continue your journey in machine learning, momentum ensures that every step forward builds on the lessons of the past—literally.
Code implementation of Momentum Gradient Descent
import numpy as np
import matplotlib.pyplot as plt
def quadratic_loss(x, y):
return x**2 + 10 * y**2
def quadratic_grad(x, y):
dx = 2 * x
dy = 20 * y
return np.array([dx, dy])
def batch_gradient_descent(grad_func, eta, epochs, start_point):
x, y = start_point
path = [(x, y)]
losses = [quadratic_loss(x, y)]
for _ in range(epochs):
grad = grad_func(x, y)
x -= eta * grad[0]
y -= eta * grad[1]
path.append((x, y))
losses.append(quadratic_loss(x, y))
return np.array(path), losses
def gradient_descent_momentum(grad_func, eta, beta, epochs, start_point):
x, y = start_point
v = np.array([0, 0])
path = [(x, y)]
losses = [quadratic_loss(x, y)]
for _ in range(epochs):
grad = grad_func(x, y)
v = beta * v + (1 - beta) * grad
x -= eta * v[0]
y -= eta * v[1]
path.append((x, y))
losses.append(quadratic_loss(x, y))
return np.array(path), losses
def plot_paths(function, paths, labels, title):
X, Y = np.meshgrid(np.linspace(-2, 2, 400), np.linspace(-2, 2, 400))
Z = function(X, Y)
plt.figure(figsize=(8, 6))
plt.contour(X, Y, Z, levels=50, cmap='jet')
for path, label in zip(paths, labels):
plt.plot(path[:, 0], path[:, 1], label=label)
plt.scatter(path[0, 0], path[0, 1], color='green', label="Start")
plt.scatter(path[-1, 0], path[-1, 1], color='red', label="End")
plt.title(title)
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()
def plot_losses(losses, labels, title):
plt.figure(figsize=(8, 6))
for loss, label in zip(losses, labels):
plt.plot(loss, label=label)
plt.title(title)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
eta_bgd = 0.05 # Learning rate for BGD
eta_momentum = 0.05 # Learning rate for Momentum
beta = 0.9 # Momentum coefficient
epochs = 150
start_point = (1.5, 1.5) # Initial point far from the minimum
path_bgd, losses_bgd = batch_gradient_descent(quadratic_grad, eta_bgd, epochs, start_point)
path_momentum, losses_momentum = gradient_descent_momentum(quadratic_grad, eta_momentum, beta, epochs, start_point)
plot_paths(quadratic_loss, [path_bgd, path_momentum],
["Batch Gradient Descent", "Gradient Descent with Momentum"],
"Oscillations in BGD vs Momentum")
plot_losses([losses_bgd, losses_momentum],
["Batch Gradient Descent", "Gradient Descent with Momentum"],
"Loss vs Epochs for BGD and Momentum")