Should 'y' be 2 dimensional in DataLoader()?
from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
path.ls()
def dl(p, shuffle=False, reshape_y=True):
def f(x):
return [tensor(Image.open(o)) for o in (x).ls()]
X3 = torch.stack(f(path/p/'3'))
X7 = torch.stack(f(path/p/'7'))
y = tensor([1]*len(X3) + [0]*len(X7))
y = y.reshape(-1, 1) if reshape_y==True else y
X = torch.cat([X3, X7]).view(-1, 28*28)/255.
ds = list(zip(X, y))
dl = DataLoader(ds, batch_size=256, shuffle=shuffle)
return dl
def dls(shuffle=True, reshape_y=True):
return DataLoaders(dl('train', shuffle=shuffle, reshape_y=reshape_y), dl('valid', reshape_y=reshape_y))
def batch_accuracy(xb, yb):
preds = xb.sigmoid()
correct = (preds>0.5) == yb
return correct.float().mean()
def mnist_loss(predictions, targets):
predictions = predictions.sigmoid()
return torch.where(targets==1, 1-predictions, predictions).mean()
simple_net = nn.Sequential(
nn.Linear(28*28,30),
nn.ReLU(),
nn.Linear(30,1)
)
learn = Learner(dls(shuffle=True, reshape_y=True), simple_net, opt_func=SGD, loss_func=mnist_loss, metrics=batch_accuracy)
learn.fit(5, 0.1)
#plt.plot(L(learn.recorder.values).itemgot(0))
#plt.plot(L(learn.recorder.values).itemgot(1))
plt.plot(L(learn.recorder.values).itemgot(2))
learn = Learner(dls(shuffle=True, reshape_y=False), simple_net, opt_func=SGD, loss_func=mnist_loss, metrics=batch_accuracy)
learn.fit(9, 0.1)
#plt.plot(L(learn.recorder.values).itemgot(0))
#plt.plot(L(learn.recorder.values).itemgot(1))
plt.plot(L(learn.recorder.values).itemgot(2))