Simple Pytorch SGD
import torch
import matplotlib.pyplot as plt
params = torch.tensor([-1.,-2.,-3.])
params
x = torch.arange(-5,5,0.2)
y = 2 * x**2 - 1 * x + 3 + torch.randn_like(x) * 5
def f(param):
a, b, c = param
y = a * x**2 + b * x + c
return y
plt.scatter(x, y)
plt.scatter(x, f(params).detach().cpu().numpy(), color='red')
params.requires_grad_()
origs = params.clone()
origs
def mse(): return ((y-f(params))**2).mean()
loss = mse()
loss
loss.backward()
params.grad
lr = 0.0005
def apply_step():
loss = mse()
loss.backward()
params.data -= lr * params.grad.data
params.grad = None
print(loss.item())
params = origs.detach().requires_grad_()
origs = params.clone()
for _ in range(9):
apply_step()
origs
params = origs.detach().requires_grad_()
_, axs = plt.subplots(3,3, figsize=(24, 24))
for i in range(9):
pos = (i//3, i%3)
axs[pos].scatter(x, y)
axs[pos].scatter(x, f(params).detach().cpu().numpy(), color='red')
apply_step()