from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
path.ls()
(#3) [Path('train'),Path('labels.csv'),Path('valid')]
def get_dls():
    def get_ds(train_valid):
        def get_X(train_valid, three_seven):
            files = (path/train_valid/three_seven).ls()
            X = torch.stack([tensor(Image.open(x)) for x in files])#.reshape(-1, 28*28)
            return X
        X3, X7 = get_X(train_valid, '3'), get_X(train_valid, '7')
        X = torch.cat([X3, X7])/255.
        y = torch.tensor([1]*len(X3) + [0]*len(X7)).reshape(-1, 1)
        #print(X.shape, y.shape)
        return list(zip(X, y))
    ds1 = DataLoader(get_ds('train'), bs=256, shuffle=True)
    ds2 = DataLoader(get_ds('valid'), bs=256)
    return DataLoaders(ds1, ds2)
                    
dls = get_dls()    
def loss_fn(y, t):
    return torch.where(t==1, 1-y, y).mean()

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28,30),
    nn.ReLU(),
    nn.Linear(30,1),
    nn.Sigmoid(),
)
def batch_accuracy(y, t):
    return ((y>0.5)==t).float().mean()
learn = Learner(dls, model, opt_func=SGD, loss_func=loss_fn, metrics=batch_accuracy)
learn.fit(40, 0.1)
plt.plot(L(learn.recorder.values).itemgot(2));
epoch train_loss valid_loss batch_accuracy time
0 0.251622 0.099996 0.966143 00:00
1 0.117085 0.055608 0.969087 00:00
2 0.070977 0.044978 0.971050 00:00
3 0.049911 0.039824 0.971541 00:00
4 0.039623 0.036744 0.972031 00:00
5 0.034272 0.034183 0.973994 00:00
6 0.030532 0.032801 0.974975 00:00
7 0.028858 0.031304 0.975466 00:00
8 0.027127 0.030324 0.976938 00:00
9 0.025220 0.029530 0.977920 00:00
10 0.024591 0.028499 0.977920 00:00
11 0.023528 0.027751 0.977920 00:00
12 0.022990 0.027213 0.977920 00:00
13 0.022170 0.026516 0.978901 00:00
14 0.021096 0.026030 0.979392 00:00
15 0.020740 0.025526 0.979392 00:00
16 0.020556 0.025365 0.978901 00:00
17 0.020185 0.024765 0.979882 00:00
18 0.019934 0.024219 0.979392 00:00
19 0.019707 0.024067 0.979392 00:00
20 0.019217 0.023683 0.979392 00:00
21 0.018532 0.023350 0.979392 00:00
22 0.018405 0.023191 0.979882 00:00
23 0.018165 0.022752 0.980373 00:00
24 0.017803 0.022621 0.980864 00:00
25 0.017138 0.022359 0.980864 00:00
26 0.017201 0.022179 0.980864 00:00
27 0.017088 0.022125 0.981354 00:00
28 0.016755 0.021805 0.981845 00:00
29 0.016481 0.021685 0.981845 00:00
30 0.016408 0.021350 0.981845 00:00
31 0.016280 0.021244 0.982336 00:00
32 0.016051 0.020892 0.982336 00:00
33 0.015875 0.020904 0.982826 00:00
34 0.015870 0.020447 0.982336 00:00
35 0.015795 0.020434 0.982336 00:00
36 0.015723 0.020223 0.981845 00:00
37 0.015477 0.020262 0.982826 00:00
38 0.015137 0.020047 0.982826 00:00
39 0.015087 0.019813 0.982826 00:00