import torch
import matplotlib.pyplot as plt
params = torch.tensor([-1.,-2.,-3.])
params
tensor([-1., -2., -3.])
x = torch.arange(-5,5,0.2)
y = 2 * x**2 - 1 * x + 3 + torch.randn_like(x) * 5
def f(param):
    a, b, c = param
    y = a * x**2 + b * x + c
    return y
plt.scatter(x, y)
plt.scatter(x, f(params).detach().cpu().numpy(), color='red')
<matplotlib.collections.PathCollection at 0x7faf55fcb1f0>
params.requires_grad_()
tensor([-1., -2., -3.], requires_grad=True)
origs = params.clone()
origs
tensor([-1., -2., -3.], grad_fn=<CloneBackward0>)
def mse(): return ((y-f(params))**2).mean()
loss = mse()
loss
tensor(1542.2339, grad_fn=<MeanBackward0>)
loss.backward()
params.grad
tensor([-863.2866,   -3.5042,  -63.3594])
lr = 0.0005

def apply_step():
    loss = mse()
    loss.backward()
    params.data -= lr * params.grad.data
    params.grad = None
    print(loss.item())
params = origs.detach().requires_grad_()
origs = params.clone()
for _ in range(9):
    apply_step()
1542.23388671875
1191.165771484375
922.902099609375
717.9000244140625
561.22900390625
441.48260498046875
349.94647216796875
279.96307373046875
226.4463348388672
origs
tensor([-1., -2., -3.], grad_fn=<CloneBackward0>)
params = origs.detach().requires_grad_()
_, axs = plt.subplots(3,3, figsize=(24, 24))
for i in range(9):
    pos = (i//3, i%3)
    axs[pos].scatter(x, y)
    axs[pos].scatter(x, f(params).detach().cpu().numpy(), color='red')
    apply_step()
1542.23388671875
1191.165771484375
922.902099609375
717.9000244140625
561.22900390625
441.48260498046875
349.94647216796875
279.96307373046875
226.4463348388672